In [58]:
import torch
import pickle

from load_data import get_data
from DataLoader import WeatherDL
from einops.layers.torch import Rearrange

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [2]:
data = get_data()
min_max = pickle.load(open('./min_max_test.pkl', 'rb'))
train = WeatherDL(
    data,
    time=slice('2000', '2002'),
    temporal_resolution=3,
    # window=84,
    min=min_max['min'],
    max=min_max['max'],
    batch_size=4,
    num_workers=6,
    persistent_workers=True,
    prefetch_factor=2,
    # multiprocessing_context='fork'
    )
# min, max = train.data_wrapper.getMinMaxValues()
valid = WeatherDL(
    data,
    time='2003',
    temporal_resolution=3,
    min=min_max['min'],
    max=min_max['max'],
    batch_size=5,
    num_workers=6,
    persistent_workers=True,
    prefetch_factor=2,
    # multiprocessing_context='fork'
    )

In [3]:
input_size = train.spatio_temporal_dataset.n_channels   # n channel
n_nodes = train.spatio_temporal_dataset.n_nodes         # n nodes
horizon = train.spatio_temporal_dataset.horizon         # n prediction time steps
hidden_size = 32
batch = next(iter(train.data_loader))

In [4]:
batch

StaticBatch(
  input=(x=[b=4, t=112, n=2048, f=65], edge_index=[2, e=19328], edge_weight=[e=19328]),
  target=(y=[b=4, t=40, n=2048, f=65]),
  has_mask=False
)

In [5]:
import torch.nn as nn
from tsl.nn.layers import NodeEmbedding, DiffConv
from einops import rearrange

In [6]:
encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
encoded = encoder(batch.input.x)
emb = encoded + embeddings()

In [7]:
from graphesn import StaticGraphReservoir, Readout, initializer

reservoir = StaticGraphReservoir(num_layers=5, in_features=hidden_size, hidden_features=256)
reservoir.initialize_parameters(recurrent=initializer('uniform', rho=.9), input=initializer('uniform', scale=1))
# embeddings = reservoir(data.edge_index, data.x)

In [8]:
batch.input.x.size(-2)

2048

In [9]:
temp_emb = reservoir(batch.input.edge_index, batch.input.x, batch=0)

input torch.Size([4, 112, 2048, 65])
src size torch.Size([4, 256])
dim: 0, node_dim: -2


IndexError: Encountered an index error. Please ensure that all indices in 'edge_index' point to valid indices in the interval [0, 3] (got interval [0, 2047])

In [70]:
ones = torch.ones((2,2,2,2))
zeros = torch.zeros((2,2,2, 2))
twos = torch.ones((2,2, 3)) * 2
joint = torch.cat((ones, zeros), dim=1)

In [64]:
r = torch.arange(16, dtype=torch.float32).reshape((2,2,2,2))

In [65]:
r

tensor([[[[ 0.,  1.],
          [ 2.,  3.]],

         [[ 4.,  5.],
          [ 6.,  7.]]],


        [[[ 8.,  9.],
          [10., 11.]],

         [[12., 13.],
          [14., 15.]]]])

In [66]:
rt = Rearrange('b t n f -> b f n t')(r)
rt

tensor([[[[ 0.,  4.],
          [ 2.,  6.]],

         [[ 1.,  5.],
          [ 3.,  7.]]],


        [[[ 8., 12.],
          [10., 14.]],

         [[ 9., 13.],
          [11., 15.]]]])

In [78]:
rt.size()

torch.Size([2, 2, 2, 2])

In [77]:
torch.nn.functional.linear(rt, twos)

RuntimeError: t() expects a tensor with <= 2 dimensions, but self is 3D

In [75]:
torch.matmul(rt, twos), rt.size(), twos.size()

(tensor([[[[ 8.,  8.,  8.],
           [16., 16., 16.]],
 
          [[12., 12., 12.],
           [20., 20., 20.]]],
 
 
         [[[40., 40., 40.],
           [48., 48., 48.]],
 
          [[44., 44., 44.],
           [52., 52., 52.]]]]),
 torch.Size([2, 2, 2, 2]),
 torch.Size([2, 2, 3]))

In [83]:
l = torch.nn.Linear(2, 3)
l.weight = torch.nn.Parameter(twos)


In [86]:
l.weight

Parameter containing:
tensor([[[2., 2., 2.],
         [2., 2., 2.]],

        [[2., 2., 2.],
         [2., 2., 2.]]], requires_grad=True)

In [164]:
one = torch.ones((3,2))
r = torch.arange(2*3*3*5).reshape((2,3,3,5))
reduction = torch.einsum('btnc,aw -> banw', r, one)
r,reduction.size()

(tensor([[[[ 0,  1,  2,  3,  4],
           [ 5,  6,  7,  8,  9],
           [10, 11, 12, 13, 14]],
 
          [[15, 16, 17, 18, 19],
           [20, 21, 22, 23, 24],
           [25, 26, 27, 28, 29]],
 
          [[30, 31, 32, 33, 34],
           [35, 36, 37, 38, 39],
           [40, 41, 42, 43, 44]]],
 
 
         [[[45, 46, 47, 48, 49],
           [50, 51, 52, 53, 54],
           [55, 56, 57, 58, 59]],
 
          [[60, 61, 62, 63, 64],
           [65, 66, 67, 68, 69],
           [70, 71, 72, 73, 74]],
 
          [[75, 76, 77, 78, 79],
           [80, 81, 82, 83, 84],
           [85, 86, 87, 88, 89]]]]),
 torch.Size([2, 3, 3, 2]))

In [162]:
torch.matmul(r[0,:,0,0],torch.ones(3, dtype=torch.long))

tensor(45)