# Cooperative Graph Neural Networks ([CoGNN](https://doi.org/10.48550/arXiv.2310.01267))

This part was adapted by Tobias Erbacher from the [authors' github](https://github.com/benfinkelshtein/CoGNN/tree/main). We recommend to run this on a GPU service like [Google Colab](https://colab.research.google.com/). The goal is to modify the original authors' code to work with our codebase.

## Installation

To ensure we are using the same modules as the authors, we need to install the following:

In [1]:
#%pip install torch==2.0.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
#%pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
#%pip install torch-geometric==2.3.0
#%pip install torchmetrics ogb rdkit
#%pip install matplotlib

To execute this notebook, we will assume that the datasets are already installed. If you cloned our github repository, you will be able to find them in [this folder](https://github.com/TobiasErbacher/gdl/tree/main/replication/data).

---

## Initialization

### Custom Files

To make this notebook less crowded, some class definitions have been moved to other files. The overview here is to get an idea of the library dependencies. The file `architectures.py` will perform the following imports:

- from `torch` import `Tensor, cat`
- from `torch.nn` import `Linear, Parameter`
- from `typing` import `Callable, List`
- from `torch_geometric.typing` import `OptTensor, Adj`
- from `torch_geometric.nn.conv` import `MessagePassing`
- from `torch_geometric.nn.conv.gcn_conv` import `gcn_norm`
- from `torch_geometric.utils` import `remove_self_loops, add_remaining_self_loops`

Now, the type classes will be bundled in the file `type_classes.py` and perform the following imports:

- from `torch` import `from_numpy, tensor`
- from `torch.nn` import `Module, ReLU, GELU, CrossEntropyLoss`
- import `torch.nn.functional` as `F`
- from `torchmetrics` import `Accuracy`
- from `torch_geometric.data` import `Data`
- from `torch_geometric.nn.pool` import `global_mean_pool, global_add_pool`
- import `numpy` as `np`
- from `math` import `inf`
- from `enum` import `Enum, auto`
- from `typing` import `Callable, List, NamedTuple`
- from `architectures` import `WeightedGCNConv, WeightedGINConv, WeightedGNNConv, GraphLinear, BatchIdentity`

Note that the import from `architectures` is from the file `architectures.py`.

Next up, the encoders will be bundled in the file `encoder_classes.py` and perform the following imports:

- from `enum` import `Enum, auto`
- from `torch` import `cat, rand, isnan, sum, Tensor`
- from `torch.nn` import `Module, ModuleList, Linear, Sequential, BatchNorm1d, ReLU, TransformerEncoder, TransformerEncoderLayer, Embedding`
- from `torch.nn.init` import `xavier_uniform_`
- from `ogb.utils.features` import `get_atom_feature_dims, get_bond_feature_dims`
- from `torch_geometric.data` import `Data`

### Importing the Custom Files

In [2]:
from type_classes import ModelType, LossesAndMetrics, Pool, ActivationType, Metric
from encoder_classes import PosEncoder, DataSetEncoders

### Importing the Official Libraries

We first import the required libraries:

In [3]:
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.loader import DataLoader
from torch_geometric.typing import OptTensor, Adj
from torch_geometric.transforms import NormalizeFeatures
from scipy.sparse import csr_matrix

import os
import sys
import tqdm
import numpy as np
import random
from typing import Union, List, Tuple, NamedTuple, Callable

We first check whether a GPU is available and if so then set it as the default device:

Next, we need to define the root directory `ROOT_DIR`:

In [4]:
try:
    # If running as a .py
    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
except:
    # If running as a .ipynb
    from os import getcwd
    ROOT_DIR = os.path.dirname(os.path.abspath(getcwd()))

To ensure reproducibility, we define the `set_seed()` function:

In [5]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = False

The device to be used for training (GPU, ...) will be set automatically in the `Experiment` class (see below).

### Dataset Loader

The `.npz` files stored in the `data` folder all have the following keys or similar ones (see the list in the code):

- `adj_data`
- `adj_indices`
- `adj_indptr`
- `adj_shape`
- `attr_data`
- `attr_indices`
- `attr_indptr`
- `attr_shape`
- `labels`
- `class_names`

As an example, the dataset `A.Computer` contains the following class names:

- `Desktops`
- `Data Storage`
- `Laptops`
- `Monitors`
- `Computer Components`
- `Video Projectors`
- `Routers`
- `Tablets`
- `Networking Products`
- `Webcams`

They are represented in the data labels `labels` as numbers in `[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]`.

In this dataset we have `13752` nodes as can be inferred from the `adj_shape` key giving us the dimensions `[13752 13752]`. To each node an element of `labels` provides the corresponding label. The `adj_*` keys let us reconstruct the adjacency matrix and the `attr_*` keys let us reconstruct the node features.

We can load the dataset from `.npz` files with the `load_dataset` function:

In [6]:
def load_dataset(
    dataset_name : str
) -> Data:
    """
        Loads the datasets from .npz files into a PyTorch_Geometric Data-object where
        -   x:              Feature matrix,
        -   y:              Labels,
        -   edge_index:     List of edges,
        -   edge_weight:    List of edge weights.

        The function is currently implemented for the following datasets:
        -   A.Computer
        -   A.Photo
        -   Citeseer
        -   Cora-ML
        -   MS-Academic
        -   PubMed
    """
    path = os.path.join(ROOT_DIR, "replication/data/" + dataset_name + ".npz")
    npz = np.load(path)
    assert isinstance(npz, np.lib.npyio.NpzFile)

    variant1 = [(key in ["adj_data",
                         "adj_indices",
                         "adj_indptr",
                         "adj_shape",
                         "attr_data",
                         "attr_indices",
                         "attr_indptr",
                         "attr_shape",
                         "labels",
                         "class_names"])
                for key in npz.keys()]
    isvar1 = False not in variant1
    
    variant2 = [(key in ["adj_data",
                         "adj_indices",
                         "adj_indptr",
                         "adj_shape",
                         "attr_data",
                         "attr_indices",
                         "attr_indptr",
                         "attr_shape",
                         "labels",
                         "node_names",
                         "attr_names",
                         "class_names"])
                for key in npz.keys()]
    isvar2 = False not in variant2
    
    variant3 = [(key in ["adj_matrix.data",
                         "adj_matrix.indices",
                         "adj_matrix.indptr",
                         "adj_matrix.shape",
                         "attr_matrix.data",
                         "attr_matrix.indices",
                         "attr_matrix.indptr",
                         "attr_matrix.shape",
                         "edge_attr_matrix",
                         "labels",
                         "node_names",
                         "attr_names",
                         "edge_attr_names",
                         "class_names",
                         "metadata"])
                for key in npz.keys()]
    isvar3 = False not in variant3

    match dataset_name:
        case name if name in ["A.Computer", "A.Photo", "MS-Academic"]:
            assert isvar1 or isvar2
            features_csr = csr_matrix((npz["attr_data"], npz["attr_indices"], npz["attr_indptr"]), shape=npz["attr_shape"])
            adjacency_csr = csr_matrix((npz["adj_data"], npz["adj_indices"], npz["adj_indptr"]), shape=npz["adj_shape"])
        case name if name in name in ["Citeseer", "Cora-ML", "PubMed"]:
            assert isvar3
            features_csr = csr_matrix((npz["attr_matrix.data"], npz["attr_matrix.indices"], npz["attr_matrix.indptr"]), shape=npz["attr_matrix.shape"])
            adjacency_csr = csr_matrix((npz["adj_matrix.data"], npz["adj_matrix.indices"], npz["adj_matrix.indptr"]), shape=npz["adj_matrix.shape"])
        case _:
            raise NotImplementedError(f"The dataset cannot be loaded as it contains unexpected keys. The keys obtained are \n{list(npz.keys())}")
    
    x = torch.FloatTensor(features_csr.toarray())
    edge_index, edge_weight = from_scipy_sparse_matrix(adjacency_csr)
    y = torch.LongTensor(npz["labels"])

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y)
    if len(np.unique(x)) > 2:
        transform = NormalizeFeatures()
        data = transform(data)

    num_classes = len(set(npz["labels"]))
    
    return data, num_classes

## Parameters

In this section we will define the parameters to use.

In [7]:
def create_config(
    dataset: str,                                                                           # Only the dataset file name, without the .npz ending.
    device : torch.device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),     # Can override this manually as torch.device object.
    pool: Pool = Pool.NONE,

    # gumbel
    learn_temp=False,
    temp_model_type: ModelType.from_string=ModelType.LIN,
    tau0: float=0.5,
    temp: float=0.01,

    # optimization
    num_epochs: int=1000,
    batch_size: int=32,
    lr: float=1e-3,
    dropout: float=0.2,

    # env cls parameters
    env_model_type: ModelType.from_string=ModelType.MEAN_GNN,
    env_num_layers: int=3,
    env_dim: int=128,
    skip=False,
    batch_norm=False,
    layer_norm=False,
    dec_num_layers: int=1,
    pos_enc: PosEncoder.from_string=PosEncoder.NONE,

    # policy cls parameters
    act_model_type: ModelType=ModelType.MEAN_GNN,
    act_num_layers: int=1,
    act_dim: int=16,

    # reproduce
    seed: int=0,
    gpu: int=None, # In the original argument parser, there is no default value defined here

    # dataset dependant parameters
    fold: int=None,

    # optimizer and scheduler
    weight_decay: float=0.0,
    ## for steplr scheduler only
    step_size: int=None,
    gamma: float=None,
    ## for cosine with warmup scheduler only
    num_warmup_epochs: int=None,

    decimal: int=2
):
    return {
        'dataset' : dataset,
        'device' : device,
        'pool' : pool,
        'learn_temp' : learn_temp,
        'temp_model_type' : temp_model_type,
        'tau0' : tau0,
        'temp' : temp,
        'num_epochs' : num_epochs,
        'batch_size' : batch_size,
        'lr' : lr,
        'dropout' : dropout,
        'env_model_type' : env_model_type,
        'env_num_layers' : env_num_layers,
        'env_dim' : env_dim,
        'skip' : skip,
        'batch_norm' : batch_norm,
        'layer_norm' : layer_norm,
        'dec_num_layers' : dec_num_layers,
        'pos_enc' : pos_enc,
        'act_model_type' : act_model_type,
        'act_num_layers' : act_num_layers,
        'act_dim' : act_dim,
        'seed' : seed,
        'gpu' : gpu,
        'fold' : fold,
        'weight_decay' : weight_decay,
        'step_size' : step_size,
        'gamma' : gamma,
        'num_warmup_epochs' : num_warmup_epochs,
        "decimal" : decimal
    }

## Dataset Class

In [8]:
class DatasetBySplit(NamedTuple):
    train: Union[Data, List[Data]]
    val: Union[Data, List[Data]]
    test: Union[Data, List[Data]]

In [9]:
class DataSet:
    def __init__(self, data: Data, train_mask: np.ndarray, val_mask: np.ndarray, test_mask: np.ndarray):
        self.data = data
        self.train_mask = train_mask
        self.val_mask = val_mask
        self.test_mask = test_mask
        self.family = True
        self.is_node_based = True
        self.not_synthetic = True
        self.is_expressivity = False
        self.clip_grad = False
        self.dataset_encoders = DataSetEncoders.NONE
        self.num_after_decimal = 2
        self.env_activation_type = ActivationType.RELU
    
    def select_fold_and_split(self, dataset: Data) -> DatasetBySplit:
        device = dataset.x.device
        train_mask = torch.tensor(self.train_mask, dtype=torch.bool, device=device)
        val_mask = torch.tensor(self.val_mask, dtype=torch.bool, device=device)
        test_mask = torch.tensor(self.test_mask, dtype=torch.bool, device=device)

        setattr(dataset, "train_mask", train_mask)
        setattr(dataset, "val_mask", val_mask)
        setattr(dataset, "test_mask", test_mask)

        dataset.train_mask[dataset.non_valid_samples] = False
        dataset.test_mask[dataset.non_valid_samples] = False
        dataset.val_mask[dataset.non_valid_samples] = False
        
        return DatasetBySplit(train=dataset, val=dataset, test=dataset)

    def gin_mlp_func(self) -> Callable:
        def mlp_func(in_channels: int, out_channels: int, bias: bool):
            return torch.nn.Sequential(
                torch.nn.Linear(in_channels, 2 * in_channels, bias=bias),
                torch.nn.BatchNorm1d(2 * in_channels),
                torch.nn.ReLU(), torch.nn.Linear(2 * in_channels, out_channels, bias=bias)
            )
        return mlp_func
    
    def get_split_mask(self, data: Data, batch_size: int, split_mask_name: str) -> Tensor:
        if hasattr(data, split_mask_name):
            return getattr(data, split_mask_name)
        elif self.is_node_based():
            return torch.ones(size=(data.x.shape[0],), dtype=torch.bool)
        else:
            return torch.ones(size=(batch_size,), dtype=torch.bool)
    
    def get_edge_ratio_node_mask(self, data: Data, split_mask_name: str) -> Tensor:
        if hasattr(data, split_mask_name):
            return getattr(data, split_mask_name)
        else:
            return torch.ones(size=(data.x.shape[0],), dtype=torch.bool)

## Network Classes

The Action Network (Policy):

In [10]:
def load_act_net(action_args : dict) -> torch.nn.ModuleList:
    model_type = action_args["model_type"]
    env_dim = action_args["env_dim"]
    hidden_dim = action_args["hidden_dim"]
    num_layers = action_args["num_layers"]
    gin_mlp_func = action_args["gin_mlp_func"]

    net = model_type.get_component_list(in_dim=env_dim,
                                        hidden_dim=hidden_dim,
                                        out_dim=2,
                                        num_layers=num_layers,
                                        bias=True,
                                        edges_required=False,
                                        gin_mlp_func=gin_mlp_func)

    return torch.nn.ModuleList(net)

In [11]:
class ActionNet(torch.nn.Module):
    def __init__(self, action_args: dict):
        """
        Create a model which represents the agent's policy.
        """
        super().__init__()
        self.num_layers = action_args["num_layers"]
        self.net = load_act_net(action_args=action_args)
        self.dropout = torch.nn.Dropout(action_args["dropout"])
        self.act = action_args["act_type"].get()

    def forward(self, x: Tensor, edge_index: Adj, env_edge_attr: OptTensor, act_edge_attr: OptTensor) -> Tensor:
        edge_attrs = [env_edge_attr] + (self.num_layers - 1) * [act_edge_attr]
        for idx, (edge_attr, layer) in enumerate(zip(edge_attrs[:-1], self.net[:-1])):
            x = layer(x=x, edge_index=edge_index, edge_attr=edge_attr)
            x = self.dropout(x)
            x = self.act(x)
        x = self.net[-1](x=x, edge_index=edge_index, edge_attr=edge_attrs[-1])
        return x

Optional dynamic temperature calculation for the Gumbel Softmax:

In [12]:
class TempSoftPlus(torch.nn.Module):
    def __init__(self, gumbel_args: dict, env_dim: int):
        super(TempSoftPlus, self).__init__()
        model_list = gumbel_args["temp_model_type"].get_component_list(in_dim=env_dim, hidden_dim=env_dim, out_dim=1, num_layers=1,
                                                                       bias=False, edges_required=False,
                                                                       gin_mlp_func=gumbel_args["gin_mlp_func"])
        self.linear_model = torch.nn.ModuleList(model_list)
        self.softplus = torch.nn.Softplus(beta=1)
        self.tau0 = gumbel_args["tau0"]

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor):
        x = self.linear_model[0](x=x, edge_index=edge_index,edge_attr = edge_attr)
        x = self.softplus(x) + self.tau0
        temp = x.pow_(-1)
        return temp.masked_fill_(temp == float("inf"), 0.)

## CoGNN

Here we will implement the CoGNN architecture.

In [13]:
def load_env_net(env_args : dict) -> torch.nn.ModuleList:
    in_dim = env_args["in_dim"]
    env_dim = env_args["env_dim"]
    num_layers = env_args["num_layers"]
    gin_mlp_func = env_args["gin_mlp_func"]
    dec_num_layers = env_args["dec_num_layers"]
    dropout = env_args["dropout"]
    act_type = env_args["act_type"]
    out_dim = env_args["out_dim"]

    enc_list = [env_args["dataset_encoders"].node_encoder(in_dim=in_dim, emb_dim=env_dim)]

    component_list = env_args["model_type"].get_component_list(in_dim=in_dim, hidden_dim=env_dim, out_dim=env_dim,
                                                               num_layers=num_layers, bias=True, edges_required=True,
                                                               gin_mlp_func=gin_mlp_func)

    if dec_num_layers > 1:
        mlp_list = (dec_num_layers - 1) * [torch.nn.Linear(env_dim, env_dim), torch.nn.Dropout(dropout), act_type.nn()]
        mlp_list = mlp_list + [torch.nn.Linear(env_dim, out_dim)]
        dec_list = [torch.nn.Sequential(*mlp_list)]
    else:
        dec_list = [torch.nn.Linear(env_dim, out_dim)]

    return torch.nn.ModuleList(enc_list + component_list + dec_list)

In [14]:
class CoGNN(torch.nn.Module):
    def __init__(
        self,
        gumbel_args: dict,
        env_args: dict,
        action_args: dict,
        pool: Pool
    ):
        super(CoGNN, self).__init__()
        self.env_args = env_args
        self.learn_temp = gumbel_args["learn_temp"]
        if self.learn_temp:
            self.temp_model = TempSoftPlus(gumbel_args=gumbel_args, env_dim=env_args["env_dim"])
        self.temp = gumbel_args["temp"]

        self.num_layers = env_args["num_layers"]
        self.env_net = load_env_net(env_args=env_args)
        self.use_encoders = env_args["dataset_encoders"].use_encoders()

        layer_norm_cls = torch.nn.LayerNorm if env_args["layer_norm"] else torch.nn.Identity
        self.hidden_layer_norm = layer_norm_cls(env_args["env_dim"])
        self.skip = env_args["skip"]
        self.drop_ratio = env_args["dropout"]
        self.dropout = torch.nn.Dropout(p=self.drop_ratio)
        self.act = env_args["act_type"].get()
        self.in_act_net = ActionNet(action_args=action_args)
        self.out_act_net = ActionNet(action_args=action_args)

        # Encoder types
        self.dataset_encoder = env_args["dataset_encoders"]
        self.env_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=env_args["env_dim"], model_type=env_args["model_type"])
        self.act_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=action_args["hidden_dim"], model_type=action_args["model_type"])

        # Pooling function to generate whole-graph embeddings
        self.pooling = pool.get()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        pestat,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        edge_ratio_node_mask: OptTensor = None
    ) -> Tuple[Tensor, Tensor]:
        result = 0

        calc_stats = edge_ratio_node_mask is not None
        if calc_stats:
            edge_ratio_edge_mask = edge_ratio_node_mask[edge_index[0]] & edge_ratio_node_mask[edge_index[1]]
            edge_ratio_list = []

        # bond encode
        if edge_attr is None or self.env_bond_encoder is None:
            env_edge_embedding = None
        else:
            env_edge_embedding = self.env_bond_encoder(edge_attr)
        if edge_attr is None or self.act_bond_encoder is None:
            act_edge_embedding = None
        else:
            act_edge_embedding = self.act_bond_encoder(edge_attr)

        # node encode  
        x = self.env_net[0](x, pestat) # (N, F) encoder
        if not self.use_encoders:
            x = self.dropout(x)
            x = self.act(x)

        for gnn_idx in range(self.num_layers):
            x = self.hidden_layer_norm(x)

            # action
            in_logits = self.in_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding) # (N, 2)
            out_logits = self.out_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding) # (N, 2)

            temp = self.temp_model(x=x, edge_index=edge_index, edge_attr=env_edge_embedding) if self.learn_temp else self.temp
            in_probs = torch.nn.functional.gumbel_softmax(logits=in_logits, tau=temp, hard=True)
            out_probs = torch.nn.functional.gumbel_softmax(logits=out_logits, tau=temp, hard=True)
            edge_weight = self.create_edge_weight(edge_index=edge_index, keep_in_prob=in_probs[:, 0], keep_out_prob=out_probs[:, 0])

            # environment
            out = self.env_net[1 + gnn_idx](x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=env_edge_embedding)
            out = self.dropout(out)
            out = self.act(out)

            if calc_stats:
                edge_ratio = edge_weight[edge_ratio_edge_mask].sum() / edge_weight[edge_ratio_edge_mask].shape[0]
                edge_ratio_list.append(edge_ratio.item())

            if self.skip:
                x = x + out
            else:
                x = out

        x = self.hidden_layer_norm(x)
        x = self.pooling(x, batch=batch)
        x = self.env_net[-1](x) # decoder
        result = result + x

        if calc_stats:
            edge_ratio_tensor = torch.tensor(edge_ratio_list, device=x.device)
        else:
            edge_ratio_tensor = -1 * torch.ones(size=(self.num_layers,), device=x.device)
        return result, edge_ratio_tensor

    def create_edge_weight(
        self,
        edge_index: Adj,
        keep_in_prob: Tensor,
        keep_out_prob: Tensor
    ) -> Tensor:
        u, v = edge_index
        edge_in_prob = keep_in_prob[v]
        edge_out_prob = keep_out_prob[u]
        return edge_in_prob * edge_out_prob

## Experiment Definition

The `Experiment` class contains the procedure to load a dataset, run the training and evaluation, as well as return the performance metrics.

In [15]:
def train_CoGNN(
    model : torch.nn.Module,
    task_loss : torch.nn.CrossEntropyLoss,
    dataset : Data,
    optimizer : torch.optim.Optimizer,
    train_loader : DataLoader,
    config : dict
):
    """
        Training loop for one epoch.
    """
    model.train()

    for batch_data in train_loader:
        if config["batch_norm"] and (batch_data.x.shape[0] == 1 or batch_data.num_graphs == 1):
            continue
        optimizer.zero_grad()
        node_mask = dataset.get_split_mask(data=batch_data, batch_size=batch_data.num_graphs, split_mask_name='train_mask').to(config["device"])
        edge_attr = batch_data.edge_attr
        if batch_data.edge_attr is not None:
            edge_attr = edge_attr.to(device=config["device"])

        predictions, edge_ratio_tensor = model(x=batch_data.x.to(device=config["device"]),
                                               edge_index=batch_data.edge_index.to(device=config["device"]),
                                               batch=batch_data.batch.to(device=config["device"]),
                                               edge_attr=edge_attr,
                                               edge_ratio_node_mask=None,
                                               pestat=config["pos_enc"].get_pe(data=batch_data, device=config["device"]))

        train_loss = task_loss(predictions[node_mask], batch_data.y.to(device=config["device"])[node_mask])
        train_loss.backward()

        if dataset.clip_grad():
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
        optimizer.step()

    return model, train_loss.item(), edge_ratio_tensor

In [16]:
def evaluate_CoGNN(
    model : torch.nn.Module,
    dataset : Data,
    config,
    task_loss,
    metric_type,
    split_mask_name=None,
    calc_edge_ratio=True
):
    model.eval()

    total_loss, total_metric, total_edge_ratios = 0, 0, 0
    total_scores = np.empty(shape=(0, model.env_args["out_dim"]))
    total_y = None

    loader = DataLoader(dataset, batch_size=config["batch_size"], shuffle=True)
    for batch_data in loader:
        if config["batch_norm"] and (batch_data.x.shape[0] == 1 or batch_data.num_graphs == 1):
            continue
        node_mask = dataset.get_split_mask(data=batch_data, batch_size=batch_data.num_graphs, split_mask_name=split_mask_name).to(device=config["device"])
        if calc_edge_ratio:
            edge_ratio_node_mask = dataset.get_edge_ratio_node_mask(data=batch_data, split_mask_name=split_mask_name).to(device=config["device"])
        else:
            edge_ratio_node_mask = None
        edge_attr = batch_data.edge_attr
        if batch_data.edge_attr is not None:
            edge_attr = edge_attr.to(device=config["device"])
        predictions, edge_ratio_tensor = model(batch_data.x.to(device=config["device"]),
                                               edge_index=batch_data.edge_index.to(device=config["device"]),
                                               edge_attr=edge_attr, batch=batch_data.batch.to(device=config["device"]),
                                               edge_ratio_node_mask=edge_ratio_node_mask,
                                               pestat=config["pos_enc"].get_pe(data=batch_data, device=config["device"]))

        eval_loss = task_loss(predictions, batch_data.y.to(device=config["device"]))

        total_scores = np.concatenate((total_scores, predictions[node_mask].detach().cpu().numpy()))
        if total_y is None:
            total_y = batch_data.y.to(device=config["device"])[node_mask].detach().cpu().numpy()
        else:
            total_y = np.concatenate((total_y, batch_data.y.to(device=config["device"])[node_mask].detach().cpu().numpy()))
        total_loss += eval_loss.item() * batch_data.num_graphs
        total_edge_ratios += edge_ratio_tensor * batch_data.num_graphs
    
    accuracy = metric_type.apply_metric(scores=total_scores, target=total_y)
    loss = total_loss / len(loader.dataset)
    edge_ratios = total_edge_ratios / len(loader.dataset)

    return accuracy, loss, edge_ratios

In [17]:
def train_fold_CoGNN(
    config : dict,
    dataset_by_split: DatasetBySplit,
    model : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    pbar : tqdm.std.tqdm,
    num_fold: int
) -> Tuple[LossesAndMetrics, OptTensor]:
    """
        Training loop for one fold.
    """
    train_loader = DataLoader(dataset_by_split.train, batch_size=config["batch_size"], shuffle=True)
    val_loader = DataLoader(dataset_by_split.val, batch_size=config["batch_size"], shuffle=True)
    test_loader = DataLoader(dataset_by_split.test, batch_size=config["batch_size"], shuffle=True)

    task_loss = config["metric"].task_loss

    best_losses_n_metrics = config["metric"].get_worst_losses_n_metrics
    for epoch in range(config["num_epochs"]):
        model, epoch_loss, epoch_edge_ratio_tensor = train_CoGNN(model=model,
                                                                 task_loss=task_loss,
                                                                 dataset=dataset_by_split.train,
                                                                 optimizer=optimizer,
                                                                 train_loader=train_loader,
                                                                 config=config)

        epoch_train_accuracy, epoch_train_loss, epoch_edge_ratios = evaluate_CoGNN(loader=train_loader,
                                                                                   model=model,
                                                                                   split_mask_name="train_mask",
                                                                                   calc_edge_ratio=False)
        
        val_loss, val_metric, _ = evaluate_CoGNN(loader=val_loader, model=model, split_mask_name='val_mask', calc_edge_ratio=False)
        test_loss, test_metric, _ = evaluate_CoGNN(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=False)

        losses_n_metrics = LossesAndMetrics(train_loss=epoch_loss,
                                            val_loss=val_loss,
                                            test_loss=test_loss,
                                            train_metric=epoch_score,
                                            val_metric=val_metric,
                                            test_metric=test_metric)
        
        if metric_type.src_better_than_other(src=losses_n_metrics.val_metric,
                                             other=best_losses_n_metrics.val_metric):
            best_losses_n_metrics = losses_n_metrics

        log_str = f"Split: {num_fold}, epoch: {epoch}"
        for name in losses_n_metrics._fields:
            log_str += f",{name}={round(getattr(losses_n_metrics, name), config["decimal"])}"
        log_str += f"({round(best_losses_n_metrics.test_metric, config["decimal"])})"
        pbar.set_description(log_str)
        pbar.update(n=1)
    
    edge_ratios = None
    _, _, edge_ratios = evaluate_CoGNN(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=True)
    
    return best_losses_n_metrics, edge_ratios


In [18]:
def train_model_CoGNN(
    data : Data,
    config : dict,
    num_folds : int
) -> Tuple[torch.nn.Module, dict]:
    """
        Training loop for n-fold with multiple epochs.
    """
    #task_loss = config["metric"].task_loss
    decimal = config["decimal"]
    seeds = [config["seed"] + n for n in range(num_folds)]

    dataset = DataSet(data=data, train_mask=None, val_mask=None, test_mask=None)

    gin_mlp_func = dataset.gin_mlp_func()
    env_act_type = dataset.env_activation_type
    dataset_encoder = dataset.dataset_encoders
    out_dim = config["metric"].get_out_dim(data=data)

    gumbel = ["learn_temp", "temp_model_type", "tau0", "temp"]
    gumbel_add = {"gin_mlp_func" :  gin_mlp_func}
    gumbel_args = {key: config.get(key) for key in gumbel}
    gumbel_args.update(gumbel_add)

    env = ["env_dim", "layer_norm", "skip", "batch_norm", "dropout", "metric", "dec_num_layers", "pos_enc"]
    env_add = {"model_type" :       config["env_model_type"],
               "num_layers" :       config["env_num_layers"],
               "act_type" :         env_act_type,
               "in_dim" :           data.x.shape[1],
               "out_dim" :          out_dim,
               "gin_mlp_func" :     gin_mlp_func,
               "dataset_encoders" : dataset_encoder}
    env_args = {key: config.get(key) for key in env}
    env_args.update(env_add)

    action = ["dropout", "env_dim"]
    action_add = {"model_type" :    config["act_model_type"],
                  "num_layers" :    config["act_num_layers"],
                  "hidden_dim" :    config["act_dim"],
                  "act_type" :      ActivationType.RELU,
                  "gin_mlp_func" :  gin_mlp_func}
    action_args = {key: config.get(key) for key in action}
    action_args.update(action_add)

    metrics_list = []
    edge_ratios_list = []
    for n in range(num_folds):
        set_seed(seed=seeds[n])
        dataset_by_split = dataset.select_fold_and_split(dataset=dataset)

        model = CoGNN(gumbel_args=gumbel_args, env_args=env_args, action_args=action_args, pool=config["pool"]).to(device=config["device"])
        optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

        with tqdm.tqdm(total=config["num_epochs"], file=sys.stdout) as pbar:
            best_losses_n_metrics, edge_ratios = train_fold_CoGNN(config=config,
                                                                  dataset_by_split=dataset_by_split,
                                                                  model=model,
                                                                  optimizer=optimizer,
                                                                  pbar=pbar,
                                                                  num_fold=n)
        
        print_str = f"Fold {n}/{num_folds}"
        for name in best_losses_n_metrics._fields:
            print_str += f",{name}={round(getattr(best_losses_n_metrics, name), decimal)}"
        print(print_str)
        print()
        metrics_list.append(best_losses_n_metrics.get_fold_metrics())

        if edge_ratios is not None:
                edge_ratios_list.append(edge_ratios)
    
    metrics_matrix = torch.stack(metrics_list, dim=0)  # (F, 3)
    metrics_mean = torch.mean(metrics_matrix, dim=0).tolist()  # (3,)
    if len(edge_ratios_list) > 0:
        edge_ratios = torch.mean(torch.stack(edge_ratios_list, dim=0), dim=0)
    else:
        edge_ratios = None
    
    print(f"Final Rewired train={round(metrics_mean[0], decimal)},"
          f"val={round(metrics_mean[1], decimal)},"
          f"test={round(metrics_mean[2], decimal)}")
    
    if num_folds > 1:
        metrics_std = torch.std(metrics_matrix, dim=0).tolist()  # (3,)
        print(f"Final Rewired train={round(metrics_mean[0], decimal)}+-{round(metrics_std[0], decimal)},"
              f"val={round(metrics_mean[1], decimal)}+-{round(metrics_std[1], decimal)},"
              f"test={round(metrics_mean[2], decimal)}+-{round(metrics_std[2], decimal)}")

    # return metrics_mean, edge_ratios
    history = None
    return model, history

In [19]:
def main_CoGNN(
    config : dict,
    num_folds : int=10,
    run_grid_search : bool=False
) -> Tuple[torch.nn.Module, Data, dict]:
    """
        Main function with optional grid search.
    """
    device = config["device"]

    print(f"Loading {config["dataset"]} dataset...")
    data, num_classes = load_dataset(config["dataset"])
    config["metric"] = Metric(task="multiclass", num_classes=num_classes)
    data = data.to(device)

    if run_grid_search:
        print("Running grid search...")
        raise NotImplementedError("Grid search hasn't been implemented yet!")
    
        # List of the parameters and the values to be searched across
        param_list_1 = []
        param_list_2 = []
        # ...

        results = []

        # Adjust the nested loops for the various parameters
        for param_1 in param_list_1:
            for param_2 in param_list_2:
                # ...
                config_copy = config.copy()
                config_copy["param_1_name"] = param_1
                config_copy["param_2_name"] = param_2

                print(f"\nTrying param_1_name={param_1}, param_2_name={param_2} ...")
                model = CoGNN(dataset, config_copy).to(device)
                model, _ = train_model_CoGNN(model, data, config_copy)

                test_acc, test_steps, _, _ = evaluate_CoGNN(model, data, data.test_mask)

                results.append({
                    "param_1_name": param_1,
                    "param_2_name": param_2,
                    # ...
                    "test_acc": test_acc,
                    "test_steps": test_steps
                })

                print(f"Result: Test Acc: {test_acc:.4f}, Test Steps: {test_steps:.2f}")

        best_parameters = max(results, key=lambda x: x["test_acc"])

        print("\nGrid Search Results:")
        for r in results:
            print(f"param_1_name={r["param_1_name"]}, param_2_name={r["param_2_name"]}: " # ...
                  f"Acc={r["test_acc"]:.4f}, Steps={r["test_steps"]:.2f}")

        print(f"\nBest parameters: param_1_name={best_result["param_1_name"]}, " # ...
              f"param_2_name={best_result["param_2_name"]}, "
              f"Accuracy={best_result["test_acc"]:.4f}")

        # config with best parameters
        config["param_1_name"] = best_result["param_1_name"]
        config["param_2_name"] = best_result["param_2_name"]
        # ...    

    # Instantiating the model with obtained config
    print("\nTraining final model...")
    #model = CoGNN(gumbel_args, env_args, action_args, config["pool"]).to(device)

    # Training the model
    model, history = train_model_CoGNN(data=data, config=config, num_folds=num_folds)

    # test eval
    test_acc, test_steps, _, _ = evaluate_CoGNN(model, data, data.test_mask)
    print(f"Test Accuracy: {test_acc:.4f}, Average Steps: {test_steps:.2f}")

    # visualizations
    # TO-DO!

    # optional: impact of various parameters

    return model, data, history

In [20]:
if __name__ == "__main__":
    CONFIG = create_config(dataset="A.Computer", seed=42)
    model, data, history = main_CoGNN(config=CONFIG, num_folds=10, run_grid_search=False)

Loading A.Computer dataset...

Training final model...


AttributeError: 'DataSet' object has no attribute 'x'

$\textcolor{red}{\texttt{TO - DO!}}$

- grid search in `main` function. What parameters to search across and what should the values be?
- Check that all quotations are " not '
- Check whether the dataset classification (homophilic etc.) is okay