# Cooperative Graph Neural Networks ([CoGNN](https://doi.org/10.48550/arXiv.2310.01267))

This part was adapted by Tobias Erbacher from the [authors' github](https://github.com/benfinkelshtein/CoGNN/tree/main). We recommend to run this on a GPU service like [Google Colab](https://colab.research.google.com/). The goal is to modify the original authors' code to work with our codebase.

## Installation

To ensure we are using the same modules as the authors, we need to install the following:

In [1]:
#%pip install torch==2.0.0 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
#%pip install torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.0.0+cu118.html
#%pip install torch-geometric==2.3.0
#%pip install torchmetrics ogb rdkit
#%pip install matplotlib

To execute this notebook, we will assume that the datasets are already installed. If you cloned our github repository, you will be able to find them in [this folder](https://github.com/TobiasErbacher/gdl/tree/main/replication/data).

---

## Initialization

### Custom Files

To make this notebook less crowded, some class definitions have been moved to other files. The overview here is to get an idea of the library dependencies. The file `architectures.py` will perform the following imports:

- from `torch` import `Tensor, cat`
- from `torch.nn` import `Linear, Parameter`
- from `typing` import `Callable, List`
- from `torch_geometric.typing` import `OptTensor, Adj`
- from `torch_geometric.nn.conv` import `MessagePassing`
- from `torch_geometric.nn.conv.gcn_conv` import `gcn_norm`
- from `torch_geometric.utils` import `remove_self_loops, add_remaining_self_loops`

Now, the type classes will be bundled in the file `type_classes.py` and perform the following imports:

- from `torch` import `from_numpy, tensor`
- from `torch.nn` import `Module, ReLU, GELU, CrossEntropyLoss`
- import `torch.nn.functional` as `F`
- from `torchmetrics` import `Accuracy`
- from `torch_geometric.data` import `Data`
- from `torch_geometric.nn.pool` import `global_mean_pool, global_add_pool`
- import `numpy` as `np`
- from `math` import `inf`
- from `enum` import `Enum, auto`
- from `typing` import `Callable, List, NamedTuple`
- from `architectures` import `WeightedGCNConv, WeightedGINConv, WeightedGNNConv, GraphLinear, BatchIdentity`

Note that the import from `architectures` is from the file `architectures.py`.

Next up, the encoders will be bundled in the file `encoder_classes.py` and perform the following imports:

- from `enum` import `Enum, auto`
- from `torch` import `cat, rand, isnan, sum, Tensor`
- from `torch.nn` import `Module, ModuleList, Linear, Sequential, BatchNorm1d, ReLU, TransformerEncoder, TransformerEncoderLayer, Embedding`
- from `torch.nn.init` import `xavier_uniform_`
- from `ogb.utils.features` import `get_atom_feature_dims, get_bond_feature_dims`
- from `torch_geometric.data` import `Data`

### Importing the Custom Files

In [2]:
from type_classes import ModelType, LossesAndMetrics, Pool, ActivationType, Metric
from encoder_classes import PosEncoder, DataSetEncoders

### Importing the Official Libraries

We first import the required libraries:

In [3]:
import torch
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import from_scipy_sparse_matrix
from torch_geometric.loader import DataLoader
from torch_geometric.typing import OptTensor, Adj
from torch_geometric.transforms import NormalizeFeatures
from scipy.sparse import csr_matrix

import os
import sys
import tqdm
import numpy as np
import random
from typing import Union, List, Tuple, NamedTuple, Callable

We first check whether a GPU is available and if so then set it as the default device:

Next, we need to define the root directory `ROOT_DIR`:

In [4]:
try:
    # If running as a .py
    ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
except:
    # If running as a .ipynb
    from os import getcwd
    ROOT_DIR = os.path.dirname(os.path.abspath(getcwd()))

To ensure reproducibility, we define the `set_seed()` function:

In [5]:
def set_seed(seed):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = False

The device to be used for training (GPU, ...) will be set automatically in the `Experiment` class (see below).

### Dataset Loader

The following functions are copied from the replication/data_loading/ folder.

In [6]:
from torch_geometric.data import InMemoryDataset
import scipy.sparse as sp
import scipy.sparse.linalg as spla
from pathlib import Path
from typing import Any, Dict
import warnings

data_dir = os.path.join(Path(getcwd()).parent, "replication/data")
development_seed = 4143496719

sparse_graph_properties = [
        'adj_matrix', 'attr_matrix', 'labels',
        'node_names', 'attr_names', 'class_names',
        'metadata']

class SparseGraph:
    """Attributed labeled graph stored in sparse matrix form.

    Parameters
    ----------
    adj_matrix
        Adjacency matrix in CSR format. Shape [num_nodes, num_nodes]
    attr_matrix
        Attribute matrix in CSR or numpy format. Shape [num_nodes, num_attr]
    labels
        Array, where each entry represents respective node's label(s). Shape [num_nodes]
        Alternatively, CSR matrix with labels in one-hot format. Shape [num_nodes, num_classes]
    node_names
        Names of nodes (as strings). Shape [num_nodes]
    attr_names
        Names of the attributes (as strings). Shape [num_attr]
    class_names
        Names of the class labels (as strings). Shape [num_classes]
    metadata
        Additional metadata such as text.

    """
    def __init__(
            self, adj_matrix: sp.spmatrix,
            attr_matrix: Union[np.ndarray, sp.spmatrix] = None,
            labels: Union[np.ndarray, sp.spmatrix] = None,
            node_names: np.ndarray = None,
            attr_names: np.ndarray = None,
            class_names: np.ndarray = None,
            metadata: Any = None):
        # Make sure that the dimensions of matrices / arrays all agree
        if sp.isspmatrix(adj_matrix):
            adj_matrix = adj_matrix.tocsr().astype(np.float32)
        else:
            raise ValueError("Adjacency matrix must be in sparse format (got {0} instead)."
                             .format(type(adj_matrix)))

        if adj_matrix.shape[0] != adj_matrix.shape[1]:
            raise ValueError("Dimensions of the adjacency matrix don't agree.")

        if attr_matrix is not None:
            if sp.isspmatrix(attr_matrix):
                attr_matrix = attr_matrix.tocsr().astype(np.float32)
            elif isinstance(attr_matrix, np.ndarray):
                attr_matrix = attr_matrix.astype(np.float32)
            else:
                raise ValueError("Attribute matrix must be a sp.spmatrix or a np.ndarray (got {0} instead)."
                                 .format(type(attr_matrix)))

            if attr_matrix.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency and attribute matrices don't agree.")

        if labels is not None:
            if labels.shape[0] != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the label vector don't agree.")

        if node_names is not None:
            if len(node_names) != adj_matrix.shape[0]:
                raise ValueError("Dimensions of the adjacency matrix and the node names don't agree.")

        if attr_names is not None:
            if len(attr_names) != attr_matrix.shape[1]:
                raise ValueError("Dimensions of the attribute matrix and the attribute names don't agree.")

        self.adj_matrix = adj_matrix
        self.attr_matrix = attr_matrix
        self.labels = labels
        self.node_names = node_names
        self.attr_names = attr_names
        self.class_names = class_names
        self.metadata = metadata

    def num_nodes(self) -> int:
        """Get the number of nodes in the graph.
        """
        return self.adj_matrix.shape[0]

    def num_edges(self) -> int:
        """Get the number of edges in the graph.

        For undirected graphs, (i, j) and (j, i) are counted as _two_ edges.

        """
        return self.adj_matrix.nnz

    def get_neighbors(self, idx: int) -> np.ndarray:
        """Get the indices of neighbors of a given node.

        Parameters
        ----------
        idx
            Index of the node whose neighbors are of interest.

        """
        return self.adj_matrix[idx].indices

    def get_edgeid_to_idx_array(self) -> np.ndarray:
        """Return a Numpy Array that maps edgeids to the indices in the adjacency matrix.

        Returns
        -------
        np.ndarray
            The i'th entry contains the x- and y-coordinates of edge i in the adjacency matrix.
            Shape [num_edges, 2]

        """
        return np.transpose(self.adj_matrix.nonzero())

    def is_directed(self) -> bool:
        """Check if the graph is directed (adjacency matrix is not symmetric).
        """
        return (self.adj_matrix != self.adj_matrix.T).sum() != 0

    def to_undirected(self) -> 'SparseGraph':
        """Convert to an undirected graph (make adjacency matrix symmetric).
        """
        idx = self.get_edgeid_to_idx_array().T
        ridx = np.ravel_multi_index(idx, self.adj_matrix.shape)
        ridx_rev = np.ravel_multi_index(idx[::-1], self.adj_matrix.shape)

        # Get duplicate edges (self-loops and opposing edges)
        dup_ridx = ridx[np.isin(ridx, ridx_rev)]
        dup_idx = np.unravel_index(dup_ridx, self.adj_matrix.shape)

        # Check if the adjacency matrix weights are symmetric (if nonzero)
        if len(dup_ridx) > 0 and not np.allclose(self.adj_matrix[dup_idx], self.adj_matrix[dup_idx[::-1]]):
            raise ValueError("Adjacency matrix weights of opposing edges differ.")

        # Create symmetric matrix
        new_adj_matrix = self.adj_matrix + self.adj_matrix.T
        if len(dup_ridx) > 0:
            new_adj_matrix[dup_idx] = (new_adj_matrix[dup_idx] - self.adj_matrix[dup_idx]).A1

        self.adj_matrix = new_adj_matrix
        return self

    def is_weighted(self) -> bool:
        """Check if the graph is weighted (edge weights other than 1).
        """
        return np.any(np.unique(self.adj_matrix[self.adj_matrix.nonzero()].A1) != 1)

    def to_unweighted(self) -> 'SparseGraph':
        """Convert to an unweighted graph (set all edge weights to 1).
        """
        self.adj_matrix.data = np.ones_like(self.adj_matrix.data)
        return self

    def is_connected(self) -> bool:
        """Check if the graph is connected.
        """
        return sp.csgraph.connected_components(self.adj_matrix, return_labels=False) == 1

    def has_self_loops(self) -> bool:
        """Check if the graph has self-loops.
        """
        return not np.allclose(self.adj_matrix.diagonal(), 0)

    def __repr__(self) -> str:
        props = []
        for prop_name in sparse_graph_properties:
            prop = getattr(self, prop_name)
            if prop is not None:
                if prop_name == 'metadata':
                    props.append(prop_name)
                else:
                    shape_string = 'x'.join([str(x) for x in prop.shape])
                    props.append("{} ({})".format(prop_name, shape_string))
        dir_string = 'Directed' if self.is_directed() else 'Undirected'
        weight_string = 'weighted' if self.is_weighted() else 'unweighted'
        conn_string = 'connected' if self.is_connected() else 'disconnected'
        loop_string = 'has self-loops' if self.has_self_loops() else 'no self-loops'
        return ("<{}, {} and {} SparseGraph with {} edges ({}). Data: {}>"
                .format(dir_string, weight_string, conn_string,
                        self.num_edges(), loop_string,
                        ', '.join(props)))

    # Quality of life (shortcuts)
    def standardize(
            self, make_unweighted: bool = True,
            make_undirected: bool = True,
            no_self_loops: bool = True,
            select_lcc: bool = True
            ) -> 'SparseGraph':
        """Perform common preprocessing steps: remove self-loops, make unweighted/undirected, select LCC.

        All changes are done inplace.

        Parameters
        ----------
        make_unweighted
            Whether to set all edge weights to 1.
        make_undirected
            Whether to make the adjacency matrix symmetric. Can only be used if make_unweighted is True.
        no_self_loops
            Whether to remove self loops.
        select_lcc
            Whether to select the largest connected component of the graph.

        """
        G = self
        if make_unweighted and G.is_weighted():
            G = G.to_unweighted()
        if make_undirected and G.is_directed():
            G = G.to_undirected()
        if no_self_loops and G.has_self_loops():
            G = remove_self_loops(G)
        if select_lcc and not G.is_connected():
            G = largest_connected_components(G, 1)
        return G

    def unpack(self) -> Tuple[sp.csr_matrix,
                              Union[np.ndarray, sp.csr_matrix],
                              Union[np.ndarray, sp.csr_matrix]]:
        """Return the (A, X, E, z) quadruplet.
        """
        return self.adj_matrix, self.attr_matrix, self.labels

    def to_flat_dict(self) -> Dict[str, Any]:
        """Return flat dictionary containing all SparseGraph properties.
        """
        data_dict = {}
        for key in sparse_graph_properties:
            val = getattr(self, key)
            if sp.isspmatrix(val):
                data_dict['{}.data'.format(key)] = val.data
                data_dict['{}.indices'.format(key)] = val.indices
                data_dict['{}.indptr'.format(key)] = val.indptr
                data_dict['{}.shape'.format(key)] = val.shape
            else:
                data_dict[key] = val
        return data_dict

    @staticmethod
    def from_flat_dict(data_dict: Dict[str, Any]) -> 'SparseGraph':
        """Initialize SparseGraph from a flat dictionary.
        """
        init_dict = {}
        del_entries = []

        # Construct sparse matrices
        for key in data_dict.keys():
            if key.endswith('_data') or key.endswith('.data'):
                if key.endswith('_data'):
                    sep = '_'
                    warnings.warn(
                            "The separator used for sparse matrices during export (for .npz files) "
                            "is now '.' instead of '_'. Please update (re-save) your stored graphs.",
                            DeprecationWarning, stacklevel=2)
                else:
                    sep = '.'
                matrix_name = key[:-5]
                mat_data = key
                mat_indices = '{}{}indices'.format(matrix_name, sep)
                mat_indptr = '{}{}indptr'.format(matrix_name, sep)
                mat_shape = '{}{}shape'.format(matrix_name, sep)
                if matrix_name == 'adj' or matrix_name == 'attr':
                    warnings.warn(
                            "Matrices are exported (for .npz files) with full names now. "
                            "Please update (re-save) your stored graphs.",
                            DeprecationWarning, stacklevel=2)
                    matrix_name += '_matrix'
                init_dict[matrix_name] = sp.csr_matrix(
                        (data_dict[mat_data],
                         data_dict[mat_indices],
                         data_dict[mat_indptr]),
                        shape=data_dict[mat_shape])
                del_entries.extend([mat_data, mat_indices, mat_indptr, mat_shape])

        # Delete sparse matrix entries
        for del_entry in del_entries:
            del data_dict[del_entry]

        # Load everything else
        for key, val in data_dict.items():
            if ((val is not None) and (None not in val)):
                init_dict[key] = val

        # Check if the dictionary contains only entries in sparse_graph_properties
        unknown_keys = [key for key in init_dict.keys() if key not in sparse_graph_properties]
        if len(unknown_keys) > 0:
            raise ValueError("Input dictionary contains keys that are not SparseGraph properties ({})."
                             .format(unknown_keys))

        return SparseGraph(**init_dict)

def create_subgraph(
        sparse_graph: SparseGraph,
        _sentinel: None = None,
        nodes_to_remove: np.ndarray = None,
        nodes_to_keep: np.ndarray = None
        ) -> SparseGraph:
    """Create a graph with the specified subset of nodes.

    Exactly one of (nodes_to_remove, nodes_to_keep) should be provided, while the other stays None.
    Note that to avoid confusion, it is required to pass node indices as named arguments to this function.

    The subgraph partially points to the old graph's data.

    Parameters
    ----------
    sparse_graph
        Input graph.
    _sentinel
        Internal, to prevent passing positional arguments. Do not use.
    nodes_to_remove
        Indices of nodes that have to removed.
    nodes_to_keep
        Indices of nodes that have to be kept.

    Returns
    -------
    SparseGraph
        Graph with specified nodes removed.

    """
    # Check that arguments are passed correctly
    if _sentinel is not None:
        raise ValueError("Only call `create_subgraph` with named arguments',"
                         " (nodes_to_remove=...) or (nodes_to_keep=...).")
    if nodes_to_remove is None and nodes_to_keep is None:
        raise ValueError("Either nodes_to_remove or nodes_to_keep must be provided.")
    elif nodes_to_remove is not None and nodes_to_keep is not None:
        raise ValueError("Only one of nodes_to_remove or nodes_to_keep must be provided.")
    elif nodes_to_remove is not None:
        nodes_to_keep = [i for i in range(sparse_graph.num_nodes()) if i not in nodes_to_remove]
    elif nodes_to_keep is not None:
        nodes_to_keep = sorted(nodes_to_keep)
    else:
        raise RuntimeError("This should never happen.")

    sparse_graph.adj_matrix = sparse_graph.adj_matrix[nodes_to_keep][:, nodes_to_keep]
    if sparse_graph.attr_matrix is not None:
        sparse_graph.attr_matrix = sparse_graph.attr_matrix[nodes_to_keep]
    if sparse_graph.labels is not None:
        sparse_graph.labels = sparse_graph.labels[nodes_to_keep]
    if sparse_graph.node_names is not None:
        sparse_graph.node_names = sparse_graph.node_names[nodes_to_keep]
    return sparse_graph


def largest_connected_components(sparse_graph: SparseGraph, n_components: int = 1) -> SparseGraph:
    """Select the largest connected components in the graph.

    Changes are returned in a partially new SparseGraph.

    Parameters
    ----------
    sparse_graph
        Input graph.
    n_components
        Number of largest connected components to keep.

    Returns
    -------
    SparseGraph
        Subgraph of the input graph where only the nodes in largest n_components are kept.

    """
    _, component_indices = sp.csgraph.connected_components(sparse_graph.adj_matrix)
    component_sizes = np.bincount(component_indices)
    components_to_keep = np.argsort(component_sizes)[::-1][:n_components]  # reverse order to sort descending
    nodes_to_keep = [
        idx for (idx, component) in enumerate(component_indices) if component in components_to_keep
    ]
    return create_subgraph(sparse_graph, nodes_to_keep=nodes_to_keep)


def remove_self_loops(sparse_graph: SparseGraph) -> SparseGraph:
    """Remove self loops (diagonal entries in the adjacency matrix).

    Changes are returned in a partially new SparseGraph.

    """
    num_self_loops = (~np.isclose(sparse_graph.adj_matrix.diagonal(), 0)).sum()
    if num_self_loops > 0:
        sparse_graph.adj_matrix = sparse_graph.adj_matrix.tolil()
        sparse_graph.adj_matrix.setdiag(0)
        sparse_graph.adj_matrix = sparse_graph.adj_matrix.tocsr()
        warnings.warn("{0} self loops removed".format(num_self_loops))

    return sparse_graph

def load_from_npz(file_name: str) -> SparseGraph:
    """Load a SparseGraph from a Numpy binary file.

    Parameters
    ----------
    file_name
        Name of the file to load.

    Returns
    -------
    SparseGraph
        Graph in sparse matrix format.

    """
    with np.load(file_name, allow_pickle=True) as loader:
        loader = dict(loader)
        dataset = SparseGraph.from_flat_dict(loader)
    return dataset

def load_dataset(name: str,
                 directory: Union[Path, str] = data_dir
                 ) -> SparseGraph:
    """Load a dataset.

    Parameters
    ----------
    name
        Name of the dataset to load.
    directory
        Path to the directory where the datasets are stored.

    Returns
    -------
    SparseGraph
        The requested dataset in sparse format.

    """
    if isinstance(directory, str):
        directory = Path(directory)
    if not name.endswith('.npz'):
        name += '.npz'
    path_to_file = directory / name
    if path_to_file.exists():
        return load_from_npz(path_to_file)
    else:
        raise ValueError("{} doesn't exist.".format(path_to_file))

def normalize_attributes(attr_matrix):
    epsilon = 1e-12
    if isinstance(attr_matrix, sp.csr_matrix):
        attr_norms = spla.norm(attr_matrix, ord=1, axis=1)
        attr_invnorms = 1 / np.maximum(attr_norms, epsilon)
        attr_mat_norm = attr_matrix.multiply(attr_invnorms[:, np.newaxis])
    else:
        attr_norms = np.linalg.norm(attr_matrix, ord=1, axis=1)
        attr_invnorms = 1 / np.maximum(attr_norms, epsilon)
        attr_mat_norm = attr_matrix * attr_invnorms[:, np.newaxis]
    return attr_mat_norm

def get_dataset(name: str, use_lcc: bool = True) -> InMemoryDataset:
    """
    :param name: The name of the dataset
    :param use_lcc: Largest Connected Component
    :return:
    """
    dataset = InMemoryDataset
    graph = load_dataset(name)
    graph.standardize(select_lcc=use_lcc)  # This was changed from =True to use_lcc by JK
    new_y = torch.LongTensor(graph.labels)
    data = Data(
        x=torch.FloatTensor(normalize_attributes(graph.attr_matrix).toarray()),
        edge_index=torch.LongTensor(graph.get_edgeid_to_idx_array().T),
        y=new_y,
        train_mask=torch.zeros(new_y.size(0), dtype=torch.bool),
        test_mask=torch.zeros(new_y.size(0), dtype=torch.bool),
        val_mask=torch.zeros(new_y.size(0), dtype=torch.bool)
    )
    dataset.data = data
    dataset.name = name
    dataset.num_classes = len(np.unique(new_y))
    return dataset

def set_train_val_test_split(
        seed: int,
        data: Data,
        num_development: int = 1500,
        num_per_class: int = 20) -> Data:
    rnd_state = np.random.RandomState(development_seed)
    num_nodes = data.y.shape[0]
    development_idx = rnd_state.choice(num_nodes, num_development, replace=False)
    test_idx = [i for i in np.arange(num_nodes) if i not in development_idx]

    train_idx = []
    rnd_state = np.random.RandomState(seed)
    for c in range(data.y.max() + 1):
        class_idx = development_idx[np.where(data.y[development_idx].cpu() == c)[0]]
        train_idx.extend(rnd_state.choice(class_idx, num_per_class, replace=False))

    val_idx_tmp = [i for i in development_idx if i not in train_idx]

    val_idx = rnd_state.choice(val_idx_tmp, 500, replace=False)

    def get_mask(idx):
        mask = torch.zeros(num_nodes, dtype=torch.bool)
        mask[idx] = 1
        return mask

    data.train_mask = get_mask(train_idx)
    data.val_mask = get_mask(val_idx)
    data.test_mask = get_mask(test_idx)

    return data

The next one was custom written and initially used to load the data.

In [7]:
def load_dataset_custom(
    dataset_name : str
) -> Data:
    """
        Loads the datasets from .npz files into a PyTorch_Geometric Data-object where
        -   x:              Feature matrix,
        -   y:              Labels,
        -   edge_index:     List of edges,
        -   edge_weight:    List of edge weights.

        The function is currently implemented for the following datasets:
        -   A.Computer
        -   A.Photo
        -   Citeseer
        -   Cora-ML
        -   MS-Academic
        -   PubMed
    """
    path = os.path.join(ROOT_DIR, "replication/data/" + dataset_name + ".npz")
    npz = np.load(path)
    assert isinstance(npz, np.lib.npyio.NpzFile)

    variant1 = [(key in ["adj_data",
                         "adj_indices",
                         "adj_indptr",
                         "adj_shape",
                         "attr_data",
                         "attr_indices",
                         "attr_indptr",
                         "attr_shape",
                         "labels",
                         "class_names"])
                for key in npz.keys()]
    isvar1 = False not in variant1
    
    variant2 = [(key in ["adj_data",
                         "adj_indices",
                         "adj_indptr",
                         "adj_shape",
                         "attr_data",
                         "attr_indices",
                         "attr_indptr",
                         "attr_shape",
                         "labels",
                         "node_names",
                         "attr_names",
                         "class_names"])
                for key in npz.keys()]
    isvar2 = False not in variant2
    
    variant3 = [(key in ["adj_matrix.data",
                         "adj_matrix.indices",
                         "adj_matrix.indptr",
                         "adj_matrix.shape",
                         "attr_matrix.data",
                         "attr_matrix.indices",
                         "attr_matrix.indptr",
                         "attr_matrix.shape",
                         "edge_attr_matrix",
                         "labels",
                         "node_names",
                         "attr_names",
                         "edge_attr_names",
                         "class_names",
                         "metadata"])
                for key in npz.keys()]
    isvar3 = False not in variant3

    match dataset_name:
        case name if name in ["A.Computer", "A.Photo", "MS-Academic"]:
            assert isvar1 or isvar2
            features_csr = csr_matrix((npz["attr_data"], npz["attr_indices"], npz["attr_indptr"]), shape=npz["attr_shape"])
            adjacency_csr = csr_matrix((npz["adj_data"], npz["adj_indices"], npz["adj_indptr"]), shape=npz["adj_shape"])
        case name if name in name in ["Citeseer", "Cora-ML", "PubMed"]:
            assert isvar3
            features_csr = csr_matrix((npz["attr_matrix.data"], npz["attr_matrix.indices"], npz["attr_matrix.indptr"]), shape=npz["attr_matrix.shape"])
            adjacency_csr = csr_matrix((npz["adj_matrix.data"], npz["adj_matrix.indices"], npz["adj_matrix.indptr"]), shape=npz["adj_matrix.shape"])
        case _:
            raise NotImplementedError(f"The dataset cannot be loaded as it contains unexpected keys. The keys obtained are \n{list(npz.keys())}")
    
    x = torch.FloatTensor(features_csr.toarray())
    edge_index, edge_weight = from_scipy_sparse_matrix(adjacency_csr)
    y = torch.LongTensor(npz["labels"])

    data = Data(x=x, edge_index=edge_index, edge_attr=edge_weight, y=y)
    if len(np.unique(x)) > 2:
        transform = NormalizeFeatures()
        data = transform(data)

    num_classes = len(set(npz["labels"]))
    
    return data, num_classes

## Parameters

In this section we will define the parameters to use.

In [8]:
def create_config(
    dataset: str,                                                                           # Only the dataset file name, without the .npz ending.
    device : torch.device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),     # Can override this manually as torch.device object.
    pool: Pool = Pool.NONE,

    # gumbel
    learn_temp=False,
    temp_model_type: ModelType.from_string=ModelType.LIN,
    tau0: float=0.5,
    temp: float=0.01,

    # optimization
    num_epochs: int=1000,
    batch_size: int=32,
    lr: float=1e-3,
    dropout: float=0.2,

    # env cls parameters
    env_model_type: ModelType.from_string=ModelType.MEAN_GNN,
    env_num_layers: int=3,
    env_dim: int=128,
    skip=False,
    batch_norm=False,
    layer_norm=False,
    dec_num_layers: int=1,
    pos_enc: PosEncoder.from_string=PosEncoder.NONE,

    # policy cls parameters
    act_model_type: ModelType=ModelType.MEAN_GNN,
    act_num_layers: int=1,
    act_dim: int=16,

    # reproduce
    seed: int=0,
    gpu: int=None, # In the original argument parser, there is no default value defined here

    # dataset dependant parameters
    fold: int=None,

    # optimizer and scheduler
    weight_decay: float=0.0,
    ## for steplr scheduler only
    step_size: int=None,
    gamma: float=None,
    ## for cosine with warmup scheduler only
    num_warmup_epochs: int=None,

    decimal: int=2
):
    return {
        'dataset' : dataset,
        'device' : device,
        'pool' : pool,
        'learn_temp' : learn_temp,
        'temp_model_type' : temp_model_type,
        'tau0' : tau0,
        'temp' : temp,
        'num_epochs' : num_epochs,
        'batch_size' : batch_size,
        'lr' : lr,
        'dropout' : dropout,
        'env_model_type' : env_model_type,
        'env_num_layers' : env_num_layers,
        'env_dim' : env_dim,
        'skip' : skip,
        'batch_norm' : batch_norm,
        'layer_norm' : layer_norm,
        'dec_num_layers' : dec_num_layers,
        'pos_enc' : pos_enc,
        'act_model_type' : act_model_type,
        'act_num_layers' : act_num_layers,
        'act_dim' : act_dim,
        'seed' : seed,
        'gpu' : gpu,
        'fold' : fold,
        'weight_decay' : weight_decay,
        'step_size' : step_size,
        'gamma' : gamma,
        'num_warmup_epochs' : num_warmup_epochs,
        "decimal" : decimal
    }

## Dataset Class

In [9]:
class DataSet:
    def __init__(self, data: Data):
        self.data = data
        self.family = True
        self.is_node_based = True
        self.not_synthetic = True
        self.is_expressivity = False
        self.clip_grad = False
        self.dataset_encoders = DataSetEncoders.NONE
        self.num_after_decimal = 2
        self.env_activation_type = ActivationType.RELU

    def gin_mlp_func(self) -> Callable:
        def mlp_func(in_channels: int, out_channels: int, bias: bool):
            return torch.nn.Sequential(
                torch.nn.Linear(in_channels, 2 * in_channels, bias=bias),
                torch.nn.BatchNorm1d(2 * in_channels),
                torch.nn.ReLU(), torch.nn.Linear(2 * in_channels, out_channels, bias=bias)
            )
        return mlp_func
    
    def get_split_mask(self, batch_size: int, split_mask_name: str) -> Tensor:
        if hasattr(self.data, split_mask_name):
            return getattr(self.data, split_mask_name)
        elif self.is_node_based():
            return torch.ones(size=(self.data.x.shape[0],), dtype=torch.bool)
        else:
            return torch.ones(size=(batch_size,), dtype=torch.bool)
    
    def get_edge_ratio_node_mask(self, split_mask_name: str) -> Tensor:
        if hasattr(self.data, split_mask_name):
            return getattr(self.data, split_mask_name)
        else:
            return torch.ones(size=(self.data.x.shape[0],), dtype=torch.bool)

## Network Classes

The Action Network (Policy):

In [10]:
def load_act_net(action_args : dict) -> torch.nn.ModuleList:
    model_type = action_args["model_type"]
    env_dim = action_args["env_dim"]
    hidden_dim = action_args["hidden_dim"]
    num_layers = action_args["num_layers"]
    gin_mlp_func = action_args["gin_mlp_func"]

    net = model_type.get_component_list(in_dim=env_dim,
                                        hidden_dim=hidden_dim,
                                        out_dim=2,
                                        num_layers=num_layers,
                                        bias=True,
                                        edges_required=False,
                                        gin_mlp_func=gin_mlp_func)

    return torch.nn.ModuleList(net)

In [11]:
class ActionNet(torch.nn.Module):
    def __init__(self, action_args: dict):
        """
        Create a model which represents the agent's policy.
        """
        super().__init__()
        self.num_layers = action_args["num_layers"]
        self.net = load_act_net(action_args=action_args)
        self.dropout = torch.nn.Dropout(action_args["dropout"])
        self.act = action_args["act_type"].get()

    def forward(self, x: Tensor, edge_index: Adj, env_edge_attr: OptTensor, act_edge_attr: OptTensor) -> Tensor:
        edge_attrs = [env_edge_attr] + (self.num_layers - 1) * [act_edge_attr]
        for idx, (edge_attr, layer) in enumerate(zip(edge_attrs[:-1], self.net[:-1])):
            x = layer(x=x, edge_index=edge_index, edge_attr=edge_attr)
            x = self.dropout(x)
            x = self.act(x)
        x = self.net[-1](x=x, edge_index=edge_index, edge_attr=edge_attrs[-1])
        return x

Optional dynamic temperature calculation for the Gumbel Softmax:

In [12]:
class TempSoftPlus(torch.nn.Module):
    def __init__(self, gumbel_args: dict, env_dim: int):
        super(TempSoftPlus, self).__init__()
        model_list = gumbel_args["temp_model_type"].get_component_list(in_dim=env_dim, hidden_dim=env_dim, out_dim=1, num_layers=1,
                                                                       bias=False, edges_required=False,
                                                                       gin_mlp_func=gumbel_args["gin_mlp_func"])
        self.linear_model = torch.nn.ModuleList(model_list)
        self.softplus = torch.nn.Softplus(beta=1)
        self.tau0 = gumbel_args["tau0"]

    def forward(self, x: Tensor, edge_index: Adj, edge_attr: Tensor):
        x = self.linear_model[0](x=x, edge_index=edge_index,edge_attr = edge_attr)
        x = self.softplus(x) + self.tau0
        temp = x.pow_(-1)
        return temp.masked_fill_(temp == float("inf"), 0.)

## CoGNN

Here we will implement the CoGNN architecture.

In [13]:
def load_env_net(env_args : dict) -> torch.nn.ModuleList:
    in_dim = env_args["in_dim"]
    env_dim = env_args["env_dim"]
    num_layers = env_args["num_layers"]
    gin_mlp_func = env_args["gin_mlp_func"]
    dec_num_layers = env_args["dec_num_layers"]
    dropout = env_args["dropout"]
    act_type = env_args["act_type"]
    out_dim = env_args["out_dim"]

    enc_list = [env_args["dataset_encoders"].node_encoder(in_dim=in_dim, emb_dim=env_dim)]

    component_list = env_args["model_type"].get_component_list(in_dim=env_dim, hidden_dim=env_dim, out_dim=env_dim,
                                                               num_layers=num_layers, bias=True, edges_required=True,
                                                               gin_mlp_func=gin_mlp_func)

    if dec_num_layers > 1:
        mlp_list = (dec_num_layers - 1) * [torch.nn.Linear(env_dim, env_dim), torch.nn.Dropout(dropout), act_type.nn()]
        mlp_list = mlp_list + [torch.nn.Linear(env_dim, out_dim)]
        dec_list = [torch.nn.Sequential(*mlp_list)]
    else:
        dec_list = [torch.nn.Linear(env_dim, out_dim)]

    return torch.nn.ModuleList(enc_list + component_list + dec_list)

In [14]:
class CoGNN(torch.nn.Module):
    def __init__(
        self,
        gumbel_args: dict,
        env_args: dict,
        action_args: dict,
        pool: Pool
    ):
        super(CoGNN, self).__init__()
        self.env_args = env_args
        self.learn_temp = gumbel_args["learn_temp"]
        if self.learn_temp:
            self.temp_model = TempSoftPlus(gumbel_args=gumbel_args, env_dim=env_args["env_dim"])
        self.temp = gumbel_args["temp"]

        self.num_layers = env_args["num_layers"]
        self.env_net = load_env_net(env_args=env_args)
        self.use_encoders = env_args["dataset_encoders"].use_encoders()

        layer_norm_cls = torch.nn.LayerNorm if env_args["layer_norm"] else torch.nn.Identity
        self.hidden_layer_norm = layer_norm_cls(env_args["env_dim"])
        self.skip = env_args["skip"]
        self.drop_ratio = env_args["dropout"]
        self.dropout = torch.nn.Dropout(p=self.drop_ratio)
        self.act = env_args["act_type"].get()
        self.in_act_net = ActionNet(action_args=action_args)
        self.out_act_net = ActionNet(action_args=action_args)

        # Encoder types
        self.dataset_encoder = env_args["dataset_encoders"]
        self.env_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=env_args["env_dim"], model_type=env_args["model_type"])
        self.act_bond_encoder = self.dataset_encoder.edge_encoder(emb_dim=action_args["hidden_dim"], model_type=action_args["model_type"])

        # Pooling function to generate whole-graph embeddings
        self.pooling = pool.get()

    def forward(
        self,
        x: Tensor,
        edge_index: Adj,
        pestat,
        edge_attr: OptTensor = None,
        batch: OptTensor = None,
        edge_ratio_node_mask: OptTensor = None
    ) -> Tuple[Tensor, Tensor]:
        result = 0

        calc_stats = edge_ratio_node_mask is not None
        if calc_stats:
            edge_ratio_edge_mask = edge_ratio_node_mask[edge_index[0]] & edge_ratio_node_mask[edge_index[1]]
            edge_ratio_list = []

        # bond encode
        if edge_attr is None or self.env_bond_encoder is None:
            env_edge_embedding = None
        else:
            env_edge_embedding = self.env_bond_encoder(edge_attr)
        if edge_attr is None or self.act_bond_encoder is None:
            act_edge_embedding = None
        else:
            act_edge_embedding = self.act_bond_encoder(edge_attr)

        # node encode  
        x = self.env_net[0](x, pestat) # (N, F) encoder
        if not self.use_encoders:
            x = self.dropout(x)
            x = self.act(x)

        for gnn_idx in range(self.num_layers):
            x = self.hidden_layer_norm(x)

            # action
            in_logits = self.in_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding) # (N, 2)
            out_logits = self.out_act_net(x=x, edge_index=edge_index, env_edge_attr=env_edge_embedding, act_edge_attr=act_edge_embedding) # (N, 2)

            temp = self.temp_model(x=x, edge_index=edge_index, edge_attr=env_edge_embedding) if self.learn_temp else self.temp
            in_probs = torch.nn.functional.gumbel_softmax(logits=in_logits, tau=temp, hard=True)
            out_probs = torch.nn.functional.gumbel_softmax(logits=out_logits, tau=temp, hard=True)
            edge_weight = self.create_edge_weight(edge_index=edge_index, keep_in_prob=in_probs[:, 0], keep_out_prob=out_probs[:, 0])

            # environment
            out = self.env_net[1 + gnn_idx](x=x, edge_index=edge_index, edge_weight=edge_weight, edge_attr=env_edge_embedding)
            out = self.dropout(out)
            out = self.act(out)

            if calc_stats:
                edge_ratio = edge_weight[edge_ratio_edge_mask].sum() / edge_weight[edge_ratio_edge_mask].shape[0]
                edge_ratio_list.append(edge_ratio.item())

            if self.skip:
                x = x + out
            else:
                x = out

        x = self.hidden_layer_norm(x)
        x = self.pooling(x, batch=batch)
        x = self.env_net[-1](x) # decoder
        result = result + x

        if calc_stats:
            edge_ratio_tensor = torch.tensor(edge_ratio_list, device=x.device)
        else:
            edge_ratio_tensor = -1 * torch.ones(size=(self.num_layers,), device=x.device)

        return result, edge_ratio_tensor

    def create_edge_weight(
        self,
        edge_index: Adj,
        keep_in_prob: Tensor,
        keep_out_prob: Tensor
    ) -> Tensor:
        u, v = edge_index
        edge_in_prob = keep_in_prob[v]
        edge_out_prob = keep_out_prob[u]
        return edge_in_prob * edge_out_prob

## Experiment Definition

The `Experiment` class contains the procedure to load a dataset, run the training and evaluation, as well as return the performance metrics.

In [15]:
def train_CoGNN(
    model : torch.nn.Module,
    task_loss : torch.nn.CrossEntropyLoss,
    dataset : DataSet,
    optimizer : torch.optim.Optimizer,
    config : dict
):
    """
        Training loop for one epoch.
    """
    device = config["device"]

    model.train()
    optimizer.zero_grad()
    data = dataset.data.to(device)

    train_mask = data.train_mask

    predictions, edge_ratio_tensor = model(x=data.x,
                                           edge_index=data.edge_index,
                                           batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device), # dummy batch for pooling in CoGNN
                                           edge_attr=data.edge_attr if 'edge_attr' in data else None,
                                           edge_ratio_node_mask=None,
                                           pestat=config["pos_enc"].get_pe(data=data, device=device))

    train_loss = task_loss(predictions[train_mask], data.y.to(device=device)[train_mask])
    train_loss.backward()

    if config["clip_grad"]:
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
    optimizer.step()

    return model, train_loss.item(), edge_ratio_tensor

In [None]:
def evaluate_CoGNN(
    model : torch.nn.Module,
    metric : Metric,
    dataset : DataSet,
    config : dict,
    split_mask_name : str=None,
    calc_edge_ratio : bool=True,
    node_mask : np.ndarray=None
): 
    """
        Evaluation function of the model.
    """
    device = config["device"]

    model.eval()
    data = dataset.data.to(device)

    mask = dataset.get_split_mask(batch_size=None, split_mask_name=split_mask_name).to(device=device) if node_mask is None else node_mask
    edge_ratio_node_mask = dataset.get_edge_ratio_node_mask(split_mask_name=split_mask_name).to(device) if calc_edge_ratio else None
    edge_attr = data.edge_attr.to(device) if data.edge_attr is not None else None

    with torch.no_grad():
        predictions, edge_ratio_tensor = model(x=data.x,
                                               edge_index=data.edge_index,
                                               batch = torch.zeros(data.x.size(0), dtype=torch.long, device=device), # dummy batch for pooling in CoGNN
                                               edge_attr=edge_attr,
                                               edge_ratio_node_mask=edge_ratio_node_mask,
                                               pestat=config["pos_enc"].get_pe(data=data, device=device))
        
        eval_loss = metric.task_loss(predictions[mask], data.y[mask])
    
    scores_np = predictions[mask].detach().cpu().numpy()
    labels_np = data.y[mask].detach().cpu().numpy()

    accuracy = metric.apply_metric(scores=scores_np, target=labels_np)
    loss = eval_loss.item()
    edge_ratios = edge_ratio_tensor

    return accuracy, loss, edge_ratios

In [17]:
def train_fold_CoGNN(
    config : dict,
    dataset: DataSet,
    model : torch.nn.Module,
    optimizer : torch.optim.Optimizer,
    pbar : tqdm.std.tqdm,
    num_fold: int
) -> Tuple[LossesAndMetrics, OptTensor]:
    """
        Training loop for one fold.
    """

    task_loss = config["metric"].task_loss

    best_losses_n_metrics = config["metric"].get_worst_losses_n_metrics
    for epoch in range(config["num_epochs"]):
        model, epoch_loss, epoch_edge_ratio_tensor = train_CoGNN(model=model, task_loss=task_loss, dataset=dataset, optimizer=optimizer, config=config)

        epoch_train_accuracy, epoch_train_loss, epoch_edge_ratios = evaluate_CoGNN(model=model,
                                                                                   metric=config["metric"],
                                                                                   dataset=dataset,
                                                                                   config=config,
                                                                                   split_mask_name="train_mask",
                                                                                   calc_edge_ratio=False)

        val_loss, val_metric, _ = evaluate_CoGNN(model=model,
                                                 metric=config["metric"],
                                                 dataset=dataset,
                                                 config=config,
                                                 split_mask_name="val_mask",
                                                 calc_edge_ratio=False)

        test_loss, test_metric, _ = evaluate_CoGNN(model=model,
                                                   metric=config["metric"],
                                                   dataset=dataset,
                                                   config=config,
                                                   split_mask_name="test_mask",
                                                   calc_edge_ratio=False)

        losses_n_metrics = LossesAndMetrics(train_loss=epoch_loss,
                                            val_loss=val_loss,
                                            test_loss=test_loss,
                                            train_metric=epoch_train_accuracy,
                                            val_metric=val_metric,
                                            test_metric=test_metric)
        
        if config["metric"].src_better_than_other(src=losses_n_metrics.val_metric,
                                             other=best_losses_n_metrics.val_metric):
            best_losses_n_metrics = losses_n_metrics

        log_str = f"Split: {num_fold}, epoch: {epoch}"
        for name in losses_n_metrics._fields:
            log_str += f",{name}={round(getattr(losses_n_metrics, name), config["decimal"])}"
        log_str += f"({round(best_losses_n_metrics.test_metric, config["decimal"])})"
        pbar.set_description(log_str)
        pbar.update(n=1)
    
    edge_ratios = None
    _, _, edge_ratios = evaluate_CoGNN(model=model,
                                       metric=config["metric"],
                                       dataset=dataset,
                                       config=config,
                                       split_mask_name="test_mask",
                                       calc_edge_ratio=False)

    return best_losses_n_metrics, edge_ratios


In [18]:
def train_model_CoGNN(
    data : Data,
    config : dict,
    num_folds : int
) -> Tuple[torch.nn.Module, dict]:
    """
        Training loop for n-fold with multiple epochs.
    """
    #task_loss = config["metric"].task_loss
    decimal = config["decimal"]
    seeds = [config["seed"] + n for n in range(num_folds)]

    dataset = DataSet(data=data)
    config["clip_grad"] = dataset.clip_grad

    gin_mlp_func = dataset.gin_mlp_func()
    env_act_type = dataset.env_activation_type
    dataset_encoder = dataset.dataset_encoders
    out_dim = config["metric"].get_out_dim(data=data)

    gumbel = ["learn_temp", "temp_model_type", "tau0", "temp"]
    gumbel_add = {"gin_mlp_func" :  gin_mlp_func}
    gumbel_args = {key: config.get(key) for key in gumbel}
    gumbel_args.update(gumbel_add)

    env = ["env_dim", "layer_norm", "skip", "batch_norm", "dropout", "metric", "dec_num_layers", "pos_enc"]
    env_add = {"model_type" :       config["env_model_type"],
               "num_layers" :       config["env_num_layers"],
               "act_type" :         env_act_type,
               "in_dim" :           data.x.shape[1],
               "out_dim" :          out_dim,
               "gin_mlp_func" :     gin_mlp_func,
               "dataset_encoders" : dataset_encoder}
    env_args = {key: config.get(key) for key in env}
    env_args.update(env_add)

    action = ["dropout", "env_dim"]
    action_add = {"model_type" :    config["act_model_type"],
                  "num_layers" :    config["act_num_layers"],
                  "hidden_dim" :    config["act_dim"],
                  "act_type" :      ActivationType.RELU,
                  "gin_mlp_func" :  gin_mlp_func}
    action_args = {key: config.get(key) for key in action}
    action_args.update(action_add)

    metrics_list = []
    edge_ratios_list = []
    for n in range(num_folds):
        set_seed(seed=seeds[n])

        model = CoGNN(gumbel_args=gumbel_args, env_args=env_args, action_args=action_args, pool=config["pool"]).to(device=config["device"])
        optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

        with tqdm.tqdm(total=config["num_epochs"], file=sys.stdout) as pbar:
            best_losses_n_metrics, edge_ratios = train_fold_CoGNN(config=config,
                                                                  dataset=dataset,
                                                                  model=model,
                                                                  optimizer=optimizer,
                                                                  pbar=pbar,
                                                                  num_fold=n)
        
        print_str = f"Fold {n}/{num_folds}"
        for name in best_losses_n_metrics._fields:
            print_str += f",{name}={round(getattr(best_losses_n_metrics, name), decimal)}"
        print(print_str)
        print()
        metrics_list.append(best_losses_n_metrics.get_fold_metrics())

        if edge_ratios is not None:
                edge_ratios_list.append(edge_ratios)
    
    metrics_matrix = torch.stack(metrics_list, dim=0)  # (F, 3)
    metrics_mean = torch.mean(metrics_matrix, dim=0).tolist()  # (3,)
    if len(edge_ratios_list) > 0:
        edge_ratios = torch.mean(torch.stack(edge_ratios_list, dim=0), dim=0)
    else:
        edge_ratios = None
    
    print(f"Final Rewired train={round(metrics_mean[0], decimal)},"
          f"val={round(metrics_mean[1], decimal)},"
          f"test={round(metrics_mean[2], decimal)}")
    
    if num_folds > 1:
        metrics_std = torch.std(metrics_matrix, dim=0).tolist()  # (3,)
        print(f"Final Rewired train={round(metrics_mean[0], decimal)}+-{round(metrics_std[0], decimal)},"
              f"val={round(metrics_mean[1], decimal)}+-{round(metrics_std[1], decimal)},"
              f"test={round(metrics_mean[2], decimal)}+-{round(metrics_std[2], decimal)}")

    # return metrics_mean, edge_ratios
    history = None
    return model, history

In [19]:
def main_CoGNN(
    config : dict,
    num_folds : int=10,
    run_grid_search : bool=False
) -> Tuple[torch.nn.Module, Data, dict]:
    """
        Main function with optional grid search.
    """
    device = config["device"]

    print(f"Loading {config["dataset"]} dataset...")

    # This is where we load the data
    raw_data = get_dataset(name=config["dataset"])
    data = set_train_val_test_split(seed=42, data=raw_data.data)
    data = data.to(device)
    
    config["metric"] = Metric(task="multiclass", num_classes=raw_data.num_classes)

    if run_grid_search:
        print("Running grid search...")
        raise NotImplementedError("Grid search hasn't been implemented yet!")
    
        # List of the parameters and the values to be searched across
        param_list_1 = []
        param_list_2 = []
        # ...

        results = []

        # Adjust the nested loops for the various parameters
        for param_1 in param_list_1:
            for param_2 in param_list_2:
                # ...
                config_copy = config.copy()
                config_copy["param_1_name"] = param_1
                config_copy["param_2_name"] = param_2

                print(f"\nTrying param_1_name={param_1}, param_2_name={param_2} ...")
                model = CoGNN(dataset, config_copy).to(device)
                model, _ = train_model_CoGNN(model, data, config_copy)

                test_acc, test_steps, _, _ = evaluate_CoGNN(model, data, data.test_mask)

                results.append({
                    "param_1_name": param_1,
                    "param_2_name": param_2,
                    # ...
                    "test_acc": test_acc,
                    "test_steps": test_steps
                })

                print(f"Result: Test Acc: {test_acc:.4f}, Test Steps: {test_steps:.2f}")

        best_parameters = max(results, key=lambda x: x["test_acc"])

        print("\nGrid Search Results:")
        for r in results:
            print(f"param_1_name={r["param_1_name"]}, param_2_name={r["param_2_name"]}: " # ...
                  f"Acc={r["test_acc"]:.4f}, Steps={r["test_steps"]:.2f}")

        print(f"\nBest parameters: param_1_name={best_result["param_1_name"]}, " # ...
              f"param_2_name={best_result["param_2_name"]}, "
              f"Accuracy={best_result["test_acc"]:.4f}")

        # config with best parameters
        config["param_1_name"] = best_result["param_1_name"]
        config["param_2_name"] = best_result["param_2_name"]
        # ...    

    # Instantiating the model with obtained config
    print("\nTraining final model...")
    #model = CoGNN(gumbel_args, env_args, action_args, config["pool"]).to(device)

    # Training the model
    model, history = train_model_CoGNN(data=data, config=config, num_folds=num_folds)

    # test eval
    test_acc, test_steps, _, _ = evaluate_CoGNN(model, data, data.test_mask)
    print(f"Test Accuracy: {test_acc:.4f}, Average Steps: {test_steps:.2f}")

    # visualizations
    # TO-DO!

    # optional: impact of various parameters

    return model, data, history

In [None]:
if __name__ == "__main__":
    CONFIG = create_config(dataset="A.Computer.npz", seed=42)
    model, data, history = main_CoGNN(config=CONFIG, num_folds=10, run_grid_search=False)

$\textcolor{red}{\texttt{TO - DO!}}$

- train fold CoGNN and below needs to be adjusted.
- grid search in `main` function. What parameters to search across and what should the values be?
- Check that all quotations are " not '
- Check whether the dataset classification (homophilic etc.) is okay