In [None]:
from __future__ import annotations

import numpy as np
import pandas as pd
import os
from os.path import join
from datetime import timedelta
from collections import OrderedDict

import data_utils

import torch
from torch_geometric.data import HeteroData, Dataset
from torch_geometric.loader import DataLoader

from sklearn.preprocessing import MinMaxScaler, StandardScaler
from sklearn.preprocessing import MaxAbsScaler, RobustScaler
from sklearn.preprocessing import FunctionTransformer

from typing import Type, Optional, Tuple, List, Union, Dict, Any, Callable

DataDateDictType = Dict[np.datetime64, HeteroData]
ScalerType = Union[MinMaxScaler, StandardScaler,
                   MaxAbsScaler, RobustScaler, FunctionTransformer]

In [None]:
from torch_geometric.data import Data

class SeqData(Data):
    def __init__(self, edge_indices=None, xs=None, y=None):
        super().__init__()
        
        self.edge_indices = edge_indices
        self.xs = xs
        self.y = y
    
    def __inc__(self, key, value, *args, **kwargs):
        if key == "edge_indices":
            return self.xs[0].size(0)
        else:
            return super().__inc__(key, value, *args, **kwargs)

In [None]:
n_graphs = 2
n_nodes = 4

edge_index = torch.tensor([
    [1, 2, 3],
    [0, 0, 0]
])

edge_indices = edge_index.repeat(2, 1, 1)

xs = torch.stack([torch.randint(1, 5, (n_nodes, 8)) for _ in range(n_graphs)])
y = torch.randint(6, 10, (n_nodes, n_graphs))

data = SeqData(edge_indices, xs, y)

data_list = [data, data]

loader = DataLoader(data_list, batch_size=2, follow_batch=["xs", "y"])
batch = next(iter(loader))

print(data)
print(batch)

In [None]:
data_list = [data, data]

loader = DataLoader(data_list, batch_size=2, follow_batch=["xs", "y"])
batch = next(iter(loader))

print(batch)
print(batch.edge_indices)
print(batch.xs)
print(batch.xs_batch)
print(batch.y)
print(batch.y_batch)

In [None]:
from torch_geometric.data import Data, HeteroData

class SeqData(HeteroData):
    def __init__(self):
        super().__init__()
    
    def __inc__(self, key, value, store, *args, **kwargs):
        print(key, value)
        return 0

In [None]:
n_graphs = 2
n_nodes = 4

measurement_xs = torch.stack([torch.randint(1, 5, (n_nodes, 8)) for _ in range(n_graphs)])
measurement_y = torch.randint(6, 10, (n_nodes, n_graphs))

measurement_edge_index = torch.tensor([
    [1, 2, 3],
    [0, 0, 0]
])

measurement_edge_indices = edge_index.repeat(2, 1, 1)

data = SeqData()

data["measurement"].x = measurement_xs
data["measurement"].y = measurement_y
data["measurement", "flows", "measurement"].edge_index = measurement_edge_indices

data_list = [data, data]

loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)

In [None]:
data_list = [data, data]

loader = DataLoader(data_list, batch_size=2)
batch = next(iter(loader))

print(batch)
# print(batch.edge_indices)
# print(batch.xs)
# print(batch.xs_batch)
# print(batch.y)
# print(batch.y_batch)