In [1]:
import torch
import pickle

from load_data import get_data
from DataLoader import WeatherDL

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device


device(type='cuda', index=0)

In [2]:
data = get_data()
min_max = pickle.load(open('./min_max_test.pkl', 'rb'))
train = WeatherDL(
    data,
    time=slice('2000', '2002'),
    temporal_resolution=3,
    min=min_max['min'],
    max=min_max['max'],
    batch_size=5,
    num_workers=6,
    persistent_workers=True,
    prefetch_factor=2,
    multiprocessing_context='fork'
    )
# min, max = train.data_wrapper.getMinMaxValues()
valid = WeatherDL(
    data,
    time='2003',
    temporal_resolution=3,
    min=min_max['min'],
    max=min_max['max'],
    batch_size=5,
    num_workers=6,
    persistent_workers=True,
    prefetch_factor=2,
    multiprocessing_context='fork'
    )

In [3]:
input_size = train.spatio_temporal_dataset.n_channels   # n channel
n_nodes = train.spatio_temporal_dataset.n_nodes         # n nodes
horizon = train.spatio_temporal_dataset.horizon         # n prediction time steps
hidden_size = 32

In [4]:
batch = next(iter(train.data_loader))

In [5]:
batch

StaticBatch(
  input=(x=[b=5, t=112, n=2048, f=65], edge_index=[2, e=19328], edge_weight=[e=19328]),
  target=(y=[b=5, t=40, n=2048, f=65]),
  has_mask=False
)

In [6]:
import torch.nn as nn
from tsl.nn.layers import NodeEmbedding, DiffConv
from einops import rearrange
import snntorch as snn
from snntorch import functional as SF


In [7]:
encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)

In [8]:
snn.Leaky
encoded = encoder(batch.input.x)
emb = encoded + embeddings()

In [9]:
emb.size()

torch.Size([5, 112, 2048, 32])

In [10]:
b, t, n, f = emb.size()

In [11]:
emb[:, 0, :, :].size()

torch.Size([5, 2048, 32])

In [12]:
beta = 0.9
rlif = snn.RLeaky(beta=beta, linear_features=hidden_size)
spike, membrane_pot = rlif.init_rleaky()

In [13]:
from snntorch import surrogate
spike_grad=surrogate.atan(alpha=2.0)
thresh=1
l = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh, output=True)

In [14]:
torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)

Parameter containing:
tensor(0.9000, requires_grad=True)

In [15]:
alpha = torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)
beta = torch.nn.Parameter(data=torch.tensor(0.9), requires_grad=True)
synaptic = snn.Synaptic(alpha=alpha, beta=beta)
syn, mem = synaptic.init_synaptic()

In [20]:
flat = emb.flatten(start_dim=2)
rsynaptic = snn.RSynaptic(linear_features=emb.size(-1), alpha=torch.Tensor(emb.size(-1)), beta=torch.Tensor(emb.size(-1)), learn_alpha=True, learn_beta=True, learn_recurrent=True, learn_threshold=True)
spike, syn, membrane_pot = rsynaptic.init_rsynaptic()

In [80]:
rsynaptic_conv = snn.RSynaptic(conv2d_channels=5, kernel_size=3,  alpha=torch.Tensor(emb.size(-1)), beta=torch.Tensor(emb.size(-1)), learn_alpha=True, learn_beta=True, learn_recurrent=True, learn_threshold=True)
spike, syn, membrane_pot = rsynaptic_conv.init_rsynaptic()

spike, syn, mem_p = rsynaptic_conv(emb[:, 0, :, :], spike, syn, membrane_pot)
spike.size(), emb.size()

(torch.Size([5, 2048, 32]), torch.Size([5, 112, 2048, 32]))

In [81]:
spike.not_equal(torch.ones(spike.size()))._is_all_true()

tensor(False)

In [82]:
spikes =[]
mem_pots = []
# flat = emb.flatten(start_dim=2)
for timestep in range(t):
    # spike, syn, membrane_pot = synaptic(flat[:, timestep, :], syn, mem)
    # spike, syn, membrane_pot = rsynaptic(flat[:, timestep, :], spike, syn, membrane_pot)
    # spike, syn, membrane_pot = rsynaptic(emb[:, timestep, :, :], spike, syn, membrane_pot)
    spike, syn, membrane_pot = rsynaptic_conv(emb[:, timestep, :, :], spike, syn, membrane_pot)
    spikes.append(spike)
    mem_pots.append(membrane_pot)
    

In [83]:
spike

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 1., 

In [29]:
spike.size(), syn.size(), membrane_pot.size()

(torch.Size([5, 65536]), torch.Size([5, 65536]), torch.Size([5, 65536]))

In [85]:
space = DiffConv(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 k=2)

In [97]:
stacked_spikes = torch.stack(spikes, 1)
post_space = space(stacked_spikes[:,-1,:,:], batch.edge_index, batch.edge_weight)

In [99]:
post_space

tensor([[[-0.0807, -0.1449, -0.0741,  ...,  0.1569,  0.1054,  0.3627],
         [-0.2199, -0.1943, -0.0111,  ...,  0.3141,  0.1207,  0.4086],
         [-0.1230, -0.1954, -0.2658,  ..., -0.0458,  0.1488,  0.2473],
         ...,
         [-0.1111, -0.1153, -0.2082,  ...,  0.1679,  0.0426,  0.2478],
         [-0.1377, -0.1824, -0.2921,  ..., -0.0328,  0.0171,  0.3081],
         [-0.0920, -0.2026, -0.1924,  ...,  0.0428, -0.0578,  0.2363]],

        [[-0.0778, -0.1913, -0.2179,  ..., -0.0746,  0.3959,  0.2268],
         [-0.1454, -0.2277, -0.1324,  ...,  0.1110,  0.2454,  0.1155],
         [-0.2592, -0.1894, -0.2258,  ...,  0.1127,  0.2355,  0.2972],
         ...,
         [-0.1200, -0.1963, -0.0906,  ..., -0.0205,  0.1414,  0.2800],
         [-0.0259, -0.1451,  0.0784,  ..., -0.0541, -0.0671,  0.2433],
         [-0.0808, -0.0644,  0.0456,  ..., -0.1054,  0.0526,  0.3027]],

        [[-0.0926, -0.1685, -0.2180,  ...,  0.2638,  0.1462,  0.4568],
         [-0.3127, -0.3623, -0.3031,  ...,  0

In [100]:
decoder = nn.Linear(32, input_size * horizon)
post_decoder = decoder(post_space)

In [101]:
post_decoder.size()

torch.Size([5, 2048, 2600])

In [102]:
from einops.layers.torch import Rearrange
rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)
post_rearange = rearrange(post_decoder)
post_rearange.size()

torch.Size([5, 40, 2048, 65])

In [103]:
post_rearange

tensor([[[[ 9.7866e-02,  3.3486e-01,  5.5667e-01,  ..., -3.6614e-01,
            1.8089e-01, -1.0009e-01],
          [ 2.7036e-01,  1.7795e-01,  2.6131e-01,  ..., -2.4860e-01,
            2.1535e-01, -3.6149e-02],
          [ 1.0726e-01,  2.5034e-01,  1.7120e-01,  ..., -2.1300e-01,
            9.9274e-02, -6.7685e-02],
          ...,
          [ 1.9004e-01,  2.9565e-01,  3.2597e-01,  ..., -2.7058e-01,
            1.3818e-01, -8.2204e-02],
          [ 2.0400e-01,  2.4904e-01,  2.9729e-01,  ..., -2.6582e-01,
            1.3087e-01, -9.8462e-02],
          [ 1.9219e-01,  2.1347e-01,  2.1782e-01,  ..., -1.9840e-01,
            1.1164e-01, -8.8364e-02]],

         [[-2.8088e-01,  2.0386e-01,  1.5833e-03,  ..., -2.6159e-01,
           -2.1512e-02, -1.0355e-01],
          [-1.1891e-01,  1.4279e-01,  1.4797e-05,  ..., -2.3179e-01,
           -4.2443e-02, -7.4534e-02],
          [-4.9644e-02,  2.0441e-01, -1.1441e-01,  ..., -2.0422e-01,
           -4.5939e-03, -2.6315e-02],
          ...,
     

In [86]:
torch.stack(spikes, 1).size(), torch.stack(spikes, 1).reshape((b,t,n,-1)).size()

(torch.Size([5, 112, 2048, 32]), torch.Size([5, 112, 2048, 32]))

In [150]:
spike.reshape((b,n,-1)).size()

torch.Size([5, 2048, 32])

In [None]:
snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True, output=True,threshold=thresh)
                          

In [29]:
from snntorch import utils
from snntorch import surrogate
spike_grad = surrogate.atan(alpha=2.0)

class TemporalSpike(nn.Module):
    def __init__(self,
                 hidden_size = 32,
                 beta=0.9,
                 return_last=False,
                 spike_grad=surrogate.atan(alpha=2.0),
                 thresh=1) -> None:
        super(TemporalSpike, self).__init__()
        
        self.rlif = snn.RLeaky(beta=beta, linear_features=hidden_size)
        self.spike, self.membrane_potential = self.rlif.init_rleaky()
        self.leaky = snn.Leaky(beta=beta, spike_grad=spike_grad, init_hidden=True,threshold=thresh, output=True)
        print(self.leaky)
        self.return_last = return_last
        
        
    def forward(self, x):
        """
        Args:
            x (tensor): (b,t,n,f)
        """
        b, t, n, f = x.size()
        spikes = torch.zeros(x.size())
        membrane_pots = torch.zeros(x.size())
        for timestep in range(t):
            spike, membrane_potential = self.leaky(x[:, timestep, :, :])
            spikes[:, timestep, :, :] = spike
            membrane_pots[:, timestep, :, :] = membrane_potential
            
            # self.spike = spike
            # self.membrane_potential = membrane_potential
            
        if self.return_last:
            return spikes[:, -1, :, :], membrane_pots[:, -1, :, :]
        else:
            return spikes, membrane_pots

In [30]:
temporal

TemporalSpike(
  (rlif): RLeaky(
    (recurrent): Linear(in_features=32, out_features=32, bias=True)
  )
)

In [31]:
emb.size()

torch.Size([5, 112, 2048, 32])

In [33]:
temporal = TemporalSpike(return_last=True)
s, m = temporal(emb)

Leaky()


In [60]:
from einops.layers.torch import Rearrange
r = Rearrange('b (t n) f -> b t n f', t=1)
r(s).size()

torch.Size([5, 1, 2048, 32])

In [74]:
from einops.layers.torch import Rearrange

class SNNConvGraphNet(nn.Module):
    def __init__(self, input_size: int, n_nodes: int, horizon: int,
                 hidden_size: int = 32,
                 temporal_layers: int = 1,
                 gnn_kernel: int = 2,
                 use_spike_for_output = True) -> None:
        super(SNNConvGraphNet, self).__init__()
        self.use_spike_for_output = use_spike_for_output
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size)
        self.node_embeddings = NodeEmbedding(n_nodes=n_nodes, emb_size=hidden_size)
        
        # start with 1 layer, time then space model lookalike
        self.time_nn = TemporalSpike(hidden_size=hidden_size, return_last=True)
        self.space_nn = DiffConv(in_channels=hidden_size,
                                 out_channels=hidden_size,
                                 k=gnn_kernel)
        
        self.decoder = nn.Linear(hidden_size, input_size * horizon)
        self.rearrange = Rearrange('b n (t f) -> b t n f', t=horizon)
        
    def forward(self, x, edge_index, edge_weight):
        # x: [batch time nodes features]
        x_enc = self.encoder(x)  # linear encoder: x_enc = xΘ + b
        x_emb = x_enc + self.node_embeddings()  # add node-identifier embeddings
        # h = self.time_nn(x_emb)  # temporal processing: x=[b t n f] -> h=[b n f]
        s, m = self.time_nn(x_emb)  # temporal processing: x=[b t n f] -> h=[b n f]
        if self.use_spike_for_output:
            z = self.space_nn(s, edge_index, edge_weight)  # spatial processing for spikes
        else:
            z = self.space_nn(m, edge_index, edge_weight)  # spatial processing for membrane potentials
        x_out = self.decoder(z)  # linear decoder: z=[b n f] -> x_out=[b n t⋅f]
        x_horizon = self.rearrange(x_out)
        return x_horizon
        
        

In [75]:
model = SNNConvGraphNet(input_size=input_size, n_nodes=n_nodes, horizon=horizon)

In [76]:
res = model(*batch.input)
res.size()

torch.Size([5, 40, 2048, 65])

In [37]:
s = torch.stack(spikes, dim=1)
s.size()

torch.Size([5, 112, 2048, 32])

In [49]:
spikes[4]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 1.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

In [41]:
a = torch.zeros(s.size())
a[:, 0, :, :] = spikes[0]

In [43]:
a[:, 0, :, :]

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 

In [30]:
emb[:, 0, :, :].size()

torch.Size([5, 2048, 32])

In [31]:
spk, mem = rlif(emb[:, 0, :, :], spike, membrane_pot)

In [33]:
spk.size(), mem.size()

(torch.Size([5, 2048, 32]), torch.Size([5, 2048, 32]))

In [106]:
from models.TemporalSpikeGraphConvNet import TemporalSpikeGraphConvNet

model = TemporalSpikeGraphConvNet(input_size=input_size,
                               n_nodes=n_nodes,
                               horizon=horizon,
                               hidden_size=hidden_size * 8,
                               use_spike_for_output=True
                               )
model

TemporalSpikeGraphConvNet(
  (encoder): Linear(in_features=65, out_features=256, bias=True)
  (node_embeddings): NodeEmbedding(n_nodes=2048, embedding_size=256)
  (time_nn): SynapticSpike()
  (space_nn): DiffConv(256, 256)
  (decoder): Linear(in_features=256, out_features=2600, bias=True)
  (rearrange): Rearrange('b n (t f) -> b t n f', t=40)
)

In [108]:
device

device(type='cuda', index=0)

In [110]:
model.to(device)
batch.to(device)
res = model(*batch.input)

In [111]:
res

tensor([[[[ 1.0360e+01,  3.1081e+00, -2.2486e-01,  ...,  5.8565e-01,
            4.0614e+00,  1.9673e-01],
          [ 1.1157e+01,  4.6511e+00, -3.0375e+00,  ...,  1.7725e+00,
            3.5030e+00,  9.6556e-01],
          [ 1.3178e+01,  1.5806e+00, -2.0622e+00,  ...,  2.4261e+00,
            3.6040e+00,  1.7883e+00],
          ...,
          [ 1.0598e+01,  1.2583e+00, -5.3719e+00,  ...,  6.9457e-01,
            1.4685e+00,  1.0432e+00],
          [ 1.2235e+01,  7.3006e-01, -2.1544e+00,  ...,  6.5154e-01,
            2.5411e+00, -3.9560e-02],
          [ 1.0556e+01,  2.7652e+00, -2.6616e+00,  ...,  4.5207e-01,
            2.1854e+00,  1.0029e+00]],

         [[-1.4813e+00,  1.8180e+00,  4.5394e-01,  ..., -3.9864e+00,
           -8.3428e+00, -6.9358e-02],
          [-1.2623e+00, -8.1147e-01, -2.4351e-01,  ..., -4.4298e+00,
           -1.0861e+01, -4.4648e-01],
          [-6.2133e-01, -3.0101e-04,  1.8729e-01,  ..., -5.1305e+00,
           -9.1202e+00, -1.8000e+00],
          ...,
     