# Post-Training S-BERT

Post-trains S-BERT on triplet examples taken from the Safegraph/Mergelog merged tables. This notebook was last trained on gamestop brand data, but can be easily adapted to whatever triplet datasets are provided.

## Setup

In [1]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.readers import TripletReader
import logging
from datetime import datetime

In [2]:
#### Just some code to print debug information to stdout
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

## Initialize Training Objects and Parameters

In [17]:
# Read the dataset
model_name = 'bert-base-nli-mean-tokens'
train_batch_size = 16
num_epochs = 4
model_save_path = 'output/training_stsbenchmark_continue_training-'+model_name+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
trp_reader = TripletReader('data/training')

# Load a pre-trained sentence transformer model
model = SentenceTransformer(model_name)


2020-08-21 13:47:18 - Load pretrained SentenceTransformer: bert-base-nli-mean-tokens
2020-08-21 13:47:18 - Did not find a '/' or '\' in the name. Assume to download model from server.
2020-08-21 13:47:18 - Load SentenceTransformer from folder: /home/chad/.cache/torch/sentence_transformers/public.ukp.informatik.tu-darmstadt.de_reimers_sentence-transformers_v0.2_bert-base-nli-mean-tokens.zip
2020-08-21 13:47:18 - loading configuration file /home/chad/.cache/torch/sentence_transformers/public.ukp.informatik.tu-darmstadt.de_reimers_sentence-transformers_v0.2_bert-base-nli-mean-tokens.zip/0_BERT/config.json
2020-08-21 13:47:18 - Model config BertConfig {
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers

### Load Training set

Generated from the following datasets:
- `gamestop_strict`: Filters by users who have visted gamestop, and only retains online activity related to games
- `all_brands_ndomain10`: Merged safegraph and mergelog data for users that visted at least 10 domains (First half of August)

In [6]:
# Convert the dataset to a DataLoader ready for training
logging.info("Read Gamestop Triplet train dataset")
train_dataset = SentencesDataset(trp_reader.get_examples('gamestop_triplet.txt'), model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)

2020-08-21 12:51:11 - Read Gamestop Triplet train dataset


Convert dataset: 100%|██████████| 1000000/1000000 [04:46<00:00, 3493.19it/s]

2020-08-21 12:56:01 - Num sentences: 1000000
2020-08-21 12:56:01 - Sentences 0 longer than max_seqence_length: 0
2020-08-21 12:56:01 - Sentences 1 longer than max_seqence_length: 0
2020-08-21 12:56:01 - Sentences 2 longer than max_seqence_length: 0





### Load Validation set

Generated from the following datasets:
- `safegraph_gamestop`: Filters by users who have visted gamestop, and also visted at least 1 gaming-related page
- `full_merge_ndomain10`: Merged safegraph and mergelog data for users that visted at least 10 domains (First half of July)

**Note:** A better development set could have been prepared by re-merging safegraph data over a different time-span

In [12]:
logging.info("Read STSbenchmark dev dataset")
dev_dataset = SentencesDataset(trp_reader.get_examples('gamestop_triplet_dev.txt'), model)
dev_dataloader = DataLoader(dev_dataset, shuffle=True, batch_size=train_batch_size)
evaluator = TripletEvaluator(dev_dataloader, name='trp-dev')

2020-08-21 13:05:08 - Read STSbenchmark dev dataset


Convert dataset: 100%|██████████| 100000/100000 [00:30<00:00, 3331.35it/s]

2020-08-21 13:05:40 - Num sentences: 100000
2020-08-21 13:05:40 - Sentences 0 longer than max_seqence_length: 0
2020-08-21 13:05:40 - Sentences 1 longer than max_seqence_length: 0
2020-08-21 13:05:40 - Sentences 2 longer than max_seqence_length: 0





## Post-Train Model

Using the `capture` magic command to allow this cell to run in the background while the page is closed

In [None]:
%%capture training_output
%%time
# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

train_loss = losses.TripletLoss(model=model, distance_metric=losses.TripletDistanceMetric.COSINE)

# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)

### Load Test set and Evaluate

Using the `capture` magic command to allow this cell to run in the background while the page is closed

Generated from the following datasets:
- `gamestop_raw`: Filters by any users who have visted gamestop
- `full_merge_ndomain7`: Merged safegraph and mergelog data for users that visted at least 7 domains (First half of July)

**Note:** A better test set could have been prepared by re-merging safegraph data over a different time-span

In [None]:
%%capture eval_output

model = SentenceTransformer(model_save_path)
test_dataset = SentencesDataset(trp_reader.get_examples('gamestop_triplet_test.txt'), model)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=train_batch_size)
test_evaluator = TripletEvaluator(test_dataloader, name='trp-test')
test_evaluator(model, output_path=model_save_path)

In [24]:
training_output.show()
eval_output.show()

2020-08-22 15:42:10 - Warmup-steps: 25000
2020-08-22 15:42:10 - Load pretrained SentenceTransformer: output/training_stsbenchmark_continue_training-bert-base-nli-mean-tokens-2020-08-21_13-47-18
2020-08-22 15:42:10 - Load SentenceTransformer from folder: output/training_stsbenchmark_continue_training-bert-base-nli-mean-tokens-2020-08-21_13-47-18
2020-08-22 15:42:10 - loading configuration file output/training_stsbenchmark_continue_training-bert-base-nli-mean-tokens-2020-08-21_13-47-18/0_BERT/config.json
2020-08-22 15:42:10 - Model config BertConfig {
  "architectures": [
    "BertModel"
  ],
  "attention_probs_dropout_prob": 0.1,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "type_vocab_size": 2,
  "vocab

Convert dataset:   0%|          | 0/100000 [00:00<?, ?it/s]Convert dataset:   0%|          | 356/100000 [00:00<00:27, 3559.14it/s]Convert dataset:   1%|          | 726/100000 [00:00<00:27, 3599.54it/s]Convert dataset:   1%|1         | 1087/100000 [00:00<00:27, 3600.08it/s]Convert dataset:   1%|1         | 1447/100000 [00:00<00:27, 3598.86it/s]Convert dataset:   2%|1         | 1800/100000 [00:00<00:27, 3574.69it/s]Convert dataset:   2%|2         | 2159/100000 [00:00<00:27, 3578.37it/s]Convert dataset:   2%|2         | 2499/100000 [00:00<00:27, 3520.79it/s]Convert dataset:   3%|2         | 2841/100000 [00:00<00:27, 3488.29it/s]Convert dataset:   3%|3         | 3191/100000 [00:00<00:27, 3491.65it/s]Convert dataset:   4%|3         | 3547/100000 [00:01<00:27, 3507.16it/s]Convert dataset:   4%|3         | 3911/100000 [00:01<00:27, 3544.48it/s]Convert dataset:   4%|4         | 4281/100000 [00:01<00:26, 3589.45it/s]Convert dataset:   5%|4         | 4647/100000 [00:01<00:26, 3608.

0.73813