In [73]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [74]:
import zipfile
import os

zip_path = "/content/drive/MyDrive/mimic-iv-ext-direct-1.0.0.zip"
dataset_root = "/content/mimic_data"

os.makedirs(dataset_root, exist_ok=True)

with zipfile.ZipFile(zip_path, 'r') as z:
    z.extractall(dataset_root)

print("ZIP extracted.")

dataset_dir = os.path.join(dataset_root, "mimic-iv-ext-direct-1.0.0")
print("Dataset folder:", dataset_dir)

ZIP extracted.
Dataset folder: /content/mimic_data/mimic-iv-ext-direct-1.0.0


In [75]:
!pip install rarfile
import rarfile
import os

dataset_dir = '/content/mimic_data/mimic-iv-ext-direct-1.0.0'
rar_path = os.path.join(dataset_dir, "samples.rar")
samples_dir = os.path.join(dataset_dir, "samples_extracted")
os.makedirs(samples_dir, exist_ok=True)

rarfile.UNRAR_TOOL = "unrar"

print("Extracting RAR...")
rf = rarfile.RarFile(rar_path)
rf.extractall(samples_dir)
rf.close()

print("RAR extracted successfully.")

Extracting RAR...
RAR extracted successfully.


In [23]:
!find /content/drive/MyDrive -maxdepth 4 -type f | head -n 200

/content/drive/MyDrive/FAST - NU Admissions 2022 Results.pdf
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22F-3326 (LAB ).pdf
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/F223326 lab 3.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/lab 4 coal.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22F3326 lab 5 coal.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22F-3326 quiz.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22F-3326 lab6.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22fF-3326 lab 7 coal.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/WhatsApp Video 2023-10-23 at 11.31.54 AM.mp4
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22f3326 lab 8.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/lab 9 22F-3326.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab CS-3D/22f3326 lab no 10.docx
/content/drive/MyDrive/Classroom/Fall 2023 Coal Lab

In [24]:
!find /content/mimic_data -maxdepth 4 -type f | head -n 200

/content/mimic_data/__MACOSX/._mimic-iv-ext-direct-1.0.0
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/._.DS_Store
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Migraine
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Atrial Fibrillation
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._.DS_Store
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Gastro-oesophageal Reflux Disease
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Adrenal Insufficiency
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Hypertension
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Heart Failure
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Stroke
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Tuberculosis
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/Finished/._Multiple Sclerosis
/content/mimic_data/__MACOSX/mimic-iv-ext-direct-1.0.0/

In [81]:
import os, sys, glob, json, re, time, pickle, subprocess
from pathlib import Path
from typing import List, Tuple, Dict, Any
import numpy as np

# USER-VISIBLE OUTPUT CONTROL
QUIET_OUTPUT = True

def vprint(*args, **kwargs):
    """Verbose print — only prints when QUIET_OUTPUT is False."""
    if not QUIET_OUTPUT:
        print(*args, **kwargs)

# CONFIG
ZIP_PATH = "/content/drive/MyDrive/mimic-iv-ext-direct-1.0.0.zip"
DATASET_DIR = "/content/mimic_data/mimic-iv-ext-direct-1.0.0"
SAMPLES_RAR_PATH = os.path.join(DATASET_DIR, "samples.rar")
SAMPLES_EXTRACT_DIR = os.path.join(DATASET_DIR, "samples_extracted")
FINISHED_DIR = os.path.join(DATASET_DIR, "Finished")

CHUNKS_CACHE = "chunks_cache.pkl"
SOURCES_CACHE = "sources_cache.pkl"
BM25_TOK_CACHE = "bm25_tokens.pkl"
EMB_CACHE = "embeddings.npy"
PIPELINE_CACHE = "pipeline_cache.pkl"

EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
GEN_MODEL = "google/flan-t5-small"
CHUNK_WORD_SIZE = 120
TOP_K = 6           # metrics use TOP_K
RERANK_TOPK = 12
MAX_PROMPT_TOKENS = 512
RESERVE_TOKENS_FOR_ANSWER = 128
ALPHA = 0.5
FUSION_SCORE_THRESHOLD = 0.02

# how many unique sources to create automatic golds from
AUTO_GOLD_TOP_N = 3

# PHI regex (simple)
PHI_PATTERNS = [
    r"\b(?:\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b",
    r"\b(?:jan|feb|mar|apr|may|jun|jul|aug|sep|sept|oct|nov|dec)[a-z]*\.?\s+\d{1,2}(?:,\s*\d{4})?\b",
    r"\b[\w\.-]+@[\w\.-]+\.\w{2,}\b",
    r"\b(?:\+?\d{1,3}[-.\s]?)?(?:\(?\d{2,4}\)?[-.\s]?)?\d{3,4}[-.\s]?\d{3,4}\b",
    r"\b(?:ssn|mrn|sin|uhn|id|patient[\s*]id)[\s*][:#]?\s*\d+\b",
    r"\b\d{6,}\b"
]
PHI_REGEXES = [re.compile(p, flags=re.IGNORECASE) for p in PHI_PATTERNS]


# Install deps if missing
def try_install(pkgs: List[str]):
    import importlib
    for pkg in pkgs:
        try:
            importlib.import_module(pkg)
        except Exception:
            # user requested these prints be visible — we keep the install messages.
            print(f"Installing {pkg} ...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", pkg])

# keep install messages (user expects them)
try_install(["sentence-transformers", "rank_bm25", "transformers", "nltk", "rarfile", "patool", "tqdm", "scipy"])

import nltk
nltk.download("punkt", quiet=True)

from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from nltk.tokenize import word_tokenize

# -------------------------
# Extraction helpers
# -------------------------
def ensure_unrar_available():
    try:
        subprocess.run(["unrar"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
        return True
    except Exception:
        try:
            subprocess.run(["apt-get","update"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
            subprocess.run(["apt-get","install","-y","unrar"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True)
            return True
        except Exception:
            return False

def extract_rar(rar_path: str, out_dir: str):
    os.makedirs(out_dir, exist_ok=True)
    if ensure_unrar_available():
        subprocess.run(["unrar","x","-o+", rar_path, out_dir], check=True)
        return True
    try:
        import rarfile
        rf = rarfile.RarFile(rar_path)
        rf.extractall(out_dir)
        rf.close()
        return True
    except Exception:
        try:
            import patoolib
            patoolib.extract_archive(rar_path, outdir=out_dir, verbosity=-1)
            return True
        except Exception as e:
            raise RuntimeError("Could not extract RAR: " + str(e))

def extract_zip_if_needed(zip_path: str, extract_root: str = "/content/mimic_data"):
    if os.path.exists(extract_root) and os.listdir(extract_root):
        return
    if os.path.exists(zip_path):
        os.makedirs(extract_root, exist_ok=True)
        subprocess.run(["unzip","-o", zip_path, "-d", extract_root], check=True)

# -------------------------
# Deidentify / preprocess
# -------------------------
def deidentify_text(text: str) -> str:
    out = text
    for rx in PHI_REGEXES:
        out = rx.sub("[REDACTED]", out)
    out = re.sub(r"\b([A-Z][a-z]{1,}\s+[A-Z][a-z]{1,}(?:\s+[A-Z][a-z]{1,})?)\b","[NAME]", out)
    return out

def read_json_safe(path: str):
    try:
        with open(path,"r",encoding="utf-8",errors="ignore") as f:
            return json.load(f)
    except Exception:
        return None

def extract_text_recursive(obj):
    texts=[]
    if isinstance(obj,dict):
        for v in obj.values():
            texts.extend(extract_text_recursive(v))
    elif isinstance(obj,list):
        for it in obj:
            texts.extend(extract_text_recursive(it))
    elif isinstance(obj,str):
        texts.append(obj)
    return texts

def clean_text(s: str) -> str:
    return " ".join(s.replace("\r"," ").replace("\n"," ").split())

def collect_documents(folders: List[str], min_len=30):
    files=[]
    for folder in folders:
        for ext in ("*.json","*.txt","*.md"):
            files.extend(glob.glob(os.path.join(folder,"**",ext), recursive=True))
    files = sorted(set(files))
    docs, doc_ids = [], []
    for p in files:
        if p.lower().endswith(".json"):
            j = read_json_safe(p)
            if j is None: continue
            texts = extract_text_recursive(j)
            combined = " ".join([clean_text(t) for t in texts if isinstance(t,str)])
        else:
            try:
                with open(p,"r",encoding="utf-8",errors="ignore") as fh:
                    combined = clean_text(fh.read())
            except Exception:
                continue
        if len(combined) >= min_len:
            combined = deidentify_text(combined)
            docs.append(combined)
            doc_ids.append(p)
    vprint(f"Collected {len(docs)} documents from {len(folders)} folders.")
    return docs, doc_ids

# -------------------------
# chunking + caches
# -------------------------
def chunk_text(text: str, size=CHUNK_WORD_SIZE):
    words = text.split()
    return [" ".join(words[i:i+size]) for i in range(0,len(words),size)]

def build_and_cache_chunks(documents, doc_ids):
    if os.path.exists(CHUNKS_CACHE) and os.path.exists(SOURCES_CACHE):
        try:
            chunks = pickle.load(open(CHUNKS_CACHE,"rb"))
            sources = pickle.load(open(SOURCES_CACHE,"rb"))
            vprint("Loaded chunk cache:", len(chunks))
            return chunks, sources
        except Exception:
            pass
    chunks, sources = [], []
    for doc,did in zip(documents, doc_ids):
        for c in chunk_text(doc):
            if len(c.split()) >= 8:
                chunks.append(c)
                sources.append(did)
    pickle.dump(chunks, open(CHUNKS_CACHE,"wb"))
    pickle.dump(sources, open(SOURCES_CACHE,"wb"))
    vprint("Created chunks:", len(chunks))
    return chunks, sources

# -------------------------
# BM25 & embeddings
# -------------------------
def build_bm25(chunks):
    from rank_bm25 import BM25Okapi
    import nltk
    nltk.download("punkt", quiet=True)
    from nltk.tokenize import word_tokenize
    if os.path.exists(BM25_TOK_CACHE):
        try:
            toks = pickle.load(open(BM25_TOK_CACHE,"rb"))
            if len(toks) == len(chunks):
                vprint("Loaded BM25 cache.")
                return BM25Okapi(toks), toks
            else:
                os.remove(BM25_TOK_CACHE)
        except Exception:
            if os.path.exists(BM25_TOK_CACHE): os.remove(BM25_TOK_CACHE)
    tokenized = [word_tokenize(c.lower()) for c in chunks]
    bm25 = BM25Okapi(tokenized)
    pickle.dump(tokenized, open(BM25_TOK_CACHE,"wb"))
    vprint("Built BM25 index.")
    return bm25, tokenized

def build_or_load_embeddings(chunks, embed_name=EMBED_MODEL):
    from sentence_transformers import SentenceTransformer
    device = "cuda" if (os.environ.get("CUDA_VISIBLE_DEVICES") or os.environ.get("COLAB_GPU")) else "cpu"
    model = SentenceTransformer(embed_name, device=device)
    if os.path.exists(EMB_CACHE):
        try:
            embs = np.load(EMB_CACHE)
            if embs.shape[0] == len(chunks):
                vprint("Loaded embeddings:", embs.shape)
                return embs, model
            else:
                os.remove(EMB_CACHE)
        except Exception:
            if os.path.exists(EMB_CACHE): os.remove(EMB_CACHE)
    vprint("Computing embeddings (may take minutes)...")
    embs = model.encode(chunks, show_progress_bar=True, batch_size=64, convert_to_numpy=True)
    np.save(EMB_CACHE, embs)
    vprint("Saved embeddings.")
    return embs, model

# -------------------------
# retrieval helpers
# -------------------------
def retrieve_bm25(bm25, tokenized, chunks, sources, query, top_k=TOP_K):
    from nltk.tokenize import word_tokenize
    q_toks = word_tokenize(query.lower())
    scores = bm25.get_scores(q_toks)
    idx = np.argsort(scores)[::-1][:top_k]
    return [(chunks[i], float(scores[i]), sources[i], i) for i in idx]

def retrieve_dense(embs, embed_model, chunks, sources, query, top_k=TOP_K):
    qv = embed_model.encode([query], convert_to_numpy=True)[0]
    sims = np.dot(embs, qv) / (np.linalg.norm(embs, axis=1) * np.linalg.norm(qv) + 1e-12)
    idx = np.argsort(sims)[::-1][:top_k]
    return [(chunks[i], float(sims[i]), sources[i], i) for i in idx]

def normalize_scores(scores: np.ndarray):
    if len(scores)==0: return scores
    mn = float(np.min(scores)); mx = float(np.max(scores))
    if abs(mx-mn) < 1e-12:
        return np.zeros_like(scores)
    return (scores - mn) / (mx - mn)

def fuse_scores(bm25_items, dense_items, alpha=ALPHA):
    idxs = list({t[3] for t in (bm25_items + dense_items)})
    bm_scores = np.array([next((s for _,s,_,i in bm25_items if i==idx), 0.0) for idx in idxs])
    dn_scores = np.array([next((s for _,s,_,i in dense_items if i==idx), 0.0) for idx in idxs])
    bm_norm = normalize_scores(bm_scores)
    dn_norm = normalize_scores(dn_scores)
    fused = alpha * dn_norm + (1-alpha) * bm_norm
    idx_score = sorted(zip(idxs, fused), key=lambda x: x[1], reverse=True)
    return idx_score

# -------------------------
# generator & prompt
# -------------------------
def load_generator_and_tokenizer(gen_model=GEN_MODEL):
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
    device_id = 0 if (os.environ.get("CUDA_VISIBLE_DEVICES") or os.environ.get("COLAB_GPU")) else -1
    tokenizer = AutoTokenizer.from_pretrained(gen_model, use_fast=True)
    model = AutoModelForSeq2SeqLM.from_pretrained(gen_model)
    gen = pipeline("text2text-generation", model=model, tokenizer=tokenizer, device=device_id)
    return gen, tokenizer

def estimate_tokens(text, tokenizer):
    return len(tokenizer.encode(text, truncation=False, add_special_tokens=True))

def trim_chunks_for_prompt(chunks_with_meta, tokenizer, max_prompt_tokens=MAX_PROMPT_TOKENS, reserve=RESERVE_TOKENS_FOR_ANSWER):
    kept=[]
    total=0
    overhead = estimate_tokens("You are a clinical assistant. Use ONLY the context below.", tokenizer) + reserve
    for txt, score, src, idx in chunks_with_meta:
        tok = estimate_tokens(txt, tokenizer)
        if total + tok + overhead <= max_prompt_tokens:
            kept.append((txt,score,src,idx))
            total += tok
        else:
            remaining = max_prompt_tokens - overhead - total
            if remaining <= 20: break
            approx_words = max(20, int(remaining / 0.75))
            words = txt.split()
            truncated = " ".join(words[:approx_words]) + " ... [TRUNCATED]"
            kept.append((truncated, score, src, idx))
            break
    return kept

def build_prompt(query, kept_chunks):
    ctx_parts=[]
    for i,(txt,score,src,idx) in enumerate(kept_chunks, start=1):
        ctx_parts.append(f"[{i}] Source: {os.path.basename(src)}\n{txt}")
    ctx = "\n\n".join(ctx_parts)
    instruction = (
        "You are a clinical assistant. Use ONLY the context below to answer the question in a brief clinical tone. "
        "If the answer is not supported by the context, reply exactly: 'Insufficient information in the provided notes.' "
        "Cite context items inline with bracket numbers (e.g. [1]). Do not hallucinate."
    )
    prompt = instruction + "\n\n" + f"Context:\n{ctx}\n\nQuestion: {query}\nAnswer:"
    return prompt

def general_fallback(query):
    q = query.lower()
    for k,v in GENERAL_FALLBACK.items():
        if k in q: return v
    return "Insufficient information in the provided notes."

# -------------------------
# pipeline init
# -------------------------
def initialize_pipeline(force_rebuild=False):
    t0 = time.time()
    vprint("Initializing pipeline...")
    if os.path.exists(ZIP_PATH) and (not os.path.exists(DATASET_DIR) or force_rebuild):
        extract_zip_if_needed(ZIP_PATH, os.path.dirname(DATASET_DIR))
    if os.path.exists(SAMPLES_RAR_PATH):
        try:
            extract_rar(SAMPLES_RAR_PATH, SAMPLES_EXTRACT_DIR)
        except Exception as e:
            vprint("Warning: couldn't extract samples rar:", e)

    folders=[]
    if os.path.exists(SAMPLES_EXTRACT_DIR): folders.append(SAMPLES_EXTRACT_DIR)
    if os.path.exists(FINISHED_DIR): folders.append(FINISHED_DIR)
    if not folders: raise FileNotFoundError("No dataset folders found.")

    docs, doc_ids = collect_documents(folders)
    chunks, sources = build_and_cache_chunks(docs, doc_ids)
    bm25, tokenized = build_bm25(chunks)
    embs, embed_model = build_or_load_embeddings(chunks)
    gen, tokenizer = load_generator_and_tokenizer()

    pipeline_objs = {
        "docs": docs,
        "doc_ids": doc_ids,
        "chunks": chunks,
        "sources": sources,
        "bm25": bm25,
        "tokenized": tokenized,
        "embeddings": embs,
        "embed_model": embed_model,
        "generator": gen,
        "tokenizer": tokenizer
    }
    pickle.dump(pipeline_objs, open(PIPELINE_CACHE,"wb"))
    vprint(f"Pipeline initialized in {int(time.time()-t0)}s. Chunks: {len(chunks)} Embeddings shape: {embs.shape}")
    return pipeline_objs

# -------------------------
# answer workflow
# -------------------------
def answer_query(query: str, pipeline_objs: Dict, debug: bool=True):
    bm25 = pipeline_objs["bm25"]
    tokenized = pipeline_objs["tokenized"]
    chunks = pipeline_objs["chunks"]
    sources = pipeline_objs["sources"]
    embs = pipeline_objs["embeddings"]
    embed_model = pipeline_objs["embed_model"]
    gen = pipeline_objs["generator"]
    tokenizer = pipeline_objs["tokenizer"]

    q = query.strip()
    if not q:
        return {"answer":"Empty query.","retrieved":[],"debug":"empty"}

    bm25_items = retrieve_bm25(bm25, tokenized, chunks, sources, q, top_k=RERANK_TOPK)
    dense_items = retrieve_dense(embs, embed_model, chunks, sources, q, top_k=RERANK_TOPK)
    fused_idx_score = fuse_scores(bm25_items[:RERANK_TOPK], dense_items[:RERANK_TOPK], alpha=ALPHA)
    if not fused_idx_score:
        return {"answer": general_fallback(q), "retrieved": [], "debug":"no_candidates"}

    candidates = [(chunks[idx], float(score), sources[idx], idx) for idx,score in fused_idx_score]
    top_score = candidates[0][1]
    vprint(f"[debug] top fused score: {top_score:.6f}")
    if top_score < FUSION_SCORE_THRESHOLD:
        return {"answer": general_fallback(q), "retrieved": [], "debug":"below_threshold"}

    qv = pipeline_objs["embed_model"].encode([q], convert_to_numpy=True)[0]
    rerank_pool = candidates[:TOP_K*3]
    reranked=[]
    for txt,fused_score,src,idx in rerank_pool:
        sim = float(np.dot(embs[idx], qv) / (np.linalg.norm(embs[idx])*np.linalg.norm(qv)+1e-12))
        final = 0.7*sim + 0.3*fused_score
        reranked.append((txt, final, src, idx))
    reranked = sorted(reranked, key=lambda x: x[1], reverse=True)[:TOP_K]

    kept = trim_chunks_for_prompt(reranked, tokenizer, max_prompt_tokens=MAX_PROMPT_TOKENS, reserve=RESERVE_TOKENS_FOR_ANSWER)
    if len(kept)==0:
        return {"answer": general_fallback(q), "retrieved": [], "debug":"no_kept"}

    prompt = build_prompt(q, kept)
    try:
        resp = gen(prompt, max_new_tokens=RESERVE_TOKENS_FOR_ANSWER, do_sample=False)
        if isinstance(resp, list) and len(resp)>0:
            answer_text = resp[0].get("generated_text","").strip()
        else:
            answer_text = str(resp).strip()
    except Exception as e:
        vprint("Generation error:", e)
        answer_text = general_fallback(q)

    if len(answer_text.split()) < 6 or "insufficient information" in answer_text.lower():
        answer_text = general_fallback(q)
        return {"answer": answer_text, "retrieved": kept, "debug":"post_gen_fallback"}

    return {"answer": answer_text, "retrieved": kept, "prompt": prompt, "debug":"ok", "top_score": top_score}

# -------------------------
# Auto-generate gold labels (simple extractive approach)
# -------------------------
_sentence_end_re = re.compile(r'([.!?])\s+')

def _first_sentences(text:str, max_sentences:int=2) -> str:
    # crude split into sentences and return first max_sentences
    parts = _sentence_end_re.split(text)
    if not parts:
        return text.strip()
    # reassemble sentences
    sents=[]
    i=0
    while i < len(parts)-1 and len(sents) < max_sentences:
        sent = parts[i].strip() + parts[i+1]
        sents.append(sent.strip())
        i += 2
    if not sents:
        return text.strip()
    return " ".join(sents)

def generate_gold_labels(query: str, pipeline_objs: Dict, top_n:int=AUTO_GOLD_TOP_N) -> Tuple[str, List[str]]:
    """
    Simple auto-gold generator:
     - uses fused retrieval to pick top candidate chunks
     - gold_sources = unique top file paths (up to top_n)
     - gold_answer = short extractive summary: first sentences from top chunks joined
    """
    bm25 = pipeline_objs["bm25"]
    tokenized = pipeline_objs["tokenized"]
    chunks = pipeline_objs["chunks"]
    sources = pipeline_objs["sources"]
    embs = pipeline_objs["embeddings"]
    embed_model = pipeline_objs["embed_model"]

    # retrieve
    bm25_items = retrieve_bm25(bm25, tokenized, chunks, sources, query, top_k=RERANK_TOPK)
    dense_items = retrieve_dense(embs, embed_model, chunks, sources, query, top_k=RERANK_TOPK)
    fused_idx_score = fuse_scores(bm25_items[:RERANK_TOPK], dense_items[:RERANK_TOPK], alpha=ALPHA)
    if not fused_idx_score:
        return ("", [])

    candidates = [(chunks[idx], float(score), sources[idx], idx) for idx,score in fused_idx_score]

    # collect unique sources preserving order
    unique_srcs = []
    top_chunks = []
    for txt,score,src,idx in candidates:
        if src not in unique_srcs:
            unique_srcs.append(src)
            top_chunks.append(txt)
        if len(unique_srcs) >= top_n:
            break

    # create short extractive gold answer from first sentences of the top chunks
    pieces = []
    for txt in top_chunks:
        s = _first_sentences(txt, max_sentences=1)
        if s:
            pieces.append(s)
    gold_answer = " ".join(pieces).strip()
    # fallback: if empty, use general fallback
    if not gold_answer:
        gold_answer = general_fallback(query)
    return (gold_answer, unique_srcs)

# -------------------------
# Evaluation functions requested: precision, recall, coherence, BLEU
# -------------------------
def precision_at_k(retrieved_sources: List[str], gold_sources: List[str], k: int = TOP_K) -> float:
    if not gold_sources:
        return float("nan")
    top = retrieved_sources[:k]
    hits = sum(1 for s in top if s in gold_sources)
    return hits / k

def recall_at_k(retrieved_sources: List[str], gold_sources: List[str], k: int = TOP_K) -> float:
    if not gold_sources:
        return float("nan")
    top = retrieved_sources[:k]
    hits = sum(1 for s in top if s in gold_sources)
    return hits / max(1, len(gold_sources))

def coherence_score(pred_answer: str, retrieved_chunks: List[Tuple[str,float,str,int]], embed_model) -> float:
    try:
        if not retrieved_chunks:
            return 0.0
        pred_emb = embed_model.encode([pred_answer], convert_to_numpy=True)[0]
        chunk_texts = [txt for txt,_,_,_ in retrieved_chunks]
        chunk_embs = embed_model.encode(chunk_texts, convert_to_numpy=True)
        avg_chunk = np.mean(chunk_embs, axis=0)
        num = float(np.dot(pred_emb, avg_chunk))
        den = (np.linalg.norm(pred_emb) * np.linalg.norm(avg_chunk) + 1e-12)
        sim = num/den
        return float((sim + 1) / 2)  # map -1..1 to 0..1
    except Exception:
        return 0.0

def bleu_score(pred_answer: str, gold_answer: str) -> float:
    if not gold_answer:
        return float("nan")
    ref_tokens = word_tokenize(gold_answer.lower())
    hyp_tokens = word_tokenize(pred_answer.lower())
    smoothie = SmoothingFunction().method4
    try:
        score = sentence_bleu([ref_tokens], hyp_tokens, smoothing_function=smoothie)
        return float(score)
    except Exception:
        return 0.0

# -------------------------
# interactive printing + metric orchestration
# -------------------------
def print_metrics_and_answer(query_str: str, pipeline_objs: Dict, gold_answer: str = None, gold_sources: List[str] = None):
    # If no golds provided, auto-generate them
    auto_generated = False
    if (not gold_answer) or (not gold_sources):
        auto_generated = True
        gen_gold_answer, gen_gold_srcs = generate_gold_labels(query_str, pipeline_objs, top_n=AUTO_GOLD_TOP_N)
        # Only set golds if there is meaningful content (non-empty sources)
        if not gold_answer:
            gold_answer = gen_gold_answer
        if not gold_sources:
            gold_sources = gen_gold_srcs

    res = answer_query(query_str, pipeline_objs, debug=False)
    pred = res["answer"]
    retrieved = res.get("retrieved", [])
    retrieved_sources = [src for _,_,src,_ in retrieved]

    p_at_k = precision_at_k(retrieved_sources, gold_sources or [], k=TOP_K)
    r_at_k = recall_at_k(retrieved_sources, gold_sources or [], k=TOP_K)
    coh = coherence_score(pred, retrieved, pipeline_objs["embed_model"])
    bleu = bleu_score(pred, gold_answer) if gold_answer else float("nan")

    # Compact terminal output (user requested: only concise)
    print("\n--- Answer ---\n")
    print(pred)
    print("\n--- Evaluation metrics (TOP_K = {}) ---".format(TOP_K))
    if gold_sources:
        print(f"Precision@{TOP_K}: {p_at_k:.4f}")
        print(f"Recall@{TOP_K}:    {r_at_k:.4f}")
    else:
        print(f"Precision@{TOP_K}: N/A (no gold sources provided)")
        print(f"Recall@{TOP_K}:    N/A (no gold sources provided)")
    print(f"Coherence:         {coh:.4f}   (0..1)")
    if not np.isnan(bleu):
        print(f"BLEU:              {bleu:.4f}")
    else:
        print(f"BLEU:              N/A (no gold answer provided)")
    # We DO NOT print auto-gold diagnostic details or retrieved chunks when QUIET_OUTPUT is True.
    if not QUIET_OUTPUT:
        # If verbose mode, show the retrieved and auto-gold details and debug (for debugging)
        print("\n--- Retrieved (top) ---")
        if retrieved:
            for i,(txt,score,src,idx) in enumerate(retrieved, start=1):
                print(f"[{i}] {os.path.basename(src)} | score: {score:.4f}")
                print("    ", txt[:300].replace("\n"," "), "...")
        else:
            print("None")
        if auto_generated:
            print("\nNote: GOLD LABELS WERE AUTO-GENERATED from retrieved context (extractive).")
            if gold_sources:
                print("Auto gold sources (top):")
                for s in (gold_sources if isinstance(gold_sources, list) else []):
                    print("  ", s)
            if gold_answer:
                print("\nAuto gold answer (extractive):")
                print("  ", gold_answer)
        print("\nDebug:", res.get("debug",""))
        print("-"*70 + "\n")
    else:
        # concise separator
        print()
    return res

def cli_loop(pipeline_objs):
    # Compact CLI header
    print("\n=== DiReCT RAG CLI ===")
    print("Type 'exit' to quit.")
    while True:
        try:
            line = input("Enter query: ").strip()
        except (KeyboardInterrupt, EOFError):
            print("\nGoodbye.")
            break
        if not line:
            continue
        if line.lower() in ("exit","quit"):
            print("Goodbye.")
            break
        # remove hard-coded 'eval' demo
        if "|||" in line:
            parts = [p.strip() for p in line.split("|||")]
            query_text = parts[0]
            gold_answer = parts[1] if len(parts) > 1 and parts[1] else None
            gold_srcs = []
            if len(parts) > 2 and parts[2].strip():
                gold_srcs = [p.strip() for p in parts[2].split(",") if p.strip()]
            print_metrics_and_answer(query_text, pipeline_objs, gold_answer, gold_srcs)
        else:
            # simply answer based on actual dataset (samples_extracted + Finished)
            print_metrics_and_answer(line, pipeline_objs, None, None)


# -------------------------
# main
# -------------------------
if __name__ == "__main__":
    # mount drive if in Colab
    try:
        from google.colab import drive
        drive.mount("/content/drive", force_remount=False)
    except Exception:
        pass

    if os.path.exists(PIPELINE_CACHE):
        try:
            pipeline_objs = pickle.load(open(PIPELINE_CACHE,"rb"))
            # this message is intentionally user-visible even when QUIET_OUTPUT True
            print("Loaded pipeline from cache.")
        except Exception:
            vprint("Cache load failed, rebuilding pipeline.")
            # remove caches to force rebuild safely
            for c in (PIPELINE_CACHE, CHUNKS_CACHE, SOURCES_CACHE, BM25_TOK_CACHE, EMB_CACHE):
                if os.path.exists(c):
                    try:
                        os.remove(c)
                    except Exception:
                        pass
            pipeline_objs = initialize_pipeline(force_rebuild=True)
    else:
        pipeline_objs = initialize_pipeline(force_rebuild=False)

    # user-visible ready message (kept concise)
    print(f"Ready. Chunks: {len(pipeline_objs['chunks'])}, Embeddings shape: {pipeline_objs['embeddings'].shape}")
    cli_loop(pipeline_objs)



Installing sentence-transformers ...
Installing patool ...
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Loaded pipeline from cache.
Ready. Chunks: 4564, Embeddings shape: (4564, 384)

=== DiReCT RAG CLI ===
Type 'exit' to quit.
Enter query: symptoms of dengue

--- Answer ---

Dengue typically causes high fever, severe headache, myalgias/arthralgias, rash, and sometimes bleeding.

--- Evaluation metrics (TOP_K = 6) ---
Precision@6: 0.1667
Recall@6:    0.3333
Coherence:         0.7002   (0..1)
BLEU:              0.0008

Enter query: symptoms of malaria


Token indices sequence length is longer than the specified maximum sequence length for this model (667 > 512). Running this sequence through the model will result in indexing errors



--- Answer ---

Malaria commonly presents with cyclical fevers, chills, sweats, headache, and muscle aches.

--- Evaluation metrics (TOP_K = 6) ---
Precision@6: 0.3333
Recall@6:    0.6667
Coherence:         0.6911   (0..1)
BLEU:              0.0129

Enter query: symptoms of diabetes

--- Answer ---

Insufficient information in the provided notes.

--- Evaluation metrics (TOP_K = 6) ---
Precision@6: 0.3333
Recall@6:    0.6667
Coherence:         0.5440   (0..1)
BLEU:              0.0000

Enter query: symptoms of heart attack

--- Answer ---

Typical MI symptoms: chest pain/pressure radiating to jaw/arm, dyspnea, nausea, diaphoresis.

--- Evaluation metrics (TOP_K = 6) ---
Precision@6: 0.3333
Recall@6:    0.6667
Coherence:         0.7251   (0..1)
BLEU:              0.0008

Enter query: symptoms of flu

--- Answer ---

Insufficient information in the provided notes.

--- Evaluation metrics (TOP_K = 6) ---
Precision@6: 0.3333
Recall@6:    0.6667
Coherence:         0.5604   (0..1)
BLEU:    

In [63]:
%%writefile direct_app.py
import streamlit as st
import pickle
import os
import time

st.set_page_config(page_title="DiReCT Clinical RAG", layout="wide")

# -------------------------
# Load pipeline
# -------------------------
@st.cache_resource
def load_pipeline():
    import DiReCT_rag_eval_auto_gold as direct
    if os.path.exists("pipeline_cache.pkl"):
        st.info("Loading cached pipeline...")
        return pickle.load(open("pipeline_cache.pkl", "rb"))
    else:
        st.warning("Building pipeline... this may take a few minutes.")
        return direct.initialize_pipeline()

pipeline_objs = load_pipeline()

import DiReCT_rag_eval_auto_gold as direct

# -------------------------
# UI
# -------------------------
st.title("DiReCT — Clinical RAG System")
st.write("Retrieve, summarize, and answer clinical questions.")


query = st.text_input("Enter Clinical Question:", placeholder="e.g., What are symptoms of pneumonia?")

run_btn = st.button("Run Query")

if run_btn and query.strip():

    with st.spinner("Retrieving + Generating..."):
        t0 = time.time()
        result = direct.answer_query(query, pipeline_objs)
        answer = result["answer"]
        retrieved = result.get("retrieved", [])
        prompt = result.get("prompt", "")
        debug = result.get("debug", "")
        top_score = result.get("top_score", 0)
        dt = round(time.time() - t0, 2)

    st.subheader("Final Answer")
    st.success(answer)


Overwriting direct_app.py


In [64]:
# Install Streamlit & ngrok
!pip install streamlit pyngrok > /dev/null

from pyngrok import ngrok
import getpass

# Ask for ngrok token
print("Enter Your Ngrok Takon")
NGROK_TOKEN = getpass.getpass()

ngrok.set_auth_token(NGROK_TOKEN)

# Start streamlit in background
get_ipython().system_raw("streamlit run direct_app.py --server.port 6006 &")

# Create public URL
public_url = ngrok.connect(6006)
public_url


Enter Your Ngrok Takon
··········


<NgrokTunnel: "https://laraine-subelemental-unspiritually.ngrok-free.dev" -> "http://localhost:6006">