In [None]:
import os
import sys
import tqdm
import copy
import json
import time
import math
import numpy as np
import pandas as pd
import random
import pickle
import logging
import networkx
import scipy.sparse
from enum import Enum, auto
from typing import Dict, NamedTuple, Tuple, List, Union, Optional, Callable, Any

import torch
from torch import Tensor
from torch.cuda import set_device
from torch_sparse import coalesce
from torch_geometric.data import Data, Batch, InMemoryDataset, download_url
from torch_geometric.utils import remove_self_loops, add_remaining_self_loops, to_undirected, get_laplacian, to_scipy_sparse_matrix, scatter, to_dense_adj
from torch_geometric.loader import DataLoader
from torch_geometric.typing import OptTensor, Adj
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.pool import global_mean_pool, global_add_pool
from torch_geometric.datasets import HeterophilousGraphDataset
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_geometric.graphgym.loader import index2mask, set_dataset_attr
from torch_geometric.nn.conv.gcn_conv import gcn_norm
import torch_geometric.transforms as T
from torchmetrics import MeanAbsoluteError, Accuracy, AveragePrecision, AUROC

from ogb.utils import smiles2graph
from ogb.utils.url import decide_download
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims
from ogb.utils.torch_util import replace_numpy_with_torchtensor

from hashlib import md5
from shutil import rmtree
from functools import partial
from os import getcwd
from os.path import join, dirname, abspath, isdir, exists

In [None]:
try:
    ROOT_DIR = dirname(dirname(abspath(__file__)))
except:
    ROOT_DIR = dirname(dirname(abspath(getcwd())))

To ensure reproducibility, we define the `set_seed()` function:

In [None]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = False

---

## Defining the model options

The `MolConv` class represents the base layer of the `MessagePassing` class. The default aggregation is SUM. When we call the model to compute the output, for each node it does the following:

1. `message()`: Figure out who are my neighbors.

2. `message()`: Add the edge attributes to the neighbors' embeddings and scale this sum by the edge weights.

3. `update()`: Combine (aggregate) the resulting embeddings from (2) to produce the new node's embedding.

In [None]:
class MolConv(MessagePassing):
    def __init__(self, aggr='add'):
        super().__init__(aggr=aggr)  # 'add', 'mean' or 'max'

    def message(self, x_j: Tensor, edge_attr: OptTensor, edge_weight: OptTensor = None) -> Tensor:
        if edge_attr is None:
            if edge_weight is None:
                return x_j
            else:
                return edge_weight.view(-1, 1) * x_j
        else:
            if edge_weight is None:
                return x_j + edge_attr
            else:
                return edge_weight.view(-1, 1) * (x_j + edge_attr)

    def update(self, aggr_out: Tensor) -> Tensor:
        return aggr_out

The `WeightedGCNConv` class extends the previous base convolution layer. It requires:

- The number of input channels `in_channels` (dimension of a node's feature vector),

- The number of output channels `out_channels` (dimension of a node's feature vector after convolution),

- Whether or not to include a `bias` term in the linear layer, and

- other arguments `kwargs` passed to the `MolConv` class.

The forward pass is designed to:

1. Create a collection of the graph's edges without self-loops to avoid double counting of nodes' features in the normalization.

2. Create a collection of the edge attributes inclduing the self-loops for the propagation step so that the node can "keep its prior state in memory".

3. Normalize the graph's edge weights by their degree.

4. Perform a propagation step and map the output using the linear layer to the desired number of output channels.

In [None]:
class WeightedGCNConv(MolConv):
    def __init__(self, in_channels: int, out_channels: int, bias: bool, **kwargs):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.add_self_loops = True
        self.improved = False

        self.lin = torch.nn.Linear(in_channels, out_channels, bias=bias)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None, edge_weight: OptTensor = None) -> Tensor:
        edge_index = remove_self_loops(edge_index=edge_index)[0]
        _, edge_attr = add_remaining_self_loops(edge_index, edge_attr, fill_value=1, num_nodes=x.shape[0])

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        edge_index, edge_weight = gcn_norm(  # yapf: disable
            edge_index, edge_weight, x.size(self.node_dim),
            self.improved, self.add_self_loops, self.flow, x.dtype)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None, edge_attr=edge_attr)
        out = self.lin(out)
        return out

In the `WeightedGNNConv` class we remove the normalization step as compared to the `WeightedGCNConv` and before running the output of the propagation step through a linear layer, we first concatenate to it the original node's features.

In [None]:
class WeightedGNNConv(MolConv):
    def __init__(self, in_channels: int, out_channels: int, aggr='add', bias=True):
        super().__init__(aggr=aggr)
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.lin = torch.nn.Linear(2 * in_channels, out_channels, bias=bias)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None, edge_weight: OptTensor = None) -> Tensor:
        out = self.propagate(edge_index, x=x, edge_attr=edge_attr, edge_weight=edge_weight)
        out = self.lin(torch.cat((x, out), dim=-1))
        return out

A further simplifaction is implemented by `GraphLinear` which essentially simply passes the node's features to a linear layer.

In [None]:
class GraphLinear(torch.nn.Linear):
    def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None, edge_weight: OptTensor = None) -> Tensor:
        return super().forward(x)

Similarly to the `WeightedGCNConv`, the `WeightedGINConv` class extends the `MolConv` base convolution layer. However, instead of the `kwargs` it requires a `mlp_func` function which should return an architecture for a Multilayer Perceptron.

In the forward pass, we perform a propagation step identically to the one in the `WeightedGCNConv` class, but instead of using a linear layer to map the output to another dimensionality, we run the output (to which we add the node's features with potential skip connections) through a MLP.

In [None]:
class WeightedGINConv(MolConv):
    def __init__(self, in_channels: int, out_channels: int, bias: bool, mlp_func: Callable):
        """
            emb_dim (int): node embedding dimensionality
        """
        super(WeightedGINConv, self).__init__(aggr="add")

        self.mlp = mlp_func(in_channels=in_channels, out_channels=out_channels, bias=bias)
        self.eps = torch.Tensor([0])
        self.eps = torch.nn.Parameter(self.eps)

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: OptTensor = None, edge_weight: OptTensor = None) -> Tensor:
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight, size=None, edge_attr=edge_attr)
        return self.mlp((1 + self.eps.to(x.device)) * x + out)

The `ModelType` class provides a container for the different types of models from before. In the `get_component_list` function we return a list of the full architecture of the model that we will use, which can later be instantiated with a call to `torch.nn.Sequential()`.

In [None]:
class ModelType(Enum):
    """
        an object for the different core
    """
    GCN = auto()
    GIN = auto()
    LIN = auto()

    SUM_GNN = auto()
    MEAN_GNN = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return ModelType[s]
        except KeyError:
            raise ValueError()

    def load_component_cls(self):
        if self is ModelType.GCN:
            return WeightedGCNConv
        elif self is ModelType.GIN:
            return WeightedGINConv
        elif self in [ModelType.SUM_GNN, ModelType.MEAN_GNN]:
            return WeightedGNNConv
        elif self is ModelType.LIN:
            return GraphLinear
        else:
            raise ValueError(f'model {self.name} not supported')

    def is_gcn(self):
        return self is ModelType.GCN

    def get_component_list(self, in_dim: int, hidden_dim: int, out_dim: int, num_layers: int, bias: bool,
                           edges_required: bool, gin_mlp_func: Callable) -> List[torch.nn.Module]:
        dim_list = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
        if self is ModelType.GCN:
            component_list = [self.load_component_cls()(in_channels=in_dim_i, out_channels=out_dim_i, bias=bias)
                              for in_dim_i, out_dim_i in zip(dim_list[:-1], dim_list[1:])]
        elif self is ModelType.GIN:
            component_list = [self.load_component_cls()(in_channels=in_dim_i, out_channels=out_dim_i, bias=bias,
                                                        mlp_func=gin_mlp_func)
                              for in_dim_i, out_dim_i in zip(dim_list[:-1], dim_list[1:])]
        elif self in [ModelType.SUM_GNN, ModelType.MEAN_GNN]:
            aggr = 'mean' if self is ModelType.MEAN_GNN else 'sum'
            component_list = [self.load_component_cls()(in_channels=in_dim_i, out_channels=out_dim_i, aggr=aggr,
                                                        bias=bias)
                              for in_dim_i, out_dim_i in zip(dim_list[:-1], dim_list[1:])]
        elif self is ModelType.LIN:
            assert not edges_required, f'env does not support {self.name}'
            component_list = \
                [self.load_component_cls()(in_features=in_dim_i, out_features=out_dim_i, bias=bias)
                 for in_dim_i, out_dim_i in zip(dim_list[:-1], dim_list[1:])]
        else:
            raise ValueError(f'model {self.name} not supported')
        return component_list

---

## Miscellaneous helper classes

First, we set up a helper class `BatchIdentity` that ensures differentiability of the loss in the backpropagation step when returned by the `Pool.get()` method.

In [None]:
class BatchIdentity(torch.nn.Module):
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        super().__init__()

    def forward(self, x: Tensor, batch: Tensor) -> Tensor:
        return x

The `Pool` class enables us to have the various aggregation types (the authors mistakenly wrote actviation types instead) in a container class.

In [None]:
class Pool(Enum):
    """
        an object for the different activation types
    """
    NONE = auto()
    MEAN = auto()
    SUM = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return Pool[s]
        except KeyError:
            raise ValueError()

    def get(self):
        if self is Pool.MEAN:
            return global_mean_pool
        elif self is Pool.SUM:
            return global_add_pool
        elif self is Pool.NONE:
            return BatchIdentity()
        else:
            raise ValueError(f'Pool {self.name} not supported')

The `ActivationType` class is a container where we can set up the various activation functions to be used depending on the dataset, as specified in the `DataSet` class. The activation function is also instantiated in the Action Net of the `Experiment` class.

In [None]:
class ActivationType(Enum):
    """
        an object for the different activation types
    """
    RELU = auto()
    GELU = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return ActivationType[s]
        except KeyError:
            raise ValueError()

    def get(self):
        if self is ActivationType.RELU:
            return torch.nn.functional.relu
        elif self is ActivationType.GELU:
            return torch.nn.functional.gelu
        else:
            raise ValueError(f'ActivationType {self.name} not supported')

    def nn(self) -> torch.nn.Module:
        if self is ActivationType.RELU:
            return torch.nn.ReLU()
        elif self is ActivationType.GELU:
            return torch.nn.GELU()
        else:
            raise ValueError(f'ActivationType {self.name} not supported')

The `GumbelArgs` class is a container for various parameters for the `Experiment` class.

In [None]:
class GumbelArgs(NamedTuple):
    learn_temp: bool
    temp_model_type: ModelType
    tau0: float
    temp: float
    gin_mlp_func: Callable

The `Concat2NodeEncoder` is a prerequisite class for the `EnvArgs` class below. It essentially concatenates two node encoders.

In [None]:
class Concat2NodeEncoder(torch.nn.Module):
    """Encoder that concatenates two node encoders.
    """

    def __init__(self, enc1_cls, enc2_cls, in_dim, emb_dim, enc2_dim_pe):
        super().__init__()
        # PE dims can only be gathered once the cfg is loaded.
        self.encoder1 = enc1_cls(in_dim=in_dim, emb_dim=emb_dim - enc2_dim_pe)
        self.encoder2 = enc2_cls(in_dim=in_dim, emb_dim=emb_dim, expand_x=False)

    def forward(self, x, pestat):
        x = self.encoder1(x, pestat)
        x = self.encoder2(x, pestat)
        return x

Now, the `LossesAndMetrics` class provides a container for losses and metrics.

In [None]:
class LossesAndMetrics(NamedTuple):
    train_loss: float
    val_loss: float
    test_loss: float
    train_metric: float
    val_metric: float
    test_metric: float

    def get_fold_metrics(self):
        return torch.tensor([self.train_metric, self.val_metric, self.test_metric])

The previously defined `LossesAndMetrics` will, among others, be used in the following `MetricType` class, which provides a container for various metrics. It enables the following functionalities:

- `apply_metric()` instantiates the desired metric type and uses the object to compute the metric result which it returns as a simple float value.

- `is_classification()` returns `True` if the metric type belongs to a classification task.

- `is_multilabel()` return `True` if a class object is a multilabel classification task.

- `get_task_loss()` returns various loss functions depending on the metric type.

- `get_out_dim()` returns the output dimension required for the model based on the number of classes of the dataset.

- `higher_is_better()` returns `True` if the model's performance is better when the metric's value is higher.

- `src_better_than_other()` compares two metric values and says whether the one is better than the other depending on whether `higher_is_better()`.

- `get_worst_losses_n_metrics()` returns the worst possible LossesAndMetrics instance depending on the task.

In [None]:
class MetricType(Enum):
    """
        an object for the different metrics
    """
    # classification
    ACCURACY = auto()
    MULTI_LABEL_AP = auto()
    AUC_ROC = auto()

    # regression
    MSE_MAE = auto()

    def apply_metric(self, scores: np.ndarray, target: np.ndarray) -> float:
        if isinstance(scores, np.ndarray):
            scores = torch.from_numpy(scores)
        if isinstance(target, np.ndarray):
            target = torch.from_numpy(target)
        num_classes = scores.size(1)  # target.max().item() + 1
        if self is MetricType.ACCURACY:
            metric = Accuracy(task="multiclass", num_classes=num_classes)
        elif self is MetricType.MULTI_LABEL_AP:
            metric = AveragePrecision(task="multilabel", num_labels=num_classes).to(scores.device)
            result = metric(scores, target.int())
            return result.item()
        elif self is MetricType.MSE_MAE:
            metric = MeanAbsoluteError()
        elif self is MetricType.AUC_ROC:
            metric = AUROC(task="multiclass", num_classes=num_classes)
        else:
            raise ValueError(f'MetricType {self.name} not supported')

        metric = metric.to(scores.device)
        result = metric(scores, target)
        return result.item()

    def is_classification(self) -> bool:
        if self in [MetricType.AUC_ROC, MetricType.ACCURACY, MetricType.MULTI_LABEL_AP]:
            return True
        elif self is MetricType.MSE_MAE:
            return False
        else:
            raise ValueError(f'MetricType {self.name} not supported')

    def is_multilabel(self) -> bool:
        return self is MetricType.MULTI_LABEL_AP

    def get_task_loss(self):
        if self.is_classification():
            if self.is_multilabel():
                return torch.nn.BCEWithLogitsLoss()
            else:
                return torch.nn.CrossEntropyLoss()
        elif self is MetricType.MSE_MAE:
            return torch.nn.MSELoss()
        else:
            raise ValueError(f'MetricType {self.name} not supported')

    def get_out_dim(self, dataset: List[Data]) -> int:
        if self.is_classification():
            if self.is_multilabel():
                return dataset[0].y.shape[1]
            else:
                return int(max([data.y.max().item() for data in dataset]) + 1)
        else:
            return dataset[0].y.shape[-1]

    def higher_is_better(self):
        return self.is_classification()

    def src_better_than_other(self, src: float, other: float) -> bool:
        if self.higher_is_better():
            return src > other
        else:
            return src < other

    def get_worst_losses_n_metrics(self) -> LossesAndMetrics:
        if self.is_classification():
            return LossesAndMetrics(train_loss=math.inf, val_loss=math.inf, test_loss=math.inf,
                                    train_metric=-math.inf, val_metric=-math.inf, test_metric=-math.inf)
        else:
            return LossesAndMetrics(train_loss=math.inf, val_loss=math.inf, test_loss=math.inf,
                                    train_metric=math.inf, val_metric=math.inf, test_metric=math.inf)

The `EncoderLinear` class is a simple linear layer and it is different from the `GraphLinear` class defined later in the argument that the forward pass takes, however that does not impact the output of the model.

In [None]:
class EncoderLinear(torch.nn.Linear):
    def forward(self, x: Tensor, pestat=None) -> Tensor:
        return super().forward(x)

Next up, the `AtomEncoder` class is a network functioning as an encoder of atom features.

In [None]:
class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(get_atom_feature_dims()):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x, pestat):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:, i])

        return x_embedding

Similarly to the `AtomEncoder`, the `BondEncoder` class is a network functioning as an encoder of the bonds, i.e. edges, between two atoms.

In [None]:
class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for i, dim in enumerate(get_bond_feature_dims()):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])

        return bond_embedding

Incorporating the three Encoders above, the `DataSetEncoders` defines how the nodes' and edges' features shall be eoncoded. It does not hold the encoded data but rather the encoder instances.

In [None]:
class DataSetEncoders(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    MOL = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return DataSetEncoders[s]
        except KeyError:
            raise ValueError()

    def node_encoder(self, in_dim: int, emb_dim: int):
        if self is DataSetEncoders.NONE:
            return EncoderLinear(in_features=in_dim, out_features=emb_dim)
        elif self is DataSetEncoders.MOL:
            return AtomEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def edge_encoder(self, emb_dim: int, model_type):
        if self is DataSetEncoders.NONE:
            return None
        elif self is DataSetEncoders.MOL:
            if model_type.is_gcn():
                return None
            else:
                return BondEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def use_encoders(self) -> bool:
        return self is not DataSetEncoders.NONE

After defining various default values for some parameters, the Laplace Positional Embedding node encoder class `LapPENodeEncoder` is a network that injects positional information into the node features via concatenation to the feature vector.

In [None]:
LAP_DIM_PE = 16
LAP_MODEL = 'DeepSet' # renamed from original code
LAP_LAYERS = 2 # renamed from original code
N_HEADS = 4
POST_LAYERS = 0
LAP_MAX_FREQS = 10
LAP_RAW_NORM_TYPE = 'none' # renamed from original code
LAP_PASS_AS_VAR = False # renamed from original code

In [None]:
class LapPENodeEncoder(torch.nn.Module):
    """Laplace Positional Embedding node encoder.

    LapPE of size dim_pe will get appended to each node feature vector.
    If `expand_x` set True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and the concatenated with LapPE.

    Args:
        dim_emb: Size of final node embedding
        expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
    """

    def __init__(self, dim_in, dim_emb, expand_x=True):
        super().__init__()
        dim_pe = LAP_DIM_PE  # Size of Laplace PE embedding
        model_type = LAP_MODEL  # Encoder NN model type for PEs
        if model_type not in ['Transformer', 'DeepSet']:
            raise ValueError(f"Unexpected PE model {model_type}")
        self.model_type = model_type
        n_layers = LAP_LAYERS  # Num. layers in PE encoder model
        n_heads = N_HEADS  # Num. attention heads in Trf PE encoder
        post_n_layers = POST_LAYERS  # Num. layers to apply after pooling
        max_freqs = LAP_MAX_FREQS  # Num. eigenvectors (frequencies)
        norm_type = LAP_RAW_NORM_TYPE.lower()  # Raw PE normalization layer type
        self.pass_as_var = LAP_PASS_AS_VAR  # Pass PE also as a separate variable

        if dim_emb - dim_pe < 0: # formerly 1, but you could have zero feature size
            raise ValueError(f"LapPE size {dim_pe} is too large for "
                             f"desired embedding size of {dim_emb}.")

        if expand_x and dim_emb - dim_pe > 0:
            self.linear_x = torch.nn.Linear(dim_in, dim_emb - dim_pe)
        self.expand_x = expand_x and dim_emb - dim_pe > 0

        # Initial projection of eigenvalue and the node's eigenvector value
        self.linear_A = torch.nn.Linear(2, dim_pe)
        if norm_type == 'batchnorm':
            self.raw_norm = torch.nn.BatchNorm1d(max_freqs)
        else:
            self.raw_norm = None

        activation = torch.nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == 'Transformer':
            # Transformer model for LapPE
            encoder_layer = torch.nn.TransformerEncoderLayer(d_model=dim_pe,
                                                       nhead=n_heads,
                                                       batch_first=True)
            self.pe_encoder = torch.nn.TransformerEncoder(encoder_layer,
                                                    num_layers=n_layers)
        else:
            # DeepSet model for LapPE
            layers = []
            if n_layers == 1:
                layers.append(activation())
            else:
                self.linear_A = torch.nn.Linear(2, 2 * dim_pe)
                layers.append(activation())
                for _ in range(n_layers - 2):
                    layers.append(torch.nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(torch.nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.pe_encoder = torch.nn.Sequential(*layers)

        self.post_mlp = None
        if post_n_layers > 0:
            # MLP to apply post pooling
            layers = []
            if post_n_layers == 1:
                layers.append(torch.nn.Linear(dim_pe, dim_pe))
                layers.append(activation())
            else:
                layers.append(torch.nn.Linear(dim_pe, 2 * dim_pe))
                layers.append(activation())
                for _ in range(post_n_layers - 2):
                    layers.append(torch.nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(torch.nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.post_mlp = torch.nn.Sequential(*layers)

    def forward(self, x, pestat):
        EigVals = pestat[0]
        EigVecs = pestat[1]

        if self.training:
            sign_flip = torch.rand(EigVecs.size(1), device=EigVecs.device)
            sign_flip[sign_flip >= 0.5] = 1.0
            sign_flip[sign_flip < 0.5] = -1.0
            EigVecs = EigVecs * sign_flip.unsqueeze(0)

        pos_enc = torch.cat((EigVecs.unsqueeze(2), EigVals), dim=2) # (Num nodes) x (Num Eigenvectors) x 2
        empty_mask = torch.isnan(pos_enc)  # (Num nodes) x (Num Eigenvectors) x 2

        pos_enc[empty_mask] = 0  # (Num nodes) x (Num Eigenvectors) x 2
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)
        pos_enc = self.linear_A(pos_enc)  # (Num nodes) x (Num Eigenvectors) x dim_pe

        # PE encoder: a Transformer or DeepSet model
        if self.model_type == 'Transformer':
            pos_enc = self.pe_encoder(src=pos_enc,
                                      src_key_padding_mask=empty_mask[:, :, 0])
        else:
            pos_enc = self.pe_encoder(pos_enc)

        # Remove masked sequences; must clone before overwriting masked elements
        pos_enc = pos_enc.clone().masked_fill_(empty_mask[:, :, 0].unsqueeze(2), 0.)

        # Sum pooling
        pos_enc = torch.sum(pos_enc, 1, keepdim=False)  # (Num nodes) x dim_pe

        # MLP post pooling
        if self.post_mlp is not None:
            pos_enc = self.post_mlp(pos_enc)  # (Num nodes) x dim_pe

        # Expand node features if needed
        if self.expand_x:
            h = self.linear_x(x)
        else:
            h = x
        # Concatenate final PEs to input embedding
        x = torch.cat((h, pos_enc), 1)
        return x

Again after defining some default parameter values, the Kernel-based Positional Embedding node encoder class `KernelPENodeEncoder`. In contrast to the `LapPENodeEncoder` which uses Laplace eigenvalues and eigenvectors to encode global graph-level structural information, this `KernelPENodeEncoder` relies on statistics such as Random Walk Structural Encodings.

In [None]:
KER_DIM_PE = 28
NUM_RW_STEPS = 20
KER_MODEL = 'Linear' # renamed from original code
KER_LAYERS = 3 # renamed from original code
KER_RAW_NORM_TYPE = 'BatchNorm' # renamed from original code
KER_PASS_AS_VAR = False # renamed from original code

In [None]:
class KernelPENodeEncoder(torch.nn.Module):
    """Configurable kernel-based Positional Encoding node encoder.

    The choice of which kernel-based statistics to use is configurable through
    setting of `kernel_type`. Based on this, the appropriate config is selected,
    and also the appropriate variable with precomputed kernel stats is then
    selected from PyG Data graphs in `forward` function.
    E.g., supported are 'RWSE', 'HKdiagSE', 'ElstaticSE'.

    PE of size `dim_pe` will get appended to each node feature vector.
    If `expand_x` set True, original node features will be first linearly
    projected to (dim_emb - dim_pe) size and the concatenated with PE.

    Args:
        dim_emb: Size of final node embedding
        expand_x: Expand node features `x` from dim_in to (dim_emb - dim_pe)
    """

    kernel_type = None  # Instantiated type of the KernelPE, e.g. RWSE

    def __init__(self, dim_in, dim_emb, expand_x=True):
        super().__init__()
        if self.kernel_type is None:
            raise ValueError(f"{self.__class__.__name__} has to be "
                             f"preconfigured by setting 'kernel_type' class"
                             f"variable before calling the constructor.")

        dim_pe = KER_DIM_PE  # Size of the kernel-based PE embedding
        num_rw_steps = NUM_RW_STEPS
        model_type = KER_MODEL.lower()  # Encoder NN model type for PEs
        n_layers = KER_LAYERS  # Num. layers in PE encoder model
        norm_type = KER_RAW_NORM_TYPE.lower()  # Raw PE normalization layer type
        self.pass_as_var = KER_PASS_AS_VAR  # Pass PE also as a separate variable

        if dim_emb - dim_pe < 0: # formerly 1, but you could have zero feature size
            raise ValueError(f"PE dim size {dim_pe} is too large for "
                             f"desired embedding size of {dim_emb}.")

        if expand_x and dim_emb - dim_pe > 0:
            self.linear_x = torch.nn.Linear(dim_in, dim_emb - dim_pe)
        self.expand_x = expand_x and dim_emb - dim_pe > 0

        if norm_type == 'batchnorm':
            self.raw_norm = torch.nn.BatchNorm1d(num_rw_steps)
        else:
            self.raw_norm = None

        activation = torch.nn.ReLU  # register.act_dict[cfg.gnn.act]
        if model_type == 'mlp':
            layers = []
            if n_layers == 1:
                layers.append(torch.nn.Linear(num_rw_steps, dim_pe))
                layers.append(activation())
            else:
                layers.append(torch.nn.Linear(num_rw_steps, 2 * dim_pe))
                layers.append(activation())
                for _ in range(n_layers - 2):
                    layers.append(torch.nn.Linear(2 * dim_pe, 2 * dim_pe))
                    layers.append(activation())
                layers.append(torch.nn.Linear(2 * dim_pe, dim_pe))
                layers.append(activation())
            self.pe_encoder = torch.nn.Sequential(*layers)
        elif model_type == 'linear':
            self.pe_encoder = torch.nn.Linear(num_rw_steps, dim_pe)
        else:
            raise ValueError(f"{self.__class__.__name__}: Does not support "
                             f"'{model_type}' encoder model.")

    def forward(self, x, pestat):
        pos_enc = pestat  # (Num nodes) x (Num kernel times)
        # pos_enc = batch.rw_landing  # (Num nodes) x (Num kernel times)
        if self.raw_norm:
            pos_enc = self.raw_norm(pos_enc)
        pos_enc = self.pe_encoder(pos_enc)  # (Num nodes) x dim_pe

        # Expand node features if needed
        if self.expand_x:
            h = self.linear_x(x)
        else:
            h = x
        # Concatenate final PEs to input embedding
        x = torch.cat((h, pos_enc), 1)
        return x

The `RWSENodeEncoder` is then just an implementation of a specific type of `KernelPENodeEncoder`.

In [None]:
class RWSENodeEncoder(KernelPENodeEncoder):
    """Random Walk Structural Encoding node encoder.
    """
    kernel_type = 'RWSE'

The `PosEncoder` class provides an abstraction for the various encoder classes defined earlier. With it we can do the following:

- `get()` returns an instance of the encoder.

- `DIM_PE()` returns the dimensionality of the encoding vector.

- `get_pe()` returns the positional encoding vector of each node in a graph.

In [None]:
class PosEncoder(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    LAP = auto()
    RWSE = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return PosEncoder[s]
        except KeyError:
            raise ValueError()

    def get(self, in_dim: int, emb_dim: int, expand_x: bool):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LapPENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        elif self is PosEncoder.RWSE:
            return RWSENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def DIM_PE(self):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LAP_DIM_PE
        elif self is PosEncoder.RWSE:
            return KER_DIM_PE
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def get_pe(self, data: Data, device):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return [data.EigVals.to(device), data.EigVecs.to(device)]
        elif self is PosEncoder.RWSE:
            return data.pestat_RWSE.to(device)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

The `EnvArgs` class provides a container for the architecture-related hyperparameters. Its `load_net()` function builds the full model.

In [None]:
class EnvArgs(NamedTuple):
    model_type: ModelType
    num_layers: int
    env_dim: int

    layer_norm: bool
    skip: bool
    batch_norm: bool
    dropout: float
    act_type: ActivationType
    dec_num_layers: int
    pos_enc: PosEncoder
    dataset_encoders: DataSetEncoders

    metric_type: MetricType
    in_dim: int
    out_dim: int

    gin_mlp_func: Callable

    def load_net(self) -> torch.nn.ModuleList:
        if self.pos_enc is PosEncoder.NONE:
            enc_list = [self.dataset_encoders.node_encoder(in_dim=self.in_dim, emb_dim=self.env_dim)]
        else:
            if self.dataset_encoders is DataSetEncoders.NONE:
                enc_list = [self.pos_enc.get(in_dim=self.in_dim, emb_dim=self.env_dim)]
            else:
                enc_list = [Concat2NodeEncoder(enc1_cls=self.dataset_encoders.node_encoder,
                                               enc2_cls=self.pos_enc.get,
                                               in_dim=self.in_dim, emb_dim=self.env_dim,
                                               enc2_dim_pe=self.pos_enc.DIM_PE())]

        component_list = self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.env_dim,  out_dim=self.env_dim,
                                               num_layers=self.num_layers, bias=True, edges_required=True,
                                               gin_mlp_func=self.gin_mlp_func)

        if self.dec_num_layers > 1:
            mlp_list = (self.dec_num_layers - 1) * [torch.nn.Linear(self.env_dim, self.env_dim),
                                                    torch.nn.Dropout(self.dropout), self.act_type.nn()]
            mlp_list = mlp_list + [torch.nn.Linear(self.env_dim, self.out_dim)]
            dec_list = [torch.nn.Sequential(*mlp_list)]
        else:
            dec_list = [torch.nn.Linear(self.env_dim, self.out_dim)]

        return torch.nn.ModuleList(enc_list + component_list + dec_list)

The "policy" of the network, i.e. which state to choose for each node is the training goal of the Action Network. To this end, we construct the `ActionNetArgs` which, similarly to the `EnvArgs` class, contains its hyperparameters and the `load_net()` function builds the full Action Network.

In [None]:
class ActionNetArgs(NamedTuple):
    model_type: ModelType
    num_layers: int
    hidden_dim: int

    dropout: float
    act_type: ActivationType

    env_dim: int
    gin_mlp_func: Callable
    
    def load_net(self) -> torch.nn.ModuleList:
        net = self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.hidden_dim, out_dim=2,
                                                 num_layers=self.num_layers, bias=True, edges_required=False,
                                                 gin_mlp_func=self.gin_mlp_func)
        return torch.nn.ModuleList(net)

Having defined the `ActionNetArgs`, we should now also implement the `ActionNet` class. It constructs the network according to the arguments where the network architecture is created via the `load_net` function from the arguments and dropout, as well as activation function, are chosen therefrom, too. The `forward` function simply computes the output of the model.

In [None]:
class ActionNet(torch.nn.Module):
    def __init__(self, action_args: ActionNetArgs):
        """
        Create a model which represents the agent's policy.
        """
        super().__init__()
        self.num_layers = action_args.num_layers
        self.net = action_args.load_net()
        self.dropout = torch.nn.Dropout(action_args.dropout)
        self.act = action_args.act_type.get()

    def forward(self, x: Tensor, edge_index: Adj, env_edge_attr: OptTensor, act_edge_attr: OptTensor) -> Tensor:
        edge_attrs = [env_edge_attr] + (self.num_layers - 1) * [act_edge_attr]
        for idx, (edge_attr, layer) in enumerate(zip(edge_attrs[:-1], self.net[:-1])):
            x = layer(x=x, edge_index=edge_index, edge_attr=edge_attr)
            x = self.dropout(x)
            x = self.act(x)
        x = self.net[-1](x=x, edge_index=edge_index, edge_attr=edge_attrs[-1])
        return x

---

## Preparing the Dataset

The `DatasetBySplit` class is a simple container that holds the data in training, validation and test splits.

In [None]:
class DatasetBySplit(NamedTuple):
    train: Union[Data, List[Data]]
    val: Union[Data, List[Data]]
    test: Union[Data, List[Data]]

The `DataSetFamily` class is a simple container grouping the different datasets into predefined categories

In [None]:
class DataSetFamily(Enum):
    heterophilic = auto()
    synthetic = auto()
    social_networks = auto()
    proteins = auto()
    lrgb = auto()
    homophilic = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return DataSetFamily[s]
        except KeyError:
            raise ValueError()

Before defining the final `DataSet` class, we need to implement two helper functions `get_cosine_schedule_with_warmup` and `cosine_with_warmup_scheduler`. The latter is a simple wrapper, whereas the former returns a different scheduler instance, depending on whether the current step is smaller than number of warmup steps.

In [None]:
def get_cosine_schedule_with_warmup(
        optimizer: torch.optim.Optimizer, num_warmup_steps: int, num_training_steps: int,
        num_cycles: float = 0.5, last_epoch: int = -1):
    """
    Implementation by Huggingface:
    https://github.com/huggingface/transformers/blob/v4.16.2/src/transformers/optimization.py

    Create a schedule with a learning rate that decreases following the values
    of the cosine function between the initial lr set in the optimizer to 0,
    after a warmup period during which it increases linearly between 0 and the
    initial lr set in the optimizer.
    Args:
        optimizer ([`~torch.optim.Optimizer`]):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (`int`):
            The number of steps for the warmup phase.
        num_training_steps (`int`):
            The total number of training steps.
        num_cycles (`float`, *optional*, defaults to 0.5):
            The number of waves in the cosine schedule (the defaults is to just
            decrease from the max value to 0 following a half-cosine).
        last_epoch (`int`, *optional*, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        `torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """

    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return max(1e-6, float(current_step) / float(max(1, num_warmup_steps)))
        progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))

    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [None]:
def cosine_with_warmup_scheduler(optimizer: torch.optim.Optimizer,
                                 num_warmup_epochs: int, max_epoch: int):
    scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=num_warmup_epochs,
        num_training_steps=max_epoch
    )
    return scheduler

One final set of helper functions that we will need to set up is `apply_transform` and its subfunctions. We start with `pre_transform_in_memory`. It essentially applies a transformation function to a dataset and ensuring `None` values are removed.

In [None]:
def pre_transform_in_memory(dataset, transform_func, show_progress=False):
    """
    Pre-transform already loaded PyG dataset object.

    Apply transform function to a loaded PyG dataset object so that
    the transformed result is persistent for the lifespan of the object.
    This means the result is not saved to disk, as what PyG's `pre_transform`
    would do, but also the transform is applied only once and not at each
    data access as what PyG's `transform` hook does.

    Implementation is based on torch_geometric.data.in_memory_dataset.copy

    Args:
        dataset: PyG dataset object to modify
        transform_func: transformation function to apply to each data example
        show_progress: show tqdm progress bar
    """
    if transform_func is None:
        return dataset

    data_list = [transform_func(dataset.get(i))
                 for i in tqdm(range(len(dataset)),
                               disable=not show_progress,
                               mininterval=10,
                               miniters=len(dataset)//20)]
    data_list = list(filter(None, data_list))

    dataset._indices = None
    dataset._data_list = data_list
    dataset.data, dataset.slices = dataset.collate(data_list)
    return dataset

Next up is the `eigvec_normalizer` function which allows for different methods of normalization of eigenvectors.

In [None]:
def eigvec_normalizer(EigVecs, EigVals, normalization="L2", eps=1e-12):
    """
    Implement different eigenvector normalizations.
    """

    EigVals = EigVals.unsqueeze(0)

    if normalization == "L1":
        # L1 normalization: eigvec / sum(abs(eigvec))
        denom = EigVecs.norm(p=1, dim=0, keepdim=True)

    elif normalization == "L2":
        # L2 normalization: eigvec / sqrt(sum(eigvec^2))
        denom = EigVecs.norm(p=2, dim=0, keepdim=True)

    elif normalization == "abs-max":
        # AbsMax normalization: eigvec / max|eigvec|
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values

    elif normalization == "wavelength":
        # AbsMax normalization, followed by wavelength multiplication:
        # eigvec * pi / (2 * max|eigvec| * sqrt(eigval))
        denom = torch.max(EigVecs.abs(), dim=0, keepdim=True).values
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom * 2 / np.pi

    elif normalization == "wavelength-asin":
        # AbsMax normalization, followed by arcsin and wavelength multiplication:
        # arcsin(eigvec / max|eigvec|)  /  sqrt(eigval)
        denom_temp = torch.max(EigVecs.abs(), dim=0, keepdim=True).values.clamp_min(eps).expand_as(EigVecs)
        EigVecs = torch.asin(EigVecs / denom_temp)
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = eigval_denom

    elif normalization == "wavelength-soft":
        # AbsSoftmax normalization, followed by wavelength multiplication:
        # eigvec / (softmax|eigvec| * sqrt(eigval))
        denom = (torch.nn.functional.softmax(EigVecs.abs(), dim=0) * EigVecs.abs()).sum(dim=0, keepdim=True)
        eigval_denom = torch.sqrt(EigVals)
        eigval_denom[EigVals < eps] = 1  # Problem with eigval = 0
        denom = denom * eigval_denom

    else:
        raise ValueError(f"Unsupported normalization `{normalization}`")

    denom = denom.clamp_min(eps).expand_as(EigVecs)
    EigVecs = EigVecs / denom

    return EigVecs

Now, after defining some default constants, we implement the `get_lap_decomp_stats` function, which computes the laplacian eigen-decomposition-based positional encoding statistics of a given a given graph and correspondingly returns the eigenvalues and eigenvectors.

In [None]:
LAPLACIAN_NORM = 'none'
POSENC_MAX_FREQS = 10
EIGVEC_NORM = 'L2'
#KERNEL = ''
TIMES = list(range(1, 21))

In [None]:
def get_lap_decomp_stats(evals, evects, max_freqs, eigvec_norm='L2'):
    """Compute Laplacian eigen-decomposition-based PE stats of the given graph.

    Args:
        evals, evects: Precomputed eigen-decomposition
        max_freqs: Maximum number of top smallest frequencies / eigenvecs to use
        eigvec_norm: Normalization for the eigen vectors of the Laplacian
    Returns:
        Tensor (num_nodes, max_freqs, 1) eigenvalues repeated for each node
        Tensor (num_nodes, max_freqs) of eigenvector values per node
    """
    N = len(evals)  # Number of nodes, including disconnected nodes.

    # Keep up to the maximum desired number of frequencies.
    idx = evals.argsort()[:max_freqs]
    evals, evects = evals[idx], np.real(evects[:, idx])
    evals = torch.from_numpy(np.real(evals)).clamp_min(0)

    # Normalize and pad eigen vectors.
    evects = torch.from_numpy(evects).float()
    evects = eigvec_normalizer(evects, evals, normalization=eigvec_norm)
    if N < max_freqs:
        EigVecs = torch.nn.functional.pad(evects, (0, max_freqs - N), value=float('nan'))
    else:
        EigVecs = evects

    # Pad and save eigenvalues.
    if N < max_freqs:
        EigVals = torch.nn.functional.pad(evals, (0, max_freqs - N), value=float('nan')).unsqueeze(0)
    else:
        EigVals = evals.unsqueeze(0)
    EigVals = EigVals.repeat(N, 1).unsqueeze(2)

    return EigVals, EigVecs

The next thing we need to do is create a function `get_rw_landing_probs` which computes the random walk probabilities of returning to a node after a range of number of steps.

In [None]:
def get_rw_landing_probs(ksteps, edge_index, edge_weight=None,
                         num_nodes=None, space_dim=0):
    """Compute Random Walk landing probabilities for given list of K steps.

    Args:
        ksteps: List of k-steps for which to compute the RW landings
        edge_index: PyG sparse representation of the graph
        edge_weight: (optional) Edge weights
        num_nodes: (optional) Number of nodes in the graph
        space_dim: (optional) Estimated dimensionality of the space. Used to
            correct the random-walk diagonal by a factor `k^(space_dim/2)`.
            In euclidean space, this correction means that the height of
            the gaussian distribution stays almost constant across the number of
            steps, if `space_dim` is the dimension of the euclidean space.

    Returns:
        2D Tensor with shape (num_nodes, len(ksteps)) with RW landing probs
    """
    if edge_weight is None:
        edge_weight = torch.ones(edge_index.size(1), device=edge_index.device)
    num_nodes = maybe_num_nodes(edge_index, num_nodes)
    source, dest = edge_index[0], edge_index[1]
    deg = scatter(edge_weight, source, dim=0, dim_size=num_nodes, reduce='sum')  # Out degrees.
    deg_inv = deg.pow(-1.)
    deg_inv.masked_fill_(deg_inv == float('inf'), 0)

    if edge_index.numel() == 0:
        P = edge_index.new_zeros((1, num_nodes, num_nodes))
    else:
        # P = D^-1 * A
        P = torch.diag(deg_inv) @ to_dense_adj(edge_index, max_num_nodes=num_nodes)  # 1 x (Num nodes) x (Num nodes)
    rws = []
    if ksteps == list(range(min(ksteps), max(ksteps) + 1)):
        # Efficient way if ksteps are a consecutive sequence (most of the time the case)
        Pk = P.clone().detach().matrix_power(min(ksteps))
        for k in range(min(ksteps), max(ksteps) + 1):
            rws.append(torch.diagonal(Pk, dim1=-2, dim2=-1) * (k ** (space_dim / 2)))
            Pk = Pk @ P
    else:
        # Explicitly raising P to power k for each k \in ksteps.
        for k in ksteps:
            rws.append(torch.diagonal(P.matrix_power(k), dim1=-2, dim2=-1) * \
                       (k ** (space_dim / 2)))

    rw_landing = torch.cat(rws, dim=0).transpose(0, 1)  # (Num nodes) x (K steps)

    return rw_landing

The two functions we just implemented will now be useful when coding the `compute_posenc_stats` which enables us to compute the positional encoding statistics for a graph.

In [None]:
def compute_posenc_stats(data, pos_encoder: PosEncoder, is_undirected):
    """Precompute positional encodings for the given graph.

    Supported PE statistics to precompute, selected by `pe_types`:
    'LapPE': Laplacian eigen-decomposition.
    'RWSE': Random walk landing probabilities (diagonals of RW matrices).

    Args:
        data: PyG graph
        is_undirected: True if the graph is expected to be undirected

    Returns:
        Extended PyG Data object.
    """

    # Basic preprocessing of the input graph.
    if hasattr(data, 'num_nodes'):
        N = data.num_nodes  # Explicitly given number of nodes, e.g. ogbg-ppa
    else:
        N = data.x.shape[0]  # Number of nodes, including disconnected nodes.
    laplacian_norm_type = LAPLACIAN_NORM.lower()
    if laplacian_norm_type == 'none':
        laplacian_norm_type = None
    if is_undirected:
        undir_edge_index = data.edge_index
    else:
        undir_edge_index = to_undirected(data.edge_index)

    # Eigen values and vectors.
    evals, evects = None, None
    if pos_encoder is PosEncoder.LAP:
        # Eigen-decomposition with numpy, can be reused for Heat kernels.
        L = to_scipy_sparse_matrix(
            *get_laplacian(undir_edge_index, normalization=laplacian_norm_type,
                           num_nodes=N)
        )
        evals, evects = np.linalg.eigh(L.toarray())
        max_freqs = POSENC_MAX_FREQS
        eigvec_norm = EIGVEC_NORM
        data.EigVals, data.EigVecs = get_lap_decomp_stats(
            evals=evals, evects=evects,
            max_freqs=max_freqs,
            eigvec_norm=eigvec_norm)
    elif pos_encoder is PosEncoder.RWSE:
        times = TIMES
        if len(times) == 0:
            raise ValueError("List of kernel times required for RWSE")
        rw_landing = get_rw_landing_probs(ksteps=times,
                                          edge_index=data.edge_index,
                                          num_nodes=N)
        data.pestat_RWSE = rw_landing

    return data

The constant `TASK` defines whether the task at hand is graph-level or node-level.

In [None]:
TASK = 'graph'

The `set_dataset_splits` function sets the training, validation and test splits for a given dataset object and depending on graph/node-level tasks, ensuring that there are no overlaps between the splits.

In [None]:
def set_dataset_splits(dataset, splits):
    """Set given splits to the dataset object.

    Args:
        dataset: PyG dataset object
        splits: List of train/val/test split indices

    Raises:
        ValueError: If any pair of splits has intersecting indices
    """
    # First check whether splits intersect and raise error if so.
    for i in range(len(splits) - 1):
        for j in range(i + 1, len(splits)):
            n_intersect = len(set(splits[i]) & set(splits[j]))
            if n_intersect != 0:
                raise ValueError(
                    f"Splits must not have intersecting indices: "
                    f"split #{i} (n = {len(splits[i])}) and "
                    f"split #{j} (n = {len(splits[j])}) have "
                    f"{n_intersect} intersecting indices"
                )

    task_level = TASK
    if task_level == 'node':
        split_names = ['train_mask', 'val_mask', 'test_mask']
        for split_name, split_index in zip(split_names, splits):
            mask = index2mask(split_index, size=dataset.data.y.shape[0])
            set_dataset_attr(dataset, split_name, mask, len(mask))

    elif task_level == 'graph':
        split_names = [
            'train_graph_index', 'val_graph_index', 'test_graph_index'
        ]
        for split_name, split_index in zip(split_names, splits):
            set_dataset_attr(dataset, split_name, split_index, len(split_index))

    else:
        raise ValueError(f"Unsupported dataset task level: {task_level}")

Now finally we can implement the `apply_transform` function  which takes in a dataset and a positional encoder, which then gets applied to the dataset to compute node features.

In [None]:
def apply_transform(dataset, pos_encoder):
    start = time.perf_counter()
    logging.info(f"Precomputing Positional Encoding statistics for all graphs... ")
    # Estimate directedness based on 10 graphs to save time.
    is_undirected = all(d.is_undirected() for d in dataset[:10])
    logging.info(f"  ...estimated to be undirected: {is_undirected}")
    dataset = pre_transform_in_memory(dataset, partial(compute_posenc_stats, pos_encoder=pos_encoder, is_undirected=is_undirected), show_progress=True)
    elapsed = time.perf_counter() - start
    timestr = time.strftime('%H:%M:%S', time.gmtime(elapsed)) + f'{elapsed:.2f}'[-3:]
    logging.info(f"Done! Took {timestr}")

    # Set standard dataset train/val/test splits
    if hasattr(dataset, 'split_idxs'):
        set_dataset_splits(dataset, dataset.split_idxs)
        delattr(dataset, 'split_idxs')
    return dataset

---

## Setting up the Dataset types

### Root Neighbours

We start by implementing the `RootNeighboursDataset`. It has the following methods:

- `get()` returns the generated data.

- `create_data()` returns the generated data, to be called at instantiation.

- `mask_task()` returns the training, validation and testing masks that shall be used.

- `generate_component()` returns one generated graph component, to be called in `create_data()`.

- `initialize_constants()` returns a dictionary with various dataset parameters.

- `generate_fold()` generates a single graph topology sample.

In [None]:
class RootNeighboursDataset(object):

    def __init__(self, seed: int, print_flag: bool = False):
        super().__init__()
        self.seed = seed
        self.plot_flag = print_flag
        self.generator = torch.Generator().manual_seed(seed)
        self.constants_dict = self.initialize_constants()

        self._data = self.create_data()

    def get(self) -> Data:
        return self._data

    def create_data(self) -> Data:
        # train, val, test
        data_list = []
        for num in range(self.constants_dict['NUM_COMPONENTS']):
            data_list.append(self.generate_component())
        return Batch.from_data_list(data_list)

    def mask_task(self, num_nodes_per_fold: List[int]) -> Tuple[Tensor, Tensor, Tensor]:
        num_nodes = sum(num_nodes_per_fold)
        train_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)
        val_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)
        test_mask = torch.zeros(size=(num_nodes,), dtype=torch.bool)

        train_mask[0] = True
        val_mask[num_nodes_per_fold[0]] = True
        test_mask[num_nodes_per_fold[0] + num_nodes_per_fold[1]] = True
        return train_mask, val_mask, test_mask

    def generate_component(self) -> Data:
        data_per_fold, num_nodes_per_fold = [], []
        for fold_idx in range(3):
            data = self.generate_fold(eval=(fold_idx != 0))
            num_nodes_per_fold.append(data.x.shape[0])
            data_per_fold.append(data)

        train_mask, val_mask, test_mask = self.mask_task(num_nodes_per_fold=num_nodes_per_fold)

        batch = Batch.from_data_list(data_per_fold)
        return Data(x=batch.x, edge_index=batch.edge_index, y=batch.y, train_mask=train_mask, val_mask=val_mask, test_mask=test_mask)

    def initialize_constants(self) -> Dict[str, int]:
        return {'NUM_COMPONENTS': 1000, 'MAX_HUBS': 3, 'MAX_1HOP_NEIGHBORS': 10, 'ADD_HUBS': 2, 'HUB_NEIGHBORS': 5, 'MAX_2HOP_NEIGHBORS': 3, 'NUM_FEATURES': 5}

    def generate_fold(self, eval: bool) -> Data:
        constant_dict = self.initialize_constants()
        MAX_HUBS, MAX_1HOP_NEIGHBORS, ADD_HUBS, HUB_NEIGHBORS, MAX_2HOP_NEIGHBORS, NUM_FEATURES =\
            [constant_dict[key] for key in ['MAX_HUBS', 'MAX_1HOP_NEIGHBORS', 'ADD_HUBS', 'HUB_NEIGHBORS', 'MAX_2HOP_NEIGHBORS', 'NUM_FEATURES']]

        assert MAX_HUBS + ADD_HUBS <= MAX_1HOP_NEIGHBORS
        add_hubs = ADD_HUBS if eval else 0
        num_hubs = torch.randint(1, MAX_HUBS + 1, size=(1,), generator=self.generator).item() + add_hubs
        num_1hop_neighbors = torch.randint(MAX_HUBS + add_hubs, MAX_1HOP_NEIGHBORS + 1, size=(1,), generator=self.generator).item()
        assert num_hubs <= num_1hop_neighbors

        list_num_2hop_neighbors = torch.randint(1, MAX_2HOP_NEIGHBORS, size=(num_1hop_neighbors - num_hubs,), generator=self.generator).tolist()
        list_num_2hop_neighbors = [HUB_NEIGHBORS] * num_hubs + list_num_2hop_neighbors

        # 2 hop edge index
        num_nodes = 1  # root node is 0
        idx_1hop_neighbors = []
        list_edge_index = []
        for num_2hop_neighbors in list_num_2hop_neighbors:
            idx_1hop_neighbors.append(num_nodes)
            if num_2hop_neighbors > 0:
                clique_edge_index = torch.tensor([[0] * num_2hop_neighbors, list(range(1, num_2hop_neighbors + 1))])
                # clique_edge_index = torch.combinations(torch.arange(num_2hop_neighbors), r=2).T
                list_edge_index.append(clique_edge_index + num_nodes)

            num_nodes += num_2hop_neighbors + 1

        # 1 hop edge index
        idx_0hop = torch.tensor([0] * num_1hop_neighbors)
        idx_1hop_neighbors = torch.tensor(idx_1hop_neighbors)
        hubs = idx_1hop_neighbors[:num_hubs]
        list_edge_index.append(torch.stack((idx_0hop, idx_1hop_neighbors), dim=0))
        edge_index = torch.cat(list_edge_index, dim=1)

        # undirect
        edge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)
        edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)

        # features
        x = 4 * torch.rand(size=(num_nodes, NUM_FEATURES), generator=self.generator) - 2

        # labels
        y = torch.zeros_like(x)
        y[0] = torch.mean(x[hubs], dim=0)
        return Data(x=x, edge_index=edge_index, y=y)

### Cycles

In order to set up the `CyclesDataset`, we first need a function `make_undirected` which simply turns a directed graph into an undirected one by concatenating the inverse of the list of directed edges to itself.

In [None]:
def make_undirected(edge_index: Tensor) -> Tensor:
    edge_index_other_direction = torch.stack((edge_index[1], edge_index[0]), dim=0)
    edge_index = torch.cat((edge_index_other_direction, edge_index), dim=1)
    return edge_index

Now, the other helper function we need is `create_cycle` which creates two cycles for each value of `cycle_size`. One is a standard 0 - ... - (n-1) - 0 cycle and the other additionally has a shortcut edge.

In [None]:
def create_cycle(max_cycle: int) -> List[Data]:
    data_list = []
    for cycle_size in range(6, max_cycle + 1):
        if cycle_size < (max_cycle + 1 - 6) / 3 + 6:
            train_mask = torch.ones(size=(1,), dtype=torch.bool)
            val_mask = torch.zeros(size=(1,), dtype=torch.bool)
            test_mask = torch.zeros(size=(1,), dtype=torch.bool)
        elif cycle_size < 2 * (max_cycle + 1 - 6) / 3 + 6:
            train_mask = torch.zeros(size=(1,), dtype=torch.bool)
            val_mask = torch.ones(size=(1,), dtype=torch.bool)
            test_mask = torch.zeros(size=(1,), dtype=torch.bool)
        else:
            train_mask = torch.zeros(size=(1,), dtype=torch.bool)
            val_mask = torch.zeros(size=(1,), dtype=torch.bool)
            test_mask = torch.ones(size=(1,), dtype=torch.bool)

        x = torch.ones(size=(cycle_size, 1))
        edge_index1 = torch.tensor([list(range(cycle_size)),
                                    list(range(1, cycle_size)) + [0]])
        edge_index1 = make_undirected(edge_index=edge_index1)
        edge_index2 = torch.tensor([[0, 1, 2] + list(range(3, cycle_size)),
                                    [1, 2, 0] + list(range(4, cycle_size)) + [3]])
        edge_index2 = make_undirected(edge_index=edge_index2)

        data_list.append(Data(x=x, edge_index=edge_index1, y=torch.tensor([0], dtype=torch.long),
                              train_mask=train_mask, val_mask=val_mask, test_mask=test_mask))
        data_list.append(Data(x=x, edge_index=edge_index2, y=torch.tensor([1], dtype=torch.long),
                              train_mask=train_mask, val_mask=val_mask, test_mask=test_mask))

    return data_list

The `CyclesDataset` class is then simply a function call to the `create_dataset` function.

In [None]:
class CyclesDataset(object):

    def __init__(self):
        super().__init__()
        self.data = create_cycle(max_cycle=13)

### Peptides

The `PeptidesFunctionalDataset` deals with the classification of organic molecules' graphs as different classes of biological functions. It contains the following methods:

- `_md5sum()` computes the md5 hash of the dataset to verify the file's integrity.

- `download()` downloads the dataset from the internet.

- `process()` converts each SMILE string (Simplified Molecular Input Line Entry System, a sort of text representation of a molecule's skeletal formula) of the dataset into a graph and applies an optional transformation.

- `get_idx_split()` returns the pre-computed training, validation and test splits.

In [None]:
class PeptidesFunctionalDataset(InMemoryDataset):
    def __init__(self, root='data', smiles2graph=smiles2graph,
                 transform=None, pre_transform=None):
        """
        PyG dataset of 15,535 peptides represented as their molecular graph
        (SMILES) with 10-way multi-task binary classification of their
        functional classes.
        The goal is use the molecular representation of peptides instead
        of amino acid sequence representation ('peptide_seq' field in the file,
        provided for possible baseline benchmarking but not used here) to test
        GNNs' representation capability.
        The 10 classes represent the following functional classes (in order):
            ['antifungal', 'cell_cell_communication', 'anticancer',
            'drug_delivery_vehicle', 'antimicrobial', 'antiviral',
            'antihypertensive', 'antibacterial', 'antiparasitic', 'toxic']
        Args:
            root (string): Root directory where the dataset should be saved.
            smiles2graph (callable): A callable function that converts a SMILES
                string into a graph object. We use the OGB featurization.
                * The default smiles2graph requires rdkit to be installed *
        """

        self.original_root = root
        self.smiles2graph = smiles2graph
        self.folder = join(root, 'peptides-functional')

        self.url = 'https://www.dropbox.com/s/ol2v01usvaxbsr8/peptide_multi_class_dataset.csv.gz?dl=1'
        self.version = '701eb743e899f4d793f0e13c8fa5a1b4'  # MD5 hash of the intended dataset file
        self.url_stratified_split = 'https://www.dropbox.com/s/j4zcnx2eipuo0xz/splits_random_stratified_peptide.pickle?dl=1'
        self.md5sum_stratified_split = '5a0114bdadc80b94fc7ae974f13ef061'

        # Check version and update if necessary.
        release_tag = join(self.folder, self.version)
        if isdir(self.folder) and (not exists(release_tag)):
            print(f"{self.__class__.__name__} has been updated.")
            if input("Will you update the dataset now? (y/N)\n").lower() == 'y':
                rmtree(self.folder)

        super().__init__(self.folder, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return 'peptide_multi_class_dataset.csv.gz'

    @property
    def processed_file_names(self):
        return 'geometric_data_processed.pt'

    def _md5sum(self, path):
        hash_md5 = md5()
        with open(path, 'rb') as f:
            buffer = f.read()
            hash_md5.update(buffer)
        return hash_md5.hexdigest()

    def download(self):
        if decide_download(self.url):
            path = download_url(self.url, self.raw_dir)
            # Save to disk the MD5 hash of the downloaded file.
            hash = self._md5sum(path)
            if hash != self.version:
                raise ValueError("Unexpected MD5 hash of the downloaded file")
            open(join(self.root, hash), 'w').close()
            # Download train/val/test splits.
            path_split1 = download_url(self.url_stratified_split, self.root)
            assert self._md5sum(path_split1) == self.md5sum_stratified_split
        else:
            print('Stop download.')
            exit(-1)

    def process(self):
        data_df = pd.read_csv(join(self.raw_dir,
                                       'peptide_multi_class_dataset.csv.gz'))
        smiles_list = data_df['smiles']

        print('Converting SMILES strings into graphs...')
        data_list = []
        for i in tqdm(range(len(smiles_list))):
            data = Data()

            smiles = smiles_list[i]
            graph = self.smiles2graph(smiles)

            assert (len(graph['edge_feat']) == graph['edge_index'].shape[1])
            assert (len(graph['node_feat']) == graph['num_nodes'])

            data.__num_nodes__ = int(graph['num_nodes'])
            data.edge_index = torch.from_numpy(graph['edge_index']).to(
                torch.int64)
            data.edge_attr = torch.from_numpy(graph['edge_feat']).to(
                torch.int64)
            data.x = torch.from_numpy(graph['node_feat']).to(torch.int64)
            data.y = torch.Tensor([eval(data_df['labels'].iloc[i])])

            data_list.append(data)

        if self.pre_transform is not None:
            print('Applying pre_transform of graphs...')
            data_list = [self.pre_transform(data) for data in tqdm(data_list)]

        data, slices = self.collate(data_list)

        print('Saving...')
        torch.save((data, slices), self.processed_paths[0])

    def get_idx_split(self):
        """ Get dataset splits.
        Returns:
            Dict with 'train', 'val', 'test', splits indices.
        """
        split_file = join(self.root,
                              "splits_random_stratified_peptide.pickle")
        with open(split_file, 'rb') as f:
            splits = pickle.load(f)
        split_dict = replace_numpy_with_torchtensor(splits)
        return split_dict

### Planetoid

On the way to define the `Planetoid` dataset class, we start by coding `parse_index_file` which basically reads a file and returns a list of the integer values of each line.

In [None]:
def parse_index_file(filename):
    """Code taken from https://github.com/Yujun-Yan/Heterophily_and_oversmoothing/blob/main/process.py#L18""

    Parse index file.
    """
    index = []
    for line in open(filename):
        index.append(int(line.strip()))
    return index

To load the dataset, we define the `full_load_citation` function which reads the files (features, labels, adjacency matrix, ...), storing the result in variables, and returning the dataset as a `Data` object.

In [None]:
def full_load_citation(dataset_str, raw_dir):
    """Code adapted from https://github.com/Yujun-Yan/Heterophily_and_oversmoothing/blob/main/process.py#L33"""
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        path = join(raw_dir, "ind.{}.{}".format(dataset_str, names[i]))
        with open(path, 'rb') as f:
            if sys.version_info > (3, 0):
                objects.append(pickle.load(f, encoding='latin1'))
            else:
                objects.append(pickle.load(f))

    x, y, tx, ty, allx, ally, graph = tuple(objects)
    test_idx_reorder = parse_index_file(join(raw_dir, "ind.{}.test.index".format(dataset_str)))
    test_idx_range = np.sort(test_idx_reorder)
    test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder) + 1)
    tx_extended = scipy.sparse.lil_matrix((len(test_idx_range_full), x.shape[1]))
    if len(test_idx_range_full) != len(test_idx_range):
        # Fix citeseer dataset (there are some isolated nodes in the graph)
        # Find isolated nodes, add them as zero-vecs into the right position, mark them
        # Follow H2GCN code
        tx_extended[test_idx_range - min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range - min(test_idx_range), :] = ty
        ty = ty_extended
        non_valid_samples = set(test_idx_range_full) - set(test_idx_range)
    else:
        non_valid_samples = set()
    features = scipy.sparse.vstack((allx, tx)).tolil()
    features[test_idx_reorder, :] = features[test_idx_range, :]
    adj = networkx.adjacency_matrix(networkx.from_dict_of_lists(graph))

    labels = np.vstack((ally, ty))
    labels[test_idx_reorder, :] = labels[test_idx_range, :]
    non_valid_samples = list(non_valid_samples.union(set(list(np.where(labels.sum(1) == 0)[0]))))
    labels = np.argmax(labels, axis=-1)

    features = features.todense()

    # Prepare in PyTorch Geometric Format
    sparse_mx = scipy.sparse.coo_matrix(adj).astype(np.float32)
    indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))
    shape = torch.Size(sparse_mx.shape)
    edge_index, _ = coalesce(indices, None, shape[0], shape[1])

    # Remove self-loops
    edge_index, _ = remove_self_loops(edge_index)
    # Make the graph undirected
    edge_index = to_undirected(edge_index)

    assert (np.array_equal(np.unique(labels), np.arange(len(np.unique(labels)))))

    features = torch.FloatTensor(features)
    labels = torch.LongTensor(labels)
    non_valid_samples = torch.LongTensor(non_valid_samples)

    return Data(x=features, edge_index=edge_index, y=labels, num_node_features=features.size(1),
                non_valid_samples=non_valid_samples)

Finally, the `Planetoid` dataset holds the data loaded at instantiation. Its `process` function can apply an optional transform to the data.

In [None]:
class Planetoid(InMemoryDataset):

    def __init__(self, root: str, name: str,
                 transform: Optional[Callable] = None,
                 pre_transform: Optional[Callable] = None):
        self.name = name

        super().__init__(root, transform, pre_transform)
        self.data, self.slices = torch.load(self.processed_paths[0])

        data = self.get(0)
        self.data, self.slices = self.collate([data])

    @property
    def raw_dir(self) -> str:
        return join(self.root, self.name, 'raw')

    @property
    def processed_dir(self) -> str:
        return join(self.root, self.name, 'processed')

    @property
    def raw_file_names(self) -> List[str]:
        names = ['x', 'tx', 'allx', 'y', 'ty', 'ally', 'graph', 'test.index']
        return [f'ind.{self.name.lower()}.{name}' for name in names]

    @property
    def processed_file_names(self) -> str:
        return 'data.pt'

    def download(self):
        pass

    def process(self):
        data = full_load_citation(self.name, self.raw_dir)
        data = data if self.pre_transform is None else self.pre_transform(data)
        torch.save(self.collate([data]), self.processed_paths[0])

    def __repr__(self) -> str:
        return f'{self.name}()'

---

## Defining the DataSet class

The `DataSet` class provides a container for the different datasets used in the paper. It has the following functionality:

- `get_family()` tells us which of the types from the `DataSetFamily` class the DataSet object is.

- `is_node_based()` returns `True` if the dataset is for node classification.

- `not_synthetic()` returns `True` if the dataset is real, i.e. not synthetic.

- `is_expressivity()` returns `True` if the dataset is for distinguishing graph structures, in our case that is only the Cycles dataset.

- `clip_grad()` returns `True` if gradient clipping is required, in our case that is only the LRGB dataset.

- `get_dataset_encoders()` tells us which encoder to use for the dataset, in our case only the LRGB dataset will get a special encoder.

- `get_folds()` returns the fold indices to be used in n-fold training, i.e. how many different models shall be trained.

- `load()` does the loading of the dataset from the data files and returns a list of Data-objects. Adjust the `root` and `name`, or `tu_dataset_name`, variable depending on the folder structure of where the data is located relative to this notebook.

- `select_fold_and_split()` takes the data returned from `laod()` and splits it into the desired number of folds. It returns the folds within a `DatasetBySplit` object.

- `get_metric_type()` tells us which from the `MetricType` to use as a metric for model performance.

- `num_after_decimal()` returns the precision to use when displaying results.

- `env_activation_type()` tells us which activation function to use in the model.

- `gin_mlp_func()` returns an architecture for the MLP to be used for the model in the `WeightedGINConv` class.

- `optimizer()` returns the optimizer to use when training the model.

- `scheduler()` returns the scheduler to use when training the model.

- `get_split_mask()` returns the training / validation / test masks (of nodes) to use for node classification tasks, or the entire graph for graph classification tasks.

- `get_edge_ratio_node_mask()` is similar to `get_split_mask()` but only for node-level tasks.

- `asserts()` checks that various assertions hold on the data.

In [None]:
class DataSet(Enum):
    """
        an object for the different datasets
    """
    # heterophilic
    roman_empire = auto()
    amazon_ratings = auto()
    minesweeper = auto()
    tolokers = auto()
    questions = auto()

    # synthetic
    root_neighbours = auto()
    cycles = auto()

    # social networks
    imdb_binary = auto()
    imdb_multi = auto()
    reddit_binary = auto()
    reddit_multi = auto()
    
    # proteins
    enzymes = auto()
    proteins = auto()
    nci1 = auto()

    # lrgb
    func = auto()

    # homophilic
    cora = auto()
    pubmed = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return DataSet[s]
        except KeyError:
            raise ValueError()
        
    def get_family(self) -> DataSetFamily:
        if self in [DataSet.roman_empire, DataSet.amazon_ratings, DataSet.minesweeper,
                    DataSet.tolokers, DataSet.questions]:
            return DataSetFamily.heterophilic
        elif self in [DataSet.root_neighbours, DataSet.cycles]:
            return DataSetFamily.synthetic
        elif self in [DataSet.imdb_binary, DataSet.imdb_multi, DataSet.reddit_binary, DataSet.reddit_multi]:
            return DataSetFamily.social_networks
        elif self in [DataSet.enzymes, DataSet.proteins, DataSet.nci1]:
            return DataSetFamily.proteins
        elif self is DataSet.func:
            return DataSetFamily.lrgb
        elif self in [DataSet.cora, DataSet.pubmed]:
            return DataSetFamily.homophilic
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')

    def is_node_based(self) -> bool:
        return self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.homophilic] or self is DataSet.root_neighbours

    def not_synthetic(self) -> bool:
        return self.get_family() is not DataSetFamily.synthetic

    def is_expressivity(self) -> bool:
        return self is DataSet.cycles

    def clip_grad(self) -> bool:
        return self.get_family() is DataSetFamily.lrgb

    def get_dataset_encoders(self):
        if self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.synthetic, DataSetFamily.social_networks,
                                 DataSetFamily.proteins, DataSetFamily.homophilic]:
            return DataSetEncoders.NONE
        elif self is DataSet.func:
            return DataSetEncoders.MOL
        else:
            raise ValueError(f'DataSet {self.name} not supported in get_dataset_encoders')

    def get_folds(self, fold: int) -> List[int]:
        if self.get_family() in [DataSetFamily.synthetic, DataSetFamily.lrgb]:
            return list(range(1))
        elif self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.homophilic]:
            return list(range(10))
        elif self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins]:
            return [fold]
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')
    
    def load(self, seed: int, pos_enc: PosEncoder) -> List[Data]:
        root = join(ROOT_DIR, 'datasets')
        if self.get_family() is DataSetFamily.heterophilic:
            name = self.name.replace('_', '-').capitalize()
            dataset = [HeterophilousGraphDataset(root=root, name=name, transform=T.ToUndirected())[0]]
        elif self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins]:
            tu_dataset_name = self.name.upper().replace('_', '-')
            root = join(ROOT_DIR, 'datasets', tu_dataset_name)
            dataset = torch.load(root + '.pt')
        elif self is DataSet.root_neighbours:
            dataset = [RootNeighboursDataset(seed=seed).get()]
        elif self is DataSet.cycles:
            dataset = CyclesDataset().data
        elif self is DataSet.func:
            dataset = PeptidesFunctionalDataset(root=root)
            dataset = apply_transform(dataset=dataset, pos_encoder=pos_enc)
        elif self.get_family() is DataSetFamily.homophilic:
            dataset = [Planetoid(root=root, name=self.name, transform=T.NormalizeFeatures())[0]]
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')
        return dataset

    def select_fold_and_split(self, dataset: List[Data], num_fold: int) -> DatasetBySplit:
        if self.get_family() is DataSetFamily.heterophilic:
            dataset_copy = copy.deepcopy(dataset)
            dataset_copy[0].train_mask = dataset_copy[0].train_mask[:, num_fold]
            dataset_copy[0].val_mask = dataset_copy[0].val_mask[:, num_fold]
            dataset_copy[0].test_mask = dataset_copy[0].test_mask[:, num_fold]
            return DatasetBySplit(train=dataset_copy, val=dataset_copy, test=dataset_copy)
        elif self.get_family() is DataSetFamily.synthetic:
            return DatasetBySplit(train=dataset, val=dataset, test=dataset)
        elif self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins]:
            tu_dataset_name = self.name.upper().replace('_', '-')
            original_fold_dict = json.load(open(f'folds/{tu_dataset_name}_splits.json', "r"))[num_fold]
            model_selection_dict = original_fold_dict['model_selection'][0]
            split_dict = {'train': model_selection_dict['train'], 'val': model_selection_dict['validation'],
                          'test': original_fold_dict['test']}
            dataset_by_splits = [[dataset[idx] for idx in split_dict[split]] for split in DatasetBySplit._fields]
            return DatasetBySplit(*dataset_by_splits)
        elif self is DataSet.func:
            split_idx = dataset.get_idx_split()
            dataset_by_splits = [[dataset[idx] for idx in split_idx[split]] for split in DatasetBySplit._fields]
            return DatasetBySplit(*dataset_by_splits)
        elif self.get_family() is DataSetFamily.homophilic:
            device = dataset[0].x.device
            with np.load(f'folds/{self.name}_split_0.6_0.2_{num_fold}.npz') as folds_file:
                train_mask = torch.tensor(folds_file['train_mask'], dtype=torch.bool, device=device)
                val_mask = torch.tensor(folds_file['val_mask'], dtype=torch.bool, device=device)
                test_mask = torch.tensor(folds_file['test_mask'], dtype=torch.bool, device=device)

            setattr(dataset[0], 'train_mask', train_mask)
            setattr(dataset[0], 'val_mask', val_mask)
            setattr(dataset[0], 'test_mask', test_mask)

            dataset[0].train_mask[dataset[0].non_valid_samples] = False
            dataset[0].test_mask[dataset[0].non_valid_samples] = False
            dataset[0].val_mask[dataset[0].non_valid_samples] = False
            return DatasetBySplit(train=dataset, val=dataset, test=dataset)
        else:
            raise ValueError(f'NotImplemented')

    def get_metric_type(self) -> MetricType:
        if self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins, DataSetFamily.homophilic] or self in [DataSet.roman_empire, DataSet.amazon_ratings, DataSet.cycles]:
            return MetricType.ACCURACY
        elif self in [DataSet.minesweeper, DataSet.tolokers, DataSet.questions]:
            return MetricType.AUC_ROC
        elif self is DataSet.root_neighbours:
            return MetricType.MSE_MAE
        elif self is DataSet.func:
            return MetricType.MULTI_LABEL_AP
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')

    def num_after_decimal(self) -> int:
        return 4 if self.get_family() is DataSetFamily.lrgb else 2

    def env_activation_type(self) -> ActivationType:
        if self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.lrgb]:
            return ActivationType.GELU
        else:
            return ActivationType.RELU

    def gin_mlp_func(self) -> Callable:
        if self is DataSet.func:
            def mlp_func(in_channels: int, out_channels: int, bias: bool):
                return torch.nn.Sequential(torch.nn.Linear(in_channels, out_channels, bias=bias),
                                           torch.nn.ReLU(), torch.nn.Linear(out_channels, out_channels, bias=bias))
        elif self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins]:
            def mlp_func(in_channels: int, out_channels: int, bias: bool):
                return torch.nn.Sequential(torch.nn.Linear(in_channels, 2 * in_channels, bias=bias),
                                           torch.nn.ReLU(), torch.nn.Linear(2 * in_channels, out_channels, bias=bias))
        else:
            def mlp_func(in_channels: int, out_channels: int, bias: bool):
                return torch.nn.Sequential(torch.nn.Linear(in_channels, 2 * in_channels, bias=bias),
                                           torch.nn.BatchNorm1d(2 * in_channels),
                                           torch.nn.ReLU(), torch.nn.Linear(2 * in_channels, out_channels, bias=bias))
        return mlp_func

    def optimizer(self, model, lr: float, weight_decay: float):
        if self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.synthetic, DataSetFamily.social_networks,
                                 DataSetFamily.proteins, DataSetFamily.homophilic]:
            return torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        elif self.get_family() is DataSetFamily.lrgb:
            return torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')

    def scheduler(self, optimizer, step_size: Optional[int], gamma: Optional[float], num_warmup_epochs: Optional[int],
                  max_epochs: int):
        if self.get_family() is DataSetFamily.lrgb:
            assert num_warmup_epochs is not None, 'cosine_with_warmup_scheduler\'s num_warmup_epochs is None'
            assert max_epochs is not None, 'cosine_with_warmup_scheduler\'s max_epochs is None'
            return cosine_with_warmup_scheduler(optimizer=optimizer, num_warmup_epochs=num_warmup_epochs,
                                                max_epoch=max_epochs)
        elif self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins]:
            assert step_size is not None, 'StepLR\'s step_size is None'
            assert gamma is not None, 'StepLR\'s gamma is None'
            return torch.optim.lr_scheduler.StepLR(optimizer=optimizer, step_size=step_size, gamma=gamma)
        elif self.get_family() in [DataSetFamily.heterophilic, DataSetFamily.synthetic, DataSetFamily.homophilic]:
            return None
        else:
            raise ValueError(f'DataSet {self.name} not supported in dataloader')

    def get_split_mask(self, data: Data, batch_size: int, split_mask_name: str) -> Tensor:
        if hasattr(data, split_mask_name):
            return getattr(data, split_mask_name)
        elif self.is_node_based():
            return torch.ones(size=(data.x.shape[0],), dtype=torch.bool)
        else:
            return torch.ones(size=(batch_size,), dtype=torch.bool)

    def get_edge_ratio_node_mask(self, data: Data, split_mask_name: str) -> Tensor:
        if hasattr(data, split_mask_name):
            return getattr(data, split_mask_name)
        else:
            return torch.ones(size=(data.x.shape[0],), dtype=torch.bool)

    def asserts(self, args):
        # model
        assert not(self.is_node_based()) or args['pool'] is Pool.NONE, "Node based datasets have no pooling"
        assert not(self.is_node_based()) or args['batch_norm'] is False, "Node based dataset cannot have batch norm"
        assert not(not(self.is_node_based()) and args['pool'] is Pool.NONE), "Graph based datasets need pooling"
        assert args['env_model_type'] is not ModelType.LIN, "The environment net can't be linear"

        # dataset dependant parameters
        assert self.get_family() in [DataSetFamily.social_networks, DataSetFamily.proteins] or args['fold'] is None, 'social networks and protein datasets are the only ones to use fold'
        assert self.get_family() not in [DataSetFamily.social_networks, DataSetFamily.proteins] or args['fold'] is not None, 'social networks and protein datasets must specify fold'
        assert self.get_family() is DataSetFamily.proteins or self.get_family() is DataSetFamily.social_networks or (args['step_size'] is None and args['gamma'] is None), 'proteins datasets are the only ones to use step_size and gamma'
        assert self.get_family() is DataSetFamily.lrgb or (args['num_warmup_epochs'] is None), 'lrgb datasets are the only ones to use num_warmup_epochs'
        # encoders
        assert self.get_family() is DataSetFamily.lrgb or (args['pos_enc'] is PosEncoder.NONE), 'lrgb datasets are the only ones to use pos_enc'

---

## Setting up the experiment parameters

Since we're running the code inside a jupyter notebook, it is easier to just use a dictionary instead of an argument parser. It is implemented as a function so that we can define default values in the parameters of `create_args`.

In [None]:
def create_args(
    dataset: DataSet.from_string=DataSet.roman_empire,
    pool: Pool.from_string=Pool.NONE,

    # gumbel
    learn_temp=False,
    temp_model_type: ModelType.from_string=ModelType.LIN,
    tau0: float=0.5,
    temp: float=0.01,

    # optimization
    max_epochs: int=3000,
    batch_size: int=32,
    lr: float=1e-3,
    dropout: float=0.2,

    # env cls parameters
    env_model_type: ModelType.from_string=ModelType.MEAN_GNN,
    env_num_layers: int=3,
    env_dim: int=128,
    skip=False,
    batch_norm=False,
    layer_norm=False,
    dec_num_layers: int=1,
    pos_enc: PosEncoder.from_string=PosEncoder.NONE,

    # policy cls parameters
    act_model_type: ModelType.from_string=ModelType.MEAN_GNN,
    act_num_layers: int=1,
    act_dim: int=16,

    # reproduce
    seed: int=0,
    gpu: int=None, # In the original argument parser, there is no default value defined here

    # dataset dependant parameters
    fold: int=None,

    # optimizer and scheduler
    weight_decay: float=0.0,
    ## for steplr scheduler only
    step_size: int=None,
    gamma: float=None,
    ## for cosine with warmup scheduler only
    num_warmup_epochs: int=None
):
    return {
        'dataset' : dataset,
        'pool' : pool,
        'learn_temp' : learn_temp,
        'temp_model_type' : temp_model_type,
        'tau0' : tau0,
        'temp' : temp,
        'max_epochs' : max_epochs,
        'batch_size' : batch_size,
        'lr' : lr,
        'dropout' : dropout,
        'env_model_type' : env_model_type,
        'env_num_layers' : env_num_layers,
        'env_dim' : env_dim,
        'skip' : skip,
        'batch_norm' : batch_norm,
        'layer_norm' : layer_norm,
        'dec_num_layers' : dec_num_layers,
        'pos_enc' : pos_enc,
        'act_model_type' : act_model_type,
        'act_num_layers' : act_num_layers,
        'act_dim' : act_dim,
        'seed' : seed,
        'gpu' : gpu,
        'fold' : fold,
        'weight_decay' : weight_decay,
        'step_size' : step_size,
        'gamma' : gamma,
        'num_warmup_epochs' : num_warmup_epochs
    }

Now we just create this dictionary as `ARGS`.

In [None]:
ARGS = create_args(dataset=DataSet.root_neighbours)

---

## Setting up the CoGNN

Before being able to implement the `CoGNN`, we first need to implement this class which will be an optional functionality. It is essentially a model that is constructed according to a `GumbelArgs` instance and, for a given graph computes the temperature dynamically while ensuring that it never reaches infinity.

In [None]:
class TempSoftPlus(torch.nn.Module):
    def __init__(self, gumbel_args: GumbelArgs, env_dim: int):
        super(TempSoftPlus, self).__init__()
        model_list = gumbel_args.temp_model_type.get_component_list(in_dim=env_dim, hidden_dim=env_dim, out_dim=1, num_layers=1,
                                                                    bias=False, edges_required=False,
                                                                    gin_mlp_func=gumbel_args.gin_mlp_func)
        self.linear_model = torch.nn.ModuleList(model_list)
        self.softplus = torch.nn.Softplus(beta=1)
        self.tau0 = gumbel_args.tau0

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor):
        x = self.linear_model[0](x=x, edge_index=edge_index,edge_attr = edge_attr)
        x = self.softplus(x) + self.tau0
        temp = x.pow_(-1)
        return temp.masked_fill_(temp == float('inf'), 0.)

Now we are ready to implement the `CoGNN` architecture. Something that is worth highlighting is the fact that we have two `ActionNet`, i.e. policies, namely `self.in_act_net` and `self.out_act_net` which decide whether to receive and broadcast, respectively. In other words, we do not have a mapping to the 4 possible states `Standard`, `Broadcast`, `Receive`, and `Isolate` but rather each of the decision to receive and broadcast are independent. Moreover, there are separate encoders of the edges for the environment and the action networks: `self.env_bond_encoder`  and `self.act_bond_encoder`.

Overall, the `CoGNN` class follows the typical design pattern of a `torch.nn.Module` model, i.e. its `forward` function computes the model's output. In particular, it can track edge ratio statistics, then it encodes the edges using the two encoder mentioned before. Afterwards, it encodes the node (in case no particular encoder shall be used it applies dropout and computes the activation) and applies layer normalization to each node's features.

Then the message passing begins with the number of message passing rounds defined by `env_args.num_layers`: The two policies compute the logit-distribution tensors (receive, no receive) and (broadcast, no broadcast) for each node from which, for a given (learned) temperature, we compute the actions using a hard Gumbel Softmax (the variable names `in_probs` and `out_probs` are somewhat misleading as the value for each edge is either keep or discard, no in between values). The call to `create_edge_weight` now constructs the graph as determined by the policy for that message passing-layer.

The next step is to perform the actual message passing, which is handled by the environment network `self.env_net` and thus in extension by the model chosen in `env_args.load_net()`, to whose output we apply dropout and an activation function. Depending on whether we chose to track the edge ratio, the information is now appended to the statistics, and thereafter we can optionally (hyperparameter) enable a residual connection to keep more of a node's original features.

The final result is then obtained by performing a layer normalization, pooling and running the output through a decoder (usually a linear classifier) and adding this message passing layer's output to the previous message passing layer's output. Then again some statistics can be collected and the forward pass returns the resulting features and statistics.

In [None]:
class CoGNN(torch.nn.Module):
    def __init__(self, gumbel_args: GumbelArgs, env_args: EnvArgs, action_args: ActionNetArgs, pool: Pool):
        super(CoGNN, self).__init__()
        self.env_args = env_args
        self.learn_temp = gumbel_args.learn_temp
        if gumbel_args.learn_temp:
            self.temp_model = TempSoftPlus(gumbel_args=gumbel_args, env_dim=env_args.env_dim)
        self.temp = gumbel_args.temp

        self.num_layers = env_args.num_layers
        self.env_net = env_args.load_net()
        self.use_encoders = env_args.dataset_encoders.use_encoders()

        layer_norm_cls = torch.nn.LayerNorm if env_args.layer_norm else torch.nn.Identity
        self.hidden_layer_norm = layer_norm_cls(env_args.env_dim)
        self.skip = env_args.skip
        self.dropout = torch.nn.Dropout(p=env_args.dropout)
        self.drop_ratio = env_args.dropout
        self.act = env_args.act_type.get()
        self.in_act_net = ActionNet(action_args=action_args)
        self.out_act_net = ActionNet(action_args=action_args)

        # Encoder types
        self.dataset_encoder = env_args.dataset_encoders
        self.env_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=env_args.env_dim, model_type=env_args.model_type)
        self.act_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=action_args.hidden_dim, model_type=action_args.model_type)

        # Pooling function to generate whole-graph embeddings
        self.pooling = pool.get()

    def forward(self, x: Tensor, edge_index: Adj, pestat, edge_attr: OptTensor = None, batch: OptTensor = None,
                edge_ratio_node_mask: OptTensor = None) -> Tuple[Tensor, Tensor]:
        result = 0

        calc_stats = edge_ratio_node_mask is not None
        if calc_stats:
            edge_ratio_edge_mask = edge_ratio_node_mask[edge_index[0]] & edge_ratio_node_mask[edge_index[1]]
            edge_ratio_list = []

        # bond encode
        if edge_attr is None or self.env_bond_encoder is None:
            env_edge_embedding = None
        else:
            env_edge_embedding = self.env_bond_encoder(edge_attr)
        if edge_attr is None or self.act_bond_encoder is None:
            act_edge_embedding = None
        else:
            act_edge_embedding = self.act_bond_encoder(edge_attr)

        # node encode  
        x = self.env_net[0](x, pestat)  # (N, F) encoder
        if not self.use_encoders:
            x = self.dropout(x)
            x = self.act(x)

        for gnn_idx in range(self.num_layers):
            x = self.hidden_layer_norm(x)

            # action
            in_logits = self.in_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding)  # (N, 2)
            out_logits = self.out_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding)  # (N, 2)

            temp = self.temp_model(x=x, edge_index=edge_index, edge_attr=env_edge_embedding) if self.learn_temp else self.temp
            in_probs = torch.nn.functional.gumbel_softmax(logits=in_logits, tau=temp, hard=True)
            out_probs = torch.nn.functional.gumbel_softmax(logits=out_logits, tau=temp, hard=True)
            edge_weight = self.create_edge_weight(edge_index=edge_index, keep_in_prob=in_probs[:, 0], keep_out_prob=out_probs[:, 0])

            # environment
            out = self.env_net[1 + gnn_idx](x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=env_edge_embedding)
            out = self.dropout(out)
            out = self.act(out)

            if calc_stats:
                edge_ratio = edge_weight[edge_ratio_edge_mask].sum() / edge_weight[edge_ratio_edge_mask].shape[0]
                edge_ratio_list.append(edge_ratio.item())

            if self.skip:
                x = x + out
            else:
                x = out

        x = self.hidden_layer_norm(x)
        x = self.pooling(x, batch=batch)
        x = self.env_net[-1](x)  # decoder
        result = result + x

        if calc_stats:
            edge_ratio_tensor = torch.tensor(edge_ratio_list, device=x.device)
        else:
            edge_ratio_tensor = -1 * torch.ones(size=(self.num_layers,), device=x.device)
        return result, edge_ratio_tensor

    def create_edge_weight(self, edge_index: Adj, keep_in_prob: Tensor, keep_out_prob: Tensor) -> Tensor:
        u, v = edge_index
        edge_in_prob = keep_in_prob[v]
        edge_out_prob = keep_out_prob[u]
        return edge_in_prob * edge_out_prob

---

## Setting up the experiment

The `Experiment` class is designed to handle all the functionalities required for the experiment. It contains the training and testing of our model as defined by the arguments `ARGS`. Compared to the authors' code, we have changed only that we are not using an argument parser but instead a dictionary. 

During instantiation of an `Experiment` object, we set all the attributes drived from the `ARGS` and define the device that the experiment shall be run upon. Moreover, we define a performance metric, the decimal precision and loss function to be used. There are also some assertations verified.

The `Experiment` class implements several methods:

- `single_fold()` instantiates a `CoGNN` as well as the optimizer and scheduler, then it calls `train_and_test()` to perform training and testing of the CoGNN. Thereafter it returns the performance metrics.

- `train_and_test()` essentially contains the training and testing structure. It sets up the dataloaders and instantiates the metric class. The training loop operates by the following logic per epoch: (1) `train()` on the training set with evaluation on the training set, (2) `test()` on the training set, (3) `test()` on the validation set and on the test set, (4) store the metrics, optionally perform a scheduler step. At the very end for non-synthetic datasets we `test()` it again on the test dataset.

- `train()` performs a batch-based training step with backpropagation and an optimizer step.

- `test()` evaluates the model by computing the outputs of a dataset and applying the loss function to determine the losses and performance metric scores which it then returns.

- `run()` at a high level view is the core script of the experiment. It loads the dataset into memory, splits them into folds, contains an n-fold training loop which makes a call to `single_fold()` where the model training and evaluation happens, then computes and outputs the experiment results.

In [None]:
class Experiment(object):
    def __init__(self, args: dict): # args type used to be Namespace
        super().__init__()
        for key, value in args.items(): # Adjusted loop
            setattr(self, key, value)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        set_seed(seed=self.seed)

        # parameters
        self.metric_type = self.dataset.get_metric_type()
        self.decimal = self.dataset.num_after_decimal()
        self.task_loss = self.metric_type.get_task_loss()

        # asserts
        self.dataset.asserts(args)

    def run(self) -> Tuple[Tensor, Tensor]:
        dataset = self.dataset.load(seed=self.seed, pos_enc=self.pos_enc)
        if self.metric_type.is_multilabel():
            dataset.data.y = dataset.data.y.to(dtype=torch.float)

        folds = self.dataset.get_folds(fold=self.fold)

        # locally used parameters
        out_dim = self.metric_type.get_out_dim(dataset=dataset)
        gin_mlp_func = self.dataset.gin_mlp_func()
        env_act_type = self.dataset.env_activation_type()

        # named tuples
        gumbel_args = GumbelArgs(learn_temp=self.learn_temp, temp_model_type=self.temp_model_type, tau0=self.tau0,
                                 temp=self.temp, gin_mlp_func=gin_mlp_func)
        env_args =  EnvArgs(model_type=self.env_model_type, num_layers=self.env_num_layers, env_dim=self.env_dim,
                            layer_norm=self.layer_norm, skip=self.skip, batch_norm=self.batch_norm, dropout=self.dropout,
                            act_type=env_act_type, metric_type=self.metric_type, in_dim=dataset[0].x.shape[1], out_dim=out_dim,
                            gin_mlp_func=gin_mlp_func, dec_num_layers=self.dec_num_layers, pos_enc=self.pos_enc,
                            dataset_encoders=self.dataset.get_dataset_encoders())
        action_args = ActionNetArgs(model_type=self.act_model_type, num_layers=self.act_num_layers,
                                    hidden_dim=self.act_dim, dropout=self.dropout, act_type=ActivationType.RELU,
                                    env_dim=self.env_dim, gin_mlp_func=gin_mlp_func)

        # folds
        metrics_list = []
        edge_ratios_list = []
        for num_fold in folds:
            set_seed(seed=self.seed)
            dataset_by_split = self.dataset.select_fold_and_split(num_fold=num_fold, dataset=dataset)
            best_losses_n_metrics, edge_ratios =\
                self.single_fold(dataset_by_split=dataset_by_split, gumbel_args=gumbel_args, env_args=env_args,
                                 action_args=action_args, num_fold=num_fold)

            # print final
            print_str = f'Fold {num_fold}/{len(folds)}'
            for name in best_losses_n_metrics._fields:
                print_str += f",{name}={round(getattr(best_losses_n_metrics, name), self.decimal)}"
            metrics_list.append(best_losses_n_metrics.get_fold_metrics())

            if edge_ratios is not None:
                edge_ratios_list.append(edge_ratios)

        metrics_matrix = torch.stack(metrics_list, dim=0)  # (F, 3)
        metrics_mean = torch.mean(metrics_matrix, dim=0).tolist()  # (3,)
        if len(edge_ratios_list) > 0:
            edge_ratios = torch.mean(torch.stack(edge_ratios_list, dim=0), dim=0)
        else:
            edge_ratios = None

        # prints
        print(f'Final Rewired train={round(metrics_mean[0], self.decimal)},'
              f'val={round(metrics_mean[1], self.decimal)},'
              f'test={round(metrics_mean[2], self.decimal)}')
        if len(folds) > 1:
            metrics_std = torch.std(metrics_matrix, dim=0).tolist()  # (3,)
            print(f'Final Rewired train={round(metrics_mean[0], self.decimal)}+-{round(metrics_std[0], self.decimal)},'
                  f'val={round(metrics_mean[1], self.decimal)}+-{round(metrics_std[1], self.decimal)},'
                  f'test={round(metrics_mean[2], self.decimal)}+-{round(metrics_std[2], self.decimal)}')
    
        return metrics_mean, edge_ratios
            
    def single_fold(self, dataset_by_split: DatasetBySplit, gumbel_args: GumbelArgs, env_args: EnvArgs,
                    action_args: ActionNetArgs, num_fold: int) -> Tuple[LossesAndMetrics, OptTensor]:
        model = CoGNN(gumbel_args=gumbel_args, env_args=env_args, action_args=action_args,
                      pool=self.pool).to(device=self.device)

        optimizer = self.dataset.optimizer(model=model, lr=self.lr, weight_decay=self.weight_decay)
        scheduler = self.dataset.scheduler(optimizer=optimizer, step_size=self.step_size, gamma=self.gamma,
                                           num_warmup_epochs=self.num_warmup_epochs, max_epochs=self.max_epochs)

        with tqdm.tqdm(total=self.max_epochs, file=sys.stdout) as pbar:
            best_losses_n_metrics, edge_ratios = self.train_and_test(dataset_by_split=dataset_by_split, model=model, optimizer=optimizer,
                                                                     scheduler=scheduler, pbar=pbar, num_fold=num_fold)
        return best_losses_n_metrics, edge_ratios

    def train_and_test(self, dataset_by_split: DatasetBySplit, model, optimizer, scheduler, pbar, num_fold: int)\
            -> Tuple[LossesAndMetrics, OptTensor]:
        train_loader = DataLoader(dataset_by_split.train, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(dataset_by_split.val, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(dataset_by_split.test, batch_size=self.batch_size, shuffle=True)

        best_losses_n_metrics = self.metric_type.get_worst_losses_n_metrics()
        for epoch in range(self.max_epochs):
            self.train(train_loader=train_loader, model=model, optimizer=optimizer)
            train_loss, train_metric, _ = self.test(loader=train_loader, model=model, split_mask_name='train_mask', calc_edge_ratio=False)
            if self.dataset.is_expressivity():
                val_loss, val_metric = train_loss, train_metric
                test_loss, test_metric = train_loss, train_metric
            else:
                val_loss, val_metric, _ = self.test(loader=val_loader, model=model, split_mask_name='val_mask', calc_edge_ratio=False)
                test_loss, test_metric, _ = self.test(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=False)

            losses_n_metrics = LossesAndMetrics(train_loss=train_loss, val_loss=val_loss, test_loss=test_loss,
                                                train_metric=train_metric, val_metric=val_metric, test_metric=test_metric)
            if scheduler is not None:
                scheduler.step(losses_n_metrics.val_metric)

            # best metrics
            if self.metric_type.src_better_than_other(src=losses_n_metrics.val_metric, other=best_losses_n_metrics.val_metric):
                best_losses_n_metrics = losses_n_metrics

            # prints
            log_str = f'Split: {num_fold}, epoch: {epoch}'
            for name in losses_n_metrics._fields:
                log_str += f",{name}={round(getattr(losses_n_metrics, name), self.decimal)}"
            log_str += f"({round(best_losses_n_metrics.test_metric, self.decimal)})"
            pbar.set_description(log_str)
            pbar.update(n=1)

        edge_ratios = None
        if self.dataset.not_synthetic():
            _, _, edge_ratios = self.test(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=True)

        return best_losses_n_metrics, edge_ratios

    def train(self, train_loader, model, optimizer):
        model.train()

        for data in train_loader:
            if self.batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
                continue
            optimizer.zero_grad()
            node_mask = self.dataset.get_split_mask(data=data, batch_size=data.num_graphs,
                                                    split_mask_name='train_mask').to(self.device)
            edge_attr = data.edge_attr
            if data.edge_attr is not None:
                edge_attr = edge_attr.to(device=self.device)

            # forward
            scores, _ =\
                model(data.x.to(device=self.device), edge_index=data.edge_index.to(device=self.device),
                      batch=data.batch.to(device=self.device), edge_attr=edge_attr, edge_ratio_node_mask=None,
                      pestat=self.pos_enc.get_pe(data=data, device=self.device))
            train_loss = self.task_loss(scores[node_mask], data.y.to(device=self.device)[node_mask])

            # backward
            train_loss.backward()
            if self.dataset.clip_grad():
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

    def test(self, loader, model, split_mask_name: str, calc_edge_ratio: bool)\
            -> Tuple[float, Any, Tensor]:
        model.eval()

        total_loss, total_metric, total_edge_ratios = 0, 0, 0
        total_scores = np.empty(shape=(0, model.env_args.out_dim))
        total_y = None
        for data in loader:
            if self.batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
                continue
            node_mask = self.dataset.get_split_mask(data=data, batch_size=data.num_graphs, split_mask_name=split_mask_name).to(device=self.device)
            if calc_edge_ratio:
                edge_ratio_node_mask = self.dataset.get_edge_ratio_node_mask(data=data, split_mask_name=split_mask_name).to(device=self.device)
            else:
                edge_ratio_node_mask = None
            edge_attr = data.edge_attr
            if data.edge_attr is not None:
                edge_attr = edge_attr.to(device=self.device)

            # forward
            scores, edge_ratios =\
                model(data.x.to(device=self.device), edge_index=data.edge_index.to(device=self.device),
                      edge_attr=edge_attr, batch=data.batch.to(device=self.device),
                      edge_ratio_node_mask=edge_ratio_node_mask,
                      pestat=self.pos_enc.get_pe(data=data, device=self.device))
            
            eval_loss = self.task_loss(scores, data.y.to(device=self.device))

            # analytics
            total_scores = np.concatenate((total_scores, scores[node_mask].detach().cpu().numpy()))
            if total_y is None:
                total_y = data.y.to(device=self.device)[node_mask].detach().cpu().numpy()
            else:
                total_y = np.concatenate((total_y, data.y.to(device=self.device)[node_mask].detach().cpu().numpy()))

            total_loss += eval_loss.item() * data.num_graphs
            total_edge_ratios += edge_ratios * data.num_graphs

        metric = self.metric_type.apply_metric(scores=total_scores, target=total_y)

        loss = total_loss / len(loader.dataset)
        edge_ratios = total_edge_ratios / len(loader.dataset)
        return loss, metric, edge_ratios

The last thing we have to do now is execute the experiment.

In [None]:
#if ARGS['gpu'] is not None:
#    set_device(ARGS['gpu'])

GPU_AVAILABLE = torch.cuda.is_available()
if GPU_AVAILABLE:
    device = torch.device("cuda:0")
    torch.cuda.set_device(device)

Experiment(args=ARGS).run()