# üè• Advanced Medical AI Assistant

This notebook contains the complete pipeline for a Medical AI Assistant using **Gemma-2-9b-it** (or similar) with **RAG (Retrieval Augmented Generation)** and **Fine-Tuning** capabilities.

## Instructions
1. **Runtime**: Make sure you are using a GPU Runtime (Runtime > Change runtime type > T4 GPU or A100/L4 if available).
2. **HuggingFace Login**: You might need to login to HuggingFace for Gemma-2 models (`!huggingface-cli login`).
3. **Run All**: You can run all cells. The training step is optional and can be skipped if you just want to use the RAG system with the base model.

In [15]:
# --- Step 1: Install Dependencies ---
# Note: Google Colab often has pre-installed packages that conflict with newer libraries.
# You will likely see RED TEXT saying "pip's dependency resolver...".
# THIS IS NORMAL. As long as the libraries install, you can ignore the conflict errors regarding 'requests' or 'opentelemetry'.

!pip install -q requests==2.32.4 torch transformers peft bitsandbytes trl accelerate datasets langchain langchain-community langchain-huggingface chromadb sentence-transformers gradio tiktoken pypdf scipy numpy huggingface_hub

# Quick check to ensure core libraries loaded despite warning
try:
    import requests
    import transformers
    print("‚úÖ Core libraries installed successfully (Components are ready).")
except ImportError as e:
    print(f"‚ùå Critical library missing: {e}")

‚úÖ Core libraries installed successfully (Components are ready).


In [18]:
import os
import gc
import torch
import gradio as gr
from datasets import load_dataset, load_from_disk, Dataset
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig,
    TrainingArguments, pipeline
)
from peft import LoraConfig, prepare_model_for_kbit_training, get_peft_model, PeftModel
from trl import SFTTrainer
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from langchain_huggingface import HuggingFacePipeline
from huggingface_hub import login

# Global Configuration
# Use 'google/gemma-2-9b-it' if you have access, or 'unsloth/llama-3-8b-bnb-4bit' for faster inference
BASE_MODEL_NAME = "google/gemma-2-9b-it"
ADAPTER_NAME = "medical_assistant_adapter"
DATASET_PATH = "processed_medical_data"
CHROMA_DB_DIR = "./chroma_db"

print("Libraries loaded. Ready for authentication.")

Libraries loaded. Ready for authentication.


In [24]:
from huggingface_hub import login
login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

In [25]:
!huggingface-cli whoami


Kovidkodi007


In [29]:
class MedicalRAG:
    def __init__(self, persist_dir=CHROMA_DB_DIR):
        self.persist_dir = persist_dir
        # Lightweight embeddings model suitable for CPU/Colab
        self.embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        self.vectordb = None

    def ingest_documents(self, file_paths):
        """Ingests medical documents (PDFs or Text) into the vector store."""
        docs = []
        for path in file_paths:
            print(f"Loading {path}...")
            try:
                if path.endswith(".pdf"):
                    loader = PyPDFLoader(path)
                    docs.extend(loader.load())
                elif path.endswith(".txt"):
                    loader = TextLoader(path)
                    docs.extend(loader.load())
            except Exception as e:
                print(f"Error loading {path}: {e}")

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        splits = text_splitter.split_documents(docs)

        if not splits:
            print("No text found in documents.")
            return

        print(f"Creating vector store with {len(splits)} chunks...")
        # Clean up existing DB to start fresh for demo purposes
        if os.path.exists(self.persist_dir):
             import shutil
             shutil.rmtree(self.persist_dir)

        self.vectordb = Chroma.from_documents(
            documents=splits,
            embedding=self.embedding_function,
            persist_directory=self.persist_dir
        )
        # .persist() is auto-called in newer Chroma versions, but keeping for safety
        # self.vectordb.persist()
        print("Vector store created and saved.")

    def load_vector_store(self):
        if os.path.exists(self.persist_dir):
            self.vectordb = Chroma(persist_directory=self.persist_dir, embedding_function=self.embedding_function)
            print("Loaded existing vector store.")
            return True
        return False

    def setup_rag_pipeline(self, model, tokenizer):
        """Sets up the RAG chain using the loaded model."""
        # Use vector DB if available, else plain LLM
        if not self.vectordb:
            self.load_vector_store()

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.15
        )

        llm = HuggingFacePipeline(pipeline=pipe)

        # Medical-focused system prompt
        template = """<|im_start|>system
You are an advanced medical assistant. Use the following pieces of context to answer the user's question.
If the answer is not in the context, say you don't know, but try to be helpful based on general medical knowledge while warning that it's not from your verified sources.
Always prioritize patient safety. If the user describes a medical emergency, advise them to call emergency services immediately.
Context: {context}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
"""
        QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

        if self.vectordb:
            retriever = self.vectordb.as_retriever(search_kwargs={"k": 3})
            qa_chain = RetrievalQA.from_chain_type(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
            )
            return qa_chain
        else:
            print("No knowledge base found. Running in pure LLM mode.")
            return None

In [30]:
# --- Step 6: Model Training (Optional) ---
# Set TRAIN_MODEL = True to run fine-tuning (~10-20 min on T4)
TRAIN_MODEL = False

def train_model():
    # Clean memory before training
    torch.cuda.empty_cache()
    gc.collect()

    print(f"Loading base model for training: {BASE_MODEL_NAME}")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=False,
    )
    model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_NAME,
        quantization_config=bnb_config,
        device_map="auto"
    )
    model.config.use_cache = False
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
    tokenizer.pad_token = tokenizer.eos_token

    model = prepare_model_for_kbit_training(model)
    peft_config = LoraConfig(
        lora_alpha=16, lora_dropout=0.1, r=64, bias="none",
        task_type="CAUSAL_LM",
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj"]
    )
    model = get_peft_model(model, peft_config)

    dataset = load_from_disk(DATASET_PATH)

    training_args = TrainingArguments(
        output_dir="./results",
        num_train_epochs=1,
        per_device_train_batch_size=4,
        gradient_accumulation_steps=4,
        learning_rate=2e-4,
        fp16=True,
        logging_steps=10,
        max_steps=50, # Quick demo training. Increase for better results.
        optim="paged_adamw_32bit"
    )

    trainer = SFTTrainer(
        model=model,
        train_dataset=dataset["train"],
        peft_config=peft_config,
        dataset_text_field="text",
        tokenizer=tokenizer,
        args=training_args,
    )

    trainer.train()
    trainer.model.save_pretrained(ADAPTER_NAME)
    print("Training Complete. Adapter Saved!")

    # Cleanup to free VRAM for inference
    del model
    del trainer
    torch.cuda.empty_cache()
    gc.collect()

if TRAIN_MODEL:
    train_model()

In [31]:
# --- Step 7: Application Interface ---

# Ensure memory is clean
torch.cuda.empty_cache()
gc.collect()

print("Loading Inference Model...")
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto"
)

tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)

# Try loading adapter
if os.path.exists(ADAPTER_NAME):
    print(f"Loading Fine-Tuned Adapter: {ADAPTER_NAME}")
    model = PeftModel.from_pretrained(base_model, ADAPTER_NAME)
else:
    print("Adapter not found. Using Base Model.")
    model = base_model

# Initialize RAG
rag_system = MedicalRAG()
qa_chain = rag_system.setup_rag_pipeline(model, tokenizer)

def process_query(message, history):
    # Safety check for keywords
    emergency_keywords = ["suicide", "kill myself", "chest pain", "dying", "stroke", "heart attack", "collapsed"]
    if any(k in message.lower() for k in emergency_keywords):
        return "‚ö†Ô∏è CRITICAL ALERT: It sounds like you may be experiencing a medical emergency. Please call emergency services (911 or your local equivalent) immediately."

    try:
        # Use RAG chain if available
        if qa_chain:
            response = qa_chain.invoke({"query": message})
            answer = response["result"]

            # Clean up ChatML tokens if they leak
            if "<|im_start|>assistant" in answer:
                answer = answer.split("<|im_start|>assistant")[-1].strip()

            sources = response.get("source_documents", [])
            if sources:
                source_text = "\n\n---\n**üìö Sources Used:**\n" + "\n".join([f"- {d.metadata.get('source', 'Doc')}" for d in sources])
                return answer + source_text
            return answer
        else:
            # Pure LLM fallback
            inputs = tokenizer(message, return_tensors="pt").to(model.device)
            outputs = model.generate(**inputs, max_new_tokens=250)
            return tokenizer.decode(outputs[0], skip_special_tokens=True)

    except Exception as e:
        return f"An error occurred: {str(e)}"

def ingest_file(files):
    if not files:
        return "No files uploaded."
    file_paths = [f.name for f in files]
    rag_system.ingest_documents(file_paths)

    # Re-init chain with new DB
    global qa_chain
    qa_chain = rag_system.setup_rag_pipeline(model, tokenizer)
    return f"Successfully processed {len(files)} documents. You can now ask questions about them!"

# --- Gradio Layout ---
with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", neutral_hue="slate")) as demo:
    gr.Markdown("# üè• AI Medical Assistant")
    gr.Markdown("This assistant combines **Advanced LLMs** with **Medical Knowledge Retrieval (RAG)** to provide accurate information.")

    with gr.Row():
        with gr.Column(scale=3):
             chat_interface = gr.ChatInterface(
                 process_query,
                 description="Ask me about symptoms, medical concepts, or upload reports (PDF/TXT) for analysis.",
             )
        with gr.Column(scale=1):
            gr.Markdown("### üìÇ Knowledge Base")
            gr.Markdown("Upload medical guides, reports, or papers here to let the AI answer based on them.")
            file_upload = gr.File(label="Upload Documents (PDF/TXT)", file_count="multiple", type="filepath")
            ingest_btn = gr.Button("Process Documents", variant="primary")
            ingest_output = gr.Textbox(label="Status", interactive=False)

    ingest_btn.click(ingest_file, inputs=[file_upload], outputs=[ingest_output])

print("Launching Application...")
demo.launch(share=True, debug=True)

Loading Inference Model...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Adapter not found. Using Base Model.


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md: 0.00B [00:00, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Device set to use cuda:0
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="teal", neutral_hue="slate")) as demo:


No knowledge base found. Running in pure LLM mode.


  self.chatbot = Chatbot(


Launching Application...
Colab notebook detected. This cell will run indefinitely so that you can see errors and logs. To turn off, set debug=False in launch().
* Running on public URL: https://f5da5a5402d69abccf.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)


Keyboard interruption in main thread... closing server.
Killing tunnel 127.0.0.1:7860 <> https://f5da5a5402d69abccf.gradio.live


