# Fine tuning

Reference: [Fine-Tuning Embeddings for RAG with Synthetic Data](https://medium.com/llamaindex-blog/fine-tuning-embeddings-for-rag-with-synthetic-data-e534409a3971)

Finetune an opensource sentencetransformers embedding model on our synthetically generated dataset.

## Load pretrained model

In [21]:
# from tqdm import tqdm
from sentence_transformers import SentenceTransformer
from sentence_transformers import losses
from sentence_transformers.evaluation import InformationRetrievalEvaluator

import json

from torch.utils.data import DataLoader
from sentence_transformers import InputExample

In [2]:
model_id = "BAAI/bge-small-en"
model = SentenceTransformer(model_id)

Downloading (…)ab102/.gitattributes: 100%|██████████| 1.52k/1.52k [00:00<00:00, 4.64MB/s]
Downloading (…)_Pooling/config.json: 100%|██████████| 190/190 [00:00<00:00, 694kB/s]
Downloading (…)2d2d7ab102/README.md: 100%|██████████| 78.9k/78.9k [00:00<00:00, 637kB/s]
Downloading (…)2d7ab102/config.json: 100%|██████████| 684/684 [00:00<00:00, 2.66MB/s]
Downloading (…)ce_transformers.json: 100%|██████████| 124/124 [00:00<00:00, 593kB/s]
Downloading model.safetensors: 100%|██████████| 133M/133M [00:18<00:00, 7.38MB/s] 
Downloading pytorch_model.bin: 100%|██████████| 134M/134M [00:14<00:00, 9.40MB/s] 
Downloading (…)nce_bert_config.json: 100%|██████████| 52.0/52.0 [00:00<00:00, 181kB/s]
Downloading (…)cial_tokens_map.json: 100%|██████████| 125/125 [00:00<00:00, 459kB/s]
Downloading (…)ab102/tokenizer.json: 100%|██████████| 711k/711k [00:00<00:00, 1.65MB/s]
Downloading (…)okenizer_config.json: 100%|██████████| 366/366 [00:00<00:00, 1.42MB/s]
Downloading (…)2d2d7ab102/vocab.txt: 100%|██████████|

In [3]:
model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 512, 'do_lower_case': True}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': True, 'pooling_mode_mean_tokens': False, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
)

## Dataloader

In [8]:
TRAIN_DATASET_FPATH = 'afa_docs/train_val_data/train_dataset.json'
VAL_DATASET_FPATH = 'afa_docs/train_val_data/val_dataset.json'

# We use a very small batchsize to run this toy example on a local machine. 
# This should typically be much larger.
BATCH_SIZE = 10

In [9]:
with open(TRAIN_DATASET_FPATH, 'r+') as f:
    train_dataset = json.load(f)

with open(VAL_DATASET_FPATH, 'r+') as f:
    val_dataset = json.load(f)

In [11]:
train_dataset.keys()

dict_keys(['queries', 'corpus', 'relevant_docs'])

In [13]:
dataset = train_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

examples = []
for query_id, query in queries.items():
    node_id = relevant_docs[query_id][0]
    text = corpus[node_id]
    example = InputExample(texts=[query, text])
    examples.append(example)

In [16]:
len(examples)

22

In [18]:
loader = DataLoader(
    examples, batch_size=BATCH_SIZE
)

## Loss

`MultipleNegativesRankingLoss` is a great loss function if you only have positive pairs, for example, only pairs of similar texts like pairs of paraphrases, pairs of duplicate questions, pairs of (query, response), or pairs of (source_language, target_language).

This loss function works great to train embeddings for retrieval setups where you have positive pairs (e.g. (query, relevant_doc)) as it will sample in each batch n-1 negative docs randomly.

The performance usually increases with increasing batch sizes.

For more detals, see:
- [docs](https://www.sbert.net/docs/package_reference/losses.html)

In [20]:
loss = losses.MultipleNegativesRankingLoss(model)

## Evaluator 

We setup an evaluator with our val split of the dataset to monitor how well the embedding model is performing during training.

In [22]:
dataset = val_dataset

corpus = dataset['corpus']
queries = dataset['queries']
relevant_docs = dataset['relevant_docs']

evaluator = InformationRetrievalEvaluator(queries, corpus, relevant_docs)

## Training

The training loop is very straight forward to set up thanks to `sentencetransformers` high-level model training API.
All we need to do is plugging in the data loader, loss function, and evaluator that we defined in the previous cells (along with a couple of additional minor settings).

In [23]:
# We train the model for very few epochs in this toy example.
# This should typically be higher for better performance.
EPOCHS = 2

In [24]:
warmup_steps = int(len(loader) * EPOCHS * 0.1)

model.fit(
    train_objectives=[(loader, loss)],
    epochs=EPOCHS,
    warmup_steps=warmup_steps,
    output_path='exp_finetune',
    show_progress_bar=True,
    evaluator=evaluator, 
    evaluation_steps=50,
)

Iteration: 100%|██████████| 3/3 [00:06<00:00,  2.06s/it]
Iteration: 100%|██████████| 3/3 [00:05<00:00,  1.91s/it]
Epoch: 100%|██████████| 2/2 [00:17<00:00,  8.76s/it]
