<a href="https://colab.research.google.com/github/Maanisha27/MediTriage_/blob/main/MedTriage.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Triage-Program

In [None]:
import sys
try:
    import pkgutil, importlib
    required = ["scikit_learn", "plotly", "pandas", "numpy", "rich", "networkx"]
    for r in required:
        if not pkgutil.find_loader(r):
            raise ImportError
except Exception:
    try:
        print("Installing required packages...")
        !pip install -q scikit-learn plotly pandas numpy rich networkx
    except Exception:
        import subprocess
        subprocess.check_call([sys.executable, "-m", "pip", "install", "scikit-learn", "plotly", "pandas", "numpy", "rich", "networkx"])

# ------------------ Imports ------------------
import numpy as np
import pandas as pd
from datetime import datetime
import warnings, math, json, os, random
warnings.filterwarnings("ignore")

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from sklearn.metrics.pairwise import cosine_similarity

import plotly.express as px
import plotly.graph_objects as go
import networkx as nx

from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.prompt import Prompt, IntPrompt, Confirm
from rich.progress import track

console = Console()
pd.set_option('display.max_columns', 200)
pd.set_option('display.width', 140)

# ------------------ Base data (same as your earlier definitions) ------------------
SPECIALISTS = [
    {'id': 'CARD_01', 'label': 'Specialist_1', 'expertise': 92, 'availability': 80, 'success_rate': 88, 'resource_access': 90, 'workload': 6, 'specialization': 'Cardiac'},
    {'id': 'NEUR_02', 'label': 'Specialist_2', 'expertise': 89, 'availability': 70, 'success_rate': 90, 'resource_access': 75, 'workload': 4, 'specialization': 'Neurology'},
    {'id': 'TRMA_03', 'label': 'Specialist_3', 'expertise': 95, 'availability': 65, 'success_rate': 91, 'resource_access': 85, 'workload': 9, 'specialization': 'Trauma'},
    {'id': 'GENM_04', 'label': 'Specialist_4', 'expertise': 80, 'availability': 90, 'success_rate': 82, 'resource_access': 70, 'workload': 3, 'specialization': 'General Medicine'},
    {'id': 'EMER_05', 'label': 'Specialist_5', 'expertise': 88, 'availability': 85, 'success_rate': 87, 'resource_access': 80, 'workload': 5, 'specialization': 'Emergency'}
]

SPECIALIST_SYM = {
    'CARD_01': np.array([1,0,0,0,0.95,0.2,0.2,0.1]),
    'NEUR_02': np.array([0,1,0,0,0.2,0.95,0.2,0.1]),
    'TRMA_03': np.array([0,0,1,0,0.3,0.3,0.9,0.1]),
    'GENM_04': np.array([0.5,0.5,0.2,0.2,0.5,0.5,0.5,0.4]),
    'EMER_05': np.array([0.7,0.6,0.5,0.6,0.9,0.8,0.6,0.3])
}

ADJ = np.array([
    [1.0,0.2,0.6,0.1,0.3],
    [0.2,1.0,0.2,0.1,0.4],
    [0.6,0.2,1.0,0.1,0.3],
    [0.1,0.1,0.1,1.0,0.2],
    [0.3,0.4,0.3,0.2,1.0]
])

SAMPLE_PATIENTS = [
    ["P001", 60, 39.0, 90, 85, 80, 90, 7, "Acute Myocardial Infarction"],
    ["P002", 8, 37.5, 80, 85, 70, 80, 6, "Severe Asthma Exacerbation"],
    ["P003", 72, 38.2, 85, 90, 85, 88, 6, "Displaced Hip Fracture"],
    ["P004", 25, 36.8, 30, 25, 20, 20, 2, "Deep Laceration"],
    ["P005", 55, 39.5, 90, 95, 90, 90, 3, "Acute Ischemic Stroke"],
    ["P006", 60, 38.8, 85, 80, 75, 85, 7, "Severe Pneumonia"],
    ["P007", 35, 37.2, 70, 75, 60, 65, 8, "Appendicitis"],
    ["P008", 50, 36.9, 60, 55, 50, 60, 5, "Fractured Arm"],
    ["P009", 28, 39.2, 40, 50, 20, 30, 7, "Migraine Attack"],
    ["P010", 65, 37.8, 85, 80, 90, 85, 5, "Chronic Heart Failure"]
]

# ------------------ Utility functions (improved) ------------------
def ensure_dir(path):
    os.makedirs(path, exist_ok=True)

def now_str():
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def normalize_cols_by_vector_norm(X: np.ndarray) -> np.ndarray:
    denom = np.sqrt((X**2).sum(axis=0))
    denom[denom == 0] = 1.0
    return X / denom

def topsis_scores(decision_matrix: np.ndarray, weights: np.ndarray, benefit_flags=None) -> np.ndarray:
    R = normalize_cols_by_vector_norm(decision_matrix)
    w = np.array(weights, dtype=float).flatten()
    w = w / (w.sum() + 1e-12)
    V = R * w
    if benefit_flags is None:
        benefit_flags = [True] * V.shape[1]
    pos = np.array([V[:, j].max() if benefit_flags[j] else V[:, j].min() for j in range(V.shape[1])])
    neg = np.array([V[:, j].min() if benefit_flags[j] else V[:, j].max() for j in range(V.shape[1])])
    d_pos = np.sqrt(((V - pos)**2).sum(axis=1))
    d_neg = np.sqrt(((V - neg)**2).sum(axis=1))
    Ci = d_neg / (d_pos + d_neg + 1e-12)
    return Ci

def promethee_netflow(decision_matrix: np.ndarray, weights: np.ndarray) -> np.ndarray:
    X = decision_matrix.astype(float)
    m, n = X.shape
    W = np.array(weights) / (np.sum(weights) + 1e-12)
    P = np.zeros((m, m))
    rng = X.max(axis=0) - X.min(axis=0)
    rng[rng == 0] = 1.0
    for i in range(m):
        for j in range(m):
            if i == j: continue
            diffs = (X[i] - X[j]) / rng
            prefs = 1.0 / (1.0 + np.exp(-5.0 * diffs))  # smooth preference
            P[i, j] = (W * prefs).sum()
    phi_plus = P.sum(axis=1) / (m - 1)
    phi_minus = P.sum(axis=0) / (m - 1)
    netflow = phi_plus - phi_minus
    return netflow

def fuzzy_mamdani_label(severity_norm, urgency_norm, waiting_norm):
    # same fuzzy but returning numeric score too for ensemble
    def low(x): return max(0.0, min((0.5 - x)/0.5, 1.0))
    def med(x): return max(0.0, 1 - abs(x - 0.5)/0.25)
    def high(x): return max(0.0, min((x - 0.5)/0.5, 1.0))
    sev_h, sev_m, sev_l = high(severity_norm), med(severity_norm), low(severity_norm)
    urg_h, urg_m, urg_l = high(urgency_norm), med(urgency_norm), low(urgency_norm)
    wait_h, wait_m, wait_l = high(waiting_norm), med(waiting_norm), low(waiting_norm)
    rules = []
    rules.append((max(sev_h, urg_h, wait_h), 0.95))
    rules.append((max(min(sev_m, urg_m), min(sev_h, urg_m), min(sev_m, urg_h)), 0.7))
    rules.append((max(min(sev_m, urg_l), min(sev_l, urg_m)), 0.45))
    rules.append((min(sev_l, urg_l, wait_l), 0.12))
    num = sum(r[0] * r[1] for r in rules)
    den = sum(r[0] for r in rules) + 1e-12
    score = num / den
    if score >= 0.8: label = "Emergency/Immediate"
    elif score >= 0.6: label = "Urgent"
    elif score >= 0.35: label = "Semi-Urgent"
    else: label = "Routine"
    return label, float(score)

def waspas_for_patient(patient_sym_vector: np.ndarray, specialists: list, lam=0.5):
    cols = ['expertise','availability','success_rate','resource_access','workload']
    mat = []
    ids = []
    for s in specialists:
        mat.append([s['expertise'], s['availability'], s['success_rate'], s['resource_access'], s['workload']])
        ids.append(s['id'])
    mat = np.array(mat, dtype=float)
    Xn = np.zeros_like(mat)
    for j, col in enumerate(cols):
        if col == 'workload':
            mn = mat[:, j].min() or 1.0
            Xn[:, j] = mn / np.maximum(mat[:, j], 1e-9)
        else:
            mx = mat[:, j].max() or 1.0
            Xn[:, j] = mat[:, j] / mx
    w = np.array([0.35, 0.20, 0.25, 0.15, 0.05])
    w = w / (w.sum() + 1e-12)
    Q_wsm = (Xn * w).sum(axis=1)
    with np.errstate(divide='ignore'):
        Q_wpm = np.prod(np.power(np.clip(Xn, 1e-9, None), w), axis=1)
    Q = lam * Q_wsm + (1 - lam) * Q_wpm
    return ids, Q, Xn

def gnn_adjust_scores(base_scores: np.ndarray, eta=0.2, avail_vector=None, workloads=None, iters=2):
    # adaptive GNN: adjust adjacency by availability and inverse workload (prefer low workload)
    if avail_vector is None:
        avail_vector = np.array([s['availability'] for s in SPECIALISTS]) / 100.0
    if workloads is None:
        workloads = np.array([s['workload'] for s in SPECIALISTS])
    # adjacency adaptation
    avail_factor = avail_vector.reshape(-1,1) * (1.0 / (1.0 + workloads.reshape(-1,1)))
    adj_mod = ADJ * (0.5 + 0.5 * (avail_factor @ avail_factor.T))
    influence = adj_mod.dot(avail_vector)
    scores = base_scores.copy().astype(float)
    for _ in range(iters):
        scores = scores + eta * influence
        scores = scores / (1.0 + eta * adj_mod.mean())
    return scores

def similarity_scores(patient_sym_vector: np.ndarray, spec_sym_map: dict, specialists=None):
    sims = []
    ids = []
    if specialists is None:
        specialists = SPECIALISTS
    for s in specialists:
        sid = s['id']
        vec = spec_sym_map.get(sid, None)
        if vec is None:
            sims.append(0.0)
        else:
            sims.append(float(cosine_similarity([patient_sym_vector], [vec])[0,0]))
        ids.append(sid)
    return ids, np.array(sims)

def rank_like(scores: np.ndarray):
    order = (-scores).argsort()
    ranks = np.empty_like(order)
    ranks[order] = np.arange(1, len(scores)+1)
    maxr = len(scores)
    scorelike = (maxr - ranks) / (maxr - 1) if maxr > 1 else np.ones_like(ranks, dtype=float)
    return scorelike, ranks

def aggregate_ranks(ids, waspas_scores, gnn_scores, sim_scores, use_gnn=True, weights=(0.5,0.3,0.2)):
    w_waspas, w_gnn, w_sim = weights
    w_like, _ = rank_like(waspas_scores)
    g_like, _ = rank_like(gnn_scores) if use_gnn else (np.zeros_like(w_like), None)
    s_like, _ = rank_like(sim_scores)
    fused = w_waspas * w_like + (w_gnn * g_like if use_gnn else 0.0) + w_sim * s_like
    order = (-fused).argsort()
    final_ids = [ids[i] for i in order]
    final_scores = fused[order]
    return final_ids, final_scores

def build_patient_sym_vector(record: dict):
    cond = record.get('condition_desc','').lower()
    cardiac = 1.0 if any(k in cond for k in ['heart','card','mi','angina']) else 0.0
    neuro = 1.0 if any(k in cond for k in ['stroke','neuro','seizure','paralysis']) else 0.0
    trauma = 1.0 if any(k in cond for k in ['trauma','fracture','injury','bleed','laceration']) else 0.0
    metabolic = 1.0 if any(k in cond for k in ['diabet','metabolic','keto','hypogly']) else 0.0
    sev = record.get('severity',50)/100.0
    urg = record.get('urgency',50)/100.0
    pain = record.get('pain_level',0)/10.0
    age_adj = record.get('age',40)/100.0
    return np.array([cardiac, neuro, trauma, metabolic, sev, urg, pain, age_adj], dtype=float)

def age_vulnerability_from_age(age:int):
    if age >= 75: return 90.0
    if age >= 65: return 70.0
    if age <= 5: return 85.0
    if age <= 18: return 60.0
    return 30.0

# ------------------ Combined System Class (improved) ------------------
class CombinedMedicalSystem:
    def __init__(self, specialists=None, spec_sym=None, log_dir="logs"):
        self.patients = []
        self.specialists = specialists.copy() if specialists else SPECIALISTS.copy()
        self.spec_sym = spec_sym.copy() if spec_sym else SPECIALIST_SYM.copy()
        self.last_weights = None
        self.log_dir = log_dir
        ensure_dir(self.log_dir)
        self.event_log = []  # list of dicts for CSV export

    # load and annotate samples (with fuzzy)
    def load_sample_patients(self, auto_run_triage=True):
        self.patients = []
        for rec in SAMPLE_PATIENTS:
            pid, age, temp, sev, urg, res, wait, pain, cond = rec
            age_vuln = age_vulnerability_from_age(age)
            p = {
                'id': pid, 'age': age, 'temp': temp,
                'severity': sev, 'urgency': urg, 'resource_need': res,
                'waiting_impact': wait, 'age_vulnerability': age_vuln,
                'pain_level': pain, 'condition_desc': cond,
                'registration_time': datetime.now().isoformat()
            }
            label, fuzzy_score = fuzzy_mamdani_label(sev/100.0, urg/100.0, wait/100.0)
            p['fuzzy_label'] = label
            p['fuzzy_score'] = fuzzy_score
            self.patients.append(p)
        if auto_run_triage:
            self.run_triage(use_pca=True)

    def add_patient_interactive(self):
        console.print("\n[bold]Enter new patient details[/bold]")
        pid = Prompt.ask("Patient ID (e.g., P011)")
        if any(p['id'] == pid for p in self.patients):
            console.print("[red]ID exists — choose a unique ID.[/red]")
            return
        try:
            age = IntPrompt.ask("Age (years)")
            temp = float(Prompt.ask("Temperature (°C)", default="36.6"))
            cond = Prompt.ask("Condition short description")
            sev = float(Prompt.ask("Severity (0-100)", default="50"))
            urg = float(Prompt.ask("Urgency (0-100)", default="50"))
            res = float(Prompt.ask("Resource need (0-100)", default="50"))
            wait = float(Prompt.ask("Waiting impact (0-100)", default="10"))
            pain = float(Prompt.ask("Pain level (0-10)", default="1"))
        except Exception as e:
            console.print("[red]Invalid input. Aborting add.[/red]", e)
            return
        p = {
            'id': pid, 'age': age, 'temp': temp, 'condition_desc': cond,
            'severity': sev, 'urgency': urg, 'resource_need': res,
            'waiting_impact': wait, 'age_vulnerability': age_vulnerability_from_age(age),
            'pain_level': pain, 'registration_time': datetime.now().isoformat()
        }
        label, fuzzy_score = fuzzy_mamdani_label(sev/100.0, urg/100.0, wait/100.0)
        p['fuzzy_label'] = label
        p['fuzzy_score'] = fuzzy_score
        self.patients.append(p)
        console.print(f"[green]✅ Patient {pid} added.[/green]")

    def build_decision_matrix(self):
        if not self.patients:
            return np.zeros((0,5))
        rows = []
        for p in self.patients:
            rows.append([p['severity'], p['urgency'], p['resource_need'], p['waiting_impact'], p.get('age_vulnerability',30.0)])
        return np.array(rows, dtype=float)

    def compute_pca_weights(self, matrix, var_threshold=0.85):
        if matrix.shape[0] < 3:
            return np.array([0.35,0.30,0.15,0.15,0.05])
        scaler = StandardScaler()
        X = scaler.fit_transform(matrix)
        pca = PCA().fit(X)
        cum = np.cumsum(pca.explained_variance_ratio_)
        k = np.argmax(cum >= var_threshold) + 1
        eig = pca.explained_variance_[:k]
        load = pca.components_[:k,:]
        weights = np.zeros(matrix.shape[1])
        for j in range(matrix.shape[1]):
            s = 0.0
            for i in range(k):
                s += eig[i] * abs(load[i,j])
            weights[j] = s
        if weights.sum() == 0:
            weights = np.ones_like(weights)
        return weights / (weights.sum() + 1e-12)

    def explain_pca(self, matrix):
        scaler = StandardScaler()
        X = scaler.fit_transform(matrix)
        pca = PCA(n_components=min(X.shape[0], X.shape[1])).fit(X)
        comps = pca.components_
        explained = pca.explained_variance_ratio_
        return comps, explained

    def run_triage(self, algorithm="ensemble", use_pca=True):
        D = self.build_decision_matrix()
        if D.size == 0:
            console.print("[red]No patients to triage.[/red]")
            return None
        # weights via PCA or default
        if use_pca and D.shape[0] >= 3:
            w = self.compute_pca_weights(D)
            self.last_weights = w
            console.print(f"[green]PCA-derived weights (S,U,R,W,A):[/green] {np.round(w,3).tolist()}")
        else:
            w = np.array([0.35,0.30,0.15,0.15,0.05])
            self.last_weights = w
            console.print(f"[yellow]Using default expert weights (S,U,R,W,A):[/yellow] {np.round(w,3).tolist()}")

        # compute methods
        tops = topsis_scores(D, w)
        proms = promethee_netflow(D, w)
        fuzzy_scores = []
        for p in self.patients:
            _, fscore = fuzzy_mamdani_label(p['severity']/100.0, p['urgency']/100.0, p['waiting_impact']/100.0)
            fuzzy_scores.append(fscore)
        fuzzy_scores = np.array(fuzzy_scores)

        # ensemble logic
        if algorithm == "topsis":
            final_scores = tops
        elif algorithm == "promethee":
            # normalize promethee to 0..1
            pf = proms
            pf = (pf - pf.min())/(pf.max()-pf.min()+1e-12)
            final_scores = pf
        elif algorithm == "fuzzy":
            final_scores = fuzzy_scores
        else:
            # ensemble: weighted average with normalization
            t = (tops - tops.min())/(tops.max()-tops.min()+1e-12)
            p = (proms - proms.min())/(proms.max()-proms.min()+1e-12)
            f = (fuzzy_scores - fuzzy_scores.min())/(fuzzy_scores.max()-fuzzy_scores.min()+1e-12)
            final_scores = 0.5*t + 0.3*p + 0.2*f

        # annotate patients
        for i, p in enumerate(self.patients):
            p['topsis_score'] = float(tops[i])
            p['promethee_score'] = float(proms[i])
            p['fuzzy_score'] = float(fuzzy_scores[i])
            p['ensemble_score'] = float(final_scores[i])
            # priority thresholds can be adjusted
            if final_scores[i] >= 0.8: p['priority'] = "Critical"
            elif final_scores[i] >= 0.6: p['priority'] = "High"
            elif final_scores[i] >= 0.4: p['priority'] = "Medium"
            else: p['priority'] = "Low"

            # explainability: feature contribution (simple proportional)
            Drow = D[i].astype(float)
            contrib_raw = (Drow / (Drow.sum()+1e-12)) * self.last_weights
            p['feature_contrib'] = {k:v for k,v in zip(['S','U','R','W','A'], np.round(contrib_raw,3).tolist())}

        # rank by ensemble_score
        self.patients.sort(key=lambda x: -x['ensemble_score'])
        console.print("[green]Triage complete: patients ranked by chosen algorithm/ensemble.[/green]")

        # log triage event summary
        self.event_log.append({
            "event": "triage_run",
            "timestamp": datetime.now().isoformat(),
            "algorithm": algorithm,
            "num_patients": len(self.patients),
            "weights": self.last_weights.tolist()
        })
        return self.last_weights

    def route_patient(self, patient_id, use_gnn=True, lam=0.5):
        p = next((x for x in self.patients if x['id']==patient_id), None)
        if p is None:
            console.print("[red]Patient not found.[/red]")
            return None
        psym = build_patient_sym_vector(p)
        ids, waspas_scores, Xn = waspas_for_patient(psym, self.specialists, lam=lam)
        avail = np.array([s['availability']/100.0 for s in self.specialists])
        workloads = np.array([s['workload'] for s in self.specialists])
        gnn_scores = gnn_adjust_scores(waspas_scores, eta=0.25, avail_vector=avail, workloads=workloads, iters=3)
        _, sim_scores = similarity_scores(psym, self.spec_sym, specialists=self.specialists)
        final_ids, final_scores = aggregate_ranks(ids, waspas_scores, gnn_scores, sim_scores, use_gnn=use_gnn, weights=(0.5,0.3,0.2))

        # compute confidence: lower std across normalized components -> higher confidence
        components = np.vstack([
            (waspas_scores - waspas_scores.min())/(waspas_scores.max()-waspas_scores.min()+1e-12),
            (gnn_scores - gnn_scores.min())/(gnn_scores.max()-gnn_scores.min()+1e-12),
            (sim_scores - sim_scores.min())/(sim_scores.max()-sim_scores.min()+1e-12)
        ])
        conf = 1.0 - float(np.std(components))
        conf = max(0.0, min(1.0, conf))

        fused = []
        for sid, sc in zip(final_ids, final_scores):
            sp = next((s for s in self.specialists if s['id']==sid), None)
            fused.append({'id': sid, 'label': sp['label'] if sp else sid, 'score': float(sc)})
        recommended = fused[0] if fused else None

        # log routing event
        self.event_log.append({
            "event": "route_patient",
            "timestamp": datetime.now().isoformat(),
            "patient_id": patient_id,
            "waspas": dict(zip(ids, np.round(waspas_scores,4))),
            "gnn": dict(zip(ids, np.round(gnn_scores,4))),
            "sim": dict(zip(ids, np.round(sim_scores,4))),
            "recommended": recommended,
            "confidence": conf
        })

        return {
            'patient_id': patient_id,
            'waspas_scores': dict(zip(ids, np.round(waspas_scores,4))),
            'gnn_scores': dict(zip(ids, np.round(gnn_scores,4))),
            'sim_scores': dict(zip(ids, np.round(sim_scores,4))),
            'fused_ranking': fused,
            'recommended': recommended,
            'confidence': conf
        }

    def export_results_csv(self, prefix="triage_results"):
        if not self.patients:
            console.print("[red]No patients to export.[/red]")
            return None
        df = pd.DataFrame(self.patients)
        fname = f"{prefix}_{now_str()}.csv"
        df.to_csv(fname, index=False)
        console.print(f"[green]✅ Exported to {fname}[/green]")
        return fname

    def export_event_log(self, prefix="event_log"):
        if not self.event_log:
            console.print("[yellow]No events logged yet.[/yellow]")
            return None
        fname = f"{prefix}_{now_str()}.json"
        with open(fname, "w") as f:
            json.dump(self.event_log, f, indent=2)
        console.print(f"[green]✅ Event log saved to {fname}[/green]")
        return fname

    # ------------------ Visualizations ------------------
    def visualize_dashboard(self, show_plotly=True, save_prefix=None):
        if not self.patients:
            console.print("[red]No patients to visualize.[/red]")
            return None
        if 'ensemble_score' not in self.patients[0]:
            self.run_triage(use_pca=True)
        df = pd.DataFrame(self.patients)

        # 1) TOPSIS / Ensemble bar
        fig1 = px.bar(df, x='id', y='ensemble_score', color='priority',
                      color_discrete_map={'Critical':'red','High':'orange','Medium':'yellow','Low':'green'},
                      title='Ensemble Emergency Scores per Patient', labels={'ensemble_score':'Ensemble Score','id':'Patient ID'})

        # 2) PCA 2D scatter of decision features (S,U,R,W,A)
        D = self.build_decision_matrix()
        scaler = StandardScaler()
        Dsc = scaler.fit_transform(D)
        pca = PCA(n_components=2).fit_transform(Dsc)
        df['pca1'], df['pca2'] = pca[:,0], pca[:,1]
        fig2 = px.scatter(df, x='pca1', y='pca2', color='priority', hover_data=['id','condition_desc','ensemble_score'],
                          title='PCA projection of triage features', labels={'pca1':'PC1','pca2':'PC2'})

        # 3) Specialist network graph (availability-weighted)
        G = nx.Graph()
        for s in self.specialists:
            G.add_node(s['id'], label=s['label'], availability=s['availability'], workload=s['workload'])
        # adapt adjacency for visualization
        for i, s1 in enumerate(self.specialists):
            for j, s2 in enumerate(self.specialists):
                if j <= i: continue
                w = ADJ[i,j] * ( (s1['availability']+s2['availability'])/200.0 )
                if w > 0.05:
                    G.add_edge(s1['id'], s2['id'], weight=float(w))
        pos = nx.spring_layout(G, seed=42)
        edge_x, edge_y = [], []
        for edge in G.edges():
            x0,y0 = pos[edge[0]]
            x1,y1 = pos[edge[1]]
            edge_x += [x0, x1, None]
            edge_y += [y0, y1, None]
        node_x = [pos[n][0] for n in G.nodes()]
        node_y = [pos[n][1] for n in G.nodes()]
        node_text = [f"{n}: {G.nodes[n]['label']} (A={G.nodes[n]['availability']},W={G.nodes[n]['workload']})" for n in G.nodes()]
        edge_trace = go.Scatter(x=edge_x, y=edge_y, mode='lines', line=dict(width=1,color='gray'), hoverinfo='none')
        node_trace = go.Scatter(x=node_x, y=node_y, mode='markers+text', text=[G.nodes[n]['label'] for n in G.nodes()],
                                hovertext=node_text, textposition="bottom center", marker=dict(size=[10+G.nodes[n]['availability']/10 for n in G.nodes()]))
        fig3 = go.Figure(data=[edge_trace, node_trace], layout=go.Layout(title='Specialist Collaboration Network', showlegend=False))

        # 4) Radar chart for recommended specialist capabilities (for top patient)
        top_patient = df.iloc[0]
        r = self.route_patient(top_patient['id'])
        rec_id = r['recommended']['id']
        sp = next((s for s in self.specialists if s['id']==rec_id), None)
        categories = ['expertise','availability','success_rate','resource_access','workload']
        values = [sp[c] for c in categories]
        # for radar scale, invert workload (lower better)
        values_scaled = [values[0], values[1], values[2], values[3], (100 - values[4])]
        fig4 = go.Figure()
        fig4.add_trace(go.Scatterpolar(r=values_scaled, theta=categories, fill='toself', name=sp['label']))
        fig4.update_layout(title=f"Radar: Capabilities of recommended specialist for {top_patient['id']}", polar=dict(radialaxis=dict(visible=True)))

        # show
        if show_plotly:
            fig1.show()
            fig2.show()
            fig3.show()
            fig4.show()

        # save if requested
        if save_prefix:
            for i, fig in enumerate([fig1, fig2, fig3, fig4], start=1):
                try:
                    fname = f"{save_prefix}_fig{i}_{now_str()}.png"
                    fig.write_image(fname, scale=2)
                    console.print(f"[green]Saved:[/green] {fname}")
                except Exception:
                    htmlname = f"{save_prefix}_fig{i}_{now_str()}.html"
                    fig.write_html(htmlname)
                    console.print(f"[green]Saved interactive HTML:[/green] {htmlname}")

        return fig1, fig2, fig3, fig4

    # ------------------ Simulation to stress test and produce metrics ------------------
    def simulate_emergency(self, n=50, seed=42, noise=0.15, verbose=False):
        random.seed(seed)
        sim_patients = []
        for i in range(n):
            base = random.choice(SAMPLE_PATIENTS)
            pid = f"SIM_{i+1:03d}"
            age = max(1, int(base[1] + random.randint(-10,10)))
            sev = min(100, max(0, int(base[3] + random.gauss(0,10)*noise*10)))
            urg = min(100, max(0, int(base[4] + random.gauss(0,10)*noise*10)))
            res = min(100, max(0, int(base[5] + random.gauss(0,10)*noise*10)))
            wait = min(100, max(0, int(base[6] + random.gauss(0,10)*noise*10)))
            pain = max(0, min(10, int(base[7] + random.gauss(0,2))))
            cond = base[8]
            sim_patients.append([pid, age, 37.0 + random.random()*2, sev, urg, res, wait, pain, cond])
        # load sim patients
        self.patients = []
        for rec in sim_patients:
            pid, age, temp, sev, urg, res, wait, pain, cond = rec
            p = {
                'id': pid, 'age': age, 'temp': temp,
                'severity': sev, 'urgency': urg, 'resource_need': res,
                'waiting_impact': wait, 'age_vulnerability': age_vulnerability_from_age(age),
                'pain_level': pain, 'condition_desc': cond,
                'registration_time': datetime.now().isoformat()
            }
            label, fuzzy_score = fuzzy_mamdani_label(sev/100.0, urg/100.0, wait/100.0)
            p['fuzzy_label'] = label
            p['fuzzy_score'] = fuzzy_score
            self.patients.append(p)

        console.print(f"[cyan]Simulating {n} patients and running triage & routing...[/cyan]")
        self.run_triage(use_pca=True)
        # sample routing for first 10 to compute statistics
        stats = {'routes_tested':0, 'avg_conf':0.0}
        confs = []
        for p in track(self.patients[:min(30,len(self.patients))], description="Routing samples..."):
            r = self.route_patient(p['id'])
            if r:
                confs.append(r['confidence'])
                stats['routes_tested'] += 1
        stats['avg_conf'] = float(np.mean(confs)) if confs else 0.0
        console.print(f"[green]Simulation complete.[/green] Routes tested: {stats['routes_tested']}, Avg. confidence: {stats['avg_conf']:.3f}")
        self.event_log.append({"event":"simulation","timestamp":datetime.now().isoformat(),"n":n,"avg_conf":stats['avg_conf']})
        return stats

# ------------------ CLI Menu (improved) ------------------
def main_menu():
    system = CombinedMedicalSystem()
    header_panel = Panel.fit("🚑 [bold red]AI in Medical Emergency Decision Making — v2.0[/bold red]", style="cyan")
    console.print(header_panel)
    console.print("[grey]Tip: Use option 1 to load demo data, then explore triage (3) and routing (5).[/grey]")
    while True:
        console.print("\n[bold]Menu:[/bold]")
        console.print("1. Load 10 fixed sample patients (demo)")
        console.print("2. Add one patient manually")
        console.print("3. Run triage (choose algorithm / ensemble)")
        console.print("4. Show triage report (console)")
        console.print("5. Route a patient to specialists (WASPAS + adaptive GNN + Similarity fusion)")
        console.print("6. Visualize dashboard (Plotly)")
        console.print("7. Export CSV of triage patient list")
        console.print("8. Export event log (JSON)")
        console.print("9. Simulation mode (stress test)")
        console.print("10. Clear all patients")
        console.print("11. Exit")
        choice = Prompt.ask("Choose (1-11)", default="1")
        if choice == '1':
            system.load_sample_patients(auto_run_triage=True)
            df = pd.DataFrame(system.patients)
            console.print("\n[bold yellow]==== TRIAGE REPORT ====[/bold yellow]")
            table = Table(show_header=True, header_style="bold magenta")
            cols = ['id','age','condition_desc','severity','urgency','ensemble_score','priority','fuzzy_label']
            for c in cols:
                table.add_column(c, justify="center")
            for _, row in df.iterrows():
                table.add_row(str(row['id']), str(row['age']), str(row['condition_desc']), str(row['severity']),
                              str(row['urgency']), f"{row.get('ensemble_score',0):.3f}", str(row.get('priority','N/A')), str(row.get('fuzzy_label','N/A')))
            console.print(table)
        elif choice == '2':
            system.add_patient_interactive()
        elif choice == '3':
            alg = Prompt.ask("Choose triage algorithm [ensemble/topsis/promethee/fuzzy]", default="ensemble")
            system.run_triage(algorithm=alg, use_pca=True)
            # show brief explanation
            if system.last_weights is not None:
                console.print("[cyan]Feature weights used (S,U,R,W,A):[/cyan]", np.round(system.last_weights,3).tolist())
                comps, explained = system.explain_pca(system.build_decision_matrix())
                console.print("[cyan]PCA explained variance (first components):[/cyan]", np.round(explained[:3],3).tolist())
        elif choice == '4':
            if not system.patients:
                console.print("[red]No patients to show. Load or add patients first.[/red]")
                continue
            df = pd.DataFrame(system.patients)
            console.print("\n[bold yellow]==== TRIAGE REPORT (Full) ====[/bold yellow]")
            table = Table(show_header=True, header_style="bold magenta")
            cols = ['id','age','condition_desc','severity','urgency','resource_need','waiting_impact','pain_level','topsis_score','promethee_score','fuzzy_score','ensemble_score','priority']
            for c in cols:
                table.add_column(c, justify="center")
            for _, row in df.iterrows():
                table.add_row(str(row['id']), str(row['age']), str(row['condition_desc']), str(row['severity']),
                              str(row['urgency']), str(row['resource_need']), str(row['waiting_impact']), str(row['pain_level']),
                              f"{row.get('topsis_score',0):.3f}", f"{row.get('promethee_score',0):.4f}", f"{row.get('fuzzy_score',0):.3f}", f"{row.get('ensemble_score',0):.3f}", str(row.get('priority','N/A')))
            console.print(table)
        elif choice == '5':
            if not system.patients:
                console.print("[red]No patients available. Load or add first.[/red]")
                continue
            pid = Prompt.ask("Enter patient ID to route (e.g., P001)")
            res = system.route_patient(pid, use_gnn=True)
            if not res:
                console.print("[red]Routing failed or patient not found.[/red]")
                continue
            console.print(f"\n[bold yellow]Routing result for {pid} (top 3):[/bold yellow]")
            for r in res['fused_ranking'][:3]:
                console.print(f" - {r['label']} ({r['id']}) | fused score: {r['score']:.3f}")
            console.print(f"\n[cyan]Recommendation confidence:[/cyan] {res['confidence']:.3f}")
        elif choice == '6':
            if not system.patients:
                console.print("[red]No patients to visualize. Load or add first.[/red]")
                continue
            save_choice = Confirm.ask("Save visuals to disk?", default=False)
            prefix = None
            if save_choice:
                prefix = Prompt.ask("Filename prefix", default="triage_viz")
            system.visualize_dashboard(show_plotly=True, save_prefix=prefix)
        elif choice == '7':
            if not system.patients:
                console.print("[red]No patients to export.[/red]")
                continue
            system.export_results_csv()
        elif choice == '8':
            system.export_event_log()
        elif choice == '9':
            n = IntPrompt.ask("Number of simulated patients", default=50)
            stats = system.simulate_emergency(n=n, noise=0.12)
            console.print(f"[green]Simulation stats:[/green] {stats}")
        elif choice == '10':
            confirm = Confirm.ask("Clear all patients?", default=False)
            if confirm:
                system.patients.clear()
                console.print("[green]Cleared all patient data.[/green]")
        elif choice == '11':
            console.print("[bold green]Exiting. Stay safe![/bold green]")
            # save event log automatically
            system.export_event_log(prefix="event_log_before_exit")
            break
        else:
            console.print("[red]Invalid choice — please select a menu number.[/red]")

# Run the menu
if __name__ == "__main__":
    main_menu()


  if not pkgutil.find_loader(r):


Installing required packages...


6
