In [1]:
from torch_geometric.nn import TAGConv
from torch_geometric.nn import AGNNConv
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from torch.nn import LogSoftmax
import torch.nn as nn

In [None]:
class TemporalBoy(nn.Module):

    def __init__(self, node_dim, edge_dim, output_dim, hidden_dim=50, n_gnn_layers=2, K=2, dropout_rate=0):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.K = K
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(self.dropout_rate, inplace=False)
        self.relu = nn.PReLU()
        self.softmax = LogSoftmax(dim = -1)
        
        self.gcn1 = TAGConv(hidden_dim, hidden_dim, K=K)
        self.gcn2 = TAGConv(hidden_dim, output_dim, K=K)
        
        self.lstm = nn.LSTM(node_dim, hidden_dim, dropout=dropout_rate)
        
       

    def forward(self, x, edge_index):
        """Applies the GNN to the input graph.

          Args:
              data (Data): A PyTorch Geometric Data object representing the input graph.

          Returns:
              torch.Tensor: The output tensor of the GNN.

          """
        
        x, hidden = self.lstm(x)
        x = self.relu(x)
        # try also feeding in hidden state perhaps?
        x = self.gcn1(x, edge_index)
        x = self.dropout(x)
        x = self.relu(x)
        
        x = self.gcn2(x, edge_index)

        
        # Convert to probability
        x = self.softmax(x)
        return x