<a href="https://colab.research.google.com/github/Taaniya/sentence-embeddings-with-bert/blob/main/Finetuning_bert_for_semantic_search_with_SBERT_approach.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

##### This notebook fine-tunes `bert-base-uncased` model for learning sentence embeddings using siamese network (SBERT) with softmax loss on NLI dataset (SNLI and MNLI)

In [None]:
! pip install datasets transformers sentence-transformers

In [2]:
import datasets
from sentence_transformers import SentenceTransformer, InputExample, models
from sentence_transformers import losses

from torch.utils.data import DataLoader
from tqdm.auto import tqdm

### Data preparation

In [3]:
snli = datasets.load_dataset("snli", split='train')

Downloading builder script:   0%|          | 0.00/3.82k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.90k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/14.1k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.93k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/65.9M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.26M [00:00<?, ?B/s]

In [4]:
snli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 550152
})

##### Labels in NLI data
0 - entailment, e.g., the premise suggests the hypothesis

1 - neutral, the premise and hypothesis could both be true, but they are not necessarily related

2 - contradiction, the premise and hypothesis contradict each other

In [5]:
snli[0]

{'premise': 'A person on a horse jumps over a broken down airplane.',
 'hypothesis': 'A person is training his horse for a competition.',
 'label': 1}

In [None]:
mnli = datasets.load_dataset('glue', 'mnli', split='train')

In [7]:
mnli

Dataset({
    features: ['premise', 'hypothesis', 'label', 'idx'],
    num_rows: 392702
})

In [8]:
mnli[0]

{'premise': 'Conceptually cream skimming has two basic dimensions - product and geography.',
 'hypothesis': 'Product and geography are what make cream skimming work. ',
 'label': 1,
 'idx': 0}

In [9]:
# remove idx column
mnli = mnli.remove_columns(['idx'])
mnli

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 392702
})

In [10]:
# concatenate snli and mnli datasets

dataset = datasets.concatenate_datasets([snli, mnli])

In [11]:
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942854
})

In [12]:
# check schema of dataset

dataset.features

{'premise': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}

In [13]:
# Some rows in the dataset have label -1 which is invalid. They got introduced in the dataset when annotators were unsure of the labels among 0, 1 & 2.
dataset = dataset.filter(
    lambda x: False if x['label'] == -1 else True
)

Filter:   0%|          | 0/942854 [00:00<?, ? examples/s]

In [14]:
dataset

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})

In [15]:
dataset.features

{'premise': Value(dtype='string', id=None),
 'hypothesis': Value(dtype='string', id=None),
 'label': ClassLabel(names=['entailment', 'neutral', 'contradiction'], id=None)}

In [16]:
dataset.flatten()

Dataset({
    features: ['premise', 'hypothesis', 'label'],
    num_rows: 942069
})

In [17]:
dataset[0]

{'premise': 'A person on a horse jumps over a broken down airplane.',
 'hypothesis': 'A person is training his horse for a competition.',
 'label': 1}

### Training


In [18]:
batch_size = 16
model_name = 'bert-case-uncased'
epochs = 1
warmup_steps = int(0.1 * len(dataset))

Convert the training examples into `InputExamples` objects. Ita takes 2 following parameters -
* `texts` - list of strings representing the pairs of texts / triplets in training examples
* `label` - float or an integer

In [19]:
InputExample

sentence_transformers.readers.InputExample.InputExample

In [20]:
train_samples = []

for row in tqdm(dataset):
  train_samples.append(InputExample(
      texts=[row['premise'], row['hypothesis']],
      label=row['label']
  ))

  0%|          | 0/942069 [00:00<?, ?it/s]

In [21]:
# add this list of InputExample objects to pytorch DataLoader class to perform
# shuffling and distribute into batches

train_dataloader = DataLoader(train_samples, shuffle=True, batch_size=batch_size)

In [None]:
# Construct the network

bert = models.Transformer('bert-base-uncased')
pooler = models.Pooling(bert.get_word_embedding_dimension(),     # 768 for BERT
                        pooling_mode_mean_tokens=True)

# modules parameters takes a list of model layers to be executed consecutively

model = SentenceTransformer(modules=[bert, pooler])

In [23]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

#### Add Loss function

We train this model with a classification objective function - [Softmax loss / cross entropy loss](https://arxiv.org/pdf/1908.10084.pdf) as described in the [SBERT paper](https://arxiv.org/pdf/1908.10084.pdf) while training on NLI dataset labeled with categorical values.



In [24]:
loss = losses.SoftmaxLoss(
    model=model,
    sentence_embedding_dimension=model.get_sentence_embedding_dimension(),
    num_labels=3
)

In [None]:
#Tune the model
model.fit(train_objectives=[(train_dataloader, loss)],
          epochs=epochs,
          warmup_steps=warmup_steps,
          output_path='sbert_test_b',
          show_progress_bar=True)

#### Inference - Encode the sentences

In [29]:
sentences = ['This is an example sentence', 'Each sentence is converted']

model = SentenceTransformer('sbert_test_b')
embeddings = model.encode(sentences)

In [30]:
embeddings

array([[ 0.02250259, -0.07829171, -0.02303071, ..., -0.0082793 ,
         0.02652686, -0.00201896],
       [ 0.04170233,  0.00109741, -0.01553419, ..., -0.02181627,
        -0.06359359, -0.00875289]], dtype=float32)

#### References
* Fine tune sentence transformers by James Briggs - https://youtu.be/aSx0jg9ZILo
* https://www.sbert.net/docs/training/overview.html#training-overview
* https://huggingface.co/blog/how-to-train-sentence-transformers