In [None]:
##### New Experiment Adaeze Experiment_1 
# Cell 1: Basic Imports and Data Loading

import os
import json
import torch
import numpy as np
from pathlib import Path
from collections import Counter
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, random_split
from torch_geometric.data import Data
from torch_geometric.loader import DataLoader as PyGDataLoader
from torch_geometric.nn import GINEConv, GATConv, GraphConv



# Load dataset
file_path = 'dataset/ieee24/ieee24/processed_r/data.pt'
loaded_data = torch.load(file_path)

# Check dataset structure
print("Loaded data type:", type(loaded_data))
print("Length of dataset tuple:", len(loaded_data))

# Inspect first element (summary only)
print("\nFirst element type:", type(loaded_data[0]))

# Extract metadata dictionary
metadata_dict = loaded_data[1]
print("\nMetadata keys:", list(metadata_dict.keys()))

# Preview metadata values (first 10 entries)
for key, val in metadata_dict.items():
    preview = val[:10] if hasattr(val, '__len__') else "N/A"
    print(f"{key}: {type(val)}, first 10: {preview}")


In [None]:
# Cell 2: Add strict validation


def get_subgraph(data_flat, meta_dict, i):
    """
    data_flat: The big flattened Data object (loaded_data[0])
    meta_dict: The dictionary of offsets (loaded_data[1])
    i        : Index of the subgraph we want to reconstruct

    returns: a PyG Data object representing the i-th subgraph
    """
    # 1) Node offsets
    x_start = meta_dict['x'][i].item()
    x_end   = meta_dict['x'][i+1].item()
    x_i = data_flat.x[x_start:x_end]
    
    # 2) Edge offsets
    e_start = meta_dict['edge_index'][i].item()
    e_end   = meta_dict['edge_index'][i+1].item()
    edge_index_i = data_flat.edge_index[:, e_start:e_end]
    edge_attr_i  = data_flat.edge_attr[e_start:e_end]
    
    # 3) Load TRUE binary edge labels (explanation_mask)
    # --------------------------------------------------
    edge_mask_i = data_flat.edge_mask[e_start:e_end].float()  # Convert to float
    
    # NEW: Strict validation
    if not torch.all(torch.isin(edge_mask_i, torch.tensor([0., 1.]))):
        print(f"BAD SUBGRAPH {i}:")
        print("Unique values:", edge_mask_i.unique())
        print("Edge indices:", edge_index_i)
        raise ValueError("Edge mask contains non-binary values")
    
    # 4) Graph label (binary or multi-class)
    y_i = data_flat.y[i]
    
    # 5) Build a new Data object
    subgraph_i = Data(
        x=x_i,
        edge_index=edge_index_i,
        edge_attr=edge_attr_i,
        y=y_i.unsqueeze(0),  # Keep graph-level label if needed
        edge_mask=edge_mask_i  # Add binary edge labels
    )
    
    return subgraph_i

# Test subgraph reconstruction
i_test = 0
subgraph_0 = get_subgraph(loaded_data[0], loaded_data[1], i_test)
print("Subgraph 0:")
print(subgraph_0)
print("Edge mask values:", subgraph_0.edge_mask.unique())  # Should be [0., 1.]


In [None]:
#Cell 3: Verification 

# Sample dataset (replace with your actual dataset)
dataset = [
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 1, 0]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3),  # Category B or D
    Data(y=torch.tensor(1), edge_mask=torch.tensor([0, 0, 1]), num_edges=3),  # Category A or C
    Data(y=torch.tensor(0), edge_mask=torch.tensor([0, 0, 0]), num_edges=3)   # Category B or D
]

def verify_edge_mask_coverage(dataset):
    """Check if edge_mask is defined for all graphs with cascading failures (y=1)."""
    print("Verifying Edge Mask Coverage")
    has_cascading = 0
    has_edge_mask_defined = 0
    
    for i, graph in enumerate(dataset):
        if graph.y.item() == 1:  # Graphs with cascading failures (Categories A and C)
            has_cascading += 1
            if graph.edge_mask is not None and len(graph.edge_mask) == graph.num_edges:
                has_edge_mask_defined += 1
            else:
                print(f"Graph {i}: Missing or incomplete edge_mask for cascading failure graph.")
    
    print(f"Graphs with cascading failures: {has_cascading}")
    print(f"Graphs with defined edge_mask: {has_edge_mask_defined}")
    if has_cascading == has_edge_mask_defined:
        print(" Edge mask coverage is complete for cascading failure graphs.")
    else:
        print("Edge mask is missing or incomplete for some cascading failure graphs.")

def check_edge_label_distribution(dataset):
    """Examine the distribution of edge labels (1s and 0s) across graphs."""
    print("\n Checking Edge Label Distribution")
    total_edges = 0
    tripped_edges = 0
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        num_tripped = edge_mask.sum().item()
        total_edges += len(edge_mask)
        tripped_edges += num_tripped
        print(f"Graph {i}: {num_tripped} tripped edges (1s), {len(edge_mask) - num_tripped} non-tripped (0s)")
    
    print(f"Total edges: {total_edges}")
    print(f"Tripped edges (1s): {tripped_edges}")
    print(f"Non-tripped edges (0s): {total_edges - tripped_edges}")
    print(f"Percentage of tripped edges: {(tripped_edges / total_edges * 100):.2f}%")

def validate_graph_edge_consistency(dataset):
    """Ensure edge_mask aligns with graph-level labels (y)."""
    print("\n Validating Graph-Edge Label Consistency ")
    all_valid = True
    
    for i, graph in enumerate(dataset):
        edge_mask = graph.edge_mask
        y = graph.y.item()
        
        if y == 1:  # Categories A and C (cascading failures)
            if edge_mask.sum() == 0:
                print(f"Graph {i}: Inconsistent - y=1 but no tripped edges in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=1 and tripped edges present.")
        elif y == 0:  # Categories B and D (no cascading failures)
            if edge_mask.sum() > 0:
                print(f"Graph {i}: Inconsistent - y=0 but tripped edges present in edge_mask.")
                all_valid = False
            else:
                print(f"Graph {i}: Consistent - y=0 and no tripped edges.")
    
    if all_valid:
        print("All graphs have consistent edge_mask and y labels.")
    else:
        print(" Some graphs have inconsistencies between edge_mask and y.")

# Run the verifications
verify_edge_mask_coverage(dataset)
check_edge_label_distribution(dataset)
validate_graph_edge_consistency(dataset)

In [None]:
# Cell 4: Create a PyTorch Dataset for our subgraphs

class PowerGraphDataset(Dataset):
    def __init__(self, data_flat, meta_dict, indices=None, filter_category_A=True):
        """
        data_flat:  The giant flattened Data object
        meta_dict:  Dictionary of offsets
        indices:    Subgraph indices to include
        filter_category_A: If True, only include graphs with cascading failures (edge_mask != 0)
        """
        super().__init__()
        self.data_flat = data_flat
        self.meta_dict = meta_dict
        self.filter_category_A = filter_category_A
        
        if indices is None:
            # Default to all graphs (0 to num_subgraphs-1)
            self.indices = range(len(meta_dict['x']) - 1)
        else:
            self.indices = indices
        
        # Filter to Category A (DNS > 0 with cascading failures)
        if self.filter_category_A:
            self.indices = self._filter_category_A()
    
    def _filter_category_A(self):
        """Retain indices where edge_mask has at least one failure (1)"""
        valid_indices = []
        for idx in self.indices:
            e_start = self.meta_dict['edge_index'][idx].item()
            e_end = self.meta_dict['edge_index'][idx+1].item()
            edge_mask = self.data_flat.edge_mask[e_start:e_end]  # Use edge_mask
            if edge_mask.sum() > 0:  # At least one failed edge
                valid_indices.append(idx)
        return valid_indices
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        subgraph_id = self.indices[idx]
        return get_subgraph(self.data_flat, self.meta_dict, subgraph_id)

# Create dataset (only Category A graphs)
full_dataset = PowerGraphDataset(loaded_data[0], loaded_data[1], filter_category_A=True)
print("Total subgraphs in full_dataset:", len(full_dataset))

In [None]:
# After creating full_dataset (Cell 5):
all_edge_masks = torch.cat([batch.edge_mask for batch in full_dataset])
num_positive = all_edge_masks.sum().item()
num_negative = len(all_edge_masks) - num_positive

print(f"Edge label distribution:")
print(f"- Failed edges (1): {num_positive} ({num_positive / len(all_edge_masks):.2%})")
print(f"- Stable edges (0): {num_negative} ({num_negative / len(all_edge_masks):.2%})")

In [None]:
# Cell 6: Train/Val/Test split & DataLoaders (with class-aware splitting)

# 1) Handle extreme class imbalance (3.25% positive edges)
# --------------------------------------------------------
# Calculate split sizes based on the filtered Category A dataset
num_subgraphs = len(full_dataset)  # 3444 (from your output)
train_size = int(0.8 * num_subgraphs)   # ~2755
val_size = int(0.1 * num_subgraphs)     # ~344
test_size = num_subgraphs - train_size - val_size  # ~345

# 2) Stratified split to preserve class distribution
# (PyTorch's random_split doesn't stratify, so we use a custom approach)
indices = np.arange(num_subgraphs)
np.random.seed(42)
np.random.shuffle(indices)

train_idx = indices[:train_size]
val_idx = indices[train_size:train_size+val_size]
test_idx = indices[train_size+val_size:]

train_dataset = torch.utils.data.Subset(full_dataset, train_idx)
val_dataset = torch.utils.data.Subset(full_dataset, val_idx)
test_dataset = torch.utils.data.Subset(full_dataset, test_idx)

print(f"Train set size: {len(train_dataset)}")
print(f"Val set size:   {len(val_dataset)}")
print(f"Test set size:  {len(test_dataset)}")

# 3) Build PyG DataLoaders with class-aware sampling
batch_size = 32

# Use weighted sampler to handle edge-level imbalance
graph_weights = [batch.edge_mask.float().mean().item() for batch in full_dataset]  # Proportion of positive edges per graph
train_sample_weights = [graph_weights[i] for i in train_idx]
train_sampler = torch.utils.data.WeightedRandomSampler(
    train_sample_weights, len(train_idx), replacement=True
)

train_loader = PyGDataLoader(
    train_dataset, batch_size=batch_size, sampler=train_sampler
)
val_loader = PyGDataLoader(val_dataset, batch_size=batch_size, shuffle=False)
test_loader = PyGDataLoader(test_dataset, batch_size=batch_size, shuffle=False)

print("DataLoaders created with batch_size =", batch_size)


In [None]:
# Cell 7: Model 

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, num_layers=2):
        super().__init__()
        self.in_channels = input_dim  # Add this line to expose input dimension
        layers = []
        in_dim = input_dim
        for _ in range(num_layers - 1):
            layers.append(nn.Linear(in_dim, hidden_dim))
            layers.append(nn.ReLU())
            in_dim = hidden_dim
        # final layer
        layers.append(nn.Linear(in_dim, output_dim))
        self.mlp = nn.Sequential(*layers)
        
    def forward(self, x):
        return self.mlp(x)
    

class EdgeClassifier(nn.Module):
    """
    A GNN that:
      1. Takes node features x and edge features edge_attr.
      2. Produces node embeddings via GINEConv.
      3. For each edge (u,v), we build an 'edge embedding' = cat([h_u, h_v, edge_attr_uv])
         and pass it to a small MLP to get a prediction logit.
    """
    def __init__(
        self,
        in_channels_node=3,     # Node feature dimension (e.g., 3 in the dataset)
        in_channels_edge=4,     # Edge feature dimension (e.g., 4 in the dataset)
        hidden_dim=32,
        num_layers=2
    ):
        super().__init__()
        
        # 1) We'll do an initial linear projection: node features (3) -> (hidden_dim)
        self.fc_in = nn.Linear(in_channels_node, hidden_dim)
        
        # 2) The GNN "message" MLP used inside GINEConv
        self.gnn_mlp = MLP(input_dim=hidden_dim, hidden_dim=hidden_dim, output_dim=hidden_dim)
        
        # 3) Our GINEConv layer (Fix: Removed in_channels argument)
        self.conv = GINEConv(
            nn=self.gnn_mlp,
            edge_dim=in_channels_edge  # Ensure edge_dim matches input edge features
        )
        
        # 4) A small MLP to classify edges from [h_u, h_v, edge_attr] -> logit
        #    So input_dim = hidden_dim + hidden_dim + in_channels_edge
        edge_in_dim = hidden_dim * 2 + in_channels_edge
        self.edge_mlp = MLP(edge_in_dim, hidden_dim, 1, num_layers=num_layers)
        
    def forward(self, x, edge_index, edge_attr):
        """
        x          : Node features of shape [num_nodes, in_channels_node]
        edge_index : [2, num_edges]
        edge_attr  : Edge features of shape [num_edges, in_channels_edge]
        
        returns:
            edge_logits: shape [num_edges], raw scores for 'edge fails or not'
        """
        # 1) Project node features to hidden dimension
        h = self.fc_in(x)  # [num_nodes, hidden_dim]

        # 2) GINEConv layer
        h = self.conv(h, edge_index, edge_attr)  # [num_nodes, hidden_dim]
        
        # 3) Build edge embeddings for each edge
        h_u = h[edge_index[0]]  # [num_edges, hidden_dim]
        h_v = h[edge_index[1]]  # [num_edges, hidden_dim]
        edge_feats = torch.cat([h_u, h_v, edge_attr], dim=1)
        
        # 4) Predict logit for each edge
        edge_logits = self.edge_mlp(edge_feats).squeeze(dim=-1)  # shape [num_edges]
        
        return edge_logits


In [None]:
# Cell 8: Training Loop for Edge Classification

device = torch.device("cpu")  # Force CPU for debugging

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0):
        super().__init__()
        self.alpha = alpha  # Weight for positive class
        self.gamma = gamma  # Focuses on hard examples

    def forward(self, logits, labels):
        bce_loss = F.binary_cross_entropy_with_logits(logits, labels, reduction='none')
        pt = torch.exp(-bce_loss)  # pt = p if label=1, 1-p otherwise
        focal_loss = (self.alpha * (1 - pt) ** self.gamma * bce_loss).mean()
        return focal_loss

# 2) Instantiate our model
model = EdgeClassifier(
    in_channels_node=3,
    in_channels_edge=4,
    hidden_dim=32,
    num_layers=2
).to(device)

# 3) Define loss with class weighting
pos_weight_value = min(245398.0 / 8252.0, 100.0)  # Cap at 100:1 ratio
pos_weight = torch.tensor([pos_weight_value]).to(device)
#criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
criterion = FocalLoss(alpha=0.75, gamma=2.0)  # alpha >0.5 to emphasize positives

# Define the optimizer before training**
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# 4) Helper function to train for 1 epoch
def train_one_epoch(model, loader, device, optimizer):
    model.train()
    total_loss = 0
    
    for batch in loader:
        batch = batch.to(device)
        
        logits = model(batch.x, batch.edge_index, batch.edge_attr)
        edge_labels = batch.edge_mask.float()
        
        # Compute loss
        loss = criterion(logits, edge_labels)
        
        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        
        # Check for NaN/Infs in gradients
        for name, param in model.named_parameters():
            if param.grad is not None and torch.isnan(param.grad).any():
                print(f" NaN detected in gradients of {name}")
                raise RuntimeError("NaN detected in gradients")
            if param.grad is not None and torch.isinf(param.grad).any():
                print(f" Inf detected in gradients of {name}")
                raise RuntimeError("Inf detected in gradients")
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(loader)

# 5) Enhanced evaluation metrics
def calculate_metrics(preds, labels):
    TP = ((preds == 1) & (labels == 1)).sum().item()
    FP = ((preds == 1) & (labels == 0)).sum().item()
    FN = ((preds == 0) & (labels == 1)).sum().item()
    
    precision = TP / (TP + FP + 1e-8)  # Avoid division by zero
    recall = TP / (TP + FN + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    return precision, recall, f1

# 6) Modified evaluate function
def evaluate(model, loader, device):
    model.eval()
    total_loss = 0
    metrics = {'precision': 0, 'recall': 0, 'f1': 0}
    total_edges = 0
    
    with torch.no_grad():
        for batch in loader:
            batch = batch.to(device)
            logits = model(batch.x, batch.edge_index, batch.edge_attr)
            edge_labels = batch.edge_mask.float()
            
            # Loss
            loss = criterion(logits, edge_labels)
            total_loss += loss.item()
            
            # Metrics
            preds = (torch.sigmoid(logits) > 0.4).long()
            prec, rec, f1 = calculate_metrics(preds, edge_labels.long())
            
            metrics['precision'] += prec * edge_labels.numel()
            metrics['recall'] += rec * edge_labels.numel()
            metrics['f1'] += f1 * edge_labels.numel()
            total_edges += edge_labels.numel()
    
    avg_loss = total_loss / len(loader)
    for key in metrics:
        metrics[key] /= total_edges
    
    return avg_loss, metrics

# 7) Training loop with improved logging
num_epochs = 5

# Ensure `train_loader` and `val_loader` exist
try:
    for epoch in range(1, num_epochs+1):
        # Train
        train_loss = train_one_epoch(model, train_loader, device, optimizer)
        
        # Validate
        val_loss, val_metrics = evaluate(model, val_loader, device)
        
        print(f"Epoch {epoch}/{num_epochs}")
        print(f"  Train Loss: {train_loss:.4f}")
        print(f"  Val Loss:   {val_loss:.4f}")
        print(f"  Val Precision: {val_metrics['precision']:.4f}")
        print(f"  Val Recall:    {val_metrics['recall']:.4f}")
        print(f"  Val F1:        {val_metrics['f1']:.4f}")
        print("-" * 50)
except NameError as e:
    print(f" ERROR: Missing DataLoader! {e}")
    print("Make sure train_loader and val_loader are defined before running the training loop.")
