In [1]:
import torch
import torch.nn as nn
from torch_geometric.loader import DataLoader
import torch.optim as optim
import json
import os
import torch
import random
from tqdm import tqdm
from torch_geometric.data import InMemoryDataset, HeteroData

### Load Data and Preprocess for GNS

In [None]:
class PFDeltaDataset(InMemoryDataset):
    def __init__(self, root_dir='data', case_name='', split='train', transform=None, pre_transform=None, pre_filter=None, force_reload=False):
        self.split = split
        self.force_reload = force_reload
        root = os.path.join(root_dir, case_name)
        super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)
        self.load(self.processed_paths[self._split_to_idx()])

    def _split_to_idx(self):
        return {'train': 0, 'val': 1, 'test': 2}[self.split]

    @property
    def raw_file_names(self):
        return sorted([f for f in os.listdir(self.raw_dir) if f.endswith('.json')])

    @property
    def processed_file_names(self):
        return ['train.pt', 'val.pt', 'test.pt']

    def build_heterodata(self, pm_case):
        data = HeteroData()

        network_data = pm_case['network']
        solution_data = pm_case['solution']['solution']

        PQ_bus_x, PQ_bus_y = [], []
        PV_bus_x, PV_bus_y = [], []
        PV_demand, PV_generation = [], []
        slack_x, slack_y = [], []
        slack_demand, slack_generation = [], []
        bus_x = []

        PV_to_bus, PQ_to_bus, slack_to_bus = [], [], []
        pq_idx, pv_idx, slack_idx = 0, 0, 0

        for bus_id_str, bus in sorted(network_data['bus'].items(), key=lambda x: int(x[0])):
            bus_id = int(bus_id_str)
            bus_idx = bus_id - 1
            bus_sol = solution_data['bus'][bus_id_str]

            # Shunts
            gs, bs = 0.0, 0.0
            for shunt in network_data['shunt'].values():
                if int(shunt['shunt_bus']) == bus_id:
                    gs += shunt['gs']
                    bs += shunt['bs']
            bus_x.append(torch.tensor([gs, bs]))

            # Load
            pd, qd = 0.0, 0.0
            for load in network_data['load'].values():
                if int(load['load_bus']) == bus_id:
                    pd += load['pd']
                    qd += load['qd']

            # Gen
            pg, qg = 0.0, 0.0
            for gen_id, gen in network_data['gen'].items():
                if int(gen['gen_bus']) == bus_id:
                    gen_sol = solution_data['gen'].get(gen_id)
                    if gen_sol:
                        pg += gen_sol['pg']
                        qg += gen_sol['qg']
                    else:
                        assert gen['gen_status'] == 0, f"Expected gen {gen_id} to be off."

            # Node features
            va, vm = bus_sol['va'], bus_sol['vm']
            if bus['bus_type'] == 1:
                PQ_bus_x.append(torch.tensor([pd, qd]))
                PQ_bus_y.append(torch.tensor([va, vm]))
                PQ_to_bus.append(torch.tensor([pq_idx, bus_idx]))
                pq_idx += 1
            elif bus['bus_type'] == 2:
                PV_bus_x.append(torch.tensor([pg - pd, vm]))
                PV_bus_y.append(torch.tensor([qg - qd, va]))
                PV_demand.append(torch.tensor([pd, qd]))
                PV_generation.append(torch.tensor([pg, qg]))
                PV_to_bus.append(torch.tensor([pv_idx, bus_idx]))
                pv_idx += 1
            elif bus['bus_type'] == 3:
                slack_x.append(torch.tensor([va, vm]))
                slack_y.append(torch.tensor([pg - pd, qg - qd]))
                slack_demand.append(torch.tensor([pd, qd]))
                slack_generation.append(torch.tensor([pg, qg]))
                slack_to_bus.append(torch.tensor([slack_idx, bus_idx]))
                slack_idx += 1

        # Edges
        edge_index, edge_attr, edge_label = [], [], []
        for branch_id_str, branch in sorted(network_data['branch'].items(), key=lambda x: int(x[0])):
            from_bus = int(branch['f_bus']) - 1
            to_bus = int(branch['t_bus']) - 1
            edge_index.append(torch.tensor([from_bus, to_bus]))
            edge_attr.append(torch.tensor([
                branch['br_r'], branch['br_x'],
                branch['g_fr'], branch['b_fr'],
                branch['g_to'], branch['b_to'],
                branch['tap'],  branch['shift']
            ]))

            branch_sol = solution_data['branch'].get(branch_id_str)
            if branch_sol:
                edge_label.append(torch.tensor([
                    branch_sol['pf'], branch_sol['qf'],
                    branch_sol['pt'], branch_sol['qt']
                ]))
            else:
                assert branch['br_status'] == 0, f"Expected branch {branch_id_str} to be outaged."

        # Create graph nodes and edges
        data['PQ'].x = torch.stack(PQ_bus_x)
        data['PQ'].y = torch.stack(PQ_bus_y)

        data['PV'].x = torch.stack(PV_bus_x)
        data['PV'].y = torch.stack(PV_bus_y)
        data['PV'].generation = torch.stack(PV_generation)
        data['PV'].demand = torch.stack(PV_demand)

        data['slack'].x = torch.stack(slack_x)
        data['slack'].y = torch.stack(slack_y)
        data['slack'].generation = torch.stack(slack_generation)
        data['slack'].demand = torch.stack(slack_demand)

        data['bus'].x = torch.stack(bus_x)

        data['bus', 'branch', 'bus'].edge_index = torch.stack(edge_index, dim=1)
        data['bus', 'branch', 'bus'].edge_attr = torch.stack(edge_attr)
        data['bus', 'branch', 'bus'].edge_label = torch.stack(edge_label)

        for link_name, edges in {
            ('PV', 'PV_link', 'bus'): PV_to_bus,
            ('PQ', 'PQ_link', 'bus'): PQ_to_bus,
            ('slack', 'slack_link', 'bus'): slack_to_bus
        }.items():
            edge_tensor = torch.stack(edges, dim=1)
            data[link_name].edge_index = edge_tensor
            data[(link_name[2], link_name[1], link_name[0])].edge_index = edge_tensor.flip(0)

        return data

    def process(self):
        fnames = self.raw_file_names
        random.shuffle(fnames)
        n = len(fnames)

        split_dict = {
            'train': fnames[:int(0.8 * n)],
            'val': fnames[int(0.8 * n): int(0.9 * n)],
            'test': fnames[int(0.9 * n):]
        }

        for split, files in split_dict.items():
            data_list = []
            print(f"Processing split: {split} ({len(files)} files)")
            for fname in tqdm(files, desc=f"Building {split} data"):
                with open(os.path.join(self.raw_dir, fname)) as f:
                    pm_case = json.load(f)
                data = self.build_heterodata(pm_case)
                data_list.append(data)

            data, slices = self.collate(data_list)
            torch.save((data, slices), os.path.join(self.processed_dir, f'{split}.pt'))

In [None]:
class PFDeltaGNSDataset(PFDeltaDataset):
    def __init__(self, root_dir='data', case_name='', split='train', transform=None, pre_transform=None, pre_filter=None, force_reload=False):
        super().__init__()
        # self.split = split
        # self.force_reload = force_reload
        # root = os.path.join(root_dir, case_name)
        # super().__init__(root, transform, pre_transform, pre_filter, force_reload=force_reload)
        # self.load(self.processed_paths[self._split_to_idx()])

In [29]:
import json

json_path = 'gns_data/case14/raw/instance_1.json'

with open(json_path, 'r') as f:
    data = json.load(f)

# Print top-level keys or preview content
print(type(data))
print(data.keys() if isinstance(data, dict) else data[:1])


<class 'dict'>
dict_keys(['network', 'solution'])


In [39]:
data['network']['slack']

KeyError: 'slack'

In [47]:
case_14_data = PFDeltaGNSDataset(root_dir='gns_data', case_name='case14')

Processing...


FileNotFoundError: [Errno 2] No such file or directory: 'data/raw'

In [14]:
case_14_train = torch.load('gns_data/case14/processed/train.pt')
case_14_val = torch.load('gns_data/case14/processed/val.pt')
case_14_test = torch.load('gns_data/case14/processed/test.pt')

In [21]:
example_data = case_14_data[0]

In [None]:
example_data
# in the bus variables we have shunt susceptance and conductance 
# we need to get Pmin, Psetpoints, Pmax for all the PV buses and store them somewhere 

HeteroData(
  PQ={
    x=[9, 2],
    y=[9, 2],
  },
  PV={
    x=[4, 2],
    y=[4, 2],
    generation=[4, 2],
    demand=[4, 2],
  },
  slack={
    x=[1, 2],
    y=[1, 2],
    generation=[1, 2],
    demand=[1, 2],
  },
  bus={ x=[14, 2] },
  (bus, branch, bus)={
    edge_index=[2, 20],
    edge_attr=[20, 8],
    edge_label=[19, 4],
  },
  (PV, PV_link, bus)={ edge_index=[2, 4] },
  (bus, PV_link, PV)={ edge_index=[2, 4] },
  (PQ, PQ_link, bus)={ edge_index=[2, 9] },
  (bus, PQ_link, PQ)={ edge_index=[2, 9] },
  (slack, slack_link, bus)={ edge_index=[2, 1] },
  (bus, slack_link, slack)={ edge_index=[2, 1] }
)

### Graph Neural Solver Architecture

In [None]:
class GraphNeuralSolver(nn.Module):
    def __init__(self, K, hidden_dim, L_input_dim, phi_input_dim, gamma):
        super().__init__()
        self.K = K
        self.gamma = gamma
        self.phi = nn.Sequential(
            nn.Linear(phi_input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.layers = nn.ModuleList(
            NNUpdate(L_input_dim, hidden_dim) for _ in range(K)
        )


    def forward(self, data):  
        """ """
        total_layer_loss = 0

        for k in range(self.K): 

            # Update P and Q values for all buses
            self.global_active_compensation(data) 

            # Compute local power imbalance variables and store power imbalance loss 
            layer_loss = self.local_power_imbalance(data)
            total_layer_loss += layer_loss * torch.pow(self.gamma, self.K - k)

            # Apply the neural network update block 
            self.apply_nn_update(data, k)
        
        return data, total_layer_loss


    def global_active_compensation(self, data):
        # calculate global P_joule
        # calculate P_global
        p_joule = None 
        p_global = None

        # Compute lambda
        lamb = self.compute_lambda(p_global, data)
        lamb = torch.max(0, lamb)

        # calculate P_gi slack (using lambda factor)
        # calculate P_gi non-slack buses

        # calculate Q_gi for all buses
        pass


    def compute_lambda(self, p_global, data):
        # simple scalar calculation
        pass


    def local_power_imbalance(self, data):
        # compute delta P
        # compute delta Q
        # compute delta S -- store these values in the data instance and scale by gamma factor
        pass
    

    def apply_nn_update(self, data, k): 
        # apply phi to the message vector at each node and on the line values 
        # apply the k-th layer NN udpate 
        self.layers[k] # pass inputs in
        pass


class NNUpdate(nn.Module):
    def __init__(self, input_dim, hidden_dim):
        super().__init__()
        self.L_theta = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.L_v = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, 1)
        )
        self.L_m = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

    def forward(self, data):
        # apply V, theta, and m updates within the data itself
        pass



def kirchoff_law_violation_loss(layer_loss, data):
    # get loss per batch, avergae across samples and then backprop
    pass