## 🧠 Brain Tumor AI Reporting System in Google Colab

This notebook implements the two-phase medical reporting system.
1.  **Phase 1:** A specialist model (`google/med-gemma`) analyzes raw medical data to produce a technical JSON output.
2.  **Phase 2:** An orchestrator model (`llama3-70b-8192` via Groq API) translates the technical data into a user-friendly report and then enters a conversational Q&A loop, using a RAG system powered by a Pinecone vector database.

### Step 1: Install Required Libraries

In [None]:
!pip install -qU transformers accelerate bitsandbytes torch
!pip install -qU langchain langchain-groq langchain_community python-dotenv pinecone-client sentence-transformers

### Step 2: Set Up API Keys and Hugging Face Token

To run this notebook, you need three secret keys:
1.  `GROQ_API_KEY`: From [GroqCloud](https://console.groq.com/keys)
2.  `PINECONE_API_KEY`: From [Pinecone](https://www.pinecone.io/)
3.  `HF_TOKEN`: A Hugging Face token with access to Gemma models. Get one from [Hugging Face](https://huggingface.co/settings/tokens).

**Instructions:**
1. Click the '🔑' (key) icon in the left sidebar of Colab.
2. Click '+ Add new secret' for each of the three keys listed above.
3. Enable the 'Notebook access' toggle for each key.

In [None]:
import os
from google.colab import userdata

# Set environment variables from Colab secrets
os.environ["GROQ_API_KEY"] = userdata.get('GROQ_API_KEY')
os.environ["PINECONE_API_KEY"] = userdata.get('PINECONE_API_KEY')

# Login to Hugging Face CLI to access Gemma models
from huggingface_hub import login
login(token=userdata.get('HF_TOKEN'))

### Step 3: Imports and Initializations

In [None]:
import json
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline

from langchain_groq import ChatGroq
from langchain.prompts import PromptTemplate
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_core.runnables import RunnablePassthrough, RunnableLambda
from langchain.memory import ConversationBufferMemory

from pinecone import Pinecone, ServerlessSpec
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.vectorstores import Pinecone as LangchainPinecone

print("Libraries imported successfully.")

### Step 4: Set Up Pinecone Vector Database (RAG System)

This cell will create a cloud-based vector index in your Pinecone account and populate it with simulated PubMed abstracts about brain tumors. This only needs to be run once. After the first successful execution, you can comment out the `setup_pinecone_index()` call at the end of the script.

In [None]:
PINECONE_INDEX_NAME = "brain-tumor-pubmed"

# Sample documents simulating PubMed abstracts
SAMPLE_DOCUMENTS = [
    "High-grade gliomas, such as glioblastoma, are highly aggressive brain tumors. On contrast-enhanced MRI, they classically present as irregularly shaped, ring-enhancing lesions due to a necrotic core and a hypervascular, leaky rim. The surrounding vasogenic edema is a result of the disruption of the blood-brain barrier by tumor-secreted factors.",
    "The standard of care for a suspected high-grade glioma, known as the Stupp protocol, often involves maximal safe surgical resection, followed by concurrent radiation therapy and chemotherapy with the alkylating agent temozolomide (TMZ).",
    "A differential diagnosis for a ring-enhancing lesion in the brain is critical. It primarily includes high-grade glioma (glioblastoma), brain metastasis from a primary cancer elsewhere (e.g., lung, breast, melanoma), and a pyogenic brain abscess. Advanced imaging techniques like MR Spectroscopy can help differentiate these.",
    "Mass effect is a crucial radiological sign, referring to the secondary effects of a large lesion, such as the displacement of normal brain structures. A midline shift greater than 5mm is often associated with increased intracranial pressure and can be a neurosurgical emergency.",
    "Stereotactic biopsy is a minimally invasive neurosurgical procedure used to obtain a tissue sample from a brain lesion for definitive histopathological diagnosis. This is often performed when a lesion is in an eloquent or deep-seated area where a full resection carries high risk.",
    "Magnetic Resonance Spectroscopy (MRS) is a non-invasive imaging technique that analyzes the chemical composition of brain tissue. In brain tumors, an elevated choline peak and a reduced N-acetylaspartate (NAA) peak are indicative of high cellular turnover and neuronal loss, respectively, suggesting malignancy."
]

def get_embeddings_model():
    return HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2")

def get_pinecone_client():
    return Pinecone(api_key=os.environ["PINECONE_API_KEY"])

def setup_pinecone_index():
    pc = get_pinecone_client()
    embeddings = get_embeddings_model()
    embedding_dimension = embeddings.client[1].word_embedding_dimension

    if PINECONE_INDEX_NAME not in pc.list_indexes().names():
        print(f"Creating Pinecone index: {PINECONE_INDEX_NAME}...")
        pc.create_index(
            name=PINECONE_INDEX_NAME,
            dimension=embedding_dimension,
            metric="cosine",
            spec=ServerlessSpec(cloud='aws', region='us-east-1')
        )
        print("Index created. Upserting documents...")
        LangchainPinecone.from_texts(texts=SAMPLE_DOCUMENTS, embedding=embeddings, index_name=PINECONE_INDEX_NAME)
        print("Documents upserted successfully.")
    else:
        print(f"Index '{PINECONE_INDEX_NAME}' already exists.")

def get_retriever():
    embeddings = get_embeddings_model()
    vectorstore = LangchainPinecone.from_existing_index(index_name=PINECONE_INDEX_NAME, embedding=embeddings)
    return vectorstore.as_retriever(search_kwargs={"k": 2})

# --- Run the one-time setup ---
setup_pinecone_index()

### Step 5: Load Specialist LLM (Med-Gemma)

Here, we download the `google/med-gemma` model from Hugging Face. We use 4-bit quantization (`BitsAndBytesConfig`) to ensure the model fits into the memory of a free Colab T4 GPU.

In [None]:
model_id = "google/med-gemma"

quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)

tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=quantization_config,
    device_map="auto",
)

# Create a LangChain-compatible LLM from the local model
med_gemma_pipeline = torch.pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    max_new_tokens=1024,
    do_sample=True,
    temperature=0.1, 
    top_p=0.95
)

specialist_llm = HuggingFacePipeline(pipeline=med_gemma_pipeline)

### Step 6: Define LLMs and Prompts

In [None]:
# Orchestrator LLM (via API)
orchestrator_llm = ChatGroq(
    temperature=0.4,
    model_name="llama3-70b-8192",
    max_tokens=2048,
)

# --- PROMPTS ---
TECHNICAL_PROMPT_TEMPLATE = PromptTemplate.from_template(
    """
    You are a neuroradiology expert AI. Your task is to analyze the provided Brain MRI findings and convert them into a structured, raw JSON object. Focus exclusively on technical accuracy. Do not add conversational text or explanations. Your entire output must be only the JSON object.

    MRI Data:
    {medical_data}

    Required JSON Output Format:
    {{
      "summary": "One-sentence technical summary of the primary finding.",
      "detailed_analysis": {{
        "lesion_type": "Categorize the lesion based on its features (e.g., 'Ring-enhancing intra-axial mass').",
        "location_specifics": "Provide a detailed anatomical location.",
        "key_features": ["List key radiological features (e.g., 'Central necrosis', 'Significant vasogenic edema', 'Irregular enhancement')."],
        "mass_effect_details": "Describe the mass effect observed (e.g., 'Sulcal effacement with 4mm rightward midline shift')."
      }},
      "differential_diagnosis": ["List the most likely differential diagnoses in order of probability (e.g., 'High-grade glioma (e.g., glioblastoma)', 'Metastasis', 'Abscess')."],
      "recommendations": ["List technical next steps (e.g., 'Neurosurgical consultation for resection or biopsy', 'Advanced imaging such as MRS or Perfusion MRI')."]
    }}
    """
)

REPORT_GENERATION_TEMPLATE = PromptTemplate.from_template(
    """
    You are a compassionate medical communicator specializing in neurology. Your task is to translate a raw, technical neuroradiology analysis into a clear, well-structured report for a {target_audience}. Do not add any medical information not present in the technical analysis. Use the original patient data for context. Generate the final report using Markdown formatting.

    **Raw Technical Analysis:**
    {raw_analysis}

    **Original Patient Data:**
    {medical_data}
    """
)

QNA_RAG_TEMPLATE = PromptTemplate.from_template(
    """
    You are an AI assistant helping a user understand a brain tumor medical report. Use the provided chat history and the retrieved context from medical literature to answer the user's question. Frame your answer based on the retrieved context. If the context is relevant, relate it back to the specifics of the patient's case. If the context does not help, say that you cannot find information on that topic.

    **Chat History:**
    {chat_history}

    **Retrieved Context from Medical Literature:**
    {context}

    **User Question:**
    {question}

    **Answer:**
    """
)

print("LLMs and Prompts are ready.")

### Step 7: Define Input Data and Run Phase 1 (Report Generation)

In [None]:
medical_data = {
  "patient_id": "P-78901",
  "imaging_type": "Brain MRI with Contrast",
  "findings": [
    {
      "finding_type": "Mass",
      "location": "Left Frontal Lobe",
      "size_mm": [35, 30, 42],
      "characteristics": "Irregularly shaped, ring-enhancing lesion with central necrosis.",
      "associated_edema": "Significant vasogenic edema in the surrounding white matter.",
      "mass_effect": "Effacement of the adjacent cortical sulci and mild midline shift of 4mm to the right.",
      "confidence_score": 0.95
    }
  ],
  "image_metadata": {
    "sequence": "T1-weighted post-contrast",
    "slice_thickness_mm": 5
  }
}

def run_phase_1_report_generation(medical_data: dict):
    print("--- 🧠 Phase 1: Initial Report Generation ---")
    medical_data_str = json.dumps(medical_data, indent=2)

    specialist_chain = TECHNICAL_PROMPT_TEMPLATE | specialist_llm | JsonOutputParser()
    print("> Invoking Med-Gemma specialist model...")
    raw_analysis = specialist_chain.invoke({"medical_data": medical_data_str})
    print("< Specialist analysis received.")

    report_generation_chain = REPORT_GENERATION_TEMPLATE | orchestrator_llm | StrOutputParser()
    print("> Invoking Llama-3 orchestrator model for report generation...")
    final_report = report_generation_chain.invoke({
        "target_audience": "patient",
        "raw_analysis": json.dumps(raw_analysis, indent=2),
        "medical_data": medical_data_str
    })
    print("< Final report generated.")

    print("\n" + "="*60)
    print("        Final Medical Report")
    print("="*60)
    print(final_report)
    print("="*60 + "\n")
    return final_report

# Execute Phase 1
final_report = run_phase_1_report_generation(medical_data)

### Step 8: Run Phase 2 (Conversational Q&A)

In [None]:
def format_chat_history(chat_history):
    if not chat_history:
        return "No history yet."
    return "\n".join([f"{msg.type.capitalize()}: {msg.content}" for msg in chat_history])

def run_phase_2_conversational_qna(initial_report: str):
    print("--- 💬 Phase 2: Conversational Q&A ---")
    print("The system is now ready for your follow-up questions.")
    print("Type 'exit' to end the session.\n")

    retriever = get_retriever()
    memory = ConversationBufferMemory(return_messages=True, memory_key="chat_history")
    memory.save_context(
        {"input": "Here is the initial report for context."},
        {"output": initial_report}
    )

    rag_chain = (
        {
            "context": retriever,
            "question": RunnablePassthrough(),
            "chat_history": RunnableLambda(lambda x: format_chat_history(memory.load_memory_variables({})["chat_history"]))
        }
        | QNA_RAG_TEMPLATE
        | orchestrator_llm
        | StrOutputParser()
    )

    while True:
        try:
            question = input("Your Question: ")
            if question.lower().strip() == 'exit':
                print("\nSession ended. Goodbye! 👋")
                break

            print("\n> Searching knowledge base and generating response...")
            answer = rag_chain.invoke(question)
            print(f"\nAssistant: {answer}\n")
            memory.save_context({"input": question}, {"output": answer})

        except (KeyboardInterrupt, EOFError):
            print("\nSession ended by user. Goodbye! 👋")
            break

# Execute Phase 2 if Phase 1 was successful
if final_report:
    run_phase_2_conversational_qna(final_report)