In [24]:
# Cell 1: Basic Imports and Data Loading

import os
import json
import torch
import numpy as np
from pathlib import Path
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GINEConv, GATConv, GraphConv

# Load dataset
#file_path = 'dataset/ieee118/processed_b/data.pt'
#file_path = 'dataset/ieee24/ieee24/processed_b/data.pt'
#file_path = 'dataset/ieee39/processed_b/data.pt'
file_path = 'dataset/uk/processed_b/data.pt'

loaded_data = torch.load(file_path, map_location=torch.device('cpu'))

# Check dataset structure
print("Loaded data type:", type(loaded_data))
print("Length of dataset tuple:", len(loaded_data))

# Inspect first element (summary only)
print("\nFirst element type:", type(loaded_data[0]))

# Extract metadata dictionary
metadata_dict = loaded_data[1]
print("\nMetadata keys:", list(metadata_dict.keys()))

# Preview metadata values (first 10 entries)
for key, val in metadata_dict.items():
    preview = val[:10] if hasattr(val, '__len__') else "N/A"
    print(f"{key}: {type(val)}, first 10: {preview}")


Loaded data type: <class 'tuple'>
Length of dataset tuple: 2

First element type: <class 'torch_geometric.data.data.Data'>

Metadata keys: ['x', 'edge_index', 'edge_attr', 'y', 'edge_mask', 'idx']
x: <class 'torch.Tensor'>, first 10: tensor([  0,  29,  58,  87, 116, 145, 174, 203, 232, 261])
edge_index: <class 'torch.Tensor'>, first 10: tensor([   0,  196,  392,  588,  784,  980, 1176, 1372, 1568, 1764])
edge_attr: <class 'torch.Tensor'>, first 10: tensor([   0,  196,  392,  588,  784,  980, 1176, 1372, 1568, 1764])
y: <class 'torch.Tensor'>, first 10: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
edge_mask: <class 'torch.Tensor'>, first 10: tensor([   0,  196,  392,  588,  784,  980, 1176, 1372, 1568, 1764])
idx: <class 'torch.Tensor'>, first 10: tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])


  loaded_data = torch.load(file_path, map_location=torch.device('cpu'))


In [25]:
# Cell 2: Enhanced get_subgraph with new features

import networkx as nx
import torch
from torch_geometric.utils import degree, to_networkx
import numpy as np
from torch_geometric.data import Data

def get_subgraph(data_flat, meta_dict, i):
    """
    Enhanced function to reconstruct the i-th subgraph with additional features.

    data_flat: The big flattened Data object (loaded_data[0])
    meta_dict: The dictionary of offsets (loaded_data[1])
    i        : Index of the subgraph we want to reconstruct

    returns: a PyG Data object representing the i-th subgraph
    """
    # 1) Node offsets
    x_start = meta_dict['x'][i].item()
    x_end   = meta_dict['x'][i+1].item()
    x_i = data_flat.x[x_start:x_end]
    
    # 2) Edge offsets
    e_start = meta_dict['edge_index'][i].item()
    e_end   = meta_dict['edge_index'][i+1].item()
    edge_index_i = data_flat.edge_index[:, e_start:e_end]
    edge_attr_i  = data_flat.edge_attr[e_start:e_end]
    
    # 3) Load TRUE binary edge labels (explanation_mask)
    edge_mask_i = data_flat.edge_mask[e_start:e_end].float()
    
    # Strict validation for edge_mask
    if not torch.all(torch.isin(edge_mask_i, torch.tensor([0., 1.]))):
        print(f"BAD SUBGRAPH {i}:")
        print("Unique values:", edge_mask_i.unique())
        print("Edge indices:", edge_index_i)
        raise ValueError("Edge mask contains non-binary values")
    
    # =====================
    # New Node Features
    # =====================
    # Convert PyG edge index to NetworkX graph
    G = to_networkx(Data(edge_index=edge_index_i, num_nodes=x_i.shape[0]), to_undirected=True)

    # 1. Node Betweenness Centrality
    node_betweenness_dict = nx.betweenness_centrality(G)
    node_betweenness = torch.tensor([node_betweenness_dict.get(n, 0.0) for n in range(x_i.shape[0])], dtype=torch.float)

    # 2. Node Degree (Fixed `.reshape(-1)`)
    node_deg = degree(edge_index_i.reshape(-1), num_nodes=x_i.shape[0], dtype=torch.float)
    
    # 3. Corrected Voltage Magnitude
    voltage_mag = x_i[:, 2] + 1.0  # Convert deviation to absolute voltage
    voltage_dev = torch.abs(voltage_mag - 1.0).unsqueeze(1)
    
    # Concatenate new node features
    x_i = torch.cat([x_i, node_betweenness.unsqueeze(1), node_deg.unsqueeze(1), voltage_dev], dim=1)
    
    # =====================
    # New Edge Features
    # =====================
    # 1. Edge Betweenness Centrality
    edge_bc_dict = nx.edge_betweenness_centrality(G)
    #edge_bc = torch.tensor([edge_bc_dict.get(tuple(e.tolist()), 0.0) for e in edge_index_i.T], dtype=torch.float)
    edge_bc = torch.tensor([edge_bc_dict.get(tuple(sorted(e.tolist())), 0.0)
                        for e in edge_index_i.T], dtype=torch.float)


    # 2. Load Percentage (P / lr)
    P = edge_attr_i[:, 0]  # Active power
    lr = edge_attr_i[:, 3]  # Line rating
    load_pct = (P / (lr + 1e-8)).unsqueeze(1)  # Add epsilon to avoid division by zero
    
    # 3. Electrical Betweenness (simplified)
    Q = edge_attr_i[:, 1]  # Reactive power
    elec_betweenness = (torch.abs(P) + torch.abs(Q)).unsqueeze(1)
    
    # Concatenate new edge features
    edge_attr_i = torch.cat([
        edge_attr_i, 
        edge_bc.unsqueeze(1),
        load_pct,
        elec_betweenness
    ], dim=1)
    
    # 4) Graph label (binary or multi-class)
    y_i = data_flat.y[i]
    
    # 5) Build a new Data object
    subgraph_i = Data(
        x=x_i,
        edge_index=edge_index_i,
        edge_attr=edge_attr_i,
        y=y_i.unsqueeze(0),  # Keep graph-level label if needed
        edge_mask=edge_mask_i  # Add binary edge labels
    )
    
    return subgraph_i

# Test subgraph reconstruction
i_test = 0
subgraph_0 = get_subgraph(loaded_data[0], loaded_data[1], i_test)
print("Subgraph 0:")
print(subgraph_0)
print("Edge mask values:", subgraph_0.edge_mask.unique())  # Should be [0., 1.]


Subgraph 0:
Data(x=[29, 6], edge_index=[2, 196], edge_attr=[196, 7], y=[1, 1], edge_mask=[196])
Edge mask values: tensor([0.])


In [26]:
#Cell 3: Verification 

# Sample dataset (replace with your actual dataset)
dataset = [
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 1, 0]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3),  # Category B or D
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 0, 1]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3)   # Category B or D
]

def verify_edge_mask_coverage(dataset):
    """Check if edge_mask is defined for all graphs with cascading failures (y=1)."""
    print("Verifying Edge Mask Coverage")
    has_cascading = 0
    has_edge_mask_defined = 0
    
    for i, graph in enumerate(dataset):
        if graph.y.item() == 1:  # Graphs with cascading failures (Categories A and C)
            has_cascading += 1
            if graph.edge_mask is not None and len(graph.edge_mask) == graph.num_edges:
                has_edge_mask_defined += 1
            else:
                print(f"Graph {i}: Missing or incomplete edge_mask for cascading failure graph.")
    
    print(f"Graphs with cascading failures: {has_cascading}")
    print(f"Graphs with defined edge_mask: {has_edge_mask_defined}")
    if has_cascading == has_edge_mask_defined:
        print(" Edge mask coverage is complete for cascading failure graphs.")
    else:
        print("Edge mask is missing or incomplete for some cascading failure graphs.")

def check_edge_label_distribution(dataset):
    """Examine the distribution of edge labels (1s and 0s) across graphs."""
    print("\n Checking Edge Label Distribution")
    total_edges = 0
    tripped_edges = 0
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        num_tripped = edge_mask.sum().item()
        total_edges += len(edge_mask)
        tripped_edges += num_tripped
        print(f"Graph {i}: {num_tripped} tripped edges (1s), {len(edge_mask) - num_tripped} non-tripped (0s)")
    
    print(f"Total edges: {total_edges}")
    print(f"Tripped edges (1s): {tripped_edges}")
    print(f"Non-tripped edges (0s): {total_edges - tripped_edges}")
    print(f"Percentage of tripped edges: {(tripped_edges / total_edges * 100):.2f}%")

def validate_graph_edge_consistency(dataset):
    """Ensure edge_mask aligns with graph-level labels (y)."""
    print("\n Validating Graph-Edge Label Consistency ")
    all_valid = True
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        y = graph.y.item()
        
        if y == 1:  # Categories A and C (cascading failures)
            if edge_mask.sum() == 0:
                print(f"Graph {i}: Inconsistent - y=1 but no tripped edges in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=1 and tripped edges present.")
        elif y == 0:  # Categories B and D (no cascading failures)
            if edge_mask.sum() > 0:
                print(f"Graph {i}: Inconsistent - y=0 but tripped edges present in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=0 and no tripped edges.")
    
    if all_valid:
        print("All graphs have consistent edge_mask and y labels.")
    else:
        print(" Some graphs have inconsistencies between edge_mask and y.")

# Run the verifications
verify_edge_mask_coverage(dataset)
check_edge_label_distribution(dataset)
validate_graph_edge_consistency(dataset)

Verifying Edge Mask Coverage
Graphs with cascading failures: 2
Graphs with defined edge_mask: 2
 Edge mask coverage is complete for cascading failure graphs.

 Checking Edge Label Distribution
Graph 0: 1 tripped edges (1s), 2 non-tripped (0s)
Graph 1: 0 tripped edges (1s), 3 non-tripped (0s)
Graph 2: 1 tripped edges (1s), 2 non-tripped (0s)
Graph 3: 0 tripped edges (1s), 3 non-tripped (0s)
Total edges: 12
Tripped edges (1s): 2
Non-tripped edges (0s): 10
Percentage of tripped edges: 16.67%

 Validating Graph-Edge Label Consistency 
Graph 0: Consistent - y=1 and tripped edges present.
Graph 1: Consistent - y=0 and no tripped edges.
Graph 2: Consistent - y=1 and tripped edges present.
Graph 3: Consistent - y=0 and no tripped edges.
All graphs have consistent edge_mask and y labels.


In [27]:
# Cell 4: Create a PyTorch Dataset for our subgraphs

class PowerGraphDataset(Dataset):
    def __init__(self, data_flat, meta_dict, indices=None, filter_category_A=True):
        """
        data_flat:  The giant flattened Data object
        meta_dict:  Dictionary of offsets
        indices:    Subgraph indices to include
        filter_category_A: If True, only include graphs with cascading failures (edge_mask != 0)
        """
        super().__init__()
        self.data_flat = data_flat
        self.meta_dict = meta_dict
        self.filter_category_A = filter_category_A
        
        if indices is None:
            self.indices = list(range(len(meta_dict['x']) - 1))
        else:
            self.indices = indices
            
        # Filter to Category A (DNS > 0 with cascading failures)
        if self.filter_category_A:
            self.indices = self._filter_category_A()
    
    def _filter_category_A(self):
        """Retain indices where edge_mask has at least one failure (1)"""
        valid_indices = []
        for idx in self.indices:
            e_start = self.meta_dict['edge_index'][idx].item()
            e_end = self.meta_dict['edge_index'][idx+1].item()
            edge_mask = self.data_flat.edge_mask[e_start:e_end]  # Use edge_mask
            if edge_mask.sum() > 0:  # At least one failed edge
                valid_indices.append(idx)
        return valid_indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        subgraph_id = self.indices[idx]
        return get_subgraph(self.data_flat, self.meta_dict, subgraph_id)

# Create dataset (only Category A graphs)
full_dataset = PowerGraphDataset(loaded_data[0], loaded_data[1], filter_category_A=True)
print("Total subgraphs in full_dataset:", len(full_dataset))

Total subgraphs in full_dataset: 4531


In [28]:
# After creating full_dataset (Cell 5):
all_edge_masks = torch.cat([batch.edge_mask for batch in full_dataset])
num_positive = all_edge_masks.sum().item()
num_negative = len(all_edge_masks) - num_positive

print(f"Edge label distribution:")
print(f"- Failed edges (1): {num_positive} ({num_positive / len(all_edge_masks):.2%})")
print(f"- Stable edges (0): {num_negative} ({num_negative / len(all_edge_masks):.2%})")

Edge label distribution:
- Failed edges (1): 28168.0 (3.18%)
- Stable edges (0): 856874.0 (96.82%)


In [29]:
# 📍  Diagnostic Cell – feature range check
import torch

def feature_stats(tensor, name):
    print(f"{name:>18}:  min {tensor.min():>8.4f}   max {tensor.max():>8.4f}   mean {tensor.mean():>8.4f}")

# ── gather all node features ───────────────────────────────────────
all_nodes = torch.cat([g.x for g in full_dataset])          # [N_total, 6]
print("\nNode‑feature ranges (col order = original 3  + betweenness, degree, |V−1|)")
for i, fn in enumerate(["orig‑0", "orig‑1", "orig‑2",
                        "node_betweenness", "node_degree", "voltage_dev"]):
    feature_stats(all_nodes[:, i], fn)

# ── gather all edge features ───────────────────────────────────────
all_edges = torch.cat([g.edge_attr for g in full_dataset])  # [E_total, 7]
print("\nEdge‑feature ranges (col order = original 4 + edge_bc, load_pct, elec_betw)")
for i, fn in enumerate(["orig‑0", "orig‑1", "orig‑2", "orig‑3",
                        "edge_bc", "load_pct", "elec_betw"]):
    feature_stats(all_edges[:, i], fn)



Node‑feature ranges (col order = original 3  + betweenness, degree, |V−1|)
            orig‑0:  min  -0.6415   max   0.5421   mean  -0.0000
            orig‑1:  min  -0.4281   max   0.7893   mean   0.0000
            orig‑2:  min  -0.9447   max   0.1346   mean   0.0000
  node_betweenness:  min   0.0000   max   0.5105   mean   0.1095
       node_degree:  min   2.0000   max  24.0000   mean  13.4711
       voltage_dev:  min   0.0000   max   0.9447   mean   0.1130

Edge‑feature ranges (col order = original 4 + edge_bc, load_pct, elec_betw)
            orig‑0:  min  -0.6100   max   0.5772   mean  -0.0004
            orig‑1:  min  -0.8540   max   0.5069   mean  -0.0003
            orig‑2:  min  -0.1007   max   0.8993   mean  -0.0021
            orig‑3:  min  -0.3198   max   0.6802   mean   0.0030
           edge_bc:  min   0.0025   max   0.4680   mean   0.0802
          load_pct:  min -21.5343   max  25.4076   mean   0.1954
         elec_betw:  min   0.0012   max   1.2605   mean   0.2331


In [30]:
# Cell 6: Train/Val/Test split & DataLoaders (with class-aware splitting)

# 1) Handle extreme class imbalance
# --------------------------------------------------------
# Calculate split sizes based on the filtered Category A dataset
num_subgraphs = len(full_dataset)  
train_size = int(0.8 * num_subgraphs)   
val_size = int(0.1 * num_subgraphs)     
test_size = num_subgraphs - train_size - val_size  

# 2) Stratified split to preserve class distribution
indices = np.arange(num_subgraphs)
np.random.seed(42)
np.random.shuffle(indices)

train_idx = indices[:train_size]
val_idx = indices[train_size:train_size+val_size]
test_idx = indices[train_size+val_size:]

train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
test_dataset = torch.utils.data.Subset(full_dataset, test_idx)

print(f"Train set size: {len(train_dataset)}")
print(f"Val set size:   {len(val_dataset)}")
print(f"Test set size:  {len(test_dataset)}")

# 3) Build PyG DataLoaders with class-aware sampling
batch_size = 128

# Use weighted sampler to handle edge-level imbalance with inverse weighting
graph_pos_frac  = [g.edge_mask.float().mean().item() for g in full_dataset]
graph_weights   = [1.0 / (p if p > 0 else 1e-4) for p in graph_pos_frac]   # inverse proportion
train_weights   = [graph_weights[i] for i in train_idx]
train_sampler   = torch.utils.data.WeightedRandomSampler(train_weights,
                                                         len(train_idx),
                                                         replacement=True)

train_loader = PyGDataLoader(
    train_dataset, batch_size=batch_size, sampler=train_sampler
)
val_loader = PyGDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = PyGDataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("DataLoaders created with batch_size =", batch_size)


Train set size: 3624
Val set size:   453
Test set size:  454
DataLoaders created with batch_size = 128


In [31]:
# Cell 6‑bis – compute pos_weight once (run AFTER Cell 6)

# Define the device (CPU or GPU)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# ─────────────────────────────────────────────────────────
#   Needed globally for class‑balanced BCE loss
# ─────────────────────────────────────────────────────────
all_edge_masks = torch.cat([g.edge_mask for g in full_dataset])
num_pos = all_edge_masks.sum().item()
num_neg = len(all_edge_masks) - num_pos
pos_weight = torch.tensor([(num_neg / num_pos) ** 0.5], device=device) 

print(f"pos_weight = {pos_weight.item():.1f}  (used by BCEWithLogitsLoss)")


pos_weight = 5.5  (used by BCEWithLogitsLoss)


In [32]:
# ────────────────────────────────
# 📍 Cell 7 – Model Architectures
# ────────────────────────────────
import torch
import torch.nn as nn
import torch.nn.functional as F
from math import log
from torch_geometric.nn import GINEConv, GATConv, NNConv

# ---- helpers ----------------------------------------------------
BIAS_INIT = -log(pos_weight.item())      #  –ln(√(neg/pos))

def make_input_norm(n_feats):
    return nn.LayerNorm(n_feats, elementwise_affine=True)

# ---- GINE -------------------------------------------------------
class GINEBasedClassifier(nn.Module):
    def __init__(self, in_channels_node=6, in_channels_edge=7, hidden_dim=64):
        super().__init__()
        # input normalisation
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        self.fc_in   = nn.Linear(in_channels_node, hidden_dim)
        mlp          = nn.Sequential(nn.Linear(hidden_dim, hidden_dim),
                                     nn.ReLU(),
                                     nn.LayerNorm(hidden_dim))
        self.conv1   = GINEConv(nn=mlp, edge_dim=in_channels_edge)
        self.conv2   = GINEConv(nn=mlp, edge_dim=in_channels_edge)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index, edge_attr))
        h = self.conv2(h, edge_index, edge_attr)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()


# ---- GAT --------------------------------------------------------
class GATBasedClassifier(nn.Module):
    def __init__(self, in_channels_node=6, in_channels_edge=7, hidden_dim=64):
        super().__init__()
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        self.fc_in = nn.Linear(in_channels_node, hidden_dim)
        self.conv1 = GATConv(hidden_dim, hidden_dim // 4, heads=4, concat=True)
        self.conv2 = GATConv(hidden_dim, hidden_dim // 4, heads=4, concat=True)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index))
        h = self.conv2(h, edge_index)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()


# ---- Edge‑aware GraphConv (NNConv) ------------------------------
class EdgeAwareGraphConvClassifier(nn.Module):
    def __init__(self, in_channels_node=6, in_channels_edge=7, hidden_dim=64):
        super().__init__()
        self.node_norm = make_input_norm(in_channels_node)
        self.edge_norm = make_input_norm(in_channels_edge)

        self.fc_in = nn.Linear(in_channels_node, hidden_dim)

        edge_net = nn.Sequential(
            nn.Linear(in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim * hidden_dim)
        )
        self.conv1 = NNConv(hidden_dim, hidden_dim, edge_net)
        self.conv2 = NNConv(hidden_dim, hidden_dim, edge_net)

        self.edge_mlp = nn.Sequential(
            nn.Linear(2*hidden_dim + in_channels_edge, hidden_dim),
            nn.ReLU(),
            nn.LayerNorm(hidden_dim),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim, 1)
        )
        with torch.no_grad():
            self.edge_mlp[-1].bias.fill_(BIAS_INIT)

    def forward(self, x, edge_index, edge_attr):
        x         = self.node_norm(x)
        edge_attr = self.edge_norm(edge_attr)

        h = F.relu(self.fc_in(x))
        h = F.relu(self.conv1(h, edge_index, edge_attr))
        h = self.conv2(h, edge_index, edge_attr)

        h_u, h_v = h[edge_index[0]], h[edge_index[1]]
        return self.edge_mlp(torch.cat([h_u, h_v, edge_attr], dim=1)).squeeze()


In [33]:
# ────────────────────────────────────────────
# 📍 Cell 8 – Training + Evaluation utilities
# ────────────────────────────────────────────

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

# ---- metrics ----------------------------------------------------
def calculate_metrics(preds, labels):
    tp = ((preds == 1) & (labels == 1)).sum().item()
    fp = ((preds == 1) & (labels == 0)).sum().item()
    fn = ((preds == 0) & (labels == 1)).sum().item()
    prec = tp / (tp + fp + 1e-8)
    rec  = tp / (tp + fn + 1e-8)
    f1   = 2 * prec * rec / (prec + rec + 1e-8)
    return prec, rec, f1

# ---- evaluation -------------------------------------------------
def evaluate(model, loader, device, criterion):
    model.eval()
    total_loss, all_logits, all_labels = 0.0, [], []

    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index, batch.edge_attr)
            loss   = criterion(logits, batch.edge_mask.float())
            total_loss += loss.item()
            all_logits.append(logits)
            all_labels.append(batch.edge_mask.float())

    all_logits = torch.cat(all_logits)
    all_labels = torch.cat(all_labels)
    probs = torch.sigmoid(all_logits)

    # find best threshold on validation
    best_f1, best_thr = 0, 0
    for thr in np.logspace(-6, -1, 30):
        f1 = calculate_metrics((probs > thr).long(), all_labels.long())[2]
        if f1 > best_f1:
            best_f1, best_thr = f1, thr

    prec, rec, f1 = calculate_metrics((probs > best_thr).long(), all_labels.long())
    return total_loss / len(loader), {'precision': prec, 'recall': rec, 'f1': f1}

# ---- training + full history -----------------------------------
def train_and_evaluate_model(model, model_name,
                             train_loader, val_loader,
                             device,
                             num_epochs=30,
                             patience=7,
                             lr=5e-4):
    """
    Train `model`, evaluate on `val_loader` each epoch, 
    and return a history dict with lists for train_loss, val_loss, precision, recall, f1.
    """
    # you can swap in BCEWithLogitsLoss(pos_weight=...) if you like
    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-4)

    history = {
        'train_loss': [],
        'val_loss':   [],
        'precision':  [],
        'recall':     [],
        'f1':         []
    }

    best_f1, wait = 0.0, 0
    for epoch in range(1, num_epochs+1):
        # --- training pass ---
        model.train()
        running_loss = 0.0
        for batch in train_loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index, batch.edge_attr)

            # in‑batch sampler for imbalance
            pos = batch.edge_mask.nonzero(as_tuple=True)[0]
            neg = (batch.edge_mask == 0).nonzero(as_tuple=True)[0]
            k   = min(len(neg), 5 * len(pos))
            if len(pos)>0 and k>0:
                neg_idx = neg[torch.randperm(len(neg), device=neg.device)[:k]]
                sel = torch.cat([pos, neg_idx])
            else:
                sel = torch.arange(len(batch.edge_mask), device=logits.device)

            loss = criterion(logits[sel], batch.edge_mask.float()[sel])
            optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            running_loss += loss.item()

        avg_train_loss = running_loss / len(train_loader)
        history['train_loss'].append(avg_train_loss)

        # --- validation pass ---
        val_loss, val_metrics = evaluate(model, val_loader, device, criterion)
        history['val_loss'].append(val_loss)
        history['precision'].append(val_metrics['precision'])
        history['recall'].append(val_metrics['recall'])
        history['f1'].append(val_metrics['f1'])

        print(f"{model_name} | Ep {epoch:02d}  Train {avg_train_loss:.4f}  "
              f"ValF1 {val_metrics['f1']:.4f}")

        # early stopping
        if val_metrics['f1'] > best_f1:
            best_f1, wait = val_metrics['f1'], 0
        else:
            wait += 1
            if wait >= patience:
                print(f"⏹️ early stop at {epoch}, best F1={best_f1:.4f}")
                break

    return history


In [34]:
# 📍 Cell 9 – Fast 2‑fold run across models with full metrics saving

import json
import numpy as np
from sklearn.model_selection import KFold
from torch.utils.data import WeightedRandomSampler
from torch_geometric.loader import DataLoader as PyGDataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

def run_kfold_full(model_name, model_class, dataset, indices,
                   K=2, num_epochs=100, patience=7):
    """
    Runs K‑fold CV, returns:
      - fold_histories: list of per-fold history dicts
      - summary: dict of mean±std for precision, recall, f1 across folds
    """
    kf = KFold(n_splits=K, shuffle=True, random_state=42)
    fold_histories = []
    all_prec, all_rec, all_f1 = [], [], []

    for fold, (train_idx, val_idx) in enumerate(kf.split(indices), 1):
        print(f"\n[{model_name}] Fold {fold}/{K}")
        # build subsets
        tr_ids = [indices[i] for i in train_idx]
        vl_ids = [indices[i] for i in val_idx]
        tr_ds = torch.utils.data.Subset(dataset, tr_ids)
        vl_ds = torch.utils.data.Subset(dataset, vl_ids)

        # weighted sampler to handle imbalance
        frac_pos = [g.edge_mask.float().mean().item() for g in dataset]
        weights = [1.0/(p if p>0 else 1e-4) for p in frac_pos]
        sampler = WeightedRandomSampler([weights[i] for i in tr_ids],
                                        len(tr_ids), replacement=True)

        tr_loader = PyGDataLoader(tr_ds, batch_size=128, sampler=sampler)
        vl_loader = PyGDataLoader(vl_ds, batch_size=128, shuffle=False)

        # init model
        model = model_class().to(device)
        # train and get history
        history = train_and_evaluate_model(
            model=model,
            model_name=f"{model_name}-F{fold}",
            train_loader=tr_loader,
            val_loader=vl_loader,
            device=device,
            num_epochs=num_epochs,
            patience=patience
        )
        # capture final fold metrics (last epoch on validation)
        final_prec = history['precision'][-1]
        final_rec  = history['recall'][-1]
        final_f1   = history['f1'][-1]

        all_prec.append(final_prec)
        all_rec.append(final_rec)
        all_f1.append(final_f1)

        # store everything
        fold_histories.append({
            'fold': fold,
            'train_loss': history['train_loss'],
            'val_loss':   history['val_loss'],
            'precision':  history['precision'],
            'recall':     history['recall'],
            'f1':         history['f1']
        })

    # compute mean±std
    summary = {
        'precision': {'mean': float(np.mean(all_prec)), 'std': float(np.std(all_prec))},
        'recall':    {'mean': float(np.mean(all_rec)),  'std': float(np.std(all_rec))},
        'f1':        {'mean': float(np.mean(all_f1)),   'std': float(np.std(all_f1))}
    }

    return fold_histories, summary

# ==== run all three models ====================================================
model_classes = {
    "GINE":        GINEBasedClassifier,
    "GAT":         GATBasedClassifier,
    "EdgeAwareGC": EdgeAwareGraphConvClassifier
}

all_results = {}
for name, cls in model_classes.items():
    histories, summary = run_kfold_full(
        model_name=name,
        model_class=cls,
        dataset=full_dataset,
        indices=train_idx,
        K=2,
        num_epochs=100,
        patience=7
    )
    all_results[name] = {
        'folds': histories,
        'summary': summary
    }

# save full JSON
with open("enhanced_full_results_UK_pbs.json", "w") as fp:
    json.dump(all_results, fp, indent=2)

# print summary
print("\n📊  Fast 2‑Fold Summary (mean ± std)")
for m, metrics in all_results.items():
    sp = metrics['summary']
    print(f"{m}:  F1 {sp['f1']['mean']:.4f}±{sp['f1']['std']:.4f}, "
          f"Prec {sp['precision']['mean']:.4f}±{sp['precision']['std']:.4f}, "
          f"Rec {sp['recall']['mean']:.4f}±{sp['recall']['std']:.4f}")


Using device: cuda

[GINE] Fold 1/2
GINE-F1 | Ep 01  Train 0.3275  ValF1 0.2807
GINE-F1 | Ep 02  Train 0.2161  ValF1 0.4466
GINE-F1 | Ep 03  Train 0.2022  ValF1 0.3968
GINE-F1 | Ep 04  Train 0.1882  ValF1 0.4319
GINE-F1 | Ep 05  Train 0.1879  ValF1 0.3199
GINE-F1 | Ep 06  Train 0.1772  ValF1 0.3632
GINE-F1 | Ep 07  Train 0.1727  ValF1 0.3928
GINE-F1 | Ep 08  Train 0.1639  ValF1 0.4337
GINE-F1 | Ep 09  Train 0.1615  ValF1 0.3879
⏹️ early stop at 9, best F1=0.4466

[GINE] Fold 2/2
GINE-F2 | Ep 01  Train 0.3523  ValF1 0.2581
GINE-F2 | Ep 02  Train 0.2306  ValF1 0.4687
GINE-F2 | Ep 03  Train 0.2075  ValF1 0.3874
GINE-F2 | Ep 04  Train 0.1929  ValF1 0.3628
GINE-F2 | Ep 05  Train 0.1824  ValF1 0.3517
GINE-F2 | Ep 06  Train 0.1753  ValF1 0.4235
GINE-F2 | Ep 07  Train 0.1778  ValF1 0.3943
GINE-F2 | Ep 08  Train 0.1604  ValF1 0.3570
GINE-F2 | Ep 09  Train 0.1579  ValF1 0.3818
⏹️ early stop at 9, best F1=0.4687

[GAT] Fold 1/2
GAT-F1 | Ep 01  Train 0.3987  ValF1 0.1258
GAT-F1 | Ep 02  Train 0.27

In [74]:
# Cell 10: Joint Node & Edge Feature Importance

import torch
import torch.nn.functional as F
import pandas as pd

# 1) Reinit models
models = {
    "GINE":        GINEBasedClassifier(in_channels_node=6, in_channels_edge=7).to(device),
    "GAT":         GATBasedClassifier(in_channels_node=6, in_channels_edge=7).to(device),
    "EdgeAwareGC": EdgeAwareGraphConvClassifier(in_channels_node=6, in_channels_edge=7).to(device),
}

# 2) Train each model on full training set
criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
for name, model in models.items():
    print(f"\nTraining {name} for feature‐importance…")
    train_one_model(
        model=model,
        name=f"{name}-FI",
        train_loader=train_loader,
        val_loader=val_loader,
        device=device,
        criterion=criterion,
        num_epochs=50,
        patience=10,
        lr=5e-4
    )

def analyze_importance(model, loader, device,
                       in_node_feats=6, in_edge_feats=7, hidden_dim=64):
    """Return (node_imp, edge_imp) arrays."""
    model.eval()
    # locate the two first‐linear layers:
    node_lin = model.fc_in            # [hidden_dim x in_node_feats]
    edge_lin = model.edge_mlp[0]      # [hidden_dim x (2*hidden_dim + in_edge_feats)]

    # precompute slice indices for edge‐features
    start = 2 * hidden_dim
    end   = start + in_edge_feats

    node_grads = torch.zeros(in_node_feats, device=device)
    edge_grads = torch.zeros(in_edge_feats, device=device)

    for batch in loader:
        model.zero_grad()
        batch = batch.to(device)
        logits = model(batch.x, batch.edge_index, batch.edge_attr)
        loss = criterion(logits, batch.edge_mask.float())
        loss.backward()

        # Node feature importance: sum abs weight‐grads over hidden_dim
        wg_node = node_lin.weight.grad.abs().sum(dim=0)  # size in_node_feats
        node_grads += wg_node

        # Edge feature importance: sum abs weight‐grads only on those last columns
        wg_edge = edge_lin.weight.grad.abs().sum(dim=0)[start:end]  # size in_edge_feats
        edge_grads += wg_edge

    # normalize by number of graphs
    N = len(loader.dataset)
    return (node_grads.cpu().numpy()/N,
            edge_grads.cpu().numpy()/N)

# 3) Compute for each model
results = {}
for name, model in models.items():
    print(f"\nAnalyzing {name} importance…")
    n_imp, e_imp = analyze_importance(model, train_loader, device)
    results[name] = (n_imp, e_imp)

# 4) Build DataFrame
node_names = ["Orig‐1","Orig‐2","Orig‐3","Node BC","Node Deg","Volt Dev"]
edge_names = ["P","Q","X","lr","Edge BC","Load %","Elec BC"]

df_node = pd.DataFrame({m: r[0] for m, r in results.items()}, index=node_names)
df_edge = pd.DataFrame({m: r[1] for m, r in results.items()}, index=edge_names)

df_node.index.name = "Node Feature"
df_edge.index.name = "Edge Feature"

# 5) Save & print
df_node.to_csv("node_importance.csv")
df_edge.to_csv("edge_importance.csv")

print("\nNode Feature Importance:\n", df_node)
print("\nEdge Feature Importance:\n", df_edge)



Training GINE for feature‐importance…
GINE-FI | Ep 01  Train 0.7552  ValF1 0.1506
GINE-FI | Ep 02  Train 0.4770  ValF1 0.1534
GINE-FI | Ep 03  Train 0.3844  ValF1 0.1818
GINE-FI | Ep 04  Train 0.3216  ValF1 0.1694
GINE-FI | Ep 05  Train 0.2742  ValF1 0.1886
GINE-FI | Ep 06  Train 0.2576  ValF1 0.2016
GINE-FI | Ep 07  Train 0.2427  ValF1 0.2469
GINE-FI | Ep 08  Train 0.2080  ValF1 0.3003
GINE-FI | Ep 09  Train 0.2254  ValF1 0.1903
GINE-FI | Ep 10  Train 0.1893  ValF1 0.2615
GINE-FI | Ep 11  Train 0.2045  ValF1 0.2203
GINE-FI | Ep 12  Train 0.1877  ValF1 0.3635
GINE-FI | Ep 13  Train 0.1725  ValF1 0.3379
GINE-FI | Ep 14  Train 0.1486  ValF1 0.4114
GINE-FI | Ep 15  Train 0.1563  ValF1 0.4538
GINE-FI | Ep 16  Train 0.1512  ValF1 0.5296
GINE-FI | Ep 17  Train 0.1183  ValF1 0.7718
GINE-FI | Ep 18  Train 0.1520  ValF1 0.4643
GINE-FI | Ep 19  Train 0.1080  ValF1 0.5513
GINE-FI | Ep 20  Train 0.1353  ValF1 0.4355
GINE-FI | Ep 21  Train 0.1179  ValF1 0.6479
GINE-FI | Ep 22  Train 0.1316  ValF1 

In [None]:
# Select a subgraph index to visualize (you can change it)
i_viz = 42

# Get the subgraph from your enhanced full_dataset
sample_graph = full_dataset[i_viz]

# Print graph stats for context
print("Subgraph info:")
print(f"Nodes: {sample_graph.num_nodes}, Edges: {sample_graph.num_edges}")
print(f"Graph-level label (y): {sample_graph.y.item()}")
print(f"Unique edge_mask values: {sample_graph.edge_mask.unique()}")


In [None]:
import networkx as nx
import matplotlib.pyplot as plt
from torch_geometric.utils import to_networkx

# Convert to NetworkX for plotting (undirected)
G = to_networkx(sample_graph, to_undirected=True)

# Map node features for coloring or sizing (e.g., betweenness or voltage deviation)
node_color = sample_graph.x[:, -1].numpy()  # Voltage deviation (last column)

# Map edge status: 1 = failed during cascade (cascading edge)
edge_color = [
    'red' if sample_graph.edge_mask[i].item() == 1 else 'gray'
    for i in range(sample_graph.edge_index.shape[1])
]

# Create layout for nice visuals
pos = nx.spring_layout(G, seed=42)


In [None]:
plt.figure(figsize=(8, 6))

# Draw edges and labels
nx.draw_networkx_edges(G, pos, edge_color='black')
nx.draw_networkx_labels(G, pos)

# Draw nodes with color map
nodes = nx.draw_networkx_nodes(
    G, pos,
    node_color=node_color,
    cmap='coolwarm',
    node_size=300
)

# Add colorbar using the actual mappable from draw_networkx_nodes
plt.colorbar(nodes, label='Voltage Deviation')

plt.title("Power Grid Topology (Node color: Voltage Deviation)")
plt.axis('off')
plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns
from matplotlib.cm import ScalarMappable

# Set clean seaborn styling
sns.set_style("whitegrid")

# Recreate failure steps
random.seed(42)
failed_edges_idx = [i for i, val in enumerate(sample_graph.edge_mask) if val == 1]
random.shuffle(failed_edges_idx)

step1 = failed_edges_idx[:len(failed_edges_idx)//3]
step2 = failed_edges_idx[len(failed_edges_idx)//3:2*len(failed_edges_idx)//3]
step3 = failed_edges_idx[2*len(failed_edges_idx)//3:]

cascade_steps = {
    "Initial Outage": step1,
    "Step 1 Propagation": step2,
    "Step 2 Propagation": step3
}
final_step = failed_edges_idx  # All failures

# Voltage deviation for node coloring
node_vals = sample_graph.x[:, -1].numpy()

# Plot setup
fig, axs = plt.subplots(2, 2, figsize=(14, 12))
titles = list(cascade_steps.keys()) + ["Final Cascade State"]
pos = nx.spring_layout(G, seed=42)
edges = list(G.edges())

# Loop through each subplot
for i, ax in enumerate(axs.flatten()):
    # Get edges to highlight for this step
    if i < 3:
        highlight = cascade_steps[titles[i]]
    else:
        highlight = final_step

    # Edge coloring
    edge_colors = ['crimson' if idx in highlight else '#cccccc' for idx in range(len(edges))]

    # Draw components
    nodes = nx.draw_networkx_nodes(G, pos, node_color=node_vals, cmap='coolwarm', node_size=400, ax=ax)
    nx.draw_networkx_edges(G, pos, edge_color=edge_colors, width=2.5, ax=ax)
    nx.draw_networkx_labels(G, pos, font_size=9, ax=ax)

    ax.set_title(titles[i], fontsize=14, fontweight='bold')
    ax.axis('off')

    # Add legend only in final plot
    if i == 3:
        legend_elements = [
            mpatches.Patch(color='crimson', label='Failed Edge'),
            mpatches.Patch(color='#cccccc', label='Active Edge'),
            mpatches.Patch(color='blue', label='Low Voltage Deviation'),
            mpatches.Patch(color='red', label='High Voltage Deviation')
        ]
        ax.legend(handles=legend_elements, loc='lower left', fontsize=10)

# Add shared colorbar for node voltage deviation
cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.7])  # [left, bottom, width, height]
sm = ScalarMappable(cmap='coolwarm')
sm.set_array(node_vals)
cbar = fig.colorbar(sm, cax=cbar_ax)
cbar.set_label("Voltage Deviation", fontsize=12)

# Title & layout
plt.suptitle("Simulated Cascading Failure Progression (Edge Failures Over Steps)", fontsize=18, fontweight='bold')
plt.tight_layout(rect=[0, 0, 0.9, 0.95])
plt.subplots_adjust(hspace=0.3)
plt.show()
