In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import os
from os.path import join
from datetime import timedelta
from collections import OrderedDict

import torch
import torch_geometric

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.preprocessing import MaxAbsScaler, RobustScaler
from sklearn.preprocessing import FunctionTransformer

from typing import Type, Optional, Tuple, List, Union, Dict, Any, Callable

DataDateDictType = Dict[np.datetime64, torch_geometric.data.HeteroData]
ScalerType = Union[MinMaxScaler, StandardScaler,
                   MaxAbsScaler, RobustScaler, FunctionTransformer]

In [None]:
from data import GraphFlowDataset

processed_path = os.path.join("data", "processed")
dataset = GraphFlowDataset(
            root=processed_path,
            process=True
        )

data = (dataset[0])

loader = torch_geometric.loader.DataLoader(dataset, batch_size=2, num_workers=8)

batch = (next(iter(loader)))

print(data["measurement", "flows", "measurement"].edge_index)
print(batch["measurement", "flows", "measurement"].edge_index)

print(data["measurement"].x.shape)
print(batch["measurement"].x.shape)

In [None]:
class SeqData(torch_geometric.data.Data):
    def __init__(self, edge_indices=None, xs=None, y=None):
        super().__init__()

        self.xs = xs
        self.y = y
        
        if edge_indices is not None:

            self.edge_indices = edge_indices.clone()
            self.num_graphs = edge_indices.size(0)

            for i, edge_index in enumerate(edge_indices):
                self.edge_indices[i] = edge_index + (self.xs[0].size(0) * i)
        else:
            self.edge_indices = edge_indices
    
    def __inc__(self, key, value, *args, **kwargs):
        if key == "edge_indices":
            return self.xs[0].size(0) * self.num_graphs
        else:
            return super().__inc__(key, value, *args, **kwargs)
    
    def __cat_dim__(self, key, value, *args, **kwargs):
        if key == "xs":
            return None
        else:
            return super().__cat_dim__(key, value, *args, **kwargs)

1. B X N X T D
2. B * T graphs
3. N nodes each
4. D features

5. Transpose 1) -> B x T x N x D
6. Flatten 5) -> (B * T * N) x D (here we must make sure all edge indices are correct)
7. Message Passing
8. Unflatten 7) -> B x T x N x D' (D' can be new output dimension)
9. Transpose 7) -> B x N x T x D'
10. Flatten 9) -> B' x T x D'
11. Put 10) through LSTM.

In [None]:
# Create singular datapoint containing multiple graphs
n_graphs = 3
n_outputs = 6
n_nodes = 4
n_features = 8

edge_index = torch.tensor([
    [0, 0, 0],
    [1, 2, 3]
], dtype=torch.long)

edge_indices = edge_index.repeat(n_graphs, 1, 1)

xs = torch.randint(1, 5, (n_graphs, n_nodes, n_features), dtype=torch.float32)
y = torch.randint(6, 10, (n_nodes, n_outputs))

data = SeqData(edge_indices=edge_indices, xs=xs, y=y)


# Create batch from two of these datapoints
data_list = [data, data]

loader = torch_geometric.loader.DataLoader(data_list, batch_size=2,
                                           follow_batch=["xs", "y"])
batch = next(iter(loader))

print(batch)

conv = torch_geometric.nn.GCNConv(n_features, 32)

# Apply convolution
print("\nSingular Data Dimensions:")
print(data.xs.shape)
print(data.edge_indices.shape)

def forward(conv, batch):
    edge_indices = batch.edge_indices
    xs = batch.xs

    print("\nBatch Dimensions:")
    print(xs.shape)
    print(edge_indices.shape)

    print(xs)
    print(edge_indices)
    
    xs = torch.flatten(batch.xs, 0, 2)
    edge_indices = (edge_indices.permute((0, 2, 1)).flatten(0, 1).T)
    
    print("\nReshape Dimensions:")
    print(xs.shape)
    print(edge_indices.shape)

    print(xs)
    print(edge_indices)
    
    print("\nOut of convolution dimension:")
    out = conv(xs, edge_indices)
    print(out.shape)


forward(conv, batch)

In [None]:
from torch_geometric.data import Data, HeteroData

class HeteroSeqData(HeteroData):
    def __init__(self, mapping=None, n_graphs=None, **kwargs):
        super().__init__(mapping, **kwargs)
        # If a mapping is provided increment indices for segmentation of subgraph
        if mapping:
            for key, value in mapping.items():
                # Only consider edges
                if isinstance(key, tuple):
                    src_type, _, dst_type = key
                    edge_indices = value["edge_indices"]
                    new_edge_indices = edge_indices.clone()

                    src_xs = mapping[src_type]["xs"]
                    dst_xs = mapping[dst_type]["xs"]
                    
                    # Increment indices
                    for i, edge_index in enumerate(edge_indices):
                        new_edge_indices[i][0] = edge_index[0] + (src_xs[0].size(0) * i)
                        new_edge_indices[i][1] = edge_index[1] + (src_xs[0].size(0) * i)

                    mapping[key]["edge_indices"] = new_edge_indices  

            # Initialize HeteroData with properly incremented mapping
            super().__init__(mapping, **kwargs)
        
        self.n_graphs = n_graphs
            
    def __inc__(self, key, value, store, *args, **kwargs):
        if key == "edge_indices":
            src_type, _, dst_type = store._key
            return torch.tensor([[self[src_type].xs[0].size(0)], [self[dst_type].xs[0].size(0)]]) * self.n_graphs
        else:
            return super().__inc__(key, value, store, *args, **kwargs)
    
    def __cat_dim__(self, key, value, store, *args, **kwargs):
        if key == "xs":
            return None
        if key == "y":
            return None
        else:
            return super().__cat_dim__(key, value, store, *args, **kwargs)

In [None]:
# Create singular datapoint containing multiple graphs
n_graphs = 3
n_msr_outputs = 6
n_msr_nodes = 4
n_msr_features = 8

n_subsub_nodes = 5
n_subsub_features = 12

msr_edge_index = torch.tensor([
    [1, 2, 3],
    [0, 0, 0]
], dtype=torch.long)
subsub_edge_index = torch.tensor([
    [0, 1, 2, 3],
    [1, 4, 3, 4]
])
submsr_edge_index = torch.tensor([
    [4],
    [1]
])

msr_edge_indices = edge_index.repeat(n_graphs, 1, 1)
subsub_edge_indices = subsub_edge_index.repeat(n_graphs, 1, 1)
submsr_edge_indices = submsr_edge_index.repeat(n_graphs, 1, 1)

msr_xs = torch.randint(1, 5, (n_graphs, n_msr_nodes, n_msr_features), dtype=torch.float32)
msr_y = torch.randint(6, 10, (n_msr_nodes, n_msr_outputs))

subsub_xs = torch.randint(1, 5, (n_graphs, n_subsub_nodes, n_subsub_features), dtype=torch.float32)


mapping = {
            ("measurement", "flows", "measurement"): {"edge_indices": msr_edge_indices},
            "measurement": {"xs": msr_xs, "y": y},
            ("subsub", "flows", "subsub"): {"edge_indices": subsub_edge_indices},
            "subsub": {"xs": subsub_xs},
            ("subsub", "in", "measurement"): {"edge_indices": submsr_edge_indices}
      }

data = HeteroSeqData(mapping, n_graphs)

print(data)
print(data["measurement", "flows", "measurement"].edge_indices)
print(data["subsub", "flows", "subsub"].edge_indices)

In [None]:
# Create batch from two of these datapoints
data_list = [data, data]

loader = torch_geometric.loader.DataLoader(data_list, batch_size=2)
batch = next(iter(loader))
print(batch)

print(batch["measurement", "flows", "measurement"].edge_indices)
print(batch["subsub", "flows", "subsub"].edge_indices)
print(batch["subsub", "in", "measurement"].edge_indices)

In [None]:
def forward(batch, conv):
    new_batch = batch.clone()
    
    print("\nBatch Dimensions")
    print(new_batch)
    
    for node_type in new_batch.node_types:
        new_xs = new_batch[node_type]["xs"].clone()
        new_xs = torch.flatten(new_xs, 0, 2)
        new_batch[node_type]["xs"] = new_xs
    
    for edge_type in new_batch.edge_types:
        new_edge_indices = new_batch[edge_type]["edge_indices"].clone()
        new_edge_indices = new_edge_indices.permute((0, 2, 1)).flatten(0, 1).T
        new_batch[edge_type]["edge_indices"] = new_edge_indices
        
    print("\nReshape Dimensions")
    print(new_batch)
    
    print("\nOut of convolution dimension:")
    out = conv(new_batch.xs_dict, new_batch.edge_indices_dict)
    print(out["measurement"].shape)
    print(out["subsub"].shape)
    
    print("Reshape Out of Convolution")
    msr_out = out["measurement"]
    msr_out = msr_out.reshape((2, 3, 4, -1))
    print(msr_out.shape)
    msr_out = msr_out.permute((0, 2, 1, 3))
    print(msr_out.shape)
    msr_out = msr_out.flatten(0, 1)
    print(msr_out.shape)
    
    lstm = torch.nn.LSTM(32, 64, batch_first=True)
    fc = torch.nn.Linear(64, 6)
    
    print("Shape out of LSTM")
    h, c = lstm(msr_out)
    print(h.shape)
    print("Reshape LSTM Out")
    h = h[:, -1, :]
    h = h.reshape((2, 4, -1))
    print(h.shape)
    
    print("Final Output of Linear layer ")
    out = fc(h)
    print(out.shape)
    
    print("y shape")
    print(new_batch["measurement"].y.shape)


hetero_conv = torch_geometric.nn.HeteroConv(
    {
        ("measurement", "flows", "measurement"): torch_geometric.nn.SAGEConv(in_channels=(-1, -1), out_channels=32),
        ("subsub", "flows", "subsub"): torch_geometric.nn.SAGEConv(in_channels=(-1, -1), out_channels=32),
        ("subsub", "in", "measurement"): torch_geometric.nn.SAGEConv(in_channels=(-1, -1), out_channels=32)
    }
)

forward(batch, hetero_conv)

1. B X N X T D
2. B * T graphs
3. N nodes each
4. D features

5. Transpose 1) -> B x T x N x D
6. Flatten 5) -> (B * T * N) x D (here we must make sure all edge indices are correct)
7. Message Passing
8. Unflatten 7) -> B x T x N x D' (D' can be new output dimension)
9. Transpose 7) -> B x N x T x D'
10. Flatten 9) -> B' x T x D'
11. Put 10) through LSTM.