<a href="https://colab.research.google.com/github/Sakthi21-12/HealthAI-Intelligent-Healthcare-Assistant-Using-IBM-Granite/blob/main/HEALTHAI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install transformers torch gradio -q

In [2]:
import os, json, random, re, hashlib
from datetime import datetime, timedelta

import numpy as np, pandas as pd, plotly.express as px
import gradio as gr, torch
from transformers import AutoTokenizer, AutoModelForCausalLM

MODEL_NAME = "ibm-granite/granite-3.2-2b-instruct"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
    device_map="auto" if torch.cuda.is_available() else None,
)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

CACHE = {}
def cache_get(k): return CACHE.get(hashlib.sha256(k.encode()).hexdigest())
def cache_set(k, v): CACHE[hashlib.sha256(k.encode()).hexdigest()] = v

def precheck(txt):
    if not txt or not txt.strip():
        return False, "Please enter details."
    lowered = txt.lower()
    if re.search(r"\b(chest pain|shortness of breath|suicide|severe bleeding)\b", lowered):
        return False, "⚠️ Emergency symptoms detected. Seek immediate medical help."
    if re.search(r"\b\d{10}\b|\b\d{3}-\d{2}-\d{4}\b|@.+\..+", txt):
        return False, "⚠️ Remove personal info (phone/email/ID) before using."
    return True, "OK"

def call_llm(prompt, max_new_tokens=200):
    c = cache_get(prompt)
    if c:
        return c
    x = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=1024)
    if torch.cuda.is_available():
        x = {k: v.to(model.device) for k, v in x.items()}
    y = model.generate(**x, max_new_tokens=max_new_tokens, do_sample=False, temperature=0.0, top_p=0.9, num_beams=1, pad_token_id=tokenizer.eos_token_id)
    out = tokenizer.decode(y[0], skip_special_tokens=True).replace(prompt, "").strip()
    cache_set(prompt, out)
    return out

def chat_fn(msg, st):
    ok, m = precheck(msg)
    if not ok:
        st["chat"].append((msg, m))
        return st["chat"], {"response": m}, st, ""
    p = f"Patient:{json.dumps(st['profile'])}\nHistory:{json.dumps(st['chat'][-5:])}\nQ:{msg}\nReturn JSON {{response:'',sources:[]}}"
    raw = call_llm(p, 150)
    try:
        r = json.loads(raw)
    except Exception:
        r = {"response": raw, "sources": []}
    st["chat"].append((msg, r.get("response", "")))
    return st["chat"], r, st, ""

def pred_fn(sym, st):
    ok, m = precheck(sym)
    if not ok:
        return {"summary": m, "conditions": [], "advice": ""}, st
    p = f"Patient:{json.dumps(st['profile'])}\nSymptoms:{sym}\nReturn JSON summary,conditions,advice"
    raw = call_llm(p, 300)
    try:
        r = json.loads(raw)
    except Exception:
        r = {"summary": raw, "conditions": [], "advice": "Informational only."}
    return r, st

def plan_fn(cond, st):
    ok, m = precheck(cond)
    if not ok:
        return {"plan_summary": m, "medications": [], "lifestyle": [], "follow_up": "", "disclaimer": ""}, st
    p = f"Condition:{cond}\nPatient:{json.dumps(st['profile'])}\nReturn JSON plan_summary,lifestyle,medications,follow_up,disclaimer"
    raw = call_llm(p, 350)
    try:
        r = json.loads(raw)
    except Exception:
        r = {"plan_summary": raw, "lifestyle": [], "medications": [], "follow_up": "", "disclaimer": "Informational only."}
    return r, st

def gen_data(days=30):
    base = datetime.now().date() - timedelta(days=days-1)
    rows = []
    for i in range(days):
        d = base + timedelta(days=i)
        rows.append({
            "date": d,
            "hr": int(np.clip(70 + np.random.normal(0,5), 50,120)),
            "sys": int(np.clip(120 + np.random.normal(0,10), 90,180)),
            "dia": int(np.clip(80 + np.random.normal(0,6), 60,120)),
            "glucose": round(np.clip(95 + np.random.normal(0,15), 60,200), 1),
            "symptom": random.choice(["cough","fever","headache","fatigue","none"])
        })
    return pd.DataFrame(rows)

def metrics(df):
    return {c: {"value": float(df[c].iloc[-1]), "delta": float(df[c].iloc[-1] - df[c].iloc[0])} for c in ["hr","sys","dia","glucose"]}
css_style = """
.gradio-container {
    background-color: #f0f8ff !important;
    font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
h1, h2, h3 {
    font-weight: 300 !important;
    margin-bottom: 0.5rem !important;
}
h1 {
    font-size: 2.2rem !important;
    font-weight: 600 !important;
    color: #2c5aa0 !important;
}
h2 {
    font-size: 1.8rem !important;
    font-weight: 500 !important;
    color: #3a6cb3 !important;
}
h3 {
    font-size: 1.4rem !important;
    color: #4a7cc4 !important;
}
.center-heading {
    text-align: center;
}
.label-text {
    font-size: 1.1rem !important;
    font-weight: 500 !important;
    margin-bottom: 0.3rem !important;
}
.normal-text {
    font-size: 1.05rem !important;
    line-height: 1.5 !important;
}
.compact-section {
    margin-bottom: 1rem !important;
}
.chat-message {
    padding: 0.8rem;
    border-radius: 12px;
    margin-bottom: 0.8rem;
    max-width: 80%;
}
.user-message {
    background-color: #e1f5fe;
    margin-left: auto;
}
.ai-message {
    background-color: #f3e5f5;
    margin-right: auto;
}
.metric-card {
    background: white;
    border-radius: 10px;
    padding: 15px;
    box-shadow: 0 2px 6px rgba(0,0,0,0.1);
    margin-bottom: 15px;
}
.metric-value {
    font-size: 1.8rem;
    font-weight: 600;
    color: #2c5aa0;
}
.metric-delta {
    font-size: 1rem;
    font-weight: 500;
}
.positive-delta {
    color: #4caf50;
}
.negative-delta {
    color: #f44336;
}
"""

with gr.Blocks(theme=gr.themes.Soft(), css=css_style) as app:
    st = gr.State({"profile": {}, "chat": []})

    with gr.Row():
        with gr.Column(scale=1):
            gr.Markdown("# 🌤️ HealthAI — Smart Assistant", elem_classes="center-heading")
            gr.Markdown("### Your Personal Health Companion", elem_classes="center-heading")
            gr.Markdown("**⚠️ Informational only. Not medical advice.**", elem_classes="center-heading normal-text")

            with gr.Accordion("Patient Profile", open=True):
                name = gr.Textbox(label="Name", elem_classes="label-text")
                age = gr.Number(label="Age", elem_classes="label-text")
                gender = gr.Dropdown(["Male","Female","Other"], label="Gender", elem_classes="label-text")
                history = gr.Textbox(label="Medical History", lines=2, elem_classes="label-text")
                meds = gr.Textbox(label="Current Medications", lines=2, elem_classes="label-text")
                allergies = gr.Textbox(label="Allergies", lines=2, elem_classes="label-text")
                save = gr.Button("Save Profile", variant="primary")
                prof_out = gr.JSON(visible=False)

                def save_fn(n, a, g, h, m, al, st):
                    st["profile"] = {"name": n, "age": a, "gender": g, "history": h, "medications": m, "allergies": al}
                    return st, st["profile"]

                save.click(save_fn, [name, age, gender, history, meds, allergies, st], [st, prof_out])

        with gr.Column(scale=3):
            with gr.Tabs():
                with gr.TabItem("💬 Patient Chat"):
                    chat_ui = gr.Chatbot(elem_classes="normal-text", show_label=False)
                    msg = gr.Textbox(placeholder="Describe your symptoms or ask a health question...",
                                   elem_classes="normal-text", lines=2)
                    with gr.Row():
                        send = gr.Button("Send Message", variant="primary")
                        clear = gr.Button("Clear Chat")
                    chat_json = gr.JSON(visible=False)

                    def clear_chat(st):
                        st["chat"] = []
                        return [], st

                    send.click(chat_fn, [msg, st], [chat_ui, chat_json, st, msg])
                    msg.submit(chat_fn, [msg, st], [chat_ui, chat_json, st, msg])
                    clear.click(clear_chat, [st], [chat_ui, st])

                with gr.TabItem("🔍 Disease Prediction"):
                    gr.Markdown("### Describe your symptoms for analysis", elem_classes="compact-section")
                    sym = gr.Textbox(label="Symptoms", placeholder="Describe what you're experiencing in detail...",
                                   lines=4, elem_classes="normal-text")
                    analyze = gr.Button("Analyze Symptoms", variant="primary")
                    pred_out = gr.JSON(label="Analysis Results")
                    analyze.click(pred_fn, [sym, st], [pred_out, st])

                with gr.TabItem("📋 Treatment Plans"):
                    gr.Markdown("### Generate a personalized treatment plan", elem_classes="compact-section")
                    cond = gr.Textbox(label="Condition", placeholder="Enter diagnosed condition or select from predictions...",
                                    elem_classes="normal-text")
                    gen_plan = gr.Button("Generate Plan", variant="primary")
                    plan_out = gr.JSON(label="Treatment Plan")
                    gen_plan.click(plan_fn, [cond, st], [plan_out, st])

                with gr.TabItem("📊 Health Analytics"):
                    with gr.Row():
                        with gr.Column(scale=2):
                            gr.Markdown("### Health Metrics Dashboard")
                            days = gr.Slider(7, 90, 30, label="Time Period (days)", info="Select days to analyze")
                            gen = gr.Button("Generate Report", variant="primary")

                            # Metrics summary
                            gr.Markdown("#### Health Summary", elem_classes="compact-section")
                            with gr.Column(elem_classes="metric-card"):
                                gr.Markdown("**Heart Rate**", elem_classes="label-text")
                                with gr.Row():
                                    hr_value = gr.Markdown("74.0 bpm", elem_classes="metric-value")
                                    hr_delta = gr.Markdown("↑ 6.1", elem_classes="metric-delta positive-delta")

                            with gr.Column(elem_classes="metric-card"):
                                gr.Markdown("**Blood Pressure**", elem_classes="label-text")
                                with gr.Row():
                                    bp_value = gr.Markdown("120/80", elem_classes="metric-value")
                                    bp_delta = gr.Markdown("↑ 2.0", elem_classes="metric-delta positive-delta")

                            with gr.Column(elem_classes="metric-card"):
                                gr.Markdown("**Blood Glucose**", elem_classes="label-text")
                                with gr.Row():
                                    glu_value = gr.Markdown("101 mg/dL", elem_classes="metric-value")
                                    glu_delta = gr.Markdown("↓ 17.8", elem_classes="metric-delta negative-delta")

                            insight = gr.Textbox(label="AI Health Insights", elem_classes="normal-text")

                        with gr.Column(scale=3):
                            with gr.Tabs():
                                with gr.TabItem("Heart Rate"):
                                    hr_plot = gr.Plot()
                                with gr.TabItem("Blood Pressure"):
                                    bp_plot = gr.Plot()
                                with gr.TabItem("Glucose"):
                                    glu_plot = gr.Plot()
                                with gr.TabItem("Symptoms"):
                                    pie_plot = gr.Plot()

                    def ana(d, st):
                        df = gen_data(int(d))
                        f1 = px.line(df, x="date", y="hr", title="Heart Rate Trend")
                        f2 = px.line(df, x="date", y=["sys", "dia"], title="Blood Pressure Trend")
                        f3 = px.line(df, x="date", y="glucose", title="Blood Glucose Trend")
                        f4 = px.pie(df, names="symptom", title="Symptom Frequency")
                        met = metrics(df)

                        # Update metric displays
                        hr_val = f"{met['hr']['value']:.1f} bpm"
                        hr_del = f"↑ {met['hr']['delta']:.1f}" if met['hr']['delta'] >= 0 else f"↓ {abs(met['hr']['delta']):.1f}"

                        bp_val = f"{met['sys']['value']:.0f}/{met['dia']['value']:.0f}"
                        bp_del = f"↑ {(met['sys']['delta'] + met['dia']['delta'])/2:.1f}" if (met['sys']['delta'] + met['dia']['delta'])/2 >= 0 else f"↓ {abs((met['sys']['delta'] + met['dia']['delta'])/2):.1f}"

                        glu_val = f"{met['glucose']['value']:.1f} mg/dL"
                        glu_del = f"↑ {met['glucose']['delta']:.1f}" if met['glucose']['delta'] >= 0 else f"↓ {abs(met['glucose']['delta']):.1f}"

                        p = f"Given {json.dumps(met)} give 1-2 sentence health insight with red flags."
                        ins = call_llm(p, 120)

                        return f1, f2, f3, f4, ins, st, hr_val, hr_del, bp_val, bp_del, glu_val, glu_del

                    gen.click(ana, [days, st], [hr_plot, bp_plot, glu_plot, pie_plot, insight, st,
                                              hr_value, hr_delta, bp_value, bp_delta, glu_value, glu_delta])

    gr.Markdown("**Note:** This tool provides health information for educational purposes only. Always consult a healthcare professional for medical advice.",
                elem_classes="center-heading normal-text")

if __name__ == "__main__":
    os.makedirs("exports", exist_ok=True)
    app.launch(share=True, height=800)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

added_tokens.json:   0%|          | 0.00/87.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/701 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/786 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors.index.json: 0.00B [00:01, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/67.1M [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

  chat_ui = gr.Chatbot(elem_classes="normal-text", show_label=False)


Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://7ce5ec74dbaef3049a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)
