In [109]:
import os
import json
import pandas as pd
import numpy as np
from tqdm.auto import tqdm
from langchain_core.documents import Document
from pinecone import ServerlessSpec
from pinecone import Pinecone
from dotenv import load_dotenv
from sentence_transformers import CrossEncoder
from langchain.vectorstores import FAISS
from langchain_openai import OpenAIEmbeddings, OpenAI

load_dotenv()

True

In [110]:
OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY")
EMBEDDINGS = OpenAIEmbeddings(api_key=os.environ["OPENAI_API_KEY"])
PINECONE_API_KEY = os.getenv("PINECONE_API_KEY")

In [111]:
with open('documents_with_ids.json', 'rt') as f_in:
    documents = json.load(f_in)

In [112]:
# select the sample used to create the ground truth
documents = documents[2100: 2150] + documents[4100: 4200] + documents[1100: 1150] + documents[3100: 3150]

In [114]:
documents[0]

{'page_content': 'CHAPTER 5   THE CARdiAC ExAminATion 83\nRate of pulse\nPractised observers can estimate the rate quickly. Formal \ncounting over 30 seconds is accurate and requires only \nsimple mathematics to obtain the rate per minute. The \nnormal resting heart rate in adults is usually said to be \nbetween 60 and 100 beats per minute but a more \nsensible range is probably 55 to 95 (95% of normal \npeople). Bradycardia (from the Greek bradys ‘slow’ , \nkardia ‘heart’) is defined as a heart rate of less than 60 \nbeats per minute. Tachycardia (from the Greek tachys \n‘swift’ , kardia ‘heart’) is defined as a heart rate over 100 \nbeats per minute (see the OSCE ECGs nos 2, 3 and 4 \nat ). The causes of bradycardia and \ntachycardia are listed in Table 5.1.\nRhythm\nThe rhythm of the pulse can be regular or irregular. An \nirregular rhythm can be completely irregular with no \npattern (irregularly irregular or chaotic rhythm); this is \nusually due to atrial fibrillation (see Table 

In [177]:
def create_chunk_embedding(documents):
    
    processed_docs = [
    Document(page_content=doc["page_content"], metadata={"id": doc["id"]})
    for doc in documents]

    # OpenAI Embeddings
    openai_embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
    faiss_index_openai = FAISS.from_documents(processed_docs, openai_embeddings)
    openai_index_path = "../embeddings/faiss_index_openai"
    faiss_index_openai.save_local(openai_index_path)

    return {"openai_index": openai_index_path}

In [178]:
index_path_one = create_chunk_embedding(documents)

### Load and Query FAISS retrieval

In [179]:
def load_faiss_index(index_path):
    embeddings = OpenAIEmbeddings(model="text-embedding-ada-002")
    faiss_index = FAISS.load_local(
        index_path,
        embeddings,
        allow_dangerous_deserialization=True
    )
    return faiss_index

In [180]:
def query_faiss_index(faiss_index, query, k=20):
    query_text = query["question"]
    results = faiss_index.similarity_search(query_text, k=k)
    return results

In [181]:
first_index = load_faiss_index("../embeddings/faiss_index_openai")

In [182]:
df_ground_truth = pd.read_csv('questions.csv')

In [183]:
ground_truth = df_ground_truth.to_dict(orient='records')

In [184]:
ground_truth[0]

{'question': 'Assessment of bradycardia or tachycardia',
 'case_prompt': 'A 68-year-old male presents to the clinic with complaints of lightheadedness and palpitations for the past week.',
 'document': 'e84c82d5'}

### Code to evaluate retrieval

In [272]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [273]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

### Evaluate Faiss retreival

In [218]:
def evaluate(ground_truth):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = query_faiss_index(first_index, q, k=30)
        relevance = [d.metadata["id"] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [219]:
# semantic search result
evaluate(ground_truth)

100%|██████████| 250/250 [03:28<00:00,  1.20it/s]


{'hit_rate': 0.92, 'mrr': 0.5351691156908999}

### Faiss retrival with reranking evaluation

In [274]:
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [278]:
def rerank_documents(query, retrieved_docs):
    pairs = []
    for doc in retrieved_docs:
        text = doc.page_content  
        
        if text is None:
            print(f"Warning: Missing 'page_content' in metadata for document ID {doc.metadata.get('id', 'Unknown')}")
            continue 
        
        pairs.append((query, text))
    
    if not pairs:
        print("No valid documents found for reranking.")
        return []
    
    scores = reranker.predict(pairs)
    reranked_docs = [doc for _, doc in sorted(zip(scores, retrieved_docs), key=lambda x: x[0], reverse=True)]
    
    return reranked_docs


In [279]:
def evaluate_with_reranking(ground_truth):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        query_text = q['question']
        
        retrieved_results =  query_faiss_index(first_index, q, k=30)

        reranked_results = rerank_documents(query_text, retrieved_results)

        relevance = [d.metadata["id"] == doc_id for d in reranked_results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [280]:
evaluate_with_reranking(ground_truth)

100%|██████████| 250/250 [18:27<00:00,  4.43s/it]


{'hit_rate': 0.92, 'mrr': 0.6199280855376507}

## Evaluate Pinecone Retrieval

In [134]:
embeddings_model = OpenAIEmbeddings(model="text-embedding-ada-002")

def generate_embeddings(documents):
    texts = [doc["page_content"] for doc in documents] 
    embeddings = embeddings_model.embed_documents(texts) 
    return embeddings

# Example usage
chunked_document_embeddings = generate_embeddings(documents)

In [135]:
print(f"Generated {len(chunked_document_embeddings)} embeddings.")

Generated 250 embeddings.


In [136]:
def combine_vector_and_text(documents: list[dict], doc_embeddings: list[list[float]]) -> list[dict]:
    data_with_metadata = []

    for doc, embedding in zip(documents, doc_embeddings):
        data_item = {
            "id": str(doc.get("id", "unknown_id")),
            "values": embedding, 
            "metadata": {"page_content": doc.get("page_content", ""), "id": str(doc.get("id", "unknown_id"))},
        }
        data_with_metadata.append(data_item)

    return data_with_metadata


In [137]:
data = combine_vector_and_text(documents=documents, doc_embeddings=chunked_document_embeddings) 

In [161]:
pc = Pinecone(api_key=PINECONE_API_KEY)
index = pc.create_index(
name="final",
dimension=1536,
metric="cosine",
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
)

In [162]:
index = pc.Index("final")

In [163]:
def upsert_data_to_pinecone(data_with_metadata: list[dict[str, any]], chunk_size: int = 1000) -> None:
    
    for i in range(0, len(data_with_metadata), chunk_size):
        chunk = data_with_metadata[i:i + chunk_size]
        index.upsert(vectors=chunk)


upsert_data_to_pinecone(data_with_metadata= data)

In [164]:
def get_query_embeddings(query: str) -> list[float]:
    query_embeddings = EMBEDDINGS.embed_query(query)
    return query_embeddings

In [192]:
def query_pinecone_index(
    query_embeddings: list, top_k: int = 20, include_metadata: bool = True
) -> dict[str, any]:
    query_response = index.query(
        vector=query_embeddings, top_k=top_k, include_metadata=include_metadata
    )
    return query_response

In [197]:
def evaluate_pinecone(ground_truth):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']  
        query_text = q['question']
        embeddings = get_query_embeddings(query_text) 
        results = query_pinecone_index(embeddings, top_k=30)
        
        relevance = [match["metadata"]["id"] == doc_id for match in results["matches"]]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [196]:
evaluate_pinecone(ground_truth)

100%|██████████| 250/250 [03:55<00:00,  1.06it/s]


{'hit_rate': 0.92, 'mrr': 0.5272179545726603}

### Document reranking to improve Pinecone retrival

In [200]:
reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")

In [290]:
def rerank_documents(query, retrieved_docs):
    
    pairs = []
    doc_list = []

    for doc in retrieved_docs["matches"]:
        text = doc["metadata"].get("page_content") 
        # if text is None:
        #     print(f"Warning: Missing 'page_content' in metadata for document ID {doc['metadata'].get('id', 'Unknown')}")
        #     continue 
        
        pairs.append((query, text))
        doc_list.append(doc)

    if not pairs:
        print("No valid documents found for reranking.")
        return []

    scores = reranker.predict(pairs)

    scored_docs = list(zip(scores, doc_list))
    scored_docs.sort(key=lambda x: x[0], reverse=True)
    reranked_docs = [doc for _, doc in scored_docs]
    
    return reranked_docs

In [None]:
def evaluate_pinecone_reranking(ground_truth):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        query_text = q['question']
        
        embeddings = get_query_embeddings(query_text)
        retrieved_results = query_pinecone_index(embeddings, top_k=30)

        reranked_results = rerank_documents(query_text, retrieved_results)

        relevance = [match["metadata"]["id"] == doc_id for match in reranked_results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }



In [217]:
evaluate_pinecone_reranking(ground_truth)

100%|██████████| 250/250 [17:55<00:00,  4.30s/it]


{'hit_rate': 0.92, 'mrr': 0.6199398867798867}