In [None]:
from torch_geometric.nn import TAGConv
from torch_geometric.nn import AGNNConv
from torch.nn import Sequential, Linear, ReLU
import torch.nn.functional as F
from torch.nn import LogSoftmax
import torch.nn as nn

In [None]:
class SpatialFeatureAggregator(nn.Module):
    """
      This class defines a PyTorch module that takes in a graph represented in the PyTorch Geometric Data format,
      and outputs a tensor of predictions for each node in the graph. The model consists of one or more TAGConv layers,
      which are a type of graph convolutional layer.

      Args:
          node_dim (int): The number of node inputs.
          edge_dim (int): The number of edge inputs.
          output_dim (int, optional): The number of outputs (default: 1).
          hidden_dim (int, optional): The number of hidden units in each GNN layer (default: 50).
          n_gnn_layers (int, optional): The number of GNN layers in the model (default: 1).
          K (int, optional): The number of hops in the neighbourhood for each GNN layer (default: 2).
          dropout_rate (float, optional): The dropout rate to be applied to the output of each GNN layer (default: 0).

      """
    def __init__(self, node_dim, edge_dim, output_dim, hidden_dim=50, n_gnn_layers=2, K=2, dropout_rate=0):
        super().__init__()
        self.node_dim = node_dim
        self.edge_dim = edge_dim
        self.output_dim = output_dim
        self.hidden_dim = hidden_dim
        self.K = K
        self.dropout_rate = dropout_rate
        self.dropout = nn.Dropout(self.dropout_rate, inplace=False)
        self.relu = nn.PReLU()
        self.softmax = LogSoftmax(dim = -1)

        self.gcn1 = TAGConv(node_dim, hidden_dim, K=K)
        self.gcn2 = TAGConv(hidden_dim, hidden_dim, K=K)
        self.fc1 = nn.Linear(node_dim, hidden_dim)
        self.fc2 = nn.Linear(2*hidden_dim, 2*hidden_dim)
        self.fc3 = nn.Linear(2*hidden_dim, output_dim)


    def forward(self, data):

        x = data.x
        edge_index = data.edge_index

        # GCN Layer 1
        x1 = self.gcn1(x=x, edge_index=edge_index)
        x1 = self.dropout(x1)
        x1 = self.relu(x1)

        # GCN Layer 2
        x1 = self.gcn2(x=x1, edge_index=edge_index)
        x1 = self.dropout(x1)
        x1 = self.relu(x1)

        # Parallel FC Layer
        x2 = self.fc1(x)
        x2 = self.dropout(x2)

        # Concatenate representations
        x = torch.cat([x1, x2], dim=1)
        x = self.relu(x)

        # Two more FC Layers
        x = self.fc2(x)
        x = self.dropout(x)
        x = self.relu(x)

        x = self.fc3(x)

        # Convert to probability
        x = self.softmax(x)
        return x