# Try fine-tuning `Sentence BERT`
Since Thur. Dec. 9th, 2021

To set up our fine-tuning pipeline

To verify our implementation, check if we can reproduce the results from [*Sentence-BERT: Sentence Embeddings using Siamese BERT-Networks*](https://arxiv.org/abs/1908.10084).
With Hugging Face dependency, not [sentence-transformers](https://github.com/UKPLab/sentence-transformers).

Specifically, on **section 4.2 Supervised STS**, go with `SBERT-STSb-base`:
> The STS benchmark (STSb) (Cer et al., 2017) provides is a popular dataset to evaluate supervised STS systems.

> We use the training set to fine-tune SBERT using the regression objective function. At prediction time, we compute the cosine-similarity between the sentence embeddings. All systems are trained with 10 random seeds to counter variances.

An example on [sentence-transformers](https://github.com/UKPLab/sentence-transformers):
    [training_stsbenchmark.py](https://github.com/UKPLab/sentence-transformers/blob/master/examples/training/sts/training_stsbenchmark.py).


## Setup



In [27]:
from transformers import AutoTokenizer, AutoModel
from datasets import load_dataset
from icecream import ic


dnm = 'stsb_multi_mt'
mdnm = 'bert-base-uncased'
seed = 77



## Check STSb data



In [28]:
dset = load_dataset(dnm, name='en')
ic(dset)
# ic(dset['dev'])

# for split in ['dev', 'train', 'test']:
#     dset = load_dataset(dnm, name='en', split=split)
#     ic(dset)

ic(len(dset['train']), dset['train'][0])



Reusing dataset stsb_multi_mt (/Users/stefanh/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/bc6de0eaa8d97c28a4c22a07e851b05879ae62c60b0b69dd6b331339e8020f07)


  0%|          | 0/3 [00:00<?, ?it/s]

ic| dset: {'dev': Dataset({
              features: ['sentence1', 'sentence2', 'similarity_score'],
              num_rows: 1500
          }),
           'test': Dataset({
              features: ['sentence1', 'sentence2', 'similarity_score'],
              num_rows: 1379
          }),
           'train': Dataset({
              features: ['sentence1', 'sentence2', 'similarity_score'],
              num_rows: 5749
          })}
ic| len(dset['train']): 5749
    dset['train'][0]: {'sentence1': 'A plane is taking off.',
                       'sentence2': 'An air plane is taking off.',
                       'similarity_score': 5.0}


(5749,
 {'sentence1': 'A plane is taking off.',
  'sentence2': 'An air plane is taking off.',
  'similarity_score': 5.0})

## Check BERT model


In [29]:
tokenizer = AutoTokenizer.from_pretrained(mdnm)
ic(tokenizer)

# model = BertModel.from_pretrained('bert-base-uncased')
model = AutoModel.from_pretrained(mdnm)
ic(model.__class__)



Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

ic| tokenizer: PreTrainedTokenizerFast(name_or_path='bert-base-uncased', vocab_size=30522, model_max_len=512, is_fast=True, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})


Downloading:   0%|          | 0.00/420M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
ic| model.__class__: <class 'transformers.models.bert.modeling_bert.BertModel'>


transformers.models.bert.modeling_bert.BertModel