<a href="https://colab.research.google.com/github/450586509/practical-ml/blob/master/sentence_bert_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install -U sentence-transformers


Collecting sentence-transformers
[?25l  Downloading https://files.pythonhosted.org/packages/ce/4b/0add07b1eebbbe83e77fb5ac4e72e87046c3fc2c9cb16f7d1cd8c6921a1d/sentence-transformers-0.3.7.2.tar.gz (59kB)
[K     |████████████████████████████████| 61kB 6.2MB/s 
[?25hCollecting transformers<3.4.0,>=3.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/19/22/aff234f4a841f8999e68a7a94bdd4b60b4cebcfeca5d67d61cd08c9179de/transformers-3.3.1-py3-none-any.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 13.4MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 49.3MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883

In [2]:
from torch.utils.data import DataLoader
import math
from sentence_transformers import SentenceTransformer,  SentencesDataset, LoggingHandler, losses, models, util
from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator
from sentence_transformers.readers import STSBenchmarkDataReader, InputExample
import logging
from datetime import datetime
import sys
import os
import gzip
import csv

In [3]:
logging.basicConfig(format='%(asctime)s - %(message)s',
                    datefmt='%Y-%m-%d %H:%M:%S',
                    level=logging.INFO,
                    handlers=[LoggingHandler()])
#### /print debug information to stdout

In [4]:
#Check if dataset exsist. If not, download and extract  it
sts_dataset_path = 'datasets/stsbenchmark.tsv.gz'

if not os.path.exists(sts_dataset_path):
    util.http_get('https://sbert.net/datasets/stsbenchmark.tsv.gz', sts_dataset_path)

100%|██████████| 392k/392k [00:02<00:00, 134kB/s]


In [5]:
# Read the dataset
model_name = 'bert-base-nli-mean-tokens'
train_batch_size = 16
num_epochs = 4
model_save_path = 'output/training_stsbenchmark_continue_training-'+model_name+'-'+datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [6]:
# Load a pre-trained sentence transformer model
model = SentenceTransformer(model_name)

2020-10-09 10:45:30 - Load pretrained SentenceTransformer: bert-base-nli-mean-tokens
2020-10-09 10:45:30 - Did not find folder bert-base-nli-mean-tokens. Assume to download model from server.
2020-10-09 10:45:30 - Downloading sentence transformer model from https://sbert.net/models/bert-base-nli-mean-tokens.zip and saving it at /root/.cache/torch/sentence_transformers/sbert.net_models_bert-base-nli-mean-tokens


100%|██████████| 405M/405M [00:49<00:00, 8.22MB/s]


2020-10-09 10:46:26 - Load SentenceTransformer from folder: /root/.cache/torch/sentence_transformers/sbert.net_models_bert-base-nli-mean-tokens
2020-10-09 10:46:29 - Use pytorch device: cuda


In [7]:
train_samples = []
dev_samples = []
test_samples = []
with gzip.open(sts_dataset_path, 'rt', encoding='utf8') as fIn:
    reader = csv.DictReader(fIn, delimiter='\t', quoting=csv.QUOTE_NONE)
    for row in reader:
        score = float(row['score']) / 5.0  # Normalize score to range 0 ... 1
        inp_example = InputExample(texts=[row['sentence1'], row['sentence2']], label=score)

        if row['split'] == 'dev':
            dev_samples.append(inp_example)
        elif row['split'] == 'test':
            test_samples.append(inp_example)
        else:
            train_samples.append(inp_example)

In [8]:
train_dataset = SentencesDataset(train_samples, model)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=train_batch_size)
train_loss = losses.CosineSimilarityLoss(model=model)

In [9]:
# Development set: Measure correlation between cosine score and gold labels
logging.info("Read STSbenchmark dev dataset")
evaluator = EmbeddingSimilarityEvaluator.from_input_examples(dev_samples, name='sts-dev')


2020-10-09 10:46:29 - Read STSbenchmark dev dataset


In [10]:
# Configure the training. We skip evaluation in this example
warmup_steps = math.ceil(len(train_dataset) * num_epochs / train_batch_size * 0.1) #10% of train data for warm-up
logging.info("Warmup-steps: {}".format(warmup_steps))

2020-10-09 10:46:29 - Warmup-steps: 144


In [11]:
# Train the model
model.fit(train_objectives=[(train_dataloader, train_loss)],
          evaluator=evaluator,
          epochs=num_epochs,
          evaluation_steps=1000,
          warmup_steps=warmup_steps,
          output_path=model_save_path)


HBox(children=(FloatProgress(value=0.0, description='Epoch', max=4.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Iteration', max=360.0, style=ProgressStyle(description_wi…


2020-10-09 10:48:31 - Evaluation the model on sts-dev dataset after epoch 0:
2020-10-09 10:48:36 - Cosine-Similarity :	Pearson: 0.8617	Spearman: 0.8632
2020-10-09 10:48:36 - Manhattan-Distance:	Pearson: 0.8557	Spearman: 0.8593
2020-10-09 10:48:36 - Euclidean-Distance:	Pearson: 0.8560	Spearman: 0.8597
2020-10-09 10:48:36 - Dot-Product-Similarity:	Pearson: 0.8445	Spearman: 0.8440
2020-10-09 10:48:36 - Save model to output/training_stsbenchmark_continue_training-bert-base-nli-mean-tokens-2020-10-09_10-45-30


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=360.0, style=ProgressStyle(description_wi…


2020-10-09 10:50:25 - Evaluation the model on sts-dev dataset after epoch 1:
2020-10-09 10:50:31 - Cosine-Similarity :	Pearson: 0.8699	Spearman: 0.8716
2020-10-09 10:50:31 - Manhattan-Distance:	Pearson: 0.8607	Spearman: 0.8657
2020-10-09 10:50:31 - Euclidean-Distance:	Pearson: 0.8611	Spearman: 0.8660
2020-10-09 10:50:31 - Dot-Product-Similarity:	Pearson: 0.8550	Spearman: 0.8559
2020-10-09 10:50:31 - Save model to output/training_stsbenchmark_continue_training-bert-base-nli-mean-tokens-2020-10-09_10-45-30


HBox(children=(FloatProgress(value=0.0, description='Iteration', max=360.0, style=ProgressStyle(description_wi…

KeyboardInterrupt: ignored