In [59]:
import os
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data, Batch
from torch_geometric.nn import GATv2Conv, global_mean_pool
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from rdkit import Chem
from rdkit.Chem import AllChem


In [60]:

def zscore_sample(sample: np.ndarray):
    mu = sample.mean(axis=0, keepdims=True)
    sigma = sample.std(axis=0, keepdims=True)
    sigma[sigma < 1e-6] = 1.0
    return (sample - mu) / sigma

def eeg_row_to_graph(channel_features, connectivity_threshold=0.6, knn_fallback=4):
    if channel_features is None or channel_features.shape[0] == 0:
        return None
    X = np.array(channel_features, dtype=np.float32)
    Xz = zscore_sample(X)
    num_nodes = Xz.shape[0]
    x = torch.tensor(Xz, dtype=torch.float)

    corr_matrix = pd.DataFrame(Xz).T.corr(method='pearson').fillna(0.0).values
    edges, weights = [], []
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            w = float(abs(corr_matrix[i, j]))
            if w > connectivity_threshold:
                edges.append([i, j]); edges.append([j, i])
                weights.append(w); weights.append(w)

    if len(edges) == 0:
        Xn = Xz / (np.linalg.norm(Xz, axis=1, keepdims=True) + 1e-8)
        S = np.matmul(Xn, Xn.T)
        np.fill_diagonal(S, -1.0)
        k = min(knn_fallback, max(1, num_nodes - 1))
        for i in range(num_nodes):
            nbrs = np.argpartition(-S[i], k)[:k]
            for j in nbrs:
                if j >= 0:
                    w = max(0.0, float(S[i, j]))
                    edges.append([i, j]); weights.append(w)

    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous() if edges else torch.empty((2, 0), dtype=torch.long)
    edge_attr  = torch.tensor(weights, dtype=torch.float).unsqueeze(-1) if weights else torch.empty((0, 1), dtype=torch.float)
    return Data(x=x, edge_index=edge_index, edge_attr=edge_attr)

_ATOM_ORDER = ["C","N","O","S","F","Cl","Br","I","P","Si"]
def atom_feature_11(atom: Chem.Atom):
    sym = atom.GetSymbol()
    vec = [0.0]*11
    if sym in _ATOM_ORDER:
        vec[_ATOM_ORDER.index(sym)] = 1.0
    else:
        vec[-1] = 1.0  
    return vec

def mol_to_graph(smiles: str):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None
    AllChem.Compute2DCoords(mol)

    num_atoms = mol.GetNumAtoms()
    if num_atoms == 0:
        return None

    X = []
    for a in mol.GetAtoms():
        X.append(atom_feature_11(a))
    x = torch.tensor(X, dtype=torch.float)

    edges = []
    for b in mol.GetBonds():
        i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
        edges.append([i, j]); edges.append([j, i])
    if len(edges) == 0:
        edge_index = torch.empty((2,0), dtype=torch.long)
    else:
        edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()

    return Data(x=x, edge_index=edge_index)


In [61]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool

class CrossAttention(nn.Module):
    def __init__(self, hidden_dim):
        super().__init__()
        self.query = nn.Linear(hidden_dim, hidden_dim)
        self.key   = nn.Linear(hidden_dim, hidden_dim)
        self.value = nn.Linear(hidden_dim, hidden_dim)
        self.scale = hidden_dim ** 0.5

    def forward(self, x_query, x_key_value):
        Q = self.query(x_query)
        K = self.key(x_key_value)
        V = self.value(x_key_value)
        attn_scores  = torch.matmul(Q, K.T) / self.scale
        attn_weights = torch.softmax(attn_scores, dim=-1)
        attended     = torch.matmul(attn_weights, V)
        return attended

class CrossAttentionGNN(nn.Module):
    """
    دقیقا مطابق Phase2.ipynb:
      - دو لایه GCN با اسامی gcn1/gcn2
      - ماژول cross_attention
      - classifier انتهایی (۴-کلاسه)
    بعلاوه: یک متد کمکی encode_pair برای گرفتن امبدینگ زوج قبل از کلاس‌فایر
    """
    def __init__(self, in_channels, hidden_channels):
        super().__init__()
        self.gcn1 = GCNConv(in_channels, hidden_channels)
        self.gcn2 = GCNConv(hidden_channels, hidden_channels)
        self.cross_attention = CrossAttention(hidden_channels)
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_channels, hidden_channels),
            nn.ReLU(),
            nn.Linear(hidden_channels, 4)
        )

    def encode_single(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = F.relu(self.gcn1(x, edge_index))
        x = F.relu(self.gcn2(x, edge_index))
        return x, batch  

    def encode_pair(self, data_a, data_b):
        x1, batch1 = self.encode_single(data_a)
        x2, batch2 = self.encode_single(data_b)
        attn1 = self.cross_attention(x1, x2)
        attn2 = self.cross_attention(x2, x1)
        pooled1 = global_mean_pool(attn1, batch1)
        pooled2 = global_mean_pool(attn2, batch2)
        combined = torch.cat([pooled1, pooled2], dim=1)  
        return combined

    def forward(self, data_a, data_b):
        combined = self.encode_pair(data_a, data_b)
        out = self.classifier(combined)
        return out


In [62]:
from torch_geometric.data import Batch
import pandas as pd
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
PHASE2_CKPT = "../Phase 2/trained_cross_attention_gnn.pth"

def infer_dims_from_state(state):
    """حدس hidden_dim و in_channels از روی state_dict چک‌پوینت."""
    hidden = None
   
    for k in ["cross_attention.query.weight", "cross_attention.value.weight", "cross_attention.key.weight"]:
        if k in state:
            w = state[k]
            hidden = w.shape[0]
            break

    in_ch = None
    for k in ["gcn1.lin_l.weight", "gcn1.lin.weight"]:
        if k in state:
            w = state[k] 
            in_ch = w.shape[1]
            break

    return hidden, in_ch

def load_phase2_model_smart(model_path, device):
    if not os.path.isfile(model_path):
        raise FileNotFoundError(f"Phase2 checkpoint not found: {model_path}")
    print(f"Loading trained Phase 2 model from: {model_path}")
    state = torch.load(model_path, map_location=device)

    hidden_dim, in_channels = infer_dims_from_state(state)
    if hidden_dim is None:
        raise RuntimeError("Could not infer hidden_dim from checkpoint (cross_attention.* not found).")
    if in_channels is None:
        in_channels = 11
        print(f"[WARN] Could not infer in_channels from checkpoint. Falling back to {in_channels}.")

    print(f"[Phase2] inferred hidden_dim={hidden_dim}, in_channels={in_channels}")

    model = CrossAttentionGNN(in_channels=in_channels, hidden_channels=hidden_dim)

    missing, unexpected = model.load_state_dict(state, strict=False)
    if missing:
        print("[Phase2] missing keys:", missing)
    if unexpected:
        print("[Phase2] unexpected keys:", unexpected)

    model.to(device).eval()
    projector = nn.Linear(2 * hidden_dim, 128).to(device)
    return model, projector, hidden_dim, in_channels

phase2_model, proj_pair_to_128, HIDDEN_PHASE2, IN_CH_PHASE2 = load_phase2_model_smart(PHASE2_CKPT, device)
pairs_df = pd.read_csv("../Phase 2/pairs.csv")
label_map = {"ADHD": 0, "Schizophrenia": 1, "Epilepsy": 2, "Supplement": 3}
num_classes = len(label_map)

@torch.no_grad()
def pair_embedding(smiles_a, smiles_b):
    g1 = mol_to_graph(smiles_a)
    g2 = mol_to_graph(smiles_b)
    if g1 is None or g2 is None:
        return None
    data_a = Batch.from_data_list([g1]).to(device)
    data_b = Batch.from_data_list([g2]).to(device)
    combined = phase2_model.encode_pair(data_a, data_b) 
    z = proj_pair_to_128(combined)                       
    z = F.normalize(z, p=2, dim=-1)
    return z.squeeze(0) 

drug_embeddings = torch.zeros((num_classes, 128), device=device)

for name, idx in label_map.items():
    rows = pairs_df[pairs_df['label'] == name]
    if rows.empty:
        print(f"[WARN] No row for {name} in pairs.csv")
        continue
    row = rows.iloc[0]
    z = pair_embedding(row['mol_1'], row['mol_2'])
    if z is None:
        print(f"[WARN] RDKit failed for {name}")
        continue
    drug_embeddings[idx] = z

print("Drug embeddings shape:", tuple(drug_embeddings.shape))
print(f"[Phase2] Using hidden_dim={HIDDEN_PHASE2}, in_channels={IN_CH_PHASE2}")
print(f"[Phase2] Using hidden_dim={HIDDEN_PHASE2}, in_channels={IN_CH_PHASE2}")
assert IN_CH_PHASE2 == 11, f"Phase2 expects {IN_CH_PHASE2} features per atom; adapt atom_feature_11 accordingly!"



Loading trained Phase 2 model from: ../Phase 2/trained_cross_attention_gnn.pth
[Phase2] inferred hidden_dim=100, in_channels=11
Drug embeddings shape: (4, 128)
[Phase2] Using hidden_dim=100, in_channels=11
[Phase2] Using hidden_dim=100, in_channels=11


In [63]:
import os, json, torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def build_drug_bank_quiet(pairs_df, device,
                          invalid_path="outputs_phase3/invalid_smiles.txt"):
    """
    Returns:
      Z: [M, 128] L2-normalized embeddings for all candidate drug pairs
      meta: list of dicts with keys: idx, label, mol1, mol2
      label_to_indices: dict label -> list of indices in Z
    Writes:
      invalid_path: one invalid pair per line (idx \t mol1 \t mol2)
    """
    os.makedirs("outputs_phase3", exist_ok=True)

    Z_list, meta = [], []
    label_to_indices = {}
    invalid = []

    for idx, row in pairs_df.iterrows():
        smi1, smi2 = row["mol_1"], row["mol_2"]
        lab = str(row["label"])

        z = None
        try:
            z = pair_embedding(smi1, smi2) 
        except Exception:
            z = None

        if (z is None) or (torch.isnan(z).any()):
            invalid.append((int(idx), smi1, smi2))
            continue

        z = F.normalize(z.to(device, dtype=torch.float), p=2, dim=-1)
        Z_list.append(z.unsqueeze(0))
        meta.append({"idx": int(idx), "label": lab, "mol1": smi1, "mol2": smi2})
        label_to_indices.setdefault(lab, []).append(len(Z_list) - 1)

    if not Z_list:
        raise RuntimeError("No valid drug embeddings built from pairs.csv")

    Z = torch.cat(Z_list, dim=0).to(device)  

    with open(invalid_path, "w", encoding="utf-8") as f:
        for i, a, b in invalid:
            f.write(f"{i}\t{a}\t{b}\n")

    print(f"[DrugBank] Built {Z.size(0)} embeddings; skipped {len(invalid)} invalid pairs "
          f"out of {len(pairs_df)} total.")

    return Z, meta, label_to_indices

z_path   = "outputs_phase3/drug_bank_Z.pt"
meta_path= "outputs_phase3/drug_bank_meta.json"
l2i_path = "outputs_phase3/drug_bank_label_to_indices.json"

if os.path.exists(z_path) and os.path.exists(meta_path) and os.path.exists(l2i_path):
    Z = torch.load(z_path, map_location=device)
    with open(meta_path, "r", encoding="utf-8") as f:
        meta = json.load(f)
    with open(l2i_path, "r", encoding="utf-8") as f:
        label_to_indices = json.load(f)
    print(f"[DrugBank] Loaded cached bank: Z{tuple(Z.shape)} | classes={list(label_to_indices.keys())}")
else:
    Z, meta, label_to_indices = build_drug_bank_quiet(pairs_df, device)
    torch.save(Z, z_path)
    with open(meta_path, "w", encoding="utf-8") as f:
        json.dump(meta, f, ensure_ascii=False, indent=2)
    with open(l2i_path, "w", encoding="utf-8") as f:
        json.dump(label_to_indices, f, ensure_ascii=False, indent=2)


[DrugBank] Loaded cached bank: Z(13818, 128) | classes=['Schizophrenia', 'ADHD', 'Epilepsy', 'Supplement']


In [64]:
CANONICAL_CHANNELS = [
    "FP1","FP2","AF3","AF4","F7","F3","FZ","F4","F8",
    "T3","T4","T5","T6","T7","T8","FT9","FT10",
    "FC5","FC6","C3","CZ","C4",
    "P7","P3","PZ","P4","P8",
    "O1","O2"
]

NODE_FEATURE_SUFFIXES = [
    "mean","std","entropy","hjorth_mobility",
    "delta_power","theta_power","alpha_power","beta_power","gamma_power"
]

LEGACY_TO_MODERN = {"T3":"T7","T4":"T8","T5":"P7","T6":"P8"}  

def _detect_channels(df: pd.DataFrame, dataset_tag: str):
    chs = []
    for ch in CANONICAL_CHANNELS:
        if dataset_tag == "sch":
            cand = [ch] + [k for k,v in LEGACY_TO_MODERN.items() if v == ch]
            ok = any(f"{c}_mean" in df.columns for c in cand)
        else:
            ok = f"{ch}_mean" in df.columns
        if ok:
            chs.append(ch)
    return chs

def _row_to_matrix(row: pd.Series, channels, dataset_tag: str):
    feats = []
    for ch in channels:
        vals = []
        for suf in NODE_FEATURE_SUFFIXES:
            col = f"{ch}_{suf}"
            val = row[col] if col in row else None
            if val is None and dataset_tag == "sch":
                legacy = next((k for k,v in LEGACY_TO_MODERN.items() if v == ch), None)
                if legacy is not None:
                    col2 = f"{legacy}_{suf}"
                    if col2 in row: val = row[col2]
            if val is None or pd.isna(val) or np.isinf(val):
                val = 0.0
            vals.append(float(val))
        feats.append(vals)
    X = np.array(feats, dtype=np.float32)  
    return X

def read_eeg_csv_to_graphs(csv_path: str, dataset_tag: str, class_id: int):
    df = pd.read_csv(csv_path)
    id_col = "subject_id" if "subject_id" in df.columns else ("file" if "file" in df.columns else None)
    channels = _detect_channels(df, dataset_tag)
    graphs, removed_healthy, failed = [], 0, 0

    for _, row in df.iterrows():
        if int(row["label"]) == 0:
            removed_healthy += 1
            continue
        X = _row_to_matrix(row, channels, dataset_tag)
        g = eeg_row_to_graph(X, connectivity_threshold=0.6, knn_fallback=6)
        if g is None: 
            failed += 1; continue
        g.y = torch.tensor([class_id], dtype=torch.long)
        uid = str(row[id_col]) if id_col and id_col in row else f"{dataset_tag}_unknown"
        g.uid = uid
        graphs.append(g)

    print(f"[EEG] {dataset_tag}: built {len(graphs)} graphs | skipped healthy: {removed_healthy} | failed: {failed} | channels: {len(channels)}")
    return graphs



ADHD_CSV = "adhd.csv"
SCH_CSV  = "sch.csv"
EPI_CSV  = "epi.csv"

PH3_LABELS = {"adhd":0, "sch":1, "epi":2}

all_graphs = []
if os.path.isfile(ADHD_CSV):
    all_graphs += read_eeg_csv_to_graphs(ADHD_CSV, "adhd", PH3_LABELS["adhd"])
if os.path.isfile(SCH_CSV):
    all_graphs += read_eeg_csv_to_graphs(SCH_CSV, "sch",  PH3_LABELS["sch"])
if os.path.isfile(EPI_CSV):
    all_graphs += read_eeg_csv_to_graphs(EPI_CSV, "epi",  PH3_LABELS["epi"])

print(f"[Phase3] Total EEG graphs (disease-only): {len(all_graphs)}")


[EEG] adhd: built 20400 graphs | skipped healthy: 20400 | failed: 0 | channels: 14
[EEG] sch: built 2700 graphs | skipped healthy: 2340 | failed: 0 | channels: 18
[EEG] epi: built 3423 graphs | skipped healthy: 3423 | failed: 0 | channels: 21
[Phase3] Total EEG graphs (disease-only): 26523


In [65]:

import numpy as np
from sklearn.model_selection import StratifiedGroupKFold, train_test_split
from torch_geometric.loader import DataLoader

def has_all(idxs, y, ncls=3):
    bc = np.bincount(y[idxs], minlength=ncls)
    return (bc > 0).all(), bc

groups = np.array([g.uid for g in all_graphs])
y = np.array([int(g.y.item()) for g in all_graphs])
idx = np.arange(len(all_graphs))

sgkf = StratifiedGroupKFold(n_splits=5, shuffle=True, random_state=42)
picked = None
for tr_idx, te_idx in sgkf.split(idx, y=y, groups=groups):
    
    va_idx, ts_idx = train_test_split(te_idx, test_size=0.5, random_state=42, stratify=y[te_idx])

    ok_tr, bc_tr = has_all(tr_idx, y)
    ok_va, bc_va = has_all(va_idx, y)
    ok_ts, bc_ts = has_all(ts_idx, y)

    if ok_tr and ok_va and ok_ts:
        picked = (tr_idx, va_idx, ts_idx)
        break


if picked is None:
    for tr_idx, te_idx in sgkf.split(idx, y=y, groups=groups):
        for shift in [0, len(te_idx)//5, len(te_idx)//3, len(te_idx)//2]:
            te2 = np.roll(te_idx, shift=shift)
            va_idx, ts_idx = train_test_split(te2, test_size=0.5, random_state=42, stratify=y[te2])
            ok_tr, bc_tr = has_all(tr_idx, y)
            ok_va, bc_va = has_all(va_idx, y)
            ok_ts, bc_ts = has_all(ts_idx, y)
            if ok_tr and ok_va and ok_ts:
                picked = (tr_idx, va_idx, ts_idx)
                break
        if picked is not None:
            break

if picked is None:
    raise RuntimeError("Could not find a split with all classes in train/val/test. Check group/labels.")

train_idx, val_idx, test_idx = picked

def subset(idxs): return [all_graphs[i] for i in idxs]
train_graphs = subset(train_idx)
val_graphs   = subset(val_idx)
test_graphs  = subset(test_idx)

print("[Split] sizes:", len(train_graphs), len(val_graphs), len(test_graphs))
print("[Dist] train:", np.bincount([int(g.y.item()) for g in train_graphs], minlength=3))
print("[Dist] val  :", np.bincount([int(g.y.item()) for g in val_graphs],   minlength=3))
print("[Dist] test :", np.bincount([int(g.y.item()) for g in test_graphs],  minlength=3))

train_loader = DataLoader(train_graphs, batch_size=64, shuffle=True)
val_loader   = DataLoader(val_graphs,   batch_size=64, shuffle=False)
test_loader  = DataLoader(test_graphs,  batch_size=64, shuffle=False)


[Split] sizes: 21081 2721 2721
[Dist] train: [16200  2160  2721]
[Dist] val  : [2100  270  351]
[Dist] test : [2100  270  351]


In [66]:
from torch_geometric.nn import GATv2Conv, global_mean_pool, global_max_pool
import torch.nn as nn
import torch.nn.functional as F
import torch

class EEGPhase3Model(nn.Module):
    def __init__(self, in_dim=9, hidden=128, heads=4, layers=2, dropout=0.3, out_classes=3):
        super().__init__()
        self.layers = nn.ModuleList()
        self.bns    = nn.ModuleList()
        d_in = in_dim
        for _ in range(layers):
            self.layers.append(GATv2Conv(d_in, hidden//heads, heads=heads, edge_dim=1, add_self_loops=False))
            self.bns.append(nn.BatchNorm1d(hidden))
            d_in = hidden
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden*2, hidden)
        self.cls  = nn.Linear(hidden, out_classes)

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch
        if edge_attr is not None and edge_attr.dim() == 1:
            edge_attr = edge_attr.unsqueeze(-1)
        x = torch.nan_to_num(x, nan=0.0, posinf=0.0, neginf=0.0)
        for conv, bn in zip(self.layers, self.bns):
            x = conv(x, edge_index, edge_attr=edge_attr)
            x = bn(x)
            x = F.relu(x)
            x = self.dropout(x)
        g = torch.cat([global_mean_pool(x, batch), global_max_pool(x, batch)], dim=-1)
        g = F.relu(self.proj(g))
        logits = self.cls(g)
        return logits, g
IN_DIM = train_graphs[0].x.size(1) 
print("IN_DIM =", IN_DIM)   
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model3 = EEGPhase3Model(in_dim=IN_DIM, hidden=128, heads=4, layers=2, dropout=0.3, out_classes=3).to(device)
sum(p.numel() for p in model3.parameters())


IN_DIM = 9


70147

In [67]:
from sklearn.metrics import accuracy_score, f1_score, classification_report

def class_weights(graphs, ncls=3):
    ys = [int(g.y.item()) for g in graphs]
    cnt = np.bincount(ys, minlength=ncls).astype(np.float32)
    w = cnt.max() / (cnt + 1e-6)
    return torch.tensor(w, dtype=torch.float)

cw = class_weights(train_graphs, ncls=3).to(device)
opt = torch.optim.AdamW(model3.parameters(), lr=2e-3, weight_decay=1e-4)

@torch.no_grad()
def eval_loader(loader):
    model3.eval()
    ys, ps = [], []
    for b in loader:
        b = b.to(device)
        logits, _ = model3(b)
        ps.extend(logits.argmax(-1).cpu().numpy().tolist())
        ys.extend(b.y.cpu().numpy().tolist())
    acc = accuracy_score(ys, ps)
    f1m = f1_score(ys, ps, average="macro", zero_division=0)
    rep = classification_report(ys, ps, labels=[0,1,2],
                                target_names=["ADHD","Schizophrenia","Epilepsy"], zero_division=0)
    return acc, f1m, rep

def train_one_epoch():
    model3.train()
    total = 0.0
    for b in train_loader:
        b = b.to(device)
        if b.edge_attr is not None and b.edge_attr.dim()==1:
            b.edge_attr = b.edge_attr.unsqueeze(-1)
        logits, _ = model3(b)
        loss = F.cross_entropy(logits, b.y, weight=cw)
        opt.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model3.parameters(), 2.0)
        opt.step()
        total += loss.item() * b.num_graphs
    return total / len(train_loader.dataset)

best_f1, best_state, patience, wo = -1.0, None, 8, 0
for epoch in range(1, 10):
    tr_loss = train_one_epoch()
    va_acc, va_f1, _ = eval_loader(val_loader)
    print(f"[Epoch {epoch:02d}] loss={tr_loss:.4f} | val_acc={va_acc:.4f} | val_f1={va_f1:.4f}")
    if va_f1 > best_f1:
        best_f1, best_state, wo = va_f1, {k:v.detach().cpu() for k,v in model3.state_dict().items()}, 0
    else:
        wo += 1
        if wo >= patience:
            print("Early stopping."); break

if best_state is not None:
    model3.load_state_dict(best_state)


[Epoch 01] loss=0.2279 | val_acc=0.9779 | val_f1=0.9560
[Epoch 02] loss=0.0961 | val_acc=0.9857 | val_f1=0.9701
[Epoch 03] loss=0.0669 | val_acc=0.9864 | val_f1=0.9716
[Epoch 04] loss=0.0613 | val_acc=0.9860 | val_f1=0.9697
[Epoch 05] loss=0.0489 | val_acc=0.9882 | val_f1=0.9750
[Epoch 06] loss=0.0447 | val_acc=0.9857 | val_f1=0.9693
[Epoch 07] loss=0.0355 | val_acc=0.9886 | val_f1=0.9753
[Epoch 08] loss=0.0425 | val_acc=0.9908 | val_f1=0.9806
[Epoch 09] loss=0.0350 | val_acc=0.9912 | val_f1=0.9811


In [68]:
te_acc, te_f1, te_rep = eval_loader(test_loader)
print("=== TEST REPORT ===")
print(te_rep)
print(f"Test Acc={te_acc:.4f} | Test Macro-F1={te_f1:.4f}")

os.makedirs("outputs_phase3", exist_ok=True)
torch.save(model3.state_dict(), "outputs_phase3/phase3_eeg_gatv2.pth")
print("Saved → outputs_phase3/phase3_eeg_gatv2.pth")


=== TEST REPORT ===
               precision    recall  f1-score   support

         ADHD       1.00      1.00      1.00      2100
Schizophrenia       0.98      0.97      0.97       270
     Epilepsy       0.97      0.98      0.97       351

     accuracy                           0.99      2721
    macro avg       0.98      0.98      0.98      2721
 weighted avg       0.99      0.99      0.99      2721

Test Acc=0.9919 | Test Macro-F1=0.9813
Saved → outputs_phase3/phase3_eeg_gatv2.pth


In [69]:

import os, json, torch
import torch.nn.functional as F

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


Z = torch.load("outputs_phase3/drug_bank_Z.pt", map_location=device)
with open("outputs_phase3/drug_bank_label_to_indices.json", "r", encoding="utf-8") as f:
    label_to_indices = json.load(f)


label_names = ["ADHD", "Schizophrenia", "Epilepsy"]                
drug_label_names = ["ADHD", "Schizophrenia", "Epilepsy", "Supplement"]  

TAU = 0.1   

model3.eval()

def ensure_edge_attr(batch):
    """Ensure edge_attr exists and is [E,1] float, since GATv2Conv(edge_dim=1) is used."""
    if not hasattr(batch, "edge_attr") or batch.edge_attr is None:
        E = batch.edge_index.size(1)
        batch.edge_attr = torch.ones((E, 1), device=batch.edge_index.device, dtype=torch.float)
    else:
        ea = batch.edge_attr
        if ea.dim() == 1:
            batch.edge_attr = ea.view(-1, 1).float()
        elif ea.dim() == 2 and ea.size(-1) != 1:
            batch.edge_attr = ea[:, :1].float()
        else:
            batch.edge_attr = ea.float()
    return batch

@torch.no_grad()
def run_model3(batch):
    """Call model3 no matter if it expects (data) or (x,edge_index,edge_attr,batch)."""
    try:
        out = model3(batch)
        if isinstance(out, tuple) and len(out) == 2:
            return out
        raise TypeError("Model returned unexpected output shape")
    except TypeError:
        b = ensure_edge_attr(batch)
        return model3(b.x, b.edge_index, b.edge_attr, b.batch)

@torch.no_grad()
def aggregate_class_probs(probs_row: torch.Tensor):
    """
    probs_row: [M] per-drug probabilities
    returns: dict class -> sum of probs over drugs of that class
    """
    out = {}
    for lab in drug_label_names:
        idxs = label_to_indices.get(lab, [])
        out[lab] = float(probs_row[idxs].sum().item()) if idxs else 0.0
    return out

@torch.no_grad()
def predict_full(batch):
    """
    For a batch of EEG graphs:
      - Get disease prediction (3-class)
      - Map EEG embedding to full drug bank probs (size M)
      - Aggregate to class-level probs (ADHD/Sch/Epi/Supplement)
    """
    logits, g = run_model3(batch)                            
    probs_disease = F.softmax(logits, dim=1)                   
    scores, probs_drugs = scores_for_all_drugs(g, Z, tau=TAU)  

    out_list = []
    for i in range(probs_drugs.size(0)):
        true_lbl = int(batch.y[i].item()) if hasattr(batch, "y") else None
        pred_lbl = int(torch.argmax(probs_disease[i]).item())
        out_list.append({
            "true_disease": label_names[true_lbl] if true_lbl is not None and true_lbl < len(label_names) else None,
            "pred_disease": label_names[pred_lbl],
            "disease_probs": {
                "ADHD": float(probs_disease[i,0]),
                "Schizophrenia": float(probs_disease[i,1]),
                "Epilepsy": float(probs_disease[i,2]),
            },
            
            "drug_class_probs": aggregate_class_probs(probs_drugs[i]),
        })
    return out_list


all_records = []
for b in test_loader:
    b = b.to(device)
    all_records.extend(predict_full(b))

os.makedirs("outputs_phase3", exist_ok=True)
out_path = "outputs_phase3/task2_fullbank_classonly.jsonl"
with open(out_path, "w", encoding="utf-8") as f:
    for rec in all_records:
        f.write(json.dumps(rec, ensure_ascii=False) + "\n")

print(f"Saved → {out_path}")
print("Example:", json.dumps(all_records[0], ensure_ascii=False, indent=2)[:1200], " ...")


Saved → outputs_phase3/task2_fullbank_classonly.jsonl
Example: {
  "true_disease": "ADHD",
  "pred_disease": "ADHD",
  "disease_probs": {
    "ADHD": 1.0,
    "Schizophrenia": 6.748279712809335e-09,
    "Epilepsy": 1.9434192033429554e-09
  },
  "drug_class_probs": {
    "ADHD": 0.07944796979427338,
    "Schizophrenia": 0.2082979828119278,
    "Epilepsy": 0.4988584518432617,
    "Supplement": 0.21339532732963562
  }
}  ...
