In [None]:
import torch
import os
import typing
import torch_geometric

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.datasets as datasets

from torch_geometric.nn import GCNConv

In [None]:
class GCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 dropout_ratio: float = 0.3):
        """
        Args:
            input_dim: input feature dimension
            hid_dim: hidden feature dimension
            n_classes: number of target classes
            n_layers: number of layers
            dropout_ratio: dropout_ratio
        """
        super(GCN, self).__init__()
        assert n_layers > 1
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes) # final MLP for generating logits

        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        return self.mlp(X)

    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """ Generate node embeddings without applying the MLP. """
        for l in self.layers[:-1]: # message-passing through all layers
            X = l(X, A)
            X = F.relu(X)
            X = F.dropout(X, p=self.dropout_ratio, training=self.training)
        return self.layers[-1](X, A)  # raw GNN output without applying MLP

    def param_init(self):
        # initialise MLP parameters
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            # initialise weight in each layer's Linear object
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [None]:
class SkipGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 dropout_ratio: float = 0.3):
        """
        Args:
          input_dim: input feature dimension
          hid_dim: hidden feature dimension
          n_classes: number of target classes
          n_layers: number of layers
          dropout_ratio: dropout ratio
        """
        super(SkipGCN, self).__init__()
        assert n_layers > 1
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)

        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        return self.mlp(X)


    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """Generate node embeddings without applying the MLP."""
        X = self.layers[0](X, A)
        X = F.relu(X)
        X = F.dropout(X, p=self.dropout_ratio, training=self.training)

        for l in self.layers[1:-1]: # message-passing through all layers
            residual = X  # previous layer's representation
            X = l(X, A)
            X = F.relu(X)
            dropout_mask = F.dropout(
                torch.ones_like(X),
                p=self.dropout_ratio,
                training=self.training)
            X = X * dropout_mask
            residual = residual * dropout_mask
            X = X + residual  # add skip connection

        return self.layers[-1](X, A)

    def param_init(self):
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [None]:
class JumpKnowGCN(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hid_dim: int,
        n_classes: int,
        n_layers: int,
        dropout_ratio: float = 0.3):
        """
        Args:
            input_dim: input feature dimension
            hid_dim: hidden feature dimension
            n_classes: number of target classes
            n_layers: number of layers
            dropout_ratio: dropout ratio
        """
        super(JumpKnowGCN, self).__init__()
        assert n_layers > 1
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)

        self.param_init()


    def _layer_outputs(self, X, A) -> torch.Tensor:
        """ Outputs of all layers
            (no activation & dropout applied to the final layer -
            i.e. last element is just logits)
        """
        outputs = []
        for l in self.layers[:-1]:
            X = l(X, A)
            X = F.relu(X)
            X = F.dropout(X, p=self.dropout_ratio, training=self.training)
            outputs.append(X)

        outputs.append(self.layers[-1](X, A))
        return outputs


    def forward(self, X, A) -> torch.Tensor:
      return self._layer_outputs(X, A)[-1]

    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        outputs = self._layer_outputs(X, A)[:-1]
        return torch.max(torch.stack(outputs), 0)[0]  # max pooling

    def param_init(self):
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [None]:
class WeightedSkipGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 dropout_ratio: float = 0.3,
                 init_res_weight: float = 0.3):
        """
        Args:
          input_dim: input feature dimension
          hid_dim: hidden feature dimension
          n_classes: number of target classes
          n_layers: number of layers
          dropout_ratio: dropout ratio
        """
        super(WeightedSkipGCN, self).__init__()
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio
        self.res_weight = nn.Parameter(torch.tensor(init_res_weight,
                                                    dtype=torch.float32))

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)
        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        return self.mlp(X)  # MLP maps to logits


    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """Generate node embeddings without applying the MLP."""
        # First GCNConv layer
        X = self.layers[0](X, A)
        X = F.relu(X)
        X = F.dropout(X, p=self.dropout_ratio, training=self.training)

        # Intermediate GCNConv layers with skip connections
        for l in self.layers[1:-1]:
            residual = X  # previous layer's representation
            X = l(X, A)
            X = F.relu(X)
            dropout_mask = F.dropout(
                torch.ones_like(X),
                p=self.dropout_ratio,
                training=self.training)
            X = X * dropout_mask
            residual = residual * dropout_mask
            X = self.res_weight * X + residual  # Add skip connection

        # Final GCNConv layer outputs raw embeddings
        return self.layers[-1](X, A)

    def param_init(self):
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)