# Recurrent GNN

In [1]:
# Importing necessary modules and libraries
import os.path as osp  # For handling file and directory paths
import torch  # PyTorch library for tensor computations and deep learning
import torch.nn as nn  # Neural network module in PyTorch
import torch.nn.functional as F  # Functional interface for PyTorch (e.g., activation functions)
import torch_geometric.transforms as T  # Transformations for graph data in PyTorch Geometric
import torch_geometric  # PyTorch Geometric library for graph-based deep learning
from torch_geometric.datasets import Planetoid, TUDataset  # Datasets for graph learning tasks
from torch_geometric.loader import DataLoader  # DataLoader for batching graph data
from torch_geometric.nn.inits import uniform  # Initialization utility for PyTorch Geometric
from torch.nn import Parameter as Param  # Parameter class for defining learnable parameters
from torch import Tensor  # Tensor class for type hinting

# Set the random seed for reproducibility
torch.manual_seed(42)

# Check if the Metal Performance Shaders (MPS) backend is available for GPU acceleration on macOS
if torch.backends.mps.is_available():
    device = torch.device("mps")  # Use MPS device if available
    print("Using MPS device:", device)
else:
    device = torch.device("cpu")  # Fallback to CPU if MPS is not available
    print("MPS not available, using CPU")

# Importing the MessagePassing class from PyTorch Geometric
# This class is used to define custom message-passing layers for graph neural networks
from torch_geometric.nn.conv import MessagePassing

  Referenced from: <CA14ED34-FA3D-31FE-B4AD-2B2A8446B324> /Users/michaels/anaconda3/envs/icgnn/lib/python3.11/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file)
  Referenced from: <CA14ED34-FA3D-31FE-B4AD-2B2A8446B324> /Users/michaels/anaconda3/envs/icgnn/lib/python3.11/site-packages/libpyg.so
  Reason: tried: '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/System/Volumes/Preboot/Cryptexes/OS/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file), '/Library/Frameworks/Python.framework/Versions/3.11/Python' (no such file)


Using MPS device: mps


In [2]:
# Define the dataset name as 'Cora', which is a commonly used citation network dataset
dataset = 'Cora'

# Define a series of transformations to be applied to the dataset
# T.RandomNodeSplit: Splits the nodes into training, validation, and test sets.
#   'train_rest': Specifies that the remaining nodes after validation and test splits will be used for training.
#   num_val=500: Specifies that 500 nodes will be used for validation.
#   num_test=500: Specifies that 500 nodes will be used for testing.
# T.TargetIndegree: Adds the in-degree of each node as a feature to the dataset.
transform = T.Compose([
    T.RandomNodeSplit('train_rest', num_val=500, num_test=500),
    T.TargetIndegree(),
])

# Define the path where the dataset will be stored or loaded from
# osp.join: Joins the directory name ('data') and the dataset name ('Cora') into a single path
path = osp.join('data', dataset)

# Load the Planetoid dataset with the specified path and transformations
# Planetoid: A dataset class in PyTorch Geometric for citation network datasets like Cora, Citeseer, and PubMed
# transform=transform: Applies the defined transformations to the dataset
dataset = Planetoid(path, dataset, transform=transform)

# Access the first graph in the dataset (Cora contains only one graph)
data = dataset[0]
data = data.to(device)  # Move the graph data to the specified device (CPU or MPS)

### Graph Neural Network Model

The MLP class is used to instantiate the transition and output functions as simple feed forard networks

In [3]:
class MLP(nn.Module):
    """
    A Multi-Layer Perceptron (MLP) class that defines a simple feed-forward neural network.
    This class allows for the creation of an MLP with customizable input, hidden, and output dimensions.
    """
    def __init__(self, input_dim, hid_dims, out_dim):
        """
        Initializes the MLP class.

        Args:
            input_dim (int): The number of input features.
            hid_dims (list of int): A list containing the number of units in each hidden layer.
            out_dim (int): The number of output features.
        """
        super(MLP, self).__init__()  # Call the parent class (nn.Module) initializer.

        # Define a sequential container to hold the layers of the MLP.
        self.mlp = nn.Sequential()

        # Combine input, hidden, and output dimensions into a single list for layer construction.
        dims = [input_dim] + hid_dims + [out_dim]

        # Loop through the dimensions to create layers.
        for i in range(len(dims) - 1):
            # Add a linear layer (fully connected layer) to the MLP.
            self.mlp.add_module('lay_{}'.format(i), nn.Linear(in_features=dims[i], out_features=dims[i + 1]))

            # Add an activation function (Tanh) after each linear layer, except the last one.
            if i + 2 < len(dims):  # Skip adding activation for the last layer.
                self.mlp.add_module('act_{}'.format(i), nn.Tanh())

    def reset_parameters(self):
        """
        Resets the parameters of the MLP layers.
        Specifically, it initializes the weights of all linear layers using Xavier normal initialization.
        """
        for i, l in enumerate(self.mlp):  # Iterate through all layers in the sequential container.
            if type(l) == nn.Linear:  # Check if the layer is a linear layer.
                nn.init.xavier_normal_(l.weight)  # Apply Xavier normal initialization to the weights.

    def forward(self, x):
        """
        Defines the forward pass of the MLP.

        Args:
            x (Tensor): The input tensor to the MLP.

        Returns:
            Tensor: The output tensor after passing through the MLP.
        """
        return self.mlp(x)  # Pass the input through the sequential container (MLP).


The GNNM calss puts together the state propagations and the readout of the nodes' states.

In [4]:
class GNNM(MessagePassing):
    """
    A Graph Neural Network Model (GNNM) class that extends the MessagePassing class from PyTorch Geometric.
    This class implements a recurrent graph neural network with state propagation and readout mechanisms.
    """

    def __init__(self, n_nodes, out_channels, features_dim, hid_dims, num_layers=50, eps=1e-3, aggr='add',
                 bias=True, **kwargs):
        """
        Initializes the GNNM class.

        Args:
            n_nodes (int): The number of nodes in the graph.
            out_channels (int): The number of output channels (e.g., number of classes for classification).
            features_dim (int): The dimensionality of the node features.
            hid_dims (list of int): A list of hidden layer dimensions for the transition and readout MLPs.
            num_layers (int, optional): The maximum number of propagation layers. Default is 50.
            eps (float, optional): Convergence threshold for node state updates. Default is 1e-3.
            aggr (str, optional): Aggregation method for message passing ('add', 'mean', 'max'). Default is 'add'.
            bias (bool, optional): Whether to include bias in the MLP layers. Default is True.
            **kwargs: Additional arguments for the MessagePassing class.
        """
        super(GNNM, self).__init__(aggr=aggr, **kwargs)  # Initialize the parent MessagePassing class.

        # Initialize node states as a learnable parameter (not updated during backpropagation).
        self.node_states = Param(torch.zeros((n_nodes, features_dim), device=device), requires_grad=False)

        # Number of output channels (e.g., number of classes for classification).
        self.out_channels = out_channels

        # Convergence threshold for node state updates.
        self.eps = eps

        # Maximum number of propagation layers.
        self.num_layers = num_layers

        # Transition function: A Multi-Layer Perceptron (MLP) for updating node states.
        self.transition = MLP(features_dim, hid_dims, features_dim)

        # Readout function: A Multi-Layer Perceptron (MLP) for producing the final output.
        self.readout = MLP(features_dim, hid_dims, out_channels)

        # Initialize the parameters of the transition and readout MLPs.
        self.reset_parameters()

        # Print the architectures of the transition and readout MLPs for debugging purposes.
        print(self.transition)
        print(self.readout)

    def reset_parameters(self):
        """
        Resets the parameters of the transition and readout MLPs.
        This ensures that the weights are initialized properly before training.
        """
        self.transition.reset_parameters()
        self.readout.reset_parameters()

    def forward(self):
        """
        Defines the forward pass of the GNNM.

        Returns:
            Tensor: The log-softmax output of the model for each node.
        """
        # Access the edge index (connectivity information) and edge attributes from the dataset.
        edge_index = data.edge_index
        edge_weight = data.edge_attr

        # Initialize the node states.
        node_states = self.node_states

        # Perform message passing and state updates for a maximum of `num_layers` iterations.
        for i in range(self.num_layers):
            # Propagate messages across the graph.
            # `propagate` is a method from the MessagePassing class that handles message passing.
            m = self.propagate(edge_index, x=node_states, edge_weight=edge_weight, size=None)

            # Update the node states using the transition MLP.
            new_states = self.transition(m)

            # Compute the distance between the new and old node states to check for convergence.
            with torch.no_grad():  # Disable gradient computation for this block.
                distance = torch.norm(new_states - node_states, dim=1)  # Compute the L2 norm for each node.
                convergence = distance < self.eps  # Check if the distance is below the convergence threshold.

            # Update the node states.
            node_states = new_states

            # If all nodes have converged, stop the propagation early.
            if convergence.all():
                break

        # Apply the readout MLP to the final node states to produce the output.
        out = self.readout(node_states)

        # Apply log-softmax to the output for classification tasks.
        return F.log_softmax(out, dim=-1)

    def message(self, x_j, edge_weight):
        """
        Defines the message computation for each edge.

        Args:
            x_j (Tensor): The features of the neighboring nodes.
            edge_weight (Tensor, optional): The weights of the edges.

        Returns:
            Tensor: The computed messages for each edge.
        """
        # If edge weights are provided, scale the neighboring node features by the edge weights.
        # Otherwise, return the neighboring node features as-is.
        return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

    def message_and_aggregate(self, adj_t, x):
        """
        Defines the combined message and aggregation step for sparse adjacency matrices.

        Args:
            adj_t (SparseTensor): The transposed adjacency matrix.
            x (Tensor): The node features.

        Returns:
            Tensor: The aggregated messages for each node.
        """
        # Perform sparse matrix multiplication to aggregate messages.
        return matmul(adj_t, x, reduce=self.aggr)

    def __repr__(self):
        """
        Returns a string representation of the GNNM class.

        Returns:
            str: A string describing the class, output channels, and number of layers.
        """
        return '{}({}, num_layers={})'.format(self.__class__.__name__,
                                              self.out_channels,
                                              self.num_layers)

In [5]:
# Instantiate the GNNM model with the specified parameters:
# - data.num_nodes: Number of nodes in the graph.
# - dataset.num_classes: Number of output classes for classification.
# - 32: Dimensionality of the node features.
# - [64, 64, 64, 64, 64]: Hidden layer dimensions for the transition and readout MLPs.
# - eps=0.01: Convergence threshold for node state updates.
# Move the model to the specified device (CPU or MPS).
model = GNNM(data.num_nodes, dataset.num_classes, 32, [64, 64, 64, 64, 64], eps=0.01).to(device)

# Define the optimizer for training the model.
# - torch.optim.Adam: Adam optimizer, which is a popular optimization algorithm for deep learning.
# - model.parameters(): Parameters of the model to be optimized.
# - lr=0.001: Learning rate for the optimizer.
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define the loss function for training.
# - nn.CrossEntropyLoss: Cross-entropy loss, commonly used for classification tasks.
loss_fn = nn.CrossEntropyLoss()

# Split the dataset into training and testing subsets.
# - dataset[:len(dataset) // 10]: Take the first 10% of the dataset as the test dataset.
# - dataset[len(dataset) // 10:]: Take the remaining 90% of the dataset as the training dataset.
test_dataset = dataset[:len(dataset) // 10]
train_dataset = dataset[len(dataset) // 10:]

# Create DataLoader objects for batching the training and testing datasets.
# - DataLoader: Utility for batching and shuffling data during training/testing.
test_loader = DataLoader(test_dataset)
train_loader = DataLoader(train_dataset)

# Define the training function.
def train():
    """
    Trains the model for one epoch.

    Steps:
    1. Set the model to training mode.
    2. Zero out the gradients of the optimizer.
    3. Compute the loss for the training data and backpropagate the gradients.
    4. Update the model parameters using the optimizer.
    """
    model.train()  # Set the model to training mode.
    optimizer.zero_grad()  # Clear the gradients of all optimized parameters.
    # Compute the loss for the training data:
    # - model()[data.train_mask]: Predictions for the training nodes.
    # - data.y[data.train_mask]: Ground truth labels for the training nodes.
    loss = loss_fn(model()[data.train_mask], data.y[data.train_mask])
    loss.backward()  # Backpropagate the gradients.
    optimizer.step()  # Perform a single optimization step.

# Define the testing function.
def test():
    """
    Evaluates the model on the training, validation, and test sets.

    Steps:
    1. Set the model to evaluation mode.
    2. Compute the logits (predictions) for all nodes.
    3. For each mask (train, validation, test):
       - Compute the predicted class labels.
       - Compute the accuracy by comparing predictions to ground truth labels.
    4. Return the accuracies for the train, validation, and test sets.

    Returns:
        list: A list containing the accuracies for the train, validation, and test sets.
    """
    model.eval()  # Set the model to evaluation mode.
    logits, accs = model(), []  # Compute the logits (predictions) for all nodes.
    # Iterate over the masks for training, validation, and testing.
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        # Compute the predicted class labels for the nodes in the current mask.
        pred = logits[mask].max(1)[1]  # Take the class with the highest probability.
        # Compute the accuracy for the current mask.
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)  # Append the accuracy to the list.
    return accs  # Return the list of accuracies.

# Train the model for 50 epochs.
for epoch in range(1, 51):
    train()  # Train the model for one epoch.
    accs = test()  # Evaluate the model on the train, validation, and test sets.
    train_acc = accs[0]  # Training accuracy.
    val_acc = accs[1]  # Validation accuracy.
    test_acc = accs[2]  # Test accuracy.
    # Print the accuracies for the current epoch.
    print('Epoch: {:03d}, Train Acc: {:.5f}, '
          'Val Acc: {:.5f}, Test Acc: {:.5f}'.format(epoch, train_acc,
                                                     val_acc, test_acc))

MLP(
  (mlp): Sequential(
    (lay_0): Linear(in_features=32, out_features=64, bias=True)
    (act_0): Tanh()
    (lay_1): Linear(in_features=64, out_features=64, bias=True)
    (act_1): Tanh()
    (lay_2): Linear(in_features=64, out_features=64, bias=True)
    (act_2): Tanh()
    (lay_3): Linear(in_features=64, out_features=64, bias=True)
    (act_3): Tanh()
    (lay_4): Linear(in_features=64, out_features=64, bias=True)
    (act_4): Tanh()
    (lay_5): Linear(in_features=64, out_features=32, bias=True)
  )
)
MLP(
  (mlp): Sequential(
    (lay_0): Linear(in_features=32, out_features=64, bias=True)
    (act_0): Tanh()
    (lay_1): Linear(in_features=64, out_features=64, bias=True)
    (act_1): Tanh()
    (lay_2): Linear(in_features=64, out_features=64, bias=True)
    (act_2): Tanh()
    (lay_3): Linear(in_features=64, out_features=64, bias=True)
    (act_3): Tanh()
    (lay_4): Linear(in_features=64, out_features=64, bias=True)
    (act_4): Tanh()
    (lay_5): Linear(in_features=64, ou



Epoch: 001, Train Acc: 0.16393, Val Acc: 0.15400, Test Acc: 0.13800
Epoch: 002, Train Acc: 0.29859, Val Acc: 0.28800, Test Acc: 0.32400
Epoch: 003, Train Acc: 0.29859, Val Acc: 0.29600, Test Acc: 0.32000
Epoch: 004, Train Acc: 0.29801, Val Acc: 0.29600, Test Acc: 0.31800
Epoch: 005, Train Acc: 0.29801, Val Acc: 0.29800, Test Acc: 0.31800
Epoch: 006, Train Acc: 0.30094, Val Acc: 0.29600, Test Acc: 0.32000
Epoch: 007, Train Acc: 0.30269, Val Acc: 0.29800, Test Acc: 0.32200
Epoch: 008, Train Acc: 0.30386, Val Acc: 0.29600, Test Acc: 0.32200
Epoch: 009, Train Acc: 0.30445, Val Acc: 0.29200, Test Acc: 0.32000
Epoch: 010, Train Acc: 0.30152, Val Acc: 0.28800, Test Acc: 0.32400
Epoch: 011, Train Acc: 0.30152, Val Acc: 0.29000, Test Acc: 0.31600
Epoch: 012, Train Acc: 0.30269, Val Acc: 0.28800, Test Acc: 0.31400
Epoch: 013, Train Acc: 0.30386, Val Acc: 0.29200, Test Acc: 0.31800
Epoch: 014, Train Acc: 0.30621, Val Acc: 0.28800, Test Acc: 0.32200
Epoch: 015, Train Acc: 0.30386, Val Acc: 0.29000