In [2]:
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATv2Conv

In [4]:
class GraphAttentionNetwork(nn.Module):
    """
      This class defines a PyTorch module that takes in a graph represented in the PyTorch Geometric Data format,
      and outputs a tensor of predictions for each node in the graph. The model consists of one or more TAGConv layers,
      which are a type of graph convolutional layer.

      Args:
          node_dim (int): The number of node inputs.
          edge_dim (int): The number of edge inputs.
          output_dim (int, optional): The number of outputs (default: 1).
          hidden_dim (int, optional): The number of hidden units in each GNN layer (default: 50).
          n_gnn_layers (int, optional): The number of GNN layers in the model (default: 1).
          K (int, optional): The number of hops in the neighbourhood for each GNN layer (default: 2).
          dropout_rate (float, optional): The dropout rate to be applied to the output of each GNN layer (default: 0).

      """

    def __init__(self, node_dim, output_dim, hidden_dim=50, n_gnn_layers=1, heads=1, dropout_rate=0):
        super().__init__()
        self.gat1 = GATv2Conv(node_dim, hidden_dim, heads=heads, dropout=dropout_rate)
        self.gat2 = GATv2Conv(hidden_dim * heads, output_dim, heads=1, dropout=dropout_rate)
        self.dropout = dropout_rate
    def forward(self, x, edge_index):
        # x = data.x
        # edge_index = data.edge_index
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.gat1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=self.dropout, training=self.training)
        x = self.gat2(x, edge_index)
        return F.log_softmax(x, dim=1)