<a href="https://colab.research.google.com/github/Gitstrong3333/MachineLearning_Python-/blob/main/RAG_MedicalAssistant.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
import pandas as pd

file_path = '/content/drive/MyDrive/medquad.csv'  # adjust if it's in a subfolder
df = pd.read_csv(file_path)

df.head()

Unnamed: 0,question,answer,source,focus_area
0,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma
1,What causes Glaucoma ?,"Nearly 2.7 million people have glaucoma, a lea...",NIHSeniorHealth,Glaucoma
2,What are the symptoms of Glaucoma ?,Symptoms of Glaucoma Glaucoma can develop in ...,NIHSeniorHealth,Glaucoma
3,What are the treatments for Glaucoma ?,"Although open-angle glaucoma cannot be cured, ...",NIHSeniorHealth,Glaucoma
4,What is (are) Glaucoma ?,Glaucoma is a group of diseases that can damag...,NIHSeniorHealth,Glaucoma


In [None]:
df.describe()


Unnamed: 0,question,answer,source,focus_area
count,16412,16407,16412,16398
unique,14984,15817,9,5126
top,What causes Causes of Diabetes ?,This condition is inherited in an autosomal re...,GHR,Breast Cancer
freq,20,348,5430,53


In [None]:
import sys
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
import argparse
from sklearn.model_selection import train_test_split
import torch.optim as optim
from tqdm.auto import tqdm


class RAGDataset(Dataset):
    """
    Custom Dataset class for loading query-context-answer pairs for training the RAG model.
    This class handles tokenizing the data and preparing it for PyTorch's DataLoader.
    """
    def __init__(self, dataframe, tokenizer, source_len, target_len):
        """
        Initialize the dataset.

        Args:
            dataframe (pd.DataFrame): The dataset containing query, context, and answer columns.
            tokenizer (transformers.PreTrainedTokenizer): Tokenizer for encoding text.
            source_len (int): Maximum length for the input sequence.
            target_len (int): Maximum length for the target sequence.
        """
        self.data = dataframe.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.source_len = source_len
        self.target_len = target_len
        self.query = self.data['query']
        self.context = self.data['context']
        self.answer = self.data['answer']

    def __len__(self):
        """
        Returns:
            int: The number of samples in the dataset.
        """
        return len(self.data)

    def __getitem__(self, idx):
        """
        Retrieve a single data point from the dataset.

        Args:
            idx (int): Index of the data point.

        Returns:
            dict: Dictionary containing tokenized input and target sequences.
        """
        query = str(self.query[idx])
        context = str(self.context[idx])
        answer = str(self.answer[idx])

        # combine query and context into a single input string
        source_text = f"query: {query} context: {context}"

        # tokenize the input string
        source = self.tokenizer.encode_plus(
            source_text, max_length=self.source_len, padding="max_length", truncation=True, return_tensors="pt"
        )
        # tokenize the answer string
        target = self.tokenizer.encode_plus(
            answer, max_length=self.target_len, padding="max_length", truncation=True, return_tensors="pt"
        )

        return {
            "input_ids": source["input_ids"].squeeze(),
            "attention_mask": source["attention_mask"].squeeze(),
            "labels": target["input_ids"].squeeze(),
        }


def preprocess_data(file_path):
    """
    Preprocess the dataset to include required columns and handle missing values.

    Args:
        file_path (str): Path to the dataset CSV file.

    Returns:
        pd.DataFrame: Preprocessed dataframe with 'query', 'context', and 'answer' columns.
    """
    # load the CSV file
    df = pd.read_csv(file_path)

    # retain only the 'question' and 'answer' columns
    df = df[['question', 'answer']]

    # drop rows with missing values
    df = df.dropna(subset=['question', 'answer'])

    # rename columns for consistency
    df = df.rename(columns={'question': 'query', 'answer': 'answer'})

    # add a 'context' column (using the answer as context for now)
    df['context'] = df['answer']
    return df


def train_epoch(model, loader, optimizer, device, epoch, logging_steps):
    """
    Train the model for one epoch.

    Args:
        model (torch.nn.Module): The model being trained.
        loader (DataLoader): DataLoader for the training data.
        optimizer (torch.optim.Optimizer): Optimizer for updating model weights.
        device (torch.device): Device to run the model on (CPU, GPU, etc.).
        epoch (int): Current epoch number.
        logging_steps (int): Frequency of logging progress during training.

    Returns:
        float: The average training loss for the epoch.
    """
    model.train()  # set the model to training mode
    total_loss = 0  # initialize total loss
    progress_bar = tqdm(loader, desc=f"Epoch {epoch}", disable=False)  # progress bar for tracking

    for step, batch in enumerate(progress_bar):
        # move inputs and labels to the specified device
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        # forward pass through the model
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        # backward pass and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # log the loss every `logging_steps`
        if (step + 1) % logging_steps == 0:
            progress_bar.set_postfix({"loss": loss.item()})

    # return the average loss for the epoch
    return total_loss / len(loader)


def main():
    """
    Main function to fine-tune the T5 model for Retrieval-Augmented Generation (RAG).
    This version handles Jupyter Notebook's extra arguments gracefully.
    """
    # simulating command-line arguments for Jupyter Notebook
    class Args:
        model_name = "t5-base"
        train_file = "/content/drive/MyDrive/medquad.csv"
        output_dir = "rag_model"
        batch_size = 8
        epochs = 3
        lr = 5e-5
        max_input_length = 512
        max_output_length = 150
        device = "cuda"
        logging_steps = 10

    args = Args()  # use the custom Args class to store arguments

    # enhanced device selection logic
    if args.device == "mps" and torch.backends.mps.is_available():
        device = torch.device("mps")
    elif args.device == "cuda" and torch.cuda.is_available():
        device = torch.device("cuda")
    else:
        device = torch.device("cpu")
    print(f"Using device: {device}")

    # preprocess the data
    df = preprocess_data(args.train_file)

    # split data into training and validation sets
    train_df, val_df = train_test_split(df, test_size=0.1, random_state=42)

    # load the tokenizer and model
    tokenizer = T5Tokenizer.from_pretrained(args.model_name, legacy=False)
    model = T5ForConditionalGeneration.from_pretrained(args.model_name).to(device)

    # create DataLoaders for training and validation datasets
    train_dataset = RAGDataset(train_df, tokenizer, args.max_input_length, args.max_output_length)
    val_dataset = RAGDataset(val_df, tokenizer, args.max_input_length, args.max_output_length)
    train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=args.batch_size, shuffle=False)

    # optimizer
    optimizer = optim.AdamW(model.parameters(), lr=args.lr)

    # training loop
    for epoch in range(1, args.epochs + 1):
        train_loss = train_epoch(model, train_loader, optimizer, device, epoch, args.logging_steps)
        print(f"Epoch {epoch} Training Loss: {train_loss:.4f}")

    # save the model and tokenizer
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Model saved to {args.output_dir}")

In [None]:
import pandas as pd
import faiss
from sentence_transformers import SentenceTransformer

# define parameters
csv_file = "/content/drive/MyDrive/medquad.csv"  # path to your dataset
updated_csv_file = "/content/drive/MyDrive/medquad_out.csv" # output dataset path
index_file = "/content/drive/MyDrive/context.index"  # path to save the FAISS index

# load the dataset
df = pd.read_csv(csv_file)

# add a 'context' column if it doesn't exist
if 'context' not in df.columns:
    print("No 'context' column found. Creating it from the 'answer' column.")
    if 'answer' not in df.columns:
        raise ValueError("The dataset must have an 'answer' column to create the 'context'.")
    df['context'] = df['answer']  # use 'answer' as the context

# save the updated dataset with the 'context' column
df.to_csv(updated_csv_file, index=False)
print(f"Updated dataset with 'context' column saved to {updated_csv_file}")

# use SentenceTransformer to generate embeddings for the context column
embedder = SentenceTransformer("all-MiniLM-L6-v2")
contexts = df["context"].tolist()
context_embeddings = embedder.encode(contexts, convert_to_tensor=False).astype("float32")

# create a FAISS index for the embeddings
dimension = context_embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)  # L2 (Euclidean) distance
index.add(context_embeddings)

# save the FAISS index
faiss.write_index(index, index_file)
print(f"FAISS index saved to {index_file}")

No 'context' column found. Creating it from the 'answer' column.
Updated dataset with 'context' column saved to /content/drive/MyDrive/medquad_out.csv


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


FAISS index saved to /content/drive/MyDrive/context.index


In [7]:
# =========================
# RAG: Build/Load Index + Generate Answers with T5
# =========================
import os
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration

# ---------- CONFIG ----------
CSV_IN         = "/content/drive/MyDrive/medquad.csv"                 # original dataset
CSV_WITH_CTX   = "/content/drive/MyDrive/medquad_with_context.csv"    # output with 'context'
INDEX_PATH     = "/content/drive/MyDrive/context.index"               # FAISS index path

# Model: use a real public model id (safe, no login)
MODEL_ID       = "google/flan-t5-base"  # or "t5-small" for speed

# Retrieval settings
TOP_K          = 3
MAX_NEW_TOKENS = 150
NUM_BEAMS      = 5

# Device selection (cuda, mps, or cpu)
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# ---------- 1) LOAD/CREATE DATA WITH 'context' ----------
def ensure_context(csv_in: str, csv_out: str) -> pd.DataFrame:
    df = pd.read_csv(csv_in)
    if 'context' not in df.columns:
        print("No 'context' column found. Creating it from the 'answer' column.")
        if 'answer' not in df.columns:
            raise ValueError("Dataset must have an 'answer' column to create 'context'.")
        df['context'] = df['answer'].astype(str)
    else:
        df['context'] = df['context'].astype(str)
    df.to_csv(csv_out, index=False)
    print(f"Saved dataset with 'context' → {csv_out} (rows: {len(df)})")
    return df

df = ensure_context(CSV_IN, CSV_WITH_CTX)

# ---------- 2) EMBEDDINGS + INDEX (build if missing) ----------
embedder = SentenceTransformer("all-MiniLM-L6-v2")

def build_faiss_index(texts: list[str], index_path: str) -> faiss.Index:
    # Cosine-friendly: normalize + Inner Product index
    emb = embedder.encode(
        texts,
        batch_size=64,
        convert_to_numpy=True,
        normalize_embeddings=True,  # unit vectors => cosine == inner product
        show_progress_bar=True
    ).astype("float32")

    d = emb.shape[1]
    index = faiss.IndexFlatIP(d)
    index.add(np.ascontiguousarray(emb))
    faiss.write_index(index, index_path)
    print(f"Built & saved FAISS index ({index.ntotal} vectors, dim={d}) → {index_path}")
    return index

def load_or_build_index(index_path: str, texts: list[str]) -> faiss.Index:
    if os.path.exists(index_path):
        index = faiss.read_index(index_path)
        print(f"Loaded FAISS index from {index_path} (ntotal={index.ntotal})")
        return index
    else:
        return build_faiss_index(texts, index_path)

index = load_or_build_index(INDEX_PATH, df["context"].tolist())

# ---------- 3) RETRIEVAL ----------
def retrieve_contexts(query: str, index: faiss.Index, df: pd.DataFrame, top_k: int = 3) -> list[str]:
    # Encode query
    qv = embedder.encode([query], convert_to_numpy=True).astype("float32")
    qv = np.ascontiguousarray(qv)

    # If index is Inner Product (cosine), normalize the query too.
    # (Index types with a `metric_type` attr report it; for older FAISS, try/except.)
    metric_ip = False
    try:
        metric_ip = getattr(index, "metric_type", None) == faiss.METRIC_INNER_PRODUCT
    except Exception:
        pass
    if metric_ip:
        faiss.normalize_L2(qv)

    distances, indices = index.search(qv, top_k)
    hits = []
    for i in indices[0]:
        if i >= 0 and i < len(df):
            hits.append(df.iloc[i]["context"])
    return hits

# ---------- 4) LOAD GENERATION MODEL ----------
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()

# ---------- 5) RAG INFERENCE ----------
def answer_query(query: str, top_k: int = 3) -> dict:
    contexts = retrieve_contexts(query, index, df, top_k=top_k)
    # Stronger instruction to list symptoms from context
    input_text = (
        "You are a concise medical assistant.\n"
        "Using ONLY the context, list the common symptoms of diabetes as bullet points.\n"
        "If something isn't in the context, do not include it.\n"
        f"Question: {query}\n"
        f"Context: {' '.join(contexts)}\n"
        "Answer:\n- "
    )
    input_ids = tokenizer.encode(input_text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=180,
            num_beams=5,
            no_repeat_ngram_size=3,
            length_penalty=1.2,
            early_stopping=True
        )
    ans = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    # keep bullet list clean
    ans = "\n".join(line for line in ans.splitlines() if line.strip())
    return {"query": query, "contexts": contexts, "answer": ans}


# ---------- 6) DEMO ----------
demo_query = "What are the symptoms of diabetes?"
result = answer_query(demo_query, top_k=TOP_K)

print("\n=== RESULT ===")
print(f"Query: {result['query']}")
print(f"Top-{TOP_K} Retrieved Context(s):")
for i, c in enumerate(result["contexts"], 1):
    print(f"  [{i}] {c[:300]}{'...' if len(c) > 300 else ''}")
print(f"\nGenerated Answer: {result['answer']}")


Using device: cpu
No 'context' column found. Creating it from the 'answer' column.
Saved dataset with 'context' → /content/drive/MyDrive/medquad_with_context.csv (rows: 16412)
Loaded FAISS index from /content/drive/MyDrive/context.index (ntotal=16412)


Token indices sequence length is longer than the specified maximum sequence length for this model (1111 > 512). Running this sequence through the model will result in indexing errors



=== RESULT ===
Query: What are the symptoms of diabetes?
Top-3 Retrieved Context(s):
  [1] The signs and symptoms of diabetes are
                
- being very thirsty  - urinating often  - feeling very hungry  - feeling very tired  - losing weight without trying  - sores that heal slowly  - dry, itchy skin  - feelings of pins and needles in your feet  - losing feeling in your feet  - blu...
  [2] Many people with diabetes experience one or more symptoms, including extreme thirst or hunger, a frequent need to urinate and/or fatigue. Some lose weight without trying. Additional signs include sores that heal slowly, dry, itchy skin, loss of feeling or tingling in the feet and blurry eyesight. So...
  [3] Diabetes is often called a "silent" disease because it can cause serious complications even before you have symptoms. Symptoms can also be so mild that you dont notice them. An estimated 8 million people in the United States have type 2 diabetes and dont know it, according to 2012 estima

In [8]:
# ================================
# RAG: Build/Load FAISS + Generate with T5
# ================================
import os
import re
import numpy as np
import pandas as pd
import faiss
import torch
from typing import List, Tuple
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration

# ---------- CONFIG ----------
CSV_IN         = "/content/drive/MyDrive/medquad.csv"                 # original dataset
CSV_WITH_CTX   = "/content/drive/MyDrive/medquad_with_context.csv"    # output with 'context'
INDEX_PATH     = "/content/drive/MyDrive/context.index"               # FAISS index path

# Public model (no auth needed). Use "t5-small" for faster CPU testing.
MODEL_ID       = "google/flan-t5-base"

# Retrieval / generation
TOP_K          = 3
MAX_NEW_TOKENS = 150
NUM_BEAMS      = 5

# Token budget for T5 (hard limit ~512). Keep margin for prompt & special tokens.
MAX_INPUT_TOKENS = 512
TOKEN_MARGIN     = 48   # leave headroom for prompt/meta

# ---------- DEVICE ----------
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# ---------- 1) DATA: ensure 'context' ----------
def ensure_context(csv_in: str, csv_out: str) -> pd.DataFrame:
    df = pd.read_csv(csv_in)
    if 'context' not in df.columns:
        print("No 'context' column found. Creating it from the 'answer' column.")
        if 'answer' not in df.columns:
            raise ValueError("Dataset must have an 'answer' column to create 'context'.")
        df['context'] = df['answer'].astype(str)
    else:
        df['context'] = df['context'].astype(str)
    df.to_csv(csv_out, index=False)
    print(f"Saved dataset with 'context' → {csv_out} (rows: {len(df)})")
    return df

df = ensure_context(CSV_IN, CSV_WITH_CTX)

# ---------- 2) EMBEDDING MODEL ----------
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# ---------- 3) FAISS INDEX (build IP/cosine; load existing if present) ----------
def build_faiss_index(texts: List[str], index_path: str) -> faiss.Index:
    emb = embedder.encode(
        texts,
        batch_size=64,
        convert_to_numpy=True,
        normalize_embeddings=True,  # cosine ready
        show_progress_bar=True
    ).astype("float32")
    d = emb.shape[1]
    index = faiss.IndexFlatIP(d)  # inner product == cosine on unit vectors
    index.add(np.ascontiguousarray(emb))
    faiss.write_index(index, index_path)
    print(f"Built & saved FAISS index ({index.ntotal} vectors, dim={d}) → {index_path}")
    return index

def load_or_build_index(index_path: str, texts: List[str]) -> faiss.Index:
    if os.path.exists(index_path):
        index = faiss.read_index(index_path)
        try:
            metric = getattr(index, "metric_type", None)
            metric_name = "IP" if metric == faiss.METRIC_INNER_PRODUCT else "L2"
        except Exception:
            metric_name = "unknown"
        print(f"Loaded FAISS index from {index_path} (ntotal={index.ntotal}, metric={metric_name})")
        return index
    else:
        return build_faiss_index(texts, index_path)

index = load_or_build_index(INDEX_PATH, df["context"].tolist())

# ---------- 4) RETRIEVAL ----------
def _maybe_normalize_query_for_ip(qv: np.ndarray, index: faiss.Index) -> np.ndarray:
    metric_type = getattr(index, "metric_type", None)
    if metric_type == faiss.METRIC_INNER_PRODUCT:
        qv = qv.copy()
        faiss.normalize_L2(qv)
    return qv

def retrieve(query: str, index: faiss.Index, df: pd.DataFrame, top_k: int = 3) -> List[Tuple[int, float, str]]:
    # Encode query
    qv = embedder.encode([query], convert_to_numpy=True).astype("float32")
    qv = np.ascontiguousarray(qv)
    qv = _maybe_normalize_query_for_ip(qv, index)

    distances, idxs = index.search(qv, top_k)
    out = []
    for j, i in enumerate(idxs[0]):
        if 0 <= i < len(df):
            out.append((int(i), float(distances[0][j]), df.iloc[i]["context"]))
    return out

# ---------- 5) TOKENIZER + MODEL ----------
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()

# ---------- 6) CONTEXT PREP: keep bullets & compact ----------
BULLET_PATTERN = re.compile(r"^\s*([-\*\u2022]|(\d+[\.\)]))\s+", re.UNICODE)

def keep_bullets_or_compact(text: str, max_lines: int = 20, max_chars: int = 800) -> str:
    """Prefer bullet lines; otherwise return a compact snippet."""
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    bullet_lines = [ln for ln in lines if BULLET_PATTERN.match(ln) or " - " in ln]
    chosen = bullet_lines if bullet_lines else lines
    snippet = "\n".join(chosen[:max_lines])
    if len(snippet) > max_chars:
        snippet = snippet[:max_chars].rsplit(" ", 1)[0] + "..."
    return snippet

def pack_contexts_for_budget(
    question: str,
    contexts: List[str],
    max_input_tokens: int = MAX_INPUT_TOKENS,
    token_margin: int = TOKEN_MARGIN
) -> str:
    """Pack as many preprocessed contexts as fit under the tokenizer budget."""
    header = (
        "You are a concise medical assistant.\n"
        "Using ONLY the context, list the common symptoms of diabetes as bullet points.\n"
        "If something isn't in the context, do not include it.\n"
        f"Question: {question}\n"
        "Context:\n"
    )
    # Token budget for contexts only
    header_ids = tokenizer.encode(header, add_special_tokens=False)
    remaining = max(8, max_input_tokens - token_margin - len(header_ids))

    packed_contexts: List[str] = []
    used = 0
    for ctx in contexts:
        processed = keep_bullets_or_compact(ctx)
        ids = tokenizer.encode(processed + "\n", add_special_tokens=False)
        if used + len(ids) <= remaining:
            packed_contexts.append(processed)
            used += len(ids)
        else:
            # Try a smaller slice of this context
            shortened = keep_bullets_or_compact(ctx, max_lines=10, max_chars=400)
            ids2 = tokenizer.encode(shortened + "\n", add_special_tokens=False)
            if used + len(ids2) <= remaining:
                packed_contexts.append(shortened)
                used += len(ids2)
            # else skip it
    body = "\n".join(packed_contexts)
    prompt = header + body + "\nAnswer:\n- "
    return prompt

# ---------- 7) QA FUNCTION ----------
def answer_query(question: str, top_k: int = TOP_K) -> dict:
    hits = retrieve(question, index, df, top_k=top_k)
    contexts = [h[2] for h in hits]  # texts only
    prompt = pack_contexts_for_budget(question, contexts, MAX_INPUT_TOKENS, TOKEN_MARGIN)

    # Final safety: enforce truncation in case
    input_ids = tokenizer.encode(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_TOKENS
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            num_beams=NUM_BEAMS,
            no_repeat_ngram_size=3,
            length_penalty=1.2,
            early_stopping=True,
        )
    ans = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
    # Clean bullet list
    ans = "\n".join(line for line in ans.splitlines() if line.strip())

    return {
        "query": question,
        "hits": hits,
        "answer": ans,
        "prompt_tokens": int(input_ids.shape[1]),
    }

# ---------- 8) DEMO ----------
demo_question = "What are the symptoms of diabetes?"
result = answer_query(demo_question, top_k=TOP_K)

print("\n=== RESULT ===")
print(f"Query: {result['query']}")
print(f"Prompt tokens used: {result['prompt_tokens']} / {MAX_INPUT_TOKENS}")
print(f"Top-{TOP_K} Retrieved Context(s):")
for rank, (row_idx, score, text) in enumerate(result["hits"], 1):
    snippet = text[:300] + ("..." if len(text) > 300 else "")
    print(f"  [{rank}] row={row_idx} score={score:.4f}  {snippet}")
print("\nGenerated Answer:")
print(result["answer"])


Using device: cpu
No 'context' column found. Creating it from the 'answer' column.
Saved dataset with 'context' → /content/drive/MyDrive/medquad_with_context.csv (rows: 16412)
Loaded FAISS index from /content/drive/MyDrive/context.index (ntotal=16412, metric=L2)

=== RESULT ===
Query: What are the symptoms of diabetes?
Prompt tokens used: 384 / 512
Top-3 Retrieved Context(s):
  [1] row=16214 score=0.4028  The signs and symptoms of diabetes are
                
- being very thirsty  - urinating often  - feeling very hungry  - feeling very tired  - losing weight without trying  - sores that heal slowly  - dry, itchy skin  - feelings of pins and needles in your feet  - losing feeling in your feet  - blu...
  [2] row=112 score=0.4611  Many people with diabetes experience one or more symptoms, including extreme thirst or hunger, a frequent need to urinate and/or fatigue. Some lose weight without trying. Additional signs include sores that heal slowly, dry, itchy skin, loss of feeling or tin

In [11]:
# Improvement of the above code

# ==========================================================
# RAG with FAISS + Sentence-Transformers + T5 (robust output)
# ==========================================================
# If needed in a fresh Colab:
# !pip install -q faiss-cpu sentence-transformers transformers

import os, re
from typing import List, Tuple
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer
from transformers import T5Tokenizer, T5ForConditionalGeneration

# ---------------- CONFIG ----------------
CSV_IN         = "/content/drive/MyDrive/medquad.csv"                 # original dataset
CSV_WITH_CTX   = "/content/drive/MyDrive/medquad_with_context.csv"    # dataset with 'context'
INDEX_PATH     = "/content/drive/MyDrive/context.index"               # FAISS index path

# Public model (no auth needed). Use "t5-small" for faster CPU tests.
MODEL_ID       = "google/flan-t5-base"

TOP_K          = 3
MAX_NEW_TOKENS = 150
MIN_NEW_TOKENS = 60     # avoid ultra-short answers
NUM_BEAMS      = 4

# T5 input budget (hard limit ~512)
MAX_INPUT_TOKENS = 512
TOKEN_MARGIN     = 48    # reserve for prompt/special tokens

# ---------------- DEVICE ----------------
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

# ---------------- 1) DATA ----------------
def ensure_context(csv_in: str, csv_out: str) -> pd.DataFrame:
    df = pd.read_csv(csv_in)
    if "context" not in df.columns:
        print("No 'context' column found. Creating it from the 'answer' column.")
        if "answer" not in df.columns:
            raise ValueError("Dataset must have an 'answer' column to create 'context'.")
        df["context"] = df["answer"].astype(str)
    else:
        df["context"] = df["context"].astype(str)
    df.to_csv(csv_out, index=False)
    print(f"Saved dataset with 'context' → {csv_out} (rows: {len(df)})")
    return df

df = ensure_context(CSV_IN, CSV_WITH_CTX)

# ---------------- 2) EMBEDDINGS ----------------
embedder = SentenceTransformer("all-MiniLM-L6-v2")

# ---------------- 3) FAISS INDEX ----------------
def build_faiss_index(texts: List[str], index_path: str) -> faiss.Index:
    emb = embedder.encode(
        texts,
        batch_size=64,
        convert_to_numpy=True,
        normalize_embeddings=True,  # cosine-ready (unit vectors)
        show_progress_bar=True
    ).astype("float32")
    d = emb.shape[1]
    index = faiss.IndexFlatIP(d)    # inner product == cosine on unit vectors
    index.add(np.ascontiguousarray(emb))
    faiss.write_index(index, index_path)
    print(f"Built & saved FAISS index ({index.ntotal} vectors, dim={d}) → {index_path}")
    return index

def load_or_build_index(index_path: str, texts: List[str]) -> faiss.Index:
    if os.path.exists(index_path):
        index = faiss.read_index(index_path)
        try:
            metric = getattr(index, "metric_type", None)
            metric_name = "IP" if metric == faiss.METRIC_INNER_PRODUCT else "L2"
        except Exception:
            metric_name = "unknown"
        print(f"Loaded FAISS index from {index_path} (ntotal={index.ntotal}, metric={metric_name})")
        return index
    else:
        return build_faiss_index(texts, index_path)

index = load_or_build_index(INDEX_PATH, df["context"].tolist())

# ---------------- 4) RETRIEVAL ----------------
def _maybe_normalize_query_for_ip(qv: np.ndarray, index: faiss.Index) -> np.ndarray:
    metric_type = getattr(index, "metric_type", None)
    if metric_type == faiss.METRIC_INNER_PRODUCT:
        qv = qv.copy()
        faiss.normalize_L2(qv)
    return qv

def retrieve(query: str, index: faiss.Index, df: pd.DataFrame, top_k: int = 3) -> List[Tuple[int, float, str]]:
    qv = embedder.encode([query], convert_to_numpy=True).astype("float32")
    qv = np.ascontiguousarray(qv)
    qv = _maybe_normalize_query_for_ip(qv, index)
    distances, idxs = index.search(qv, top_k)
    out = []
    for j, i in enumerate(idxs[0]):
        if 0 <= i < len(df):
            out.append((int(i), float(distances[0][j]), df.iloc[i]["context"]))
    return out

# ---------------- 5) MODEL ----------------
tokenizer = T5Tokenizer.from_pretrained(MODEL_ID, legacy=False)
model = T5ForConditionalGeneration.from_pretrained(MODEL_ID).to(device)
model.eval()

# ---------------- 6) CONTEXT PACKING ----------------
BULLET_PATTERN = re.compile(r"^\s*([-\*\u2022]|(\d+[\.\)]))\s+", re.UNICODE)

def keep_bullets_or_compact(text: str, max_lines: int = 20, max_chars: int = 800) -> str:
    """Prefer bullet lines; otherwise compact snippet."""
    lines = [ln.strip() for ln in text.splitlines() if ln.strip()]
    bullet_lines = [ln for ln in lines if BULLET_PATTERN.match(ln) or " - " in ln]
    chosen = bullet_lines if bullet_lines else lines
    snippet = "\n".join(chosen[:max_lines])
    if len(snippet) > max_chars:
        snippet = snippet[:max_chars].rsplit(" ", 1)[0] + "..."
    return snippet

def pack_contexts_for_budget(question: str, contexts: List[str]) -> str:
    # Generic instruction suitable for lists (symptoms/causes/etc.)
    header = (
        "You are a concise medical assistant.\n"
        "Using ONLY the context, answer the question as bullet points.\n"
        "If the answer is a list (e.g., symptoms/causes/treatments), return one item per line.\n"
        "Return ONLY bullet points; do not add extra text.\n"
        f"Question: {question}\n"
        "Context:\n"
    )
    header_ids = tokenizer.encode(header, add_special_tokens=False)
    remaining = max(8, MAX_INPUT_TOKENS - TOKEN_MARGIN - len(header_ids))

    packed, used = [], 0
    for ctx in contexts:
        processed = keep_bullets_or_compact(ctx)
        ids = tokenizer.encode(processed + "\n", add_special_tokens=False)
        if used + len(ids) <= remaining:
            packed.append(processed)
            used += len(ids)
        else:
            shortened = keep_bullets_or_compact(ctx, max_lines=10, max_chars=400)
            ids2 = tokenizer.encode(shortened + "\n", add_special_tokens=False)
            if used + len(ids2) <= remaining:
                packed.append(shortened)
                used += len(ids2)
    body = "\n".join(packed)
    return header + body + "\nAnswer:\n- "

# ---------------- 7) PARSERS & FALLBACK ----------------
def parse_bullets(text: str, max_items: int = 20) -> list[str]:
    """
    Extract bullet items from model text. Handles lines like
    '- a - b - c' (inline bullets) and standard bullet lines.
    Deduplicates and trims.
    """
    items = []

    def add_item(s: str):
        s = s.strip()
        s = re.sub(r"^([-\*\u2022]|\d+[\.\)])\s*", "", s)  # strip bullet/number
        s = re.sub(r"\s+", " ", s).strip(" .;:-").strip()
        if s:
            items.append(s)

    for raw in text.splitlines():
        line = raw.strip()
        if not line:
            continue

        if re.search(r"\s-\s", line):
            chunks = [c for c in re.split(r"\s-\s", line) if c.strip()]
            if not line.lstrip().startswith(("-", "•", "*")) and len(chunks) > 1:
                chunks = chunks[1:]  # drop preamble
            for ch in chunks:
                add_item(ch)
            continue

        if BULLET_PATTERN.match(line) or line.lstrip().startswith(("-", "•", "*")):
            add_item(line)

    seen, uniq = set(), []
    for it in items:
        key = re.sub(r"[^a-z0-9]+", " ", it.lower()).strip()
        if key and key not in seen:
            seen.add(key)
            uniq.append(it)
            if len(uniq) >= max_items:
                break
    return uniq

def extract_bullets_from_contexts(contexts: List[str]) -> str:
    raw_lines = []
    for ctx in contexts:
        for ln in ctx.splitlines():
            ln = ln.strip()
            if not ln:
                continue
            if BULLET_PATTERN.match(ln) or " - " in ln:
                ln = re.sub(r"\s*-\s*", " - ", ln).strip()
                ln = re.sub(r"^([-\*\u2022]|\d+[\.\)])\s*", "", ln).strip()
                raw_lines.append(ln)
    seen, uniq = set(), []
    for ln in raw_lines:
        key = re.sub(r"[^a-z0-9]+", " ", ln.lower()).strip()
        if key and key not in seen:
            seen.add(key)
            uniq.append(ln)
    return "" if not uniq else "- " + "\n- ".join(uniq)

# ---------------- 8) QA ----------------
def answer_query(question: str, top_k: int = TOP_K) -> dict:
    hits = retrieve(question, index, df, top_k=top_k)
    contexts = [h[2] for h in hits]

    prompt = pack_contexts_for_budget(question, contexts)
    input_ids = tokenizer.encode(
        prompt,
        return_tensors="pt",
        truncation=True,
        max_length=MAX_INPUT_TOKENS
    ).to(device)

    with torch.no_grad():
        outputs = model.generate(
            input_ids,
            max_new_tokens=MAX_NEW_TOKENS,
            min_new_tokens=MIN_NEW_TOKENS,
            num_beams=NUM_BEAMS,
            no_repeat_ngram_size=3,
            length_penalty=1.0,
            early_stopping=False,
        )
    raw = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()

    # Prefer the model's bullets
    bullets = parse_bullets(raw, max_items=20)

    # Fallback to deterministic extraction if model is too terse
    if len(bullets) < 3:
        fallback = extract_bullets_from_contexts(contexts)
        final = fallback if fallback else "- " + "\n- ".join([ln for ln in raw.splitlines() if ln.strip()][:10])
    else:
        final = "- " + "\n- ".join(bullets[:12])

    return {
        "query": question,
        "hits": hits,
        "answer": final,
        "prompt_tokens": int(input_ids.shape[1]),
    }

# ---------------- 9) DEMO ----------------
demo_question = "What are the symptoms of diabetes?"
result = answer_query(demo_question, top_k=TOP_K)

print("\n=== RESULT ===")
print(f"Query: {result['query']}")
print(f"Prompt tokens used: {result['prompt_tokens']} / {MAX_INPUT_TOKENS}")
print(f"Top-{TOP_K} Retrieved Context(s):")
for rank, (row_idx, score, text) in enumerate(result["hits"], 1):
    snippet = text[:300] + ("..." if len(text) > 300 else "")
    print(f"  [{rank}] row={row_idx} score={score:.4f}  {snippet}")
print("\nGenerated Answer:\n" + result["answer"])


Using device: cpu
No 'context' column found. Creating it from the 'answer' column.
Saved dataset with 'context' → /content/drive/MyDrive/medquad_with_context.csv (rows: 16412)
Loaded FAISS index from /content/drive/MyDrive/context.index (ntotal=16412, metric=L2)

=== RESULT ===
Query: What are the symptoms of diabetes?
Prompt tokens used: 405 / 512
Top-3 Retrieved Context(s):
  [1] row=16214 score=0.4028  The signs and symptoms of diabetes are
                
- being very thirsty  - urinating often  - feeling very hungry  - feeling very tired  - losing weight without trying  - sores that heal slowly  - dry, itchy skin  - feelings of pins and needles in your feet  - losing feeling in your feet  - blu...
  [2] row=112 score=0.4611  Many people with diabetes experience one or more symptoms, including extreme thirst or hunger, a frequent need to urinate and/or fatigue. Some lose weight without trying. Additional signs include sores that heal slowly, dry, itchy skin, loss of feeling or tin