<a href="https://colab.research.google.com/github/Sama-Borhani/medical-llm-chatbot-lora-faiss-gradio/blob/main/notebooks/Med_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🩺 Medical LLM Chatbot — RAG + Gradio (Demo)

> **Educational use only — not medical advice.**

This Colab notebook sets up dependencies, builds a tiny FAISS index from sample text (or your own files), loads `Writer/camel-5b-hf` in 4-bit/8-bit, and launches a Gradio demo.

In [None]:
# Install required packages (quiet)
!pip -q install transformers accelerate bitsandbytes torch
!pip -q install sentence-transformers faiss-cpu
!pip -q install gradio PyPDF2 requests peft jinja2 python-dotenv
import torch, os, sys, json, time, re, requests, faiss, numpy as np
from pathlib import Path
print('GPU available:', torch.cuda.is_available())
if torch.cuda.is_available():
    print('GPU name:', torch.cuda.get_device_name(0))

## 🔑 Environment
Set a Hugging Face token **only if** your model needs it. (Camel-5B is public; you can leave blank.)

In [None]:
HF_TOKEN = ''  # put your token if required
MODEL_ID = 'Writer/camel-5b-hf'
EMBED_MODEL = 'sentence-transformers/all-MiniLM-L6-v2'
TOP_K = 4

## 🧠 Load base model (4-bit → fallback 8-bit)

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
tok = AutoTokenizer.from_pretrained(MODEL_ID, use_auth_token=HF_TOKEN or None)
if tok.pad_token is None:
    tok.pad_token = tok.eos_token
bnb = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True,
                         bnb_4bit_quant_type='nf4', bnb_4bit_compute_dtype=torch.float16)
try:
    print('Loading 4-bit...')
    mdl = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, device_map='auto', quantization_config=bnb,
        use_auth_token=HF_TOKEN or None, torch_dtype=torch.float16
    )
except Exception as e:
    print('4-bit failed:', e, '\nRetrying 8-bit...')
    bnb = BitsAndBytesConfig(load_in_8bit=True)
    mdl = AutoModelForCausalLM.from_pretrained(
        MODEL_ID, device_map='auto', quantization_config=bnb,
        use_auth_token=HF_TOKEN or None, torch_dtype=torch.float16
    )
device = mdl.device
print('Model loaded on', device)

## 📚 Corpus (add your own or use samples)
We’ll create two tiny sample files if `corpus/` is empty. You can also upload your own `.txt` files in the sidebar.

In [None]:
corpus = Path('corpus'); corpus.mkdir(exist_ok=True)
if not any(corpus.glob('*.txt')):
    (corpus/'hypertension_basics.txt').write_text(
        'Hypertension is persistently elevated blood pressure. Lifestyle changes and, when needed, ACEi/ARB, thiazides, or CCBs are common first-line choices.'
    )
    (corpus/'type2_diabetes_summary.txt').write_text(
        'Type 2 diabetes is characterized by insulin resistance. Lifestyle changes and metformin are first-line; other agents depend on comorbidities.'
    )
    print('Added two sample docs to corpus/.')
print('Corpus files:', [p.name for p in corpus.glob('*.txt')])

## 🔎 Build in-memory FAISS index

In [None]:
from sentence_transformers import SentenceTransformer
texts, meta = [], []
for p in sorted(corpus.glob('*.txt')):
    t = p.read_text(encoding='utf-8', errors='ignore').strip()
    if t:
        texts.append(t)
        meta.append({'path': str(p), 'title': p.stem})
assert texts, 'No text found in corpus/. Add .txt files.'
embedder = SentenceTransformer(EMBED_MODEL)
emb = embedder.encode(texts, convert_to_numpy=True, normalize_embeddings=True)
index = faiss.IndexFlatIP(emb.shape[1]); index.add(emb)
print('Index size:', index.ntotal)

## 🧩 Retrieval + Prompting helpers

In [None]:
from jinja2 import Template
SYSTEM = (
    'You are a careful medical assistant for educational purposes only. '
    'Answer briefly and cite snippet numbers like [#]. If unsure, say so.'
)
TEMPLATE = Template(
    """
{{system}}

Context:
{% for i,s in sources %}[{{i+1}}] {{s[:220]}}...\n{% endfor %}

Question: {{q}}
Answer with: 1) Summary, 2) Bullet points, 3) References by [#].
"""
)

def retrieve(query, k=TOP_K):
    qv = embedder.encode([query], normalize_embeddings=True).astype('float32')
    sc, ix = index.search(qv, k)
    items = []
    for j in ix[0].tolist():
        if 0 <= j < len(texts):
            items.append((texts[j], meta[j]))
    return items

def generate_answer(question):
    srcs = retrieve(question)
    prompt = TEMPLATE.render(system=SYSTEM, q=question, sources=list(enumerate([s for s,_ in srcs])))
    toks = tok(prompt, return_tensors='pt').to(device)
    out = mdl.generate(**toks, max_new_tokens=300, temperature=0.2, do_sample=False)
    text = tok.decode(out[0], skip_special_tokens=True)
    # Append simple refs
    refs = '\n\nReferences: ' + ', '.join(f"[{i+1}] {m['title']}" for i,(_,m) in enumerate(srcs))
    return text + refs

## 🚀 Launch Gradio

In [None]:
import gradio as gr
def chat_fn(msg, history):
    try:
        ans = generate_answer(msg)
    except Exception as e:
        ans = f"Error: {e}\nTry a simpler question."
    history = (history or []) + [[msg, ans + '\n\n⚠️ Not medical advice.']]
    return history, ''

with gr.Blocks(title='Medical LLM Chatbot (Demo)') as demo:
    gr.Markdown('# Medical LLM Chatbot — Demo\n*Educational use only*')
    chat = gr.Chatbot(height=400)
    msg = gr.Textbox(label='Ask a medical question', lines=2)
    send = gr.Button('Send')
    clear = gr.Button('Clear')
    send.click(chat_fn, [msg, chat], [chat, msg])
    msg.submit(chat_fn, [msg, chat], [chat, msg])
    clear.click(lambda: [], None, chat)
demo.launch(share=True)