In [None]:
import pandas as pd
from sentence_transformers import SentenceTransformer
import faiss
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
df = pd.read_csv("/data/ECG_Reference_Level_Knowledge_Base.csv")
texts = df["Content"].tolist()

In [None]:
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)

In [6]:
embedding_dim = embeddings.shape[1]
index = faiss.IndexFlatIP(embedding_dim)
index.add(embeddings)

In [7]:
def retrieve_context(query, top_k=2):
    q_emb = embedder.encode([query], convert_to_numpy=True, normalize_embeddings=True)
    scores, idx = index.search(q_emb, top_k)
    results = [df.iloc[i]["Content"] for i in idx[0]]
    return "\n\n".join(results)

In [None]:
#Load the generation model
gen_model_name = "google/flan-t5-base"  
gen_tokenizer = AutoTokenizer.from_pretrained(gen_model_name)
gen_model = AutoModelForSeq2SeqLM.from_pretrained(gen_model_name)
gen_model.eval()

In [None]:
#Answer function
def rag_answer(query, top_k=2, max_new_tokens=200):
    #retrieve relevant ECG text
    context = retrieve_context(query, top_k=top_k)

    # build prompt for the LLM
    prompt = f"""
You are an expert in ECG interpretation.
Use the following reference information to answer the question clearly and accurately.

Context:
{context}

Question: {query}
Answer clearly with reasoning:
"""
    
    #generate answer
    inputs = gen_tokenizer(prompt, return_tensors="pt", truncation=True)
    with torch.no_grad():
        outputs = gen_model.generate(**inputs, max_new_tokens=max_new_tokens)
    answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return answer.strip()

In [None]:
question = "what is the normal interval between the P wave and the QRS complex on an ECG?"
print("\n Question:", question)
print("\n Answer:\n")
print(rag_answer(question))