# Dr.Bot Project

## Exporting to Kaggle Package

Here we will setup the code to initialize the model class by pulling the base model then installing our trained weights, which will further carry-forward to testing of the package in LLM-Output code.

In [None]:
#| default_exp core

In [None]:
#| export

# ===== Kaggle Package: package.Model =====
# Single-class package with .predict(str) -> str
import re
import time
import os
import pickle
import requests
import random
import hashlib
import pandas as pd
import numpy as np
import torch
import transformers
from torch import nn
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from collections import Counter
from tqdm.auto import tqdm
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig 
from cachetools import TTLCache, cached
import os, re, unicodedata, pickle, torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
import kagglehub


# --- small helpers (no external deps) ---
def _norm(s: str) -> str:
    s = unicodedata.normalize("NFKD", s).encode("ascii", "ignore").decode("ascii")
    return "".join(ch.lower() for ch in s if ch.isalnum())

def _finalize_paragraph(s: str) -> str:
    s = s.replace("\n", " ").strip()
    s = re.sub(r"(^|\s)(?:\d+[\.\)]|[-*•])\s+", " ", s)     # remove numbering/bullets
    s = re.sub(r"\s+", " ", s).strip()
    last = max(s.rfind("."), s.rfind("?"), s.rfind("!"))
    return s[: last + 1] if last != -1 else (s.rstrip(",;:- ") + ".")

class LabelEmbCls(torch.nn.Module):
    """BERT encoder + frozen label-embedding head with temperature τ."""
    def __init__(self, base, lbl_emb):
        super().__init__()
        self.bert = base
        self.lbl_E = torch.nn.Parameter(lbl_emb, requires_grad=False)
        self.tau   = torch.nn.Parameter(torch.tensor(1.0))

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        out = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls = out.last_hidden_state[:, 0]  # [CLS]
        return torch.matmul(cls, self.lbl_E.T) / self.tau

class Model:
    """
    Submission package.
    Usage (grader):
        package = kagglehub.package_import('your-username/your-notebook/versions/X')
        model = package.Model()
        print(model.predict("I have a headache..."))
    """

    def __init__(
        self,
        artifact_dir=None,
        backbone_dir=None,
        wiki_dir=None,
        phys_dir=None,
        prefix="SeverityNormal -- ",
        max_len=64,
        ctx_bg_token_budget=450,
        max_new_tokens=120,
        min_new_tokens=70,
    ):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.prefix = prefix
        self.max_len = max_len
        self.ctx_bg_token_budget = ctx_bg_token_budget
        self.max_new_tokens = max_new_tokens
        self.min_new_tokens = min_new_tokens

        # --- resolve inputs via kagglehub (NO /kaggle/input anywhere) ---
        if artifact_dir is None:
            artifact_dir = kagglehub.model_download("dhyeyk29/pubmedbert_model/PyTorch/default/1")
        if backbone_dir is None:
            backbone_dir = kagglehub.model_download("dhyeyk29/pubmedbert_base/PyTorch/default/1")
        if wiki_dir is None:
            wiki_dir = kagglehub.dataset_download("dhyeyk29/wikipedia-data")
        if phys_dir is None:
            phys_dir = kagglehub.model_download("dhyeyk29/physician_transformer/PyTorch/default/1") 

        # --- build wiki index once (fast lookup) ---
        self.wiki_index = {}
        if os.path.isdir(wiki_dir):
            for root, _, files in os.walk(wiki_dir):
                for fn in files:
                    if fn.lower().endswith(".txt"):
                        self.wiki_index[_norm(Path(fn).stem)] = os.path.join(root, fn)

        # --- load classifier artifacts (offline) ---
        self.tok = AutoTokenizer.from_pretrained(artifact_dir, local_files_only=True)
        self.bert = AutoModel.from_pretrained(backbone_dir, local_files_only=True).to(self.device).eval()
        label_embs = torch.load(os.path.join(artifact_dir, "label_embs.pt"),
                                map_location=self.device).to(self.device)
        with open(os.path.join(artifact_dir, "id2label.pkl"), "rb") as f:
            self.id2label = pickle.load(f)

        self.cls_model = LabelEmbCls(self.bert, label_embs).to(self.device)
        state = torch.load(os.path.join(artifact_dir, "classifier.pt"), map_location=self.device)
        if isinstance(state, dict) and any(k.startswith("module.") for k in state.keys()):
            state = {k.replace("module.", ""): v for k, v in state.items()}
        self.cls_model.load_state_dict(state, strict=False)
        self.cls_model.eval()

        # --- load physician LLM (guarded quantization) ---
        self.phys_tok = AutoTokenizer.from_pretrained(phys_dir, local_files_only=True)
        load_kwargs = dict(local_files_only=True, low_cpu_mem_usage=True)

        if torch.cuda.is_available():
            # try 4-bit; fall back to fp16 if bitsandbytes not available
            try:
                from transformers import BitsAndBytesConfig
                bnb_cfg = BitsAndBytesConfig(
                    load_in_4bit=True, bnb_4bit_quant_type="nf4",
                    bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16
                )
                load_kwargs.update(dict(quantization_config=bnb_cfg, device_map="auto"))
            except Exception:
                load_kwargs.update(dict(torch_dtype=torch.float16, device_map="auto"))
        else:
            load_kwargs.update(dict(torch_dtype=torch.float32, device_map={"": "cpu"}))

        self.phys_model = AutoModelForCausalLM.from_pretrained(phys_dir, **load_kwargs).eval()

        # tokenizer/model hygiene
        if self.phys_tok.eos_token_id is None:
            self.phys_tok.add_special_tokens({"eos_token": "</s>"})
            self.phys_model.resize_token_embeddings(len(self.phys_tok))
        if self.phys_tok.pad_token_id is None:
            self.phys_tok.pad_token = self.phys_tok.eos_token
        self.phys_model.config.eos_token_id = self.phys_tok.eos_token_id
        self.phys_model.config.pad_token_id = self.phys_tok.pad_token_id
        self.phys_tok.truncation_side = "left"

    # ---- private helpers ----
    def _classify(self, q: str) -> str:
        enc = self.tok(self.prefix + q, truncation=True, max_length=self.max_len,
                       padding="max_length", return_tensors="pt")
        enc = {k: v.to(self.device) for k, v in enc.items()}
        with torch.no_grad():
            logits = self.cls_model(**enc)
            pred_id = int(torch.argmax(logits, dim=-1).item())
        return self.id2label[pred_id]

    def _get_background(self, focus: str) -> str:
        path = self.wiki_index.get(_norm(focus))
        if not path: return ""
        try:
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                return f.read()
        except Exception:
            return ""

    def _clamp_by_tokens(self, text: str, budget: int) -> str:
        ids = self.phys_tok.encode(text, add_special_tokens=False)
        return text if len(ids) <= budget else self.phys_tok.decode(ids[:budget], skip_special_tokens=True)

    # ---- public API ----
    def predict(self, question: str) -> str:
        """Takes a single question string, returns a single response string."""
        question = (question or "").strip()
        if not question:
            return "Please provide a short description of your concern."

        # 1) classify
        focus = self._classify(question)

        # 2) retrieve background (optional but preferred)
        bg = self._get_background(focus)
        context = self._clamp_by_tokens(bg, self.ctx_bg_token_budget) if bg else ""

        # 3) prompt & generate (single short paragraph)
        prompt = (
            "You are a board-certified physician. Using ONLY the background below, write ONE short paragraph "
            "(4–5 sentences). Be empathetic, give practical next steps, and mention urgent-care signs only if warranted. "
            "Do NOT use bullet points, numbering, or line breaks.\n\n"
            f"Background:\n{context}\n\n"
            f"User question: {question}\n"
            "Answer (one short paragraph, no lists or numbering):\n"
        )

        inputs = self.phys_tok(prompt, return_tensors="pt", truncation=True)  # no padding
        inputs = {k: v.to(self.device) for k, v in inputs.items()}

        with torch.inference_mode():
            gen = self.phys_model.generate(
                **inputs,
                max_new_tokens=self.max_new_tokens,
                min_new_tokens=self.min_new_tokens,
                do_sample=False,
                use_cache=True,
                pad_token_id=self.phys_tok.pad_token_id,
                eos_token_id=self.phys_tok.eos_token_id,
            )

        text = self.phys_tok.decode(gen[0], skip_special_tokens=True)
        ans  = text.split("Answer", 1)[-1]
        ans  = ans.split(":", 1)[-1] if ":" in ans else ans
        return _finalize_paragraph(ans)