In [2]:
import os
print(os.getcwd())

/mnt/ComplexSystemsLab/Rawan/GNN/16


In [None]:
# --- Imports & Setup -----------------------
import warnings
warnings.filterwarnings("ignore")

import sys
sys.modules["apex"] = None
sys.modules["apex.normalization"] = None

import os
import re
import math
import random
from collections import defaultdict, Counter

import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils import clip_grad_norm_

from torch_geometric.data import Data
from torch_geometric.loader import DataLoader

from sentence_transformers import SentenceTransformer

from sklearn.preprocessing import LabelEncoder
from sklearn.model_selection import GroupShuffleSplit, GroupKFold, StratifiedGroupKFold
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report, f1_score, precision_recall_fscore_support, matthews_corrcoef

from tqdm import tqdm
import joblib

In [4]:
# To ensures if we train with the same code and data tomorrow, we get identical results:
def set_seed(seed: int = 42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

SEED = 42
set_seed(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [None]:
# --- Load Labeled Narratives and Scores -------------
df = pd.read_csv("narratives.csv")
scores_df = pd.read_csv("scores.csv")

# ------------------ Drop missing or invalid emotion rows ---
df = df[df['Emotion'].notnull()].copy()

# --- Lowercase emotion labels -------------
df['Emotion'] = df['Emotion'].astype(str).str.strip().str.lower()

In [None]:
# --- Encode Emotions -------------
emotion_encoder = LabelEncoder()
df["emotion_label"] = emotion_encoder.fit_transform(df["Emotion"])
joblib.dump(emotion_encoder, "narrative_emotion_encoder.pkl")

['narrative_emotion_encoder.pkl']

In [None]:
# --- Load E5 embedder --------
embedder = SentenceTransformer("intfloat/e5-large-v2") #note to self: maybe another embeddeR?

In [8]:
# Ensure Sentence column is string and handle NaN
df['Sentence'] = df['Sentence'].astype(str).fillna("")

In [9]:
# Prepare sentences for e5 (instruction-tuned) 
sentences = [f"passage: {s}" for s in df['Sentence'].tolist()]

# Generate embeddings
embeddings = embedder.encode(
    sentences,
    show_progress_bar=True,
    convert_to_numpy=True,
    batch_size=64
)

# Save embeddings as list in a column
df["embedding"] = [emb.tolist() for emb in embeddings]

print("Embedding shape:", embeddings.shape)  # (num_sentences, 1024)

Batches:   0%|          | 0/142 [00:00<?, ?it/s]

Embedding shape: (9068, 1024)


In [10]:
# --- Compute Negativity Ratio ---
NEGATIVE_EMOTIONS = {
    "disapproval", "sadness", "anger", "grief", "fear", "disgust", "remorse", "annoyance",
    "disappointment", "embarrassment", "nervousness", "confusion"
}

negativity_ratio = (
    df.groupby(["Patient", "SessionNumber"])["Emotion"]
    .apply(lambda g: g.isin(NEGATIVE_EMOTIONS).sum() / len(g) if len(g) > 0 else 0)
    .reset_index(name="NegativityRatio")
)

In [26]:
negativity_ratio[]

Unnamed: 0,Patient,SessionNumber,NegativityRatio
0,P0001,S0001,0.028571
1,P0001,S0002,0.172414
2,P0001,S0003,0.000000
3,P0001,S0004,0.200000
4,P0001,S0005,0.025000
...,...,...,...
234,P0014,S0011,0.744681
235,P0014,S0012,0.804878
236,P0014,S0013,0.738095
237,P0014,S0014,0.702128


In [28]:
# --- Filter for Patient 14 ---
patient_14_data = df[df["Patient"] == "P0014"]

# --- Compute Negativity Ratio for Patient 14 only ---
negativity_ratio_p14 = (
    patient_14_data.groupby(["Patient", "SessionNumber"])["Emotion"]
    .apply(lambda g: g.isin(NEGATIVE_EMOTIONS).sum() / len(g) if len(g) > 0 else 0)
    .reset_index(name="NegativityRatio")
)

negativity_ratio_p14


Unnamed: 0,Patient,SessionNumber,NegativityRatio
0,P0014,S0001,0.435897
1,P0014,S0002,0.317073
2,P0014,S0003,0.5625
3,P0014,S0004,0.847826
4,P0014,S0005,0.782609
5,P0014,S0006,0.8
6,P0014,S0007,0.789474
7,P0014,S0008,0.742857
8,P0014,S0009,0.608696
9,P0014,S0010,0.767442


In [None]:
# -------- Merge Session Scores + Label Session Status ---
session_status = scores_df.merge(negativity_ratio, on=["Patient", "SessionNumber"], how="left")

def classify_session(row):
    if row["NegativityRatio"] < 0.2 and row["GAD-7_Score"] <= 10 and row["PHQ-9_Score"] <= 10:
        return "Improving"
    elif row["NegativityRatio"] > 0.6 and (row["GAD-7_Score"] > 15 or row["PHQ-9_Score"] > 15):
        return "Deteriorating"
    else:
        return "Neutral"

session_status["SessionStatus"] = session_status.apply(classify_session, axis=1)

In [12]:
status_encoder = LabelEncoder()

session_status["SessionStatusLabel"] = status_encoder.fit_transform(session_status["SessionStatus"])
joblib.dump(status_encoder, "session_status_encoder.pkl")

['session_status_encoder.pkl']

In [None]:
# --- Merge Status Info into Main Data -----==
df = df.merge(
    session_status[["Patient", "SessionNumber", "SessionStatus", "SessionStatusLabel"]],
    on=["Patient", "SessionNumber"],
    how="left"
)

In [14]:
print("Unique emotions:", df["Emotion"].nunique())
print("Encoded classes:", len(emotion_encoder.classes_))
print("Example emotions:", list(emotion_encoder.classes_))

Unique emotions: 26
Encoded classes: 26
Example emotions: ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'joy', 'love', 'nervousness', 'neutral', 'optimism', 'pride', 'realization', 'remorse', 'sadness', 'surprise']


In [None]:
# --- Build Therapy Graphs (temporal edges only) --------
therapy_graphs = []
session_meta = []   # to support grouped/stratified splitting . . . 

for (patient, session), group in df.groupby(["Patient", "SessionNumber"]):
    if len(group) < 2:
        continue  # tp skip too-short sessions

    x = torch.tensor(np.stack(group["embedding"].values), dtype=torch.float)
    y_node = torch.tensor(group["emotion_label"].values, dtype=torch.long)
    y_graph = torch.tensor([group["SessionStatusLabel"].iloc[0]], dtype=torch.long)

    # temporal bidirectional chain (Prev/next nodes) .... 
    edge_index = torch.tensor(
        [[i, i + 1] for i in range(len(group) - 1)] + [[i + 1, i] for i in range(len(group) - 1)],
        dtype=torch.long
    ).T

    sid = f"{patient}_{session}"
    graph = Data(
        x=x,
        edge_index=edge_index,
        y=y_node,
        graph_label=y_graph,
        session_id=sid,
        patient_id=str(patient)  # keep as string for grouping 
    )
    therapy_graphs.append(graph)
    session_meta.append({
        "session_id": sid,
        "patient": str(patient),
        "status_label": int(y_graph.item())
    })

print(f"Graphs created: {len(therapy_graphs)}")

Graphs created: 239


In [None]:
# ------------- Session-level dataframe for splits ---
meta_df = pd.DataFrame(session_meta)
print("SessionStatus label counts (all):")
print(meta_df["status_label"].value_counts().sort_index())

SessionStatus label counts (all):
status_label
0     14
1      9
2    216
Name: count, dtype: int64


In [None]:
sid_to_idx = {g.session_id: i for i, g in enumerate(therapy_graphs)}

target_test_frac = 0.15
n_splits_candidates = list(range(3, 11))
best = dict(gap=1e9, fold=None, n_splits=None, frac=None)

for n_splits in n_splits_candidates:
    try:
        sgkf = StratifiedGroupKFold(n_splits=n_splits, shuffle=True, random_state=SEED)
        for train_idx_tmp, test_idx_tmp in sgkf.split(
            meta_df.index,
            y=meta_df["status_label"],
            groups=meta_df["patient"]
        ):
            frac = len(test_idx_tmp) / len(meta_df)
            gap = abs(frac - target_test_frac)
            if gap < best["gap"]:
                best.update(gap=gap, fold=(train_idx_tmp, test_idx_tmp), n_splits=n_splits, frac=frac)
    except ValueError:
        continue

if best["fold"] is None:
    raise RuntimeError(
        "Could not find a valid grouped+stratified split. "
        "Consider relaxing the target fraction or adding more data."
    )

train_pool_idx, test_idx = best["fold"]
print(
    f"Selected test fraction ≈ {best['frac']:.2f} "
    f"(target {target_test_frac}), via StratifiedGroupKFold(n_splits={best['n_splits']})"
)

Selected test fraction ≈ 0.16 (target 0.15), via StratifiedGroupKFold(n_splits=3)


In [None]:
test_sessions = meta_df.iloc[test_idx]["session_id"].tolist()
cv_sessions = meta_df.iloc[train_pool_idx]["session_id"].tolist()

test_graphs = [therapy_graphs[sid_to_idx[sid]] for sid in test_sessions]
cv_graphs = [therapy_graphs[sid_to_idx[sid]] for sid in cv_sessions]

def show_dist(tag, subset_sessions):
    sub = meta_df.set_index("session_id").loc[subset_sessions]
    counts = sub["status_label"].value_counts(normalize=True).sort_index()
    print(f"{tag} sessions: {len(subset_sessions)} | status dist:")
    print((counts * 100).round(1).astype(str) + "%")

show_dist("TEST", test_sessions)
show_dist("CV-POOL", cv_sessions)

cv_meta = meta_df.set_index("session_id").loc[cv_sessions].reset_index(drop=False)
cv_groups = cv_meta["patient"].values

unique_groups = np.unique(cv_groups)
n_groups = len(unique_groups)
n_splits_cv = min(10, n_groups)  # cap at 10, but not more than available patients

if n_splits_cv < 2:
    raise ValueError(
        f"Not enough distinct patients in CV pool for GroupKFold. "
        f"Found {n_groups} group(s). Consider reducing test size or adding data."
    )

print(f"Using GroupKFold with n_splits={n_splits_cv} (unique patients in CV pool: {n_groups})")

gkf = GroupKFold(n_splits=n_splits_cv)

cv_folds = []
for fold, (tr_idx, va_idx) in enumerate(gkf.split(cv_meta.index, groups=cv_groups), start=1):
    train_sids = cv_meta.iloc[tr_idx]["session_id"].tolist()
    val_sids   = cv_meta.iloc[va_idx]["session_id"].tolist()
    cv_folds.append((train_sids, val_sids))
    print(f"Fold {fold}: train={len(train_sids)} val={len(val_sids)}")

TEST sessions: 38 | status dist:
status_label
0     2.6%
1    18.4%
2    78.9%
Name: proportion, dtype: object
CV-POOL sessions: 201 | status dist:
status_label
0     6.5%
1     1.0%
2    92.5%
Name: proportion, dtype: object
Using GroupKFold with n_splits=10 (unique patients in CV pool: 11)
Fold 1: train=136 val=65
Fold 2: train=186 val=15
Fold 3: train=186 val=15
Fold 4: train=186 val=15
Fold 5: train=186 val=15
Fold 6: train=186 val=15
Fold 7: train=186 val=15
Fold 8: train=186 val=15
Fold 9: train=186 val=15
Fold 10: train=185 val=16


In [None]:
def add_positional_feature_to_graph(g: Data) -> Data:
    n = g.x.size(0)
    if n == 1:
        pos = torch.tensor([[0.0]], dtype=torch.float, device=g.x.device)
    else:
        pos_vals = torch.arange(n, dtype=torch.float, device=g.x.device) / max(1, (n - 1))
        pos = pos_vals.view(-1, 1)
    if g.x.size(1) == 1024:  # raw E5 dims
        g.x = torch.cat([g.x, pos], dim=1)  # -> 1025 dims
    return g

for g in therapy_graphs:
    add_positional_feature_to_graph(g)

input_dim = therapy_graphs[0].x.size(1)
print(f"Node feature dim after adding positional feature: {input_dim}")

Node feature dim after adding positional feature: 1025


In [None]:
# ---== 2-layer GATv2 + LayerNorm + dropout --------
from torch_geometric.nn import GATv2Conv, GlobalAttention

class GATv2MultiTask_Small(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_node_classes, num_graph_classes, heads=2, dropout=0.35):
        super().__init__()
        self.dropout_p = dropout

        self.gat1 = GATv2Conv(input_dim, hidden_dim, heads=heads, dropout=dropout)
        self.ln1  = nn.LayerNorm(hidden_dim * heads)

        self.gat2 = GATv2Conv(hidden_dim * heads, hidden_dim, heads=1, dropout=dropout)
        self.ln2  = nn.LayerNorm(hidden_dim)

        self.global_attention = GlobalAttention(
            gate_nn=nn.Sequential(
                nn.Linear(hidden_dim, hidden_dim),
                nn.ReLU(),
                nn.Linear(hidden_dim, 1)
            )
        )

        self.dropout = nn.Dropout(dropout)

        self.node_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_node_classes)
        )

        self.graph_classifier = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim // 2, num_graph_classes)
        )

    def forward(self, x, edge_index, batch):
        x = self.gat1(x, edge_index); x = self.ln1(x); x = F.elu(x)
        x = self.gat2(x, edge_index); x = self.ln2(x); x = F.elu(x)
        x = self.dropout(x)

        node_out = self.node_classifier(x)
        graph_emb = self.global_attention(x, batch)
        graph_out = self.graph_classifier(graph_emb)
        return node_out, graph_out


In [None]:
# -======-- Utilities: loaders, softened class weights, losses, evaluation ---
def make_loaders_from_session_ids(train_sids, val_sids, batch_size=8):
    train_graphs = [therapy_graphs[sid_to_idx[sid]] for sid in train_sids]
    val_graphs   = [therapy_graphs[sid_to_idx[sid]] for sid in val_sids]
    train_loader = DataLoader(train_graphs, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_graphs, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader

def compute_class_weights_for_fold(train_sids, num_node_classes, num_graph_classes, eps=1e-6,
                                   min_w=0.5, max_w=2.0):
    # Node-level counts
    node_counts = np.zeros(num_node_classes, dtype=np.float64)
    # Graph-level counts
    graph_counts = np.zeros(num_graph_classes, dtype=np.float64)

    for sid in train_sids:
        g = therapy_graphs[sid_to_idx[sid]]
        y_nodes = g.y.cpu().numpy()
        for c in y_nodes:
            node_counts[int(c)] += 1.0
        graph_counts[int(g.graph_label.item())] += 1.0

    inv_node = np.zeros_like(node_counts, dtype=np.float64)
    mask_nz = node_counts > 0
    inv_node[mask_nz] = 1.0 / node_counts[mask_nz]
    if (~mask_nz).any():
        fill_val = np.median(inv_node[mask_nz]) if mask_nz.any() else 1.0
        inv_node[~mask_nz] = fill_val

    inv_graph = np.zeros_like(graph_counts, dtype=np.float64)
    mask_gz = graph_counts > 0
    inv_graph[mask_gz] = 1.0 / graph_counts[mask_gz]
    if (~mask_gz).any():
        fill_val_g = np.median(inv_graph[mask_gz]) if mask_gz.any() else 1.0
        inv_graph[~mask_gz] = fill_val_g

    node_weights = inv_node / inv_node.sum() * num_node_classes
    graph_weights = inv_graph / inv_graph.sum() * num_graph_classes

    node_weights = np.clip(node_weights, min_w, max_w)
    graph_weights = np.clip(graph_weights, min_w, max_w)

    node_weights_t = torch.tensor(node_weights, dtype=torch.float, device=device)
    graph_weights_t = torch.tensor(graph_weights, dtype=torch.float, device=device)
    return node_weights_t, graph_weights_t


def maybe_augment_x(x, sigma=0.01, p=0.5):
    if sigma <= 0 or p <= 0:
        return x
    if torch.rand(1, device=x.device).item() < p:
        return x + torch.randn_like(x) * sigma
    return x

In [22]:
def evaluate_epoch(model, loader, loss_node_fn, loss_graph_fn, lambda_node: float, lambda_graph: float):
    model.eval()
    total_loss = 0.0
    all_node_preds, all_node_labels = [], []
    all_graph_preds, all_graph_labels = [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            node_out, graph_out = model(batch.x, batch.edge_index, batch.batch)

            loss_n = loss_node_fn(node_out, batch.y)
            loss_g = loss_graph_fn(graph_out, batch.graph_label)
            loss = lambda_node * loss_n + lambda_graph * loss_g
            total_loss += float(loss.item())

            all_node_preds.extend(node_out.argmax(dim=1).detach().cpu().tolist())
            all_node_labels.extend(batch.y.detach().cpu().tolist())
            all_graph_preds.extend(graph_out.argmax(dim=1).detach().cpu().tolist())
            all_graph_labels.extend(batch.graph_label.detach().cpu().tolist())

    avg_loss = total_loss / max(1, len(loader))
    node_acc = accuracy_score(all_node_labels, all_node_preds) if all_node_labels else 0.0
    graph_acc = accuracy_score(all_graph_labels, all_graph_preds) if all_graph_labels else 0.0
    node_f1 = f1_score(all_node_labels, all_node_preds, average="macro", zero_division=0) if all_node_labels else 0.0
    graph_f1 = f1_score(all_graph_labels, all_graph_preds, average="macro", zero_division=0) if all_graph_labels else 0.0

    return {
        "loss": avg_loss,
        "node_acc": node_acc, "node_f1": node_f1,
        "graph_acc": graph_acc, "graph_f1": graph_f1,
        "node_preds": all_node_preds, "node_labels": all_node_labels,
        "graph_preds": all_graph_preds, "graph_labels": all_graph_labels
    }

In [None]:
# -------- Training Hyperparametersss ---
num_node_classes  = len(emotion_encoder.classes_)
num_graph_classes = len(status_encoder.classes_)

EPOCHS = 200          
BATCH_SIZE = 8
PATIENCE = 30         
LR = 1e-4            
WEIGHT_DECAY = 5e-4  
HEADS = 2
HIDDEN = 64
DROPOUT = 0.35

LAMBDA_GRAPH = 1.0
LAMBDA_NODE_START = 0.8   
LAMBDA_NODE_END   = 0.5
ANNEAL_START_FRAC = 0.40
ANNEAL_STRATEGY   = "linear"

print(
    "Loss balance config: "
    f"λ_graph={LAMBDA_GRAPH}, λ_node_start={LAMBDA_NODE_START}, "
    f"λ_node_end={LAMBDA_NODE_END}, anneal_start_frac={ANNEAL_START_FRAC}, "
    f"strategy={ANNEAL_STRATEGY}"
)

Loss balance config: λ_graph=1.0, λ_node_start=0.8, λ_node_end=0.5, anneal_start_frac=0.4, strategy=linear


In [None]:
test_loader = DataLoader(test_graphs, batch_size=BATCH_SIZE, shuffle=False)

fold_results = []

for fold_id, (train_sids, val_sids) in enumerate(cv_folds, start=1):
    print(f"\n========== Fold {fold_id}/{len(cv_folds)} ==========")
    train_loader, val_loader = make_loaders_from_session_ids(train_sids, val_sids, batch_size=BATCH_SIZE)

    node_w, graph_w = compute_class_weights_for_fold(
        train_sids, num_node_classes, num_graph_classes, min_w=0.5, max_w=2.0
    )


    model = GATv2MultiTask_Small(
        input_dim=input_dim,
        hidden_dim=HIDDEN,
        num_node_classes=num_node_classes,
        num_graph_classes=num_graph_classes,
        heads=HEADS,
        dropout=DROPOUT
    ).to(device)

    optimizer = torch.optim.Adam(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=LR, weight_decay=WEIGHT_DECAY
    )
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=0.5, patience=10, verbose=False
    )

    loss_node_fn  = nn.CrossEntropyLoss(weight=node_w, label_smoothing=0.05)
    loss_graph_fn = nn.CrossEntropyLoss(weight=graph_w)

    best_val_loss = float("inf")
    best_state = None
    patience_ctr = 0
    val_metrics = {"loss": float("inf"), "node_acc": 0, "graph_acc": 0, "node_f1": 0, "graph_f1": 0}

    pbar = tqdm(range(1, EPOCHS + 1), desc=f"Fold {fold_id}/{len(cv_folds)}", leave=False, ncols=120)

    start_epoch = int(ANNEAL_START_FRAC * EPOCHS)

    for epoch in pbar:
        if epoch < start_epoch:
            lambda_node  = LAMBDA_NODE_START
        else:
            if EPOCHS == start_epoch:
                t = 1.0
            else:
                t = (epoch - start_epoch) / max(1, (EPOCHS - start_epoch))
            if ANNEAL_STRATEGY == "cosine":
                w = 0.5 * (1 + math.cos(math.pi * t))
                lambda_node = LAMBDA_NODE_END + (LAMBDA_NODE_START - LAMBDA_NODE_END) * w
            else:
                lambda_node = LAMBDA_NODE_START + (LAMBDA_NODE_END - LAMBDA_NODE_START) * t

        lambda_graph = LAMBDA_GRAPH

        # --- training ---=================
        model.train()
        total_train_loss = 0.0
        all_node_preds, all_node_labels = [], []
        all_graph_preds, all_graph_labels = [], []

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            x_in = maybe_augment_x(batch.x, sigma=0.01, p=0.5)
            node_out, graph_out = model(x_in, batch.edge_index, batch.batch)

            loss_n = loss_node_fn(node_out, batch.y)
            loss_g = loss_graph_fn(graph_out, batch.graph_label)
            loss = lambda_node * loss_n + lambda_graph * loss_g
            loss.backward()

            clip_grad_norm_(model.parameters(), max_norm=1.0)

            optimizer.step()

            total_train_loss += float(loss.item())
            all_node_preds.extend(node_out.argmax(dim=1).detach().cpu().tolist())
            all_node_labels.extend(batch.y.detach().cpu().tolist())
            all_graph_preds.extend(graph_out.argmax(dim=1).detach().cpu().tolist())
            all_graph_labels.extend(batch.graph_label.detach().cpu().tolist())

        avg_train_loss = total_train_loss / max(1, len(train_loader))
        train_node_acc = accuracy_score(all_node_labels, all_node_preds) if all_node_labels else 0.0
        train_graph_acc = accuracy_score(all_graph_labels, all_graph_preds) if all_graph_labels else 0.0

        # --- validation ------------------- ---
        val_metrics = evaluate_epoch(model, val_loader, loss_node_fn, loss_graph_fn,
                                     lambda_node=lambda_node, lambda_graph=lambda_graph)
        scheduler.step(val_metrics["loss"])

        pbar.set_postfix({
            "λn": f"{lambda_node:.2f}",
            "TL": f"{avg_train_loss:.3f}",
            "VL": f"{val_metrics['loss']:.3f}",
            "Tn": f"{train_node_acc*100:.1f}",
            "Vn": f"{val_metrics['node_acc']*100:.1f}",
            "Tg": f"{train_graph_acc*100:.1f}",
            "Vg": f"{val_metrics['graph_acc']*100:.1f}",
        })

        if val_metrics["loss"] < best_val_loss - 1e-6:
            best_val_loss = val_metrics["loss"]
            best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
            patience_ctr = 0
        else:
            patience_ctr += 1
            if patience_ctr >= PATIENCE:
                pbar.close()
                print(f"Early stopping at epoch {epoch}")
                break

    pbar.close()
    print(f"Fold {fold_id} done. Best ValLoss={best_val_loss:.4f} | "
          f"Val Node Acc={val_metrics['node_acc']*100:.1f}% | Val Graph Acc={val_metrics['graph_acc']*100:.1f}%")

    model_path = f"gnn_multitask_best_fold{fold_id}.pth"
    torch.save(best_state, model_path)
    print(f"Saved best model for fold {fold_id}: {model_path}")

    model.load_state_dict(best_state)
    model.to(device)

    test_lambda_node = LAMBDA_NODE_END
    test_lambda_graph = LAMBDA_GRAPH

    test_metrics = evaluate_epoch(model, test_loader, loss_node_fn, loss_graph_fn,
                                  lambda_node=test_lambda_node, lambda_graph=test_lambda_graph)

    all_node_labels = np.array(test_metrics["node_labels"])
    all_node_preds  = np.array(test_metrics["node_preds"])

    mcc_overall = matthews_corrcoef(all_node_labels, all_node_preds) if len(np.unique(all_node_labels)) > 1 else 0.0
    overall_acc = accuracy_score(all_node_labels, all_node_preds)

    precision, recall, f1, support = precision_recall_fscore_support(
        all_node_labels, all_node_preds, labels=range(num_node_classes), zero_division=0
    )

    print("\nPer-emotion metrics on TEST set:")
    print(f"{'Emotion':<20} {'Acc':>7} {'Prec':>7} {'Rec':>7} {'F1':>7} {'MCC':>7} {'Support':>8}")
    for idx in range(num_node_classes):
        cls_name = emotion_encoder.classes_[idx]
        cls_support = int(support[idx])
        cls_acc = (all_node_preds[all_node_labels == idx] == idx).mean() if cls_support > 0 else 0.0
        if cls_support > 0:
            cls_mcc = matthews_corrcoef((all_node_labels == idx).astype(int),
                                        (all_node_preds == idx).astype(int))
        else:
            cls_mcc = 0.0
        print(f"{cls_name:<20} {cls_acc*100:7.2f} {precision[idx]*100:7.2f} "
              f"{recall[idx]*100:7.2f} {f1[idx]*100:7.2f} {cls_mcc:7.3f} {cls_support:8d}")

    print(f"\nOverall accuracy: {overall_acc*100:.2f}%")
    print(f"Overall MCC: {mcc_overall:.3f}")

    fold_results.append({
        "fold": fold_id,
        "val_loss": round(best_val_loss, 4),
        "val_node_acc": round(val_metrics["node_acc"], 4),
        "val_node_f1": round(val_metrics["node_f1"], 4),
        "val_graph_acc": round(val_metrics["graph_acc"], 4),
        "val_graph_f1": round(val_metrics["graph_f1"], 4),
        "test_node_acc": round(overall_acc, 4),
        "test_node_f1": round(f1.mean(), 4),
        "test_graph_acc": round(accuracy_score(test_metrics["graph_labels"], test_metrics["graph_preds"]), 4) if test_metrics["graph_labels"] else 0.0,
        "test_graph_f1": round(f1_score(test_metrics["graph_labels"], test_metrics["graph_preds"], average="macro", zero_division=0), 4) if test_metrics["graph_labels"] else 0.0,
    })




                                                                                                                        

Fold 1 done. Best ValLoss=1.2261 | Val Node Acc=37.3% | Val Graph Acc=100.0%
Saved best model for fold 1: gnn_multitask_best_fold1.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              0.00    0.00    0.00    0.00  -0.009       49
amusement               0.00    0.00    0.00    0.00  -0.005       12
anger                   0.00    0.00    0.00    0.00  -0.002        5
annoyance               0.00    0.00    0.00    0.00  -0.004       42
approval               67.82   28.92   67.82   40.55   0.337      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity              66.67   14.81   66.67   24.24   0.309        6
desire                  5.94   40.00    5.94   10.34   0.134      101
disappointment         54.48   35.27   54.48   42.82   0.367      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Early stopping at epoch 31
Fold 2 done. Best ValLoss=3.9822 | Val Node Acc=29.9% | Val Graph Acc=20.0%
Saved best model for fold 2: gnn_multitask_best_fold2.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              2.04    1.79    2.04    1.90  -0.015       49
amusement               0.00    0.00    0.00    0.00   0.000       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval                1.72   30.00    1.72    3.26   0.048      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity              16.67    0.53   16.67    1.02   0.009        6
desire                  0.00    0.00    0.00    0.00   0.000      101
disappointment          0.00    0.00    0.00    0.00   0.000      145
disapproval             0.00    0.0

                                                                                                                        

Fold 3 done. Best ValLoss=1.2699 | Val Node Acc=34.8% | Val Graph Acc=100.0%
Saved best model for fold 3: gnn_multitask_best_fold3.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              2.04    8.33    2.04    3.28   0.026       49
amusement              16.67  100.00   16.67   28.57   0.407       12
anger                   0.00    0.00    0.00    0.00  -0.002        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               54.02   31.02   54.02   39.41   0.311      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity              16.67   25.00   16.67   20.00   0.202        6
desire                 10.89   55.00   10.89   18.18   0.225      101
disappointment         54.48   34.50   54.48   42.25   0.361      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Fold 4 done. Best ValLoss=1.2144 | Val Node Acc=32.9% | Val Graph Acc=100.0%
Saved best model for fold 4: gnn_multitask_best_fold4.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              4.08   22.22    4.08    6.90   0.083       49
amusement              33.33   50.00   33.33   40.00   0.405       12
anger                   0.00    0.00    0.00    0.00  -0.003        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               60.92   27.18   60.92   37.59   0.295      174
caring                  0.00    0.00    0.00    0.00  -0.006       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00  -0.005        6
desire                  1.98   33.33    1.98    3.74   0.068      101
disappointment         59.31   32.21   59.31   41.75   0.359      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Early stopping at epoch 84
Fold 5 done. Best ValLoss=2.3573 | Val Node Acc=28.5% | Val Graph Acc=93.3%
Saved best model for fold 5: gnn_multitask_best_fold5.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              0.00    0.00    0.00    0.00   0.000       49
amusement               0.00    0.00    0.00    0.00  -0.005       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               74.71   20.19   74.71   31.78   0.242      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00  -0.002        6
desire                  0.00    0.00    0.00    0.00   0.000      101
disappointment         67.59   24.56   67.59   36.03   0.309      145
disapproval             0.00    0.0

                                                                                                                        

Early stopping at epoch 55
Fold 6 done. Best ValLoss=2.5065 | Val Node Acc=26.9% | Val Graph Acc=80.0%
Saved best model for fold 6: gnn_multitask_best_fold6.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              0.00    0.00    0.00    0.00   0.000       49
amusement               0.00    0.00    0.00    0.00   0.000       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               70.11   20.37   70.11   31.57   0.232      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00   0.000        6
desire                  0.00    0.00    0.00    0.00   0.000      101
disappointment         49.66   32.00   49.66   38.92   0.322      145
disapproval             0.00    0.0

                                                                                                                        

Fold 7 done. Best ValLoss=1.3148 | Val Node Acc=28.8% | Val Graph Acc=100.0%
Saved best model for fold 7: gnn_multitask_best_fold7.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              6.12    7.89    6.12    6.90   0.043       49
amusement               0.00    0.00    0.00    0.00   0.000       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               60.92   28.57   60.92   38.90   0.310      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity              16.67   14.29   16.67   15.38   0.151        6
desire                  4.95   50.00    4.95    9.01   0.142      101
disappointment         68.97   32.89   68.97   44.54   0.400      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Fold 8 done. Best ValLoss=1.3149 | Val Node Acc=29.1% | Val Graph Acc=100.0%
Saved best model for fold 8: gnn_multitask_best_fold8.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              2.04   20.00    2.04    3.70   0.055       49
amusement               0.00    0.00    0.00    0.00  -0.004       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               58.62   25.95   58.62   35.98   0.274      174
caring                  3.03   50.00    3.03    5.71   0.119       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00   0.000        6
desire                  6.93   38.89    6.93   11.76   0.142      101
disappointment         48.97   33.49   48.97   39.78   0.331      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Fold 9 done. Best ValLoss=1.2163 | Val Node Acc=36.8% | Val Graph Acc=100.0%
Saved best model for fold 9: gnn_multitask_best_fold9.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              6.12   20.00    6.12    9.38   0.095       49
amusement              33.33   21.05   33.33   25.81   0.258       12
anger                   0.00    0.00    0.00    0.00  -0.001        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               54.02   28.66   54.02   37.45   0.288      174
caring                  3.03  100.00    3.03    5.88   0.172       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00   0.000        6
desire                  8.91   45.00    8.91   14.88   0.179      101
disappointment         60.00   35.51   60.00   44.62   0.390      145
disapproval             0.00    0.00    0.00    0.00   0.000 

                                                                                                                        

Early stopping at epoch 48
Fold 10 done. Best ValLoss=3.3545 | Val Node Acc=23.3% | Val Graph Acc=81.2%
Saved best model for fold 10: gnn_multitask_best_fold10.pth

Per-emotion metrics on TEST set:
Emotion                  Acc    Prec     Rec      F1     MCC  Support
admiration              0.00    0.00    0.00    0.00   0.000       49
amusement               0.00    0.00    0.00    0.00   0.000       12
anger                   0.00    0.00    0.00    0.00   0.000        5
annoyance               0.00    0.00    0.00    0.00   0.000       42
approval               56.90   20.84   56.90   30.51   0.205      174
caring                  0.00    0.00    0.00    0.00   0.000       33
confusion               0.00    0.00    0.00    0.00   0.000       30
curiosity               0.00    0.00    0.00    0.00   0.000        6
desire                  0.00    0.00    0.00    0.00   0.000      101
disappointment         83.45   21.27   83.45   33.89   0.313      145
disapproval             0.00    



In [None]:
#  Summaryyyy ---
results_df = pd.DataFrame(fold_results)

if results_df.empty:
    print("No fold results to display.")
else:
    col_order = [
        "fold",
        "val_node_acc", "val_node_f1", "val_graph_acc", "val_graph_f1",
        "test_node_acc", "test_node_f1", "test_graph_acc", "test_graph_f1",
        "val_loss"
    ]
    cols = [c for c in col_order if c in results_df.columns]
    df_summary = results_df[cols].copy()

    avg_row = df_summary.drop(columns=["fold"], errors="ignore").mean(numeric_only=True).to_dict()
    avg_row["fold"] = "AVG"
    df_summary = pd.concat([df_summary, pd.DataFrame([avg_row])], ignore_index=True)

    pct_cols = [c for c in df_summary.columns if c.endswith("_acc") or c.endswith("_f1")]
    df_num = df_summary.copy()
    for c in pct_cols:
        df_num[c] = pd.to_numeric(df_num[c], errors="coerce") * 100.0
    if "val_loss" in df_num.columns:
        df_num["val_loss"] = pd.to_numeric(df_num["val_loss"], errors="coerce")

    with pd.option_context('display.max_columns', None, 'display.width', 120):
        print("\nK-Fold Summary (best val + test) — last row is AVG")
        print(df_num.to_string(index=False, formatters={**{c: "{:.1f}".format for c in pct_cols},
                                                        **({"val_loss": "{:.3f}".format} if "val_loss" in df_num.columns else {})}))


K-Fold Summary (best val + test) — last row is AVG
fold val_node_acc val_node_f1 val_graph_acc val_graph_f1 test_node_acc test_node_f1 test_graph_acc test_graph_f1 val_loss
   1         37.2        14.2         100.0        100.0          33.9         22.2           86.8          50.8    1.226
   2         29.9         4.6          20.0         16.7           4.9          1.0           79.0          29.4    3.982
   3         34.8        22.7         100.0        100.0          34.6         20.9           84.2          45.1    1.270
   4         32.9        19.3         100.0        100.0          31.8         21.0           86.8          50.8    1.214
   5         28.5         8.8          93.3         48.3          24.6          6.8           79.0          29.4    2.357
   6         26.9         7.5          80.0         44.4          21.2          4.3           79.0          29.4    2.506
   7         28.8        21.0         100.0        100.0          33.6         22.5           