In [None]:
# ===============================
# 1Ô∏è‚É£ INSTALL DEPENDENCIES
# ===============================
!pip install -q fastapi uvicorn pyngrok transformers accelerate torch sentencepiece nest_asyncio

# ===============================
# 2Ô∏è‚É£ IMPORTS
# ===============================
import torch
import nest_asyncio
import threading
from fastapi import FastAPI
from pyngrok import ngrok
import uvicorn
from transformers import AutoTokenizer, AutoModelForCausalLM
from google.colab import userdata
from pyngrok import ngrok
from pyngrok import conf



nest_asyncio.apply()
from huggingface_hub import login
login()
conf.get_default().auth_token = userdata.get("ngrok_auth")
# ===============================
# 3Ô∏è‚É£ LOAD MEDGEMMA (GPU)
# ===============================
MODEL_ID = "google/medgemma-1.5-4b-it"

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
tokenizer.pad_token = tokenizer.eos_token

print("Loading model (this takes 1-2 mins)...")
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.bfloat16,
    device_map="auto"
).eval()

print("Model loaded successfully ‚úÖ")
print("GPU available:", torch.cuda.is_available())

# ===============================
# 4Ô∏è‚É£ BUILD FASTAPI APP
# ===============================
app = FastAPI()

@app.get("/")
def health():
    return {"status": "MedGemma server running"}

@app.post("/generate")
def generate(note: str):

    messages = [
        {
            "role": "system",
            "content": (
                "You are a senior clinical discharge copilot.\n"
                "Return ONLY valid JSON.\n"
                "Do NOT include explanations.\n"
                "Follow this schema:\n"
                "{"
                '"triage_level": "low|medium|high",'
                '"medications": [],'
                '"activity_guidance": [],'
                '"warning_signs": [],'
                '"red_flag_actions": [],'
                '"follow_up": [],'
                '"patient_instructions_simple": []'
                "}"
            )
        },
        {
            "role": "user",
            "content": f"Convert this discharge note:\n\n{note}"
        }
    ]

    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to("cuda")

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=400,     # ‚Üê increase to avoid truncation
            do_sample=False,
            pad_token_id=tokenizer.eos_token_id,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id
        )
    generated_tokens = outputs[:, inputs["input_ids"].shape[-1]:]

    response = tokenizer.batch_decode(
        generated_tokens,
        skip_special_tokens=True
    )[0]

    response = response.strip()

    # Remove markdown fences if present
    response = response.replace("```json", "")
    response = response.replace("```", "")

    # Extract JSON block safely
    start = response.find("{")
    end = response.rfind("}")

    if start != -1 and end != -1:
        response = response[start:end+1]
    else:
        response = "{}"

    return {"response": response}

# ===============================
# 5Ô∏è‚É£ START SERVER + NGROK
# ===============================
public_url = ngrok.connect(8000)
print("üî• PUBLIC URL:", public_url)

def run_server():
    uvicorn.run(app, host="0.0.0.0", port=8000)

threading.Thread(target=run_server).start()

print("Server is running...")