# Compare Embedding Function

## Create collections

In [1]:
import chromadb
from chromadb.utils.embedding_functions.sentence_transformer_embedding_function import SentenceTransformerEmbeddingFunction
from chromadb.utils.embedding_functions.ollama_embedding_function import OllamaEmbeddingFunction
from sentence_transformers import SentenceTransformer

client = chromadb.Client()

all_MiniLM_L6_v2_tuned = client.get_or_create_collection(
    name="all_MiniLM_L6_v2_tuned", 
    embedding_function=SentenceTransformerEmbeddingFunction(
        "sentence-transformers/all-MiniLM-L6-v2", device="cuda")
)

In [2]:
# Function to encode query using the adapter
def encode_query(query, base_model, adapter):
    device = next(adapter.parameters()).device
    query_emb = base_model.encode(query, convert_to_tensor=True).to(device)
    adapted_query_emb = adapter(query_emb)
    return adapted_query_emb.cpu().detach().numpy()

## Define evaluation functions

Copied from https://github.com/ALucek/linear-adapter-embedding/blob/main/Linear_Adapter.ipynb

In [3]:
def reciprocal_rank(retrieved_docs, ground_truth, k):
    try:
        rank = retrieved_docs.index(ground_truth) + 1
        return 1.0 / rank if rank <= k else 0.0
    except ValueError:
        return 0.0


def hit_rate(retrieved_docs, ground_truth, k):
    return 1.0 if ground_truth in retrieved_docs[:k] else 0.0


## Prepare dataset

In [4]:
from utils import get_train_test

train_cq, test_cq = get_train_test("../../data/chunk_question_pairs")
all_chunks = train_cq['chunk'].tolist()
all_chunks.extend(test_cq['chunk'].tolist())

## Embed chunks

In [5]:
def retrieve_documents_embeddings(collection, query_embedding, k=10):
    query_embedding_list = query_embedding.tolist()
    
    results = collection.query(
        query_embeddings=[query_embedding_list],
        n_results=k)
    return results['documents'][0]

In [6]:
from tqdm import tqdm
def insert_documents(collection, all_chunks):
    i = 0
    for chunk in tqdm(all_chunks):
        collection.add(
        documents=[chunk],
        ids=[f"chunk_{i}"]
        )
        i += 1

In [7]:
insert_documents(all_MiniLM_L6_v2_tuned, all_chunks)

100%|██████████| 24000/24000 [03:16<00:00, 122.01it/s]


## Evaluate result

In [8]:
import numpy as np

def evaluate_adapter(validation_data, collection, base_model, adapter, k=10):
    hit_rates = []
    reciprocal_ranks = []
    
    for _, row in validation_data.iterrows():
        question = row['question']
        ground_truth = row['chunk']
        
        # Generate embedding for the question
        question_embedding = encode_query(question, base_model, adapter)
        # Retrieve documents using the embedding
        retrieved_docs = retrieve_documents_embeddings(collection, question_embedding, k)
        
        # Calculate metrics
        hr = hit_rate(retrieved_docs, ground_truth, k)
        rr = reciprocal_rank(retrieved_docs, ground_truth, k)
        
        hit_rates.append(hr)
        reciprocal_ranks.append(rr)
    
    # Calculate average metrics
    avg_hit_rate = np.mean(hit_rates)
    avg_reciprocal_rank = np.mean(reciprocal_ranks)
    
    return {
        'average_hit_rate': avg_hit_rate,
        'average_reciprocal_rank': avg_reciprocal_rank
    }

In [None]:

# Later, loading and using the saved information
import torch
from utils import LinearAdapter


loaded_dict = torch.load('../../models/mc_emb/adapters/linear_adapter_20.pth')

base_model = SentenceTransformer('all-MiniLM-L6-v2')

# Recreate the adapter
loaded_adapter = LinearAdapter(base_model.get_sentence_embedding_dimension())  # Initialize with appropriate parameters
loaded_adapter.load_state_dict(loaded_dict['adapter_state_dict'])

# Access the training parameters
training_params = loaded_dict['adapter_kwargs']

print("Adapter loaded successfully.")
print("Training parameters used:")
for key, value in training_params.items():
    print(f"{key}: {value}")

In [None]:
results = evaluate_adapter(test_cq, all_MiniLM_L6_v2_tuned, base_model, loaded_adapter, k=10)
print(f"Average Hit Rate @10: {results['average_hit_rate']}")
print(f"Mean Reciprocal Rank @10: {results['average_reciprocal_rank']}")

NameError: name 'base_model' is not defined