In [None]:
# Settings for autoreloading

%load_ext autoreload
%autoreload 2

https://github.com/LMissher/STGNN/blob/main/model.py

In [1]:
import torch

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

  from .autonotebook import tqdm as notebook_tqdm


In [14]:
print(DEVICE)

cuda


In [92]:
import torch
import torch.nn as nn
from typing import Tuple

class S_GNN(nn.Module):
    def __init__(self, input_dim: int) -> None:
        super().__init__()
        # Module to obtain the latent representation of the input.
        # TODO: check if the latent representation is ok
        self.latent_encoder = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.Linear(input_dim, input_dim)
        )
        # Linear layer to model the spatial feature extraction.
        self.linear = nn.Linear(input_dim, input_dim)
    
    def forward(self, x: torch.Tensor):
        # TODO: Check if the input itself is an Adjacency Matrix, otherwise one must be provided as A
        # Get the latent representation of the input.
        p = self.latent_encoder(x)
        print(p.device)
        with torch.no_grad():
            # Get refined adjacency matrix: A_hat = A + I
            # TODO: A_hat should probably be provided by the model since it is just a refined adjacency matrix, unless we pass the adjacency matrix as an input
            A_hat = p + torch.eye(p.shape[0], p.shape[1], device=x.device)
            # Apply score function: Score(p1, p2) = p1^T p2.
            score = p.transpose(1, 0) @ p
            # Pair-wise relation between any road node.
            R = torch.relu(score).exp() / torch.relu(score).exp().sum()
            # Get the sparsified relation matrix
            R_hat = (A_hat > 0).float() * R + torch.eye(p.shape[0], p.shape[1], device=x.device)
            # Get refined degree matrix for R_hat
            D_hat = R_hat.sum(-1)
            D_hat = torch.diag_embed(D_hat)

            A = (D_hat ** -.5) @ R_hat @ (D_hat ** -.5)
        A = A.to(x.device)
        return torch.relu(self.linear(A @ x))

class GRU(nn.Module):
    def __init__(self, input_size: Tuple[int, int]) -> None:
        super().__init__()
        self.input_size = input_size
        n_nodes, n_features = self.input_size
        # Define a GRUCell layer for each node.
        gru_list = [nn.GRUCell(n_features, n_features) for _ in range(n_nodes)]
        self.gru_layers = nn.ModuleList(gru_list)
        
    def forward(self, x: torch.Tensor, x_h: torch.Tensor):
        # Get number of nodes.
        n_nodes, _ = self.input_size
        # Apply GRU layer on each node features.
        outs = [self.gru_layers[i](x[i,:], x_h[i,:]) for i in range(n_nodes)]
        # Stack the results by row.
        return torch.stack(outs, 0)#.to(x.device)

tensor([[12.0000,  3.3000],
        [ 5.0000,  3.0000]])

In [93]:
s = S_GNN(2).to(DEVICE)
a = torch.tensor([[12., 3.3],
                  [5., 3.]], device=DEVICE)
print(a)
d = s(a)

print(d)

tensor([[12.0000,  3.3000],
        [ 5.0000,  3.0000]], device='cuda:0')
cuda:0
tensor([[inf, nan],
        [inf, nan]], device='cuda:0', grad_fn=<ReluBackward0>)


In [94]:
s = GRU((2, 2)).to(DEVICE)
a = torch.tensor([[12., 3.3],
                  [5., 3.]], device=DEVICE)
b = torch.tensor([[7., 1.3],
                  [0., 3.]], device=DEVICE)
print(a)
d = s(a, b)

print(d)

tensor([[12.0000,  3.3000],
        [ 5.0000,  3.0000]], device='cuda:0')
tensor([[ 5.8793,  1.2576],
        [-0.4574,  1.6593]], device='cuda:0', grad_fn=<StackBackward0>)
