# 09 - MedGemma Integration (Kaggle)

Notebook para gerar relatorio clinico estruturado a partir dos outputs dos 3 modelos:
- Glaucoma (TransUNet)
- DR Grading (EfficientNet)
- Vascular (U-Net)

Com fallback rule-based se o MedGemma nao estiver disponivel.

In [None]:
!pip -q install transformers accelerate sentencepiece

In [None]:
from dataclasses import dataclass
from typing import Optional, Tuple, Dict

@dataclass
class ScreeningResults:
    cdr: float
    glaucoma_risk: str
    dr_grade: int
    dr_label: str
    dr_conf: float
    vessel_density: float

def build_prompt(r: ScreeningResults) -> str:
    return f"""You are a clinical ophthalmology AI assistant. Based on the following automated retinal screening results, generate a structured clinical report.

PATIENT SCREENING RESULTS:
- Glaucoma Assessment: CDR = {r.cdr:.3f} | Risk: {r.glaucoma_risk}
- Diabetic Retinopathy: Grade {r.dr_grade} ({r.dr_label}) | Confidence: {r.dr_conf:.1%}
- Vascular Analysis: Vessel density = {r.vessel_density:.1%} | Segmentation available

Generate a report with:
1. FINDINGS
2. RISK ASSESSMENT (low/moderate/high/emergent)
3. RECOMMENDATIONS (follow-up interval, referrals, exams)
4. DISCLAIMER (AI-assisted screening, not diagnosis)
"""

def overall_risk(r: ScreeningResults) -> str:
    if r.dr_grade >= 4 or r.cdr >= 0.75:
        return "emergent"
    if r.dr_grade >= 3 or r.cdr >= 0.65:
        return "high"
    if r.dr_grade >= 2 or r.cdr >= 0.55:
        return "moderate"
    return "low"

def rule_based_report(r: ScreeningResults) -> str:
    risk = overall_risk(r)
    follow = {
        "emergent": "Urgent ophthalmology referral within 24-72 hours.",
        "high": "Specialist assessment within 1-2 weeks.",
        "moderate": "Follow-up in 1-3 months with repeat retinal imaging.",
        "low": "Routine annual screening and risk-factor control.",
    }[risk]

    return (
        "1. FINDINGS\n"
        f"- Glaucoma screening: CDR {r.cdr:.3f}, risk category {r.glaucoma_risk}.\n"
        f"- DR grading: Grade {r.dr_grade} ({r.dr_label}), confidence {r.dr_conf:.1%}.\n"
        f"- Vascular analysis: estimated vessel density {r.vessel_density:.1%}.\n\n"
        "2. RISK ASSESSMENT\n"
        f"- Overall triage risk: {risk}.\n\n"
        "3. RECOMMENDATIONS\n"
        f"- {follow}\n"
        "- Correlate with clinical exam, IOP, OCT, and visual acuity.\n\n"
        "4. DISCLAIMER\n"
        "- This is AI-assisted retinal screening support and not a definitive diagnosis."
    )

def medgemma_generate(prompt: str, model_id: str = "google/medgemma-4b-it", max_new_tokens: int = 420) -> Tuple[Optional[str], Optional[str]]:
    try:
        import torch
        from transformers import AutoModelForCausalLM, AutoTokenizer

        tokenizer = AutoTokenizer.from_pretrained(model_id)
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto",
        )
        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        with torch.no_grad():
            output_ids = model.generate(
                **inputs,
                max_new_tokens=max_new_tokens,
                temperature=0.2,
                do_sample=True,
                pad_token_id=tokenizer.eos_token_id,
            )
        text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
        if text.startswith(prompt):
            text = text[len(prompt):].strip()
        return text, None
    except Exception as e:
        return None, str(e)

def generate_report(r: ScreeningResults, use_medgemma: bool = True) -> Dict[str, str]:
    prompt = build_prompt(r)
    fallback = rule_based_report(r)
    if not use_medgemma:
        return {"mode": "rule_based_only", "report": fallback, "prompt": prompt, "error": ""}
    out, err = medgemma_generate(prompt)
    if out is None:
        return {"mode": "fallback_rule_based", "report": fallback, "prompt": prompt, "error": err or "unknown"}
    return {"mode": "medgemma", "report": out, "prompt": prompt, "error": ""}

In [None]:
# Substitua pelos valores reais inferidos dos seus 3 modelos
results = ScreeningResults(
    cdr=0.612,
    glaucoma_risk="high",
    dr_grade=2,
    dr_label="Moderate",
    dr_conf=0.9616,
    vessel_density=0.12,
)

output = generate_report(results, use_medgemma=True)
print("MODE:", output["mode"])
if output["error"]:
    print("ERROR:", output["error"])
print("\nREPORT:\n")
print(output["report"])

In [None]:
# Opcional: salvar relatorio em arquivo
import os, json

os.makedirs("/kaggle/working/outputs/medgemma", exist_ok=True)
with open("/kaggle/working/outputs/medgemma/clinical_report.txt", "w", encoding="utf-8") as f:
    f.write(output["report"])

with open("/kaggle/working/outputs/medgemma/clinical_report_meta.json", "w", encoding="utf-8") as f:
    json.dump({
        "mode": output["mode"],
        "error": output["error"],
        "inputs": results.__dict__,
    }, f, ensure_ascii=False, indent=2)

print("Saved:")
print("- /kaggle/working/outputs/medgemma/clinical_report.txt")
print("- /kaggle/working/outputs/medgemma/clinical_report_meta.json")

## Integracao no app/demo

Use os outputs reais dos modelos treinados:
- `cdr` e `glaucoma_risk` do TransUNet
- `dr_grade` e `dr_conf` do EfficientNet
- `vessel_density` da segmentacao vascular

Se o MedGemma nao carregar no Kaggle, o fallback rule-based garante relatorio estavel para a demo.