In [33]:
# Cell 1: Basic Imports and Data Loading

import os
import json
import torch
import numpy as np
from pathlib import Path
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GINEConv, GATConv, GraphConv


# Load dataset
#file_path = 'dataset/ieee24/ieee24/processed_b/data.pt'
#file_path = 'dataset/ieee39/processed_b/data.pt'
#file_path = 'dataset/uk/processed_b/data.pt'
file_path = 'dataset/ieee118/processed_b/data.pt'

loaded_data = torch.load(file_path, map_location=torch.device('cpu'))


# Check dataset structure
print("Loaded data type:", type(loaded_data))
print("Length of dataset tuple:", len(loaded_data))

# Inspect first element (summary only)
print("\nFirst element type:", type(loaded_data[0]))

# Extract metadata dictionary
metadata_dict = loaded_data[1]
print("\nMetadata keys:", list(metadata_dict.keys()))

# Preview metadata values (first 10 entries)
for key, val in metadata_dict.items():
    preview = val[:10] if hasattr(val, '__len__') else "N/A"
    print(f"{key}: {type(val)}, first 10: {preview}")


  loaded_data = torch.load(file_path, map_location=torch.device('cpu'))


Loaded data type: <class 'tuple'>
Length of dataset tuple: 2

First element type: <class 'torch_geometric.data.data.Data'>

Metadata keys: ['x', 'edge_index', 'edge_attr', 'y', 'edge_mask', 'idx']
x: <class 'torch.Tensor'>, first 10: tensor([   0,  118,  236,  354,  472,  590,  708,  826,  944, 1062])
edge_index: <class 'torch.Tensor'>, first 10: tensor([   0,  370,  740, 1110, 1480, 1850, 2220, 2590, 2960, 3330])
edge_attr: <class 'torch.Tensor'>, first 10: tensor([   0,  370,  740, 1110, 1480, 1850, 2220, 2590, 2960, 3330])
y: <class 'torch.Tensor'>, first 10: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edge_mask: <class 'torch.Tensor'>, first 10: tensor([   0,  370,  740, 1110, 1480, 1850, 2220, 2590, 2960, 3330])
idx: <class 'torch.Tensor'>, first 10: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


In [34]:
# Cell 2: Add strict validation


def get_subgraph(data_flat, meta_dict, i):
    """
    data_flat: The big flattened Data object (loaded_data[0])
    meta_dict: The dictionary of offsets (loaded_data[1])
    i        : Index of the subgraph we want to reconstruct

    returns: a PyG Data object representing the i-th subgraph
    """
    # 1) Node offsets
    x_start = meta_dict['x'][i].item()
    x_end   = meta_dict['x'][i+1].item()
    x_i = data_flat.x[x_start:x_end]
    
    # 2) Edge offsets
    e_start = meta_dict['edge_index'][i].item()
    e_end   = meta_dict['edge_index'][i+1].item()
    edge_index_i = data_flat.edge_index[:, e_start:e_end]
    edge_attr_i  = data_flat.edge_attr[e_start:e_end]
    
    # 3) Load TRUE binary edge labels (explanation_mask)
    # --------------------------------------------------
    edge_mask_i = data_flat.edge_mask[e_start:e_end].float()  # Convert to float
    
    # NEW: Strict validation
    if not torch.all(torch.isin(edge_mask_i, torch.tensor([0., 1.]))):
        print(f"BAD SUBGRAPH {i}:")
        print("Unique values:", edge_mask_i.unique())
        print("Edge indices:", edge_index_i)
        raise ValueError("Edge mask contains non-binary values")
    
    # 4) Graph label (binary or multi-class)
    y_i = data_flat.y[i]
    
    # 5) Build a new Data object
    subgraph_i = Data(
        x=x_i,
        edge_index=edge_index_i,
        edge_attr=edge_attr_i,
        y=y_i.unsqueeze(0),  # Keep graph-level label if needed
        edge_mask=edge_mask_i  # Add binary edge labels
    )
    
    return subgraph_i

# Test subgraph reconstruction
i_test = 0
subgraph_0 = get_subgraph(loaded_data[0], loaded_data[1], i_test)
print("Subgraph 0:")
print(subgraph_0)
print("Edge mask values:", subgraph_0.edge_mask.unique())  # Should be [0., 1.]


Subgraph 0:
Data(x=[118, 3], edge_index=[2, 370], edge_attr=[370, 4], y=[1, 1], edge_mask=[370])
Edge mask values: tensor([0.])


In [35]:
#Cell 3: Verification 

# Sample dataset (replace with your actual dataset)
dataset = [
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 1, 0]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3),  # Category B or D
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 0, 1]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3)   # Category B or D
]

def verify_edge_mask_coverage(dataset):
    """Check if edge_mask is defined for all graphs with cascading failures (y=1)."""
    print("Verifying Edge Mask Coverage")
    has_cascading = 0
    has_edge_mask_defined = 0
    
    for i, graph in enumerate(dataset):
        if graph.y.item() == 1:  # Graphs with cascading failures (Categories A and C)
            has_cascading += 1
            if graph.edge_mask is not None and len(graph.edge_mask) == graph.num_edges:
                has_edge_mask_defined += 1
            else:
                print(f"Graph {i}: Missing or incomplete edge_mask for cascading failure graph.")
    
    print(f"Graphs with cascading failures: {has_cascading}")
    print(f"Graphs with defined edge_mask: {has_edge_mask_defined}")
    if has_cascading == has_edge_mask_defined:
        print(" Edge mask coverage is complete for cascading failure graphs.")
    else:
        print("Edge mask is missing or incomplete for some cascading failure graphs.")

def check_edge_label_distribution(dataset):
    """Examine the distribution of edge labels (1s and 0s) across graphs."""
    print("\n Checking Edge Label Distribution")
    total_edges = 0
    tripped_edges = 0
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        num_tripped = edge_mask.sum().item()
        total_edges += len(edge_mask)
        tripped_edges += num_tripped
        print(f"Graph {i}: {num_tripped} tripped edges (1s), {len(edge_mask) - num_tripped} non-tripped (0s)")
    
    print(f"Total edges: {total_edges}")
    print(f"Tripped edges (1s): {tripped_edges}")
    print(f"Non-tripped edges (0s): {total_edges - tripped_edges}")
    print(f"Percentage of tripped edges: {(tripped_edges / total_edges * 100):.2f}%")

def validate_graph_edge_consistency(dataset):
    """Ensure edge_mask aligns with graph-level labels (y)."""
    print("\n Validating Graph-Edge Label Consistency ")
    all_valid = True
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        y = graph.y.item()
        
        if y == 1:  # Categories A and C (cascading failures)
            if edge_mask.sum() == 0:
                print(f"Graph {i}: Inconsistent - y=1 but no tripped edges in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=1 and tripped edges present.")
        elif y == 0:  # Categories B and D (no cascading failures)
            if edge_mask.sum() > 0:
                print(f"Graph {i}: Inconsistent - y=0 but tripped edges present in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=0 and no tripped edges.")
    
    if all_valid:
        print("All graphs have consistent edge_mask and y labels.")
    else:
        print(" Some graphs have inconsistencies between edge_mask and y.")

# Run the verifications
verify_edge_mask_coverage(dataset)
check_edge_label_distribution(dataset)
validate_graph_edge_consistency(dataset)

Verifying Edge Mask Coverage
Graphs with cascading failures: 2
Graphs with defined edge_mask: 2
 Edge mask coverage is complete for cascading failure graphs.

 Checking Edge Label Distribution
Graph 0: 1 tripped edges (1s), 2 non-tripped (0s)
Graph 1: 0 tripped edges (1s), 3 non-tripped (0s)
Graph 2: 1 tripped edges (1s), 2 non-tripped (0s)
Graph 3: 0 tripped edges (1s), 3 non-tripped (0s)
Total edges: 12
Tripped edges (1s): 2
Non-tripped edges (0s): 10
Percentage of tripped edges: 16.67%

 Validating Graph-Edge Label Consistency 
Graph 0: Consistent - y=1 and tripped edges present.
Graph 1: Consistent - y=0 and no tripped edges.
Graph 2: Consistent - y=1 and tripped edges present.
Graph 3: Consistent - y=0 and no tripped edges.
All graphs have consistent edge_mask and y labels.


In [36]:
# Cell 4: Create a PyTorch Dataset for our subgraphs

class PowerGraphDataset(Dataset):
    def __init__(self, data_flat, meta_dict, indices=None, filter_category_A=True):
        """
        data_flat:  The giant flattened Data object
        meta_dict:  Dictionary of offsets
        indices:    Subgraph indices to include
        filter_category_A: If True, only include graphs with cascading failures (edge_mask != 0)
        """
        super().__init__()
        self.data_flat = data_flat
        self.meta_dict = meta_dict
        self.filter_category_A = filter_category_A
        
        if indices is None:
            # Default to all graphs (0 to num_subgraphs-1)
            self.indices = range(len(meta_dict['x']) - 1)
        else:
            self.indices = indices
        
        # Filter to Category A (DNS > 0 with cascading failures)
        if self.filter_category_A:
            self.indices = self._filter_category_A()
    
    def _filter_category_A(self):
        """Retain indices where edge_mask has at least one failure (1)"""
        valid_indices = []
        for idx in self.indices:
            e_start = self.meta_dict['edge_index'][idx].item()
            e_end = self.meta_dict['edge_index'][idx+1].item()
            edge_mask = self.data_flat.edge_mask[e_start:e_end]  # Use edge_mask
            if edge_mask.sum() > 0:  # At least one failed edge
                valid_indices.append(idx)
        return valid_indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        subgraph_id = self.indices[idx]
        return get_subgraph(self.data_flat, self.meta_dict, subgraph_id)

# Create dataset (only Category A graphs)
full_dataset = PowerGraphDataset(loaded_data[0], loaded_data[1], filter_category_A=True)
print("Total subgraphs in full_dataset:", len(full_dataset))

Total subgraphs in full_dataset: 1167


In [37]:
# After creating full_dataset (Cell 5):
all_edge_masks = torch.cat([batch.edge_mask for batch in full_dataset])
num_positive = all_edge_masks.sum().item()
num_negative = len(all_edge_masks) - num_positive

print(f"Edge label distribution:")
print(f"- Failed edges (1): {num_positive} ({num_positive / len(all_edge_masks):.2%})")
print(f"- Stable edges (0): {num_negative} ({num_negative / len(all_edge_masks):.2%})")

Edge label distribution:
- Failed edges (1): 2360.0 (0.55%)
- Stable edges (0): 428464.0 (99.45%)


In [38]:
# Cell 6: Train/Val/Test split & DataLoaders (with class-aware sampling)

import numpy as np
import torch
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch.utils.data import WeightedRandomSampler

# 1) Handle extreme class imbalance
# --------------------------------------------------------
num_subgraphs = len(full_dataset)
train_size   = int(0.8 * num_subgraphs)
val_size     = int(0.1 * num_subgraphs)
test_size    = num_subgraphs - train_size - val_size

# 2) Random split (with fixed seed for reproducibility)
indices = np.arange(num_subgraphs)
np.random.seed(42)
np.random.shuffle(indices)

train_idx = indices[:train_size]
val_idx   = indices[train_size:train_size+val_size]
test_idx  = indices[train_size+val_size:]

train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
val_dataset   = torch.utils.data.Subset(full_dataset, val_idx)
test_dataset  = torch.utils.data.Subset(full_dataset, test_idx)

print(f"Train set size: {len(train_dataset)}")
print(f"Val   set size: {len(val_dataset)}")
print(f"Test  set size: {len(test_dataset)}")

# 3) Build PyG DataLoaders with class-aware sampling
batch_size = 128

# Compute per-graph positive‐edge fraction
graph_pos_frac = [g.edge_mask.float().mean().item() for g in full_dataset]
# Inverse weighting: draw more from graphs with fewer positives
graph_weights  = [1.0 / (p if p > 0 else 1e-4) for p in graph_pos_frac]
train_weights  = [graph_weights[i] for i in train_idx]

train_sampler = WeightedRandomSampler(
    weights=train_weights,
    num_samples=len(train_idx),
    replacement=True
)

train_loader = PyGDataLoader(
    train_dataset, batch_size=batch_size, sampler=train_sampler
)
val_loader = PyGDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = PyGDataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("DataLoaders created with batch_size =", batch_size)


Train set size: 933
Val   set size: 116
Test  set size: 118
DataLoaders created with batch_size = 128


In [39]:
# Cell 6‑bis – compute pos_weight once (run AFTER Cell 6)

# Define the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─────────────────────────────────────────────────────────
#   Needed globally for class‑balanced BCE loss
# ─────────────────────────────────────────────────────────
all_edge_masks = torch.cat([g.edge_mask for g in full_dataset])
num_pos = all_edge_masks.sum().item()
num_neg = len(all_edge_masks) - num_pos
pos_weight = torch.tensor([(num_neg / num_pos) ** 0.5], device=device) 

print(f"pos_weight = {pos_weight.item():.1f}  (used by BCEWithLogitsLoss)")


pos_weight = 13.5  (used by BCEWithLogitsLoss)


In [40]:
# Cell 7: Model Architectures (aligned to enhanced pipeline)

import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log
from torch_geometric.nn import GINEConv, GATConv, NNConv

# ───── Helpers ────────────────────────────────────
# assumes you’ve already computed pos_weight (in Cell 6‑bis) and have it on the correct device
BIAS_INIT = -log(pos_weight.item())

def make_input_norm(n_feats):
    return nn.LayerNorm(n_feats, elementwise_affine=True)

# ───── GINE ───────────────────────────────────────
class GINEBasedClassifier(nn.Module):
    def __init__(self, in_channels_node=3, in_channels_edge=4, hidden_dim=64):
        super().__init__()
        # per‑node & per‑edge LayerNorm
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        # initial projection
        self.fc_in = nn.Linear(in_channels_node, hidden_dim)
        # GINE conv needs an MLP
        mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim)
        )
        self.conv1 = GINEConv(nn=mlp, edge_dim=in_channels_edge)
        self.conv2 = GINEConv(nn=mlp, edge_dim=in_channels_edge)

        # edge‑prediction MLP
        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        # bias init for class imbalance
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index, edge_attr))
        h = self.conv2(h, edge_index, edge_attr)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()

# ───── GAT ────────────────────────────────────────
class GATBasedClassifier(nn.Module):
    def __init__(self, in_channels_node=3, in_channels_edge=4, hidden_dim=64):
        super().__init__()
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        self.fc_in = nn.Linear(in_channels_node, hidden_dim)
        # 2 heads of size hidden_dim//4 each, concatenated back to hidden_dim
        self.conv1 = GATConv(hidden_dim, hidden_dim // 4, heads=4, concat=True)
        self.conv2 = GATConv(hidden_dim, hidden_dim // 4, heads=4, concat=True)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index))
        h = self.conv2(h, edge_index)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()

# ───── Edge‑aware NNConv ───────────────────────────
class EdgeAwareGraphConvClassifier(nn.Module):
    def __init__(self, in_channels_node=3, in_channels_edge=4, hidden_dim=64):
        super().__init__()
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        self.fc_in = nn.Linear(in_channels_node, hidden_dim)

        # build the edge network for NNConv
        edge_net = nn.Sequential(
            nn.Linear(in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * hidden_dim)
        )
        self.conv1 = NNConv(hidden_dim, hidden_dim, edge_net)
        self.conv2 = NNConv(hidden_dim, hidden_dim, edge_net)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index, edge_attr))
        h = self.conv2(h, edge_index, edge_attr)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()


In [41]:
# ────────────────────────────────────────────
# 📍 Cell 8 – Training + Evaluation utilities
# ────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ---- metrics ----------------------------------------------------
def calculate_metrics(preds, labels):
    tp = ((preds == 1) & (labels == 1)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()
    prec = tp / (tp + fp + 1e-8)
    rec  = tp / (tp + fn + 1e-8)
    f1   = 2 * prec * rec / (prec + rec + 1e-8)
    return prec, rec, f1

# ---- evaluation -------------------------------------------------
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss, all_logits, all_labels = 0.0, [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index, batch.edge_attr)
            loss   = criterion(logits, batch.edge_mask.float())
            total_loss += loss.item()
            all_logits.append(logits)
            all_labels.append(batch.edge_mask.float())

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    probs = torch.sigmoid(all_logits)

    # find best threshold on validation
    best_f1, best_thr = 0, 0
    for thr in np.logspace(-6, -1, 30):
        f1 = calculate_metrics((probs > thr).long(), all_labels.long())[2]
        if f1 > best_f1:
            best_f1, best_thr = f1, thr

    prec, rec, f1 = calculate_metrics((probs > best_thr).long(), all_labels.long())
    return total_loss / len(loader), {'precision': prec, 'recall': rec, 'f1': f1}

# ---- training + full history -----------------------------------
def train_and_evaluate_model(model, model_name,
                             train_loader, val_loader,
                             device,
                             num_epochs=30,
                             patience=7,
                             lr=5e-4):
    """
    Train `model`, evaluate on `val_loader` each epoch, 
    and return a history dict with lists for train_loss, val_loss, precision, recall, f1.
    """
    # you can swap in BCEWithLogitsLoss(pos_weight=...) if you like
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    history = {
        'train_loss': [],
        'val_loss':   [],
        'precision':  [],
        'recall':     [],
        'f1':         []
    }

    best_f1, wait = 0.0, 0
    for epoch in range(1, num_epochs+1):
        # --- training pass ---
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index, batch.edge_attr)

            # in‑batch sampler for imbalance
            pos = batch.edge_mask.nonzero(as_tuple=True)[0]
            neg = (batch.edge_mask == 0).nonzero(as_tuple=True)[0]
            k   = min(len(neg), 5 * len(pos))
            if len(pos)>0 and k>0:
                neg_idx = neg[torch.randperm(len(neg), device=neg.device)[:k]]
                sel = torch.cat([pos, neg_idx])
            else:
                sel = torch.arange(len(batch.edge_mask), device=logits.device)

            loss = criterion(logits[sel], batch.edge_mask.float()[sel])
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # --- validation pass ---
        val_loss, val_metrics = evaluate(model, val_loader, device, criterion)
        history['val_loss'].append(val_loss)
        history['precision'].append(val_metrics['precision'])
        history['recall'].append(val_metrics['recall'])
        history['f1'].append(val_metrics['f1'])

        print(f"{model_name} | Ep {epoch:02d}  Train {avg_train_loss:.4f}  "
              f"ValF1 {val_metrics['f1']:.4f}")

        # early stopping
        if val_metrics['f1'] > best_f1:
            best_f1, wait = val_metrics['f1'], 0
        else:
            wait += 1
            if wait >= patience:
                print(f"⏹️ early stop at {epoch}, best F1={best_f1:.4f}")
                break

    return history


In [42]:
# 📍 Cell 9 – Fast 2‑fold run across models with full metrics saving

import json
import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import WeightedRandomSampler
from torch_geometric.loader import DataLoader as PyGDataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def run_kfold_full(model_name, model_class, dataset, indices,
                   K=2, num_epochs=100, patience=7):
    """
    Runs K‑fold CV, returns:
      - fold_histories: list of per-fold history dicts
      - summary: dict of mean±std for precision, recall, f1 across folds
    """
    kf = KFold(n_splits=K, shuffle=True, random_state=42)
    fold_histories = []
    all_prec, all_rec, all_f1 = [], [], []

    for fold, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
        print(f"\n[{model_name}] Fold {fold}/{K}")
        # build subsets
        tr_ids = [indices[i] for i in train_idx]
        vl_ids = [indices[i] for i in val_idx]
        tr_ds = torch.utils.data.Subset(dataset, tr_ids)
        vl_ds = torch.utils.data.Subset(dataset, vl_ids)

        # weighted sampler to handle imbalance
        frac_pos = [g.edge_mask.float().mean().item() for g in dataset]
        weights = [1.0/(p if p>0 else 1e-4) for p in frac_pos]
        sampler = WeightedRandomSampler([weights[i] for i in tr_ids],
                                        len(tr_ids), replacement=True)

        tr_loader = PyGDataLoader(tr_ds, batch_size=128, sampler=sampler)
        vl_loader = PyGDataLoader(vl_ds, batch_size=128, shuffle=False)

        # init model
        model = model_class().to(device)
        # train and get history
        history = train_and_evaluate_model(
            model=model,
            model_name=f"{model_name}-F{fold}",
            train_loader=tr_loader,
            val_loader=vl_loader,
            device=device,
            num_epochs=num_epochs,
            patience=patience
        )
        # capture final fold metrics (last epoch on validation)
        final_prec = history['precision'][-1]
        final_rec  = history['recall'][-1]
        final_f1   = history['f1'][-1]

        all_prec.append(final_prec)
        all_rec.append(final_rec)
        all_f1.append(final_f1)

        # store everything
        fold_histories.append({
            'fold': fold,
            'train_loss': history['train_loss'],
            'val_loss':   history['val_loss'],
            'precision':  history['precision'],
            'recall':     history['recall'],
            'f1':         history['f1']
        })

    # compute mean±std
    summary = {
        'precision': {'mean': float(np.mean(all_prec)), 'std': float(np.std(all_prec))},
        'recall':    {'mean': float(np.mean(all_rec)),  'std': float(np.std(all_rec))},
        'f1':        {'mean': float(np.mean(all_f1)),   'std': float(np.std(all_f1))}
    }

    return fold_histories, summary

# ==== run all three models ====================================================
model_classes = {
    "GINE":        GINEBasedClassifier,
    "GAT":         GATBasedClassifier,
    "EdgeAwareGC": EdgeAwareGraphConvClassifier
}

all_results = {}
for name, cls in model_classes.items():
    histories, summary = run_kfold_full(
        model_name=name,
        model_class=cls,
        dataset=full_dataset,
        indices=train_idx,
        K=2,
        num_epochs=100,
        patience=7
    )
    all_results[name] = {
        'folds': histories,
        'summary': summary
    }

# save full JSON
with open("base_full_results_118_pbs.json", "w") as fp:
    json.dump(all_results, fp, indent=2)

# print summary
print("\n📊  Fast 2‑Fold Summary (mean ± std)")
for m, metrics in all_results.items():
    sp = metrics['summary']
    print(f"{m}:  F1 {sp['f1']['mean']:.4f}±{sp['f1']['std']:.4f}, "
          f"Prec {sp['precision']['mean']:.4f}±{sp['precision']['std']:.4f}, "
          f"Rec {sp['recall']['mean']:.4f}±{sp['recall']['std']:.4f}")


Using device: cuda

[GINE] Fold 1/2
GINE-F1 | Ep 01  Train 0.4746  ValF1 0.0109
GINE-F1 | Ep 02  Train 0.4164  ValF1 0.0128
GINE-F1 | Ep 03  Train 0.3743  ValF1 0.0215
GINE-F1 | Ep 04  Train 0.3457  ValF1 0.0203
GINE-F1 | Ep 05  Train 0.3025  ValF1 0.0306
GINE-F1 | Ep 06  Train 0.2610  ValF1 0.0342
GINE-F1 | Ep 07  Train 0.2380  ValF1 0.0641
GINE-F1 | Ep 08  Train 0.2073  ValF1 0.0688
GINE-F1 | Ep 09  Train 0.1910  ValF1 0.1197
GINE-F1 | Ep 10  Train 0.1712  ValF1 0.1506
GINE-F1 | Ep 11  Train 0.1601  ValF1 0.1777
GINE-F1 | Ep 12  Train 0.1411  ValF1 0.1804
GINE-F1 | Ep 13  Train 0.1262  ValF1 0.1805
GINE-F1 | Ep 14  Train 0.1210  ValF1 0.2306
GINE-F1 | Ep 15  Train 0.1123  ValF1 0.2105
GINE-F1 | Ep 16  Train 0.1054  ValF1 0.2554
GINE-F1 | Ep 17  Train 0.0888  ValF1 0.2832
GINE-F1 | Ep 18  Train 0.0950  ValF1 0.2529
GINE-F1 | Ep 19  Train 0.0916  ValF1 0.2832
GINE-F1 | Ep 20  Train 0.0780  ValF1 0.2605
GINE-F1 | Ep 21  Train 0.0905  ValF1 0.3228
GINE-F1 | Ep 22  Train 0.0764  ValF1 0.3