In [1]:
# Settings for autoreloading

%load_ext autoreload
%autoreload 2

https://github.com/LMissher/STGNN/blob/main/model.py

In [2]:
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
print(DEVICE)

cuda


In [87]:
import math
import torch
import torch.nn as nn
from typing import Tuple

class S_GNN(nn.Module):
    def __init__(self, input_dim: int) -> None:
        super().__init__()
        # Module to obtain the latent representation of the input.
        # TODO: check if the latent representation is ok
        self.latent_encoder = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.Linear(64, input_dim)
        )
        # Linear layer to model the spatial feature extraction.
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x: torch.Tensor):
        # TODO: Check if the input itself is an Adjacency Matrix, otherwise one must be provided as A
        # Get the latent representation of the input.
        p = self.latent_encoder(x)
        
        # The following are parameters, henc no gradient storing is required.
        with torch.no_grad():
            # Apply score function: Score(p1, p2) = p1^T p2.
            score = p @ p.transpose(-1, -2)
            # Pair-wise relation between any road node.
            # TODO: Are the phi in the 
            R = torch.relu(score).exp() / torch.relu(score).exp().sum()
            # TODO: A_hat should probably be provided by the model since it is just a refined adjacency matrix, unless we pass the adjacency matrix as an input
            #A_hat = p + torch.eye(p.shape[0], p.shape[1], device=x.device)
            # Get refined adjacency matrix: A_hat = A + I
            A_hat = torch.rand(x.shape[0], x.shape[0], device=DEVICE) #R # TODO: CHANGE
            # Get the sparsified relation matrix
            R_hat = (A_hat > 0).float() * R + torch.eye(R.shape[0], R.shape[1], device=x.device) #torch.eye(p.shape[0], p.shape[1], device=x.device)
            # Get refined degree matrix for R_hat
            D_hat = R_hat.sum(-1)
            D_hat = torch.diag_embed(D_hat)
            # TODO: handle infinities and nones

            A = (D_hat ** -.5) @ R_hat @ (D_hat ** -.5)
            
            #print(A, x)
        return torch.relu(self.linear(A @ x))

class GRU(nn.Module):
    def __init__(self, input_size: Tuple[int, int], hidden_size: int = 64) -> None:
        super().__init__()
        self.input_size = input_size
        n_nodes, n_features = self.input_size
        # Define a GRUCell layer for each node.
        gru_list = [nn.GRUCell(n_features, hidden_size) for _ in range(n_nodes)]
        self.gru_layers = nn.ModuleList(gru_list)
        
    def forward(self, x: torch.Tensor, x_h: torch.Tensor):
        # Get number of nodes.
        _, n_nodes, _ = x.shape
        # Apply GRU layer on each node features.
        outs = [self.gru_layers[i](x[:, i,:], x_h[:, i,:]) for i in range(n_nodes)]
        # Stack the results by row.
        #print('o', torch.stack(outs, 0))
        return torch.stack(outs, 1)#.to(x.device)
    
class Transformer(nn.Module):
    def __init__(self, n_features: int, n_heads: int = 4,
                 hidden_dimension: int = 64) -> None:
        super().__init__()
        #_, n_features = input_size
        # self.queries_linear = nn.Linear(n_features, n_features)
        # self.keys_linear = nn.Linear(n_features, n_features)
        # self.values_linear = nn.Linear(n_features, n_features)
        self.multi_head_attention = nn.MultiheadAttention(n_features, n_heads)
        self.normalization = nn.LayerNorm(n_features)
        self.normalization_out = nn.LayerNorm(n_features)
        self.feed_forward = nn.Sequential(
            nn.Linear(n_features, hidden_dimension),
            nn.ReLU(),
            nn.Linear(hidden_dimension, n_features)
        )
        
    #def _attention(Q, K, V):
    #    d = K.shape[-1]
    #    return torch.softmax(((Q @ K.transpose(1, 0)) / d ** .5) @ V)
        
    def forward(self, x: torch.Tensor):
        # Get queries, keys and values.
        #Q = x.clone() #self.queries_linear(x)
        #K = x.clone() #self.keys_linear(x)
        #V = x.clone() #self.values_linear(x)
        
        # Multi head attention mechanism.
        out, _ = self.multi_head_attention(x, x, x)
        
        # Apply residual connection and batch normalization.
        out += x
        norm = self.normalization(out)
        
        # Apply feed forward module.
        out = self.feed_forward(norm)
        
        # Apply residual connection and batch normalization.
        out += norm
        return self.normalization_out(out)
    
class PositionalEncoding(nn.Module):

    def __init__(self, n_features: int, max_len: int = 5000) -> None:
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, n_features, 2) * (-math.log(10000.0) / n_features))
        pe = torch.zeros(max_len, 1, n_features)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.pe[:x.size(0)]

In [121]:
class SpatioTemporalGNN(nn.Module):
    def __init__(self, n_nodes, n_features, len_timeseries, hidden_dimension = 64, output_graphs = 4):
        super().__init__()
        # self.start_emb = nn.Linear(infea, outfea)
        # self.end_emb = nn.Linear(outfea, infea)
        self.s_gnns = nn.ModuleList([S_GNN((n_features)) for _ in range(len_timeseries)])
        self.hidden_s_gnns = nn.ModuleList([S_GNN((hidden_dimension)) for _ in range(len_timeseries -1)])
        self.grus = nn.ModuleList([GRU((n_nodes, n_features), hidden_dimension) for _ in range(len_timeseries)])
        self.positional_encoding = PositionalEncoding(hidden_dimension)
        self.transformers = nn.ModuleList([Transformer(hidden_dimension, hidden_dimension=hidden_dimension) for _ in range(len_timeseries)])
        
        self.prediction_layer = nn.Sequential(
            nn.Linear(hidden_dimension, n_features),
            nn.ReLU(),
            nn.Linear(n_features, n_features)
            #nn.Upsample(size=(1, output_graphs, n_nodes, n_features))
        )

        self.len_timeseries = len_timeseries
        self.hidden_dimension = hidden_dimension
        

    def forward(self, x):
        # x = x.unsqueeze(-1)
        # x = self.start_emb(x)
        sgnn_outs = []
        for i in range(self.len_timeseries):
            x_ = self.s_gnns[i](x[:, i])
            if i > 0:
                hidden_state = sgnn_outs[i-1] 
            else:
                batch_size, _, n_nodes, _ = x.shape
                hidden_state = torch.zeros(
                    (batch_size, n_nodes, self.hidden_dimension),
                    device=DEVICE)
            x_ = self.grus[i](x_, hidden_state)
            sgnn_outs.append(x_)
            if i < self.len_timeseries:
                self.hidden_s_gnns[i-1](x_)
                
        #x = self.positional_encoding(x)
        
        # TODO: stack row-wise and pass to the transformer
        out = torch.stack(sgnn_outs, 1)
        out = self.positional_encoding(out)
        # TODO: The weights of the transformer seem to be shared, pass just subsets of nodes to a single layer.
        transformer_outs = []
        for i in range(self.len_timeseries):
            transformer_outs.append(self.transformers[i](out[:, i]))
        out = torch.stack(transformer_outs, 1)
        
        return self.prediction_layer(out)

In [122]:
s = SpatioTemporalGNN(2, 4, 2).to(DEVICE)
a = torch.tensor([
    [[[12., -1, 4.4, 5.5],
      [5., 3., 4, -2]],
     [[12., -1, 4.4, 5.5],
      [5., 3., 4, -2]]]], device=DEVICE)
print(a.shape)
d = s(a)

print(d)

torch.Size([1, 2, 2, 4])
tensor([[[[nan, nan, nan, nan],
          [nan, nan, nan, nan]],

         [[nan, nan, nan, nan],
          [nan, nan, nan, nan]]]], device='cuda:0', grad_fn=<AddBackward0>)


In [36]:
s = S_GNN(4).to(DEVICE)
a = torch.tensor([[[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]],
                  [[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]]
                  ], device=DEVICE)
#print(a)
d = s(a)

print(d)

tensor([[[-0.9849,  2.0399,  1.5521,  0.3263],
         [-1.3656,  3.0266, -0.5121,  1.8576]],

        [[-0.9849,  2.0399,  1.5521,  0.3263],
         [-1.3656,  3.0266, -0.5121,  1.8576]]], device='cuda:0',
       grad_fn=<ViewBackward0>)
tensor([[[inf, inf],
         [inf, inf]],

        [[inf, inf],
         [inf, inf]]], device='cuda:0') tensor([[[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]],

        [[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]]], device='cuda:0')
tensor([[[nan, nan, nan, nan],
         [nan, nan, nan, nan]],

        [[nan, nan, nan, nan],
         [nan, nan, nan, nan]]], device='cuda:0', grad_fn=<ReluBackward0>)


In [28]:
s = GRU((2, 4)).to(DEVICE)
a = torch.tensor([[[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]],
                  [[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]]], device=DEVICE)
b = torch.randint(-2, 4, (2, 2, 64), dtype=torch.float32, device=DEVICE)  #tensor([[7., 1.3, -2, 4], [0., 3., -5, 1]], device=DEVICE)
print(a)
d = s(a, b)

print(d)

tensor([[[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]],

        [[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]]], device='cuda:0')
tensor([[[ 1.3005,  0.8127,  0.7641, -0.4614,  1.8968,  2.1198, -0.3613,
           1.6788,  0.6115,  0.9929, -0.0058,  1.6665,  0.3433, -0.1648,
          -0.1502,  0.6703,  0.1902,  1.1755, -0.0786,  0.5219,  1.8726,
           0.9069,  0.9686, -0.3306, -0.4870, -0.8275,  1.3886,  0.1678,
          -0.3717,  0.6982,  0.2038, -1.8341,  2.6371, -1.3918,  2.6478,
           0.2898,  2.8102, -0.1003,  1.8418,  0.2130, -0.0728, -1.5794,
           0.3305, -0.9606,  1.3234, -0.7743,  1.5515,  0.7421,  0.9246,
           1.2294,  0.3247,  2.3740, -0.4397, -0.1896,  1.5115,  0.4113,
          -0.9112,  0.1225, -1.7437,  0.3361, -0.1516,  0.3274, -0.1993,
          -0.8112],
         [-0.2053, -0.2915,  0.6704, -1.0087,  1.5262,  1.9237,  2.6816,
           1.7600, -0.4572, -0.7719, -1.2

In [37]:
s = Transformer(input_size=(2,4), n_heads = 4, hidden_dimension = 64).to(DEVICE)
a = torch.tensor([[[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]],
                  [[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]]], device=DEVICE)
print(a)
d = s(a)

print(d)

tensor([[[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]],

        [[12.0000, -1.0000,  4.4000,  5.5000],
         [ 5.0000,  3.0000,  4.0000, -2.0000]]], device='cuda:0')
tensor([[[ 0.8325, -0.8137,  1.1405, -1.1593],
         [ 0.3592,  0.3865,  0.9393, -1.6850]],

        [[ 0.8325, -0.8137,  1.1405, -1.1593],
         [ 0.3592,  0.3865,  0.9393, -1.6850]]], device='cuda:0',
       grad_fn=<NativeLayerNormBackward0>)


In [42]:
s = PositionalEncoding(4).to(DEVICE)
a = torch.tensor([[[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]],
                  [[12., -1, 4.4, 5.5],
                  [5., 3., 4, -2]]
                  ], device=DEVICE)
#print(a)
d = s(a)

print(d)

tensor([[[12.0000,  0.0000,  4.4000,  6.5000],
         [ 5.0000,  4.0000,  4.0000, -1.0000]],

        [[12.8415, -0.4597,  4.4100,  6.4999],
         [ 5.8415,  3.5403,  4.0100, -1.0001]]], device='cuda:0')
