# Finetune a Two-Layer Feedforward Neural Network on top of [paraphrase-mpnet-base-v2](https://huggingface.co/sentence-transformers/paraphrase-mpnet-base-v2)
The weights of the pre-trained model are frozen.<br>
Loss function used: `MultipleNegativesRankingLoss`

In [1]:
import torch
from typing import Any, List, Optional, Tuple#, Union
from llama_index.core import SimpleDirectoryReader
from llama_index.core.base.embeddings.base import BaseEmbedding
from llama_index.core.node_parser import SentenceSplitter
from llama_index.embeddings.huggingface.base import HuggingFaceEmbedding
from llama_index.embeddings.huggingface.pooling import Pooling
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.finetuning.embeddings.adapter_utils import BaseAdapter
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(0)
torch.manual_seed(0)
%matplotlib inline

## Create dataset

In [2]:
df_kb = pd.read_pickle('../../data/kb_chunks_emb.pkl')
print(df_kb.shape)
df_kb.head(3)

(33545, 3)


Unnamed: 0,doc_url,chunk_content,embedding
0,https://ghr.nlm.nih.gov/condition/keratoderma-...,keratoderma with woolly hair : medlineplus gen...,"[-0.0039987266, 0.08037464, 0.049785912, -0.12..."
1,https://ghr.nlm.nih.gov/condition/keratoderma-...,"##ma, woolly hair, and a form of cardiomyopath...","[-0.09539697, -0.09132044, 0.0027289127, 0.005..."
2,https://ghr.nlm.nih.gov/condition/keratoderma-...,##pathy in people with this group of condition...,"[0.026278932, 0.060939535, 0.031438153, -0.044..."


In [3]:
df_ques_url_train = pd.read_pickle('../../data/questions_relevant_urls_chunks_train.pkl')
print(df_ques_url_train.shape)
df_ques_url_train.head(3)

(20000, 3)


Unnamed: 0,question,relevant_docs_urls,num_rel_chunks
0,What is (are) keratoderma with woolly hair ?,[https://ghr.nlm.nih.gov/condition/keratoderma...,5
1,How many people are affected by keratoderma wi...,[https://ghr.nlm.nih.gov/condition/keratoderma...,5
2,What are the genetic changes related to kerato...,[https://ghr.nlm.nih.gov/condition/keratoderma...,5


In [4]:
corpus = dict()
for _, row in df_kb.iterrows():
    if row['doc_url'] not in corpus:
        corpus[row['doc_url']] = row['chunk_content']

In [5]:
train_queries = dict()
train_relevant_docs = dict()

for i, row in df_ques_url_train.iterrows():
    ques_id = str(i)
    rel_docs = row['relevant_docs_urls']
    train_queries[ques_id] = row['question']
    train_relevant_docs[ques_id] = rel_docs

In [6]:
train_dataset = EmbeddingQAFinetuneDataset(
    queries = train_queries, corpus = corpus, relevant_docs = train_relevant_docs
)

In [7]:
train_dataset.save_json('data/train_dataset.json')

## Load the model and finetune for 8 epochs

In [7]:
# requires torch dependency
from llama_index.embeddings.adapter.utils import TwoLayerNN
from llama_index.core.embeddings import resolve_embed_model
from llama_index.embeddings.adapter import AdapterEmbeddingModel

In [8]:
model_name = "sentence-transformers/paraphrase-mpnet-base-v2"
base_embed_model = resolve_embed_model(f"local:{model_name}")

INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/paraphrase-mpnet-base-v2
Load pretrained SentenceTransformer: sentence-transformers/paraphrase-mpnet-base-v2
INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


In [9]:
adapter_model = TwoLayerNN(
    in_features=768,
    hidden_features=1024,
    out_features=768,
    bias=True,
    add_residual=True
)

In [None]:
finetune_engine = EmbeddingAdapterFinetuneEngine(
    train_dataset,
    base_embed_model,
    model_output_path="mpnet_finetuned_ep8",
    adapter_model=adapter_model,
    epochs=8,
    verbose=False,
    device="cuda",
    batch_size = 8
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
finetune_engine.finetune()

Epoch:   0%|          | 0/25 [00:00<?, ?it/s]

Iteration:   0%|          | 0/2500 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

**Note:** Takes 918 MiB of GPU memory. Takes 5108s for 8 epochs.<br>
Time per epoch: 638s = 10m 38s