In [1]:
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import HeteroData
from sklearn.preprocessing import StandardScaler
import warnings
warnings.filterwarnings('ignore')

# Paths
DATA_PATH = "Incidents_imputed.xlsx"
GRAPH_PATH = "Hetro_Final_NW_graph_1.pt"


In [2]:
incident_df = pd.read_excel(DATA_PATH, parse_dates=['Job OFF Time', 'Job ON Time'])
hetero_graph = torch.load(GRAPH_PATH)

print(f"Loaded incidents: {len(incident_df)}")
print("Loaded Graph Structure:")
print(hetero_graph)


Loaded incidents: 292829
Loaded Graph Structure:
HeteroData(
  substation={
    x=[347, 8],
    node_ids=[347],
  },
  (substation, spatial, substation)={
    edge_index=[2, 23378],
    edge_attr=[23378, 8],
  },
  (substation, temporal, substation)={
    edge_index=[2, 38188],
    edge_attr=[38188, 2],
  },
  (substation, causal, substation)={
    edge_index=[2, 10136],
    edge_attr=[10136, 13],
  }
)


In [3]:
# Check if node IDs are stored in metadata
print("Graph Metadata:", hetero_graph.metadata)

# Check if there's a dictionary of node mappings
if hasattr(hetero_graph['substation'], 'node_ids'):
    print("Substation Node IDs:", hetero_graph['substation'].node_ids)
elif hasattr(hetero_graph['substation'], 'node_names'):
    print("Substation Node Names:", hetero_graph['substation'].node_names)
elif hasattr(hetero_graph['substation'], 'mapping'):
    print("Substation Mapping:", hetero_graph['substation'].mapping)
else:
    print("❌ No explicit substation node IDs found in metadata!")


Graph Metadata: <bound method HeteroData.metadata of HeteroData(
  substation={
    x=[347, 8],
    node_ids=[347],
  },
  (substation, spatial, substation)={
    edge_index=[2, 23378],
    edge_attr=[23378, 8],
  },
  (substation, temporal, substation)={
    edge_index=[2, 38188],
    edge_attr=[38188, 2],
  },
  (substation, causal, substation)={
    edge_index=[2, 10136],
    edge_attr=[10136, 13],
  }
)>
Substation Node IDs: ['7317:KONAWA OC PUMP', '7312:JUMPER CREEK', '7506:SASAKWA', '7412:PEARSON', '7508:EMAHAKA', '7417:TRIBBEY', '7307:LITTLE RIVER', '7410:MAUD TAP', '7505:WEWOKA', '7208:CYPRESS', '7429:MACOMB OC PUMP', '7306:FIXICO', '7321:LETHA', '7407:REMINGTON', '7405:SHAWNEE', '7430:INGLEWOOD', '7435:MISSION HILL', '7512:CROMWELL', '8617:SUNNYLANE', '8696:TINKER 6', '8687:TINKER FIELD 5', '8685:TINKER FIELD 4', '8697:TINKER FIELD 3', '8618:BARNES', '8686:GENERAL MOTORS', '7409:MCLOUD', '7411:DALE', '7436:WOLVERINE', '7437:MOBIL CHEMICAL', '7433:ROCK CREEK', '7432:ST GREGORY'

In [4]:
import numpy as np
import pandas as pd
import re

def create_target_variable(df, severe_causes=None, critical_equipment=None, return_thresholds=False):
    """
    Creates target labels using a 180-day window with cleaned substation names.
    """

    # 1 Clean and standardize Job Substation names (Embedded)
    def clean_substation_name(x):
        match = re.match(r'(\d+)\s*:\s*(.+)', str(x).strip().upper())
        return f"{match.group(1)}:{match.group(2).strip()}" if match else None

    df['Job Substation'] = df['Job Substation'].astype(str).apply(clean_substation_name)

    # Drop invalid substations (Ensuring before processing)
    dropped_count = df['Job Substation'].isna().sum()
    df.dropna(subset=['Job Substation'], inplace=True)

    # 2 Compute Cause Thresholds
    cause_stats = df.groupby('Cause Desc').agg(
        median_duration=('Job Duration Mins', 'median'),
        max_customers=('Custs Affected', 'max')
    )
    cause_thresholds = cause_stats.quantile(0.90)
    
    # 3 Identify Severe Causes
    if severe_causes is None:
        severe_causes = cause_stats[
            (cause_stats['median_duration'] > cause_thresholds['median_duration']) &
            (cause_stats['max_customers'] > cause_thresholds['max_customers'])
        ].index.unique()
    
    # 4 Identify Critical Equipment
    equip_stats = df.groupby('Equip Desc').agg(
        failure_freq=('Job Display ID', 'count'),
        saidi=('Job SAIDI', 'mean')
    )
    equipment_thresholds = equip_stats.quantile(0.90)
    
    if critical_equipment is None:
        critical_equipment = df.groupby('Equip Desc').filter(
            lambda x: (len(x) > equipment_thresholds['failure_freq']) and
                      (x['Job SAIDI'].mean() > equipment_thresholds['saidi'])
        )['Equip Desc'].unique()
    
    # 5 Sorting and Computing Intervals
    df_sorted = df.sort_values(['Job Substation', 'Equip Desc', 'Job OFF Time']).copy()
    df_sorted['days_until_next_incident'] = df_sorted.groupby(
        ['Job Substation', 'Equip Desc']
    )['Job OFF Time'].diff(-1).dt.days.abs()
    
    # 6 Assign `needs_replacement` flag
    df_sorted['needs_replacement'] = np.where(
        (df_sorted['Cause Desc'].isin(severe_causes)) & 
        (df_sorted['Equip Desc'].isin(critical_equipment)) & 
        ((df_sorted['days_until_next_incident'] > 180) | 
         df_sorted['days_until_next_incident'].isna()),
        1,
        0
    )

    # 7 Final Aggregation using Cleaned Substations
    final_targets = df_sorted.groupby('Job Substation')['needs_replacement'].max()

    # Debugging Info
    print(f"Dropped {dropped_count} invalid substations due to format issues before processing.")
    print(f"Final unique substations flagged: {final_targets.sum()} out of {final_targets.count()} total substations.")

    # Return results
    if return_thresholds:
        return final_targets, severe_causes, critical_equipment, df_sorted
    else:
        return final_targets, df_sorted

# Run the function using the loaded incident data
full_targets, filtered_incidents = create_target_variable(incident_df, return_thresholds=False)

# Debugging: Check NaN values in `days_until_next_incident`
print("\nSubstations with NaN in days_until_next_incident:")
print(filtered_incidents['days_until_next_incident'].isna().sum())

print("\nSubstations with days_until_next_incident > 180:")
print((filtered_incidents['days_until_next_incident'] > 180).sum())

print("\nFinal Target Distribution After Fix:")
print(full_targets.value_counts())


Dropped 2 invalid substations due to format issues before processing.
Final unique substations flagged: 178 out of 362 total substations.

Substations with NaN in days_until_next_incident:
10648

Substations with days_until_next_incident > 180:
15987

Final Target Distribution After Fix:
needs_replacement
0    184
1    178
Name: count, dtype: int64


In [5]:
import torch
import numpy as np

def integrate_target_variable(graph, target_variable):
    """
    Integrates the computed target labels into the graph structure safely.
    """

    # 1 Extract substation names from graph metadata
    substation_names = graph['substation'].node_ids  

    # 2 Validate alignment (Check if all substations exist in both)
    graph_subs = set(substation_names)
    target_subs = set(target_variable.index)

    missing_from_graph = target_subs - graph_subs
    missing_from_target = graph_subs - target_subs

    if missing_from_graph:
        print(f"⚠ Warning: {len(missing_from_graph)} substations in target are missing from graph.")
        print("Examples:", list(missing_from_graph)[:5])

    if missing_from_target:
        print(f"⚠ Warning: {len(missing_from_target)} substations in graph are missing from target.")
        print("Examples:", list(missing_from_target)[:5])

    # 3 Align using the graph's node order (Ensuring safe reindexing)
    target_variable = target_variable.reindex(substation_names).fillna(0)  # Fill missing targets with 0

    # 4 Convert to tensor and integrate into graph
    graph['substation'].y = torch.tensor(
        target_variable.values.astype(np.float32), dtype=torch.float
    )

    # 5 Final integrity check
    assert len(graph['substation'].y) == len(substation_names), "Target tensor size mismatch with graph nodes!"

    print("\n Targets integrated successfully!")
    print(f"Graph node count: {len(substation_names)}")
    print(f"Target distribution:\n{target_variable.value_counts()}")

    return graph

# Usage: Integrate targets into the loaded graph
hetero_graph = integrate_target_variable(
    graph=hetero_graph,
    target_variable=full_targets  
)


Examples: ['8168:CEMETERY RD', '3329:KEETOOWAH', '4233:COWBOY HILL', '3117:THREE RIVERS', '8417:ROUND BARN']

 Targets integrated successfully!
Graph node count: 347
Target distribution:
needs_replacement
0    177
1    170
Name: count, dtype: int64


In [6]:
# Get incident date range
start_date = incident_df['Job OFF Time'].min().strftime('%Y-%m-%d')
end_date = incident_df['Job OFF Time'].max().strftime('%Y-%m-%d')
print(f"Incidents span: {start_date} to {end_date}")

# Plot monthly incident counts
plt.figure(figsize=(12,4))
incident_df.set_index('Job OFF Time').resample('M')['Job Display ID'].count().plot()
plt.title("Incident Frequency Over Time")
plt.xlabel("Date")
plt.ylabel("# Incidents")
plt.axvline(pd.Timestamp('2022-06-01'), color='red', linestyle='--', label='Proposed Split')
plt.show()

Incidents span: 2015-01-01 to 2021-12-31


NameError: name 'plt' is not defined

In [7]:
# 1 Calculate label determination date for each substation
label_dates = incident_df.groupby('Job Substation')['Job OFF Time'].max() + pd.Timedelta(days=180)

# 2 Compute split dates (Train 60% / Val 20% / Test 20%)
split_train = label_dates.quantile(0.60)  # Train: 60%
split_val = label_dates.quantile(0.80)    # Validation: 20%, Test: 20%

# 3 Assign substations to train, validation, and test sets
train_subs = label_dates[label_dates <= split_train].index
val_subs = label_dates[(label_dates > split_train) & (label_dates <= split_val)].index
test_subs = label_dates[label_dates > split_val].index

# 4 Apply masks to the graph (Ensure `val_mask` is included)
hetero_graph['substation'].train_mask = torch.tensor(
    [n in train_subs for n in hetero_graph['substation'].node_ids], dtype=torch.bool
)
hetero_graph['substation'].val_mask = torch.tensor(  
    [n in val_subs for n in hetero_graph['substation'].node_ids], dtype=torch.bool
)
hetero_graph['substation'].test_mask = torch.tensor(
    [n in test_subs for n in hetero_graph['substation'].node_ids], dtype=torch.bool
)

# 5 Verify split distribution
train_count = hetero_graph['substation'].train_mask.sum().item()
val_count = hetero_graph['substation'].val_mask.sum().item()
test_count = hetero_graph['substation'].test_mask.sum().item()

print(f" Train: {train_count}, Val: {val_count}, Test: {test_count}")
if train_count + val_count + test_count != len(hetero_graph['substation'].node_ids):
    print(f"⚠ Warning: Some nodes may be missing from the split!")


 Train: 206, Val: 70, Test: 71


In [8]:
class WeightedGCNConv(nn.Module):
    def __init__(self, in_channels, out_channels, edge_attr_dim):
        super().__init__()
        self.gcn = GCNConv(in_channels, out_channels, add_self_loops=False)
        
        # Dynamically infer edge_attr_dim at runtime
        self.edge_mlp = nn.Sequential(
            nn.Linear(edge_attr_dim, 64),  # Adjusted dynamically
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )

        self.batch_norm = nn.BatchNorm1d(out_channels)
        self.residual_proj = nn.Linear(in_channels, out_channels) if in_channels != out_channels else nn.Identity()
        
        self._init_weights()

    def _init_weights(self):
        for layer in self.edge_mlp:
            if isinstance(layer, nn.Linear):
                nn.init.kaiming_normal_(layer.weight, mode='fan_out', nonlinearity='relu')
                nn.init.constant_(layer.bias, 0.1)

    def forward(self, x, edge_index, edge_attr):
        if edge_index.size(1) == 0:
            return self.residual_proj(x)  # Handle empty graphs safely

        #  Dynamically check edge attributes
        edge_dim = edge_attr.shape[1] if edge_attr.dim() > 1 else 1

        if edge_dim != self.edge_mlp[0].in_features:
            print(f"⚠ Adjusting edge MLP: Expected {self.edge_mlp[0].in_features}, got {edge_dim}")
            self.edge_mlp[0] = nn.Linear(edge_dim, 64).to(edge_attr.device)

        # Normalize edge attributes
        edge_attr = (edge_attr - edge_attr.mean(dim=0)) / (edge_attr.std(dim=0) + 1e-8)
        edge_attr = edge_attr.unsqueeze(-1) if edge_attr.dim() == 1 else edge_attr
            
        # Compute edge weights
        weights = self.edge_mlp(edge_attr).squeeze()

        # Apply GCN with weighted edges
        out = self.gcn(x, edge_index, edge_weight=weights + 1e-8)

        # Apply batch normalization
        out = self.batch_norm(out)

        return out + self.residual_proj(x)


In [9]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv

class PowerGridGNN(nn.Module):
    def __init__(self, in_channels, hidden_dim, edge_dims, num_layers=3):
        super().__init__()
        self.edge_types = list(edge_dims.keys())  # Dynamic edge types
        self.layers = nn.ModuleList()
        current_dim = in_channels
        
        # Validate edge dimensions
        assert set(edge_dims.keys()) == set(self.edge_types), "Mismatch in provided edge dimensions!"
        
        # Initialize GCN layers
        for _ in range(num_layers):
            self.layers.append(nn.ModuleDict({
                et: WeightedGCNConv(current_dim, hidden_dim, edge_dims[et])
                for et in self.edge_types
            }))
            current_dim = hidden_dim  # Maintain consistency across layers
            
        # Prediction head
        self.head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim // 2),
            nn.BatchNorm1d(hidden_dim // 2),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(hidden_dim // 2, 1),
            nn.Sigmoid()
        )
        
        # Learnable aggregation across edge types
        self.aggregation_weights = nn.Parameter(torch.ones(len(self.edge_types)))
        
        # Initialize weights
        self._init_weights()

    def _init_weights(self):
        """Apply proper initialization to all layers"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')

    def _register_edge_stats(self, data):
        """Compute edge feature statistics from data and store them for normalization"""
        for et in self.edge_types:
            attr = data['substation', et, 'substation'].edge_attr
            self.register_buffer(f'{et}_mean', attr.mean(dim=0))
            self.register_buffer(f'{et}_std', attr.std(dim=0) + 1e-8)  # Prevent division by zero

    def forward(self, data, mode='train'):
        x = data['substation'].x
        edge_masks = self._get_edge_masks(data, mode)

        for layer in self.layers:
            messages = []
            for et in self.edge_types:
                edge_data = data['substation', et, 'substation']
                idx, attr = edge_data.edge_index, edge_data.edge_attr
                
                # Apply strict edge masking
                mask = edge_masks[et]
                idx, attr = idx[:, mask], attr[mask]
                
                # Handle empty edges
                if idx.shape[1] == 0:
                    messages.append(x.new_zeros(x.size(0), layer[et].gcn.out_channels))
                else:
                    # Edge Normalization
                    attr = (attr - getattr(self, f'{et}_mean')) / getattr(self, f'{et}_std')
                    messages.append(layer[et](x, idx, attr))

            # Learnable weighted aggregation across edge types
            x = torch.stack(messages, dim=0)  # Shape: (num_edge_types, num_nodes, hidden_dim)
            x = torch.sum(x * F.softmax(self.aggregation_weights, dim=0)[:, None, None], dim=0)
            x = F.relu(x)
            x = F.dropout(x, p=0.3, training=self.training)

        return self.head(x)

    def _get_edge_masks(self, data, mode='train'):
        """Create edge masks to strictly isolate train/test splits"""
        masks = {}
        split_mask = getattr(data['substation'], f"{mode}_mask").bool().to(data['substation'].x.device)

        for et in self.edge_types:
            edge_info = data['substation', et, 'substation']
            edge_index = edge_info.edge_index
            
            # Strict isolation: both nodes must be in the current split
            source_mask = split_mask[edge_index[0]]
            target_mask = split_mask[edge_index[1]]
            masks[et] = source_mask & target_mask  # Ensures strict split enforcement
        
        return masks


In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    """Focal Loss to handle class imbalance"""
    def __init__(self, alpha=0.75, gamma=2):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma

    def forward(self, inputs, targets):
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        return (self.alpha * (1 - pt) ** self.gamma * BCE_loss).mean()

# --- Device Setup ---
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hetero_graph = hetero_graph.to(device)

# --- Hyperparameters ---
in_channels = hetero_graph['substation'].x.size(1)
hidden_channels = 128
edge_attr_dims = {'spatial': 8, 'temporal': 2, 'causal': 13}
num_epochs = 200
patience = 20

# --- Model & Optimizer ---
model = PowerGridGNN(
    in_channels=in_channels,
    hidden_dim=hidden_channels,
    edge_dims=edge_attr_dims,
    num_layers=2
).to(device)

# Register edge stats for normalization (if needed)
model._register_edge_stats(hetero_graph)  

optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=1e-5)
criterion = FocalLoss(alpha=0.75, gamma=2)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)

# --- Training Loop ---
best_val_loss = float('inf')
no_improve = 0

for epoch in range(1, num_epochs + 1):
    model.train()
    optimizer.zero_grad()
    
    out = model(hetero_graph, mode='train')
    loss = criterion(out[hetero_graph['substation'].train_mask].squeeze(),
                     hetero_graph['substation'].y[hetero_graph['substation'].train_mask].squeeze().float())

    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # Prevents exploding gradients
    optimizer.step()
    
    # --- Validation ---
    model.eval()
    with torch.no_grad():
        out_val = model(hetero_graph, mode='val')
        val_loss = F.binary_cross_entropy_with_logits(
            out_val[hetero_graph['substation'].val_mask].squeeze(),
            hetero_graph['substation'].y[hetero_graph['substation'].val_mask].squeeze().float()
        )
    scheduler.step(val_loss)
    
    # --- Early stopping ---
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        no_improve = 0
        torch.save(model.state_dict(), 'best_model.pt')
    else:
        no_improve += 1
    
    # --- Logging ---
    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch} | Train Loss: {loss.item():.4f} | Val Loss: {val_loss.item():.4f} | LR: {optimizer.param_groups[0]['lr']:.2e}")
    
    if no_improve >= patience:
        print("🛑 Early stopping!")
        break

# --- Final Test ---
model.load_state_dict(torch.load('best_model.pt'))
model.eval()
with torch.no_grad():
    out_test = model(hetero_graph, mode='test')
    probs = torch.sigmoid(out_test[hetero_graph['substation'].test_mask].squeeze())
    predictions = (probs > 0.6).float()
    
    # Compute performance metrics
    y_true = hetero_graph['substation'].y[hetero_graph['substation'].test_mask].cpu()
    y_pred = predictions.cpu()
    
    acc = (y_true == y_pred).float().mean()
    precision = (y_true * y_pred).sum() / (y_pred.sum() + 1e-8)
    recall = (y_true * y_pred).sum() / (y_true.sum() + 1e-8)
    f1 = 2 * (precision * recall) / (precision + recall + 1e-8)
    
    print("\n Final Test Metrics:")
    print(f"Accuracy: {acc:.4f} | Precision: {precision:.4f}")
    print(f"Recall: {recall:.4f} | F1: {f1:.4f}")


Epoch 1 | Train Loss: 0.2212 | Val Loss: 0.6851 | LR: 1.00e-03
Epoch 10 | Train Loss: 0.1994 | Val Loss: 0.6700 | LR: 1.00e-03
Epoch 20 | Train Loss: 0.1854 | Val Loss: 0.6690 | LR: 1.00e-03
Epoch 30 | Train Loss: 0.1704 | Val Loss: 0.6678 | LR: 1.00e-03
Epoch 40 | Train Loss: 0.1684 | Val Loss: 0.6648 | LR: 1.00e-03
Epoch 50 | Train Loss: 0.1601 | Val Loss: 0.6634 | LR: 1.00e-03
Epoch 60 | Train Loss: 0.1469 | Val Loss: 0.6605 | LR: 1.00e-03
Epoch 70 | Train Loss: 0.1448 | Val Loss: 0.6582 | LR: 1.00e-03
Epoch 80 | Train Loss: 0.1398 | Val Loss: 0.6563 | LR: 1.00e-03
Epoch 90 | Train Loss: 0.1417 | Val Loss: 0.6537 | LR: 1.00e-03
Epoch 100 | Train Loss: 0.1383 | Val Loss: 0.6511 | LR: 1.00e-03
Epoch 110 | Train Loss: 0.1397 | Val Loss: 0.6480 | LR: 1.00e-03
Epoch 120 | Train Loss: 0.1333 | Val Loss: 0.6453 | LR: 1.00e-03
Epoch 130 | Train Loss: 0.1336 | Val Loss: 0.6430 | LR: 1.00e-03
Epoch 140 | Train Loss: 0.1334 | Val Loss: 0.6415 | LR: 1.00e-03
Epoch 150 | Train Loss: 0.1288 | Val

In [11]:
import itertools
import torch
import torch.nn.functional as F
import numpy as np

# --- Define a function to train and validate the model ---
def train_and_validate(model, data, criterion, optimizer, scheduler, 
                       train_mask, val_mask, device, epochs=100, patience=20):
    best_val_loss = float('inf')
    no_improve = 0

    for epoch in range(1, epochs + 1):
        #  Training Step
        model.train()
        optimizer.zero_grad()
        out = model(data, mode='train')
        loss = criterion(
            out[train_mask].squeeze(),
            data['substation'].y[train_mask].squeeze().float()
        )
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        #  Validation Step
        model.eval()
        with torch.no_grad():
            out_val = model(data, mode='val')
            val_loss = F.binary_cross_entropy_with_logits(
                out_val[val_mask].squeeze(),
                data['substation'].y[val_mask].squeeze().float()
            )
        scheduler.step(val_loss)

        #  Early stopping check
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve = 0
        else:
            no_improve += 1

        if no_improve >= patience:
            print(f" Early stopping at epoch {epoch} (Best Val Loss: {best_val_loss:.4f})")
            break

    return best_val_loss.item()

# --- Hyperparameter grid ---
learning_rates = [0.001, 0.0005]
hidden_dims = [128, 256]
num_layers_list = [2, 3]
weight_decays = [1e-4, 1e-5]

results = []

#  Hyperparameter tuning loop
for lr, hidden_dim, num_layers, wd in itertools.product(learning_rates, hidden_dims, num_layers_list, weight_decays):
    print(f" Testing: lr={lr}, hidden_dim={hidden_dim}, num_layers={num_layers}, wd={wd}")
    
    # Instantiate a fresh model for each run
    model = PowerGridGNN(
        in_channels=hetero_graph['substation'].x.size(1),
        hidden_dim=hidden_dim,
        edge_dims=edge_attr_dims,
        num_layers=num_layers
    ).to(device)
    model._register_edge_stats(hetero_graph)  # Edge normalization

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)
    criterion = FocalLoss(alpha=0.75, gamma=2)
    
    # Train and Validate
    best_val_loss = train_and_validate(
        model, 
        hetero_graph, 
        criterion, 
        optimizer, 
        scheduler, 
        hetero_graph['substation'].train_mask, 
        hetero_graph['substation'].val_mask, 
        device
    )
    
    results.append({
        'lr': lr,
        'hidden_dim': hidden_dim,
        'num_layers': num_layers,
        'wd': wd,
        'val_loss': best_val_loss
    })
    
    print(f" Result: best_val_loss = {best_val_loss:.4f}\n")

# --- 🔹 Print Best Hyperparameter Configuration ---
best_config = min(results, key=lambda x: x['val_loss'])
print("\n Best Hyperparameter Configuration:")
print(best_config)


 Testing: lr=0.001, hidden_dim=128, num_layers=2, wd=0.0001
 Result: best_val_loss = 0.6544

 Testing: lr=0.001, hidden_dim=128, num_layers=2, wd=1e-05
 Result: best_val_loss = 0.6639

 Testing: lr=0.001, hidden_dim=128, num_layers=3, wd=0.0001
 Result: best_val_loss = 0.6438

 Testing: lr=0.001, hidden_dim=128, num_layers=3, wd=1e-05
 Result: best_val_loss = 0.6506

 Testing: lr=0.001, hidden_dim=256, num_layers=2, wd=0.0001
 Result: best_val_loss = 0.6332

 Testing: lr=0.001, hidden_dim=256, num_layers=2, wd=1e-05
 Result: best_val_loss = 0.6389

 Testing: lr=0.001, hidden_dim=256, num_layers=3, wd=0.0001
 Result: best_val_loss = 0.6369

 Testing: lr=0.001, hidden_dim=256, num_layers=3, wd=1e-05
 Result: best_val_loss = 0.6358

 Testing: lr=0.0005, hidden_dim=128, num_layers=2, wd=0.0001
 Result: best_val_loss = 0.6653

 Testing: lr=0.0005, hidden_dim=128, num_layers=2, wd=1e-05
 Result: best_val_loss = 0.6664

 Testing: lr=0.0005, hidden_dim=128, num_layers=3, wd=0.0001
 Result: bes

In [12]:
# --- Load Best Hyperparameters ---
best_params = {'lr': 0.001, 'hidden_dim': 256, 'num_layers': 2, 'wd': 0.0001}

# --- Initialize Model ---
best_model = PowerGridGNN(
    in_channels=hetero_graph['substation'].x.size(1),
    hidden_dim=best_params['hidden_dim'],
    edge_dims=edge_attr_dims,
    num_layers=best_params['num_layers']
).to(device)
best_model._register_edge_stats(hetero_graph)  # Normalize edge attributes

# --- Define Optimizer & Loss ---
optimizer = torch.optim.AdamW(best_model.parameters(), lr=best_params['lr'], weight_decay=best_params['wd'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=0.5, patience=10)
criterion = FocalLoss(alpha=0.75, gamma=2)  # Adjust if needed

# --- Train the Best Model ---
final_val_loss = train_and_validate(
    best_model, 
    hetero_graph, 
    criterion, 
    optimizer, 
    scheduler, 
    hetero_graph['substation'].train_mask, 
    hetero_graph['substation'].val_mask, 
    device, 
    epochs=200, 
    patience=20
)

print(f"✅ Final Model Training Completed. Best Val Loss: {final_val_loss:.4f}")


✅ Final Model Training Completed. Best Val Loss: 0.6159


In [13]:
import torch

# --- Load Best Model Weights ---
best_model.eval()
with torch.no_grad():
    out_test = best_model(hetero_graph, mode='test')
    probs = torch.sigmoid(out_test[hetero_graph['substation'].test_mask].squeeze())
    predictions = (probs > 0.5).float()

# --- Extract Ground Truth ---
y_true = hetero_graph['substation'].y[hetero_graph['substation'].test_mask].cpu()
y_pred = predictions.cpu()

# --- Compute Metrics ---
acc = (y_true == y_pred).float().mean()
precision = (y_true * y_pred).sum() / (y_pred.sum() + 1e-8)
recall = (y_true * y_pred).sum() / (y_true.sum() + 1e-8)
f1 = 2 * (precision * recall) / (precision + recall + 1e-8)

# --- Print Metrics ---
print("\n **Final Test Metrics**:")
print(f" Accuracy: {acc:.4f}")
print(f" Precision: {precision:.4f}")
print(f" Recall: {recall:.4f}")
print(f" F1 Score: {f1:.4f}")



 **Final Test Metrics**:
 Accuracy: 0.6901
 Precision: 0.6901
 Recall: 1.0000
 F1 Score: 0.8167
