In [1]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
import torch.optim as optim
import json
import os
import torch
import random
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset, HeteroData

### Load Data and Preprocess for GNS

In [2]:
class PFDeltaDataset(InMemoryDataset):
    def __init__(self, root_dir='data', case_name='', split='train', transform=None, pre_transform=None, pre_filter=None, force_reload=False):
        self.split = split
        self.force_reload = force_reload
        root = os.path.join(root_dir, case_name)
        super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)
        self.load(self.processed_paths[self._split_to_idx()]) 

    def _split_to_idx(self):
        return {'train': 0, 'val': 1, 'test': 2}[self.split]

    @property
    def raw_file_names(self):
        return sorted([f for f in os.listdir(self.raw_dir) if f.endswith('.json')])

    @property
    def processed_file_names(self):
        return ['train.pt', 'val.pt', 'test.pt']

    def build_heterodata(self, pm_case):
        data = HeteroData()

        network_data = pm_case['network']
        solution_data = pm_case['solution']['solution']

        PQ_bus_x, PQ_bus_y = [], []
        PV_bus_x, PV_bus_y = [], []
        PV_demand, PV_generation = [], []
        slack_x, slack_y = [], []
        slack_demand, slack_generation = [], []
        bus_x = []

        PV_to_bus, PQ_to_bus, slack_to_bus = [], [], []
        pq_idx, pv_idx, slack_idx = 0, 0, 0
        gen_limits, gen_setpoints, more_gen_data  = [], [], []

        for bus_id_str, bus in sorted(network_data['bus'].items(), key=lambda x: int(x[0])):
            bus_id = int(bus_id_str)
            bus_idx = bus_id - 1
            bus_sol = solution_data['bus'][bus_id_str]

            # Shunts 
            gs, bs = 0.0, 0.0
            for shunt in network_data['shunt'].values():
                if int(shunt['shunt_bus']) == bus_id:
                    gs += shunt['gs']
                    bs += shunt['bs']
            bus_x.append(torch.tensor([gs, bs]))

            # Load
            pd, qd = 0.0, 0.0
            for load in network_data['load'].values():
                if int(load['load_bus']) == bus_id:
                    pd += load['pd']
                    qd += load['qd']

            # Gen
            pg, qg = 0.0, 0.0
            for gen_id, gen in sorted(network_data['gen'].items(), key=lambda x: int(x[0])):
                if int(gen['gen_bus']) == bus_id: 
                    if gen['gen_status'] == 1:
                        gen_sol = solution_data['gen'][gen_id]
                        pg += gen_sol['pg']
                        qg += gen_sol['qg']
                    else:
                        assert solution_data['gen'].get(gen_id) is None, f"Expected gen {gen_id} to be off."

            # Node features
            va, vm = bus_sol['va'], bus_sol['vm']
            if bus['bus_type'] == 1:
                PQ_bus_x.append(torch.tensor([pd, qd]))
                PQ_bus_y.append(torch.tensor([va, vm]))
                PQ_to_bus.append(torch.tensor([pq_idx, bus_idx]))
                pq_idx += 1
            elif bus['bus_type'] == 2:
                PV_bus_x.append(torch.tensor([pg - pd, vm]))
                PV_bus_y.append(torch.tensor([qg - qd, va]))
                PV_demand.append(torch.tensor([pd, qd]))
                PV_generation.append(torch.tensor([pg, qg]))
                PV_to_bus.append(torch.tensor([pv_idx, bus_idx]))
                pv_idx += 1
            elif bus['bus_type'] == 3:
                slack_x.append(torch.tensor([va, vm]))
                slack_y.append(torch.tensor([pg - pd, qg - qd]))
                slack_demand.append(torch.tensor([pd, qd]))
                slack_generation.append(torch.tensor([pg, qg]))
                slack_to_bus.append(torch.tensor([slack_idx, bus_idx]))
                slack_idx += 1
        
        # Generator
        for gen_id, gen in sorted(network_data['gen'].items(), key=lambda x: int(x[0])):
            if gen['gen_status'] == 1:
                gen_sol = solution_data['gen'][gen_id]
                pmin, pmax, qmin, qmax = gen['pmin'], gen['pmax'], gen['qmin'], gen['qmax']
                pgen, qgen = gen_sol['pg'], gen_sol['qg']
                gen_limits.append(torch.tensor([pmin, pmax, qmin, qmax]))
                gen_setpoints.append(torch.tensor([pgen, qgen]))
                is_slack = torch.tensor(
                        1 if network_data['bus'][str(gen['gen_bus'])]['bus_type'] == 3 else 0,
                        dtype=torch.bool
                                )
                gen_bus = torch.tensor(gen['gen_bus']) - 1  # zero-indexed
                more_gen_data.append(torch.stack([gen_bus, is_slack]))
            else:
                assert solution_data['gen'].get(gen_id) is None, f"Expected gen {gen_id} to be off."

        # Edges
        edge_index, edge_attr, edge_label = [], [], []
        for branch_id_str, branch in sorted(network_data['branch'].items(), key=lambda x: int(x[0])):
            if branch['br_status'] == 0:
                continue  # Skip inactive branches

            from_bus = int(branch['f_bus']) - 1 
            to_bus = int(branch['t_bus']) - 1
            edge_index.append(torch.tensor([from_bus, to_bus]))
            edge_attr.append(torch.tensor([
                branch['br_r'], branch['br_x'],
                branch['g_fr'], branch['b_fr'],
                branch['g_to'], branch['b_to'], 
                branch['tap'],  branch['shift']
            ]))

            branch_sol = solution_data['branch'].get(branch_id_str)
            assert branch_sol is not None, f"Missing solution for active branch {branch_id_str}"

            if branch_sol:
                edge_label.append(torch.tensor([
                    branch_sol['pf'], branch_sol['qf'],
                    branch_sol['pt'], branch_sol['qt']
                ]))

        # Create graph nodes and edges
        data['PQ'].x = torch.stack(PQ_bus_x) 
        data['PQ'].y = torch.stack(PQ_bus_y) 

        data['PV'].x = torch.stack(PV_bus_x) 
        data['PV'].y = torch.stack(PV_bus_y) 
        data['PV'].generation = torch.stack(PV_generation) 
        data['PV'].demand = torch.stack(PV_demand) 

        data['slack'].x = torch.stack(slack_x) 
        data['slack'].y = torch.stack(slack_y) 
        data['slack'].generation = torch.stack(slack_generation) 
        data['slack'].demand = torch.stack(slack_demand) 

        data['bus'].x = torch.stack(bus_x)

        data['gen'].limits = torch.stack(gen_limits)
        data['gen'].setpoints = torch.stack(gen_setpoints)
        data['gen'].more_gen_data = torch.stack(more_gen_data)

        data['bus', 'branch', 'bus'].edge_index = torch.stack(edge_index, dim=1) 
        data['bus', 'branch', 'bus'].edge_attr = torch.stack(edge_attr) 
        data['bus', 'branch', 'bus'].edge_label = torch.stack(edge_label) 

        for link_name, edges in {
            ('PV', 'PV_link', 'bus'): PV_to_bus,
            ('PQ', 'PQ_link', 'bus'): PQ_to_bus,
            ('slack', 'slack_link', 'bus'): slack_to_bus
        }.items():
            edge_tensor = torch.stack(edges, dim=1) 
            data[link_name].edge_index = edge_tensor
            data[(link_name[2], link_name[1], link_name[0])].edge_index = edge_tensor.flip(0)

        return data

    def process(self):
        fnames = self.raw_file_names
        random.shuffle(fnames)
        n = len(fnames)

        split_dict = {
            'train': fnames[:int(0.8 * n)],
            'val': fnames[int(0.8 * n): int(0.9 * n)],
            'test': fnames[int(0.9 * n):]
        }

        for split, files in split_dict.items():
            data_list = []
            print(f"Processing split: {split} ({len(files)} files)")
            for fname in tqdm(files, desc=f"Building {split} data"):
                with open(os.path.join(self.raw_dir, fname)) as f:
                    pm_case = json.load(f)
                data = self.build_heterodata(pm_case)
                data_list.append(data)

            data, slices = self.collate(data_list)
            torch.save((data, slices), os.path.join(self.processed_dir, f'{split}.pt'))

In [7]:
class PFDeltaGNS(PFDeltaDataset): 
    def __init__(self, root_dir='data', case_name='', split='train', transform=None, pre_transform=None, pre_filter=None, force_reload=False):
        super().__init__(root_dir, case_name, split, transform, pre_transform, pre_filter, force_reload)

    def build_heterodata(self, pm_case):
        """ """
        # call base version
        data = super().build_heterodata(pm_case)
        num_buses = data['bus'].x.size(0)

        # Init bus-level fields
        v_buses      = torch.zeros(num_buses)
        theta_buses  = torch.zeros(num_buses)
        pd_buses     = torch.zeros(num_buses)
        qd_buses     = torch.zeros(num_buses)
        pg_buses     = torch.zeros(num_buses)
        qg_buses     = torch.zeros(num_buses)

        for ntype in ['PQ', 'PV', 'slack']:
            node = data[ntype]
            num_nodes = node.x.size(0)

            if ntype == 'PQ':
                # Flat start init
                node.x_gns = torch.tensor([1.0, 0.0]).repeat(num_nodes, 1)
                pd = node.x[:, 0]
                qd = node.x[:, 1]
                pg = torch.zeros_like(pd)
                qg = torch.zeros_like(qd)

            elif ntype == 'PV':
                v = node.x[:, 1:2]
                theta = torch.zeros_like(v)
                node.x_gns = torch.cat([v, theta], dim=1)
                pd = node.demand[:, 0]
                qd = node.demand[:, 1]
                pg = node.generation[:, 0]
                qg = node.generation[:, 1]

            elif ntype == 'slack':
                node.x_gns = node.x.clone()
                pd = node.demand[:, 0]
                qd = node.demand[:, 1]
                pg = node.generation[:, 0]
                qg = node.generation[:, 1]

            # Map to buses
            link_type = ntype + '_link'
            edge_index = data[(ntype, link_type, 'bus')].edge_index
            src, dst = edge_index  # src is local node index, dst is bus index

            x_gns = node.x_gns
            v_dst = x_gns[:, 0]
            theta_dst = x_gns[:, 1]

            v_buses.index_add_(0, dst, v_dst)
            theta_buses.index_add_(0, dst, theta_dst)
            pd_buses.index_add_(0, dst, pd)
            qd_buses.index_add_(0, dst, qd)
            pg_buses.index_add_(0, dst, pg)
            qg_buses.index_add_(0, dst, qg)

        # Store in bus
        data['bus'].v = v_buses
        data['bus'].theta = theta_buses
        data['bus'].pd = pd_buses
        data['bus'].qd = qd_buses
        data['bus'].pg = pg_buses
        data['bus'].qg = qg_buses
        data['bus'].delta_p = torch.zeros_like(v_buses)
        data['bus'].delta_q = torch.zeros_like(v_buses)

        return data


In [8]:
case_14_data = PFDeltaGNS(root_dir='data/gns_data', case_name='case14')

Processing...


Processing split: train (8091 files)


Building train data: 100%|██████████| 8091/8091 [00:49<00:00, 163.12it/s]


Processing split: val (1011 files)


Building val data: 100%|██████████| 1011/1011 [00:06<00:00, 147.02it/s]


Processing split: test (1012 files)


Building test data: 100%|██████████| 1012/1012 [00:06<00:00, 156.46it/s]
Done!


In [18]:
case_14_data[10]['bus']

{'x': tensor([[0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.1900],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000],
        [0.0000, 0.0000]]), 'v': tensor([0.0000, 1.0461, 0.9502, 1.0000, 1.0000, 1.0600, 1.0000, 1.0600, 1.0000,
        1.0000, 1.0000, 1.0000, 1.0000, 1.0000]), 'theta': tensor([1.0600, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000]), 'pd': tensor([0.0000, 0.2017, 0.9490, 0.4597, 0.0744, 0.1209, 0.0000, 0.0000, 0.2653,
        0.1063, 0.0337, 0.0683, 0.1415, 0.1516]), 'qd': tensor([0.0000, 0.0650, 0.4116, 0.1102, 0.0380, 0.0814, 0.0000, 0.0000, 0.1732,
        0.0652, 0.0221, 0.0325, 0.1024, 0.0360]), 'pg': tensor([2.1193, 0.5900, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.00

In [53]:
case_14_data[0]["PV", "PV_link", "bus"]

{'edge_index': tensor([[0, 1, 2, 3],
        [1, 2, 5, 7]])}

In [56]:
case_14_train = torch.load('data/gns_data/case14/processed/train.pt')
case_14_val = torch.load('data/gns_data/case14/processed/val.pt')
case_14_test = torch.load('data/gns_data/case14/processed/test.pt')

In [57]:
case_14_train[0]

HeteroData(
  PQ={
    x=[72819, 2],
    y=[72819, 2],
    x_gns=[72819, 2],
  },
  PV={
    x=[32364, 2],
    y=[32364, 2],
    generation=[32364, 2],
    demand=[32364, 2],
    x_gns=[32364, 2],
  },
  slack={
    x=[8091, 2],
    y=[8091, 2],
    generation=[8091, 2],
    demand=[8091, 2],
    x_gns=[8091, 2],
  },
  bus={
    x=[113274, 2],
    v=[113274],
    theta=[113274],
  },
  gen={
    limits=[517748, 4],
    setpoints=[517748, 2],
  },
  (bus, branch, bus)={
    edge_index=[2, 157202],
    edge_attr=[157202, 8],
    edge_label=[157202, 4],
  },
  (PV, PV_link, bus)={ edge_index=[2, 32364] },
  (bus, PV_link, PV)={ edge_index=[2, 32364] },
  (PQ, PQ_link, bus)={ edge_index=[2, 72819] },
  (bus, PQ_link, PQ)={ edge_index=[2, 72819] },
  (slack, slack_link, bus)={ edge_index=[2, 8091] },
  (bus, slack_link, slack)={ edge_index=[2, 8091] }
)

### Graph Neural Solver Architecture

In [None]:
class GraphNeuralSolver(nn.Module):
    def __init__(self, K, hidden_dim, L_input_dim, phi_input_dim, gamma, data):
        super().__init__()
        self.K = K
        self.gamma = gamma
        self.phi = nn.Sequential(
            nn.Linear(phi_input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.layers = nn.ModuleList(
            NNUpdate(L_input_dim, hidden_dim) for _ in range(K)
        )
        
        # Instantiate message vectors for each bus
        num_nodes = data['bus'].x.size(0)
        data['bus'].m = torch.zeros((num_nodes, hidden_dim))


    def forward(self, data):  
        """ """
        total_layer_loss = 0

        for k in range(self.K): 

            # Update P and Q values for all buses
            self.global_active_compensation(data) 

            # Compute local power imbalance variables and store power imbalance loss 
            layer_loss = self.local_power_imbalance(data)
            total_layer_loss += layer_loss * torch.pow(self.gamma, self.K - k)

            # Apply the neural network update block 
            self.apply_nn_update(data, k)
        
        return data, total_layer_loss


    def global_active_compensation(self, data):
        """ """
        # Compute global power demand 
        p_joule = self.compute_p_joule(data)   
        p_global = self.compute_p_global(data, p_joule)

        # Compute pg_slack and assign new pg value
        pg_slack = self.compute_pg_slack(p_global, data)
        data['bus'].pg[0] = pg_slack
        # TODO: check if the remaining buses need to be kept the same 

        # Compute qg values for each bus
        self.compute_qg(data)
        

    def compute_p_joule(self, data): 
        """ """
        # Extract edge index and attributes
        edge_index = data[('bus', 'branch', 'bus')].edge_index
        edge_attr = data[('bus', 'branch', 'bus')].edge_attr 
        src, dst = edge_index 

        # Edge features
        tau_ij = edge_attr[:, 6]
        shift_ij = edge_attr[: 7]

        # Line admittance features 
        br_r = edge_attr[:, 0]
        br_x = edge_attr[:, 1]
        y = 1 / (torch.complex(br_r, br_x))
        y_ij = torch.abs(y)
        delta_ij = torch.angle(y)

        # Node features
        v_i = data['bus'].v[src]
        v_j = data['bus'].v[dst]
        theta_i = data['bus'].theta[src]
        theta_j = data['bus'].theta[dst]

        # Compute p_global
        term1 = v_i * v_j * y_ij / tau_ij * (
            torch.sin(theta_i - theta_j - delta_ij - shift_ij) +
            torch.sin(theta_j - theta_i - delta_ij + shift_ij)
        )

        term2 = (v_i / tau_ij) ** 2 * y_ij * torch.sin(delta_ij)
        term3 = v_j ** 2 * y_ij * torch.sin(delta_ij)

        p_joule = torch.abs(term1 + term2 + term3).sum()

        return p_joule
    
    def compute_p_global(self, data, p_joule): 
        """ """
        # Extract relevant variables for computation 
        pd = data['bus'].pd 
        v = data['bus'].v
        g_shunt = data['bus'].x[:, 0]

        # Compute p_global
        p_global = (pd + (v ** 2) * g_shunt).sum() + p_joule
        
        return p_global

    def compute_pg_slack(self, p_global, data):
        """ """
        pg_setpoints = data['gen'].setpoints[:, 0] 
        pg_max_vals = data['gen'].limits[:, 1]
        pg_min_vals = data['gen'].limits[:, 0]
        is_slack = data['gen'].more_gen_data[:, 1] == 1 
        pg_setpoints_non_slack = pg_setpoints[~is_slack]
        pg_setpoint_slack = pg_setpoints[is_slack]
        pg_max_slack = pg_max_vals[is_slack] 
        pg_min_slack = pg_min_vals[is_slack]  
        
        if p_global < pg_setpoints.sum(): 
            lamb = (p_global - pg_setpoints_non_slack.sum() - pg_max_slack)
            lamb = lamb / 2*(pg_setpoint_slack - pg_min_slack)
        else: 
            lamb = (p_global - pg_setpoints_non_slack.sum() - 2*pg_setpoint_slack - pg_max_slack)
            lamb = lamb / 2*(pg_max_slack - pg_setpoint_slack)
        
        lamb = torch.max(0, lamb)
        if lamb < 1/2: 
            pg_slack = pg_min_slack + 2 * (pg_setpoint_slack - pg_min_slack) * lamb
        else: 
            pg_slack = 2 * pg_setpoint_slack - pg_max_slack + 2 * (pg_max_slack - pg_setpoint_slack) * lamb
        
        return pg_slack
    
    def compute_qg(self, data): 
        pass 

    def local_power_imbalance(self, data):
        # compute delta P
        # compute delta Q
        # compute delta S -- store these values in the data instance and scale by gamma factor
        pass
    

    def apply_nn_update(self, data, k): 
        # apply phi to the message vector at each node and on the line values 
        # apply the k-th layer NN udpate 
        self.layers[k] # pass inputs in
        pass


class NNUpdate(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.L_theta = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.L_v = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.L_m = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, data):
        # apply V, theta, and m updates within the data itself
        pass



def kirchoff_law_violation_loss(layer_loss, data):
    # get loss per batch, avergae across samples and then backprop
    pass