In [None]:
# ============================================================
# ONE-DATASET (Sadness) RUN: Embeddings + GNN + Text + Ensemble+
# +  GNN Explainer with visualization
# ============================================================

!pip -q install torch-geometric faiss-cpu sentence-transformers transformers accelerate scikit-learn networkx matplotlib

import os, random, gc
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F

import faiss
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

import networkx as nx
import matplotlib.pyplot as plt

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import k_hop_subgraph

# -------------------------
# 0) Reproducibility
# -------------------------
SEED = 42
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed(SEED)

def cleanup_cuda():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

# -------------------------
# 1) CONFIG (Sadness only)
# -------------------------
CSV_PATH  = "/content/sample_data/Surprise_anon.csv"
LABEL_COL = "Surprise"
TEXT_COL  = "Sentence"
SPLIT_COL = "Split"     # 0=train, 1=val, 2=test

# Graph hyperparams
K   = 10
THR = 0.60

# Models
SENT_MODEL      = "pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb"
BASE_TEXT_MODEL = "dmis-lab/biobert-base-cased-v1.1"

# GNN training hyperparams
GNN_HID      = 256
GNN_DROPOUT  = 0.5
GNN_LR       = 5e-4
GNN_WD       = 5e-4
GNN_EPOCHS   = 300
GNN_PATIENCE = 25

# Text training hyperparams (keep yours; can change later)
TEXT_EPOCHS  = 4
TEXT_BS      = 16
TEXT_LR      = 3e-5
TEXT_WD      = 0.01
MAX_LEN      = 128

# Explainer hyperparams
EXPL_NUM_HOPS = 2
EXPL_EPOCHS   = 200
EXPL_LR       = 0.05
EXPL_LAM_SIZE = 0.01
EXPL_LAM_ENT  = 0.001
EXPL_TOP_EDGES= 25

# -------------------------
# 2) Helpers
# -------------------------
def acc_from_probs(p, ytrue):
    pred = p.argmax(axis=1)
    return float((pred == ytrue).mean())

def metrics_from_probs(p, ytrue):
    pred = p.argmax(axis=1)
    return {
        "acc": float(accuracy_score(ytrue, pred)),
        "f1": float(f1_score(ytrue, pred, zero_division=0)),
        "precision": float(precision_score(ytrue, pred, zero_division=0)),
        "recall": float(recall_score(ytrue, pred, zero_division=0)),
    }

def make_class_weights(y, train_mask):
    y_train = np.array(y)[train_mask]
    pos = int((y_train == 1).sum())
    neg = int((y_train == 0).sum())
    if pos == 0 or neg == 0:
        return torch.tensor([1.0, 1.0], dtype=torch.float, device=device)
    w0 = (pos + neg) / (2.0 * neg)
    w1 = (pos + neg) / (2.0 * pos)
    return torch.tensor([w0, w1], dtype=torch.float, device=device)

class WeightedGCN(nn.Module):
    def __init__(self, in_dim, hid=256, num_classes=2, dropout=0.5):
        super().__init__()
        self.dropout = dropout
        self.c1 = GCNConv(in_dim, hid)
        self.c2 = GCNConv(hid, num_classes)

    def forward(self, x, edge_index, edge_weight=None):
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.c1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.c2(x, edge_index, edge_weight)
        return x

class SimpleDataset(torch.utils.data.Dataset):
    def __init__(self, enc, y):
        self.enc = enc
        self.y = y
    def __len__(self):
        return len(self.y)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.enc.items()}
        item["labels"] = torch.tensor(int(self.y[idx]))
        return item

def comp_metrics_hf(eval_pred):
    logits, labels_ = eval_pred
    preds = np.argmax(logits, axis=1)
    return {
        "accuracy": accuracy_score(labels_, preds),
        "f1": f1_score(labels_, preds, zero_division=0),
        "precision": precision_score(labels_, preds, zero_division=0),
        "recall": recall_score(labels_, preds, zero_division=0),
    }

# -------------------------
# 3) Load shared models
# -------------------------
sent_model = SentenceTransformer(SENT_MODEL, device=device)
tok = AutoTokenizer.from_pretrained(BASE_TEXT_MODEL)

def tokenize_list(text_list, max_len=128):
    return tok(text_list, truncation=True, padding=True, max_length=max_len)

# -------------------------
# 4) Load dataset
# -------------------------
print("\n" + "="*70)
print("DATASET: Surprise")
print("="*70)

df = pd.read_csv(CSV_PATH)
df = df[[TEXT_COL, LABEL_COL, SPLIT_COL]].dropna().reset_index(drop=True)

texts  = df[TEXT_COL].astype(str).tolist()
labels = df[LABEL_COL].astype(int).tolist()
splits = df[SPLIT_COL].astype(int).tolist()

splits_np = np.array(splits)
train_mask = splits_np == 0
val_mask   = splits_np == 1
test_mask  = splits_np == 2

print("Total:", len(df))
print("Train/Val/Test:", int(train_mask.sum()), int(val_mask.sum()), int(test_mask.sum()))
print("Train label counts:", df[df[SPLIT_COL]==0][LABEL_COL].value_counts().to_dict())

# -------------------------
# 5) Sentence embeddings for graph features
# -------------------------
X = sent_model.encode(
    texts,
    batch_size=64,
    convert_to_tensor=True,
    normalize_embeddings=True
)
X_np = X.detach().cpu().numpy().astype("float32")

# -------------------------
# 6) Leakage-free kNN graph
#    - build faiss index on TRAIN+VAL only
#    - connect TEST -> TRAIN+VAL only
# -------------------------
idx_trainval = np.where(splits_np != 2)[0]
idx_test     = np.where(splits_np == 2)[0]

index = faiss.IndexFlatIP(X_np.shape[1])
index.add(X_np[idx_trainval])

edge_src, edge_dst, edge_wt = [], [], []

# (A) edges among train+val
sims, nbrs = index.search(X_np[idx_trainval], K+1)
for local_i, i in enumerate(idx_trainval):
    for j in range(1, K+1):
        nb_local = int(nbrs[local_i, j])
        nb = int(idx_trainval[nb_local])
        w = float(sims[local_i, j])
        if w >= THR:
            edge_src.append(i); edge_dst.append(nb); edge_wt.append(w)

# (B) test -> train+val
sims_t, nbrs_t = index.search(X_np[idx_test], K)
for local_i, i in enumerate(idx_test):
    for j in range(K):
        nb_local = int(nbrs_t[local_i, j])
        nb = int(idx_trainval[nb_local])
        w = float(sims_t[local_i, j])
        if w >= THR:
            edge_src.append(i); edge_dst.append(nb); edge_wt.append(w)

edge_index = torch.tensor([edge_src, edge_dst], dtype=torch.long)
edge_weight = torch.tensor(edge_wt, dtype=torch.float)

# make undirected
rev_edge_index = torch.stack([edge_index[1], edge_index[0]], dim=0)
edge_index = torch.cat([edge_index, rev_edge_index], dim=1)
edge_weight = torch.cat([edge_weight, edge_weight.clone()], dim=0)

print("Undirected edges:", int(edge_index.shape[1]))

# -------------------------
# 7) PyG data
# -------------------------
y = torch.tensor(labels, dtype=torch.long)

data = Data(
    x=X.detach().cpu(),
    edge_index=edge_index,
    edge_weight=edge_weight,
    y=y,
    train_mask=torch.tensor(train_mask),
    val_mask=torch.tensor(val_mask),
    test_mask=torch.tensor(test_mask),
).to(device)

class_weights = make_class_weights(labels, train_mask)
criterion = nn.CrossEntropyLoss(weight=class_weights)
print("Class weights:", class_weights.detach().cpu().tolist())

# -------------------------
# 8) Train GNN
# -------------------------
gnn = WeightedGCN(in_dim=data.x.size(1), hid=GNN_HID, num_classes=2, dropout=GNN_DROPOUT).to(device)
opt = torch.optim.AdamW(gnn.parameters(), lr=GNN_LR, weight_decay=GNN_WD)

def gnn_train(max_epochs=300, patience=25):
    best_val = -1
    best_state = None
    bad = 0

    for epoch in range(1, max_epochs+1):
        gnn.train()
        opt.zero_grad()
        logits = gnn(data.x, data.edge_index, data.edge_weight)
        loss = criterion(logits[data.train_mask], data.y[data.train_mask])
        loss.backward()
        torch.nn.utils.clip_grad_norm_(gnn.parameters(), 1.0)
        opt.step()

        if epoch % 10 == 0 or epoch == 1:
            gnn.eval()
            pred = logits.argmax(dim=1)
            tr = (pred[data.train_mask] == data.y[data.train_mask]).float().mean().item()
            va = (pred[data.val_mask]   == data.y[data.val_mask]).float().mean().item()
            te = (pred[data.test_mask]  == data.y[data.test_mask]).float().mean().item()
            print(f"GNN Epoch {epoch:03d} | loss {loss.item():.4f} | train {tr:.3f} | val {va:.3f} | test {te:.3f}")

            if va > best_val:
                best_val = va
                best_state = {k: v.detach().cpu().clone() for k, v in gnn.state_dict().items()}
                bad = 0
            else:
                bad += 1
                if bad >= patience:
                    print(f"GNN Early stop at epoch {epoch} (best val={best_val:.3f})")
                    break

    if best_state is not None:
        gnn.load_state_dict(best_state)
    return float(best_val)

best_val_gnn = gnn_train(max_epochs=GNN_EPOCHS, patience=GNN_PATIENCE)

@torch.no_grad()
def gnn_probs(mask_tensor):
    gnn.eval()
    logits = gnn(data.x, data.edge_index, data.edge_weight)
    probs = torch.softmax(logits, dim=1)
    return probs[mask_tensor].detach().cpu().numpy()

p_gnn_val  = gnn_probs(data.val_mask)
p_gnn_test = gnn_probs(data.test_mask)

y_val  = df[val_mask][LABEL_COL].astype(int).to_numpy()
y_test = df[test_mask][LABEL_COL].astype(int).to_numpy()

gnn_test_metrics = metrics_from_probs(p_gnn_test, y_test)
print("\nGNN metrics (test):", gnn_test_metrics)

# -------------------------
# 9) Train Text model (BioBERT)
# -------------------------
X_train = df[train_mask][TEXT_COL].tolist()
y_train = df[train_mask][LABEL_COL].astype(int).tolist()
X_val   = df[val_mask][TEXT_COL].tolist()
y_val_l = df[val_mask][LABEL_COL].astype(int).tolist()
X_test  = df[test_mask][TEXT_COL].tolist()
y_test_l= df[test_mask][LABEL_COL].astype(int).tolist()

train_ds = SimpleDataset(tokenize_list(X_train, MAX_LEN), y_train)
val_ds   = SimpleDataset(tokenize_list(X_val,   MAX_LEN), y_val_l)
test_ds  = SimpleDataset(tokenize_list(X_test,  MAX_LEN), y_test_l)

text_model = AutoModelForSequenceClassification.from_pretrained(BASE_TEXT_MODEL, num_labels=2)

args = TrainingArguments(
    output_dir="./tmp_text_sadness",
    num_train_epochs=TEXT_EPOCHS,
    per_device_train_batch_size=TEXT_BS,
    per_device_eval_batch_size=TEXT_BS,
    learning_rate=TEXT_LR,
    weight_decay=TEXT_WD,
    eval_strategy="epoch",
    save_strategy="no",
    logging_steps=50,
    report_to="none",
    fp16=torch.cuda.is_available(),
    seed=SEED,
)

trainer = Trainer(
    model=text_model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    compute_metrics=comp_metrics_hf,
)
trainer.train()

@torch.no_grad()
def text_probs(dataset):
    preds = trainer.predict(dataset)
    logits = preds.predictions
    probs = torch.softmax(torch.tensor(logits), dim=1).cpu().numpy()
    return probs

p_text_val  = text_probs(val_ds)
p_text_test = text_probs(test_ds)

text_test_metrics = metrics_from_probs(p_text_test, y_test)
print("\nText metrics (test):", text_test_metrics)

# -------------------------
# 10) Ensemble (alpha tuned on VAL)
# -------------------------
best_alpha = None
best_val_acc = -1
best_val_metrics = None
best_test_metrics = None

for alpha in [i/10 for i in range(11)]:
    p_ens_val  = alpha * p_text_val  + (1 - alpha) * p_gnn_val
    val_acc = acc_from_probs(p_ens_val, y_val)

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_alpha = alpha
        best_val_metrics  = metrics_from_probs(p_ens_val,  y_val)

        p_ens_test = alpha * p_text_test + (1 - alpha) * p_gnn_test
        best_test_metrics = metrics_from_probs(p_ens_test, y_test)

print("\n---------------- RESULTS (Sadness) ----------------")
print(f"GNN best val acc: {best_val_gnn:.3f}")
print("GNN  (test):", gnn_test_metrics)
print("Text (test):", text_test_metrics)
print(f"Ensemble best alpha: {best_alpha}")
print("Ensemble (val):", best_val_metrics)
print("Ensemble(test):", best_test_metrics)

# ============================================================
# 11) STABLE RexYing-style Explainer (edge mask)
# ============================================================
@torch.no_grad()
def pick_test_node_prefer_correct(data, logits):
    pred = logits.argmax(dim=1).detach().cpu().numpy()
    y    = data.y.detach().cpu().numpy()
    test_ids = torch.where(data.test_mask)[0].detach().cpu().numpy().tolist()
    correct = [i for i in test_ids if pred[i] == y[i]]
    return int(correct[0]) if len(correct) > 0 else int(test_ids[0])

def explain_node_edge_mask_stable(
    model,
    data,
    node_idx,
    num_hops=2,
    epochs=200,
    lr=0.05,
    lam_size=0.01,
    lam_ent=0.001,
    top_edges=25,
    eps=1e-9,
):
    """
    Learns an edge mask for a local k-hop subgraph around node_idx.
    Robust to:
      - empty subgraphs (adds self-loop)
      - NaNs (entropy stability + clamping)
    """
    model.eval()

    subset, edge_index_sub, mapping, edge_id = k_hop_subgraph(
        node_idx=int(node_idx),
        num_hops=int(num_hops),
        edge_index=data.edge_index,
        relabel_nodes=True,
        num_nodes=data.num_nodes,
        flow="source_to_target",
    )

    subset_orig = subset.detach().cpu().numpy()
    center_local = int(mapping.item())

    x_sub = data.x[subset]

    # subgraph edge weights
    if data.edge_weight is None:
        base_ew = torch.ones(edge_index_sub.size(1), device=data.x.device)
    else:
        base_ew = data.edge_weight[edge_id].clone()

    # If no edges, add a self-loop so explainer has something to optimize
    if edge_index_sub.numel() == 0 or edge_index_sub.size(1) == 0:
        edge_index_sub = torch.tensor([[center_local],[center_local]],
                                      device=data.x.device, dtype=torch.long)
        base_ew = torch.ones(1, device=data.x.device)

    base_ew = torch.nan_to_num(base_ew, nan=0.0, posinf=1.0, neginf=0.0)
    base_ew = torch.clamp(base_ew, min=0.0, max=1.0)

    with torch.no_grad():
        out_clean = model(x_sub, edge_index_sub, base_ew)  # [subN,2]
        target_class = int(out_clean[center_local].argmax().item())

    mask_logits = torch.nn.Parameter(torch.zeros(edge_index_sub.size(1), device=data.x.device))
    optm = torch.optim.Adam([mask_logits], lr=lr)

    for ep in range(1, epochs + 1):
        optm.zero_grad()
        mask = torch.sigmoid(mask_logits)
        masked_ew = base_ew * mask

        out = model(x_sub, edge_index_sub, masked_ew)

        # scalar loss (center node)
        logp = F.log_softmax(out[center_local], dim=-1)   # [2]
        loss_pred = -logp[target_class]

        loss_size = lam_size * mask.sum()

        m = torch.clamp(mask, eps, 1 - eps)
        ent = -(m * torch.log(m) + (1 - m) * torch.log(1 - m))
        loss_ent = lam_ent * ent.mean()

        loss = loss_pred + loss_size + loss_ent
        if torch.isnan(loss) or torch.isinf(loss):
            loss = loss_pred

        loss.backward()
        optm.step()

        if ep % 50 == 0 or ep == 1:
            with torch.no_grad():
                prob_t = torch.softmax(out[center_local], dim=-1)[target_class].item()
                print(f"Explainer ep {ep:03d} | loss={float(loss.item()):.4f} | target_prob={prob_t:.4f} | mask_mean={float(mask.mean().item()):.3f}")

    edge_mask = torch.sigmoid(mask_logits).detach().cpu().numpy()
    order = np.argsort(-edge_mask)
    top_keep_idx = order[:min(top_edges, len(order))]

    return subset_orig, edge_index_sub.detach().cpu(), edge_mask, center_local, target_class, top_keep_idx

def plot_explanation_q1(
    subset_orig,
    edge_index_sub,
    edge_mask,
    center_local,
    top_keep_idx,
    true_full,
    pred_full,
    node_idx_global,
    sentence_text,
    save_path="sadness_explainer_q1.png"
):
    src = edge_index_sub[0].numpy()
    dst = edge_index_sub[1].numpy()

    G = nx.DiGraph()
    nodes_in_plot = set([center_local])

    for ei in top_keep_idx:
        u = int(src[ei]); v = int(dst[ei])
        imp = float(edge_mask[ei])
        if imp <= 1e-6:
            continue
        G.add_edge(u, v, imp=imp)
        nodes_in_plot.add(u); nodes_in_plot.add(v)

    for n in nodes_in_plot:
        G.add_node(n)

    # fallback: ensure at least one edge is plotted
    if G.number_of_edges() == 0:
        best_imp = float(np.max(edge_mask)) if len(edge_mask) else 1.0
        G.add_edge(center_local, center_local, imp=best_imp)

    node_true = {n: int(true_full[int(subset_orig[n])]) for n in G.nodes()}
    node_pred = {n: int(pred_full[int(subset_orig[n])]) for n in G.nodes()}

    sizes = [1400 if n == center_local else 850 for n in G.nodes()]
    node_colors = [node_true[n] for n in G.nodes()]

    imps = np.array([G[u][v]["imp"] for u, v in G.edges()])
    if imps.max() - imps.min() < 1e-9:
        widths = np.ones_like(imps) * 2.5
    else:
        widths = 0.8 + 5.2 * (imps - imps.min()) / (imps.max() - imps.min())

    pos = nx.spring_layout(G, seed=42, k=0.7)

    plt.figure(figsize=(10, 8))
    nx.draw_networkx_nodes(G, pos, node_size=sizes, node_color=node_colors, cmap=plt.cm.Set2)

    # border style encodes prediction
    ax = plt.gca()
    nodes_list = list(G.nodes())
    for n in nodes_list:
        x0, y0 = pos[n]
        ls = "--" if node_pred[n] == 1 else "-"
        ax.scatter([x0], [y0],
                   s=(sizes[nodes_list.index(n)] * 1.05),
                   facecolors="none",
                   edgecolors="black",
                   linewidths=2,
                   linestyles=ls)

    nx.draw_networkx_edges(G, pos, width=widths, arrows=True, arrowstyle="-|>", arrowsize=12, alpha=0.9)
    labels_local = {n: str(int(subset_orig[n])) for n in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels=labels_local, font_size=9)

    plt.title(f"Local GNN Explanation (Sadness) | test node={node_idx_global}")
    plt.axis("off")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

    print(f"\nSaved figure: {save_path}")
    print("\nSentence (explained node):")
    print(sentence_text)

    print("\nCaption (paste-ready):")
    print(
        "Local explanation of the GNN prediction for a representative test instance in the Sadness dataset. "
        "A k-hop ego-subgraph is extracted around the target node, and an edge-importance mask is learned by "
        "optimizing prediction fidelity with sparsity and entropy regularization (RexYing-style explainer). "
        "Only the highest-importance edges are visualized; thicker arrows indicate higher importance. "
        "Node color denotes the ground-truth label (0/1), while border style denotes the predicted label "
        "(solid: predicted 0; dashed: predicted 1)."
    )

# -------------------------
# 12) Run explainer + plot
# -------------------------
gnn.eval()
with torch.no_grad():
    logits_full = gnn(data.x, data.edge_index, data.edge_weight)

true_full = data.y.detach().cpu().numpy()
pred_full = logits_full.argmax(dim=1).detach().cpu().numpy()

node_idx = pick_test_node_prefer_correct(data, logits_full)
print(f"\n[Explainer] Using TEST node={node_idx} (correct preferred).")
print("Sentence:", texts[node_idx])

subset_orig, edge_index_sub, edge_mask, center_local, target_class, top_keep_idx = explain_node_edge_mask_stable(
    model=gnn,
    data=data,
    node_idx=node_idx,
    num_hops=EXPL_NUM_HOPS,     # if too sparse, set 3
    epochs=EXPL_EPOCHS,
    lr=EXPL_LR,
    lam_size=EXPL_LAM_SIZE,     # if no edges selected, reduce to 0.001
    lam_ent=EXPL_LAM_ENT,
    top_edges=EXPL_TOP_EDGES,
)

plot_explanation_q1(
    subset_orig=subset_orig,
    edge_index_sub=edge_index_sub,
    edge_mask=edge_mask,
    center_local=center_local,
    top_keep_idx=top_keep_idx,
    true_full=true_full,
    pred_full=pred_full,
    node_idx_global=node_idx,
    sentence_text=texts[node_idx],
    save_path="sadness_explainer_q1.png"
)

cleanup_cuda()
