In [None]:
# imports
from transformers import AutoTokenizer, AutoModelForMaskedLM
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch
from tqdm.notebook import tqdm
import pickle
from datasets import Dataset
import seaborn as sns
from metrics import contextual_precision, contextual_recall, contextual_relevancy
import json


In [2]:
# !pip install faiss-cpu
# !pip uninstall faiss-cpu
# !pip install langchain_community
# !pip install sentence_transformers
# !pip install pyarrow
import faiss
from langchain_community.docstore.in_memory import InMemoryDocstore
from langchain_community.vectorstores import FAISS
from langchain_core.documents import Document
from sentence_transformers import SentenceTransformer, losses, InputExample
from torch.utils.data import Dataset, DataLoader, Subset

In [3]:
if torch.cuda.is_available():
  device = torch.device("cuda")
else:
  device = torch.device("cpu")
device

device(type='cuda')

In [4]:
torch.cuda.get_device_name()

'NVIDIA GeForce RTX 3070 Laptop GPU'

<h1> Finetune Adaptor llama index </h1>

In [None]:
!pip install llama-index
!pip install llama-index-embeddings-adapter
!pip install llama-index-finetuning

In [None]:
import torch
from typing import Any, List, Optional, Tuple#, Union
from llama_index.finetuning import EmbeddingAdapterFinetuneEngine
from llama_index.core.evaluation import EmbeddingQAFinetuneDataset
import pandas as pd
import numpy as np
from llama_index.embeddings.adapter.utils import TwoLayerNN
from llama_index.core.embeddings import resolve_embed_model
from tqdm import tqdm
from llama_index.embeddings.adapter import AdapterEmbeddingModel

In [None]:
# Load data on which you want to finetune
train_df = pd.read_parquet('train-00000-of-00007.parquet')

In [4]:
train_df.head()

Unnamed: 0,answers,passages,query,query_id,query_type,wellFormedAnswers
0,[The immediate impact of the success of the ma...,"{'is_selected': [1, 0, 0, 0, 0, 0, 0, 0, 0, 0]...",)what was the immediate impact of the success ...,1185869,DESCRIPTION,[]
1,[Restorative justice that fosters dialogue bet...,"{'is_selected': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]...",_________ justice is designed to repair the ha...,1185868,DESCRIPTION,[]
2,[The reasons why Stalin wanted to control East...,"{'is_selected': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]...",why did stalin want control of eastern europe,1185854,DESCRIPTION,[]
3,[Nails rust in water because water allows the ...,"{'is_selected': [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]...",why do nails get rusty,1185755,DESCRIPTION,[]
4,"[Depona Ab is a library in Vilhelmina, Sweden.]","{'is_selected': [0, 0, 0, 0, 0, 0, 1, 0, 0, 0]...",depona ab,1184773,DESCRIPTION,[]


In [None]:
train_queries = dict()
corpus = dict()
train_relevant_docs = dict()

count = 0

for index, row in train_df.iterrows():
    
    train_queries[f'{index}'] = row['query']
           
    for corpus_index, passage in enumerate(row['passages']['passage_text']):

        corpus[f'{index}.{corpus_index}'] = passage
    
    train_relevant_docs[f'{index}'] = [f'{index}.{i}' for i in range(len(row['passages']['passage_text']))] 

In [None]:
# Initialize train dataset
train_dataset = EmbeddingQAFinetuneDataset(
    queries = train_queries, corpus = corpus, relevant_docs = train_relevant_docs
)

In [None]:
# Initialize and load model
model_name = "sentence-transformers/all-MiniLM-L6-v2"
base_embed_model = resolve_embed_model(f"local:{model_name}")

  from .autonotebook import tqdm as notebook_tqdm


INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
INFO:sentence_transformers.SentenceTransformer:2 prompts are loaded, with the keys: ['query', 'text']
2 prompts are loaded, with the keys: ['query', 'text']


In [None]:
# Define adapter model 
adapter_model = TwoLayerNN(
    in_features=384,
    hidden_features=512,
    out_features=384,
    bias=True,
    add_residual=True
)

In [None]:
# Set up fine-tune engine

finetune_engine = EmbeddingAdapterFinetuneEngine(
    train_dataset,
    base_embed_model,
    model_output_path="chks/llama_index/all-MiniLM-L6-v2-finetuned-TwoLayerNN", # path to save fine-tuned adapter
    adapter_model=adapter_model,
    epochs=5,
    verbose=False,
    device="cuda",
    batch_size = 64
)

Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Batches: 100%|██████████| 1/1 [00:00<00:00,  1.43it/s]


In [None]:
# Run the engine to finetune model
finetune_engine.finetune()

In [None]:
# Paths to the saved files
model_path = "chks/llama_index/all-MiniLM-L6-v2-finetuned-TwoLayerNN-v2"
config_path = f"{model_path}/config.json"
model_weights_path = f"{model_path}/pytorch_model.bin"

# Load the config
with open(config_path, "r") as f:
    config = json.load(f)

# load the fine-tuned TwoLayerNN adapter model
adapter_model = TwoLayerNN(
    in_features=config["in_features"],
    hidden_features=config["hidden_features"],
    out_features=config["out_features"],
    bias=config["bias"],
    activation_fn_str=config["activation_fn_str"],
    add_residual=config["add_residual"]
)

# Load the adapter model's weights
adapter_model.load_state_dict(torch.load(model_weights_path))

# Load the base SentenceTransformer model
base_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device = device)


INFO:sentence_transformers.SentenceTransformer:Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2
Load pretrained SentenceTransformer: sentence-transformers/all-MiniLM-L6-v2


  adapter_model.load_state_dict(torch.load(model_weights_path))


In [None]:
# Load test dataset
test_df = pd.read_parquet('test-00000-of-00001.parquet')

In [None]:
index = faiss.IndexFlatIP(384)
# index_gpu = faiss.index_cpu_to_all_gpus(index)

def embedding_function(text):
    # inputs = tokenizer(text, return_tensors="pt")
    # outputs = model(**inputs)
    # return outputs.last_hidden_state.detach().numpy()[0, 0]
    embeddings = base_model.encode(text)
    embeddings = adapter_model.forward(torch.tensor(embeddings, dtype = torch.float32, requires_grad=False)).detach()
    return embeddings / np.linalg.norm(embeddings)


vector_store = FAISS(
    embedding_function=embedding_function,
    index=index,
    docstore=InMemoryDocstore(),
    index_to_docstore_id={},
)

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [None]:
# Make vector database with the fine-tuned adapter embeddings

unique_passages = set()
index_to_docstore_id = {}
document_counter = 0
batch_documents = []

# Loop over passages using tqdm for progress tracking
for passages in test_df['passages']:
  for passage in passages['passage_text']:

    passage_text = passage# Assuming the structure of the dict is correct
    tokens = len(base_model.tokenizer(passage)['input_ids'])

    if tokens <= 256 and passage_text not in unique_passages:
      unique_passages.add(passage_text)  # Add to set
      document = Document(page_content=passage_text)
      batch_documents.append(document)

      index_to_docstore_id[document_counter] = passage_text  # Store mapping
      document_counter += 1

      # Process and add documents in batches
      if len(batch_documents) >= 10000:  # Adjust batch size as necessary
        vector_store.add_documents(documents=batch_documents)
        batch_documents = []  # Clear the batch after adding
          # print('----')

# Add any remaining documents
if batch_documents:
    vector_store.add_documents(documents=batch_documents)

In [None]:
# faiss.write_index(index, "finetune_llamindex_v2_faiss_index.bin")
# with open("finetune_llamindex_v2_docstore_metadata.pkl.pkl", "wb") as f:
#     pickle.dump(index_to_docstore_id, f)

In [17]:
loaded_index = faiss.read_index("finetune_llamindex_v2_faiss_index.bin")

with open("finetune_llamindex_v2_docstore_metadata.pkl", "rb") as f:
    loaded_metadata = pickle.load(f)

In [None]:
metrics = {}
metrics['precision'] = {}
metrics['recall'] = {}
metrics['relevancy'] = {}

N = 5000
K_list = [3, 5, 10, 100]
for K in K_list:
    metrics['precision'][K] = []
    metrics['recall'][K] = []
    metrics['relevancy'][K] = []

np.random.seed(0)
indices = np.random.choice(len(test_df), N)
for i in tqdm(indices):

    query_text = test_df['query'][i]
    query_embedding = embedding_function(query_text).numpy().astype('float32').reshape(1, -1)
    distances, indices = loaded_index.search(query_embedding, K_list[-1])
    retrieved_passages = [loaded_metadata[j] for j in indices[0]]

    for K in K_list:
        metrics['precision'][K].append(contextual_precision(retrieved_passages[:K], test_df['passages'][i]['passage_text']))
        metrics['recall'][K].append(contextual_recall(retrieved_passages[:K], test_df['passages'][i]['passage_text']))
        metrics['relevancy'][K].append(contextual_relevancy(retrieved_passages[:K], test_df['passages'][i]['passage_text']))

In [None]:
# Uncomment to save pickle file
# with open(f'metrics_5000_finetuned_llama-index_v2.pkl', 'wb') as f:
#     pickle.dump(metrics, f)

In [None]:
# Metrics of adapter fine-tuned on base model , 10 epochs of trainset-1
with open('metrics_5000_finetuned_llama-index_v2.pkl', 'rb') as f:
    
    results = pickle.load(f)

fin_results = {}

for key in results.keys():
    
    fin_results[key] = {}
    
    for k in results[key].keys():
        
        fin_results[key][k] = sum(results[key][k])/len(results[key][k])

fin_results = pd.DataFrame(fin_results)
fin_results

Unnamed: 0,precision,recall,relevancy
3,0.70775,0.173215,0.575067
5,0.70002,0.257429,0.51284
10,0.662183,0.387006,0.38582
100,0.510808,0.689014,0.068714


In [None]:
# Metrics of adapter fine-tuned on checkpoint , 10 epochs of trainset-2
with open('metrics_5000_finetuned_llama-index_v3.pkl', 'rb') as f:
    
    results = pickle.load(f)

fin_results = {}

for key in results.keys():
    
    fin_results[key] = {}
    
    for k in results[key].keys():
        
        fin_results[key][k] = sum(results[key][k])/len(results[key][k])

fin_results = pd.DataFrame(fin_results)
fin_results

Unnamed: 0,precision,recall,relevancy
3,0.710217,0.173821,0.577133
5,0.702145,0.257452,0.51284
10,0.663034,0.387201,0.386
100,0.510749,0.690078,0.06882
