# Dr.Bot Project 

Here we will develop and train the LLM model which will further carry forward with combining into a package in export package code.

## Imports & configuration flags

Set up PyTorch + Hugging Face Transformers for an offline pipeline that runs on Kaggle GPU.

In [1]:
############################################################
# 0. Imports & configuration flags
############################################################
import re
import time
import os
import pickle
import requests
import random
import hashlib
import pandas as pd
import numpy as np
import torch
import transformers
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 
from cachetools import TTLCache, cached

## Pubmedbert model training code

**Overview:** The below cell seeds randomness for reproducibility, selects an accelerated device (GPU/MPS if available, else CPU), and loads the BiomedNLP-PubMedBERT encoder to build a lightweight label-embedding classifier. We preprocess the [MedQuAD CSV](https://www.kaggle.com/datasets/jpmiller/layoutlm?resource=download) by composing a training text (Severity -- question + answer), normalizing strings, and pruning rare labels, then tokenize with the PubMedBERT tokenizer. Instead of a heavy MLP head, we create a prototype for each focus area by encoding the label text into an embedding and classify by comparing each query’s [CLS] vector to these prototypes with a learnable temperature—fast, stable, and simple. The data are split into train/test, wrapped in PyTorch Dataset/DataLoader, and the model is trained briefly with AdamW and cross-entropy. After training, we persist everything needed for offline inference—classifier.pt (weights), label_embs.pt (prototypes), id2label.pkl (mapping), and the tokenizer—so downstream Kaggle packages can load pre-trained artifacts without re-training and generate responses quickly.

In [None]:
# Model selection  and intilization on gpus for faster computation
import random
import numpy as np
import torch

# ─── 0) Seed ───────────────────────────────────────────────────────────────
random.seed(13)
np.random.seed(13)
torch.manual_seed(13)

# ─── 1) Model & Device ─────────────────────────────────────────────────────
MODEL_NAME = "microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract-fulltext"

# Use MPS (Metal Performance Shaders) if available, otherwise CPU
DEVICE = torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")
print(f"Using device: {DEVICE}")

# ─── 2) Hyperparameters ────────────────────────────────────────────────────
BATCH_SIZE      = 8
EPOCHS          = 1
MAX_LEN         = 512
MIN_FREQ        = 4
VERBALIZE_LABEL = False   # flip to False to compare no verbalization

# ─── 3) Load your model and move it to the MPS device ───────────────────────
from transformers import AutoModel, AutoTokenizer

tok   = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME).to(DEVICE)'''

# Model training code

'''
# ─── 2) Load & preprocess CSV ────────────────────────────────────────────────
df = pd.read_csv("/kaggle/input/medquad-dataset/medquad_reclassified.csv")
# build text column with severity, question, answer
df["text"] = (
    df["Severity"].str.strip().str.lower().map({
        'normal':   'SeverityNormal',
        'moderate': 'SeverityModerate',
        'extreme':  'SeverityExtreme'
    })
    + " -- "
    + df["question"].str.strip()
    + " "
    + df["answer"].str.strip()
)
df = df.dropna(subset=["text"]).reset_index(drop=True)
dash_pat = r"[-‐-–—]"
df["text"] = df["text"].str.replace(dash_pat, " ", regex=True)

def canon(lbl: str) -> str:
    lbl = re.sub(r"[‐-–—]", "-", lbl)
    lbl = re.sub(r"\s+", " ", lbl)
    return lbl.lower().strip()

df["focus_area"] = (
    df["focus_area"].astype(str)
       .str.replace(dash_pat, " ", regex=True)
       .apply(canon)
)
# prune rare labels
valid_labels = df["focus_area"].value_counts()[lambda x: x >= MIN_FREQ].index
df = df[df["focus_area"].isin(valid_labels)].reset_index(drop=True)

# ─── 3) Tokenizer & BERT backbone ────────────────────────────────────────────
tok  = transformers.AutoTokenizer.from_pretrained(MODEL_NAME)
bert = transformers.AutoModel.from_pretrained(MODEL_NAME).to(DEVICE).eval()

# ─── 4) Build label↔️id maps & embeddings (no verbalisation!) ─────────────────
label2id = {l: i for i, l in enumerate(sorted(df["focus_area"].unique()))}
id2label = {i: l for l, i in label2id.items()}

@torch.no_grad()
def encode_text(txt: str) -> torch.Tensor:
    toks = tok(txt,
               return_tensors="pt",
               truncation=True,
               max_length=MAX_LEN,
               padding=False).to(DEVICE)
    return bert(**toks).last_hidden_state[:, 0].squeeze().cpu()

# If VERBALIZE_LABEL were True, you’d wrap here, but it’s False so we just use raw labels
label_embs = torch.stack([encode_text(label) for label in label2id.keys()])

# ─── 5) Train/Test Split ─────────────────────────────────────────────────────
train_df, test_df = train_test_split(
    df,
    test_size=max(int(0.15 * len(df)), len(label2id)),
    stratify=df["focus_area"],
    random_state=42
)

# ─── 6) Dataset & DataLoader ─────────────────────────────────────────────────
class QADataset(Dataset):
    def __init__(self, frame):
        self.texts  = frame["text"].tolist()
        self.labels = frame["focus_area"].map(label2id).tolist()
    def __len__(self):
        return len(self.texts)
    def __getitem__(self, idx):
        enc = tok(self.texts[idx],
                  truncation=True,
                  max_length=MAX_LEN,
                  padding="max_length",
                  return_tensors="pt")
        item = {k: v.squeeze() for k, v in enc.items()}
        item["label"] = torch.tensor(self.labels[idx])
        return item

train_dl = DataLoader(QADataset(train_df), batch_size=BATCH_SIZE, shuffle=True)
test_dl  = DataLoader(QADataset(test_df),  batch_size=BATCH_SIZE)

# ─── 7) Label‑Embedding Classifier ────────────────────────────────────────────
class LabelEmbCls(nn.Module):
    def __init__(self, base, lbl_emb):
        super().__init__()
        self.bert = base
        self.lbl_E = nn.Parameter(lbl_emb.to(DEVICE), requires_grad=False)
        self.tau   = nn.Parameter(torch.tensor(1.0))
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        cls = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        ).last_hidden_state[:, 0]
        return torch.matmul(cls, self.lbl_E.T) / self.tau

model     = LabelEmbCls(bert, label_embs).to(DEVICE)
criterion = nn.CrossEntropyLoss()
optim     = torch.optim.AdamW(model.parameters(), lr=3e-5)

# ─── 8) Training Loop ────────────────────────────────────────────────────────
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0
    batch_bar = tqdm(train_dl, desc=f"Epoch {epoch+1}/{EPOCHS}", leave=True)
    for batch in batch_bar:
        labels = batch.pop("label").to(DEVICE)
        inputs = {k: v.to(DEVICE) for k, v in batch.items()}
        optim.zero_grad()
        loss = criterion(model(**inputs), labels)
        loss.backward()
        optim.step()
        total_loss += loss.item() * labels.size(0)
        avg_loss = total_loss / ((batch_bar.n + 1) * BATCH_SIZE)
        batch_bar.set_postfix({'loss': f'{avg_loss:.4f}'})
    print(f"Epoch {epoch+1} completed. Avg Loss: {total_loss/len(train_dl.dataset):.4f}")

print("Training complete.")
'''

#Save the trained model

'''
import os
import torch
import pickle

ARTIFACT_DIR = "/kaggle/input/pubmedbert_model/pytorch/default/1"      # pick a folder
os.makedirs(ARTIFACT_DIR, exist_ok=True)

# 1a) Save the classifier weights
torch.save(model.state_dict(), os.path.join(ARTIFACT_DIR, "classifier.pt"))

# 1b) Save the label embeddings tensor and id2label map
torch.save(label_embs, os.path.join(ARTIFACT_DIR, "label_embs.pt"))
with open(os.path.join(ARTIFACT_DIR, "id2label.pkl"), "wb") as f:
    pickle.dump(id2label, f)

# 1c) Save the tokenizer
tok.save_pretrained(ARTIFACT_DIR)


'\nimport os\nimport torch\nimport pickle\n\nARTIFACT_DIR = "/kaggle/input/pubmedbert_model/pytorch/default/1"      # pick a folder\nos.makedirs(ARTIFACT_DIR, exist_ok=True)\n\n# 1a) Save the classifier weights\ntorch.save(model.state_dict(), os.path.join(ARTIFACT_DIR, "classifier.pt"))\n\n# 1b) Save the label embeddings tensor and id2label map\ntorch.save(label_embs, os.path.join(ARTIFACT_DIR, "label_embs.pt"))\nwith open(os.path.join(ARTIFACT_DIR, "id2label.pkl"), "wb") as f:\n    pickle.dump(id2label, f)\n\n# 1c) Save the tokenizer\ntok.save_pretrained(ARTIFACT_DIR)\n'

## Model Validation Code

## Runtime setup & paths (offline, Kaggle-friendly)

1. Device: choose GPU if available (cuda) else CPU.
2. Inputs mounted under /kaggle/input:
   1. ARTIFACT_DIR → classifier artifacts: tokenizer, label_embs.pt, id2label.pkl, classifier.pt.
   2. BACKBONE_DIR → PubMedBERT base weights (config.json + pytorch_model.bin).
   3. VAL_DATA_PATH → validation CSV with columns question (text) and focus_area (label).
4. Batching/length: BATCH_SIZE=32, MAX_LEN=64 for tokenization (tuned for speed on T4).
5. Offline-first: all HF loads use local_files_only=True (no internet calls).

In [3]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
ARTIFACT_DIR = "/kaggle/input/pubmedbert_model/pytorch/default/1"
BACKBONE_DIR = "/kaggle/input/pubmedbert_base/pytorch/default/1"
VAL_DATA_PATH = "/kaggle/input/medquad-dataset/medquad_reclassified.csv"
BATCH_SIZE = 32
MAX_LEN = 64

## Reload tokenizer, embeddings, label map, and BERT backbone
1. Tokenizer loaded from your artifact directory to match training.
2. Backbone: AutoModel (encoder-only PubMedBERT) loaded from BACKBONE_DIR, moved to DEVICE, eval() for deterministic inference.
3. Label embeddings: label_embs.pt is a [num_labels, hidden_size] matrix; moved to DEVICE.
4. id2label.pkl: maps integer class IDs → human-readable focus area string.
   1. We’ll invert this later to get label2id for supervision.

In [4]:
# 1) Reload tokenizer, label embeddings, id2label, and backbone
tok = AutoTokenizer.from_pretrained(ARTIFACT_DIR, local_files_only=True)

from transformers import AutoModel
bert = AutoModel.from_pretrained(BACKBONE_DIR, local_files_only=True).to(DEVICE).eval()

label_embs = torch.load(os.path.join(ARTIFACT_DIR, "label_embs.pt"), map_location=DEVICE).to(DEVICE)
with open(os.path.join(ARTIFACT_DIR, "id2label.pkl"), "rb") as f:
    id2label = pickle.load(f)

2025-08-12 16:43:26.024993: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1755017006.271313      14 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1755017006.338751      14 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


## Label-Embedding classifier head (simple, fast)
1. Defines LabelEmbCls:
   1. Runs BERT, takes the [CLS] vector: last_hidden_state[:, 0].
   2. Computes cosine-like matching via a learned temperature:
      logits = (CLS @ label_embsᵀ) / tau
   3. lbl_E is frozen (requires_grad=False) for inference; tau remains learned during training (now fixed).
2. Why this head? Minimal parameters, stable, and efficient vs. adding an MLP on top.

In [5]:
class LabelEmbCls(nn.Module):
    def __init__(self, base, lbl_emb):
        super().__init__()
        self.bert = base
        self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False)
        self.tau = nn.Parameter(torch.tensor(1.0))

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        cls = self.bert(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids).last_hidden_state[:, 0]
        logits = torch.matmul(cls, self.lbl_E.T) / self.tau
        return logits

model = LabelEmbCls(bert, label_embs).to(DEVICE)

## Restore classifier checkpoint safely
1. Loads classifier.pt from ARTIFACT_DIR.
2. If the checkpoint was saved with DataParallel, strips the module. prefix.
3. Uses strict=False so minor key differences don’t break loading.
4. Prints any missing/unexpected keys (should be empty for a perfect restore).
5. Sets model.eval() for inference-only behavior.

In [6]:
state = torch.load(os.path.join(ARTIFACT_DIR, "classifier.pt"), map_location=DEVICE)
# if saved with DataParallel:
if any(k.startswith("module.") for k in state.keys()):
    state = {k.replace("module.", ""): v for k, v in state.items()}

missing, unexpected = model.load_state_dict(state, strict=False)
print("missing:", missing, "unexpected:", unexpected)

model.eval()

missing: [] unexpected: []


LabelEmbCls(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_a

## Validation dataset & loader
1. Assumptions about CSV:
   1. question → free text to classify.
   2. focus_area → gold label string that must be present in label2id.
2. TextDataset:
   1. On __getitem__: tokenizes each question to fixed length (MAX_LEN) with **truncation + padding**.
   2. Returns input_ids, attention_mask (and token_type_ids if tokenizer provides), plus labels as integer IDs.
3. Label mapping: inverts id2label to get label2id; filters rows whose focus_area is missing from the mapping (prints a warning).
4. DataLoader: batches validation samples with BATCH_SIZE=32.

In [7]:
# 3) Prepare validation dataset
#    Expect a CSV with 'text' and 'label' columns
class TextDataset(Dataset):
    def __init__(self, df, tokenizer, max_len, label2id):
        self.texts = df["question"].tolist()
        self.labels = [label2id[l] for l in df["focus_area"].tolist()]
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        enc = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )
        item = {k: v.squeeze(0) for k, v in enc.items()}
        item["labels"] = torch.tensor(self.labels[idx], dtype=torch.long)
        return item

# invert id2label for easy mapping
label2id = {v: k for k, v in id2label.items()}

# load your validation DataFrame
df_val = pd.read_csv(VAL_DATA_PATH)

# Filter out rows with focus_area not in label2id
missing_labels = set(df_val["focus_area"].dropna()) - set(label2id.keys())
if missing_labels:
    print("Warning: The following labels are missing from label2id and will be dropped:", missing_labels)
df_val = df_val[df_val["focus_area"].isin(label2id.keys())].reset_index(drop=True)

val_ds = TextDataset(df_val, tok, MAX_LEN, label2id)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE)



## Evaluation loop (no grad)
1. Loss: nn.CrossEntropyLoss() over the model’s raw logits vs. integer labels.
2. Moves batch tensors to DEVICE; handles optional token_type_ids (some BERT tokenizers don’t return it).
3. Forward pass: logits = model(input_ids, attention_mask, token_type_ids).
4. Metrics: accumulates total loss (scaled by batch size), computes predicted class via argmax, and counts correct predictions.

In [8]:
# 4) Evaluation loop
loss_fn = nn.CrossEntropyLoss()
total_loss = 0.0
correct = 0
total = 0

with torch.no_grad():
    for batch in tqdm(val_loader, desc="Evaluating"):
        input_ids = batch["input_ids"].to(DEVICE)
        attention_mask = batch["attention_mask"].to(DEVICE)
        # some tokenizers return token_type_ids
        token_type_ids = batch.get("token_type_ids")
        if token_type_ids is not None:
            token_type_ids = token_type_ids.to(DEVICE)

        labels = batch["labels"].to(DEVICE)

        logits = model(input_ids, attention_mask, token_type_ids)
        loss = loss_fn(logits, labels)
        total_loss += loss.item() * labels.size(0)

        preds = logits.argmax(dim=-1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

Evaluating:   0%|          | 0/83 [00:00<?, ?it/s]

## Report metrics
1. Computes avg_loss = total_loss / total_examples.
2. Computes accuracy = correct / total * 100.
3. Prints Validation loss (useful for sanity) and Validation accuracy (headline metric).

In [9]:
# 5) Report
avg_loss = total_loss / total
accuracy = correct / total * 100
print(f"Validation loss: {avg_loss:.4f}")
print(f"Validation accuracy: {accuracy:.2f}%")

Validation loss: 0.0393
Validation accuracy: 99.05%


## a) Paths, runtime, and limits (offline & reproducible)
1. Device: selects GPU if available; otherwise CPU.
2. Mounted inputs:
   1. ARTIFACT_DIR → classifier artifacts (tokenizer, label_embs.pt, id2label.pkl, classifier.pt)
   2. BACKBONE_DIR → PubMedBERT base weights
   3. WIKI_DIR → curated .txt background files (one per focus area)
   4. PHYS_MODEL_DIR → Mistral-7B-Instruct (mistral-pytorch-7b-instruct-v0.1-hf-v1) (Apache-2.0) model
3. Preview/context knobs: control section teasers and max background size fed to the LLM.
4. Noise control: suppresses verbose TF logs.
5. Concept: Offline, deterministic pipeline → no network calls; easier to audit & reproduce.

## b) Helper utilities (normalize, section, preview, retrieve)
1. _norm() → canonicalizes labels/filenames (ASCII, lowercase, alnum) so lookups are robust.
2. split_sections() → splits article text into paragraph blocks (blank-line delimited).
3. preview_section() → shows a short, sentence-bounded teaser per section.
4. build_wiki_index() → maps normalized filename → absolute path for fast retrieval.
5. read_background() → loads the matched .txt content safely.
6. condense_for_context() → keeps early sections up to a cap for compact prompts.
7. Concept: Deterministic local retrieval (RAG-lite) keeps prompts focused and fast.

## c) [](http://)Background indexing (one-time)
1. Scans WIKI_DIR and builds an O(1) lookup dict for background files.
2. Prints how many focus areas were indexed for transparency.
3. Concept: Pre-indexing avoids per-query I/O scans → lower latency.

## d) Classifier artifacts: tokenizer, embeddings, label map, backbone
1. Loads the exact tokenizer used in training from ARTIFACT_DIR.
2. Restores label_embs.pt (shape [num_labels, hidden]) and id2label.pkl (id → name).
3. Loads PubMedBERT encoder from BACKBONE_DIR and switches to eval().
4. Concept: Matching tokenizer + backbone ensures feature space consistency at eval time.

## f) Label-embedding head (lightweight classifier)
1. LabelEmbCls:
   1. Forward: BERT → take [CLS] vector.
   2. Compute logits = (CLS @ label_embeddingsᵀ) / τ with learnable temperature tau.
2. Why this: minimal params, stable, and fast; acts like a prototypical classifier.
3. Concept: Metric-learning vibe: classify via similarity to label prototypes instead of a big MLP head.

## g) Restore classifier checkpoint (robustly)
1. Loads classifier.pt; strips module. if saved under DataParallel.
2. Uses strict=False to tolerate benign key diffs; prints missing/unexpected keys for sanity.
3. Sets eval() for deterministic inference.
4. Concept: Defensive loading avoids brittle failures across training/eval environments.

## h) Physician generator (Mistral-7B-Instruct) — quantized
1. Loads tokenizer + model from PHYS_MODEL_DIR with local_files_only=True.
2. Tries 4-bit nf4 quantization (BitsAndBytesConfig) with device_map="auto" to fit Kaggle’s T4; falls back to fp16/fp32 if needed.
3. Ensures tokenizer has EOS & PAD (sets PAD=EOS); sets left truncation for long prompts.
4. Concept: Quantization = big model quality on small GPU; tokenizer PAD/EOS hygiene prevents padding errors and speeds up single-prompt inference.

In [10]:
import os, re, pickle, unicodedata
from pathlib import Path
import torch
import torch.nn as nn
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM

# ---------- PATHS ----------
ARTIFACT_DIR   = "/kaggle/input/pubmedbert_model/pytorch/default/1"      # tokenizer + label_embs.pt + id2label.pkl + classifier.pt
BACKBONE_DIR   = "/kaggle/input/pubmedbert_base/pytorch/default/1"       # config.json + pytorch_model.bin
WIKI_DIR       = "/kaggle/input/wiki-data"                               # folder with *.txt files
PHYS_MODEL_DIR = "/kaggle/input/physician_transformer/pytorch/default/1" # Mistral-7B-Instruct v0.1
# ---------------------------

# Preview + context limits
PREVIEW_SENTENCES = 3
PREVIEW_CHARS     = 600
CTX_CHAR_LIMIT    = 3500

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ---------- Helpers ----------
def _norm(s: str) -> str:
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    return "".join(ch.lower() for ch in s if ch.isalnum())

def split_sections(text: str):
    parts = re.split(r"(?:\r?\n\s*\r?\n)+", text.strip())
    return [p.strip() for p in parts if p.strip()]

def preview_section(sec: str, max_sents=PREVIEW_SENTENCES, max_chars=PREVIEW_CHARS):
    sents = re.split(r"(?<=[.!?])\s+", sec.strip())
    out, total = [], 0
    for s in sents:
        if not s: 
            continue
        if len(out) >= max_sents or total + len(s) > max_chars:
            break
        out.append(s); total += len(s) + 1
    prev = " ".join(out).strip()
    return prev + (" …" if len(prev) < len(sec.strip()) else "")

def build_wiki_index(folder: str):
    assert os.path.isdir(folder), f"WIKI_DIR not found: {folder}"
    idx = {}
    for root, _, files in os.walk(folder):
        for fn in files:
            if fn.lower().endswith(".txt"):
                idx[_norm(Path(fn).stem)] = os.path.join(root, fn)
    return idx

def read_background(label: str, index: dict):
    path = index.get(_norm(label))
    if path and os.path.isfile(path):
        try:
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                return f.read(), path
        except Exception:
            return None, None
    return None, None

def condense_for_context(text: str, char_limit: int = CTX_CHAR_LIMIT) -> str:
    secs = split_sections(text)
    out, total = [], 0
    for sec in secs:
        if total + len(sec) + 2 > char_limit:
            break
        out.append(sec); total += len(sec) + 2
    return "\n\n".join(out) if out else text[:char_limit]

# ---------- Background index ----------
wiki_index = build_wiki_index(WIKI_DIR)
print(f"[Info] Indexed {len(wiki_index)} background files from {WIKI_DIR}")

# ---------- Classifier ----------
tok = AutoTokenizer.from_pretrained(ARTIFACT_DIR, local_files_only=True)
label_embs = torch.load(os.path.join(ARTIFACT_DIR, "label_embs.pt"), map_location=DEVICE).to(DEVICE)
with open(os.path.join(ARTIFACT_DIR, "id2label.pkl"), "rb") as f:
    id2label = pickle.load(f)

bert = AutoModel.from_pretrained(BACKBONE_DIR, local_files_only=True).to(DEVICE).eval()

class LabelEmbCls(nn.Module):
    def __init__(self, base, lbl_emb):
        super().__init__()
        self.bert = base
        self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False)
        self.tau = nn.Parameter(torch.tensor(1.0))
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        out = self.bert(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)
        cls = out.last_hidden_state[:, 0]
        return torch.matmul(cls, self.lbl_E.T) / self.tau

model = LabelEmbCls(bert, label_embs).to(DEVICE)
state = torch.load(os.path.join(ARTIFACT_DIR, "classifier.pt"), map_location=DEVICE)
if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
    state = {k.replace("module.", ""): v for k, v in state.items()}
model.load_state_dict(state, strict=False)
model.eval()

# ---------- Physician model (Mistral) ----------
phys_tok = AutoTokenizer.from_pretrained(PHYS_MODEL_DIR, local_files_only=True)

load_kwargs = dict(local_files_only=True, low_cpu_mem_usage=True)
try:
    from transformers import BitsAndBytesConfig
    bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
    )
    load_kwargs.update(dict(quantization_config=bnb_cfg, device_map="auto"))
except Exception:
    load_kwargs.update(dict(torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32))

phys_model = AutoModelForCausalLM.from_pretrained(PHYS_MODEL_DIR, **load_kwargs).eval()

if phys_tok.eos_token_id is None:
    phys_tok.add_special_tokens({"eos_token": ""})
    phys_model.resize_token_embeddings(len(phys_tok))
phys_model.config.eos_token_id = phys_tok.eos_token_id
phys_model.config.pad_token_id = phys_tok.eos_token_id
phys_tok.truncation_side = "left"

[Info] Indexed 4644 background files from /kaggle/input/wiki-data


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

## a) Timed inference-only loop (fast, grounded, one-paragraph answers)
1. What this cell does (high level):Takes a user query → classifies its focus area with PubMedBERT → retrieves the matching local wiki .txt → builds a small context (token-budgeted) → asks Mistral to write one short paragraph (empathetic, no lists) → prints detailed timings.
2. Concept: Router-then-generate (RAG-lite) — classifier picks topic → local retrieval → LLM writes a paragraph grounded in that text.

## b) Knobs (speed/length trade-offs)
1. CTX_BG_TOKEN_BUDGET (default 450) — how many background tokens to keep. Lower = faster, Higher = more context.
2. MAX_NEW_TOKENS (120) / MIN_NEW_TOKENS (70) — target ~4–5 sentences.
3. SHOW_FOCUS — print the predicted focus label for transparency.
4. Concept: Token budgets cap both prompt size and response length for predictable latency.

## c) Helper utilities
1. _norm() → canonical label/filename matching (ASCII, lowercase, alnum).
2. _get_background_for_focus() → load the correct .txt by predicted label.
3. _clamp_by_tokens() → trim background by tokens (not characters) to fit the budget.
4. _finalize_paragraph() → collapse whitespace, strip bullets/numbering, and ensure a full sentence end (. ? !).
5. Concept: Deterministic local retrieval + post-processing guarantees clean, grounded paragraphs.

## d) answer_query_timed(q) — step by step
1. Classify: tokenize PREFIX + q (BERT format) → run PubMedBERT label-embedding head → focus label.
2. Retrieve+Clamp: load the focus area .txt and token-budget it (keeps prompt compact).
3. Prompt+Tokenize: instruction enforces one short paragraph, no lists, empathetic tone; no padding (single prompt).
4. Generate: Mistral runs greedy decoding with caching; PAD/EOS properly set; TF32 matmul enabled for speed.
5. Decode+Finalize: extract after “Answer”, strip numbering, end at last sentence delimiter.

It also prints:
[timing] classify | bg+clamp | tokenize | prompt_tokens | generate | total

6. Concept: Each stage is timed so you (and judges) can see where latency lives and why it’s controlled.

## e) Tuning strategies
1. Faster: set CTX_BG_TOKEN_BUDGET=350–400 and MAX_NEW_TOKENS≈100.
2. Longer: raise MAX_NEW_TOKENS (keep MIN_NEW_TOKENS proportionally).
3. More transparency: set SHOW_FOCUS=True to display the predicted focus area with each answer.
4. Concept: Clear, isolated knobs let you choose the balance between speed, length, and grounding.

In [None]:

# === Timed inference-only loop (uses already-loaded tok/model/phys_tok/phys_model/etc.) ===
import re, time, unicodedata, torch, os

# knobs
SHOW_FOCUS            = False      # True -> also print predicted focus
CTX_BG_TOKEN_BUDGET   = 450        # background tokens kept (speed/quality)
MAX_NEW_TOKENS        = 120        # aim for ~4–5 sentences
MIN_NEW_TOKENS        = 70

# defaults so we don't depend on globals
DEFAULT_PREFIX  = "SeverityNormal -- "
DEFAULT_MAX_LEN = 64

def _norm(s: str) -> str:
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    return "".join(ch.lower() for ch in s if ch.isalnum())

def _get_background_for_focus(focus: str) -> str:
    path = wiki_index.get(_norm(focus))
    if not path: return ""
    try:
        with open(path, "r", encoding="utf-8", errors="ignore") as f:
            return f.read()
    except Exception:
        return ""

def _clamp_by_tokens(text: str, tok, budget: int) -> str:
    ids = tok.encode(text, add_special_tokens=False)
    return text if len(ids) <= budget else tok.decode(ids[:budget], skip_special_tokens=True)

def _finalize_paragraph(s: str) -> str:
    s = s.replace("\n", " ").strip()
    s = re.sub(r"(^|\s)(?:\d+[\.\)]|[-*•])\s+", r" ", s)   # strip numbering/bullets
    s = re.sub(r"\s+", " ", s).strip()
    last = max(s.rfind("."), s.rfind("?"), s.rfind("!"))
    return s[: last + 1] if last != -1 else (s.rstrip(",;:- ") + ".")

def answer_query_timed(q: str, prefix: str = DEFAULT_PREFIX, max_len: int = DEFAULT_MAX_LEN) -> str:
    t_all = time.time()

    # (1) classify
    t0 = time.time()
    enc = tok(prefix + q, truncation=True, max_length=max_len, padding="max_length", return_tensors="pt")
    enc = {k: v.to(DEVICE) for k, v in enc.items()}
    with torch.no_grad():
        focus = id2label[int(torch.argmax(model(**enc), dim=-1).item())]
    t_cls = time.time() - t0

    # (2) background + clamp
    t0 = time.time()
    bg = _get_background_for_focus(focus)
    context = _clamp_by_tokens(bg, phys_tok, CTX_BG_TOKEN_BUDGET)
    t_bg = time.time() - t0

    # (3) build prompt + tokenize (no padding)
    prompt = (
        "You are a board-certified physician. Using ONLY the background below, write ONE short paragraph "
        "(4–5 sentences). Be empathetic, give practical next steps, and mention urgent-care signs only if warranted. "
        "Do NOT use bullet points, numbering, or line breaks.\n\n"
        f"Background:\n{context}\n\n"
        f"User question: {q}\n"
        "Answer (one short paragraph, no lists or numbering):\n"
    )
    t0 = time.time()
    inputs = phys_tok(prompt, return_tensors="pt", truncation=True)
    n_in  = int(inputs["input_ids"].shape[1])
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    t_tok = time.time() - t0

    # (4) generate
    torch.backends.cuda.matmul.allow_tf32 = True
    t0 = time.time()
    with torch.inference_mode():
        out = phys_model.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            min_new_tokens=MIN_NEW_TOKENS,
            do_sample=False,
            use_cache=True,
            pad_token_id=phys_tok.pad_token_id,
            eos_token_id=phys_tok.eos_token_id,
        )
    t_gen = time.time() - t0

    # (5) decode + finalize
    text = phys_tok.decode(out[0], skip_special_tokens=True)
    ans  = text.split("Answer", 1)[-1]
    ans  = ans.split(":", 1)[-1] if ":" in ans else ans
    ans  = _finalize_paragraph(ans)

    t_total = time.time() - t_all
    print(f"[timing] classify: {t_cls:.2f}s | bg+clamp: {t_bg:.2f}s | tokenize: {t_tok:.2f}s | "
          f"prompt_tokens: {n_in} | generate: {t_gen:.2f}s | total: {t_total:.2f}s")
    if SHOW_FOCUS:
        print(f"Focus area: {focus}\n")
    print("Response:\n", ans, "\n")
    return ans

print("Ready! Ask your query. Type 'exit' to quit.\n")
while True:
    try:
        q = input("Query ▶ ").strip()
    except (EOFError, KeyboardInterrupt):
        print("\nExiting."); break
    if q.lower() in {"exit", "quit"}:
        print("Goodbye!"); break
    if not q:
        continue
    answer_query_timed(q)   # uses default prefix/max_len; override if needed

'\n# === Timed inference-only loop (uses already-loaded tok/model/phys_tok/phys_model/etc.) ===\nimport re, time, unicodedata, torch, os\n\n# knobs\nSHOW_FOCUS            = False      # True -> also print predicted focus\nCTX_BG_TOKEN_BUDGET   = 450        # background tokens kept (speed/quality)\nMAX_NEW_TOKENS        = 120        # aim for ~4–5 sentences\nMIN_NEW_TOKENS        = 70\n\n# defaults so we don\'t depend on globals\nDEFAULT_PREFIX  = "SeverityNormal -- "\nDEFAULT_MAX_LEN = 64\n\ndef _norm(s: str) -> str:\n    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")\n    return "".join(ch.lower() for ch in s if ch.isalnum())\n\ndef _get_background_for_focus(focus: str) -> str:\n    path = wiki_index.get(_norm(focus))\n    if not path: return ""\n    try:\n        with open(path, "r", encoding="utf-8", errors="ignore") as f:\n            return f.read()\n    except Exception:\n        return ""\n\ndef _clamp_by_tokens(text: str, tok, budget: int) -> str