# Models
Here we set up graph convolutional networks for node classification.
All our GCN class takes five inputs: *input_dim*, *hid_dim*, *n_class*, *n_layers*, and *dropout_ratio*.

- The `forward` function should return a Tensor object: **logits**
- The `generate_node_embeddings` fuction should return a Tensor object: **node_embeddings**, which is the representation of the last layer.
- We use `F.relu` and `F.dropout` at the end of each layer.
- We assume all models will have at least 2 layers.

In [29]:
import torch
import os
import typing
import torch_geometric

import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.datasets as datasets

from torch_geometric.nn import GCNConv

In [41]:
MODELS = ['GCN', 'SkipGCN', 'DropEdgeGCN', 'JumpKnowGCN', 'WeightedSkipGCN', 'WeightedSkipDropGCN']

In [34]:
class GCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 act_fn = F.relu):
        """
        Args:
            input_dim: input feature dimension
            hid_dim: hidden feature dimension
            n_classes: number of target classes
            n_layers: number of layers
        """
        super(GCN, self).__init__()
        # assert n_layers > 1
        self.n_layers = n_layers

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes) # final MLP for generating logits
        self.act_fn = act_fn

        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self._forward_before_final_layer(X, A)
        return self.mlp(X)  # do not apply non-linearity before MLP

    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """ Generate node embeddings without applying the MLP. """
        X = self._forward_before_final_layer(X, A)
        X = self.act_fn(X)
        return X  # raw GNN output without applying MLP

    def _forward_before_final_layer(self, X, A) -> torch.Tensor:
        """ Apply all layers except for the final layer and MLP. """
        for l in self.layers[:-1]: # message-passing through all layers except for the last
            X = l(X, A)
            X = self.act_fn(X)
        return X  # raw GNN output without applying MLP

    def param_init(self):
        # initialise MLP parameters
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            # initialise weight in each layer's Linear object
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [35]:
class DropEdgeGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 dropout_ratio: float = 0.3,
                 act_fn = F.relu):
        """
        Args:
            input_dim: input feature dimension
            hid_dim: hidden feature dimension
            n_classes: number of target classes
            n_layers: number of layers
            dropout_ratio: dropout_ratio
        """
        super(GCN, self).__init__()
        # assert n_layers > 1
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes) # final MLP for generating logits
        self.act_fn = act_fn

        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self._forward_before_final_layer(X, A)
        return self.mlp(X)  # do not apply non-linearity & dropout before MLP

    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """ Generate node embeddings without applying the MLP. """
        X = self._forward_before_final_layer(X, A)
        X = self.act_fn(X)
        X = F.dropout(X, p=self.dropout_ratio, training=self.training)
        return X  # raw GNN output without applying MLP

    def _forward_before_final_layer(self, X, A) -> torch.Tensor:
        """ Apply all layers except for the final layer and MLP. """
        for l in self.layers[:-1]: # message-passing through all layers except for the last
            X = l(X, A)
            X = self.act_fn(X)
            X = F.dropout(X, p=self.dropout_ratio, training=self.training)
        return X  # raw GNN output without applying MLP

    def param_init(self):
        # initialise MLP parameters
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            # initialise weight in each layer's Linear object
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [36]:
class SkipGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int):
        """
        Args:
          input_dim: input feature dimension
          hid_dim: hidden feature dimension
          n_classes: number of target classes
          n_layers: number of layers
        """
        super(SkipGCN, self).__init__()
        assert n_layers > 1
        self.n_layers = n_layers

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)

        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        return self.mlp(X)


    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """Generate node embeddings without applying the MLP."""
        X = self.layers[0](X, A)
        X = F.relu(X)

        for l in self.layers[1:-1]: # message-passing through all layers
            residual = X  # previous layer's representation
            X = l(X, A)
            X = F.relu(X)
            X = X + residual  # add skip connection

        return self.layers[-1](X, A)

    def param_init(self):
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [37]:
class JumpKnowGCN(nn.Module):
    def __init__(
        self,
        input_dim: int,
        hid_dim: int,
        n_classes: int,
        n_layers: int,
        dropout_ratio: float = 0.3):
        """
        Args:
            input_dim: input feature dimension
            hid_dim: hidden feature dimension
            n_classes: number of target classes
            n_layers: number of layers
        """
        super(JumpKnowGCN, self).__init__()
        assert n_layers > 1
        self.n_layers = n_layers

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)

        self.param_init()


    def _layer_outputs(self, X, A) -> torch.Tensor:
        """ Outputs of all layers
            (no activation applied to the final layer -
            i.e. last element is just logits)
        """
        outputs = []
        for l in self.layers[:-1]:
            X = l(X, A)
            X = F.relu(X)
            outputs.append(X)

        outputs.append(self.layers[-1](X, A))
        return outputs


    def forward(self, X, A) -> torch.Tensor:
      return self._layer_outputs(X, A)[-1]

    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        outputs = self._layer_outputs(X, A)[:-1]
        return torch.max(torch.stack(outputs), 0)[0]  # max pooling

    def param_init(self):
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [38]:
class WeightedSkipGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 init_res_weight: float = 0.3):
        """
        Args:
          input_dim: input feature dimension
          hid_dim: hidden feature dimension
          n_classes: number of target classes
          n_layers: number of layers
        """
        super(WeightedSkipGCN, self).__init__()
        self.n_layers = n_layers
        self.res_weight = nn.Parameter(torch.tensor(init_res_weight,
                                                    dtype=torch.float32))

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)
        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        print("current residual weight:", self.res_weight)
        return self.mlp(X)  # MLP maps to logits


    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """Generate node embeddings without applying the MLP."""
        # First GCNConv layer
        X = self.layers[0](X, A)
        X = F.relu(X)

        # Intermediate GCNConv layers with skip connections
        for l in self.layers[1:-1]:
            residual = X  # previous layer's representation
            X = l(X, A)
            X = F.relu(X)
            X = self.res_weight * X + residual  # Add skip connection

        # Final GCNConv layer outputs raw embeddings
        return self.layers[-1](X, A)

    def param_init(self):
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [39]:
class WeightedSkipDropGCN(nn.Module):
    def __init__(self, input_dim: int, hid_dim: int,
                 n_classes: int, n_layers: int,
                 dropout_ratio: float = 0.3,
                 init_res_weight: float = 0.3):
        """
        Args:
          input_dim: input feature dimension
          hid_dim: hidden feature dimension
          n_classes: number of target classes
          n_layers: number of layers
          dropout_ratio: dropout ratio
        """
        super(WeightedSkipGCN, self).__init__()
        self.n_layers = n_layers
        self.dropout_ratio = dropout_ratio
        self.res_weight = nn.Parameter(torch.tensor(init_res_weight,
                                                    dtype=torch.float32))

        layers = [GCNConv(input_dim, hid_dim)]
        layers += [GCNConv(hid_dim, hid_dim) for _ in range(1, n_layers)]
        self.layers = nn.ModuleList(layers)
        self.mlp = nn.Linear(hid_dim, n_classes)
        self.param_init()

    def forward(self, X, A) -> torch.Tensor:
        X = self.generate_node_embeddings(X, A)
        print("current residual weight:", self.res_weight)
        return self.mlp(X)  # MLP maps to logits


    def generate_node_embeddings(self, X, A) -> torch.Tensor:
        """Generate node embeddings without applying the MLP."""
        # First GCNConv layer
        X = self.layers[0](X, A)
        X = F.relu(X)
        X = F.dropout(X, p=self.dropout_ratio, training=self.training)

        # Intermediate GCNConv layers with skip connections
        for l in self.layers[1:-1]:
            residual = X  # previous layer's representation
            X = l(X, A)
            X = F.relu(X)
            dropout_mask = F.dropout(
                torch.ones_like(X),
                p=self.dropout_ratio,
                training=self.training)
            X = X * dropout_mask
            residual = residual * dropout_mask
            X = self.res_weight * X + residual  # Add skip connection

        # Final GCNConv layer outputs raw embeddings
        return self.layers[-1](X, A)

    def param_init(self):
        nn.init.xavier_uniform_(self.mlp.weight)
        nn.init.zeros_(self.mlp.bias)
        for conv in self.layers:
            nn.init.xavier_uniform_(conv.lin.weight)
            nn.init.zeros_(conv.bias)

In [40]:
def set_model(params, device):
    """
    Returns the model initialised based on the configuration specified by `params`.
    """
    model_name = params['model_name']
    model_params = [params["input_dim"], params["hid_dim"],
                    params["n_classes"], params["n_layers"], params['dropout_ratio']]
    if model_name == 'WeightedSkipGCN':
        return WeightedSkipGCN(*model_params, init_res_weight=params['init_res_weight'])
    elif model_name in MODELS:
        return globals()[model_name](*model_params)
    else:
        raise NotImplementedError