# Finetuning Embedding Models

In this notebook, I finetune the embedding models to get better performance.

In [18]:
# libraries 

from typing import List, Tuple, Dict, Any
import time
import os

# utilities
from tqdm.notebook import tqdm
from rich import print
from dotenv import load_dotenv
env = load_dotenv('./.env', override=True)

# foreign libraries 
from llama_index.finetuning import EmbeddingQAFinetuneDataset
from llama_index.finetuning import SentenceTransformersFinetuneEngine
from weaviate_interface import WeaviateClient


### Import Training-Validation Datasets

In [19]:
training_path = './data/training_data_300.json'
valid_path = './data/validation_data_100.json'

In [20]:
training_set = EmbeddingQAFinetuneDataset.from_json(training_path)
valid_set = EmbeddingQAFinetuneDataset.from_json(valid_path)
num_training_examples = len(training_set.queries)
num_valid_examples = len(valid_set.queries)
print(f'# Training Samples: {num_training_examples}\n# Validation Samples: {num_valid_examples}')

### Instantiate a Weaviate Client

This step is only needed because Chris coded sentence-transformers/miniLM6 model in the Weaviate Cleint, so we can import its name directly from the attributes.

In [21]:
#read env vars from local .env file
api_key = os.environ['WEAVIATE_API_KEY']
url = os.environ['WEAVIATE_ENDPOINT']

#instantiate client
client = WeaviateClient(api_key, url)

#check if WCS instance is live and ready
client.is_live(), client.is_ready()

(True, True)

### Wrangle Model Output Data

In [22]:
model_id = "sentence-transformers/multi-qa-distilbert-cos-v1"
model_id

'sentence-transformers/multi-qa-distilbert-cos-v1'

In [23]:
#always a good idea to name your fine-tuned so that you can easily identify it,
#especially if you plan on doing multiple training runs with different params
#also probably a good idea to include the # of training samples you are using in the name

# Currenly fine-tuning the standard model coded into the Weaviate client which is minilm-l6-v2

#model_id = client.model_name_or_path
model_ext = model_id.split('/')[1]
models_dir = './models'
if not os.path.exists('./models'):
    os.makedirs('./models') 
else:
    print(f'{models_dir} already exists')
ft_model_name = f'finetuned-{model_ext}-{num_training_examples}'
model_outpath = os.path.join(models_dir, ft_model_name)

print(f'Model ID: {model_id}')
print(f'Model Outpath: {model_outpath}')

### Calling on the specific Transformers library finetuning engine

This part will change depending on which embedding model you would like to finetune.

In [None]:
# Instantiate a finetuning engine
finetune_engine = SentenceTransformersFinetuneEngine(
    training_set,
    batch_size=32,
    model_id=model_id,
    model_output_path=model_outpath,
    val_dataset=valid_set,
    epochs=10
)

In [25]:
finetune_engine.finetune()

Epoch:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]

Iteration:   0%|          | 0/10 [00:00<?, ?it/s]