# üè• Advanced Medical AI Assistant (Fixed)

This notebook contains the **fixed** pipeline for running the Medical AI Assistant on Google Colab with GPU support.

## Instructions
1. **Runtime**: Ensure you are using a GPU Runtime (Runtime > Change runtime type > T4 GPU).
2. **Configuration**: Enter your HuggingFace and Ngrok tokens in **Cell 3**.
3. **Run All**: Run the cells sequentially.

In [None]:
# --- Step 1: Install Dependencies ---
import subprocess
import sys

print("Installing dependencies... (This may take a few minutes)")
packages = [
    "requests==2.32.3", "torch", "transformers", "peft", "bitsandbytes", "trl", "accelerate",
    "datasets", "langchain", "langchain-community", "langchain-huggingface", "chromadb",
    "sentence-transformers", "gradio", "tiktoken", "pypdf", "scipy", "numpy", "huggingface_hub",
    "fastapi", "uvicorn", "pyngrok", "nest-asyncio", "python-multipart"
]
# Install silently to avoid clutter
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q"] + packages)
print("‚úÖ Dependencies installed.")

In [None]:
# --- Step 2: Imports ---
import os
import gc
import torch
import asyncio
import uvicorn
import nest_asyncio
from pyngrok import ngrok
from fastapi import FastAPI, UploadFile, File, Form
from pydantic import BaseModel
from transformers import (
    AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, pipeline
)
from peft import PeftModel
from langchain_community.document_loaders import PyPDFLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_huggingface import HuggingFaceEmbeddings, HuggingFacePipeline
from langchain_community.vectorstores import Chroma
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from huggingface_hub import login

# Apply nest_asyncio to allow uvicorn to run in Jupyter/Colab environment
nest_asyncio.apply()
print("‚úÖ Imports successful.")

In [None]:
# --- Step 3: Configuration & Auth ---
# ‚ö†Ô∏è REPLACE WITH YOUR TOKENS
HF_TOKEN = "YOUR_HUGGINGFACE_TOKEN_HERE" 
NGROK_AUTH_TOKEN = "YOUR_NGROK_AUTH_TOKEN_HERE"

# Model Config
BASE_MODEL_NAME = "google/gemma-2-9b-it"
ADAPTER_NAME = "medical_assistant_adapter"
CHROMA_DB_DIR = "./chroma_db"

if HF_TOKEN != "YOUR_HUGGINGFACE_TOKEN_HERE":
    login(token=HF_TOKEN)
else:
    print("‚ö†Ô∏è WARNING: You haven't set your HuggingFace token. Some models might fail.")

if NGROK_AUTH_TOKEN != "YOUR_NGROK_AUTH_TOKEN_HERE":
    ngrok.set_auth_token(NGROK_AUTH_TOKEN)
else:
    print("‚ö†Ô∏è WARNING: Ngrok Token not set. Public URL will fail.")

print("‚úÖ Configuration set.")

In [None]:
# --- Step 4: Logic / RAG Class ---
class MedicalRAG:
    def __init__(self, persist_dir=CHROMA_DB_DIR):
        self.persist_dir = persist_dir
        # Lightweight embeddings model suitable for CPU/Colab
        self.embedding_function = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        self.vectordb = None

    def ingest_documents(self, file_paths):
        """Ingests medical documents (PDFs or Text) into the vector store."""
        docs = []
        for path in file_paths:
            print(f"Loading {path}...")
            try:
                if path.endswith(".pdf"):
                    loader = PyPDFLoader(path)
                    docs.extend(loader.load())
                elif path.endswith(".txt"):
                    loader = TextLoader(path)
                    docs.extend(loader.load())
            except Exception as e:
                print(f"Error loading {path}: {e}")

        text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
        splits = text_splitter.split_documents(docs)

        if not splits:
            print("No text found in documents.")
            return

        print(f"Creating vector store with {len(splits)} chunks...")
        # Clean up existing DB to start fresh for demo purposes
        if os.path.exists(self.persist_dir):
             import shutil
             shutil.rmtree(self.persist_dir)

        self.vectordb = Chroma.from_documents(
            documents=splits,
            embedding=self.embedding_function,
            persist_directory=self.persist_dir
        )
        print("Vector store created and saved.")

    def load_vector_store(self):
        if os.path.exists(self.persist_dir):
            self.vectordb = Chroma(persist_directory=self.persist_dir, embedding_function=self.embedding_function)
            print("Loaded existing vector store.")
            return True
        return False

    def setup_rag_pipeline(self, model, tokenizer):
        """Sets up the RAG chain using the loaded model."""
        if not self.vectordb:
            self.load_vector_store()

        pipe = pipeline(
            "text-generation",
            model=model,
            tokenizer=tokenizer,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.95,
            repetition_penalty=1.15
        )

        llm = HuggingFacePipeline(pipeline=pipe)

        template = """<|im_start|>system
You are an advanced medical assistant. Use the following pieces of context to answer the user's question.
If the answer is not in the context, say you don't know, but try to be helpful based on general medical knowledge.
Always prioritize patient safety.
Context: {context}<|im_end|>
<|im_start|>user
{question}<|im_end|>
<|im_start|>assistant
"""
        QA_CHAIN_PROMPT = PromptTemplate.from_template(template)

        if self.vectordb:
            retriever = self.vectordb.as_retriever(search_kwargs={"k": 3})
            qa_chain = RetrievalQA.from_chain_type(
                llm=llm,
                chain_type="stuff",
                retriever=retriever,
                return_source_documents=True,
                chain_type_kwargs={"prompt": QA_CHAIN_PROMPT}
            )
            return qa_chain
        else:
            print("No knowledge base found. Running in pure LLM mode.")
            return None
            
print("‚úÖ MedicalRAG class defined.")

In [None]:
# --- Step 5: Model Loading ---
def load_model():
    torch.cuda.empty_cache()
    gc.collect()
    
    print(f"Loading base model: {BASE_MODEL_NAME}")
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
    )
    
    try:
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL_NAME,
            quantization_config=bnb_config,
            device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_NAME)
    except Exception as e:
        print(f"Error loading model {BASE_MODEL_NAME}: {e}")
        print("Falling back to a smaller model for demo purposes...")
        # Fallback to a smaller model if the big one fails or needs access
        fallback_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
        base_model = AutoModelForCausalLM.from_pretrained(
             fallback_model,
             quantization_config=bnb_config, 
             device_map="auto"
        )
        tokenizer = AutoTokenizer.from_pretrained(fallback_model)
    
    if os.path.exists(ADAPTER_NAME):
        print(f"Loading Adapter: {ADAPTER_NAME}")
        model = PeftModel.from_pretrained(base_model, ADAPTER_NAME)
    else:
        print("Using Base Model (No Adapter found).")
        model = base_model
        
    return model, tokenizer

# Initialize Global Components
print("Initializing System...")
model, tokenizer = load_model()
rag_system = MedicalRAG()
qa_chain = rag_system.setup_rag_pipeline(model, tokenizer)
print("‚úÖ System Initialized.")

In [None]:
# --- Step 6: Server Startup ---
app = FastAPI(title="Medical RAG GPU API")

class QueryRequest(BaseModel):
    message: str

@app.get("/")
def home():
    return {"status": "online", "model": BASE_MODEL_NAME}

@app.post("/query")
def query_model(req: QueryRequest):
    if qa_chain:
        # RAG Mode
        try:
            res = qa_chain.invoke({"query": req.message})
            return {"answer": res["result"], "source_documents": [d.page_content[:200] for d in res.get("source_documents", [])]}
        except Exception as e:
             return {"error": str(e)}
    else:
        # Pure LLM Mode
        inputs = tokenizer(req.message, return_tensors="pt").to(model.device)
        outputs = model.generate(**inputs, max_new_tokens=256)
        answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
        return {"answer": answer}

@app.post("/ingest")
async def ingest_file(file: UploadFile = File(...)):
    file_location = f"./{file.filename}"
    with open(file_location, "wb+") as file_object:
        file_object.write(file.file.read())
    
    rag_system.ingest_documents([file_location])
    
    # Reload the pipeline to include new data
    global qa_chain
    qa_chain = rag_system.setup_rag_pipeline(model, tokenizer)
    
    return {"message": f"Successfully ingested {file.filename}"}

# Connect Ngrok
if NGROK_AUTH_TOKEN != "YOUR_NGROK_AUTH_TOKEN_HERE":
    public_url = ngrok.connect(8000).public_url
    print(f"\nüöÄ PUBLIC API URL: {public_url}\n")
    print(f"üëâ Copy this URL for your local client.")
else:
    print("‚ö†Ô∏è NGROK_AUTH_TOKEN not set. Remote access will not work.")

async def run_server():
    config = uvicorn.Config(app, host="0.0.0.0", port=8000)
    server = uvicorn.Server(config)
    await server.serve()

# Run the server
print("Starting Server...")
# In notebook, we use the existing event loop
if __name__ == "__main__":
    try:
        loop = asyncio.get_event_loop()
        if loop.is_running():
             # When running in a cell with an existing loop (normal for Colab)
            asyncio.create_task(run_server())
        else:
            asyncio.run(run_server())
    except RuntimeError:
         await run_server()