# Explore `sentence_transformers`
Since Mon. Dec. 6th, 2021


Try `sentence_transformers` package for fine-tuning bi-encoders.




## Setup



In [6]:
import math
import random

import numpy as np
import torch
from torch.utils.data import DataLoader
from datasets import load_dataset
from sentence_transformers import SentenceTransformer, InputExample, models, losses, evaluation
from icecream import ic

from util import *


seed = config('random-seed')
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.use_deterministic_algorithms(True)

def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(seed)


name_tune = 'eg_sbert'
d_tune = config(f'fine-tune.{name_tune}')
d_dset = config(f'datasets.{d_tune["dataset_name"]}')
md_path = os.path.join(PATH_BASE, DIR_MDL, name_tune)



## Set up model



In [7]:
base = models.Transformer(d_tune['embedding_model_name'], max_seq_length=d_tune['max_seq_length'])
pool = models.Pooling(base.get_word_embedding_dimension(), **d_tune['pooling_model_kwargs'])
# ic(pool.pooling_mode_mean_tokens)
model = SentenceTransformer(modules=[base, pool])
# ic(model[0].auto_model.__class__)



Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Prep Data


In [8]:
dnm = d_tune['dataset_name']

dset = load_dataset(dnm, name='en').shuffle(seed=seed)
# ic(dset.keys())
# ic(d_dset)
for nm, ds in dset.items():
    # ic(nm, ds, len(ds), type(len(ds)), d_dset['n_sample'][nm], type(d_dset['n_sample'][nm]), len(ds) == d_dset['n_sample'][nm])
    assert len(ds) == d_dset['n_sample'][nm]

dset_ = dict()
ran = d_dset['label_range']
ran = ran['max'] - ran['min']
for nm, ds in dset.items():
    dset_[nm] = [InputExample(texts=[eg['sentence1'], eg['sentence2']], label=eg['similarity_score'] / ran) for eg in ds]



Reusing dataset stsb_multi_mt (/Users/stefanh/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/a5d260e4b7aa82d1ab7379523a005a366d9b124c76a5a5cf0c4c5365458b0ba9)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached shuffled indices for dataset at /Users/stefanh/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/a5d260e4b7aa82d1ab7379523a005a366d9b124c76a5a5cf0c4c5365458b0ba9/cache-5015801d0d574272.arrow
Loading cached shuffled indices for dataset at /Users/stefanh/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/a5d260e4b7aa82d1ab7379523a005a366d9b124c76a5a5cf0c4c5365458b0ba9/cache-aed9104de7b4bf7c.arrow
Loading cached shuffled indices for dataset at /Users/stefanh/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/a5d260e4b7aa82d1ab7379523a005a366d9b124c76a5a5cf0c4c5365458b0ba9/cache-09d174b054d294eb.arrow


## Train


In [9]:
dl_train = DataLoader(dset_['train'], shuffle=True, batch_size=d_tune['batch_size'], worker_init_fn=seed_worker)
model.fit(
    train_objectives=[
        (dl_train, losses.CosineSimilarityLoss(model))
    ],
    epochs=d_tune['n_epochs'],
    warmup_steps=math.ceil(len(dl_train) * d_tune['n_epochs'] * d_tune['warmup_frac']),
    evaluator=evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dset_['dev'], name=f'{dnm}, dev'),
    output_path=md_path
)



Epoch:   0%|          | 0/4 [00:00<?, ?it/s]

Iteration:   0%|          | 0/360 [00:00<?, ?it/s]

KeyboardInterrupt: 

## Evaluate


In [None]:
model = SentenceTransformer(md_path)
test_evaluator = evaluation.EmbeddingSimilarityEvaluator.from_input_examples(dset_['test'], name=f'{dnm}, test')
test_evaluator(model, output_path=md_path)

