<a href="https://colab.research.google.com/github/Danny2173/RAGproject/blob/main/RAG_Testing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install faiss-cpu
!pip install gradio -q

# Imports
import os, json, re, pickle, itertools
import numpy as np
import pandas as pd
import torch
import faiss
import matplotlib.pyplot as plt
from collections import defaultdict
from nltk.stem import PorterStemmer
from tqdm import tqdm
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import normalize
from transformers import (
    DPRQuestionEncoder,
    DPRQuestionEncoderTokenizer,
    AutoTokenizer,
    AutoModelForSeq2SeqLM,
    DPRContextEncoder,
    DPRContextEncoderTokenizer,
    LogitsProcessor,
    LogitsProcessorList
)
from peft import PeftModel
import gradio as gr

!git clone -q https://github.com/Danny2173/RAGproject.git /content/RAGproject
%cd /content/RAGproject

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

# Load tokenizer and LoRA model
tokenizer = AutoTokenizer.from_pretrained("t5-large")
base_model = AutoModelForSeq2SeqLM.from_pretrained("t5-large")
model = PeftModel.from_pretrained(base_model, "/content/RAGproject/LoRA/t5-lora-final")
model = model.to(device)

with open("/content/RAGproject/FAISS/corpus_texts.pkl", "rb") as f:
    corpus_texts = pickle.load(f)

with open("/content/RAGproject/FAISS/normalized_for_index.pkl", "rb") as f:
    normalized_for_index = pickle.load(f)

with open("/content/RAGproject/FAISS/corpus.pkl", "rb") as f:
    corpus = pickle.load(f)

index = faiss.read_index("/content/RAGproject/FAISS/faiss_index.index")


def count_heading_term_matches(doc_text, matched_main_terms):
    # Normalize text and main terms
    doc_text = doc_text.lower()
    matched_main_terms = set(matched_main_terms)

    # Extracting titles and section headings
    title_match = re.match(r"^(.*?)\n", doc_text)
    title = title_match.group(1) if title_match else ""

    section_matches = re.findall(r"^section:\s*(.*?)$", doc_text, flags=re.MULTILINE)
    subsection_matches = re.findall(r"^subsection:\s*(.*?)$", doc_text, flags=re.MULTILINE)
    subsubsection_matches = re.findall(r"^subsubsection:\s*(.*?)$", doc_text, flags=re.MULTILINE)

    # Combining titles with section headings
    heading_text = " ".join([title] + section_matches + subsection_matches + subsubsection_matches)

    # Tokenize heading text for comparison
    heading_tokens = set(re.findall(r'\b\w[\w\s\-]*\w\b', heading_text))

    # Count matches
    match_count = sum(1 for term in matched_main_terms if term in heading_text)

    return match_count

# Importing conversion functions

with open("/content/RAGproject/Normalization/filtered_term_to_CUI.pkl", "rb") as f:
    term_to_CUI = pickle.load(f)

with open("/content/RAGproject/Normalization/filtered_cui_to_main_term.pkl", "rb") as f:
    cui_to_main_term = pickle.load(f)

# Creating ngrams and tracking indices
def ngram_tokenize_tokens(tokens, max_len=5):
    ngrams = []
    for i in range(len(tokens)):
        for j in range(i + 1, min(i + 1 + max_len, len(tokens) + 1)):
            span = tokens[i:j]
            ngram = ' '.join(span)
            ngrams.append((ngram, i, j))
    return ngrams

# Normalizing medical terms using main condition name
def cui_normalization(sentence, max_ngram_len=5):
    tokens = re.findall(r'\w+|\W+', sentence)

    # Filtering out words
    words = [w.lower() for w in tokens if re.match(r'\w+', w)]

    # Call tokenization function to return ngrams tuples
    ngrams = ngram_tokenize_tokens(words, max_ngram_len)
    replacements = []

    # Tracking matched CUIs
    matched_cuis = set()

    # Searching for terms in dictionary
    for ngram, start, end in ngrams:
        if ngram in term_to_CUI:
            cui = term_to_CUI[ngram]
            if cui in cui_to_main_term:
                replacements.append((start, end, cui_to_main_term[cui]))
                matched_cuis.add(cui)

    # Sorting by length then index (ensure longer terms first)
    replacements.sort(key=lambda x: (x[0], -(x[1] - x[0])))
    used = set()
    final = []
    # Ensure no overlap (check already used indices)
    for start, end, main_term in replacements:
        if not any(i in used for i in range(start, end)):
            final.append((start, end, main_term))
            used.update(range(start, end))

    # Reconstruct the sentence
    word_idx = 0
    output = []
    i = 0
    while i < len(tokens):
        # If the token is a word
        if re.match(r'\w+', tokens[i]):
            # Checking if index appears in final
            match = next((f for f in final if f[0] == word_idx), None)
            if match:
                output.append(match[2]) # append main term
                skip = match[1] - match[0]
                while skip > 0 and i < len(tokens):
                    if re.match(r'\w+', tokens[i]):
                        skip -= 1
                    i += 1
                # Update word-level index
                word_idx += (match[1] - match[0])
                continue
            word_idx += 1
        output.append(tokens[i])
        i += 1

    normalized_text = ''.join(output)
    matched_main_terms = [term for _, _, term in final]
    return normalized_text, matched_main_terms, list(matched_cuis)

# Build TF-IDF vectorizer for exact matching
tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
tfidf_matrix = tfidf_vectorizer.fit_transform(normalized_for_index)
tfidf_matrix = normalize(tfidf_matrix)

q_encoder = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base").to(device)
q_tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")

In [None]:
def rag_pipeline_lora_top1title_concat(
    query,
    top_k=600,
    bonus_weight=0.1,
    max_answer_tokens=256,
    confidence_threshold=0.63,
    combine_k=4,
    source_info=False
):

    # 1. Normalize query
    normalized_query, matched_main_terms, matched_cuis = cui_normalization(query)
    # 1.1 If no medical condition terms mentioned then return error
    if not matched_main_terms:
        return "!Please rephrase the query using a specific medical condition!", [], [], []

    # 2. Title match indices
    heading_matched_indices = [
        i for i, doc_text in enumerate(normalized_for_index)
        if any(term in doc_text.split('.')[0].lower() for term in matched_main_terms)
    ]

    # 3. FAISS search
    inputs = q_tokenizer(normalized_query, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        q_emb = q_encoder(**inputs).pooler_output.cpu().numpy()
        q_emb /= np.linalg.norm(q_emb, axis=1, keepdims=True)

    # 3.1 Scores and indices for the top 600 documents
    scores, indices = index.search(q_emb, top_k)
    scores, indices = scores[0], indices[0]
    index_scores = {idx: score for idx, score in zip(indices, scores)}


    # 4. TF-IDF similarity score
    tfidf_query_vec = tfidf_vectorizer.transform([normalized_query])
    tfidf_query_vec = normalize(tfidf_query_vec)
    tfidf_scores = (tfidf_matrix @ tfidf_query_vec.T).toarray().ravel()

    # 5. Combine indices from FAISS and indices from Title-matched docs
    all_indices = set(index_scores.keys()) | set(heading_matched_indices)
    for idx in all_indices:
        if idx not in index_scores:
            index_scores[idx] = 0.1

    # 6. Combine DPR and TF-IDF scores
    boosted_index_scores = {}
    for idx in all_indices:
        heading_match_count = count_heading_term_matches(normalized_for_index[idx], matched_main_terms)
        dense_score  = float(index_scores.get(idx, 0.1))
        sparse_score = float(tfidf_scores[idx])
        boosted_index_scores[idx] = 0.9 * dense_score + 0.1 * sparse_score + bonus_weight * heading_match_count

    # 7. Sorting scores and storing scores/indices for debugging
    sorted_items   = sorted(boosted_index_scores.items(), key=lambda x: -x[1])
    ranked_indices = [int(i) for i, _ in sorted_items]
    sorted_scores  = [float(s) for _, s in sorted_items]

    # 8. Retrieving top-4 documents
    used_k = min(int(combine_k), len(ranked_indices))
    retrieved_texts = [corpus_texts[idx] for idx in ranked_indices[:used_k]]
    boosted_used    = sorted_scores[:used_k]

    # 9. Threshold-based confidence check
    if all(s < confidence_threshold for s in boosted_used):
        return "!I do not have enough confidence to answer your question. Please try rephrasing it.", retrieved_texts, boosted_used, []

    # 10. Combining context and query for input
    combined_context = "\n\n".join(txt.strip() for txt in retrieved_texts)
    t5_input = f"question: {query}\ncontext: {combined_context}"

    # 11. Tokenize and track truncation
    enc = tokenizer(t5_input, return_tensors="pt", truncation=True, padding=True).to(model.device)
    with torch.no_grad():
        out = model.generate(
            **enc,
            max_new_tokens=max_answer_tokens,
            do_sample=False
        )
    answer = tokenizer.decode(out[0], skip_special_tokens=True).strip()


    # 13. Append source info to answer
    if source_info:
        sources = set()
        review_info = None

        for idx in ranked_indices[:used_k]:
            doc = corpus[idx]
            src = doc.get("source_url", doc.get("source", "Unknown"))
            if src:
                sources.add(src)

            if not review_info:
                review_info = doc.get("review_info", None)

        if sources:
            answer += "\n\nFor more information visit:"
            for src in sources:
                answer += f"\n {src}"

        if isinstance(review_info, str) and "Page last reviewed:" in review_info:
            reviewed_part = review_info.split("Page last reviewed:")[1].split("Next review due")[0].strip()
            answer += f"\n\nInformation up-to-date as of: {reviewed_part}"

    return answer, retrieved_texts, boosted_used, ranked_indices[:used_k]


In [None]:
def rag_pipeline_lora_top1title_ragseq(
    query,
    top_k=600,
    bonus_weight=0.1,
    max_answer_tokens=128,
    confidence_threshold=0.63,
    top_m=5,
    alpha=0.85,
    source_info=False,
    return_chosen=False
):

    # Normalized log prob calculator
    def _length_norm_seq_logprob(gen_out):
        scores = gen_out.scores
        if not scores: return float("-inf")
        chosen = gen_out.sequences[0][-len(scores):]
        total = 0.0
        for step, logits in enumerate(scores):
            lp = torch.log_softmax(logits[0], dim=-1)
            total += float(lp[int(chosen[step].item())])
        return total / max(1, len(scores))

    # 1. Normalize query
    normalized_query, matched_main_terms, _ = cui_normalization(query)
    if not matched_main_terms:
        return "!Please rephrase the query using a specific medical condition!", [], [], None

    # 2. Title match indices
    heading_matched_indices = [
        i for i, doc_text in enumerate(normalized_for_index)
        if any(term in doc_text.split('.')[0].lower() for term in matched_main_terms)
    ]

    # 3. FAISS search
    inputs = q_tokenizer(normalized_query, return_tensors="pt", truncation=True).to(device)
    with torch.no_grad():
        q_emb = q_encoder(**inputs).pooler_output.cpu().numpy()
        q_emb /= np.linalg.norm(q_emb, axis=1, keepdims=True)

    # 3.1 Scores and indices for the top documents
    tk = min(int(top_k), index.ntotal)
    scores, indices = index.search(q_emb, tk)
    scores, indices = scores[0], indices[0]
    valid = [(int(i), float(s)) for i, s in zip(indices, scores) if 0 <= i < len(corpus_texts)]
    index_scores = {i: s for i, s in valid}

    # 4. TF-IDF similarity score
    tfidf_query_vec = tfidf_vectorizer.transform([normalized_query])
    tfidf_query_vec = normalize(tfidf_query_vec)
    tfidf_scores = (tfidf_matrix @ tfidf_query_vec.T).toarray().ravel()

    # 5. Combine indices from FAISS and Title-matched docs
    all_indices = set(index_scores.keys()) | set(heading_matched_indices)
    for idx in all_indices:
        if idx not in index_scores:
            index_scores[idx] = 0.1

    # 6. Combine DPR and TF-IDF scores
    boosted_index_scores = {}
    for idx in all_indices:
        heading_match_count = count_heading_term_matches(normalized_for_index[idx], matched_main_terms)
        dense_score  = float(index_scores.get(idx, 0.1))
        sparse_score = float(tfidf_scores[idx])
        boosted_index_scores[idx] = 0.9 * dense_score + 0.1 * sparse_score + bonus_weight * heading_match_count

    # 7. Sorting scores and storing scores/indices for debugging
    sorted_items   = sorted(boosted_index_scores.items(), key=lambda x: -x[1])
    ranked_indices = [int(i) for i, _ in sorted_items]
    sorted_scores  = [float(s) for _, s in sorted_items]

    # 8. Retrieving top-M documents
    M = min(int(top_m), len(ranked_indices))
    retrieved_texts = [corpus_texts[idx] for idx in ranked_indices[:M]]
    boosted_used    = sorted_scores[:M]

    # 9. Threshold-based confidence check
    if all(s < confidence_threshold for s in boosted_used):
        return "!I do not have enough confidence to answer your question.", retrieved_texts, boosted_used, None

    # 10. Generation per document
    best = {"score": float("-inf"), "answer": None, "doc_idx": None}

    for i in range(M):
        doc_idx   = ranked_indices[i]
        doc_score = sorted_scores[i]
        if doc_score < confidence_threshold:
            continue

        ctx_text = corpus_texts[doc_idx].strip()
        t5_input = f"question: {query}\ncontext: {ctx_text}"

        enc = tokenizer(t5_input, return_tensors="pt", truncation=True, padding=True).to(model.device)

        with torch.no_grad():
            out = model.generate(
                **enc,
                max_new_tokens=max_answer_tokens,
                do_sample=False,
                return_dict_in_generate=True,
                output_scores=True,
            )

        seq = tokenizer.decode(out.sequences[0], skip_special_tokens=True).strip()
        seq_lp = _length_norm_seq_logprob(out)

        # blended scoring + short-answer penalty
        length_penalty = -2.0 if len(seq.split()) < 3 else 0.0
        combined = alpha * seq_lp + (1 - alpha) * doc_score + length_penalty

        if combined > best["score"]:
            best.update({"score": combined, "answer": seq, "doc_idx": doc_idx})

    # 11. Append source info
    if source_info and best["doc_idx"] is not None:
        doc = corpus[best["doc_idx"]]
        src = doc.get("source_url", doc.get("source", "Unknown"))
        review_info = doc.get("review_info", None)

        if src:
            best["answer"] += f"\n\nFor more information visit:\n {src}"
        if isinstance(review_info, str) and "Page last reviewed:" in review_info:
            reviewed_part = review_info.split("Page last reviewed:")[1].split("Next review due")[0].strip()
            best["answer"] += f"\n\nInformation up-to-date as of: {reviewed_part}"

    if return_chosen:
        return (
            best["answer"],
            ranked_indices[:M],
            boosted_used,
            best["doc_idx"],
        )
    else:
        return best["answer"], ranked_indices[:M], boosted_used, best["doc_idx"]


In [None]:
def ask_models(query):
    # Run your concat pipeline
    ans_concat, retrieved_concat, scores_concat, idx_concat = rag_pipeline_lora_top1title_concat(
        query, source_info=True
    )

    # Run your ragseq pipeline
    ans_ragseq, retrieved_ragseq, scores_ragseq, idx_ragseq = rag_pipeline_lora_top1title_ragseq(
        query, source_info=True
    )

    # Format nicely
    output = f"**Concat Pipeline Answer:**\n{ans_concat}\n\n---\n\n"
    output += f"**RagSeq Pipeline Answer:**\n{ans_ragseq}"

    return output

with gr.Blocks() as demo:
    gr.Markdown("## RAG LoRA App\nAsk your question containing a medical condition and see answers from both pipelines.")

    with gr.Row():
        with gr.Column():
            query = gr.Textbox(label="Your Medical Query", placeholder="Type your medical query here.")
            btn = gr.Button("Send")
        with gr.Column():
            output = gr.Markdown()

    btn.click(fn=ask_models, inputs=query, outputs=output)

demo.launch(share=True, quiet=True)

* Running on public URL: https://9f1196debc1519eba6.gradio.live


