In [30]:
import torch
import pickle
import xarray as xr
import pandas as pd
import torch.nn as nn
from torch.nn import functional as F
from tsl.nn.layers import NodeEmbedding, DenseGraphConvOrderK, DiffConv, Norm
from tsl.nn.blocks.decoders import MLPDecoder
from tsl.nn.blocks.encoders.mlp import MLP
from einops.layers.torch import Rearrange
from snntorch import utils
from models.layers.SynapticChain import SynapticChain
from models.layers.LearnableWeight import LearnableWeight

from load_data import get_data
from DataLoader import WeatherDL

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
data = get_data()
min_max = pickle.load(open('./min_max_test.pkl', 'rb'))
temporal_resolution = 6
window = 20
train = WeatherDL(
    data,
    time=slice('1978', '1980'),
    temporal_resolution=temporal_resolution,
    window=window,
    min=min_max['min'],
    max=min_max['max'],
    batch_size=1,
    num_workers=6,
    persistent_workers=True,
    prefetch_factor=2,
    shuffle=False,
    normalization_range={'min': 0, 'max': 1}
    # multiprocessing_context='fork'
    )
# min, max = train.data_wrapper.getMinMaxValues()
# valid = WeatherDL(
#     data,
#     time='2003',
#     temporal_resolution=temporal_resolution,
#     min=min_max['min'],
#     max=min_max['max'],
#     batch_size=5,
#     num_workers=6,
#     persistent_workers=True,
#     prefetch_factor=2,
#     # multiprocessing_context='fork'
#     )
device


weather loader init
WeatherDataLoader done
weather loader exists
SpatioTemporalDataset init
SpatioTemporalDataset exists
StaticGraphLoader init
StaticGraphLoader exists


device(type='cuda', index=0)

In [31]:
from metrics.weighted_rmse2 import WeightedRMSE
from metrics.metric_utils import WeatherVariable

In [52]:
input_size = train.spatio_temporal_dataset.n_channels   # n channel
n_nodes = train.spatio_temporal_dataset.n_nodes         # n nodes
horizon = train.spatio_temporal_dataset.horizon         # n prediction time steps
hidden_size = 256
learnable_size = 32
batch = next(iter(train.data_loader))
batch

StaticBatch(
  input=(x=[b=1, t=20, n=2048, f=26], edge_index=[2, e=19328], edge_weight=[e=19328]),
  target=(y=[b=1, t=40, n=2048, f=26]),
  has_mask=False
)

In [56]:
rmse = WeightedRMSE(
    train.data_wrapper.node_weights,
    variables=[
        # WeatherVariable('z', 500),
        WeatherVariable('t', 850)
        ],
    min_max=min_max,
    denormalize=True
    ).to(batch.target.y.device)
rmse(batch.input.x[:, -1:], batch.target.y[:, 11:12])

tensor(2.2591)

: 

In [29]:
rmse = WeightedRMSE(
    train.data_wrapper.node_weights,
    variables=[
        # WeatherVariable('z', 500),
        WeatherVariable('t', 850)
        ],
    min_max=min_max,
    denormalize=True,
    normalization_range={'min': 0, 'max':1}
    ).to(batch.target.y.device)
rmse(batch.target.y[:, 0:1], batch.target.y[:, 12:13])

tensor(4.4877)

In [38]:
encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
node_embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
learnable = LearnableWeight(n_nodes=n_nodes, learnable_weight_dim=learnable_size)
new_hidden_size1 = hidden_size+learnable_size
space = DiffConv(in_channels=new_hidden_size1, out_channels=new_hidden_size1, k=2)

learnable2 = LearnableWeight(n_nodes=n_nodes, learnable_weight_dim=learnable_size)
new_hidden_size2 = new_hidden_size1 + learnable_size
space2 = DiffConv(in_channels=new_hidden_size2, out_channels=new_hidden_size2, k=2)

time_nn = SynapticChain(hidden_size=new_hidden_size2+new_hidden_size1, n_layers=5, output_type='membrane_potential', return_last=True)

In [35]:
x = encoder(batch.input.x) + node_embeddings()
x = learnable(x)
x = space(x, train.spatio_temporal_dataset.edge_index, train.spatio_temporal_dataset.edge_weight)

x2 = learnable2(x)
x2 = space2(x2, train.spatio_temporal_dataset.edge_index, train.spatio_temporal_dataset.edge_weight)
stacked = torch.cat([x2, x], dim=-1)

In [39]:
res = time_nn(stacked)

In [41]:
new_hidden_size1, new_hidden_size2

(288, 320)

In [40]:
stacked.size(), res.size()

(torch.Size([1, 112, 2048, 608]), torch.Size([1, 2048, 608]))

In [7]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.checkpoint import checkpoint
from tsl.nn.layers import NodeEmbedding, DenseGraphConvOrderK, DiffConv, Norm
from tsl.nn.blocks.decoders import MLPDecoder
from tsl.nn.blocks.encoders.mlp import MLP
from einops.layers.torch import Rearrange
from snntorch import utils
from typing_extensions import Literal

from models.layers.SynapticChain import SynapticChain
from models.layers.LearnableWeight import LearnableWeight
from models.layers.MultiParam import MultiParam

class SkipConnection(nn.Module):
    def forward(self, x, out):
        return x + out[:, -x.size(1):]

class Add(nn.Module):
    def forward(self, x, out):
        return x + out

        

class TSNStacked(nn.Module):
    def __init__(self, input_size: int, n_nodes: int, horizon: int,
                 hidden_size: int = 256,
                 ff_size: int = 256,
                 gnn_kernel: int = 2,
                 output_type: Literal["spike", "synaptic_current", "membrane_potential"] = "spike",
                 learnable_feature_size = 64,
                 number_of_blocks = 2,
                 number_of_temporal_steps = 3,
                 dropout: float = 0.3,
                 ) -> None:
        super(TSNStacked, self).__init__()

        self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
        self.output_type = output_type
        
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.node_embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
        self.dropout = nn.Dropout(dropout)
        
        assert n_nodes is not None
        self.source_embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
        self.target_embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
        

        temporal = []
        spatial = []
        dense_sconvs = []
        skip_connections = []
        learnable_weights = []
        norms = []
        new_hidden_size = 0
        for i in range(number_of_blocks):
            # is_last = i == number_of_blocks - 1
            learnable = LearnableWeight(n_nodes=n_nodes, learnable_weight_dim=learnable_feature_size)
            # new_hidden_size = hidden_size + ((i + 1) * learnable_feature_size) 
            hidden_size = hidden_size + learnable_feature_size
            # time_nn = SynapticChain(
            #     hidden_size=new_hidden_size,
            #     return_last=is_last,
            #     output_type=output_type,
            #     n_layers=number_of_temporal_steps,
            #     )
            # space_nn = DiffConv(
            #                 # in_channels=hidden_size * ((3-1)**2),
            #                 in_channels=hidden_size,
            #                 out_channels=hidden_size,
            #                 k=gnn_kernel)
            dense_sconvs.append(
                    DenseGraphConvOrderK(input_size=hidden_size,
                                         output_size=hidden_size,
                                         support_len=1,
                                         order=3, # spatial kernel size
                                         include_self=False,
                                         channel_last=True))
            learnable_weights.append(learnable)
            # skip_connections.append(nn.Linear(hidden_size, ff_size))
            # temporal.append(time_nn)
            # spatial.append(space_nn)
            norms.append(Norm('batch', hidden_size))
            new_hidden_size += hidden_size
        
        self.learnable = nn.ModuleList(learnable_weights)
        self.temporal = nn.ModuleList(temporal)
        # self.spatial = nn.ModuleList(spatial)
        self.dense_sconvs = nn.ModuleList(dense_sconvs)
        # self.skip_connections = nn.ModuleList(skip_connections)
        self.norms = nn.ModuleList(norms)
        self.time_nn = SynapticChain(hidden_size=new_hidden_size, n_layers=number_of_temporal_steps, output_type=output_type, return_last=True)
        # _____________________________________________________________________
        
        # self.decoder = nn.Linear(hidden_size, input_size * horizon)
        # self.rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)

        self.readout = nn.Sequential(
            nn.ReLU(),
            MLPDecoder(input_size=new_hidden_size,
                       hidden_size=2 * new_hidden_size,
                       output_size=input_size,
                       horizon=horizon,
                       activation='relu'))

    def get_learned_adj(self):
        logits = F.relu(self.source_embeddings() @ self.target_embeddings().T)
        adj = torch.softmax(logits, dim=1)
        return adj
        
        
    def forward(self, x, edge_index, edge_weight):
        assert not torch.isnan(x).any()
        # x: [batch time nodes features]
        # utils.reset(self.temporal)
        # x_enc = self.encoder(x)  # linear encoder: x_enc = xΘ + b
        # x_emb = x_enc + self.node_embeddings()  # add node-identifier embeddings
        # out = torch.zeros(1, x.size(1), 1, 1, device=x.device)
    
        x = self.encoder(x) + self.node_embeddings()
        adj_z = self.get_learned_adj()
        # for p1, p2, p3, p4 in zip(self.phase1, self.phase2, self.phase3, self.phase4):
        #     resid = x
        #     utils.reset(p1)
        #     x = checkpoint(p1, x, use_reentrant=False)
        #     res = checkpoint(p2, [x, out], use_reentrant=False)
        #     out = res[0]
        #     print(x.size())
        #     res = checkpoint(p3, [x, edge_index, edge_weight, adj_z], use_reentrant=False)
        #     x = res[0]
        #     res = checkpoint(p4, [x, resid], use_reentrant=False)
        #     x = res[0]
        # learned_edge_index = adj_z.nonzero().t().contiguous()
        # ----------------------------------------------------------------------
        processed = []
        for i, (add_features, space, norm) in enumerate(zip(
            self.learnable,
            # self.spatial,
            self.dense_sconvs,
            self.norms,
            )):
            # utils.reset(time)
            x = checkpoint(add_features, x)
            # x = add_features(x)
            # x = space(x, adj_z)
            x = checkpoint(space, x, adj_z)
            # x = norm(x)
            x = checkpoint(norm, x)
            processed.append(x)
            
            # res = x
            
            # x = checkpoint(time, x)
            # x = time(x)
            # assert not torch.isnan(x).any()
            # # out = checkpoint(skip_conn, x) + out[:, -x.size(1):]
            # out = skip_conn(x) + out[:, -x.size(1):]
            # # xs = checkpoint(space, x, self.edge_index, self.edge_weight)
            # xs = space(x, self.edge_index, self.edge_weight)
            # # xs = space(x, learned_edge_index)
            # if len(self.dense_sconvs):
            #     # x = xs + checkpoint(self.dense_sconvs[i], x, adj_z)
            #     x = xs + self.dense_sconvs[i](x, adj_z)
            # # residual connection -> next layer
            # x = x + res[:, -x.size(1):]
            # x = norm(x)
        # ----------------------------------------------------------------------
        out = checkpoint(self.time_nn, torch.cat(processed, dim=-1))
            
        # return self.readout(out)
        return checkpoint(self.readout, out)
            
        # x_out = self.decoder(x_emb)  # linear decoder: z=[b n f] -> x_out=[b n t⋅f]
        # x_horizon = self.rearrange(x_out)
        # return x_horizon
        
        

In [8]:
m = TSNStacked(
    input_size=input_size,
    n_nodes=n_nodes,
    horizon=horizon,
    output_type='membrane_potential',
    learnable_feature_size=32,
    number_of_blocks=2,
    number_of_temporal_steps=10
    )

In [10]:
m

TSNStacked(
  (encoder): Linear(in_features=65, out_features=256, bias=True)
  (node_embeddings): NodeEmbedding(n_nodes=2048, embedding_size=256)
  (dropout): Dropout(p=0.3, inplace=False)
  (source_embeddings): NodeEmbedding(n_nodes=2048, embedding_size=256)
  (target_embeddings): NodeEmbedding(n_nodes=2048, embedding_size=256)
  (learnable): ModuleList(
    (0-1): 2 x LearnableWeight()
  )
  (temporal): ModuleList()
  (dense_sconvs): ModuleList(
    (0): DenseGraphConvOrderK(
      (mlp): Conv2d(864, 288, kernel_size=(1, 1), stride=(1, 1))
    )
    (1): DenseGraphConvOrderK(
      (mlp): Conv2d(960, 320, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (norms): ModuleList(
    (0): Norm(batch, 288)
    (1): Norm(batch, 320)
  )
  (time_nn): SynapticChain(
    (chain): Sequential(
      (0): Synaptic()
      (1): Linear(in_features=608, out_features=608, bias=True)
      (2): Synaptic()
      (3): Linear(in_features=608, out_features=608, bias=True)
      (4): Synaptic()
      (5): Lin

In [9]:
res = m(*batch.input)

In [6]:
res.size()

torch.Size([1, 40, 2048, 65])

In [6]:
rmse = WeightedRMSE(
    train.data_wrapper.node_weights,
    variables=[
        # WeatherVariable('z', 500),
        WeatherVariable('t', 850)
        ],
    min_max=min_max,
    denormalize=True
    ).to(batch.target.y.device)
rmse(batch.target.y[:, 0:1], batch.target.y[:, 23:24])

EinopsError: Shape mismatch, 1 != 2048

In [15]:
from models.SynapticAttention import SynapticAttention
model = SynapticAttention(input_size=input_size, n_nodes=n_nodes, hidden_size=hidden_size, horizon=horizon)
model.to(device)
batch.to(device)
res = model(*batch.input)

In [37]:
rmse = rmse.to(res.device)
rmse(batch.input.x[:, -40:], res)

tensor(4408.3960, device='cuda:0', grad_fn=<SqueezeBackward0>)

In [None]:
import torch.nn as nn
from tsl.nn.layers import NodeEmbedding, DiffConv, GATConv
from einops import rearrange
import snntorch as snn
from snntorch import functional as SF
encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)

encoded = encoder(batch.input.x)
emb = encoded + embeddings()

In [None]:
torch.zeros(1, emb.size(1), 1, 1).size()

In [None]:
res = GATConv(in_channels=hidden_size, out_channels=32)(emb[:, 0, :, :], batch.input.edge_index, batch.input.edge_weight)

In [None]:
res[0].size()

In [None]:
emb.size()

In [None]:
b, t, n, f = emb.size()

In [None]:
emb[:, 0, :, :].size()

In [None]:
# from typing import Optional, Tuple, Union

# import torch
# import torch.nn.functional as F
# from torch import Tensor
# from torch.nn import Parameter
# from torch_geometric.nn.conv import MessagePassing
# from torch_geometric.nn.dense.linear import Linear
# from torch_geometric.nn.inits import glorot, zeros
# from torch_geometric.typing import Adj, OptPairTensor, OptTensor
# from torch_geometric.utils import add_self_loops, remove_self_loops
# from torch_sparse import SparseTensor, set_diag

# from tsl.nn.functional import sparse_softmax

# import snntorch as snn


# class GSATConv(MessagePassing):
#     def __init__(
#         self,
#         in_channels: Union[int, Tuple[int, int]],
#         out_channels: int,
#         heads: int = 1,
#         concat: bool = True,
#         dim: int = -2,
#         negative_slope: float = 0.2,
#         dropout: float = 0.0,
#         add_self_loops: bool = True,
#         edge_dim: Optional[int] = None,
#         fill_value: Union[float, Tensor, str] = 'mean',
#         bias: bool = True,
#         **kwargs,
#     ):
#         kwargs.setdefault('aggr', 'add')
#         super().__init__(node_dim=dim, **kwargs)

#         self.in_channels = in_channels
#         self.out_channels = out_channels
#         self.heads = heads
#         self.concat = concat
#         self.negative_slope = negative_slope
#         self.dropout = dropout
#         self.add_self_loops = add_self_loops
#         self.edge_dim = edge_dim
#         self.fill_value = fill_value

#         if self.concat:
#             self.head_channels = self.out_channels // self.heads
#             assert self.head_channels * self.heads == self.out_channels, \
#                 "`out_channels` must be divisible by `heads`."
#         else:
#             self.head_channels = self.out_channels

#         # In case we are operating in bipartite graphs, we apply separate
#         # transformations 'lin_src' and 'lin_dst' to source and target nodes:
#         # if isinstance(in_channels, int):
#         self.lin_src = snn.Synaptic(
#             alpha=0.9,
#             beta=0.8,
#             learn_alpha=True,
#             learn_beta=True,
#             learn_threshold=True,
#             # init_hidden=True,
#             # output=True
#         )
#         # Linear(in_channels,
#         #                         heads * self.head_channels,
#         #                         bias=False,
#         #                         weight_initializer='glorot')
#         self.lin_dst = self.lin_src
#         # else:
#         #     self.lin_src = Linear(in_channels[0],
#         #                           heads * self.head_channels,
#         #                           False,
#         #                           weight_initializer='glorot')
#         #     self.lin_dst = Linear(in_channels[1],
#         #                           heads * self.head_channels,
#         #                           False,
#         #                           weight_initializer='glorot')

#         # The learnable parameters to compute attention coefficients:
#         self.att_src = Parameter(torch.Tensor(1, heads, self.head_channels))
#         self.att_dst = Parameter(torch.Tensor(1, heads, self.head_channels))

#         if edge_dim is not None:
#             self.edge_synaptic = snn.Synaptic(
#                 alpha=0.9,
#                 beta=0.8,
#                 learn_alpha=True,
#                 learn_beta=True,
#                 learn_threshold=True,
#                 # init_hidden=True,
#                 # output=True,
#             )
#             self.lin_edge = Linear(edge_dim,
#                                    heads * self.head_channels,
#                                    bias=False,
#                                    weight_initializer='glorot')
#             self.att_edge = Parameter(
#                 torch.Tensor(1, heads, self.head_channels))
#         else:
#             self.lin_edge = None
#             self.register_parameter('att_edge', None)

#         if bias and concat:
#             self.bias = Parameter(torch.Tensor(heads * self.head_channels))
#         elif bias and not concat:
#             self.bias = Parameter(torch.Tensor(out_channels))
#         else:
#             self.register_parameter('bias', None)

#         self.reset_parameters()

#     def reset_parameters(self):
#         # self.lin_src.reset_parameters()
#         # self.lin_dst.reset_parameters()
#         if self.lin_edge is not None:
#             self.lin_edge.reset_parameters()
#         glorot(self.att_src)
#         glorot(self.att_dst)
#         glorot(self.att_edge)
#         zeros(self.bias)

#     def forward(self,
#                 x: Union[Tensor, OptPairTensor],
#                 edge_index: Adj,
#                 edge_attr: OptTensor = None,
#                 need_weights: bool = False):
#         """"""
#         node_dim = self.node_dim
#         self.node_dim = (node_dim + x.dim()) if node_dim < 0 else node_dim
#         b, t, n, c = x.size()

#         N, H, C = n, self.heads, self.head_channels

#         syn, mem = self.lin_src.init_synaptic()
#         if self.edge_synaptic is not None:
#             syn_e, membrane_e = self.edge_synaptic.init_synaptic()
#         # syn, membrane_pot = synaptic.init_synaptic()
#         # We first transform the input node features. If a tuple is passed, we
#         # transform source and target node features via separate weights:
#         # if isinstance(x, Tensor):
#         # print("x size", x.size())
#         for timestep in range(t):
#             data_at_time = x[:, timestep, : , :]
#             spike, syn, mem = self.lin_src(data_at_time, syn, mem)
#             x_src = x_dst = mem.view(*data_at_time.shape[:-1], H, C)
#             # else:  # Tuple of source and target node features:
#             #     x_src, x_dst = x
#             #     x_src = self.lin_src(x_src).view(*x_src.shape[:-1], H, C)
#             #     if x_dst is not None:
#             #         x_dst = self.lin_dst(x_dst).view(*x_dst.shape[:-1], H, C)

#             x_node_features = (x_src, x_dst)

#             # Next, we compute node-level attention coefficients, both for source
#             # and target nodes (if present):
#             alpha_src = (x_src * self.att_src).sum(dim=-1)
#             alpha_dst = None if x_dst is None else (x_dst * self.att_dst).sum(-1)
#             alpha = (alpha_src, alpha_dst)

#             if self.add_self_loops:
#                 if isinstance(edge_index, Tensor):
#                     edge_index, edge_attr = remove_self_loops(
#                         edge_index, edge_attr)
#                     edge_index, edge_attr = add_self_loops(
#                         edge_index,
#                         edge_attr,
#                         fill_value=self.fill_value,
#                         num_nodes=N)
#                 elif isinstance(edge_index, SparseTensor):
#                     if self.edge_dim is None:
#                         edge_index = set_diag(edge_index)
#                     else:
#                         raise NotImplementedError(
#                             "The usage of 'edge_attr' and 'add_self_loops' "
#                             "simultaneously is currently not yet supported for "
#                             "'edge_index' in a 'SparseTensor' form")

#             # edge_updater_type: (alpha: OptPairTensor, edge_attr: OptTensor)
#             alpha, syn_e, membrane_e = self.edge_updater(edge_index, alpha=alpha, edge_attr=edge_attr, syn_e=syn_e, membrane_e=membrane_e)

#             # propagate_type: (x: OptPairTensor, alpha: Tensor)
#             out = self.propagate(edge_index, x=x_node_features, alpha=alpha, size=(N, N))

#             if self.concat:
#                 out = out.view(*out.shape[:-2], self.out_channels)
#             else:
#                 out = out.mean(dim=-2)

#             if self.bias is not None:
#                 out += self.bias

#             if need_weights:
#                 # alpha rearrange: [... e ... h] -> [e ... h]
#                 alpha = torch.movedim(alpha, self.node_dim, 0)
#                 if isinstance(edge_index, Tensor):
#                     alpha = (edge_index, alpha)
#                 elif isinstance(edge_index, SparseTensor):
#                     alpha = edge_index.set_value(alpha, layout='coo')
#             else:
#                 alpha = None

#             self.node_dim = node_dim

#         return out, alpha

#     def edge_update(self, alpha_j: Tensor, alpha_i: OptTensor,
#                     edge_attr: OptTensor, index: Tensor, ptr: OptTensor,
#                     size_i: Optional[int], membrane_e=None, syn_e=None) -> Tensor:
#         """"""
#         # Given edge-level attention coefficients for source and target nodes,
#         # we simply need to sum them up to "emulate" concatenation:
#         alpha = alpha_j if alpha_i is None else alpha_j + alpha_i

#         if edge_attr is not None:
#             if edge_attr.dim() == 1:
#                 edge_attr = edge_attr.view(-1, 1)
#             assert self.lin_edge is not None
#             # edge_attr = self.lin_edge(edge_attr)
#             spike, syn, mem = self.edge_synaptic(edge_attr, syn_e, membrane_e)
#             edge_attr = self.lin_edge(mem)
#             edge_attr = edge_attr.view(-1, self.heads, self.head_channels)
#             alpha_edge = (edge_attr * self.att_edge).sum(dim=-1)
#             shape = [1] * (alpha.ndim - 1) + [self.heads]
#             shape[self.node_dim] = alpha_edge.size(0)
#             alpha = alpha + alpha_edge.view(shape)

#         alpha = F.leaky_relu(alpha, self.negative_slope)
#         alpha = sparse_softmax(alpha,
#                                index,
#                                num_nodes=size_i,
#                                ptr=ptr,
#                                dim=self.node_dim)
#         alpha = F.dropout(alpha, p=self.dropout, training=self.training)
#         return alpha, syn, mem

#     def message(self, x_j: Tensor, alpha: Tensor) -> Tensor:
#         """"""
#         return alpha.unsqueeze(-1) * x_j

#     def __repr__(self) -> str:
#         return (f'{self.__class__.__name__}({self.in_channels}, '
#                 f'{self.out_channels}, heads={self.heads})')


In [None]:
spatial = DiffConv(hidden_size, out_channels=hidden_size, k=2, activation='relu')
emn_s = spatial(emb, batch.input.edge_index, batch.input.edge_weight)

In [None]:
stacked = torch.cat((emb, emn_s), dim=-1)
stacked.size()

In [None]:
from models.layers.GSATConv import GSATConv
res = GSATConv(in_channels=(1 * hidden_size) + hidden_size, out_channels=hidden_size, dim = 1, edge_dim=1)(stacked, batch.input.edge_index, batch.input.edge_weight)

In [None]:
res[0].size()

In [None]:
Parameter(torch.Tensor(1, 4, 64))

In [None]:
beta = 0.9
rlif = snn.RLeaky(beta=beta, linear_features=hidden_size)
spike, membrane_pot = rlif.init_rleaky()

In [None]:
from snntorch import surrogate
spike_grad=surrogate.atan(alpha=2.0)
thresh=1
l = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh, output=True)

In [None]:
torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)

In [None]:
def get_multiplier(i):
    return max((i * 2), 1)
    # return i + 1

In [None]:
for i in range(3):
    print(f"size {emb.size(-1) * get_multiplier(i)}")

In [None]:
# alpha = torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)
# beta = torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)
synaptic = snn.Synaptic(
    alpha=torch.Tensor(emb.size(-1)),
    beta=torch.Tensor(emb.size(-1)),
    learn_alpha=True,
    learn_beta=True,
    learn_threshold=True,
    init_hidden=True
    )
# syn, membrane_pot = synaptic.init_synaptic()

synaptic2 = snn.Synaptic(
    # alpha=torch.Tensor(emb.size(-1) * 2),
    # beta=torch.Tensor(emb.size(-1) * 2),
    alpha=torch.Tensor(emb.size(-1)),
    beta=torch.Tensor(emb.size(-1)),
    learn_alpha=True,
    learn_beta=True,
    learn_threshold=True,
    init_hidden=True
    )
# syn2, membrane_pot2 = synaptic2.init_synaptic()

synaptic3 = snn.Synaptic(
    # alpha=torch.Tensor(emb.size(-1) * 3),
    # beta=torch.Tensor(emb.size(-1) * 3),
    alpha=torch.Tensor(emb.size(-1)),
    beta=torch.Tensor(emb.size(-1)),
    learn_alpha=True,
    learn_beta=True,
    learn_threshold=True,
    init_hidden=True,
    output=True
    )
# syn3, membrane_pot3 = synaptic3.init_synaptic()

In [None]:
synaptic

In [None]:
linear = torch.nn.Linear(in_features=emb.size(-1), out_features=emb.size(-1))
# linear2 = torch.nn.Linear(in_features=emb.size(-1) * 2, out_features=emb.size(-1) * 2)
linear2 = torch.nn.Linear(in_features=emb.size(-1), out_features=emb.size(-1))

In [None]:
from models.layers.SynapticChain import SynapticChain
chain = SynapticChain(hidden_size=hidden_size, return_last=True)
spike, mem_pot, syn_cur = chain(emb)

In [None]:
assert not torch.isnan(syn_cur).any()

In [None]:
spikes =[]
mem_pots = []
# flat = emb.flatten(start_dim=2)
for timestep in range(t):
    # spike, syn, membrane_pot = synaptic(emb[:, timestep,:, :], syn, membrane_pot)
    spike = synaptic(emb[:, timestep,:, :])
    l = linear(spike)
    # spike2, syn2, membrane_pot2 = synaptic2(torch.cat((spike, emb[:, timestep,:, :]), dim=2), syn2, membrane_pot2)
    # spike2, syn2, membrane_pot2 = synaptic2(torch.cat((l, emb[:, timestep,:, :]), dim=2), syn2, membrane_pot2)
    spike2 = synaptic2(l)
    l2 = linear2(spike2)
    
    # spike3, syn3, membrane_pot3 = synaptic3(torch.cat((l2, l), dim=2), syn3, membrane_pot3)
    spike3, syn3, membrane_pot3 = synaptic3(l2)
    
    # spike, syn, membrane_pot = rsynaptic(flat[:, timestep, :], spike, syn, membrane_pot)
    # spike, syn, membrane_pot = rsynaptic(emb[:, timestep, :, :], spike, syn, membrane_pot)
    # spike, syn, membrane_pot = rsynaptic_conv(emb[:, timestep, :, :], spike, syn, membrane_pot)
    spikes.append(spike3)
    mem_pots.append(membrane_pot3)
    

In [None]:
855900

In [None]:
spike3.size()

In [None]:
torch.cat((spike, emb[:, 0,:, :]), dim=2).size()

In [None]:
spike.size(), syn.size(), membrane_pot.size()

In [None]:
space = DiffConv(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 k=2)

In [None]:
stacked_spikes = torch.stack(spikes, 1)
post_space = space(stacked_spikes[:,-1,:,:], batch.edge_index, batch.edge_weight)

In [None]:
post_space

In [None]:
decoder = nn.Linear(32, input_size * horizon)
post_decoder = decoder(post_space)

In [None]:
post_decoder.size()

In [None]:
from einops.layers.torch import Rearrange
rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)
post_rearange = rearrange(post_decoder)
post_rearange.size()

In [None]:
post_rearange

In [None]:
torch.stack(spikes, 1).size(), torch.stack(spikes, 1).reshape((b,t,n,-1)).size()

In [None]:
spike.reshape((b,n,-1)).size()

In [None]:
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)
                          

In [None]:
emb.size()

In [None]:
emb[:, 0, :, :].size()

In [None]:
spk, mem = rlif(emb[:, 0, :, :], spike, membrane_pot)

In [None]:
spk.size(), mem.size()

In [None]:
from models.TemporalSpikeGraphConvNet import TemporalSpikeGraphConvNet
from models.TemporalSynapticGraphConvNet import TemporalSynapticGraphConvNet

model = TemporalSynapticGraphConvNet(input_size=input_size,
                               n_nodes=n_nodes,
                               horizon=horizon,
                               hidden_size=hidden_size * 5,
                               output_type="membrane_potential",
                               number_of_blocks=3
                               )
model

In [None]:
device

In [None]:
model.to(device)
batch.to(device)
res = model(*batch.input)

In [None]:
res.size()

In [None]:
res