# Binary Substation Classification (Cust_Class) with Hetero GNN
We predict **Cust_Class** (0/1) at the substation level using the existing heterogeneous graph.
Labels are attached by majority vote per *Job Substation* (rows with labels not in {0,1} are dropped).
Training uses a robust stratified split and class‑balanced focal loss.


In [1]:

# Setup & Paths
import os, json, random, math
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torch import nn

from collections import Counter
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, classification_report, confusion_matrix

from torch_geometric.nn import HeteroConv, GATv2Conv

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device:', device)

DATA_PATH = 'IncidentDataFinal.csv'
GRAPH_IN  = 'Hetero_Final_NW_graph_fixed_kara.pt'      # unlabeled graph you already built
GRAPH_OUT = 'Hetero_graph_kara_CustClass_labeled.pt'   # will save labeled (0/1 only)


Device: cuda


In [2]:

# Load incidents (CSV) and base graph
_cols = pd.read_csv(DATA_PATH, nrows=0).columns
_date_cols = [c for c in ['Job OFF Time','Job ON Time'] if c in _cols]
inc = pd.read_csv(DATA_PATH, parse_dates=_date_cols)

g = torch.load(GRAPH_IN)
print('Incidents:', inc.shape)
print(g)


Incidents: (264458, 31)
HeteroData(
  substation={
    x=[194, 18],
    node_ids=[194],
    substation_id=[194],
  },
  (substation, spatial, substation)={
    edge_index=[2, 7738],
    edge_attr=[7738, 8],
  },
  (substation, temporal, substation)={
    edge_index=[2, 19134],
    edge_attr=[19134, 2],
  },
  (substation, causal, substation)={
    edge_index=[2, 7192],
    edge_attr=[7192, 4],
  }
)


  g = torch.load(GRAPH_IN)


In [3]:

# Attach Cust_Class (keep only 0/1), majority per substation, save labeled graph
def attach_binary_label(graph, incident_df, target_col='Cust_Class', save_path=None):
    g = graph
    sub = g['substation']

    df = incident_df.copy()
    if target_col not in df.columns:
        raise ValueError(f"Missing '{target_col}' in incidents.")
    df['Job Substation'] = df['Job Substation'].astype(str).str.strip().str.upper()
    df[target_col] = pd.to_numeric(df[target_col], errors='coerce')

    # keep only labels 0/1
    df = df[df[target_col].isin([0,1])]

    labels_by_name = (
        df.groupby('Job Substation')[target_col]
          .apply(lambda s: s.dropna().astype(int).mode().iloc[0] if s.dropna().size else np.nan)
    )

    node_names = [str(n).strip().upper() for n in getattr(sub, 'node_ids', [])]
    y_list, mask_list = [], []
    for name in node_names:
        v = labels_by_name.get(name, np.nan)
        if pd.isna(v): y_list.append(-1); mask_list.append(False)
        else:          y_list.append(int(v)); mask_list.append(True)

    sub.y = torch.tensor(y_list, dtype=torch.long)
    sub.train_mask = torch.tensor(mask_list, dtype=torch.bool)

    labeled = int(sub.train_mask.sum().item())
    if labeled:
        yy = np.array(y_list)[np.array(mask_list)]
        vals, cnts = np.unique(yy, return_counts=True)
        print('Labeled counts (0/1):', dict(zip(vals.tolist(), cnts.tolist())))
    print(f"Labeled (0/1) nodes: {labeled}/{len(node_names)}")

    if save_path:
        torch.save(g, save_path); print('Saved ->', save_path)
    return g

g = attach_binary_label(g, inc, 'Cust_Class', save_path=GRAPH_OUT)


Labeled counts (0/1): {0: 194}
Labeled (0/1) nodes: 194/194
Saved -> Hetero_graph_kara_CustClass_labeled.pt


In [4]:

# Build tensors, normalize, make spatial undirected
g = torch.load(GRAPH_OUT).to(device)

# Edge dicts + per-relation z-score normalization
edge_index_dict = {rel: g[rel].edge_index.to(device) for rel in g.edge_types}
edge_attr_dict, edge_dim_dict = {}, {}
for rel in g.edge_types:
    ea = getattr(g[rel], 'edge_attr', None)
    if ea is not None and ea.numel() > 0:
        m, s = ea.mean(0, keepdim=True), ea.std(0, keepdim=True)
        s[s==0] = 1.0
        edge_attr_dict[rel] = ((ea - m) / s).to(device)
        edge_dim_dict[rel]  = ea.size(1)
    else:
        edge_attr_dict[rel] = None
        edge_dim_dict[rel]  = 0

# Node features z-score
x_raw = g['substation'].x.to(device)
x_mean, x_std = x_raw.mean(0, keepdim=True), x_raw.std(0, keepdim=True)
x_std[x_std==0] = 1.0
x_dict = {'substation': (x_raw - x_mean) / x_std}

# Labels
y = g['substation'].y.to(device)
num_nodes = x_dict['substation'].size(0)

# Make spatial undirected; keep temporal/causal directed
if ('substation','spatial','substation') in g.edge_types:
    rel = ('substation','spatial','substation')
    ei = edge_index_dict[rel]; ea = edge_attr_dict[rel]
    rev_ei = torch.stack([ei[1], ei[0]], dim=0)
    edge_index_dict[rel] = torch.cat([ei, rev_ei], dim=1)
    if ea is not None:
        edge_attr_dict[rel] = torch.cat([ea, ea], dim=0)
    print('Spatial edges (undirected):', edge_index_dict[rel].size(1))


  g = torch.load(GRAPH_OUT).to(device)


Spatial edges (undirected): 15476


In [5]:

# Robust stratified split with per-class minimums
def split_counts(mask, name):
    yy = y[mask].cpu().numpy()
    c  = dict(Counter(yy))
    print(f"{name:5} size={int(mask.sum().item()):3d} | class counts:", c)

def _check_feasibility(y_all, labeled_idx, train_frac, val_frac, test_frac, min_per_class):
    counts = Counter(y_all[labeled_idx])
    p_rest = 1.0 - train_frac
    p_val  = val_frac / (val_frac + test_frac)
    p_test = test_frac / (val_frac + test_frac)
    need_val  = {c: math.ceil((min_per_class)/(p_rest*p_val)  + 1e-9) for c in counts}
    need_test = {c: math.ceil((min_per_class)/(p_rest*p_test) + 1e-9) for c in counts}
    infeasible = {c: (counts[c], max(need_val[c], need_test[c])) for c in counts if counts[c] < max(need_val[c], need_test[c])}
    return counts, infeasible

def stratified_split_with_min(y_all, labeled_idx, train_frac=0.8, val_frac=0.1, test_frac=0.1,
                              min_per_class=2, base_seed=42, max_tries=500):
    assert abs(train_frac + val_frac + test_frac - 1.0) < 1e-6
    counts, infeasible = _check_feasibility(y_all, labeled_idx, train_frac, val_frac, test_frac, min_per_class)
    if infeasible:
        raise RuntimeError('Split infeasible for some classes; reduce min_per_class or adjust fractions.')

    y_lab = y_all[labeled_idx]
    for t in range(max_tries):
        seed = base_seed + t
        tr_idx, rest_idx = train_test_split(
            labeled_idx, test_size=(1 - train_frac), stratify=y_lab, random_state=seed
        )
        pos_map = {nid:i for i, nid in enumerate(labeled_idx)}
        y_rest  = np.array([y_lab[pos_map[nid]] for nid in rest_idx])

        val_idx, test_idx = train_test_split(
            rest_idx, test_size=test_frac/(val_frac+test_frac),
            stratify=y_rest, random_state=seed
        )
        ok_val  = all(v >= min_per_class for v in Counter(y_all[val_idx]).values())
        ok_test = all(v >= min_per_class for v in Counter(y_all[test_idx]).values())
        if ok_val and ok_test:
            return tr_idx, val_idx, test_idx, seed
    raise RuntimeError('Could not find a split meeting per-class minimums.')

# Labeled nodes only
labeled_idx  = torch.where(y >= 0)[0].cpu().numpy()
y_np         = y.cpu().numpy()

counts, infeasible = _check_feasibility(y_np, labeled_idx, 0.8, 0.1, 0.1, min_per_class=2)
print('Labeled per-class counts:', dict(counts))

train_idx_l, val_idx_l, test_idx_l, used_seed = stratified_split_with_min(
    y_np, labeled_idx, train_frac=0.8, val_frac=0.1, test_frac=0.1, min_per_class=2, base_seed=SEED
)

train_mask = torch.zeros(num_nodes, dtype=torch.bool, device=device); train_mask[torch.tensor(train_idx_l, device=device)] = True
val_mask   = torch.zeros(num_nodes, dtype=torch.bool, device=device); val_mask[torch.tensor(val_idx_l, device=device)]   = True
test_mask  = torch.zeros(num_nodes, dtype=torch.bool, device=device); test_mask[torch.tensor(test_idx_l, device=device)]  = True

print('Used split seed:', used_seed)
split_counts(train_mask, 'train')
split_counts(val_mask,   'val')
split_counts(test_mask,  'test')


Labeled per-class counts: {0: 194}
Used split seed: 42
train size=155 | class counts: {0: 155}
val   size= 19 | class counts: {0: 19}
test  size= 20 | class counts: {0: 20}


In [6]:

# Class-balanced focal for binary
def class_balanced_alpha(labels, num_classes=2, beta=0.9999):
    counts = np.bincount(labels, minlength=num_classes).astype(np.float32)
    effective_num = 1.0 - np.power(beta, counts)
    weights = (1.0 - beta) / np.maximum(effective_num, 1e-8)
    weights = weights / weights.sum() * num_classes
    return torch.tensor(weights, dtype=torch.float32, device=device), counts

ALPHA_CB, cls_counts = class_balanced_alpha(y_np[labeled_idx], num_classes=2, beta=0.9999)
print('Class-balanced alpha:', ALPHA_CB.detach().cpu().numpy(), '| counts:', cls_counts)

def focal_ce_loss(logits, targets, alpha=None, gamma=2.0):
    logp = F.log_softmax(logits, dim=1)
    p    = torch.exp(logp)
    ce   = F.nll_loss(logp, targets, reduction='none')
    pt   = p[torch.arange(p.size(0), device=logits.device), targets]
    loss = ((1 - pt) ** gamma) * ce
    if alpha is not None:
        loss = alpha[targets] * loss
    return loss.mean()


Class-balanced alpha: [1.0407374e-06 1.9999990e+00] | counts: [194.   0.]


In [7]:

# Model (GATv2 + edge_dim), binary (2 logits)
class HeteroGATv2Edge(nn.Module):
    def __init__(self, hidden=128, out_channels=2, metadata=None, heads=4, dropout=0.30):
        super().__init__()
        self.dropout = dropout
        conv1_dict, conv2_dict = {}, {}
        for rel in metadata[1]:
            edim = edge_dim_dict.get(rel, 0)
            if edim > 0:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  edge_dim=edim, add_self_loops=False, dropout=dropout)
            else:
                conv1 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
                conv2 = GATv2Conv((-1,-1), hidden, heads=heads, concat=False,
                                  add_self_loops=False, dropout=dropout)
            conv1_dict[rel] = conv1; conv2_dict[rel] = conv2
        self.conv1 = HeteroConv(conv1_dict, aggr='mean')
        self.conv2 = HeteroConv(conv2_dict, aggr='mean')
        self.bn1   = nn.BatchNorm1d(hidden)
        self.bn2   = nn.BatchNorm1d(hidden)
        self.lin   = nn.Linear(hidden, out_channels)

    def forward(self, x_dict, edge_index_dict, edge_attr_dict):
        x_dict = self.conv1(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x = self.bn1(x_dict['substation'])
        x = torch.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x_dict = {'substation': x}
        x_dict = self.conv2(x_dict, edge_index_dict, edge_attr_dict=edge_attr_dict)
        x = self.bn2(x_dict['substation'])
        x = torch.relu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        return self.lin(x)

model = HeteroGATv2Edge(metadata=g.metadata()).to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.0015, weight_decay=2e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(opt, mode='max', patience=10, factor=0.5)


In [8]:

# Train + early stopping on macro-F1 (val)
def train_step():
    model.train()
    logits = model(x_dict, edge_index_dict, edge_attr_dict)
    loss = focal_ce_loss(logits[train_mask], y[train_mask], alpha=ALPHA_CB, gamma=2.0)
    opt.zero_grad(); loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    opt.step()
    return loss.item()

@torch.no_grad()
def eval_f1(mask):
    model.eval()
    logits = model(x_dict, edge_index_dict, edge_attr_dict)
    pred = logits[mask].argmax(1).cpu().numpy()
    true = y[mask].cpu().numpy()
    return f1_score(true, pred, average='macro')

best_f1, best_state, patience, best_epoch = -1, None, 0, 0
for epoch in range(1, 201):
    loss = train_step()
    val_f1 = eval_f1(val_mask)
    scheduler.step(val_f1)

    if val_f1 > best_f1:
        best_f1, best_state, patience, best_epoch = val_f1, {k:v.cpu() for k,v in model.state_dict().items()}, 0, epoch
    else:
        patience += 1

    if epoch % 10 == 0 or epoch == 1:
        print(f"Epoch {epoch:03d} | loss {loss:.4f} | val F1 {val_f1:.4f}")
    if patience >= 30:
        break

# Final test
model.load_state_dict({k:v.to(device) for k,v in best_state.items()})
@torch.no_grad()
def logits_all():
    model.eval()
    return model(x_dict, edge_index_dict, edge_attr_dict)

log = logits_all()
pred_test = log[test_mask].argmax(1).cpu().numpy()
true_test = y[test_mask].cpu().numpy()
test_f1 = f1_score(true_test, pred_test, average='macro')
print(f"\\nBest val F1: {best_f1:.4f} @ epoch {best_epoch} | Test F1: {test_f1:.4f}")


Epoch 001 | loss 0.0000 | val F1 0.0000
Epoch 010 | loss 0.0000 | val F1 0.0000
Epoch 020 | loss 0.0000 | val F1 0.0000
Epoch 030 | loss 0.0000 | val F1 0.0000
\nBest val F1: 0.0000 @ epoch 1 | Test F1: 0.0000


In [None]:

# Reports + optional threshold tuning for binary
from sklearn.metrics import roc_auc_score

pred_val = log[val_mask].argmax(1).cpu().numpy()
true_val = y[val_mask].cpu().numpy()

print('== Val (argmax) ==')
print(classification_report(true_val, pred_val, digits=4))
print(confusion_matrix(true_val, pred_val))

print('\\n== Test (argmax) ==')
print(classification_report(true_test, pred_test, digits=4))
print(confusion_matrix(true_test, pred_test))

# Optional: threshold on validation to maximize macro-F1
probs = torch.softmax(log, dim=1)[:,1].cpu().numpy()
val_mask_np = val_mask.cpu().numpy(); test_mask_np = test_mask.cpu().numpy()

best_t, best_val_f1 = 0.5, -1.0
for t in np.linspace(0.2, 0.8, 25):
    pv = (probs[val_mask_np] >= t).astype(int)
    f1v = f1_score(true_val, pv, average='macro')
    if f1v > best_val_f1:
        best_val_f1, best_t = f1v, t

pt = (probs[test_mask_np] >= best_t).astype(int)
f1t = f1_score(true_test, pt, average='macro')
print(f"\\nBest val F1 via thresholding: {best_val_f1:.4f} at t={best_t:.3f}")
print(f'Test F1 at tuned threshold:   {f1t:.4f}')
