In [None]:
import os
import time
import numpy as np
import pandas as pd
from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
from llama_index.core.retrievers import VectorIndexRetriever
from llama_index.core.postprocessor import LLMRerank
from llama_index.embeddings.openai import OpenAIEmbedding
from langchain_core.prompts import ChatPromptTemplate
from langchain.chat_models import ChatOpenAI
from llama_index.llms.openai import OpenAI
from dotenv import load_dotenv
import tiktoken

tokenizer = tiktoken.encoding_for_model("gpt-4o")

def count_tokens(text: str) -> int:
    return len(tokenizer.encode(text))

# Load environment variables
if os.path.exists("../config.env"):
    load_dotenv("../config.env")

# Load precomputed document summary embeddings
embeddings = np.load("../data/processed/lp/summary_embeddings/embeddings.npy")
df = pd.read_csv("../data/processed/lp/summary_embeddings/index.tsv", sep="\t")
embedding_model = OpenAIEmbedding()

# LLM setup
llm_expansion = ChatOpenAI(temperature=0.0, model="gpt-4o")
llm_summary = ChatOpenAI(temperature=0.0, model="gpt-3.5-turbo-0125")
llm_rerank = OpenAI(model="gpt-4o", temperature=0.0)
reranker = LLMRerank(llm=llm_rerank, top_n=3)

# Prompt template for query expansion
query_expansion_prompt = ChatPromptTemplate.from_messages([
    ("system", "You are an expert in HIV medicine."),
    ("user", (
        "Given the query below, provide a concise, comma-separated list of related terms and synonyms "
        "useful for document retrieval. Return only the list, no explanations.\n\n"
        "Query: {query}"
    ))
])

def cosine_similarity_numpy(query_vec: np.ndarray, matrix: np.ndarray) -> np.ndarray:
    query_norm = query_vec / np.linalg.norm(query_vec)
    matrix_norm = matrix / np.linalg.norm(matrix, axis=1, keepdims=True)
    return matrix_norm @ query_norm

def expand_query(query):
    start = time.time()
    messages = query_expansion_prompt.format_messages(query=query)
    expanded = llm_expansion.invoke(messages).content.strip()
    return expanded, time.time() - start

def retrieve(expanded_query):
    start = time.time()
    query_vec = embedding_model.get_text_embedding(expanded_query)
    sims = cosine_similarity_numpy(query_vec, embeddings)
    top_paths = df.loc[sims.argsort()[-3:][::-1], "vectorestore_path"].tolist()
    top_paths = [
        os.path.join("..", p)
        for p in df.loc[sims.argsort()[-3:][::-1], "vectorestore_path"].tolist()
    ]
    
    all_nodes = []
    for path in top_paths:
        ctx = StorageContext.from_defaults(persist_dir=path)
        index = load_index_from_storage(ctx)
        retriever = VectorIndexRetriever(index=index, similarity_top_k=3)
        all_nodes.extend(retriever.retrieve(expanded_query))

    return all_nodes, top_paths, time.time() - start

# def rerank(expanded_query, nodes):
#     start = time.time()
#     bundle = QueryBundle(expanded_query)
#     reranked = reranker.postprocess_nodes(nodes, bundle)
#     return reranked, time.time() - start

# def rerank(expanded_query, nodes):
#     start = time.time()
    
#     # Embed the expanded query
#     query_vec = embedding_model.get_text_embedding(expanded_query)
    
#     # Embed each node’s text
#     texts = [n.text for n in nodes]
#     node_vecs = embedding_model.get_text_embedding_batch(texts)
    
#     # Compute cosine similarities
#     sims = cosine_similarity_numpy(query_vec, np.array(node_vecs))
#     top_idxs = sims.argsort()[-2:][::-1]  # top 2 most similar
    
#     # Select top nodes
#     reranked = [nodes[i] for i in top_idxs]
#     return reranked, time.time() - start

def rerank(expanded_query, nodes, embedder, llm_reranker, top_n_cosine=5, top_n_llm=2):
    """
    Hybrid reranker:
    1. Use cosine similarity to pre-filter top_n_cosine nodes.
    2. Use LLM reranker to select top_n_llm from those.
    """
    start = time.time()

    # Embed query and candidate nodes
    query_vec = embedder.get_text_embedding(expanded_query)
    texts = [n.text for n in nodes]
    node_vecs = embedder.get_text_embedding_batch(texts)

    # Cosine similarity ranking
    sims = cosine_similarity_numpy(query_vec, np.array(node_vecs))
    top_cosine_idxs = sims.argsort()[-top_n_cosine:][::-1]
    prefiltered_nodes = [nodes[i] for i in top_cosine_idxs]

    # LLM reranking
    bundle = QueryBundle(expanded_query)
    final_nodes = llm_reranker.postprocess_nodes(prefiltered_nodes, bundle)[0:top_n_llm]

    return final_nodes, time.time() - start


def summarize(query, contexts):
    start = time.time()
    prompt = (
        "You're a clinical assistant helping a provider answer a question using HIV/AIDS guidelines.\n\n"
        f"Question: {query}\n\n"
        "Provide a detailed summary of the most relevant points from the following source texts using bullet points.\n\n"
        + "\n\n".join([f"Source {i+1}: {text}" for i, text in enumerate(contexts)])
    )
    input_tokens = count_tokens(prompt)
    result = llm_summary.invoke(prompt).content.strip()
    elapsed = time.time() - start
    return result, elapsed, input_tokens



In [None]:

# Example question list (replace/expand as needed)
queries = [
    "What are important drug interactions with dolutegravir?",
    "How should PrEP be provided to adolescent girls?",
    "When is cotrimoxazole prophylaxis indicated?",
    "What are the guidelines for ART failure?",
    "How do you manage HIV in pregnancy?",
    "When should infants start ART?",
    "What is the recommended PrEP regimen for men who have sex with men?",
    "How often should viral load be monitored?",
    "What is the preferred first-line regimen for adults?",
    "Can pregnant women use dolutegravir?",
    "When is tenofovir not recommended?",
    "How should HIV be managed in tuberculosis coinfection?",
    "What lab tests are used to monitor ART?",
    "When is second-line ART initiated?",
    "What adherence strategies are recommended?",
    "What are the contraindications to efavirenz?",
    "Can HIV be managed with a two-drug regimen?",
    "How do you handle treatment failure?",
    "When is regimen switching appropriate?",
    "What is the role of resistance testing?"
]

timing_results = []

for q in queries:
    print(f"\n⏳ {q}")

    expanded, t_expand = expand_query(q)
    nodes, top_paths, t_retrieve = retrieve(expanded)
    # reranked_nodes, t_rerank = rerank(expanded, nodes)
    reranked_nodes, t_rerank = rerank(
        expanded_query=expanded,
        nodes=nodes,
        embedder=embedding_model,
        llm_reranker=reranker,
        top_n_cosine=5,
        top_n_llm=2
    )
    answer, t_summarize, token_count = summarize(q, [n.text for n in reranked_nodes])


    timing_results.append({
        "question": q,
        "expand_time": round(t_expand, 2),
        "retrieve_time": round(t_retrieve, 2),
        "rerank_time": round(t_rerank, 2),
        "summarize_time": round(t_summarize, 2),
        "input_tokens_to_summarizer": token_count,
        "total_time": round(t_expand + t_retrieve + t_rerank + t_summarize, 2)
    })


# Save results
df = pd.DataFrame(timing_results)
df.to_csv("timing_benchmark_with_rerank_gpt35_hybridrerank.csv", index=False)

print("\n✅ Done. Timing results saved to timing_benchmark_with_rerank.csv")
print(df)
