# Cooperative Graph Neural Networks ([CoGNN](https://doi.org/10.48550/arXiv.2310.01267))

This part was adapted by Tobias Erbacher from the [authors' github](https://github.com/benfinkelshtein/CoGNN/tree/main). We recommend to run this on a GPU service like [Google Colab](https://colab.research.google.com/).

## Installation

To ensure we are using the same modules as the authors, we need to install the following:

In [None]:
%pip install torch==2.0.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
%pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
%pip install torch-geometric==2.3.0
%pip install torchmetrics ogb rdkit
%pip install matplotlib

To execute this notebook, we will assume that the datasets are already installed. If you cloned our github repository, you will be able to find them in [this folder](https://github.com/TobiasErbacher/gdl/tree/main/replication/data). In particular, we are using:

- Computers <font color='red'>(Link?)</font>
- Photo <font color='red'>(Link?)</font>
- [CiteSeer](https://linqs.org/datasets/) <font color='red'>(Which one?)</font>
- CoraML <font color='red'>(Link?)</font>
- MS-Academic <font color='red'>(Link?)</font>
- PubMed <font color='red'>(Link?)</font>

---

## Initialization

We first check whether a GPU is available and if so then set it as the default device:

In [None]:
import torch
GPU_AVAILABLE = torch.cuda.is_available()

assert GPU_AVAILABLE
device = torch.device("cuda:0")
torch.cuda.set_device(device)

Now we will import the libraries:

In [2]:
import os
import sys
import tqdm
import random
import torch
from torch_geometric.typing import OptTensor
from torch_geometric.loader import DataLoader
from torch_geometric.data import Data
import numpy as np
from enum import Enum, auto
from argparse import ArgumentParser, Namespace
from typing import NamedTuple, Tuple, Any, Callable
from ogb.utils.features import get_atom_feature_dims, get_bond_feature_dims

Next, we import the custom classes:

In [None]:
from helper.class_encoder_laplace import LapPENodeEncoder, LAP_DIM_PE
from helper.class_encoder_kernel import RWSENodeEncoder, KER_DIM_PE
from helper.class_encoder import PosEncoder
from helper.class_metric import MetricType
from helper.class_concat2node import Concat2NodeEncoder
from helper.class_actionnetargs import ActionNetArgs
from helper.class_activationtype import ActivationType
from helper.class_datasetbysplit import DatasetBySplit
from helper.class_metric_lam import LossesAndMetrics
from helper.class_cognn import CoGNN
from helper.class_dataset import DataSet
from helper.class_model import ModelType
from helper.class_pool import Pool
from helper.class_encoder import PosEncoder

Let us set up an argument parser:

In [None]:
def parse_arguments():
    parser = ArgumentParser()
    parser.add_argument("--dataset", dest="dataset", default=DataSet.roman_empire, type=DataSet.from_string,
                        choices=list(DataSet), required=False)
    parser.add_argument("--pool", dest="pool", default=Pool.NONE, type=Pool.from_string,
                        choices=list(Pool), required=False)

    # gumbel
    parser.add_argument("--learn_temp", dest="learn_temp", default=False, action='store_true', required=False)
    parser.add_argument("--temp_model_type", dest="temp_model_type", default=ModelType.LIN,
                        type=ModelType.from_string, choices=list(ModelType), required=False)
    parser.add_argument("--tau0", dest="tau0", default=0.5, type=float, required=False)
    parser.add_argument("--temp", dest="temp", default=0.01, type=float, required=False)

    # optimization
    parser.add_argument("--max_epochs", dest="max_epochs", default=3000, type=int, required=False)
    parser.add_argument("--batch_size", dest="batch_size", default=32, type=int, required=False)
    parser.add_argument("--lr", dest="lr", default=1e-3, type=float, required=False)
    parser.add_argument("--dropout", dest="dropout", default=0.2, type=float, required=False)

    # env cls parameters
    parser.add_argument("--env_model_type", dest="env_model_type", default=ModelType.MEAN_GNN,
                        type=ModelType.from_string, choices=list(ModelType), required=False)
    parser.add_argument("--env_num_layers", dest="env_num_layers", default=3, type=int, required=False)
    parser.add_argument("--env_dim", dest="env_dim", default=128, type=int, required=False)
    parser.add_argument("--skip", dest="skip", default=False, action='store_true', required=False)
    parser.add_argument("--batch_norm", dest="batch_norm", default=False, action='store_true', required=False)
    parser.add_argument("--layer_norm", dest="layer_norm", default=False, action='store_true', required=False)
    parser.add_argument("--dec_num_layers", dest="dec_num_layers", default=1, type=int, required=False)
    parser.add_argument("--pos_enc", dest="pos_enc", default=PosEncoder.NONE,
                        type=PosEncoder.from_string, choices=list(PosEncoder), required=False)

    # policy cls parameters
    parser.add_argument("--act_model_type", dest="act_model_type", default=ModelType.MEAN_GNN,
                        type=ModelType.from_string, choices=list(ModelType), required=False)
    parser.add_argument("--act_num_layers", dest="act_num_layers", default=1, type=int, required=False)
    parser.add_argument("--act_dim", dest="act_dim", default=16, type=int, required=False)

    # reproduce
    parser.add_argument("--seed", dest="seed", type=int, default=0, required=False)
    

    # dataset dependant parameters 
    parser.add_argument("--fold", dest="fold", default=None, type=int, required=False)

    # optimizer and scheduler
    parser.add_argument("--weight_decay", dest="weight_decay", default=0, type=float, required=False)
    ## for steplr scheduler only
    parser.add_argument("--step_size", dest="step_size", default=None, type=int, required=False)
    parser.add_argument("--gamma", dest="gamma", default=None, type=float, required=False)
    ## for cosine with warmup scheduler only
    parser.add_argument("--num_warmup_epochs", dest="num_warmup_epochs", default=None, type=int, required=False)

    return parser.parse_args()

args = parse_arguments()

---

In order to later on be able to set a seed for the experiments, we will define this function:

In [None]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = False

For the $\text{\textcolor{red}{checmical dataset}}$ we will need an atom and bond encoder:

In [None]:
class AtomEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(AtomEncoder, self).__init__()

        self.atom_embedding_list = torch.nn.ModuleList()

        for _, dim in enumerate(get_atom_feature_dims()):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.atom_embedding_list.append(emb)

    def forward(self, x, pestat):
        x_embedding = 0
        for i in range(x.shape[1]):
            x_embedding += self.atom_embedding_list[i](x[:, i])

        return x_embedding

class BondEncoder(torch.nn.Module):

    def __init__(self, emb_dim):
        super(BondEncoder, self).__init__()

        self.bond_embedding_list = torch.nn.ModuleList()

        for _, dim in enumerate(get_bond_feature_dims()):
            emb = torch.nn.Embedding(dim, emb_dim)
            torch.nn.init.xavier_uniform_(emb.weight.data)
            self.bond_embedding_list.append(emb)

    def forward(self, edge_attr):
        bond_embedding = 0
        for i in range(edge_attr.shape[1]):
            bond_embedding += self.bond_embedding_list[i](edge_attr[:, i])

        return bond_embedding

Now, let us define some encoder models:

In [None]:
class EncoderLinear(torch.nnLinear):
    def forward(self, x: torch.Tensor, pestat=None) -> torch.Tensor:
        return super().forward(x)

class DataSetEncoders(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    MOL = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return DataSetEncoders[s]
        except KeyError:
            raise ValueError()

    def node_encoder(self, in_dim: int, emb_dim: int):
        if self is DataSetEncoders.NONE:
            return EncoderLinear(in_features=in_dim, out_features=emb_dim)
        elif self is DataSetEncoders.MOL:
            return AtomEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def edge_encoder(self, emb_dim: int, model_type):
        if self is DataSetEncoders.NONE:
            return None
        elif self is DataSetEncoders.MOL:
            if model_type.is_gcn():
                return None
            else:
                return BondEncoder(emb_dim)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def use_encoders(self) -> bool:
        return self is not DataSetEncoders.NONE

class PosEncoder(Enum):
    """
        an object for the different encoders
    """
    NONE = auto()
    LAP = auto()
    RWSE = auto()

    @staticmethod
    def from_string(s: str):
        try:
            return PosEncoder[s]
        except KeyError:
            raise ValueError()

    def get(self, in_dim: int, emb_dim: int, expand_x: bool):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LapPENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        elif self is PosEncoder.RWSE:
            return RWSENodeEncoder(dim_in=in_dim, dim_emb=emb_dim, expand_x=expand_x)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def DIM_PE(self):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return LAP_DIM_PE
        elif self is PosEncoder.RWSE:
            return KER_DIM_PE
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

    def get_pe(self, data: Data, device):
        if self is PosEncoder.NONE:
            return None
        elif self is PosEncoder.LAP:
            return [data.EigVals.to(device), data.EigVecs.to(device)]
        elif self is PosEncoder.RWSE:
            return data.pestat_RWSE.to(device)
        else:
            raise ValueError(f'DataSetEncoders {self.name} not supported')

For the $\textcolor{red}{???}$, we will need some parameters to which end we define a class container.

In [None]:
class GumbelArgs(NamedTuple):
    learn_temp: bool
    temp_model_type: ModelType
    tau0: float
    temp: float
    gin_mlp_func: Callable

Similarly, we will define a container class for the arguments required to set up the environment:

In [None]:
class EnvArgs(NamedTuple):
    model_type: ModelType
    num_layers: int
    env_dim: int

    layer_norm: bool
    skip: bool
    batch_norm: bool
    dropout: float
    act_type: ActivationType
    dec_num_layers: int
    pos_enc: PosEncoder
    dataset_encoders: DataSetEncoders

    metric_type: MetricType
    in_dim: int
    out_dim: int

    gin_mlp_func: Callable

    def load_net(self) -> torch.nn.ModuleList:
        if self.pos_enc is PosEncoder.NONE:
            enc_list = [self.dataset_encoders.node_encoder(in_dim=self.in_dim, emb_dim=self.env_dim)]
        else:
            if self.dataset_encoders is DataSetEncoders.NONE:
                enc_list = [self.pos_enc.get(in_dim=self.in_dim, emb_dim=self.env_dim)]
            else:
                enc_list = [Concat2NodeEncoder(enc1_cls=self.dataset_encoders.node_encoder,
                                               enc2_cls=self.pos_enc.get,
                                               in_dim=self.in_dim, emb_dim=self.env_dim,
                                               enc2_dim_pe=self.pos_enc.DIM_PE())]

        component_list =\
            self.model_type.get_component_list(in_dim=self.env_dim, hidden_dim=self.env_dim,  out_dim=self.env_dim,
                                               num_layers=self.num_layers, bias=True, edges_required=True,
                                               gin_mlp_func=self.gin_mlp_func)

        if self.dec_num_layers > 1:
            mlp_list = (self.dec_num_layers - 1) * [torch.nn.Linear(self.env_dim, self.env_dim),
                                                    torch.nn.Dropout(self.dropout), self.act_type.nn()]
            mlp_list = mlp_list + [torch.nn.Linear(self.env_dim, self.out_dim)]
            dec_list = [torch.nn.Sequential(*mlp_list)]
        else:
            dec_list = [torch.nn.Linear(self.env_dim, self.out_dim)]

        return torch.nn.ModuleList(enc_list + component_list + dec_list)

Now we will set up the experiment:

In [None]:
class Experiment(object):
    def __init__(self, args: Namespace):
        super().__init__()
        for arg in vars(args):
            value_arg = getattr(args, arg)
            print(f"{arg}: {value_arg}")
            self.__setattr__(arg, value_arg)
        
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        set_seed(seed=self.seed)

        # parameters
        self.metric_type = self.dataset.get_metric_type()
        self.decimal = self.dataset.num_after_decimal()
        self.task_loss = self.metric_type.get_task_loss()

        # asserts
        self.dataset.asserts(args)

    def run(self) -> Tuple[torch.Tensor, torch.Tensor]:
        dataset = self.dataset.load(seed=self.seed, pos_enc=self.pos_enc)
        if self.metric_type.is_multilabel():
            dataset.data.y = dataset.data.y.to(dtype=torch.float)

        folds = self.dataset.get_folds(fold=self.fold)

        # locally used parameters
        out_dim = self.metric_type.get_out_dim(dataset=dataset)
        gin_mlp_func = self.dataset.gin_mlp_func()
        env_act_type = self.dataset.env_activation_type()

        # named tuples
        gumbel_args = GumbelArgs(learn_temp=self.learn_temp, temp_model_type=self.temp_model_type, tau0=self.tau0,
                                 temp=self.temp, gin_mlp_func=gin_mlp_func)
        env_args = \
            EnvArgs(model_type=self.env_model_type, num_layers=self.env_num_layers, env_dim=self.env_dim,
                    layer_norm=self.layer_norm, skip=self.skip, batch_norm=self.batch_norm, dropout=self.dropout,
                    act_type=env_act_type, metric_type=self.metric_type, in_dim=dataset[0].x.shape[1], out_dim=out_dim,
                    gin_mlp_func=gin_mlp_func, dec_num_layers=self.dec_num_layers, pos_enc=self.pos_enc,
                    dataset_encoders=self.dataset.get_dataset_encoders())
        action_args = \
            ActionNetArgs(model_type=self.act_model_type, num_layers=self.act_num_layers,
                          hidden_dim=self.act_dim, dropout=self.dropout, act_type=ActivationType.RELU,
                          env_dim=self.env_dim, gin_mlp_func=gin_mlp_func)

        # folds
        metrics_list = []
        edge_ratios_list = []
        for num_fold in folds:
            set_seed(seed=self.seed)
            dataset_by_split = self.dataset.select_fold_and_split(num_fold=num_fold, dataset=dataset)
            best_losses_n_metrics, edge_ratios =\
                self.single_fold(dataset_by_split=dataset_by_split, gumbel_args=gumbel_args, env_args=env_args,
                                 action_args=action_args, num_fold=num_fold)

            # print final
            print_str = f'Fold {num_fold}/{len(folds)}'
            for name in best_losses_n_metrics._fields:
                print_str += f",{name}={round(getattr(best_losses_n_metrics, name), self.decimal)}"
            print(print_str)
            print()
            metrics_list.append(best_losses_n_metrics.get_fold_metrics())

            if edge_ratios is not None:
                edge_ratios_list.append(edge_ratios)

        metrics_matrix = torch.stack(metrics_list, dim=0)  # (F, 3)
        metrics_mean = torch.mean(metrics_matrix, dim=0).tolist()  # (3,)
        if len(edge_ratios_list) > 0:
            edge_ratios = torch.mean(torch.stack(edge_ratios_list, dim=0), dim=0)
        else:
            edge_ratios = None

        # prints
        print(f'Final Rewired train={round(metrics_mean[0], self.decimal)},'
              f'val={round(metrics_mean[1], self.decimal)},'
              f'test={round(metrics_mean[2], self.decimal)}')
        if len(folds) > 1:
            metrics_std = torch.std(metrics_matrix, dim=0).tolist()  # (3,)
            print(f'Final Rewired train={round(metrics_mean[0], self.decimal)}+-{round(metrics_std[0], self.decimal)},'
                  f'val={round(metrics_mean[1], self.decimal)}+-{round(metrics_std[1], self.decimal)},'
                  f'test={round(metrics_mean[2], self.decimal)}+-{round(metrics_std[2], self.decimal)}')
    
        return metrics_mean, edge_ratios
            
    def single_fold(self, dataset_by_split: DatasetBySplit, gumbel_args: GumbelArgs, env_args: EnvArgs,
                    action_args: ActionNetArgs, num_fold: int) -> Tuple[LossesAndMetrics, OptTensor]:
        model = CoGNN(gumbel_args=gumbel_args, env_args=env_args, action_args=action_args,
                      pool=self.pool).to(device=self.device)

        optimizer = self.dataset.optimizer(model=model, lr=self.lr, weight_decay=self.weight_decay)
        scheduler = self.dataset.scheduler(optimizer=optimizer, step_size=self.step_size, gamma=self.gamma,
                                           num_warmup_epochs=self.num_warmup_epochs, max_epochs=self.max_epochs)

        with tqdm.tqdm(total=self.max_epochs, file=sys.stdout) as pbar:
            best_losses_n_metrics, edge_ratios =\
                self.train_and_test(dataset_by_split=dataset_by_split, model=model, optimizer=optimizer,
                                    scheduler=scheduler, pbar=pbar, num_fold=num_fold)
        return best_losses_n_metrics, edge_ratios

    def train_and_test(self, dataset_by_split: DatasetBySplit, model, optimizer, scheduler, pbar, num_fold: int)\
            -> Tuple[LossesAndMetrics, OptTensor]:
        train_loader = DataLoader(dataset_by_split.train, batch_size=self.batch_size, shuffle=True)
        val_loader = DataLoader(dataset_by_split.val, batch_size=self.batch_size, shuffle=True)
        test_loader = DataLoader(dataset_by_split.test, batch_size=self.batch_size, shuffle=True)

        best_losses_n_metrics = self.metric_type.get_worst_losses_n_metrics()
        for epoch in range(self.max_epochs):
            self.train(train_loader=train_loader, model=model, optimizer=optimizer)
            train_loss, train_metric, _ =\
                self.test(loader=train_loader, model=model, split_mask_name='train_mask', calc_edge_ratio=False)
            if self.dataset.is_expressivity():
                val_loss, val_metric = train_loss, train_metric
                test_loss, test_metric = train_loss, train_metric
            else:
                val_loss, val_metric, _ =\
                    self.test(loader=val_loader, model=model, split_mask_name='val_mask', calc_edge_ratio=False)
                test_loss, test_metric, _ =\
                    self.test(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=False)

            losses_n_metrics = \
                LossesAndMetrics(train_loss=train_loss, val_loss=val_loss, test_loss=test_loss,
                                 train_metric=train_metric, val_metric=val_metric, test_metric=test_metric)
            if scheduler is not None:
                scheduler.step(losses_n_metrics.val_metric)

            # best metrics
            if self.metric_type.src_better_than_other(src=losses_n_metrics.val_metric,
                                                      other=best_losses_n_metrics.val_metric):
                best_losses_n_metrics = losses_n_metrics

            # prints
            log_str = f'Split: {num_fold}, epoch: {epoch}'
            for name in losses_n_metrics._fields:
                log_str += f",{name}={round(getattr(losses_n_metrics, name), self.decimal)}"
            log_str += f"({round(best_losses_n_metrics.test_metric, self.decimal)})"
            pbar.set_description(log_str)
            pbar.update(n=1)

        edge_ratios = None
        if self.dataset.not_synthetic():
            _, _, edge_ratios =\
                self.test(loader=test_loader, model=model, split_mask_name='test_mask', calc_edge_ratio=True)

        return best_losses_n_metrics, edge_ratios

    def train(self, train_loader, model, optimizer):
        model.train()

        for data in train_loader:
            if self.batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
                continue
            optimizer.zero_grad()
            node_mask = self.dataset.get_split_mask(data=data, batch_size=data.num_graphs,
                                                    split_mask_name='train_mask').to(self.device)
            edge_attr = data.edge_attr
            if data.edge_attr is not None:
                edge_attr = edge_attr.to(device=self.device)

            # forward
            scores, _ =\
                model(data.x.to(device=self.device), edge_index=data.edge_index.to(device=self.device),
                      batch=data.batch.to(device=self.device), edge_attr=edge_attr, edge_ratio_node_mask=None,
                      pestat=self.pos_enc.get_pe(data=data, device=self.device))
            train_loss = self.task_loss(scores[node_mask], data.y.to(device=self.device)[node_mask])

            # backward
            train_loss.backward()
            if self.dataset.clip_grad():
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
            optimizer.step()

    def test(self, loader, model, split_mask_name: str, calc_edge_ratio: bool)\
            -> Tuple[float, Any, torch.Tensor]:
        model.eval()

        total_loss, total_metric, total_edge_ratios = 0, 0, 0
        total_scores = np.empty(shape=(0, model.env_args.out_dim))
        total_y = None
        for data in loader:
            if self.batch_norm and (data.x.shape[0] == 1 or data.num_graphs == 1):
                continue
            node_mask = self.dataset.get_split_mask(data=data, batch_size=data.num_graphs,
                                                    split_mask_name=split_mask_name).to(device=self.device)
            if calc_edge_ratio:
                edge_ratio_node_mask =\
                    self.dataset.get_edge_ratio_node_mask(data=data, split_mask_name=split_mask_name).to(device=self.device)
            else:
                edge_ratio_node_mask = None
            edge_attr = data.edge_attr
            if data.edge_attr is not None:
                edge_attr = edge_attr.to(device=self.device)

            # forward
            scores, edge_ratios =\
                model(data.x.to(device=self.device), edge_index=data.edge_index.to(device=self.device),
                      edge_attr=edge_attr, batch=data.batch.to(device=self.device),
                      edge_ratio_node_mask=edge_ratio_node_mask,
                      pestat=self.pos_enc.get_pe(data=data, device=self.device))
            
            eval_loss = self.task_loss(scores, data.y.to(device=self.device))

            # analytics
            total_scores = np.concatenate((total_scores, scores[node_mask].detach().cpu().numpy()))
            if total_y is None:
                total_y = data.y.to(device=self.device)[node_mask].detach().cpu().numpy()
            else:
                total_y = np.concatenate((total_y, data.y.to(device=self.device)[node_mask].detach().cpu().numpy()))

            total_loss += eval_loss.item() * data.num_graphs
            total_edge_ratios += edge_ratios * data.num_graphs

        metric = self.metric_type.apply_metric(scores=total_scores, target=total_y)

        loss = total_loss / len(loader.dataset)
        edge_ratios = total_edge_ratios / len(loader.dataset)
        return loss, metric, edge_ratios

Next, we run the experiment:

In [None]:
Experiment(args=args).run()