# Predicting Substation Failure Class (Classes_2) with a Heterogeneous GNN
We construct a substation–level heterogeneous graph from incidents and lines data, attach a binary target (**Classes_2**), and train a relation–aware GNN that leverages **edge attributes**. We report stratified train/val/test metrics and save reproducible artifacts.


In [1]:
# Cell 1 — Setup & Repro
import os, json, random
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn

from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix

from torch_geometric.nn import HeteroConv, GATv2Conv
from torch_geometric.data import HeteroData

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device:", device)


Device: cuda


In [2]:
# Cell 2 — Paths & Config
DATA_PATH   = "IncidentDataFinal.csv"      # incidents CSV
GRAPH_PATH  = "Hetero_Final_NW_graph_fixed_kara_labeled.pt"  # final labeled graph output
TARGET_COL  = "Classes_2"

# optional: where to save trained model + inference config
MODEL_OUT   = "hetero_gatv2_edge_best.pt"
CFG_OUT     = "inference_config.json"
PRED_OUT    = "gnn_predictions_by_node.csv"


In [3]:
# Cell 3 — Load incidents
_cols = pd.read_csv(DATA_PATH, nrows=0).columns
_date_cols = [c for c in ['Job OFF Time','Job ON Time'] if c in _cols]
incident_df = pd.read_csv(DATA_PATH, parse_dates=_date_cols)

# Clean the join key to a canonical form (UPPER + strip)
if 'Job Substation' not in incident_df.columns:
    raise ValueError("IncidentDataFinal.csv must include 'Job Substation'.")
incident_df['Job Substation'] = incident_df['Job Substation'].astype(str).str.strip().str.upper()

print("Incidents shape:", incident_df.shape)


Incidents shape: (264458, 31)


In [4]:
# Cell 4 — Build node table (CSV-aligned)
def create_clean_node_features(incident_df):
    df = incident_df.copy()

    # helper
    def mode_or_nan(s):
        s = s.dropna()
        return s.mode().iloc[0] if not s.empty else np.nan

    storm_col = 'Major Storm Event  Y (Yes) or N (No)'
    if storm_col in df.columns:
        df[storm_col] = (
            df[storm_col].astype(str).str.upper().str.strip()
              .map({'Y':1,'YES':1,'N':0,'NO':0})
        )

    agg_plan = {
        # Node features
        'X': 'mean',
        'Y': 'mean',
        'Voltage': 'mean',
        'Distribution, Substation, Transmission': mode_or_nan,
        'Substation ID': mode_or_nan,
        'PLANTSTATU': mode_or_nan,
        'no_feeders': 'mean',
        'no_cities': 'mean',
        # Network features (already precomputed in your CSV)
        'total_line_length': 'mean',
        'avg_line_voltage': 'mean',
        'num_connections': 'max',
        # Prior incident features
        'prior_avg_cust_affected': 'mean',
        'prior_avg_downtime': 'mean',
        'prior_incident_count': 'mean',
        'prior_avg_no_calls': 'mean',
        # Incident features
        'Call Qty': 'sum',
        'Job Duration Mins': 'mean',
        'Custs Affected': 'sum',
        'Year': mode_or_nan,
        storm_col: 'mean'
    }

    present_aggs = {k:v for k,v in agg_plan.items() if k in df.columns}
    node = df.groupby('Job Substation').agg(present_aggs).reset_index()

    # de-duplicate any duplicate column names (observed for Job Substation)
    node = node.loc[:, ~node.columns.duplicated()]

    # order (keep only those present)
    #node_features = [
    #   'X','Y','Voltage','Distribution, Substation, Transmission',
    #  'Job Substation','Substation ID','PLANTSTATU','no_feeders','no_cities'
    #]

    node_features = [
    'X','Y','Voltage','Distribution, Substation, Transmission',
    'Substation ID','PLANTSTATU','no_feeders','no_cities'
    ]

    network_features = ['total_line_length','avg_line_voltage','num_connections']
    prior_incident_features = ['prior_avg_cust_affected','prior_avg_downtime',
                               'prior_incident_count','prior_avg_no_calls']
    incident_features = ['Call Qty','Job Duration Mins','Custs Affected','Year',
                         'Major Storm Event  Y (Yes) or N (No)']

    desired = (['Job Substation'] + node_features + network_features +
               prior_incident_features + incident_features)
    node = node[[c for c in desired if c in node.columns]].copy()
    return node

node_features_clean = create_clean_node_features(incident_df)
node_features_clean['Job Substation'] = node_features_clean['Job Substation'].astype(str).str.strip().str.upper()
print("node_features_clean:", node_features_clean.shape)
print(node_features_clean[['Job Substation','Substation ID','X','Y']].head())


node_features_clean: (194, 21)
       Job Substation  Substation ID          X          Y
0  3109:HONOR HEIGHTS           3109 -95.412010  35.771211
1      3110:RIVERSIDE           3110 -95.322400  35.762029
2    3111:FIVE TRIBES           3111 -95.396671  35.725030
3       3114:TENNYSON           3114 -95.396408  35.746585
4        3128:HANCOCK           3128 -95.347455  35.724962


In [5]:
# Cell 5 — Extract edge DataFrames from the loaded graph (PT object)

import pandas as pd
import numpy as np
import torch

# Ensure the graph is loaded as `g`
try:
    _ = g['substation'].x
except Exception:
    g = torch.load(GRAPH_PATH)

# Node names (as stored in PT)
node_names = [str(n).strip().upper() for n in getattr(g['substation'], 'node_ids', [])]
idx_to_name = {i: name for i, name in enumerate(node_names)}

def _default_cols(dim):
    return [f"attr_{i}" for i in range(dim)]

def _cols_for(rel, dim):
    # Best-effort names based on your build shapes:
    # spatial: 8, temporal: 2, causal: 4
    if rel == 'spatial' and dim == 8:
        return ['has_line','is_nearby','line_voltage','line_length_km',
                'shared_cities','shared_feeders','distance_km','weight']
    if rel == 'temporal' and dim == 2:
        return ['total_weight','count']
    if rel == 'causal' and dim == 4:
        return ['z_score','cooccur_ratio','window_hrs','cooccur_count']
    return _default_cols(dim)

edge_frames = {}

for rel in g.edge_types:  # e.g., ('substation','spatial','substation')
    et_name = rel[1]      # 'spatial' | 'temporal' | 'causal'
    store = g[rel]

    # edge_index -> source/target names
    ei = store.edge_index.detach().cpu().numpy() if hasattr(store, 'edge_index') else np.zeros((2,0), dtype=int)
    src_idx, dst_idx = ei[0], ei[1]
    src_names = [idx_to_name.get(int(i), f"IDX_{int(i)}") for i in src_idx]
    dst_names = [idx_to_name.get(int(i), f"IDX_{int(i)}") for i in dst_idx]

    df = pd.DataFrame({'source': src_names, 'target': dst_names})

    # edge_attr -> numeric columns (with best-effort names)
    ea = getattr(store, 'edge_attr', None)
    if ea is not None and ea.numel() > 0:
        ea_np = ea.detach().cpu().numpy()
        colnames = _cols_for(et_name, ea_np.shape[1])
        # Guard against mismatched counts
        if len(colnames) != ea_np.shape[1]:
            colnames = _default_cols(ea_np.shape[1])
        ea_df = pd.DataFrame(ea_np, columns=colnames)
        df = pd.concat([df, ea_df], axis=1)

    edge_frames[et_name] = df
    print(f"{et_name}: extracted {len(df)} edges | columns: {list(df.columns)}")

# Expose as variables for later inspection/use
spatial_edges  = edge_frames.get('spatial',  pd.DataFrame(columns=['source','target']))
temporal_edges = edge_frames.get('temporal', pd.DataFrame(columns=['source','target']))
causal_edges   = edge_frames.get('causal',   pd.DataFrame(columns=['source','target']))

# Quick peek
for nm, edf in [('spatial', spatial_edges), ('temporal', temporal_edges), ('causal', causal_edges)]:
    print(f"\n{nm} head:")
    print(edf.head())


spatial: extracted 7738 edges | columns: ['source', 'target', 'has_line', 'is_nearby', 'line_voltage', 'line_length_km', 'shared_cities', 'shared_feeders', 'distance_km', 'weight']
temporal: extracted 19134 edges | columns: ['source', 'target', 'total_weight', 'count']
causal: extracted 7192 edges | columns: ['source', 'target', 'z_score', 'cooccur_ratio', 'window_hrs', 'cooccur_count']

spatial head:
               source          target  has_line  is_nearby  line_voltage  \
0  3109:HONOR HEIGHTS  3110:RIVERSIDE       0.0        1.0           0.0   
1  3109:HONOR HEIGHTS   3114:TENNYSON       1.0        1.0          69.0   
2  3109:HONOR HEIGHTS    3205:SAPULPA       0.0        1.0           0.0   
3  3109:HONOR HEIGHTS      3209:BIXBY       0.0        1.0           0.0   
4  3109:HONOR HEIGHTS  7605:DRUMRIGHT       0.0        1.0           0.0   

   line_length_km  shared_cities  shared_feeders  distance_km    weight  
0        0.000000            0.0             0.0     9.964664  0

  g = torch.load(GRAPH_PATH)


In [6]:
# Cell 6 — Build hetero graph keyed by Job Substation names
def build_hetero_by_name(node_df, spatial_edges, temporal_edges, causal_edges, save_path):
    nd = node_df.copy()
    nd['Job Substation'] = nd['Job Substation'].astype(str).str.strip().str.upper()

    # numeric node features for x
    X_num = (nd.select_dtypes(include=[np.number])
               .replace([np.inf, -np.inf], np.nan)
               .fillna(0.0))
    node_names = nd['Job Substation'].tolist()
    name_to_idx = {n:i for i,n in enumerate(node_names)}

    data = HeteroData()
    data['substation'].x = torch.tensor(X_num.to_numpy(), dtype=torch.float32)
    data['substation'].node_ids = node_names
    if 'Substation ID' in nd.columns:
        data['substation'].substation_id = nd['Substation ID'].tolist()

    def _add(df, etype):
        keep = df['source'].isin(name_to_idx) & df['target'].isin(name_to_idx)
        kept = df.loc[keep].copy()
        src = kept['source'].map(name_to_idx).to_numpy()
        dst = kept['target'].map(name_to_idx).to_numpy()
        ei = torch.tensor(np.vstack([src, dst]), dtype=torch.long)
        data['substation', etype, 'substation'].edge_index = ei

        ea = (kept.drop(columns=['source','target'], errors='ignore')
                   .select_dtypes(include=[np.number])
                   .replace([np.inf, -np.inf], np.nan)
                   .fillna(0.0))
        if ea.shape[1] > 0:
            data['substation', etype, 'substation'].edge_attr = torch.tensor(ea.to_numpy(), dtype=torch.float32)
        print(f"{etype}: provided={len(df)} | kept={ei.size(1)}")

    _add(spatial_edges,  'spatial')
    _add(temporal_edges, 'temporal')
    _add(causal_edges,   'causal')

    torch.save(data, save_path)
    print("Saved graph ->", save_path)
    return data

g = build_hetero_by_name(node_features_clean, spatial_edges, temporal_edges, causal_edges,
                         save_path="Hetero_Final_NW_graph_fixed_kara.pt")
print(g)


spatial: provided=7738 | kept=7738
temporal: provided=19134 | kept=19134
causal: provided=7192 | kept=7192
Saved graph -> Hetero_Final_NW_graph_fixed_kara.pt
HeteroData(
  substation={
    x=[194, 18],
    node_ids=[194],
    substation_id=[194],
  },
  (substation, spatial, substation)={
    edge_index=[2, 7738],
    edge_attr=[7738, 8],
  },
  (substation, temporal, substation)={
    edge_index=[2, 19134],
    edge_attr=[19134, 2],
  },
  (substation, causal, substation)={
    edge_index=[2, 7192],
    edge_attr=[7192, 4],
  }
)


In [7]:
# Cell 7 — Attach Classes_2 labels
def attach_classes2_by_name(graph, incident_df, target_col="Classes_2", save_path=None):
    g = graph
    sub = g['substation']

    df = incident_df.copy()
    if target_col not in df.columns:
        raise ValueError(f"Missing '{target_col}' in incidents.")
    df[target_col] = pd.to_numeric(df[target_col], errors='coerce')
    df['Job Substation'] = df['Job Substation'].astype(str).str.strip().str.upper()

    labels_by_name = (
        df.groupby('Job Substation')[target_col]
          .apply(lambda s: s.dropna().round().astype(int).mode().iloc[0] if s.dropna().size else np.nan)
    )

    node_names = [str(n).strip().upper() for n in sub.node_ids]
    y_list, mask_list = [], []
    for name in node_names:
        val = labels_by_name.get(name, np.nan)
        if pd.isna(val):
            y_list.append(-1); mask_list.append(False)
        else:
            y_list.append(int(val)); mask_list.append(True)

    sub.y = torch.tensor(y_list, dtype=torch.long)
    sub.train_mask = torch.tensor(mask_list, dtype=torch.bool)

    labeled = int(sub.train_mask.sum().item())
    print(f"Attached '{target_col}' to {labeled}/{len(node_names)} nodes.")
    if save_path:
        torch.save(g, save_path); print("Saved labeled graph ->", save_path)
    return g

g = attach_classes2_by_name(g, incident_df, TARGET_COL, save_path=GRAPH_PATH)


Attached 'Classes_2' to 194/194 nodes.
Saved labeled graph -> Hetero_Final_NW_graph_fixed_kara_labeled.pt


In [8]:
# Cell 8 — Prep: dicts + normalization + (optional) undirected spatial
g = torch.load(GRAPH_PATH).to(device)

edge_index_dict = {rel: g[rel].edge_index.to(device) for rel in g.edge_types}
edge_attr_dict  = {}
edge_dim_dict   = {}
for rel in g.edge_types:
    ea = getattr(g[rel], 'edge_attr', None)
    if ea is not None and ea.numel() > 0:
        mean, std = ea.mean(dim=0, keepdim=True), ea.std(dim=0, keepdim=True)
        std[std == 0] = 1.0
        edge_attr_dict[rel] = ((ea - mean) / std).to(device)
        edge_dim_dict[rel]  = ea.size(1)
    else:
        edge_attr_dict[rel] = None
        edge_dim_dict[rel]  = 0

# Node features normalization (z-score)
x_raw = g['substation'].x.to(device)
x_mean = x_raw.mean(dim=0, keepdim=True)
x_std  = x_raw.std(dim=0, keepdim=True); x_std[x_std==0] = 1.0
x_norm = (x_raw - x_mean) / x_std
x_dict = {'substation': x_norm}

y = g['substation'].y.to(device)
num_nodes = x_norm.size(0)

# Make SPATIAL edges undirected (duplicate reverse edges)
if ('substation','spatial','substation') in g.edge_types:
    rel = ('substation','spatial','substation')
    ei = edge_index_dict[rel]
    ea = edge_attr_dict[rel]
    rev_ei = torch.stack([ei[1], ei[0]], dim=0)
    edge_index_dict[rel] = torch.cat([ei, rev_ei], dim=1)
    if ea is not None:
        edge_attr_dict[rel] = torch.cat([ea, ea], dim=0)
    print("Spatial edges undirected ->", edge_index_dict[rel].size(1))


  g = torch.load(GRAPH_PATH).to(device)


Spatial edges undirected -> 15476


In [9]:
# Cell 9 — Splits & weights
idx = np.arange(num_nodes)
y_np = y.detach().cpu().numpy()

train_idx, tmp_idx = train_test_split(idx, test_size=0.30, stratify=y_np, random_state=SEED)
val_idx, test_idx   = train_test_split(tmp_idx, test_size=0.50, stratify=y_np[tmp_idx], random_state=SEED)

train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device); train_mask[train_idx] = True
val_mask   = torch.zeros(num_nodes, dtype=torch.bool, device=device); val_mask[val_idx]   = True
test_mask  = torch.zeros(num_nodes, dtype=torch.bool, device=device); test_mask[test_idx]  = True

counts = np.bincount(y_np)
class_weights = torch.tensor(counts.max() / counts, dtype=torch.float32, device=device)
print("Class counts:", dict(enumerate(counts)))


Class counts: {0: 104, 1: 90}


In [10]:
# Cell 10 — Model
class HeteroGATv2Edge(nn.Module):
    def __init__(self, hidden=64, out_channels=2, metadata=None, heads=2, dropout=0.2):
        super().__init__()
        self.dropout = dropout
        conv1_dict, conv2_dict = {}, {}
        for rel in metadata[1]:
            edim = edge_dim_dict.get(rel, 0)
            if edim > 0:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
            else:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
            conv1_dict[rel] = conv1
            conv2_dict[rel] = conv2
        self.conv1 = HeteroConv(conv1_dict, aggr='mean')
        self.conv2 = HeteroConv(conv2_dict, aggr='mean')
        self.lin   = nn.Linear(hidden, out_channels)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_dict = self.conv1(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x_dict = {k: F.relu(v) for k,v in x_dict.items()}
        x_dict = {k: F.dropout(v, p=self.dropout, training=self.training) for k,v in x_dict.items()}
        x_dict = self.conv2(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x_dict = {k: F.relu(v) for k,v in x_dict.items()}
        return self.lin(x_dict['substation'])

model = HeteroGATv2Edge(hidden=64, out_channels=2, metadata=g.metadata(), heads=2, dropout=0.2).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.003, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', patience=8, factor=0.5)


In [12]:
# Cell 11 — Train
def train_step():
    model.train()
    logits = model(x_dict, edge_index_dict, edge_attr_dict)
    loss = F.cross_entropy(logits[train_mask], y[train_mask], weight=class_weights, label_smoothing=0.05)
    opt.zero_grad(); loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    return loss.item()

@torch.no_grad()
def eval_f1(mask):
    model.eval()
    logits = model(x_dict, edge_index_dict, edge_attr_dict)
    pred = logits[mask].argmax(1).cpu().numpy()
    true = y[mask].cpu().numpy()
    return f1_score(true, pred, average='macro')

best_f1, best_state, patience, best_epoch = -1, None, 0, 0
for epoch in range(1, 151):
    loss = train_step()
    val_f1 = eval_f1(val_mask)
    scheduler.step(val_f1)

    if val_f1 > best_f1:
        best_f1, best_state, patience, best_epoch = val_f1, {k:v.cpu() for k,v in model.state_dict().items()}, 0, epoch
    else:
        patience += 1

    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | loss {loss:.4f} | val F1 {val_f1:.4f}")
    if patience >= 20:  # early stop
        break

# Load best and evaluate test (argmax @ 0.50)
model.load_state_dict({k:v.to(device) for k,v in best_state.items()})
logits = model(x_dict, edge_index_dict, edge_attr_dict)
pred_test = logits[test_mask].argmax(1).cpu().numpy()
true_test = y[test_mask].cpu().numpy()
test_f1 = f1_score(true_test, pred_test, average='macro')

print(f"\nBest val F1: {best_f1:.4f} @ epoch {best_epoch} | Test F1 (0.50): {test_f1:.4f}")


Epoch 001 | loss 0.3464 | val F1 0.7238
Epoch 010 | loss 0.3109 | val F1 0.7238
Epoch 020 | loss 0.3304 | val F1 0.7238

Best val F1: 0.7238 @ epoch 1 | Test F1 (0.50): 0.6970


In [13]:
# Cell 12 — Threshold tuning
@torch.no_grad()
def logits_all():
    model.eval()
    return model(x_dict, edge_index_dict, edge_attr_dict)

log = logits_all()
probs = torch.softmax(log, dim=1)[:,1].cpu().numpy()
y_true = y.cpu().numpy()
val_mask_np = val_mask.cpu().numpy()
test_mask_np = test_mask.cpu().numpy()

best_t, best_val = 0.5, -1
for t in np.linspace(0.2, 0.8, 25):
    f1v = f1_score(y_true[val_mask_np], (probs[val_mask_np] >= t).astype(int), average='macro')
    if f1v > best_val:
        best_val, best_t = f1v, t

pred_test_thr = (probs[test_mask_np] >= best_t).astype(int)
test_f1_thr = f1_score(y_true[test_mask_np], pred_test_thr, average='macro')

print(f"Best val F1 via thresholding: {best_val:.4f} at t={best_t:.3f}")
print(f"Test F1 at tuned threshold:  {test_f1_thr:.4f}")


Best val F1 via thresholding: 0.7586 at t=0.325
Test F1 at tuned threshold:  0.7285


In [14]:
# Cell 13 — Reports
pred_val_thr = (probs[val_mask_np]  >= best_t).astype(int)
pred_test_thr= (probs[test_mask_np] >= best_t).astype(int)

print("== Val @ tuned t ==")
print(classification_report(y_true[val_mask_np],  pred_val_thr,  digits=4))
print(confusion_matrix(y_true[val_mask_np],  pred_val_thr))

print("\n== Test @ tuned t ==")
print(classification_report(y_true[test_mask_np], pred_test_thr, digits=4))
print(confusion_matrix(y_true[test_mask_np], pred_test_thr))


== Val @ tuned t ==
              precision    recall  f1-score   support

           0     0.8462    0.6875    0.7586        16
           1     0.6875    0.8462    0.7586        13

    accuracy                         0.7586        29
   macro avg     0.7668    0.7668    0.7586        29
weighted avg     0.7750    0.7586    0.7586        29

[[11  5]
 [ 2 11]]

== Test @ tuned t ==
              precision    recall  f1-score   support

           0     0.9000    0.5625    0.6923        16
           1     0.6500    0.9286    0.7647        14

    accuracy                         0.7333        30
   macro avg     0.7750    0.7455    0.7285        30
weighted avg     0.7833    0.7333    0.7261        30

[[ 9  7]
 [ 1 13]]


In [19]:
# Cell 14 — Save artifacts
sub = g['substation']
node_names = [str(n) for n in getattr(sub, 'node_ids', [f"node_{i}" for i in range(len(probs))])]
sub_ids = getattr(sub, 'substation_id', [None]*len(node_names))

def split_tag(i):
    if train_mask[i]: return "train"
    if val_mask[i]:   return "val"
    if test_mask[i]:  return "test"
    return "unused"

pred_df = pd.DataFrame({
    "node_idx": np.arange(len(probs)),
    "Job Substation": node_names,
    "Substation ID": sub_ids,
    "y_true": y_true,
    "prob_1": probs,
    "pred@0.50": (probs >= 0.50).astype(int),
    f"pred@{best_t:.3f}": (probs >= best_t).astype(int),
    "split": [split_tag(i) for i in range(len(probs))]
})
pred_df.to_csv(PRED_OUT, index=False)

best_state_cpu = {k:v.cpu() for k,v in best_state.items()}
torch.save(best_state_cpu, MODEL_OUT)

cfg = {
    "threshold": float(best_t),
    "x_mean": x_mean.detach().cpu().numpy().tolist(),
    "x_std":  x_std.detach().cpu().numpy().tolist(),
    "edge_dims": {str(rel): int(edge_dim_dict[rel]) for rel in edge_dim_dict},
    "metadata": {"node_types": g.node_types, "edge_types": [tuple(r) for r in g.edge_types]},
    "graph_path": GRAPH_PATH,
    "model_path": MODEL_OUT,
    "seed": SEED,
}
with open(CFG_OUT, "w") as f:
    json.dump(cfg, f, indent=2)

print("Saved:", MODEL_OUT, CFG_OUT, PRED_OUT)


Saved: hetero_gatv2_edge_best.pt inference_config.json gnn_predictions_by_node.csv


In [20]:
# Cell 15 — Ablation (no retrain; mask relations at inference)
@torch.no_grad()
def eval_with_mask(disable=set()):
    masked_ei = {}
    masked_ea = {}
    for rel in g.edge_types:
        if rel in disable:
            masked_ei[rel] = torch.empty((2,0), dtype=edge_index_dict[rel].dtype, device=device)
            ea = edge_attr_dict[rel]
            masked_ea[rel] = None if ea is None else torch.empty((0, ea.size(1)), dtype=ea.dtype, device=device)
        else:
            masked_ei[rel] = edge_index_dict[rel]
            masked_ea[rel] = edge_attr_dict[rel]
    logits = model(x_dict, masked_ei, masked_ea)
    p = torch.softmax(logits, dim=1)[:,1].cpu().numpy()
    pred = (p >= best_t).astype(int)
    f1v  = f1_score(y_true[val_mask_np],  pred[val_mask_np],  average='macro')
    f1t  = f1_score(y_true[test_mask_np], pred[test_mask_np], average='macro')
    return f1v, f1t

for drop in [set(),
             {('substation','spatial','substation')},
             {('substation','temporal','substation')},
             {('substation','causal','substation')}]:
    tag = "none" if not drop else "drop:" + ",".join([f"{r[1]}" for r in drop])
    f1v, f1t = eval_with_mask(drop)
    print(f"{tag:>10} -> Val F1 {f1v:.4f} | Test F1 {f1t:.4f}")


      none -> Val F1 0.7586 | Test F1 0.7285
drop:spatial -> Val F1 0.3095 | Test F1 0.3182
drop:temporal -> Val F1 0.6836 | Test F1 0.6528
drop:causal -> Val F1 0.5894 | Test F1 0.4994


**Notes**
- Nodes: 194 substations; features aggregated per `Job Substation` from IncidentDataFinal.csv.
- Edges:
  - spatial: connectivity/nearby with attributes (e.g., `line_length_km`, `distance_km`, etc.), made **undirected** at training time.
  - temporal: directed co-variation edges (`total_weight`, `count`).
  - causal: directed cause co-occurrence edges with `z_score`, `cooccur_count`, `cooccur_ratio`, `window_hrs`.
- Target: `Classes_2` (majority label per substation).
- Split: stratified 70/15/15 over nodes.
- Model: 2-layer Hetero GATv2 with per-relation `edge_dim`, mean aggregation, feature/edge z-scoring, label smoothing, grad clipping.
- Threshold tuning on the validation set improves macro-F1 on the test set.
- Artifacts saved for reproducibility: model weights, normalization parameters, tuned threshold, predictions CSV, and metadata.
