<a href="https://colab.research.google.com/github/MariiaSaltykova/Saltykova_effective_python/blob/master/12_gnn_mpnn.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Graph Neural Networks (GNNs)
## Part I: Message Passing Neural Networks (MPNNs)
We are going to implement few MPNNs for molecular property prediciton. It's recommended that you're familiar with the recent lectures on GNNs.

# Packages for GNNs
There two very popular packages for GNNs that uses pytorch as a backend:
1. [PyTorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html).
2. [Deep Graph Library](https://www.dgl.ai/pages/start.html) (along with [dgl-lifesci](https://lifesci.dgl.ai/install/index.html).
The former is more stable, the latter has a convenient extension [dgl-lifesci](https://lifesci.dgl.ai/generated/dgllife.utils.CanonicalAtomFeaturizer.html) for molecular data and is generally much more user-friendly. For convenience, we are going to use all three packages, so install appropriate versions of them, please (I recommend installing with pip). If you have issues with installing rdkit (required by dgl-lifesci), you can install rdkit using pip (pip install rdkit).

Some additional packages that we are going to use:

In [2]:

!pip install torchmetrics # or conda install -c conda-forge torchmetrics
!pip install wandb # or conda install -c conda-forge wandb

Collecting torchmetrics
  Downloading torchmetrics-1.2.1-py3-none-any.whl (806 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m806.1/806.1 kB[0m [31m7.7 MB/s[0m eta [36m0:00:00[0m
Collecting lightning-utilities>=0.8.0 (from torchmetrics)
  Downloading lightning_utilities-0.10.0-py3-none-any.whl (24 kB)
Installing collected packages: lightning-utilities, torchmetrics
Successfully installed lightning-utilities-0.10.0 torchmetrics-1.2.1
Collecting wandb
  Downloading wandb-0.16.1-py3-none-any.whl (2.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m13.2 MB/s[0m eta [36m0:00:00[0m
Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)
  Downloading GitPython-3.1.40-py3-none-any.whl (190 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m190.6/190.6 kB[0m [31m14.8 MB/s[0m eta [36m0:00:00[0m
Collecting sentry-sdk>=1.0.0 (from wandb)
  Downloading sentry_sdk-1.39.1-py2.py3-none-any.whl (254 kB)
[2K     [90m

In [3]:
!pip install dgllife # or conda install -c conda-forge torchmetrics
!pip install dgl # or conda install -c conda-forge wandb

Collecting dgllife
  Downloading dgllife-0.3.2-py3-none-any.whl (226 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m226.1/226.1 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgllife
Successfully installed dgllife-0.3.2
Collecting dgl
  Downloading dgl-1.1.3-cp310-cp310-manylinux1_x86_64.whl (6.5 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.5/6.5 MB[0m [31m19.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.3


In [4]:
!pip install rdkit

Collecting rdkit
  Downloading rdkit-2023.9.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (34.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m34.3/34.3 MB[0m [31m25.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.9.3


# Molecular graphs
(Copied from [mldd23 repository](https://github.com/gmum/mldd23/blob/main/labs/L3-graph-neural-networks/laboratory.ipynb))
In mathematics, a graph is an object that consists of a set of vertices (nodes) connected with edges, i.e. $\mathcal{G} = (V, E)$, where $V = \{ v_i: i \in \{1, 2, \dots, N \} \}$ and $E \subseteq \{ (v_i, v_j):\, v_i,v_j \in V \}$.

Molecular graphs are a special class of graphs, where besides nodes (denoting atoms) and edges (denoting chemical bonds), we have an additional information about atom types and sometimes also bond types. We can assume that we have an additional set of node/atom features encoded as a matrix $X$, where $X_{ij}$ is the $j$-th feature of the $i$-th atom. As atomic features, we can have one-hot encoded atom symbols (a vector containing zeros on all positions besides the position that corresponds to the atom symbol), the number of implicit hydrogens bonded with this atom, or the number of heavy neighbors (atoms other than hydrogens bonded to the given atom).

Egdes/bonds can be encoded in two different ways. One method is to use an adjacency matrix $A$, where $A_{ij}=1$ if nodes/atoms $v_i$ nad $v_j$ are connected ($A_{ij}=0$ otherwise). In the case of sparse matrices, a more useful encoding is a list of pairs of connected atoms (a list of index pairs). This latter enocding is used by the PyTorch-Geometric library.

In practice, a molecular graph can be described by two matrices: $X \in \mathbb{R}^{N \times F}$ and $E \in \{0, 1,\dots,N-1\}^{2 \times N}$, where $N$ is the number of atoms, and $F$ is the number of atomic features.
<img src="https://github.com/MariiaSaltykova/machine_learning/blob/main/resources/mol_graph.png?raw=1" height="500" />

# Dataset

We are going to use FreeSolv dataset that contains 642 hydration free energy values for small molecules. The goal is to predict the [hydration free energy](https://en.wikipedia.org/wiki/Hydration_energy) of a given molecule. It's a very commonly used dataset for benchmarking molecular property prediction models. It's small, so we can minimize our co2 footprint and time spent on training.

Molecules in most chemical datasets are represented with SMILES. SMILES is a linearization of the molecular graph, it's pretty convenient and can even be used as an input to text-based models. Fortunately, dgllife provides a fancy FreeSolv dataset wrapper that will 1) transform the SMILES into a molecular graph, and 2) encode the nodes and edges with some sensible chemical features (like atom types, bond type etc.) with node and edge features, so we don't really need to care about it.

In [5]:
from dgllife.utils import CanonicalAtomFeaturizer, CanonicalBondFeaturizer, SMILESToBigraph
from dgllife.data import FreeSolv
import torch
import dgl

node_featurizer = CanonicalAtomFeaturizer()
edge_featurizer = CanonicalBondFeaturizer(self_loop=True)
dataset = FreeSolv(
    smiles_to_graph=SMILESToBigraph(
        node_featurizer=node_featurizer,
        edge_featurizer=edge_featurizer,
        add_self_loop=True,
    ),
)

DGL backend not selected or invalid.  Assuming PyTorch for now.


Setting the default backend to "pytorch". You can change it in the ~/.dgl/config.json file or export the DGLBACKEND environment variable.  Valid options are: pytorch, mxnet, tensorflow (all lowercase)
Downloading /root/.dgl/FreeSolv.zip from https://data.dgl.ai/dataset/FreeSolv.zip...
Extracting file to /root/.dgl/FreeSolv
Processing dgl graphs from scratch...


## Playground

In [6]:
smiles, graph, label = dataset[0]
smiles, graph, label

('CN(C)C(=O)c1ccc(cc1)OC',
 Graph(num_nodes=13, num_edges=39,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 tensor([-11.0100]))

We see that the dataset item consist of a SMILES string, a graph, and a label. The graph is a [DGLGraph](https://docs.dgl.ai/en/0.8.x/api/python/dgl.DGLGraph.html) object that contains node and edge features. We can access them with the following code:

In [7]:
graph.ndata['h'].shape  # node features

torch.Size([13, 74])

In [8]:
graph.edata['e'].shape  # edge features

torch.Size([39, 13])

In [9]:
start_nodes, end_nodes = graph.edges()  # edges. Note that edges are directed, so we have two edges for each bond. Moreover, we have self-loops, to easily handle molecules with only one atom.
edges = torch.stack([start_nodes, end_nodes], dim=1)
edges

tensor([[12,  0],
        [ 0, 12],
        [ 0,  2],
        [ 2,  0],
        [ 0,  4],
        [ 4,  0],
        [ 4,  7],
        [ 7,  4],
        [ 4,  9],
        [ 9,  4],
        [ 9,  6],
        [ 6,  9],
        [ 6, 10],
        [10,  6],
        [10, 11],
        [11, 10],
        [11,  3],
        [ 3, 11],
        [ 3,  8],
        [ 8,  3],
        [11,  5],
        [ 5, 11],
        [ 5,  1],
        [ 1,  5],
        [ 8,  9],
        [ 9,  8],
        [ 0,  0],
        [ 1,  1],
        [ 2,  2],
        [ 3,  3],
        [ 4,  4],
        [ 5,  5],
        [ 6,  6],
        [ 7,  7],
        [ 8,  8],
        [ 9,  9],
        [10, 10],
        [11, 11],
        [12, 12]], dtype=torch.int32)

Importantly, if we want to create a batch of graphs, we can simply treat the graphs as... a single graph with many disconnected components. The reason is that MPNN cannot pass the message between disconnected compontents, so the graphs in a batch won't influence each other. To make a batch from two graphs, we can simply run:

In [10]:
_, graph_1, _ = dataset[0]
_, graph_2, _ = dataset[1]
collated_graph = dgl.batch([graph_1, graph_2])
graph_1, graph_2, collated_graph

(Graph(num_nodes=13, num_edges=39,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=5, num_edges=13,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}),
 Graph(num_nodes=18, num_edges=52,
       ndata_schemes={'h': Scheme(shape=(74,), dtype=torch.float32)}
       edata_schemes={'e': Scheme(shape=(13,), dtype=torch.float32)}))

In [11]:
collated_graph.batch_num_nodes()

tensor([13,  5])

In the collated_graph, the ids corresponding to the nodes of graph_2 are shifted by the size of graph_1:

In [12]:
graph_1.nodes(), graph_2.nodes(), collated_graph.nodes()

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12], dtype=torch.int32),
 tensor([0, 1, 2, 3, 4], dtype=torch.int32),
 tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17],
        dtype=torch.int32))

## Split
We are going to make our split slightly harder by using [scaffold](https://hub.knime.com/infocom/extensions/jp.co.infocom.cheminfo.jchem.feature/latest/jp.co.infocom.cheminfo.jchem.bemismurckoclustering.BemisMurckoClusteringNodeFactory) (scaffold is the largest cycle in a molecule) splitting that puts molecules with similar scaffolds to the same split.

In [13]:
from dgllife.utils import ScaffoldSplitter

splitter = ScaffoldSplitter()
train, valid, test = splitter.train_val_test_split(dataset)

Start initializing RDKit molecule instances...
Start computing Bemis-Murcko scaffolds.


In [14]:
!pip install utils

Collecting utils
  Downloading utils-1.0.1-py2.py3-none-any.whl (21 kB)
Installing collected packages: utils
Successfully installed utils-1.0.1


In [15]:
import os
from typing import Callable, Tuple, List, Type

import numpy as np
import torch

import utils
from types import SimpleNamespace
from torch.optim import SGD
from torch.optim import Adagrad as torch_adagrad
from torch.optim import RMSprop as torch_rmsprop
from torch.optim import Adadelta as torch_adadelta
from torch.optim import Adam as torch_adam


def check_closest(fn: Callable) -> None:
    inputs = [
        (6, np.array([5, 3, 4])),
        (10, np.array([12, 2, 8, 9, 13, 14])),
        (-2, np.array([-5, 12, 6, 0, -14, 3])),
    ]
    assert np.isclose(fn(*inputs[0]), 5), "Jest błąd w funkcji closest!"
    assert np.isclose(fn(*inputs[1]), 9), "Jest błąd w funkcji closest!"
    assert np.isclose(fn(*inputs[2]), 0), "Jest błąd w funkcji closest!"


def check_poly(fn: Callable) -> None:
    inputs = [
        (6, np.array([5.5, 3, 4])),
        (10, np.array([12, 2, 8, 9, 13, 14])),
        (-5, np.array([6, 3, -12, 9, -15])),
    ]
    assert np.isclose(fn(*inputs[0]), 167.5), "Jest błąd w funkcji poly!"
    assert np.isclose(fn(*inputs[1]), 1539832), "Jest błąd w funkcji poly!"
    assert np.isclose(fn(*inputs[2]), -10809), "Jest błąd w funkcji poly!"


def check_multiplication_table(fn: Callable) -> None:
    inputs = [3, 5]
    assert np.all(
        fn(inputs[0]) == np.array([[1, 2, 3], [2, 4, 6], [3, 6, 9]])
    ), "Jest błąd w funkcji multiplication_table!"
    assert np.all(
        fn(inputs[1])
        == np.array(
            [
                [1, 2, 3, 4, 5],
                [2, 4, 6, 8, 10],
                [3, 6, 9, 12, 15],
                [4, 8, 12, 16, 20],
                [5, 10, 15, 20, 25],
            ]
        )
    ), "Jest błąd w funkcji multiplication_table!"


def check_1_1(
        mean_error: Callable,
        mean_squared_error: Callable,
        max_error: Callable,
        train_sets: List[np.ndarray],
) -> None:
    train_set_1d, train_set_2d, train_set_10d = train_sets
    assert np.isclose(mean_error(train_set_1d, np.array([8])), 8.897352)
    assert np.isclose(mean_error(train_set_2d, np.array([2.5, 5.2])), 7.89366)
    assert np.isclose(mean_error(train_set_10d, np.array(np.arange(10))), 14.16922)

    assert np.isclose(mean_squared_error(train_set_1d, np.array([3])), 23.03568)
    assert np.isclose(mean_squared_error(train_set_2d, np.array([2.4, 8.9])), 124.9397)
    assert np.isclose(mean_squared_error(train_set_10d, -np.arange(10)), 519.1699)

    assert np.isclose(max_error(train_set_1d, np.array([3])), 7.89418)
    assert np.isclose(max_error(train_set_2d, np.array([2.4, 8.9])), 14.8628)
    assert np.isclose(max_error(train_set_10d, -np.linspace(0, 5, num=10)), 23.1727)


def check_1_2(
        minimize_me: Callable, minimize_mse: Callable, minimize_max: Callable, train_set_1d: np.ndarray
) -> None:
    assert np.isclose(minimize_mse(train_set_1d), -0.89735)
    assert np.isclose(minimize_mse(train_set_1d * 2), -1.79470584)
    assert np.isclose(minimize_me(train_set_1d), -1.62603)
    assert np.isclose(minimize_me(train_set_1d ** 2), 3.965143)
    assert np.isclose(minimize_max(train_set_1d), 0.0152038)
    assert np.isclose(minimize_max(train_set_1d / 2), 0.007601903895526174)


def check_1_3(
        me_grad: Callable, mse_grad: Callable, max_grad: Callable, train_sets: List[np.ndarray]
) -> None:
    train_set_1d, train_set_2d, train_set_10d = train_sets
    assert all(np.isclose(me_grad(train_set_1d, np.array([0.99])), [0.46666667]))
    assert all(np.isclose(me_grad(train_set_2d, np.array([0.99, 8.44])), [0.21458924, 0.89772834]))
    assert all(
        np.isclose(
            me_grad(train_set_10d, np.linspace(0, 10, num=10)),
            [
                -0.14131273,
                -0.031631,
                0.04742431,
                0.0353542,
                0.16364242,
                0.23353252,
                0.30958123,
                0.35552034,
                0.4747464,
                0.55116738,
            ],
        )
    )

    assert all(np.isclose(mse_grad(train_set_1d, np.array([1.24])), [4.27470585]))
    assert all(
        np.isclose(mse_grad(train_set_2d, np.array([-8.44, 10.24])), [-14.25378235, 21.80373175])
    )
    assert all(np.isclose(max_grad(train_set_1d, np.array([5.25])), [1.0]))
    assert all(
        np.isclose(max_grad(train_set_2d, np.array([-6.28, -4.45])), [-0.77818704, -0.62803259])
    )


def check_02_linear_regression(lr_cls: Type) -> None:
    from sklearn import datasets

    np.random.seed(54)

    input_dataset = datasets.load_diabetes()
    lr = lr_cls()
    lr.fit(input_dataset.data, input_dataset.target)
    returned = lr.predict(input_dataset.data)
    expected = np.load(".checker/05/lr_diabetes.out.npz")["data"]
    assert np.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    loss = lr.loss(input_dataset.data, input_dataset.target)
    assert np.isclose(
        loss, 26004.287402, rtol=1e-03, atol=1e-06
    ), "Wrong value of the loss function!"


def check_02_regularized_linear_regression(lr_cls: Type) -> None:
    from sklearn import datasets

    np.random.seed(54)

    input_dataset = datasets.load_diabetes()
    lr = lr_cls(lr=1e-2, alpha=1e-4)
    lr.fit(input_dataset.data, input_dataset.target)
    returned = lr.predict(input_dataset.data)
    # np.savez_compressed(".checker/05/rlr_diabetes.out.npz", data=returned)
    expected = np.load(".checker/05/rlr_diabetes.out.npz")["data"]
    assert np.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    loss = lr.loss(input_dataset.data, input_dataset.target)
    assert np.isclose(
        loss, 26111.08336411, rtol=1e-03, atol=1e-06
    ), "Wrong value of the loss function!"


def check_4_1_mse(fn: Callable, datasets: List[Tuple[np.ndarray, np.ndarray]]) -> None:
    results = [torch.tensor(13.8520), torch.tensor(31.6952)]
    for (data, param), loss in zip(datasets, results):
        result = fn(data, param)
        assert torch.allclose(fn(data, param), loss, atol=1e-3), "Wrong loss returned!"


def check_4_1_me(fn: Callable, datasets: List[Tuple[np.ndarray, np.ndarray]]) -> None:
    results = [torch.tensor(3.6090), torch.tensor(5.5731)]
    for (data, param), loss in zip(datasets, results):
        assert torch.allclose(fn(data, param), loss, atol=1e-3), "Wrong loss returned!"


def check_4_1_max(fn: Callable, datasets: List[Tuple[np.ndarray, np.ndarray]]) -> None:
    results = [torch.tensor(7.1878), torch.tensor(7.5150)]
    for (data, param), loss in zip(datasets, results):
        assert torch.allclose(fn(data, param), loss, atol=1e-3), "Wrong loss returned!"


def check_4_1_lin_reg(fn: Callable, data: List[np.ndarray]) -> None:
    X, y, w = data
    assert torch.allclose(fn(X, w, y), torch.tensor(29071.6699), atol=1e-3), "Wrong loss returned!"


def check_4_1_reg_reg(fn: Callable, data: List[np.ndarray]) -> None:
    X, y, w = data
    assert torch.allclose(fn(X, w, y), torch.tensor(29073.4551)), "Wrong loss returned!"


def check_04_logistic_reg(lr_cls: Type) -> None:
    np.random.seed(10)
    torch.manual_seed(10)

    # **** First dataset ****
    input_dataset = utils.get_classification_dataset_1d()
    lr = lr_cls(1)
    lr.fit(input_dataset.data, input_dataset.target, lr=1e-3, num_steps=int(1e4))
    returned = lr.predict(input_dataset.data)
    save_path = ".checker/04/lr_dataset_1d.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    returned = lr.predict_proba(input_dataset.data)
    save_path = ".checker/04/lr_dataset_1d_proba.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    returned = lr.predict(input_dataset.data)
    save_path = ".checker/04/lr_dataset_1d_preds.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    # **** Second dataset ****
    input_dataset = utils.get_classification_dataset_2d()
    lr = lr_cls(2)
    lr.fit(input_dataset.data, input_dataset.target, lr=1e-2, num_steps=int(1e4))
    returned = lr.predict(input_dataset.data)
    save_path = ".checker/04/lr_dataset_2d.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    returned = lr.predict_proba(input_dataset.data)
    save_path = ".checker/04/lr_dataset_2d_proba.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"

    returned = lr.predict(input_dataset.data)
    save_path = ".checker/04/lr_dataset_2d_preds.out.torch"
    # torch.save(returned, save_path)
    expected = torch.load(save_path)
    assert torch.allclose(expected, returned, rtol=1e-03, atol=1e-06), "Wrong prediction returned!"


def optim_f(w: torch.Tensor) -> torch.Tensor:
    x = torch.tensor([0.2, 2], dtype=torch.float)
    return torch.sum(x * w ** 2)


def optim_g(w: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    x = torch.tensor([0.2, 2], dtype=torch.float)
    return torch.sum(x * w + b)


opt_checker_1 = SimpleNamespace(
    f=optim_f, params=[torch.tensor([-6, 2], dtype=torch.float, requires_grad=True)]
)
opt_checker_2 = SimpleNamespace(
    f=optim_g,
    params=[
        torch.tensor([-6, 2], dtype=torch.float, requires_grad=True),
        torch.tensor([1, -1], dtype=torch.float, requires_grad=True),
    ],
)

test_params = {
    "Momentum": {
        "torch_cls": SGD,
        "torch_params": {"lr": 0.1, "momentum": 0.9},
        "params": {"learning_rate": 0.1, "gamma": 0.9},
    },
    "Adagrad": {
        "torch_cls": torch_adagrad,
        "torch_params": {"lr": 0.5, "eps": 1e-8},
        "params": {"learning_rate": 0.5, "epsilon": 1e-8},
    },
    "RMSProp": {
        "torch_cls": torch_rmsprop,
        "torch_params": {
            "lr": 0.5,
            "alpha": 0.9,
            "eps": 1e-08,
        },
        "params": {"learning_rate": 0.5, "gamma": 0.9, "epsilon": 1e-8},
    },
    "Adadelta": {
        "torch_cls": torch_adadelta,
        "torch_params": {"rho": 0.9, "eps": 1e-1},
        "params": {"gamma": 0.9, "epsilon": 1e-1},
    },
    "Adam": {
        "torch_cls": torch_adam,
        "torch_params": {"lr": 0.5, "betas": (0.9, 0.999), "eps": 1e-08},
        "params": {"learning_rate": 0.5, "beta1": 0.9, "beta2": 0.999, "epsilon": 1e-8},
    },
}


def test_optimizer(optim_cls: Type, num_steps: int = 10) -> None:
    test_dict = test_params[optim_cls.__name__]

    for ns in [opt_checker_1, opt_checker_2]:
        torch_params = [p.clone().detach().requires_grad_(True) for p in ns.params]
        torch_opt = test_dict["torch_cls"](torch_params, **test_dict["torch_params"])
        for _ in range(num_steps):
            torch_opt.zero_grad()

            loss = ns.f(*torch_params)
            loss.backward()
            torch_opt.step()

        params = [p.clone().detach().requires_grad_(True) for p in ns.params]
        opt = optim_cls(params, **test_dict["params"])

        for _ in range(num_steps):
            opt.zero_grad()

            loss = ns.f(*params)
            loss.backward()
            opt.step()

        for p, tp in zip(params, torch_params):
            assert torch.allclose(p, tp)


def test_droput(dropout_cls: Type) -> None:
    drop = dropout_cls(0.5)
    drop.train()
    x = torch.randn(10, 30)
    out = drop(x)

    for row, orig_row in zip(out, x):
        zeros_in_row = torch.where(row == 0.0)[0]
        non_zeros_in_row = torch.where(row != 0.0)[0]
        non_zeros_scaled = (row[non_zeros_in_row] == 2 * orig_row[non_zeros_in_row]).all()
        assert len(zeros_in_row) > 0 and len(zeros_in_row) < len(row) and non_zeros_scaled

    drop_eval = dropout_cls(0.5)
    drop_eval.eval()
    x = torch.randn(10, 30)
    out_eval = drop_eval(x)

    for row in out_eval:
        zeros_in_row = len(torch.where(row == 0.0)[0])
        assert zeros_in_row == 0


def test_bn(bn_cls: Type) -> None:
    torch.manual_seed(42)
    bn = bn_cls(num_features=100)

    opt = torch.optim.SGD(bn.parameters(), lr=0.1)

    bn.train()
    x = torch.rand(20, 100)
    out = bn(x)

    assert out.mean().abs().item() < 1e-4
    assert abs(out.var().item() - 1) < 1e-1

    assert (bn.sigma != 1).all()
    assert (bn.mu != 1).all()

    loss = 1 - out.mean()
    loss.backward()
    opt.step()

    assert (bn.beta != 0).all()

    n_steps = 10

    for i in range(n_steps):
        x = torch.rand(20, 100)
        out = bn(x)
        loss = 1 - out.mean()
        loss.backward()
        opt.step()

    torch.manual_seed(43)
    test_x = torch.randn(20, 100)
    bn.eval()
    test_out = bn(test_x)

    assert abs(test_out.mean() + 0.5) < 1e-1


expected_mean_readout = torch.tensor(
    [[-0.0035, 0.0505, -0.2221, 0.1404, 0.1922, -0.3736, -0.0672, 0.0752,
      -0.0613, 0.0439, -0.1307, -0.0752, -0.0310, 0.0081, -0.0553, -0.1734],
     [-0.0054, -0.0144, -0.3113, 0.1665, 0.0738, -0.3303, 0.0420, 0.0668,
      0.0494, 0.2648, -0.0478, 0.0550, -0.1923, -0.0157, 0.0508, 0.0148],
     [-0.1912, 0.0309, -0.1512, 0.1283, 0.1120, -0.4540, -0.0644, 0.1378,
      -0.0194, 0.0103, -0.1713, 0.0175, -0.0604, -0.0193, -0.0208, -0.0822]]
)
expected_attention_readout = torch.Tensor(
    [[-0.0083, 0.0499, -0.2197, 0.1380, 0.1921, -0.3753, -0.0669, 0.0771,
      -0.0592, 0.0411, -0.1317, -0.0769, -0.0299, 0.0074, -0.0568, -0.1741],
     [-0.0068, -0.0131, -0.3102, 0.1656, 0.0736, -0.3312, 0.0410, 0.0670,
      0.0485, 0.2635, -0.0479, 0.0544, -0.1933, -0.0162, 0.0508, 0.0150],
     [-0.1911, 0.0308, -0.1514, 0.1271, 0.1100, -0.4542, -0.0658, 0.1376,
      -0.0215, 0.0099, -0.1723, 0.0164, -0.0618, -0.0209, -0.0217, -0.0817]],
)
expected_sage_layer_output = torch.tensor(
    [[-5.0965e-01, -4.5482e-01, -8.1451e-01, 5.4286e-03],
     [-5.6737e-01, -5.9137e-01, -7.9304e-01, 7.5955e-02],
     [-4.6768e-01, -5.0346e-01, -7.2765e-01, 5.0357e-02],
     [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
     [-5.0465e-01, -3.5816e-01, -8.7864e-01, -3.1902e-02],
     [-5.6591e-01, -4.2403e-01, -8.7506e-01, 2.9357e-02],
     [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
     [-5.7196e-01, -3.5674e-01, -9.4769e-01, -4.9931e-03],
     [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
     [-5.2655e-01, -5.1094e-01, -8.3806e-01, -1.8521e-02],
     [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
     [-5.7628e-01, -5.5394e-01, -8.7300e-01, -7.6976e-03],
     [-4.6768e-01, -5.0346e-01, -7.2765e-01, 5.0357e-02],
     [-5.4808e-01, -5.3204e-01, -7.8906e-01, 4.2878e-02],
     [-5.3417e-01, -3.5912e-01, -9.5030e-01, 2.3648e-05],
     [-6.2538e-01, -2.9249e-01, -1.1233e+00, 1.0970e-01],
     [-6.5214e-01, -3.8342e-01, -1.0136e+00, -1.6424e-02],
     [-6.5214e-01, -3.8342e-01, -1.0136e+00, -1.6424e-02]],
)
expected_gin_layer_output = torch.tensor(
    [[-0.4516, -0.3673, -0.5313, 0.3170],
     [-0.4524, -0.3760, -0.5243, 0.3249],
     [-0.4570, -0.3747, -0.5313, 0.3221],
     [-0.4763, -0.4030, -0.5390, 0.3335],
     [-0.4481, -0.3855, -0.5187, 0.3295],
     [-0.4545, -0.3838, -0.5245, 0.3276],
     [-0.4763, -0.4030, -0.5390, 0.3335],
     [-0.4390, -0.4001, -0.4973, 0.3446],
     [-0.4763, -0.4030, -0.5390, 0.3335],
     [-0.4683, -0.3882, -0.5400, 0.3248],
     [-0.4763, -0.4030, -0.5390, 0.3335],
     [-0.4682, -0.3921, -0.5374, 0.3277],
     [-0.4570, -0.3747, -0.5313, 0.3221],
     [-0.4225, -0.3671, -0.4928, 0.3295],
     [-0.3760, -0.3700, -0.4407, 0.3489],
     [-0.2646, -0.3342, -0.3357, 0.3683],
     [-0.3859, -0.3950, -0.4392, 0.3624],
     [-0.3859, -0.3950, -0.4392, 0.3624]],
)
expected_simple_mpnn_output = torch.tensor(
    [[-0.1990, -0.2007, -0.7749, -0.2355],
     [-0.5297, -0.4750, -0.8783, -0.0762],
     [-0.3664, -0.4155, -0.7463, -0.0573],
     [-0.5217, -0.3488, -0.9198, -0.1840],
     [0.1237, -0.0524, -0.5546, -0.1867],
     [-0.3597, -0.2378, -0.8626, -0.1551],
     [-0.5217, -0.3488, -0.9198, -0.1840],
     [-0.3358, -0.2634, -0.8318, -0.0586],
     [-0.5217, -0.3488, -0.9198, -0.1840],
     [-0.2175, -0.2724, -0.7910, -0.2460],
     [-0.5217, -0.3488, -0.9198, -0.1840],
     [-0.3758, -0.3293, -0.9195, -0.2665],
     [-0.3664, -0.4155, -0.7463, -0.0573],
     [-0.3907, -0.4223, -0.7682, -0.0586],
     [-0.2049, -0.2482, -0.7605, -0.0309],
     [-0.1718, 0.0814, -1.0231, -0.2095],
     [-0.3551, -0.2676, -0.8502, -0.0614],
     [-0.3551, -0.2676, -0.8502, -0.0614]]
)
expected_sum_readout = torch.tensor(
    [[-0.0451, 0.6570, -2.8874, 1.8256, 2.4987, -4.8573, -0.8733, 0.9780,
      -0.7967, 0.5701, -1.6988, -0.9777, -0.4033, 0.1053, -0.7191, -2.2545],
     [-0.0268, -0.0720, -1.5565, 0.8324, 0.3692, -1.6515, 0.2101, 0.3342,
      0.2468, 1.3238, -0.2389, 0.2752, -0.9615, -0.0785, 0.2541, 0.0741],
     [-0.9559, 0.1545, -0.7560, 0.6414, 0.5598, -2.2701, -0.3222, 0.6888,
      -0.0969, 0.0516, -0.8565, 0.0875, -0.3022, -0.0964, -0.1039, -0.4109]],
)
expected_gine_layer_output = torch.tensor(
    [[-0.4519, -0.3654, -0.5197, 0.3193],
     [-0.4577, -0.3681, -0.5309, 0.3200],
     [-0.4617, -0.3697, -0.5356, 0.3193],
     [-0.4318, -0.3586, -0.5039, 0.3215],
     [-0.3675, -0.3206, -0.4476, 0.3215],
     [-0.4474, -0.3725, -0.5134, 0.3252],
     [-0.4318, -0.3586, -0.5039, 0.3215],
     [-0.4617, -0.3816, -0.5311, 0.3244],
     [-0.4318, -0.3586, -0.5039, 0.3215],
     [-0.3174, -0.2810, -0.4102, 0.3140],
     [-0.4318, -0.3586, -0.5039, 0.3215],
     [-0.3173, -0.2847, -0.4078, 0.3168],
     [-0.4617, -0.3697, -0.5356, 0.3193],
     [-0.4367, -0.3529, -0.5122, 0.3167],
     [-0.4103, -0.3570, -0.4806, 0.3282],
     [-0.4105, -0.3539, -0.4767, 0.3282],
     [-0.4575, -0.3899, -0.5207, 0.3318],
     [-0.4575, -0.3899, -0.5207, 0.3318]]
)
expected_gat_output = torch.tensor(
    [[0.2640, 0.0480, 0.0950, -0.0174, -0.2840, 0.0064, 0.0522, -0.1773,
      0.1720, 0.1878, -0.1340, 0.0229],
     [0.1955, 0.0230, 0.0520, 0.0308, -0.2525, 0.0519, 0.0259, -0.1553,
      0.1808, 0.1965, -0.1323, 0.0663],
     [0.2423, 0.0486, 0.1118, -0.0467, -0.2726, 0.0444, 0.0325, -0.1617,
      0.1654, 0.1770, -0.1465, 0.0071],
     [0.2717, 0.0307, 0.0516, 0.1657, -0.2802, -0.1184, 0.1700, -0.1849,
      0.2089, 0.2373, -0.1915, -0.0212],
     [0.2887, -0.0457, 0.2075, 0.0216, -0.2877, -0.0890, 0.1351, -0.1585,
      0.2169, 0.1446, -0.0779, 0.0065],
     [0.2594, -0.0098, 0.0917, 0.0416, -0.2764, -0.0409, 0.1162, -0.1622,
      0.1887, 0.1710, -0.1145, 0.0457],
     [0.2717, 0.0307, 0.0516, 0.1657, -0.2802, -0.1184, 0.1700, -0.1849,
      0.2089, 0.2373, -0.1915, -0.0212],
     [0.2488, -0.0431, 0.1990, 0.0435, -0.2735, -0.0590, 0.0793, -0.1624,
      0.2314, 0.1686, -0.0642, 0.0281],
     [0.2717, 0.0307, 0.0516, 0.1657, -0.2802, -0.1184, 0.1700, -0.1849,
      0.2089, 0.2373, -0.1915, -0.0212],
     [0.2924, 0.0171, 0.0866, 0.1376, -0.2914, -0.1295, 0.1688, -0.1878,
      0.2136, 0.2186, -0.1574, -0.0087],
     [0.2717, 0.0307, 0.0516, 0.1657, -0.2802, -0.1184, 0.1700, -0.1849,
      0.2089, 0.2373, -0.1915, -0.0212],
     [0.2485, 0.0049, 0.0737, 0.1471, -0.2700, -0.0880, 0.1455, -0.1686,
      0.2142, 0.2180, -0.1624, 0.0050],
     [0.2423, 0.0486, 0.1118, -0.0467, -0.2726, 0.0444, 0.0325, -0.1617,
      0.1654, 0.1770, -0.1465, 0.0071],
     [0.1620, 0.0681, 0.0655, -0.0755, -0.2404, 0.0517, -0.0479, -0.1210,
      0.1310, 0.2535, -0.1107, 0.0330],
     [0.1185, -0.0203, 0.1807, -0.1225, -0.2394, 0.0383, -0.0468, -0.0771,
      0.1557, 0.2144, -0.0754, 0.0079],
     [0.1485, -0.0095, 0.1458, -0.0414, -0.2376, 0.0539, -0.0255, -0.1200,
      0.1828, 0.2043, -0.0969, 0.0238],
     [0.1268, -0.0224, 0.1846, -0.0438, -0.2185, 0.0215, -0.0412, -0.0883,
      0.1823, 0.2223, -0.0525, 0.0223],
     [0.1268, -0.0224, 0.1846, -0.0438, -0.2185, 0.0215, -0.0412, -0.0883,
      0.1823, 0.2223, -0.0525, 0.0223]]
)
expected_dot_attention_output = torch.tensor(
    [[[0.247395, 0.028085, 0.077167, 0.078323, -0.272530, -0.039541,
       0.097621, -0.173803, 0.194980, 0.212308, -0.155935, 0.009429],
      [0.247197, 0.028359, 0.076629, 0.078957, -0.272447, -0.039580,
       0.097599, -0.173896, 0.195020, 0.212726, -0.156340, 0.009328],
      [0.247197, 0.028359, 0.076629, 0.078957, -0.272447, -0.039580,
       0.097599, -0.173896, 0.195020, 0.212726, -0.156340, 0.009328],
      [0.247205, 0.028425, 0.076465, 0.079172, -0.272451, -0.039678,
       0.097692, -0.173931, 0.195030, 0.212829, -0.156458, 0.009298],
      [0.247366, 0.028181, 0.077058, 0.078077, -0.272529, -0.039312,
       0.097431, -0.173821, 0.194880, 0.212299, -0.155915, 0.009523],
      [0.247294, 0.028266, 0.076776, 0.078823, -0.272488, -0.039640,
       0.097682, -0.173873, 0.195015, 0.212599, -0.156228, 0.009356],
      [0.247205, 0.028425, 0.076465, 0.079172, -0.272451, -0.039678,
       0.097692, -0.173931, 0.195030, 0.212829, -0.156458, 0.009298],
      [0.247267, 0.028328, 0.076711, 0.078774, -0.272479, -0.039554,
       0.097609, -0.173882, 0.194982, 0.212625, -0.156263, 0.009363],
      [0.247205, 0.028425, 0.076465, 0.079172, -0.272451, -0.039678,
       0.097692, -0.173931, 0.195030, 0.212829, -0.156458, 0.009298],
      [0.247439, 0.028163, 0.077091, 0.078058, -0.272561, -0.039385,
       0.097505, -0.173831, 0.194878, 0.212268, -0.155892, 0.009516],
      [0.247205, 0.028425, 0.076465, 0.079172, -0.272451, -0.039678,
       0.097692, -0.173931, 0.195030, 0.212829, -0.156458, 0.009298],
      [0.247439, 0.028163, 0.077091, 0.078058, -0.272561, -0.039385,
       0.097505, -0.173831, 0.194878, 0.212268, -0.155892, 0.009516],
      [0.247197, 0.028359, 0.076629, 0.078957, -0.272447, -0.039580,
       0.097599, -0.173896, 0.195020, 0.212726, -0.156340, 0.009328]],

     [[0.149018, -0.009261, 0.146324, -0.040262, -0.237542, 0.053946,
       -0.025190, -0.120178, 0.182961, 0.204468, -0.097441, 0.023849],
      [0.148795, -0.009539, 0.146771, -0.040582, -0.237508, 0.053813,
       -0.025267, -0.119906, 0.182963, 0.204425, -0.097220, 0.023744],
      [0.148841, -0.009706, 0.146875, -0.040337, -0.237519, 0.053888,
       -0.025151, -0.120024, 0.183147, 0.204257, -0.097275, 0.023751],
      [0.148969, -0.009118, 0.146230, -0.040569, -0.237560, 0.053904,
       -0.025295, -0.120064, 0.182772, 0.204599, -0.097426, 0.023824],
      [0.148969, -0.009118, 0.146230, -0.040569, -0.237560, 0.053904,
       -0.025295, -0.120064, 0.182772, 0.204599, -0.097426, 0.023824],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000]]]
)
sub_optimal_multihead_attention_output = torch.tensor(
    [[[-0.047262, 0.158078, -0.034781, -0.059588, -0.184203, 0.316856,
       -0.051797, -0.067320, 0.136281, 0.157510, -0.516869, -0.168178],
      [-0.047338, 0.158081, -0.034847, -0.059696, -0.184200, 0.316798,
       -0.051846, -0.067277, 0.136283, 0.157549, -0.516819, -0.168149],
      [-0.047338, 0.158081, -0.034847, -0.059696, -0.184200, 0.316798,
       -0.051846, -0.067277, 0.136283, 0.157549, -0.516819, -0.168149],
      [-0.047268, 0.158293, -0.035019, -0.059606, -0.184036, 0.316881,
       -0.051766, -0.067249, 0.136248, 0.157471, -0.516841, -0.168100],
      [-0.047252, 0.158166, -0.034837, -0.059572, -0.184190, 0.316919,
       -0.051814, -0.067332, 0.136266, 0.157541, -0.516920, -0.168195],
      [-0.047393, 0.158178, -0.035001, -0.059744, -0.184167, 0.316858,
       -0.051927, -0.067270, 0.136275, 0.157615, -0.516819, -0.168113],
      [-0.047268, 0.158293, -0.035019, -0.059606, -0.184036, 0.316881,
       -0.051766, -0.067249, 0.136248, 0.157471, -0.516841, -0.168100],
      [-0.047321, 0.158228, -0.034976, -0.059684, -0.184134, 0.316867,
       -0.051861, -0.067273, 0.136249, 0.157590, -0.516864, -0.168169],
      [-0.047268, 0.158293, -0.035019, -0.059606, -0.184036, 0.316881,
       -0.051766, -0.067249, 0.136248, 0.157471, -0.516841, -0.168100],
      [-0.047309, 0.158097, -0.034834, -0.059547, -0.184185, 0.316923,
       -0.051868, -0.067345, 0.136356, 0.157493, -0.516793, -0.168035],
      [-0.047268, 0.158293, -0.035019, -0.059606, -0.184036, 0.316881,
       -0.051766, -0.067249, 0.136248, 0.157471, -0.516841, -0.168100],
      [-0.047309, 0.158097, -0.034834, -0.059547, -0.184185, 0.316923,
       -0.051868, -0.067345, 0.136356, 0.157493, -0.516793, -0.168035],
      [-0.047338, 0.158081, -0.034847, -0.059696, -0.184200, 0.316798,
       -0.051846, -0.067277, 0.136283, 0.157549, -0.516819, -0.168149]],

     [[-0.065048, 0.168032, -0.084588, -0.057781, -0.217642, 0.305161,
       -0.096480, -0.093513, 0.154069, 0.215230, -0.510000, -0.149824],
      [-0.065092, 0.168042, -0.084506, -0.057821, -0.217695, 0.305299,
       -0.096637, -0.093606, 0.154123, 0.215237, -0.510100, -0.149899],
      [-0.065017, 0.168049, -0.084602, -0.057950, -0.217689, 0.305254,
       -0.096547, -0.093616, 0.154084, 0.215219, -0.510125, -0.149970],
      [-0.065047, 0.168035, -0.084817, -0.057796, -0.217683, 0.305158,
       -0.096428, -0.093492, 0.153998, 0.215284, -0.509987, -0.149702],
      [-0.065047, 0.168035, -0.084817, -0.057796, -0.217683, 0.305158,
       -0.096428, -0.093492, 0.153998, 0.215284, -0.509987, -0.149702],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000]]]
)
expected_multihead_attention_output = torch.tensor(
    [[[-0.200755, -0.260959, 0.024996, -0.159938, -0.117561, -0.155180,
       -0.118139, -0.045920, 0.259167, 0.172350, -0.112809, 0.000357],
      [-0.095821, -0.200517, -0.044585, -0.245520, -0.105037, -0.087925,
       -0.114011, -0.138533, 0.175087, 0.122279, -0.243832, -0.062268],
      [-0.095821, -0.200517, -0.044585, -0.245520, -0.105037, -0.087925,
       -0.114011, -0.138533, 0.175087, 0.122279, -0.243832, -0.062268],
      [-0.052662, -0.286963, -0.057362, -0.221366, -0.107536, -0.107964,
       -0.254695, -0.186610, 0.313417, 0.165387, -0.220242, -0.075458],
      [-0.162727, -0.270804, 0.014429, -0.124464, -0.043585, -0.150952,
       -0.160673, -0.101641, 0.258712, 0.126690, -0.086178, 0.005115],
      [-0.142106, -0.245964, -0.061551, -0.185489, -0.082154, -0.078331,
       -0.108871, -0.160405, 0.274722, 0.203109, -0.137307, -0.037420],
      [-0.052662, -0.286963, -0.057362, -0.221366, -0.107536, -0.107964,
       -0.254695, -0.186610, 0.313417, 0.165387, -0.220242, -0.075458],
      [-0.163393, -0.247003, -0.065559, -0.170543, -0.109216, -0.102695,
       -0.102142, -0.125409, 0.295476, 0.250740, -0.140760, -0.024756],
      [-0.052662, -0.286963, -0.057362, -0.221366, -0.107536, -0.107964,
       -0.254695, -0.186610, 0.313417, 0.165387, -0.220242, -0.075458],
      [-0.177342, -0.279195, 0.035479, -0.132536, -0.025988, -0.143513,
       -0.184524, -0.091127, 0.269926, 0.090618, -0.080655, 0.007161],
      [-0.052662, -0.286963, -0.057362, -0.221366, -0.107536, -0.107964,
       -0.254695, -0.186610, 0.313417, 0.165387, -0.220242, -0.075458],
      [-0.177342, -0.279195, 0.035479, -0.132536, -0.025988, -0.143513,
       -0.184524, -0.091127, 0.269926, 0.090618, -0.080655, 0.007161],
      [-0.095821, -0.200517, -0.044585, -0.245520, -0.105037, -0.087925,
       -0.114011, -0.138533, 0.175087, 0.122279, -0.243832, -0.062268]],

     [[-0.095821, -0.200517, -0.044585, -0.245520, -0.105037, -0.087925,
       -0.114011, -0.138533, 0.175087, 0.122279, -0.243832, -0.062268],
      [-0.178757, -0.240222, -0.038666, -0.193246, -0.126989, -0.105094,
       -0.065300, -0.063201, 0.250495, 0.208277, -0.178849, -0.007677],
      [-0.150638, -0.238185, -0.068680, -0.149003, -0.092437, -0.140781,
       -0.076493, -0.075186, 0.247676, 0.200479, -0.141489, 0.012234],
      [-0.163393, -0.247003, -0.065559, -0.170543, -0.109216, -0.102695,
       -0.102142, -0.125409, 0.295476, 0.250740, -0.140760, -0.024756],
      [-0.163393, -0.247003, -0.065559, -0.170543, -0.109216, -0.102695,
       -0.102142, -0.125409, 0.295476, 0.250740, -0.140760, -0.024756],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000],
      [0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000,
       0.000000, 0.000000, 0.000000, 0.000000, 0.000000, 0.000000]]]
)

# Code

## Trainer

In [16]:
import copy
import numpy as np
from tqdm.autonotebook import tqdm
from dgl.dataloading import GraphDataLoader
from torchmetrics import Metric
from dgl.data import Subset
from torch import nn
from typing import Type
from typing import Dict, Any
from pathlib import Path
from abc import ABC, abstractmethod
# from lab.checker import expected_mean_readout, expected_gin_layer_output, expected_sage_layer_output, \
#     expected_attention_readout, expected_gine_layer_output, expected_sum_readout, expected_simple_mpnn_output


class LoggerBase(ABC):
    def __init__(self, logdir: str | Path):
        self.logdir = Path(logdir)
        self.logdir.mkdir(parents=True, exist_ok=True)

    @abstractmethod
    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        ...

    @abstractmethod
    def close(self):
        ...


class DummyLogger(LoggerBase):  # If you don't want to use any logger, you can use this one
    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        pass

    def close(self):
        pass

    def restart(self):
        pass


class MetricList:
    def __init__(self, metrics: Dict[str, Metric]):
        self.metrics = copy.deepcopy(metrics)

    def update(self, preds: torch.Tensor, targets: torch.Tensor) -> None:
        for name, metric in self.metrics.items():
            metric.update(preds.detach().cpu(), targets.cpu())

    def compute(self) -> Dict[str, float]:
        metrics = {}
        for name, metric_fn in self.metrics.items():
            metrics[name] = metric_fn.compute().item()
            metric_fn.reset()
        return metrics


class Trainer:
    def __init__(
            self,
            *,
            run_dir: str | Path,
            train_dataset: Subset,
            valid_dataset: Subset,
            train_metrics: Dict[str, Metric],
            valid_metrics: Dict[str, Metric],
            model: nn.Module,
            logger: LoggerBase,
            optimizer_kwargs: Dict[str, Any],
            optimizer_cls: Type[torch.optim.Optimizer] = torch.optim.Adam,
            n_epochs: int,
            train_batch_size: int = 32,
            valid_batch_size: int = 16,
            device: str = "cuda",
            valid_every_n_epochs: int = 1,
            loss_fn=nn.MSELoss()
    ):
        self.run_dir = Path(run_dir)
        self.train_loader = GraphDataLoader(
            dataset=train_dataset,
            batch_size=train_batch_size,
            shuffle=True,
        )
        self.valid_loader = GraphDataLoader(
            dataset=valid_dataset,
            batch_size=valid_batch_size,
            shuffle=True,
        )
        self.train_metrics = MetricList(train_metrics)
        self.valid_metrics = MetricList(valid_metrics)
        self.logger = logger
        self.model = model
        self.optimizer = optimizer_cls(model.parameters(), **optimizer_kwargs)
        self.n_epochs = n_epochs
        self.device = device
        self.valid_every_n_epochs = valid_every_n_epochs
        self.loss_fn = loss_fn
        self.model.to(device)

    @torch.no_grad()
    def validate(self, dataloader: GraphDataLoader, prefix: str) -> Dict[str, float]:
        previous_mode = self.model.training
        self.model.eval()
        losses = []
        for _, graphs, labels in dataloader:
            graphs = graphs.to(self.device)
            labels = labels.to(self.device)
            preds = self.model(graphs)
            loss = self.loss_fn(preds, labels)
            losses.append(loss.item())
            self.valid_metrics.update(preds, labels)
        self.model.train(mode=previous_mode)
        metrics = {"loss": np.mean(losses)} | self.valid_metrics.compute()
        self.logger.log_metrics(metrics=metrics, prefix=prefix)
        return metrics

    def train(self) -> Dict[str, float]:
        self.model.train()
        valid_metrics = {}
        for epoch in tqdm(range(self.n_epochs), total=self.n_epochs):
            for _, graphs, labels in self.train_loader:
                self.optimizer.zero_grad()
                graphs = graphs.to(self.device)
                labels = labels.to(self.device)
                preds = self.model(graphs)
                loss = self.loss_fn(preds, labels)
                loss.backward()
                self.optimizer.step()

                self.train_metrics.update(preds, labels)
                train_metrics = {"loss": loss.item()} | self.train_metrics.compute()
                self.logger.log_metrics(metrics=train_metrics, prefix="train")

                if epoch % self.valid_every_n_epochs == 0 or epoch == self.n_epochs - 1:
                    valid_metrics = self.validate(self.valid_loader, prefix="valid")

        return valid_metrics

    def test(self, dataset: Subset) -> Dict[str, float]:
        dataloader = GraphDataLoader(
            dataset=dataset,
            batch_size=16,
            shuffle=False,
        )
        return self.validate(dataloader, prefix="test")

    def close(self):  # close the logger, not really required for wandb
        self.logger.close()

  from tqdm.autonotebook import tqdm


# Graph Neural Networks (GNNs)
The high-level Graph Neural Network architecture we are going to use looks roughly like this:

<img src="https://github.com/MariiaSaltykova/machine_learning/blob/main/resources/gnn.png?raw=1" width="1200" />

- The Featurizer takes a molecule and transforms it to a graph with node and edge features (it happens at the level of dataset, so we don't really need to worry about that).
- In our case, we will linearly embed the node and edge features to the hidden size before applying first MPNN layer which is not captured in the diagram.
- The MPNN layer takes node (and possibly edge embeddings) and the graph structure and returns updated node embeddings. It happens in a loop.
- Then the node embeddings are aggregated by the Readout layer to obtain a graph embeddings.
- Finally, the graph embeddings are passed to the MLP to obtain the final prediction.

In [17]:
class MPNNLayerBase(ABC, nn.Module):
    def _init(self, hidden_size: int):
        """
        Attributes:
            hidden_size: the size of node (and edges) embeddings
        """
        super().__init__()
        self.hidden_size = hidden_size

    @abstractmethod
    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        ...


class ReadoutBase(nn.Module):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size

    @abstractmethod
    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        ...


class GNN(nn.Module):
    def __init__(self,
                 node_features_size: int,
                 edge_features_size: int,
                 hidden_size: int,
                 output_size: int,
                 mpnn_layer_cls: Type[MPNNLayerBase],
                 mpnn_layer_kwargs: Dict[str, Any],
                 mpnn_n_layers: int,
                 readout_cls: Type[ReadoutBase]):
        """
        Arguments:
            node_features_size: the size of node features
            edge_features_size: the size of edge features
            hidden_size: the size of node (and edge) embeddings
            output_size: the size of the final prediction
            mpnn_layer_cls: the class of MPNN layer
            mpnn_layer_kwargs: the kwargs for the MPNN layer
            mpnn_n_layers: the number of MPNN layers
            readout_cls: the class of Readout layer
        """
        super().__init__()
        self.linear_node = nn.Linear(node_features_size, hidden_size)
        self.linear_edge = nn.Linear(edge_features_size, hidden_size)
        self.mpnn_layers = nn.ModuleList([
            mpnn_layer_cls(hidden_size=hidden_size, **mpnn_layer_kwargs)
            for _ in range(mpnn_n_layers)
        ])
        self.readout = readout_cls(hidden_size=hidden_size)
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, output_size),
        )

    def forward(self, graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            graph: a DGLGraph that contains the graph structure and node/edge features in a sparse format
        Returns:
            predictions: the final predictions
        """
        node_embeddings, edge_embeddings = graph.ndata['h'], graph.edata['e']
        node_embeddings = self.linear_node(node_embeddings)
        edge_embeddings = self.linear_edge(
            edge_embeddings)  # some of the models does not use edge features, but we won't use if-clauses for convenience.
        for layer in self.mpnn_layers:
            node_embeddings = layer(node_embeddings=node_embeddings, edge_embeddings=edge_embeddings, graph=graph)
        graph_embedding = self.readout(node_embeddings, graph)
        predictions = self.mlp(graph_embedding)
        return predictions

## Readout
Readout operation is used to aggregate node embeddings to obtain a graph embedding. There are many different readout operations, but the most popular are: sum, mean, attention, and max. We are going to implement the first three of them. Summing over nodes' embeddings seems trivial, but they're stored in a sparse format, meaning that all the nodes form all the graphs in a batch are stored in a one tensor of size `[num_nodes_1 + num_nodes_2 + ... + num_nodes_N, hidden_size]':   

In [18]:
batched_graph = dgl.batch([dataset[0][1], dataset[1][1], dataset[2][1]])
linear = nn.Linear(node_featurizer.feat_size(), 16)
node_embeddings = linear(batched_graph.ndata['h'])
node_embeddings.shape, batched_graph.batch_num_nodes()

(torch.Size([23, 16]), tensor([13,  5,  5]))

In [19]:
!pip install torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.4.0-py3-none-any.whl (1.0 MB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/1.0 MB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/1.0 MB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━[0m [32m0.7/1.0 MB[0m [31m10.6 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.0/1.0 MB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: torch_geometric
Successfully installed torch_geometric-2.4.0


For simplicity, we will convert the sparse node embeddings to a dense format with padding. Then the shape of the node embeddings will be `[batch_size, max_num_nodes, hidden_size]`. We can use the `to_dense_batch` function from `torch_geometric` for that:

In [20]:
from typing import Tuple
from torch_geometric.utils import to_dense_batch


def to_dense_embeddings(node_embeddings: torch.Tensor,
                        graph: dgl.DGLGraph,
                        fill_value: float = 0.0) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Converts sparse node embeddings to dense node embeddings with padding.
    Arguments:
        node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        graph: a batch of graphs
        fill_value: a value to fill the padding with
    Returns:
        node_embeddings: node embeddings in a dense format, i.e. [batch_size, max_num_nodes, hidden_size]
        mask: a mask indicating which nodes are real and which are padding, i.e. [batch_size, max_num_nodes]
    """
    num_nodes = graph.batch_num_nodes() # e.g. [2, 3, 3]
    # print(num_nodes)
    indices = torch.arange(len(num_nodes), device=num_nodes.device)
    batch = torch.repeat_interleave(indices, num_nodes).long() # e.g. [0, 0, 1, 1, 1, 2, 2, 2]
    # print(batch)
    return to_dense_batch(node_embeddings, batch,
                          fill_value=fill_value)  # that's the only reason we have torch_geometric in the requirements


def to_sparse_embeddings(node_embeddings: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
    """
    Converts dense node embeddings to sparse node embeddings.
    Arguments:
        node_embeddings: node embeddings in a dense format, i.e. [batch_size, max_num_nodes, hidden_size]
        mask: a mask indicating which nodes are real and which are padding, i.e. [batch_size, max_num_nodes]
    Returns:
        node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
    """
    return node_embeddings[mask]

Now, we can simply convert the node embeddings to a dense format and sum them $x = \sum_i^n x_i$:

In [21]:
class SumReadout(ReadoutBase):
    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        # We can also use dgl.sum_nodes function, but let assume it's forbidden in that notebook ;)
        node_embeddings, _ = to_dense_embeddings(node_embeddings, graph)
        return node_embeddings.sum(dim=1)

In [22]:
def test_readout(readout_cls: Type[ReadoutBase], expected_output: torch.Tensor):
    torch.manual_seed(0)
    graph = dgl.batch([dataset[0][1], dataset[1][1], dataset[2][1]])

    linear = nn.Linear(node_featurizer.feat_size(), 16)
    node_embeddings = linear(graph.ndata['h'])
    readout = readout_cls(hidden_size=16)
    result = readout(node_embeddings, graph)
    print(torch.allclose(result, expected_output, atol=1e-3))

In [23]:

test_readout(SumReadout, expected_sum_readout)

True


In [24]:
from torch_geometric.nn import global_mean_pool

### Task 1. Implement mean readout (1 point).
Implement the mean readout given by formula $x = \frac{1}{n}\sum_i^n x_i$:

In [25]:
class MeanReadout(ReadoutBase):
    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        # Don't use any dlg functions here


        node_embeddings, mask = to_dense_embeddings(node_embeddings, graph)
        # res_custom = node_embeddings.mean(dim=1)
        # e.g. [0, 0, 1, 1, 1, 2, 2, 2]        # res_torch = global_mean_pool(node_embeddings, batch=None)

        # print(node_embeddings.size())#3,13,16
        masked_tensor = node_embeddings * mask.unsqueeze(2)

# Compute the sum along the specified axis and divide by the number of True values in the mask
        masked_sum = torch.sum(masked_tensor, dim=1)
        mask_count = torch.sum(mask, dim=1, keepdim=True)
        masked_mean = masked_sum / mask_count

        # return global_mean_pool(node_embeddings, batch=None)
        return masked_mean

test_readout(MeanReadout, expected_mean_readout)


True


### Task 2. Implement attention readout (2 points).
Implement the attention readout given by formula $x = \sum_i^n \frac{\exp(score_i))}{\sum_j^n \exp(score_j)}x_i$, where $score_i=score\_mlp(x_i)$:

In [26]:
class AttentionReadout(ReadoutBase):
    def __init__(self, hidden_size: int):
        super().__init__(hidden_size)
        self.score_mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, 1),
        )

    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        node_embeddings, mask = to_dense_embeddings(node_embeddings, graph)
        masked_node_embeddings = node_embeddings * mask.unsqueeze(2)

# Calculate the exponential scores only for True values in the mask
        exp_scores = torch.exp(self.score_mlp(masked_node_embeddings))

        # Create a new mask to handle the False values in the original mask

        # Calculate the softmax probabilities using the new mask
        softmax_probs = exp_scores / torch.sum(torch.exp(self.score_mlp(masked_node_embeddings)), dim=1, keepdim=True)

        # Calculate the masked mean using softmax probabilities and the new mask
        masked_mean = torch.sum(softmax_probs * masked_node_embeddings, dim=1)

        # result = torch.sum(((torch.exp(self.score_mlp(node_embeddings),)) / (torch.sum(torch.exp(self.score_mlp(node_embeddings)), dim=1, keepdim=True))) * node_embeddings, dim=1)
        return masked_mean

test_readout(AttentionReadout, expected_attention_readout)
# expected_attention_readout = torch.Tensor(
#     [[-0.0083, 0.0499, -0.2197, 0.1380, 0.1921, -0.3753, -0.0669, 0.0771,
#       -0.0592, 0.0411, -0.1317, -0.0769, -0.0299, 0.0074, -0.0568, -0.1741],
#      [-0.0068, -0.0131, -0.3102, 0.1656, 0.0736, -0.3312, 0.0410, 0.0670,
#       0.0485, 0.2635, -0.0479, 0.0544, -0.1933, -0.0162, 0.0508, 0.0150],
#      [-0.1911, 0.0308, -0.1514, 0.1271, 0.1100, -0.4542, -0.0658, 0.1376,
#       -0.0215, 0.0099, -0.1723, 0.0164, -0.0618, -0.0209, -0.0217, -0.0817]],
# )

#my:
# tensor([[-0.0083,  0.0499, -0.2197,  0.1380,  0.1921, -0.3753, -0.0669,  0.0771,
#          -0.0592,  0.0411, -0.1317, -0.0769, -0.0299,  0.0074, -0.0568, -0.1741],
#         [-0.0025, -0.0049, -0.1165,  0.0622,  0.0276, -0.1244,  0.0154,  0.0252,
#           0.0182,  0.0990, -0.0180,  0.0204, -0.0726, -0.0061,  0.0191,  0.0056],
#         [-0.0728,  0.0117, -0.0576,  0.0484,  0.0419, -0.1729, -0.0250,  0.0524,
#          -0.0082,  0.0038, -0.0656,  0.0062, -0.0235, -0.0080, -0.0083, -0.0311]],
#        grad_fn=<SumBackward1>)

False


## Message Passing Neural Networks (MPNNs)
Message Passing is given by formula:
$$
x'_i=\rho(x_i, \square_{j\in N(i)} \psi(x_j, x_i, e_{ji})),
$$
where $\psi$ is learnable message function, $\rho$ is learnable update, and $\square$ is aggregation function.

### Simple MPNN
For instance, we can define a very simple MPNN layer by the following formula:
$$
x'_i=W_1x_i + W_2\sum_{j\in N(i)} W_3x_j,
$$
where W_i are linear layers with implicit bias term (we will make the bias implicit in every formula in that notebook). Let us implement this simple MPNN:

In [27]:
class SimpleMPNNLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.linear_3 = nn.Linear(hidden_size, hidden_size)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        # graph is bi-directed, so we can freely swap the "start" and "end" meanings
        start_nodes, end_nodes = graph.edges(order='srcdst') # using this `order` value sorts the `start_nodes`
        messages = self.linear_3(node_embeddings[end_nodes]) # W_3x_j
        message_dense, _ = to_dense_batch(messages, start_nodes.long(), fill_value=0.0) # to make the life easier, we convert the node embeddings to dense representation
        aggregated_message = message_dense.sum(dim=1) # \sum_{j\in N(i)} W_3x_j
        aggregated_message = self.linear_2(aggregated_message) # W_2\sum_{j\in N(i)} W_3x_j
        node_embeddings = self.linear_1(node_embeddings) + aggregated_message # W_1x_i + W_2\sum_{j\in N(i)} W_3x_j
        return node_embeddings

In [28]:
def test_mpnn_layer(mpnn_layer_cls: Type[MPNNLayerBase], expected_output: torch.Tensor):
    torch.manual_seed(0)
    graph = dgl.batch([dataset[0][1], dataset[1][1]])
    linear_nodes = nn.Linear(node_featurizer.feat_size(), 4)
    linear_edges = nn.Linear(edge_featurizer.feat_size(), 4)
    node_embeddings = linear_nodes(graph.ndata['h'])
    edge_embeddings = linear_edges(graph.edata['e'])
    layer = mpnn_layer_cls(hidden_size=4)
    result = layer(node_embeddings, edge_embeddings, graph)
    assert torch.allclose(result, expected_output, atol=1e-3)

In [29]:
test_mpnn_layer(SimpleMPNNLayer, expected_simple_mpnn_output)

In [30]:
import dgl.function as fn

### Task 3. Implement GraphSAGE layer (2 points).
Implement a GraphSAGE given by the following formula:
$$
x'_i=W_1x_i + W_2\frac{1}{deg(i)}\sum_{j\in N(i)} x_j,
$$
where $deg(i) = #N(i)$ is the number of neighbors of node $i$.

In [31]:
class SAGELayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        with graph.local_scope():
            # Message Passing: Aggregate messages from neighbors
            graph.ndata['h'] = node_embeddings
            graph.update_all(message_func=self.message_func, reduce_func=self.reduce_func)

            # Calculate the degree of each node
            deg = graph.in_degrees().float().clamp(min=1)

            # Normalize aggregated messages by the degree of each node
            graph.ndata['h_neigh'] = graph.ndata['h_neigh'] / deg.unsqueeze(-1)

            # Perform linear transformations
            h_linear_1 = self.linear_1(node_embeddings)
            h_neigh_linear_2 = self.linear_2(graph.ndata['h_neigh'])

            # Sum the two sets of transformed embeddings
            h_combined = h_linear_1 + h_neigh_linear_2
            h_combined = self.dropout(h_combined)
            return h_combined
    def message_func(self, edges):
        return {'m': edges.src['h']}

    def reduce_func(self, nodes):
        return {'h_neigh': torch.sum(nodes.mailbox['m'], dim=1)}

test_mpnn_layer(SAGELayer, expected_sage_layer_output)

    # [[-5.0965e-01, -4.5482e-01, -8.1451e-01, 5.4286e-03],
    #  [-5.6737e-01, -5.9137e-01, -7.9304e-01, 7.5955e-02],
    #  [-4.6768e-01, -5.0346e-01, -7.2765e-01, 5.0357e-02],
    #  [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
    #  [-5.0465e-01, -3.5816e-01, -8.7864e-01, -3.1902e-02],
    #  [-5.6591e-01, -4.2403e-01, -8.7506e-01, 2.9357e-02],
    #  [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
    #  [-5.7196e-01, -3.5674e-01, -9.4769e-01, -4.9931e-03],
    #  [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
    #  [-5.2655e-01, -5.1094e-01, -8.3806e-01, -1.8521e-02],
    #  [-6.4185e-01, -5.0983e-01, -8.6305e-01, 1.3008e-02],
    #  [-5.7628e-01, -5.5394e-01, -8.7300e-01, -7.6976e-03],
    #  [-4.6768e-01, -5.0346e-01, -7.2765e-01, 5.0357e-02],
    #  [-5.4808e-01, -5.3204e-01, -7.8906e-01, 4.2878e-02],
    #  [-5.3417e-01, -3.5912e-01, -9.5030e-01, 2.3648e-05],
    #  [-6.2538e-01, -2.9249e-01, -1.1233e+00, 1.0970e-01],
    #  [-6.5214e-01, -3.8342e-01, -1.0136e+00, -1.6424e-02],
    #  [-6.5214e-01, -3.8342e-01, -1.0136e+00, -1.6424e-02]],

AssertionError: ignored

### Task 4. Implement GIN layer (2 points).
Implement a GIN layer given by the following formula:
$$
x'_i=mlp((1 + \epsilon)x_i + \sum_{j\in N(i)} x_j).
$$

In [None]:
class GINLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int, eps: float = 0.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )
    def message_func(self, edges):
      return {'m': edges.src['h']}

    def reduce_func(self, nodes):
      return {'h_neigh': torch.sum(nodes.mailbox['m'], dim=1)}

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """

        graph.ndata['h'] = node_embeddings
        graph.update_all(message_func=self.message_func, reduce_func=self.reduce_func)
        # Calculate the degree of each node
        # Normalize aggregated messages by the degree of each node
        graph.ndata['h_neigh'] = graph.ndata['h_neigh']


        result = self.mlp((1+self.eps) * node_embeddings + graph.ndata['h_neigh'])

        return result
    # [[-0.4516, -0.3673, -0.5313, 0.3170],
    #  [-0.4524, -0.3760, -0.5243, 0.3249],
    #  [-0.4570, -0.3747, -0.5313, 0.3221],
    #  [-0.4763, -0.4030, -0.5390, 0.3335],
    #  [-0.4481, -0.3855, -0.5187, 0.3295],
    #  [-0.4545, -0.3838, -0.5245, 0.3276],
    #  [-0.4763, -0.4030, -0.5390, 0.3335],
    #  [-0.4390, -0.4001, -0.4973, 0.3446],
    #  [-0.4763, -0.4030, -0.5390, 0.3335],
    #  [-0.4683, -0.3882, -0.5400, 0.3248],
    #  [-0.4763, -0.4030, -0.5390, 0.3335],
    #  [-0.4682, -0.3921, -0.5374, 0.3277],
    #  [-0.4570, -0.3747, -0.5313, 0.3221],
    #  [-0.4225, -0.3671, -0.4928, 0.3295],
    #  [-0.3760, -0.3700, -0.4407, 0.3489],
    #  [-0.2646, -0.3342, -0.3357, 0.3683],
    #  [-0.3859, -0.3950, -0.4392, 0.3624],
    #  [-0.3859, -0.3950, -0.4392, 0.3624]],
test_mpnn_layer(GINLayer, expected_gin_layer_output)

### Task 5. Implement GINE layer (2 points).
Implement a GINE layer given by the following formula:
$$
x'_i=mlp((1 + \epsilon)x_i + \sum_{j\in N(i)} ReLU(x_j + e_{ji})).
$$

In [None]:
class GINELayer(MPNNLayerBase):
    def __init__(self, hidden_size: int, eps: float = 0.0):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.relu = nn.ReLU()
        self.mlp = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.GELU(),
            nn.Linear(hidden_size, hidden_size),
        )

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        start_nodes, end_nodes, edge_ids = graph.edges(order='srcdst', form='all')

        end_nodes = end_nodes.long()
        start_nodes = start_nodes.long()
        # Compute neighbors for edge embeddings

        neighbor_edge_embeddings = edge_embeddings[edge_ids]

        # Compute the sum of neighbor edge embeddings and node embeddings
        sum_neighbor_embeddings = torch.zeros_like(node_embeddings)
        sum_neighbor_embeddings.scatter_add_(0, end_nodes.unsqueeze(1).expand_as(neighbor_edge_embeddings),
                                            self.relu(node_embeddings[end_nodes] + neighbor_edge_embeddings[end_nodes]))
        # sum_neighbor_embeddings.scatter_add_(0, start_nodes.unsqueeze(1).expand_as(neighbor_edge_embeddings),
                                            # self.relu(node_embeddings[start_nodes] + neighbor_edge_embeddings[start_nodes]))
        gine_update = self.mlp((1 + self.eps) * node_embeddings + sum_neighbor_embeddings)

        return gine_update

test_mpnn_layer(GINELayer, expected_gine_layer_output)
    # [[-0.4519, -0.3654, -0.5197, 0.3193],
    #  [-0.4577, -0.3681, -0.5309, 0.3200],
    #  [-0.4617, -0.3697, -0.5356, 0.3193],
    #  [-0.4318, -0.3586, -0.5039, 0.3215],
    #  [-0.3675, -0.3206, -0.4476, 0.3215],
    #  [-0.4474, -0.3725, -0.5134, 0.3252],
    #  [-0.4318, -0.3586, -0.5039, 0.3215],
    #  [-0.4617, -0.3816, -0.5311, 0.3244],
    #  [-0.4318, -0.3586, -0.5039, 0.3215],
    #  [-0.3174, -0.2810, -0.4102, 0.3140],
    #  [-0.4318, -0.3586, -0.5039, 0.3215],
    #  [-0.3173, -0.2847, -0.4078, 0.3168],
    #  [-0.4617, -0.3697, -0.5356, 0.3193],
    #  [-0.4367, -0.3529, -0.5122, 0.3167],
    #  [-0.4103, -0.3570, -0.4806, 0.3282],
    #  [-0.4105, -0.3539, -0.4767, 0.3282],
    #  [-0.4575, -0.3899, -0.5207, 0.3318],
    #  [-0.4575, -0.3899, -0.5207, 0.3318]]

# Experiments

## Logger
We are going to use [wandb](https://wandb.ai/site) for logging. It's a very convenient tool for logging and visualizing the training process. It's free for academic use, so you can create an account and use it for your projects. If you don't want to use wandb, you can use any other online logger (like [comet.ml](https://www.comet.ml/site/)), but you need to implement the appropriate LoggerBase subclass on your own. To setup and use wandb, you need to do the following:
1. [Setup the wandb](https://docs.wandb.ai/quickstart) (or any other online logger).
2. Give your supervisor access to your project (ask him/her about the username.
3. Use the logger for all your trainings and provide the links to the final runs.

In [None]:
class WandbLogger(LoggerBase):
    def __init__(
            self, logdir: str | Path, project_name: str, experiment_name: str, **kwargs: Dict[str, Any]
    ):
        super().__init__(logdir)
        import wandb
        self.project_name = project_name
        self.experiment_name = experiment_name
        self.kwargs = kwargs
        self.run = wandb.init(
            dir=self.logdir,
            project=self.project_name,
            name=self.experiment_name,
            **self.kwargs,
        )

    def log_metrics(self, metrics: Dict[str, Any], prefix: str):
        metrics = {f"{prefix}/{k}": v for k, v in metrics.items()}
        self.run.log(metrics)

    def close(self):
        self.run.finish()

## Task 6. Train GraphSAGE (2 points).
1. Tune hyperparameters of a GNN with `SAGELayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: https://wandb.ai/mariia-saltykova-work/mldd23/runs/u56z6kqq?workspace=user-mariia-saltykova-work

In [None]:
### Example code for training. You can modify it for easier grid-searching.

In [57]:
from datetime import datetime


def get_time_stamp() -> str:
    return datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

In [60]:
from torchmetrics import MeanAbsoluteError as MAE
from torchmetrics import MeanSquaredError as MSE
from torchmetrics import PearsonCorrCoef as PCC

metrics = {
    "mae": MAE(),
    "mse": MSE(),
    "pcc": PCC(),
}

model = GNN(
    node_features_size=node_featurizer.feat_size(),
    edge_features_size=edge_featurizer.feat_size(),
    hidden_size=256,
    output_size=1,
    mpnn_layer_cls=SAGELayer,
    mpnn_n_layers=6,
    readout_cls=AttentionReadout,
    mpnn_layer_kwargs={}
)

trainer = Trainer(
    run_dir="experiments",
    train_dataset=train,
    valid_dataset=valid,
    train_metrics=metrics,
    valid_metrics=metrics,
    train_batch_size=32,
    model=model,
    logger=WandbLogger(
        logdir="runs/mpnn",
        project_name="mldd23",
        experiment_name=f"sage_{get_time_stamp()}",
    ),
    optimizer_kwargs={"lr": 1e-4},
    n_epochs=50,
    device="cpu",
    valid_every_n_epochs=1,
)

valid_metrics = trainer.train()
test_metrics = trainer.test(test)
trainer.close()
print(f"Validation metrics: {valid_metrics}")
print(f"Test metrics: {test_metrics}")

  0%|          | 0/50 [00:00<?, ?it/s]



VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.1089862920968015, max=1.0…

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▆▃▃▃▆▁▂▃▃▃▂▃▃▂▃▂▂▂▂▅▂▃▂▂▁▁▂▂▂▂▂▂▃▂▂▂▂▃▂
train/mae,█▆▅▄▅▅▃▄▄▄▄▃▄▄▄▅▃▃▃▃▄▃▄▃▃▁▂▃▃▃▂▃▃▄▃▃▃▃▄▂
train/mse,█▆▃▃▃▆▁▂▃▃▃▂▃▃▂▃▂▂▂▂▅▂▃▂▂▁▁▂▂▂▂▂▂▃▂▂▂▂▃▂
train/pcc,▁▂▃▅▅▃ ▅▆▄▆█▄▇▆▅▇██▆▇▇▆▆▆ █▇▆██▇▇▇▇▇▇█▆▇
valid/loss,█▃▃▃▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▃▂▂▂▂▁▂▂▂▁▁▂▂▁▁▁▂▁
valid/mae,█▄▄▄▄▄▃▃▃▃▃▃▃▂▂▂▂▃▂▂▁▂▃▂▁▁▂▁▂▁▂▂▁▂▂▁▁▁▂▁

0,1
test/loss,6.82746
test/mae,2.08767
test/mse,8.16321
test/pcc,0.73023
train/loss,2.86545
train/mae,1.69276
train/mse,2.86545
train/pcc,
valid/loss,20.65364
valid/mae,2.84491


Validation metrics: {'loss': 20.653637647628784, 'mae': 2.84490704536438, 'mse': 20.653636932373047, 'pcc': 0.803337574005127}
Test metrics: {'loss': 6.8274595737457275, 'mae': 2.0876688957214355, 'mse': 8.163209915161133, 'pcc': 0.7302309274673462}


In [142]:
from torchmetrics import MeanAbsoluteError as MAE
from torchmetrics import MeanSquaredError as MSE
from torchmetrics import PearsonCorrCoef as PCC
from itertools import product

# Define the hyperparameter grid
grid_params = {
    "hidden_size": [128, 256, 512],
    "mpnn_n_layers": [4, 6, 8],
    "learning_rate": [1e-3, 5e-5, 1e-5],
}

# Define metrics
metrics = {
    "mae": MAE(),
    "mse": MSE(),
    "pcc": PCC(),
}

# Perform grid search
for params in product(*grid_params.values()):
    hidden_size, mpnn_n_layers, learning_rate = params

    # Create model with current hyperparameters
    model = GNN(
        node_features_size=node_featurizer.feat_size(),
        edge_features_size=edge_featurizer.feat_size(),
        hidden_size=hidden_size,
        output_size=1,
        mpnn_layer_cls=SAGELayer,
        mpnn_n_layers=mpnn_n_layers,
        readout_cls=AttentionReadout,
        mpnn_layer_kwargs={}
    )

    # Create Trainer with current hyperparameters
    trainer = Trainer(
        run_dir=f"experiments/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
        train_dataset=train,
        valid_dataset=valid,
        train_metrics=metrics,
        valid_metrics=metrics,
        train_batch_size=32,
        model=model,
        logger=WandbLogger(
            logdir=f"runs/mpnn/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
            project_name="mldd23",
            experiment_name=f"sage_{get_time_stamp()}",
        ),
        optimizer_kwargs={"lr": learning_rate},
        n_epochs=50,
        device="cpu",
        valid_every_n_epochs=1,
    )

    # Train and test the model
    valid_metrics = trainer.train()
    test_metrics = trainer.test(test)

    # Print metrics for each hyperparameter combination
    print(f"Hyperparameters: hidden_size={hidden_size}, mpnn_n_layers={mpnn_n_layers}, learning_rate={learning_rate}")
    print(f"Validation metrics: {valid_metrics}")
    print(f"Test metrics: {test_metrics}")

    # Close the trainer
    trainer.close()

VBox(children=(Label(value='0.001 MB of 0.011 MB uploaded\r'), FloatProgress(value=0.11381108067509058, max=1.…

0,1
train/loss,▅▄▄▅▅▅▅▆▅▄▆▁▇▆▅▅▆▅▅▅▅▅▄▆▄▃▄▃▄█▄▄▃▆▄▃▃▂▃▅
train/mae,▆▅▅▆▅▆▆▇▆▅▇▁▆▇▆▆▇▆▆▆▅▆▆▇▅▄▅▄▅▇▅▅▄█▅▄▄▃▄▅
train/mse,▅▄▄▅▅▅▅▆▅▄▆▁▇▆▅▅▆▅▅▅▅▅▄▆▄▃▄▃▄█▄▄▃▆▄▃▃▂▃▅
train/pcc,▆▅▄▄█▅▄▆▆▇▄ ▅▅▄▆▇▄▅▄▇█ ▆▆▇▄▆▆▅▆▅▄ ▁▅▅▄▆▆
valid/loss,██████████████████▇▇▇▇▇▇▇▇▆▆▆▅▅▄▄▃▂▃▁▂▁▁
valid/mae,██████████████████▇▇▇▇▇▇▇▇▆▆▆▅▅▄▃▃▂▂▁▁▁▁
valid/mse,██████████████████▇▇▇▇▇▇▇▇▆▆▆▅▅▄▄▃▂▃▁▂▁▁
valid/pcc,▁▃▄▄▃▃▂▄▃▃▃▄▃▂▃▂▃▃▂▃▃▂▃▄▅▅▄▅▃▆▅▆▅▆█▅█▇▆▆

0,1
train/loss,20.82832
train/mae,3.11571
train/mse,20.82832
train/pcc,0.31257
valid/loss,34.71943
valid/mae,4.13385
valid/mse,34.71943
valid/pcc,0.38845


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': 9.983070433139801, 'mae': 1.859006404876709, 'mse': 9.983070373535156, 'pcc': 0.8783034682273865}
Test metrics: {'loss': 4.382921511679887, 'mae': 1.7279000282287598, 'mse': 5.38176965713501, 'pcc': 0.7907462120056152}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▄█▃▃▂▂▃▃▂▂▂▂▃▂▂▂▁▁▂▂▂▁▂▂▂▁▂▂▂▂▃▂▂▁▁▁▁▂▃▂
train/mae,██▆▆▅▄▇▅▅▅▄▄▆▅▄▅▃▃▄▄▄▃▄▄▄▁▄▄▅▄▅▄▃▃▃▃▃▃▄▃
train/mse,▄█▃▃▂▂▃▃▂▂▂▂▃▂▂▂▁▁▂▂▂▁▂▂▂▁▂▂▂▂▃▂▂▁▁▁▁▂▃▂
train/pcc,▁▂▅▆▅▇ ▇▆▇▇▆▆▇▇▇██▇▇▇█▇▇▇ █▇▇▇▇███████▇█
valid/loss,██▇▅▅▅▅▄▅▄▄▅▃▃▄▃▃▂▅▂▃▂▃▁▄▂▁▂▁▃▃▂▂▃▂▂▂▂▁▁
valid/mae,██▇▅▅▄▄▄▄▄▃▄▄▃▃▃▂▂▄▂▂▂▃▂▄▁▁▂▁▃▄▂▂▃▂▁▂▂▁▁

0,1
test/loss,4.38292
test/mae,1.7279
test/mse,5.38177
test/pcc,0.79075
train/loss,7.38923
train/mae,2.71831
train/mse,7.38923
train/pcc,
valid/loss,9.98307
valid/mae,1.85901


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 26.6795015335083, 'mae': 3.1615848541259766, 'mse': 26.679500579833984, 'pcc': 0.7990860342979431}
Test metrics: {'loss': 10.711921215057373, 'mae': 2.7622146606445312, 'mse': 13.10132884979248, 'pcc': 0.6576178073883057}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▇▇▃▅▄▁▃▃▃▄▃▅▃▄▄▄▄▃▃▃▃▃▃▂▂▄▄▂▂▃▃▄▂▂▂▃▂▂▃
train/mae,█▇█▅▆▆▁▅▅▅▆▅▇▅▆▅▅▆▅▅▅▅▅▄▄▅▅▅▄▄▅▄▅▄▃▄▄▄▄▄
train/mse,█▇▇▃▅▄▁▃▃▃▄▃▅▃▄▄▄▄▃▃▃▃▃▃▂▂▄▄▂▂▃▃▄▂▂▂▃▂▂▃
train/pcc,▆▄▁▄▃▅ ▅▅▆▅▇▄▅▄▆▆▆▆▅▇▆▇▅▆ ▆▆▇▇▇▇▇▇▇█▇▇▇▇
valid/loss,██▇▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁
valid/mae,██▇▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,10.71192
test/mae,2.76221
test/mse,13.10133
test/pcc,0.65762
train/loss,2.53066
train/mae,1.59081
train/mse,2.53066
train/pcc,
valid/loss,26.6795
valid/mae,3.16158


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 35.9278130531311, 'mae': 4.158862590789795, 'mse': 35.92781448364258, 'pcc': 0.4741687774658203}
Test metrics: {'loss': 15.588681602478028, 'mae': 3.293412446975708, 'mse': 17.163393020629883, 'pcc': 0.2675379514694214}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▆▆▃█▃▅▅▆▅▅▇▃▃▂▃▂▃▂▃▂▅▂▁▃▃▂▁▃▂▂▁▁▂▁▂▁▂▂▂
train/mae,▇█▇▅▇▃█▆▆▆▆█▄▄▃▃▃▄▃▃▂▄▃▁▃▅▃▂▃▃▂▁▁▂▂▂▁▃▃▃
train/mse,▆▆▆▃█▃▅▅▆▅▅▇▃▃▂▃▂▃▂▃▂▅▂▁▃▃▂▁▃▂▂▁▁▂▁▂▁▂▂▂
train/pcc,▆▂▇▅▄▂ ▅▂▅▅▁▅▆▁▆▄▅▁▃▆▆▄▇▅ ▃▇▅█▃██▆▅▆▃▇█▁
valid/loss,███████▇▇▇▇▆▅▄▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/mae,███████▇▇▇▆▆▄▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,15.58868
test/mae,3.29341
test/mse,17.16339
test/pcc,0.26754
train/loss,27.29477
train/mae,5.22444
train/mse,27.29477
train/pcc,
valid/loss,35.92781
valid/mae,4.15886


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': 11.156673192977905, 'mae': 2.0489845275878906, 'mse': 11.156673431396484, 'pcc': 0.8877020478248596}
Test metrics: {'loss': 4.658344233036042, 'mae': 1.8460578918457031, 'mse': 5.656925678253174, 'pcc': 0.7675876617431641}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▅▄▄▅▄▃▂██▁▂▅▇▁▅▁▂▂▃▁▂▁▃▂▁▁▁▂▂▃▁▂▁▁▂▁▅▁▅▁
train/mae,▇▆▆█▆▄▅▅▆▂▃█▆▂▄▂▃▂▄▂▃▂▄▃▁▃▂▂▃▅▁▃▂▂▃▂▅▂▃▂
train/mse,▅▄▄▅▄▃▂██▁▂▅▇▁▅▁▂▂▃▁▂▁▃▂▁▁▁▂▂▃▁▂▁▁▂▁▅▁▅▁
train/pcc,▁▁▂▄▃▆ ▄▅▆▅▆▄▆▆▆▆▇▄▇▇▇▅▆█ █▇▇▆▇▆█▇▇▇▇▇▆▇
valid/loss,█▆▄▅▄▄▄▅▄▄▄▂▃▃▂▄▃▃▅▃▂▂▂▂▃▁▁▁▁▅▃▂▂▂▁▁▂▁▂▁
valid/mae,█▇▆▅▅▄▄▄▄▄▃▃▃▃▂▄▄▃▆▂▂▂▂▂▃▁▁▁▂▅▃▂▃▃▂▁▂▁▂▁

0,1
test/loss,4.65834
test/mae,1.84606
test/mse,5.65693
test/pcc,0.76759
train/loss,0.1961
train/mae,0.44283
train/mse,0.1961
train/pcc,
valid/loss,11.15667
valid/mae,2.04898


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': 27.941579818725586, 'mae': 3.39616060256958, 'mse': 27.941577911376953, 'pcc': 0.8774385452270508}
Test metrics: {'loss': 12.089959192276002, 'mae': 2.9861743450164795, 'mse': 14.448640823364258, 'pcc': 0.6419483423233032}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▇▅▃▄▂▁▂▃▃▃▃▃▃▂▃▃▂▅▂▂▂▂▂▃▂▂▂▂▂▄▂▂▂▃▃▂▃▂▅
train/mae,▇█▇▅▅▄▁▄▅▅▅▅▄▅▄▅▅▄▅▄▄▄▄▄▅▅▄▄▄▄▅▄▄▄▄▄▃▄▄▅
train/mse,█▇▅▃▄▂▁▂▃▃▃▃▃▃▂▃▃▂▅▂▂▂▂▂▃▂▂▂▂▂▄▂▂▂▃▃▂▃▂▅
train/pcc,▃▃▂▁▂▅ ▆▄▁▆▆▆▅▆▆▆▅▄▆▇▆▅█▄ ▇█▇▆▇▇█▆▆▆▇▆▇▆
valid/loss,██▆▃▃▃▃▂▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁
valid/mae,██▆▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁

0,1
test/loss,12.08996
test/mae,2.98617
test/mse,14.44864
test/pcc,0.64195
train/loss,30.92012
train/mae,5.56059
train/mse,30.92012
train/pcc,
valid/loss,27.94158
valid/mae,3.39616


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 35.57434916496277, 'mae': 4.143280506134033, 'mse': 35.57434844970703, 'pcc': 0.47397857904434204}
Test metrics: {'loss': 15.641404724121093, 'mae': 3.314692258834839, 'mse': 17.108966827392578, 'pcc': 0.2835009694099426}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▇▆▆▅█▆▄▇▇▄▄▃▃▂▂▄▃▃▄▅▄▃▃▃▃▁▂▃▃▃▂▄▃▃▂▃▅▄▂▃
train/mae,▇▇▇▅█▇▇▆▇▄▄▄▄▃▃▅▄▄▄▅▄▄▄▄▃▁▂▄▄▄▂▄▃▄▃▄▅▅▃▄
train/mse,▇▆▆▅█▆▄▇▇▄▄▃▃▂▂▄▃▃▄▅▄▃▃▃▃▁▂▃▃▃▂▄▃▃▂▃▅▄▂▃
train/pcc,▁▅▆▄▄▅ ▃▆▄▅▅▃▅▅▄▅▆▇▄▅▅▇▅▆ ▅▆▅▅▃▃▇▅▆█▅▅▇▅
valid/loss,██████▇▇▇▆▅▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/mae,██████▇▇▇▅▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,15.6414
test/mae,3.31469
test/mse,17.10897
test/pcc,0.2835
train/loss,2.53848
train/mae,1.59326
train/mse,2.53848
train/pcc,
valid/loss,35.57435
valid/mae,4.14328


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,██▃▄▃▃▁▂▂▂▂▃▄▂▂▂▂▂▃
train/mae,▆█▅▆▄▅▁▄▄▃▄▄▄▃▄▄▃▃▅
train/mse,██▃▄▃▃▁▂▂▂▂▃▄▂▂▂▂▂▃
train/pcc,▁▁▆▃▅▄ ▆▆▇▆▆▆█▇▇█▇▆
valid/loss,█▄▇▃▃▅▂▂▂▂▃▅▁▂▃▁▁▂▃
valid/mae,█▅▇▄▃▅▂▂▂▂▃▅▂▂▃▁▁▂▃
valid/mse,█▄▇▃▃▅▂▂▂▂▃▅▁▂▃▁▁▂▃
valid/pcc,▁▃▅▅██████▇▇█▇▇███▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 24.858879566192627, 'mae': 3.2458415031433105, 'mse': 24.8588809967041, 'pcc': 0.8689024448394775}
Test metrics: {'loss': 10.925103783607483, 'mae': 2.8592495918273926, 'mse': 13.100049018859863, 'pcc': 0.6526756882667542}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▇▃▃▃▃▅▃▂▂▇▃▂▂▁▂▂▂▁▂▂▂▂▂▂▄▁▂▂▁▂▂▁▂▁▂▂▂▂▁
train/mae,█▇▄▄▄▄█▄▃▃▆▄▃▂▂▃▃▃▂▂▃▃▂▃▃▇▁▃▃▂▂▂▂▂▂▂▂▃▂▁
train/mse,█▇▃▃▃▃▅▃▂▂▇▃▂▂▁▂▂▂▁▂▂▂▂▂▂▄▁▂▂▁▂▂▁▂▁▂▂▂▂▁
train/pcc,▃▁▄▆▁▄ ▂▅▃▂▃▄▅▄▂▆▅▆▄▄▄▅▄▅ ▇▆▇▇█▄█▇▇▅▇▆▆▇
valid/loss,██▅▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▁▂▂▂▂▂▁▁▁▁▁▂▁▁▁▁▁▁
valid/mae,██▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,10.9251
test/mae,2.85925
test/mse,13.10005
test/pcc,0.65268
train/loss,11.83854
train/mae,3.44072
train/mse,11.83854
train/pcc,
valid/loss,24.85888
valid/mae,3.24584


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 32.48310208320618, 'mae': 4.063084602355957, 'mse': 32.48310089111328, 'pcc': 0.49595946073532104}
Test metrics: {'loss': 14.111568832397461, 'mae': 3.0809781551361084, 'mse': 14.83796215057373, 'pcc': 0.3361538052558899}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▇▆▆▇██▂▃▃▃▃▂▃▂▃▃▂▃▃▂▃▂▃▃▃▁▆▃▂▃▃▅▂▄▂▂▂▃▃▂
train/mae,█▇▆▇██▄▅▅▄▄▃▄▄▄▄▄▄▄▃▄▃▄▄▄▁▅▄▃▄▄▅▃▅▄▃▃▄▄▃
train/mse,▇▆▆▇██▂▃▃▃▃▂▃▂▃▃▂▃▃▂▃▂▃▃▃▁▆▃▂▃▃▅▂▄▂▂▂▃▃▂
train/pcc,▁▆▅▄▅▅ ▄▅▅▃▆▇▆▅▆▅▆▇▇▇▇▆▅▅ ▆▅█▆▆▅▇▇▇▇▅▅▇▇
valid/loss,█████▇▇▆▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▁▁▁▁▁▁▁▁
valid/mae,█████▇▇▆▄▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,14.11157
test/mae,3.08098
test/mse,14.83796
test/pcc,0.33615
train/loss,5.61942
train/mae,2.37053
train/mse,5.61942
train/pcc,
valid/loss,32.4831
valid/mae,4.06308


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▅▆▃▃▂▁▃██▃▃▂▃▃▂▁▂▄▂▁▂▂▂▃▁▃▂▃▂
train/mae,█▆▆▄▄▂▁▃▆▅▂▄▂▄▄▃▁▃▄▁▁▂▁▁▃▁▃▃▄▃
train/mse,█▅▆▃▃▂▁▃██▃▃▂▃▃▂▁▂▄▂▁▂▂▂▃▁▃▂▃▂
train/pcc,▁▃▅▆▅▆ ▆▆▄▆▆▇▇▇▇█▆▆▇██▇█▇ ▆▆▆▆
valid/loss,█▄▄▄▄▄▃▃▅▂▃▄▃▃▃▄▂▂▃▂▂▁▂▂▂▄▄▃▂▃
valid/mae,█▅▅▄▄▄▄▄▆▂▃▅▃▄▃▄▂▃▄▃▂▁▂▂▂▆▅▃▃▄
valid/mse,█▄▄▄▄▄▃▃▅▂▃▄▃▃▃▄▂▂▃▂▂▁▂▂▂▄▄▃▂▃
valid/pcc,▁▄▆▆▅▆▆▇▆▇▇▇▇▇▇▇▇▇▇▇█▇███▇▇▇▇▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 25.837578892707825, 'mae': 3.0757880210876465, 'mse': 25.83757781982422, 'pcc': 0.7785635590553284}
Test metrics: {'loss': 10.295427074935287, 'mae': 2.746832847595215, 'mse': 12.67025089263916, 'pcc': 0.6826815605163574}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▃▃▂▅▃▆▂▂▃▂▃▂▂▂▂▂▂▂▃▂▂▂▂▁▁▂▂▁▁▂▂▁▁▂▂▁▁▂▁
train/mae,▆▄▃▃▄▃█▂▃▄▂▃▂▃▂▂▃▂▂▂▂▂▂▃▂▃▂▂▂▂▂▂▁▁▂▃▁▁▂▁
train/mse,█▃▃▂▅▃▆▂▂▃▂▃▂▂▂▂▂▂▂▃▂▂▂▂▁▁▂▂▁▁▂▂▁▁▂▂▁▁▂▁
train/pcc,▃▃▂▅▁▃ ▄▂▄▄▄▅▅▆▆▆▆▆▆▆▆▆▆█ ▆▇▇▇▆▇█▇▆▇▇█▇▇
valid/loss,█▇▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▁▂▂▂▂▁▂▁▁▂▁▁▂▂▂▁▂▂▂▂
valid/mae,█▇▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▁▂▂▂▂▁▂▁▁▂▁▁▁▁▂▁▁▁▁▁

0,1
test/loss,10.29543
test/mae,2.74683
test/mse,12.67025
test/pcc,0.68268
train/loss,1.14666
train/mae,1.07082
train/mse,1.14666
train/pcc,
valid/loss,25.83758
valid/mae,3.07579


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 31.91602873802185, 'mae': 3.985349655151367, 'mse': 31.916027069091797, 'pcc': 0.5483509302139282}
Test metrics: {'loss': 13.641582679748534, 'mae': 3.0319154262542725, 'mse': 14.84697437286377, 'pcc': 0.3986424207687378}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▇▇▆▇▄▆▁▄▄▃▄▄▃▄▃▃▃▃▃▃▂▂▂▃▃▁▃▃█▆▃▆▃▃▃▃▂▂▃▃
train/mae,█▇▇▇▅▇▂▅▅▄▅▅▄▅▄▄▄▄▄▄▄▃▃▄▄▁▄▅▇▅▅▆▄▃▃▄▃▃▄▄
train/mse,▇▇▆▇▄▆▁▄▄▃▄▄▃▄▃▃▃▃▃▃▂▂▂▃▃▁▃▃█▆▃▆▃▃▃▃▂▂▃▃
train/pcc,▃▇▃▇▂▃ ▄▃▇▁▃▁▅▅▄▂▃▄▄▃▅▅▅▇ ▄▅▁▄▃▄▇▆▆█▇▇▅█
valid/loss,███▇▇▆▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁
valid/mae,███▇▇▆▅▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,13.64158
test/mae,3.03192
test/mse,14.84697
test/pcc,0.39864
train/loss,10.39607
train/mae,3.22429
train/mse,10.39607
train/pcc,
valid/loss,31.91603
valid/mae,3.98535


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▄▄▆▄▃▁▂▅▃▃▂▂▂▂▂▂▂▃▃▃▂▂▃▃▁▂▂▃▂▃▂▃▃▃
train/mae,█▅▆▇▆▅▂▄▇▄▄▄▄▃▃▃▃▂▄▄▄▃▃▅▄▁▄▃▄▃▄▃▄▄▄
train/mse,█▄▄▆▄▃▁▂▅▃▃▂▂▂▂▂▂▂▃▃▃▂▂▃▃▁▂▂▃▂▃▂▃▃▃
train/pcc,▂▁▄▄▃▃ ▆▇▆▅▇▆▇▇▇█▇▇▅▅█▇▇▇ █▆▆█▆█▇▇▅
valid/loss,█▅▄▃▃▃▃▃▅▃▃▃▂▂▂▂▂▂▃▃▁▂▂▁▂▂▃▁▁▁▂▁▂▂
valid/mae,█▅▄▄▄▃▃▃▆▃▃▃▃▂▂▂▂▂▃▃▂▂▂▂▁▂▄▂▂▁▂▁▂▂
valid/mse,█▅▄▃▃▃▃▃▅▃▃▃▂▂▂▂▂▂▃▃▁▂▂▁▂▂▃▁▁▁▂▁▂▂
valid/pcc,▁▂▃▄▅▇▇▇▇▇▆▆▆▆▇▇▇▇▇▅▇████▇██████▇▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': 27.31190013885498, 'mae': 3.2377779483795166, 'mse': 27.311899185180664, 'pcc': 0.8182260394096375}
Test metrics: {'loss': 10.452423058450222, 'mae': 2.740644931793213, 'mse': 12.836944580078125, 'pcc': 0.7385573983192444}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆█▄▃▄▃▂▃▃▂▂▂▃▃▃▁▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▁▃▆▁▂▁▁▂
train/mae,▆█▅▄▅▄▃▃▄▃▃▃▄▄▄▂▃▃▂▃▃▃▂▂▃▂▂▂▂▂▃▂▂▄▄▁▃▁▁▂
train/mse,▆█▄▃▄▃▂▃▃▂▂▂▃▃▃▁▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▁▃▆▁▂▁▁▂
train/pcc,▂▄▃▅▂▁ ▅▄▆▆▅▅▅▇▇▆▇█▇▇▅▇▇▇ ▆▇▇▆▇██▅▆█▇███
valid/loss,█▆▃▂▃▃▃▃▂▂▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▂▁▂▁▁▂▁▂
valid/mae,█▆▄▃▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂

0,1
test/loss,10.45242
test/mae,2.74064
test/mse,12.83694
test/pcc,0.73856
train/loss,0.01905
train/mae,0.13801
train/mse,0.01905
train/pcc,
valid/loss,27.3119
valid/mae,3.23778


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 30.33459186553955, 'mae': 3.8178863525390625, 'mse': 30.334590911865234, 'pcc': 0.6725152730941772}
Test metrics: {'loss': 12.701448917388916, 'mae': 2.9681267738342285, 'mse': 14.339152336120605, 'pcc': 0.49347081780433655}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▅█▂▅▃▂▁▂▂▂▂▂▂▃▁▄▁▂▂▁▁▃▂▂▂▁▂▁▄▂▁▂▂▁▂▂▁▁▁
train/mae,█▆▇▃▆▃▅▁▂▃▂▂▂▂▄▂▃▂▃▂▂▂▄▃▃▄▁▂▂▂▂▁▃▂▁▃▂▁▁▂
train/mse,▆▅█▂▅▃▂▁▂▂▂▂▂▂▃▁▄▁▂▂▁▁▃▂▂▂▁▂▁▄▂▁▂▂▁▂▂▁▁▁
train/pcc,▄▄▁▅▃▂ ▅▄▆▃▇▅▂▆▆▅▄▇█▆▆▃▆▅ █▇▇▄▆█▆█▇▆▆▆▇█
valid/loss,███▇▆▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▂▁▁▂▁▁▁▁▁▁▁▁▁
valid/mae,███▇▆▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,12.70145
test/mae,2.96813
test/mse,14.33915
test/pcc,0.49347
train/loss,17.18408
train/mae,4.14537
train/mse,17.18408
train/pcc,
valid/loss,30.33459
valid/mae,3.81789


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▄▃▂▃▃▆▁▁▂▃▃█
train/mae,▄▃▂▂▄▄▁▁▂▃▃█
train/mse,▄▃▂▃▃▆▁▁▂▃▃█
train/pcc,▅▄▄▆▂▄ █▆█▆▁
valid/loss,█▅▃▃▅▃▁▃▂▁▂█
valid/mae,█▄▃▂▅▃▁▂▁▁▁█
valid/mse,█▅▃▃▅▃▁▃▂▁▂█
valid/pcc,▃▄▆▆▃▆▇▇██▇▁

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 20.266576766967773, 'mae': 2.959890365600586, 'mse': 20.266576766967773, 'pcc': 0.8457738161087036}
Test metrics: {'loss': 8.747647953033447, 'mae': 2.4525482654571533, 'mse': 10.527332305908203, 'pcc': 0.6731082797050476}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▄▂▂▂▃▆▂▃▂▃▂▂▂▂▂▂▂▂▄▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂
train/mae,█▅▃▂▂▃█▂▃▃▃▂▂▂▃▂▂▂▂▃▂▂▂▂▂▁▂▁▂▂▁▂▂▁▃▂▂▂▁▂
train/mse,█▄▂▂▂▃▆▂▃▂▃▂▂▂▂▂▂▂▂▄▂▂▂▂▂▁▂▂▂▂▂▂▂▂▂▂▂▂▁▂
train/pcc,▁▁▂▅▃▅ ▄▂▅▅▅▅▂▆▅▅▅▅▇▅▆▆▇▇ ▆▅▅▆▅▅▆▆▇▆▃▄█▆
valid/loss,█▆▃▃▄▃▃▃▂▂▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▂▂▂▁▁▂▁▂▁▂▁▁▁▁
valid/mae,█▅▄▃▄▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁

0,1
test/loss,8.74765
test/mae,2.45255
test/mse,10.52733
test/pcc,0.67311
train/loss,4.22558
train/mae,2.05562
train/mse,4.22558
train/pcc,
valid/loss,20.26658
valid/mae,2.95989


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 30.64444661140442, 'mae': 3.8716399669647217, 'mse': 30.644447326660156, 'pcc': 0.6406511068344116}
Test metrics: {'loss': 13.15852975845337, 'mae': 3.0250561237335205, 'mse': 14.5305814743042, 'pcc': 0.4454938471317291}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▅█▇▅▃▃▂▃▃▂▃▃▃▃▃▃▃▃▆▃▃▃▃▃▃▁▃▂▃▂▃▄▃▂▂▃▃▄▃▂
train/mae,▆█▇▆▅▄▄▄▅▄▄▄▄▅▄▅▅▅▅▄▄▅▄▅▅▁▄▄▄▄▅▅▅▃▄▄▄▅▄▄
train/mse,▅█▇▅▃▃▂▃▃▂▃▃▃▃▃▃▃▃▆▃▃▃▃▃▃▁▃▂▃▂▃▄▃▂▂▃▃▄▃▂
train/pcc,▅▄▅▆▃▄ ▂▅▄▄▆▅▅▁▂▂▃▂▃▄▅▄▅▃ ▅▅▆▃▅▄▄▄▆▅▄▆█▇
valid/loss,███▆▅▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▁▁▁▁▁▂▁
valid/mae,███▆▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,13.15853
test/mae,3.02506
test/mse,14.53058
test/pcc,0.44549
train/loss,16.45237
train/mae,4.05615
train/mse,16.45237
train/pcc,
valid/loss,30.64445
valid/mae,3.87164


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▄▁▃▂█▁█▅
train/mae,▃▁▃▁▃▁█▄
train/mse,▄▁▃▂█▁█▅
train/pcc,▅▁▃█▁█ ▂
valid/loss,▂▂▃▃▆▁█▃
valid/mae,▂▂▂▂▆▁█▂
valid/mse,▂▂▃▃▆▁█▃
valid/pcc,▆▅▆█▁▇▂▆

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 19.638887524604797, 'mae': 2.7539782524108887, 'mse': 19.638887405395508, 'pcc': 0.8186450600624084}
Test metrics: {'loss': 7.898186048492789, 'mae': 2.405573844909668, 'mse': 9.715441703796387, 'pcc': 0.742785632610321}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▅█▂▃▃▁▃▃▁▅▂▁▂▂▁▁▁▂▁▁▂▂▂▃▃▁▁▁▁▄▁▂▁▂▄▁▂▁▁
train/mae,█▆▆▃▅▄▂▅▅▂▄▄▂▂▃▂▂▂▃▁▂▃▃▂▄▅▂▂▁▁▃▂▃▂▃▄▁▂▁▂
train/mse,▆▅█▂▃▃▁▃▃▁▅▂▁▂▂▁▁▁▂▁▁▂▂▂▃▃▁▁▁▁▄▁▂▁▂▄▁▂▁▁
train/pcc,▁▁▁▆▄▂ ▄▄▇▅▅▇▆▆▇▇▇▆█▇▆▆▆▆ ███▇▅▇▆▇█▆█▇▇█
valid/loss,█▃▄▃▃▃▂▃▃▂▃▃▂▃▂▂▂▂▂▂▂▂▂▂▃▂▂▁▂▂▂▂▁▂▁▁▁▁▁▁
valid/mae,█▄▄▄▄▄▃▃▃▃▃▃▂▃▂▂▂▂▂▂▂▁▂▂▂▁▁▁▂▁▂▂▁▂▂▁▁▁▁▁

0,1
test/loss,7.89819
test/mae,2.40557
test/mse,9.71544
test/pcc,0.74279
train/loss,0.0167
train/mae,0.12922
train/mse,0.0167
train/pcc,
valid/loss,19.63889
valid/mae,2.75398


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 26.486700534820557, 'mae': 3.2930476665496826, 'mse': 26.486698150634766, 'pcc': 0.8046021461486816}
Test metrics: {'loss': 12.421636819839478, 'mae': 2.989426851272583, 'mse': 14.706385612487793, 'pcc': 0.55301433801651}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▅▆█▃▃▃▃▃▃▃▅▃▃▂▃▂▃▂▂▃▃▅▃▃▁▂▂▃▃▃▂▅▂▂▂▂▂▄▂
train/mae,▇▇█▇▅▅▆▄▅▅▄▅▅▅▄▄▄▄▃▄▅▄▅▅▄▁▄▄▅▅▄▃▅▄▃▃▃▃▄▄
train/mse,▆▅▆█▃▃▃▃▃▃▃▅▃▃▂▃▂▃▂▂▃▃▅▃▃▁▂▂▃▃▃▂▅▂▂▂▂▂▄▂
train/pcc,▁▄▄▂▅▅ ▄▃▄▆▄▆▅▆▅▃▃▅▇▇▅▃▆▅ ▅▅▇▆▇▄▅▆█▇██▆▇
valid/loss,██▇▅▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▂▁▁▁▁▁▁▁▁
valid/mae,██▇▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁

0,1
test/loss,12.42164
test/mae,2.98943
test/mse,14.70639
test/pcc,0.55301
train/loss,1.14434
train/mae,1.06974
train/mse,1.14434
train/pcc,
valid/loss,26.4867
valid/mae,3.29305


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▆▆█▁
train/mae,▄█▅▁
train/mse,▆▆█▁
train/pcc,█▂▁▆
valid/loss,▁█▂▁
valid/mae,▂█▁▂
valid/mse,▁█▂▁
valid/pcc,█▁▂▃

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': 16.102939128875732, 'mae': 2.838378429412842, 'mse': 16.10293960571289, 'pcc': 0.8217169046401978}
Test metrics: {'loss': 6.715177154541015, 'mae': 2.060702323913574, 'mse': 7.741240501403809, 'pcc': 0.7574983835220337}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▃▄█▃▃▃▃▃▄▂▄▃▃▃▃▂▂▃▂▃▃▂▃▃▁▁▁▂▂▂▂▂▁▂▁▁▁▁▂
train/mae,█▅▆█▅▆▆▅▆▆▃▆▄▆▅▄▄▃▅▄▄▅▃▄▃▃▂▂▂▄▄▄▃▁▄▂▂▂▂▃
train/mse,▆▃▄█▃▃▃▃▃▄▂▄▃▃▃▃▂▂▃▂▃▃▂▃▃▁▁▁▂▂▂▂▂▁▂▁▁▁▁▂
train/pcc,▃▄▂▁▃▃ ▆▅▆▆▂▃▃▅▄▆▇▅▇▅▆▆▇▇ ▇▇▇██▇▇██▇███▇
valid/loss,█▄▃▄▃▃▃▃▃▃▃▂▃▂▂▃▂▂▂▃▄▂▂▂▂▁▂▂▂▁▂▁▁▂▂▁▁▂▁▂
valid/mae,█▄▄▄▄▄▄▃▃▃▃▃▃▂▃▃▂▃▂▃▄▃▂▂▂▂▂▂▂▁▂▁▁▂▂▂▂▂▂▂

0,1
test/loss,6.71518
test/mae,2.0607
test/mse,7.74124
test/pcc,0.7575
train/loss,18.37227
train/mae,4.28629
train/mse,18.37227
train/pcc,
valid/loss,16.10294
valid/mae,2.83838


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 26.35999631881714, 'mae': 3.3003995418548584, 'mse': 26.359996795654297, 'pcc': 0.8746218681335449}
Test metrics: {'loss': 12.14793803691864, 'mae': 3.0188043117523193, 'mse': 14.552294731140137, 'pcc': 0.6064296960830688}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,██▅▃▄▃▁▃▄▃▇▃▃▄▃▃▆▃▂▃▄▄▃▃▄▁▃▃▃▃▃▂▆▃▃▃▃▂▂▃
train/mae,▇█▆▅▅▅▁▄▅▅▅▅▅▅▅▄▅▄▄▄▆▅▅▄▅▂▄▅▄▅▅▄▅▄▄▄▄▃▄▄
train/mse,██▅▃▄▃▁▃▄▃▇▃▃▄▃▃▆▃▂▃▄▄▃▃▄▁▃▃▃▃▃▂▆▃▃▃▃▂▂▃
train/pcc,▃▅▄▆▁▅ ▃▅▃▂▅█▆▅▄▄█▇▆▄▄▅▆▄ ▆▆▇█▇▇▆▇▇▅▇█▇▆
valid/loss,██▆▃▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▁▁▁▁▁▁
valid/mae,██▆▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,12.14794
test/mae,3.0188
test/mse,14.55229
test/pcc,0.60643
train/loss,0.15306
train/mae,0.39123
train/mse,0.15306
train/pcc,
valid/loss,26.36
valid/mae,3.3004


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▁█
train/mae,▁█
train/mse,▁█
train/pcc,█▁
valid/loss,▁█
valid/mae,▁█
valid/mse,▁█
valid/pcc,█▁

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 28.530992031097412, 'mae': 3.2421507835388184, 'mse': 28.53099250793457, 'pcc': 0.765211820602417}
Test metrics: {'loss': 11.803349660336972, 'mae': 2.922029972076416, 'mse': 14.523759841918945, 'pcc': 0.6899458765983582}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▇▄▄▅▅▁▄▃▃▄▃▃▅▅▃▃▄▃▃▄▃▃▃▄▂▅▃▃▃▄▄▃▂▄▂█▄▃▄
train/mae,█▇▅▅▇▆▁▅▅▅▅▅▄▆▅▄▅▅▄▄▅▄▄▄▅▄▅▄▄▄▆▅▄▃▅▄▅▅▄▅
train/mse,█▇▄▄▅▅▁▄▃▃▄▃▃▅▅▃▃▄▃▃▄▃▃▃▄▂▅▃▃▃▄▄▃▂▄▂█▄▃▄
train/pcc,▁▄▄▂▂▃ ▅▅▆▅█▆▆▄▆▇▅▇▅▅▇▅▆▇ ▅█▆▇▅▆▇▇▆█▄▆▇▇
valid/loss,█▅▃▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▄▃▂▂▂▂▂▃▂▃▁▂▂▂▂▂▁▁▃▁▂▃
valid/mae,█▄▄▄▃▃▃▃▂▃▂▂▂▃▂▁▂▁▄▃▂▂▂▂▁▂▂▂▁▂▂▁▁▁▁▁▂▁▁▂

0,1
test/loss,11.80335
test/mae,2.92203
test/mse,14.52376
test/pcc,0.68995
train/loss,0.09561
train/mae,0.30921
train/mse,0.09561
train/pcc,
valid/loss,28.53099
valid/mae,3.24215


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 23.70927381515503, 'mae': 3.379948377609253, 'mse': 23.709272384643555, 'pcc': 0.8847461938858032}
Test metrics: {'loss': 10.772088718414306, 'mae': 2.731513023376465, 'mse': 12.207980155944824, 'pcc': 0.603986918926239}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▇█▆▅▃▅▁▃▄▇▂▄▇█▄▄▅▄▄▃▃▄▄▃▃▁▃▃▃▃▃▂▃▃▃▄▂▂▂▃
train/mae,██▇▆▄▅▁▄▄▅▃▅▅▆▅▅▆▅▅▃▄▅▄▄▃▁▄▄▄▄▃▃▄▄▄▄▂▂▂▃
train/mse,▇█▆▅▃▅▁▃▄▇▂▄▇█▄▄▅▄▄▃▃▄▄▃▃▁▃▃▃▃▃▂▃▃▃▄▂▂▂▃
train/pcc,▁▁▃▃▄▂ ▄▁▂▄▃▄▁▆▅▄▂▄▆▅▅▄▄▄ ▆▅▆█▆▇▃▆▅▃▇▆▅▆
valid/loss,██▅▄▃▃▃▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▁▁▁▁▁▁▁▁
valid/mae,██▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,10.77209
test/mae,2.73151
test/mse,12.20798
test/pcc,0.60399
train/loss,6.83884
train/mae,2.61512
train/mse,6.83884
train/pcc,
valid/loss,23.70927
valid/mae,3.37995


## Task 7. Train GIN (2 points).
1. Tune hyperparameters of a GNN with `GINLayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: [your link]

In [143]:
from torchmetrics import MeanAbsoluteError as MAE
from torchmetrics import MeanSquaredError as MSE
from torchmetrics import PearsonCorrCoef as PCC
from itertools import product

# Define the hyperparameter grid
grid_params = {
    "hidden_size": [128, 256, 512],
    "mpnn_n_layers": [4, 6, 8],
    "learning_rate": [1e-3, 5e-5, 1e-5],
}

# Define metrics
metrics = {
    "mae": MAE(),
    "mse": MSE(),
    "pcc": PCC(),
}

# Perform grid search
for params in product(*grid_params.values()):
    hidden_size, mpnn_n_layers, learning_rate = params

    # Create model with current hyperparameters
    model = GNN(
        node_features_size=node_featurizer.feat_size(),
        edge_features_size=edge_featurizer.feat_size(),
        hidden_size=hidden_size,
        output_size=1,
        mpnn_layer_cls=GINLayer,
        mpnn_n_layers=mpnn_n_layers,
        readout_cls=AttentionReadout,
        mpnn_layer_kwargs={}
    )

    # Create Trainer with current hyperparameters
    trainer = Trainer(
        run_dir=f"experiments/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
        train_dataset=train,
        valid_dataset=valid,
        train_metrics=metrics,
        valid_metrics=metrics,
        train_batch_size=32,
        model=model,
        logger=WandbLogger(
            logdir=f"runs/mpnn/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
            project_name="mldd23",
            experiment_name=f"sage_{get_time_stamp()}",
        ),
        optimizer_kwargs={"lr": learning_rate},
        n_epochs=50,
        device="cpu",
        valid_every_n_epochs=1,
    )

    # Train and test the model
    valid_metrics = trainer.train()
    test_metrics = trainer.test(test)

    # Print metrics for each hyperparameter combination
    print(f"Hyperparameters: hidden_size={hidden_size}, mpnn_n_layers={mpnn_n_layers}, learning_rate={learning_rate}")
    print(f"Validation metrics: {valid_metrics}")
    print(f"Test metrics: {test_metrics}")

    # Close the trainer
    trainer.close()

  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▅▇▅▄▄▁▃▃▁▁▂▅▂▂▂▂▂▂▂▂
train/mae,█▇█▆▅▄▁▃▄▂▁▂▆▂▃▃▂▂▂▃▃
train/mse,█▅▇▅▄▄▁▃▃▁▁▂▅▂▂▂▂▂▂▂▂
train/pcc,▂▁▅▃▆▇ ▇▇▇█▇▇█▇▇▇▇███
valid/loss,███▆▅▁▃▂▂▁▂▂▅▃▁▄▁▂▄▁
valid/mae,██▇▆▆▂▃▂▂▁▂▂▅▃▂▂▁▁▃▂
valid/mse,███▆▅▁▃▂▂▁▂▂▅▃▁▄▁▂▄▁
valid/pcc,▁▃▄▆▆█▇▇▇█▇▇▇▇█▆██▇█

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 16.617879152297974, 'mae': 2.867065906524658, 'mse': 16.61787986755371, 'pcc': 0.8142910599708557}
Test metrics: {'loss': 8.402321529388427, 'mae': 2.4386613368988037, 'mse': 9.405458450317383, 'pcc': 0.7092682123184204}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▅▆▃▆▃▁▃▂▃▃▃▃▂▃▃▃▂▃▃▃▃▂▃▃▁▅▂▂▂▂▂▃▂▂▂▂▂▂▂
train/mae,█▆▇▄▅▅▁▄▃▅▄▅▅▄▅▅▄▄▄▅▅▅▃▄▅▂▅▃▃▃▃▃▄▃▃▃▃▃▂▃
train/mse,█▅▆▃▆▃▁▃▂▃▃▃▃▂▃▃▃▂▃▃▃▃▂▃▃▁▅▂▂▂▂▂▃▂▂▂▂▂▂▂
train/pcc,▄▁▃▅▄▂ ▆▆▄▅▆▅▇▆▆▄▆▆▆▆▆▇▇▇ ▆▇▇█▇▇▆▇▇▇█▇██
valid/loss,██▇▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▃▂▂▂▂▂▁▁▂▁▂▁▁▁
valid/mae,██▇▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▃▂▂▂▂▂▁▁▂▁▁▁▁▁

0,1
test/loss,8.40232
test/mae,2.43866
test/mse,9.40546
test/pcc,0.70927
train/loss,7.46729
train/mae,2.73263
train/mse,7.46729
train/pcc,
valid/loss,16.61788
valid/mae,2.86707


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 33.570942640304565, 'mae': 4.1702423095703125, 'mse': 33.57094192504883, 'pcc': 0.4034169912338257}
Test metrics: {'loss': 14.440587997436523, 'mae': 3.1208508014678955, 'mse': 14.947635650634766, 'pcc': 0.2132442593574524}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▅█▄▅▆█▂▆▃▂▂▃▂▃▃▂▂▃▃▂▂▂▂▂▃▁▂▃▃▂▃▃▂▂▂▂▂▂▂▄
train/mae,▇█▆▆█▇▄█▄▃▃▄▃▄▅▃▃▄▄▃▃▄▃▃▄▁▃▄▄▃▄▅▂▃▃▃▄▂▃▄
train/mse,▅█▄▅▆█▂▆▃▂▂▃▂▃▃▂▂▃▃▂▂▂▂▂▃▁▂▃▃▂▃▃▂▂▂▂▂▂▂▄
train/pcc,▂▄▂▄▅▄ ▃▄▃▂▄▃▂▃▅▃▂▃▅▄▅▄▆▁ ▄▃▅▂▂▃▃▅▃▅▆█▃▄
valid/loss,█████▇▇▆▄▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/mae,█████▇▆▅▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,14.44059
test/mae,3.12085
test/mse,14.94764
test/pcc,0.21324
train/loss,0.62892
train/mae,0.79305
train/mse,0.62892
train/pcc,
valid/loss,33.57094
valid/mae,4.17024


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▃▂▂█▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/mae,▇▅▄█▆▅▆▄▅▃▄▄▃▃▃▃▂▂▃▃▃▂▂▃▂▁▃▂▂▂▃
train/mse,▃▂▂█▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train/pcc,▃▄▄▁▃▆ ▆▅▆▆▆▆▆▇███▇▇▇█▇██ █████
valid/loss,█▅▄▇▆▃▃▄▆▄▄▅▂▂▂▃▁▁▁▁▂▂▁▂▁▂▁▂▂▂
valid/mae,█▆▅▇▆▄▄▄▇▅▅▅▃▃▂▃▂▂▁▁▂▂▁▂▁▂▁▃▂▂
valid/mse,█▅▄▇▆▃▃▄▆▄▄▅▂▂▂▃▁▁▁▁▂▂▁▂▁▂▁▂▂▂
valid/pcc,▁▂▄▄▄▇▆▇▇▇▆▅▇▇▇▇█████▇████████

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': 15.634022235870361, 'mae': 2.7438805103302, 'mse': 15.634021759033203, 'pcc': 0.8072391748428345}
Test metrics: {'loss': 7.367557613551616, 'mae': 2.30387020111084, 'mse': 9.059057235717773, 'pcc': 0.6427496671676636}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▆▇▄▃▃▃▁▇▅▄█▂▃▃▃▄▄▃▂▃▂▂▃▃▂▁▂▃▂▂▂▂▂▂▂▂▂▁▂▂
train/mae,██▆▅▅▄▁▆▆▆▇▄▅▅▅▅▆▅▃▅▄▄▅▅▄▂▄▄▃▃▃▃▄▄▄▃▄▃▃▃
train/mse,▆▇▄▃▃▃▁▇▅▄█▂▃▃▃▄▄▃▂▃▂▂▃▃▂▁▂▃▂▂▂▂▂▂▂▂▂▁▂▂
train/pcc,▅▁▃▃▅▃ ▂▂▁▁▅▅▃▃▄▄▄▆▆▇▅▄▅▅ ▆▅▇█▇▇▇▅▆█▆█▇▇
valid/loss,█▇▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁
valid/mae,█▇▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▁▁▁▁▂▁

0,1
test/loss,7.36756
test/mae,2.30387
test/mse,9.05906
test/pcc,0.64275
train/loss,0.52105
train/mae,0.72184
train/mse,0.52105
train/pcc,
valid/loss,15.63402
valid/mae,2.74388


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 34.660544872283936, 'mae': 4.221195220947266, 'mse': 34.660545349121094, 'pcc': 0.3352803587913513}
Test metrics: {'loss': 14.875128173828125, 'mae': 3.157849073410034, 'mse': 15.358074188232422, 'pcc': 0.17335934937000275}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▅▄▃▄▄▆▂▂▃▂▃▃▂▂▃▃▂▂▂▂▂▃▂▂▁▂▂▂▂▃▂▂▂▂▂▂▂▃▃
train/mae,▆▆▅▄▅▅█▃▃▃▃▄▃▃▃▃▃▄▃▃▄▃▄▃▃▁▃▃▃▃▃▃▃▃▃▂▃▃▃▄
train/mse,█▅▄▃▄▄▆▂▂▃▂▃▃▂▂▃▃▂▂▂▂▂▃▂▂▁▂▂▂▂▃▂▂▂▂▂▂▂▃▃
train/pcc,▁▆▅▃▅▅ █▄▆▅▆▅▇▅▅▁▇▆▆▅▆▆▅▄ ▅▇▅█▂▆▄▆▇▇▇▅▄▆
valid/loss,███▇▇▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid/mae,███▇▆▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
test/loss,14.87513
test/mae,3.15785
test/mse,15.35807
test/pcc,0.17336
train/loss,0.12996
train/mae,0.3605
train/mse,0.12996
train/pcc,
valid/loss,34.66054
valid/mae,4.2212


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▆▅▂▃▄▇█▁▃▁▁▂
train/mae,▅▄▂▃▄▄█▂▃▁▁▂
train/mse,▆▅▂▃▄▇█▁▃▁▁▂
train/pcc,▃▄▆▁▃▅ ▇▆██▆
valid/loss,█▄▃▄▄▃▃▃▃▁▂▂
valid/mae,█▅▃▅▄▃▃▃▂▁▁▃
valid/mse,█▄▃▄▄▃▃▃▃▁▂▂
valid/pcc,▁▂▅▂▄▆▆▆▄▇█▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112498733321747, max=1.0…

  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 9.356805205345154, 'mae': 2.0383760929107666, 'mse': 9.356804847717285, 'pcc': 0.8891998529434204}
Test metrics: {'loss': 5.4694278001785275, 'mae': 1.8806800842285156, 'mse': 6.484004497528076, 'pcc': 0.7527417540550232}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▆▅▃█▆▅▄▃▄▄▃▆▃▂▃▃▃▂▂▂▂▂▃▂▃▃▂▁▂▃▂▂▁▁▂▂▁▁▂
train/mae,█▆▆▄▆▆▇▅▄▅▄▄▅▄▃▄▃▄▃▃▃▃▂▃▃▅▃▃▂▂▃▃▂▁▁▂▂▁▂▂
train/mse,█▆▅▃█▆▅▄▃▄▄▃▆▃▂▃▃▃▂▂▂▂▂▃▂▃▃▂▁▂▃▂▂▁▁▂▂▁▁▂
train/pcc,▁▃▃▃▄▂ ▆▅▄▄▄▅▅▆▅▆▅▆▆▆▆▆▆▇ ▅▆▆▇█▆▇█▇▇▇█▇█
valid/loss,█▇▄▄▄▄▄▅▄▄▄▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▁▁▂▁▁▁▁▁▁▁▁
valid/mae,█▇▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▃▂▂▁▁▁▁▁▁

0,1
test/loss,5.46943
test/mae,1.88068
test/mse,6.484
test/pcc,0.75274
train/loss,0.05964
train/mae,0.24421
train/mse,0.05964
train/pcc,
valid/loss,9.35681
valid/mae,2.03838


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=128, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 30.618992805480957, 'mae': 4.1203227043151855, 'mse': 30.618993759155273, 'pcc': 0.4775939881801605}
Test metrics: {'loss': 11.766705417633057, 'mae': 2.6818795204162598, 'mse': 11.777313232421875, 'pcc': 0.4455341398715973}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▆▄▆▅▃▄▂▂▃▄▂▂▂▂▂▂▂▅▃▂▂▃▂▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂
train/mae,█▅▄▇▅▃▆▂▃▃▄▁▂▃▃▃▂▂▄▃▂▂▃▁▃▁▂▂▁▂▂▁▁▂▂▃▂▁▂▁
train/mse,█▆▄▆▅▃▄▂▂▃▄▂▂▂▂▂▂▂▅▃▂▂▃▂▂▁▂▂▂▂▂▁▂▂▂▂▂▂▂▂
train/pcc,▅▁▃▁▃▄ ▆▃▂▁▄▄▃▆▄▃▄▄▄▅▄▃▆▃ ▅▃▆▇▄█▅▇▇▄▅▆▃▆
valid/loss,███▇▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▂▂▁▂▁▁▁▁▁
valid/mae,███▇▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▂▁▁▂▁▁▁

0,1
test/loss,11.76671
test/mae,2.68188
test/mse,11.77731
test/pcc,0.44553
train/loss,6.02823
train/mae,2.45525
train/mse,6.02823
train/pcc,
valid/loss,30.61899
valid/mae,4.12032


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▄▃▂▁
train/mae,█▅▄▃▁
train/mse,█▄▃▂▁
train/pcc,▁▁▁▄█
valid/loss,█▄▁▃▁
valid/mae,█▃▁▂▁
valid/mse,█▄▁▃▁
valid/pcc,▁▂▇▆█

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 10.32167112827301, 'mae': 2.0466599464416504, 'mse': 10.321671485900879, 'pcc': 0.8946718573570251}
Test metrics: {'loss': 6.081819152832031, 'mae': 2.046905040740967, 'mse': 7.300798416137695, 'pcc': 0.735763430595398}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▇▃▄▄▆▅▄▃▃▂▃▂▂▂▂▂▂▂▂▁▂▂▂▂▅▁▁▂▁▁▂▁▂▁▁▂▁▁▁
train/mae,█▇▄▄▅▅▇▅▃▄▃▄▃▃▃▃▃▂▃▃▂▂▂▂▃▇▂▂▂▂▂▂▂▂▁▂▃▁▂▁
train/mse,█▇▃▄▄▆▅▄▃▃▂▃▂▂▂▂▂▂▂▂▁▂▂▂▂▅▁▁▂▁▁▂▁▂▁▁▂▁▁▁
train/pcc,▃▂▄▅▁▂ ▆▅▆▆▄▅▆▆▆▆▆▆▆▇▇▇▇█ ▆█▇▇█▇▇▇█▇█▇▇█
valid/loss,█▆▄▄▄▄▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▁▂▂▁▂▁▁▂▁▂▁▁▁▁▁▁▁▁
valid/mae,█▆▅▅▅▄▄▄▄▄▄▄▄▄▃▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▃▁▁▁▁▁▂▁▁

0,1
test/loss,6.08182
test/mae,2.04691
test/mse,7.3008
test/pcc,0.73576
train/loss,1.39332
train/mae,1.18039
train/mse,1.39332
train/pcc,
valid/loss,10.32167
valid/mae,2.04666


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 27.474535942077637, 'mae': 3.805910110473633, 'mse': 27.474536895751953, 'pcc': 0.7366822361946106}
Test metrics: {'loss': 10.19689121246338, 'mae': 2.6009645462036133, 'mse': 11.234374046325684, 'pcc': 0.5637915730476379}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▅▅▇▄▃▁▄▂▂▂▃▃▃▂▃▄▄▂▃▂▂▂▂▃▁▂▂▃▂▂▃▃▂▂▂▂▂▂▂
train/mae,█▆▆▇▅▅▁▅▄▄▄▄▅▅▄▅▄▅▄▄▄▄▄▄▄▁▃▃▄▄▄▄▄▃▃▄▄▃▄▄
train/mse,█▅▅▇▄▃▁▄▂▂▂▃▃▃▂▃▄▄▂▃▂▂▂▂▃▁▂▂▃▂▂▃▃▂▂▂▂▂▂▂
train/pcc,▁▅▄▅▅▂ ▅▇▆▄▆▃▄▆▄▆▃▆▇▄▆▅▇▄ ▆█▇█▅▅▅██▆▇▇▅▇
valid/loss,███▆▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▁
valid/mae,███▆▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,10.19689
test/mae,2.60096
test/mse,11.23437
test/pcc,0.56379
train/loss,2.44585
train/mae,1.56392
train/mse,2.44585
train/pcc,
valid/loss,27.47454
valid/mae,3.80591


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▄▂▁▁▁
train/mae,█▆▄▁▃▃
train/mse,█▄▂▁▁▁
train/pcc,▁▄▄▆▇█
valid/loss,█▃▂▂▂▁
valid/mae,█▃▃▂▁▁
valid/mse,█▃▂▂▂▁
valid/pcc,▁▂▃▅█▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': 13.361243486404419, 'mae': 2.289992332458496, 'mse': 13.36124324798584, 'pcc': 0.8321922421455383}
Test metrics: {'loss': 7.051892971992492, 'mae': 2.099661350250244, 'mse': 8.48674201965332, 'pcc': 0.6667726635932922}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▂▄▃▃▄▃▂▂▃▃▂▂▁▁▁▁▁▁▁▁▂▂▂▁▂▁▁▂▁▂▁▁▁▁▁▁▁▁▂
train/mae,█▃▅▄▄▅▅▃▃▄▄▂▂▂▂▂▁▂▁▂▁▂▂▃▂▃▁▂▂▂▂▂▁▂▂▁▂▂▂▂
train/mse,█▂▄▃▃▄▃▂▂▃▃▂▂▁▁▁▁▁▁▁▁▂▂▂▁▂▁▁▂▁▂▁▁▁▁▁▁▁▁▂
train/pcc,▄▅▁▃▂▅ ▅▅▄▂▆▆▇▇▇█▇▇▇▇▇▇▇▇ █▇▆▇▆▇█▇▇▇██▇█
valid/loss,█▄▄▄▄▃▃▃▃▂▂▃▂▂▂▃▂▂▂▂▂▂▁▂▁▁▁▂▂▁▂▂▁▁▁▁▁▁▁▂
valid/mae,█▅▅▄▄▄▄▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁

0,1
test/loss,7.05189
test/mae,2.09966
test/mse,8.48674
test/pcc,0.66677
train/loss,0.95912
train/mae,0.97935
train/mse,0.95912
train/pcc,
valid/loss,13.36124
valid/mae,2.28999


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 25.09736967086792, 'mae': 3.605632781982422, 'mse': 25.097370147705078, 'pcc': 0.7610130906105042}
Test metrics: {'loss': 11.683052760583815, 'mae': 3.121006727218628, 'mse': 14.379036903381348, 'pcc': 0.586065948009491}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▃▅▄▄▂▂▃▄▂▂▂▃▃▃▂▂▃▄▄▂▃▂▃▂▁▁▂▂▁▁▂▂▁▂▂▃▁▁▁
train/mae,█▄▆▄▄▃▃▃▅▂▂▂▃▄▂▃▂▂▄▄▁▃▃▃▂▂▁▃▃▁▂▂▂▂▂▃▃▁▁▁
train/mse,█▃▅▄▄▂▂▃▄▂▂▂▃▃▃▂▂▃▄▄▂▃▂▃▂▁▁▂▂▁▁▂▂▁▂▂▃▁▁▁
train/pcc,▅▃▁▂▃▄ ▄▂▄▅▅▄▁▃▃▃▁▂▁▆▅▅▄▅ ▅▅▁█▇▇█▇▆▅▅▆██
valid/loss,██▇▄▃▃▃▃▃▃▃▂▃▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁
valid/mae,██▇▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁

0,1
test/loss,11.68305
test/mae,3.12101
test/mse,14.37904
test/pcc,0.58607
train/loss,6.93915
train/mae,2.63423
train/mse,6.93915
train/pcc,
valid/loss,25.09737
valid/mae,3.60563


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▄▄▄▄▃▁
train/mae,█▅▅▅▅▃▁
train/mse,█▄▄▄▄▃▁
train/pcc,▄▄▁█▂▃
valid/loss,█▄▂▄▂▁▆
valid/mae,█▄▂▃▂▁▆
valid/mse,█▄▂▄▂▁▆
valid/pcc,▄▆▆█▆█▁

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 16.61906623840332, 'mae': 2.568714141845703, 'mse': 16.61906623840332, 'pcc': 0.7870736718177795}
Test metrics: {'loss': 7.697455058991909, 'mae': 2.1538259983062744, 'mse': 9.462502479553223, 'pcc': 0.6288226842880249}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▃▃▂▄▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▂▂▁▁▂▂▁
train/mae,█▅▅▄▆▅▆▄▄▃▄▄▄▄▃▃▃▃▄▄▃▃▃▃▃▁▃▄▂▃▂▂▂▃▃▃▂▃▃▂
train/mse,█▃▃▂▄▃▃▂▂▂▂▂▂▂▂▂▂▁▂▂▂▁▂▁▁▁▁▂▁▁▁▁▁▂▂▁▁▂▂▁
train/pcc,▂▄▃▄▁▂ ▅▆▆▆▇▇▆█▇▆▇▇▆▇█▇██ ▇▇█████▇▇▇██▇█
valid/loss,█▄▄▄▄▃▃▃▃▂▂▂▁▁▁▁▁▂▂▁▂▁▁▁▁▁▂▂▂▁▁▁▁▁▂▂▁▂▂▁
valid/mae,█▅▄▄▅▄▄▄▄▃▃▃▂▂▂▂▂▂▃▂▂▂▂▁▂▂▂▂▂▁▁▁▂▁▂▂▁▃▂▁

0,1
test/loss,7.69746
test/mae,2.15383
test/mse,9.4625
test/pcc,0.62882
train/loss,2.84263
train/mae,1.68601
train/mse,2.84263
train/pcc,
valid/loss,16.61907
valid/mae,2.56871


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=256, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 18.78222346305847, 'mae': 3.341937303543091, 'mse': 18.782222747802734, 'pcc': 0.8007034063339233}
Test metrics: {'loss': 8.961331194639206, 'mae': 2.6405954360961914, 'mse': 10.965785026550293, 'pcc': 0.5542622208595276}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▄▄▃▃▂▃▁▃▂▂▂▂▃▂▄▂▂▃▂▂▂▂▃▂▂█▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train/mae,▅▄▄▄▃▄▁▄▃▃▃▂▃▃▄▃▃▄▃▃▃▃▄▃▃█▃▃▃▃▃▃▃▃▃▃▃▃▃▃
train/mse,▄▄▃▃▂▃▁▃▂▂▂▂▃▂▄▂▂▃▂▂▂▂▃▂▂█▂▂▂▂▂▂▂▂▂▂▂▂▂▂
train/pcc,▄▃▄▃▅▁ ▃▅▄▇▆▅▄▅▆▅▅▆▅▆▆▃▇▄ ▄▇▆▄▇▇▆▆▆▇▆▅█▇
valid/loss,██▇▄▃▃▃▃▃▃▃▃▄▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▁▁▁▁▂▁▁▁▁▁▁▁
valid/mae,██▇▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▂▃▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,8.96133
test/mae,2.6406
test/mse,10.96579
test/pcc,0.55426
train/loss,10.98756
train/mae,3.31475
train/mse,10.98756
train/pcc,
valid/loss,18.78222
valid/mae,3.34194


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,▄▂▂▃▃▁█▅▁▁▁▃▂▆
train/mae,▄▂▂▃▃▁█▂▁▁▂▂▂▅
train/mse,▄▂▂▃▃▁█▅▁▁▁▃▂▆
train/pcc,▆▄▄▃▁▆ ▄▆█▆▅▄█
valid/loss,▇▂▄▃▄▃▂▃▂▁▂▁▃█
valid/mae,▇▂▄▂▄▃▂▂▂▁▂▁▃█
valid/mse,▇▂▄▃▄▃▂▃▂▁▂▁▃█
valid/pcc,▂▂▂▄▃▃▄▅▅█▃▅▄▁

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=5e-05
Validation metrics: {'loss': 8.904665350914001, 'mae': 2.0316872596740723, 'mse': 8.904664993286133, 'pcc': 0.8822529911994934}
Test metrics: {'loss': 5.205345094203949, 'mae': 1.8463512659072876, 'mse': 6.3753743171691895, 'pcc': 0.7227172255516052}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▃▆▃▄▃▄█▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▁▂▂▁▁▂▁▁▁▁▂▂▂▁▁▁▁▂
train/mae,▄▄▄▄▃▄█▃▃▃▃▂▂▃▂▂▂▂▂▂▂▂▁▂▂▂▂▂▂▁▁▁▂▂▂▁▂▂▂▂
train/mse,▃▆▃▄▃▄█▃▂▂▂▂▁▂▂▂▂▂▂▂▁▂▁▂▂▁▁▂▁▁▁▁▂▂▂▁▁▁▁▂
train/pcc,▁▁▁▁▂▂ ▅▅▅▅▇▇▇▇▇▇▇▇██▇▇▇█ █▇████▇▇▇▇█▇██
valid/loss,█▄▄▄▅▄▄▄▃▃▃▂▂▂▂▂▂▂▂▁▂▁▂▁▁▁▁▁▁▁▂▁▂▁▂▂▁▂▁▁
valid/mae,█▅▅▅▅▅▅▅▄▄▄▃▃▃▂▂▂▂▂▂▂▁▂▁▁▁▁▂▁▁▁▁▂▂▃▁▁▂▁▁

0,1
test/loss,5.20535
test/mae,1.84635
test/mse,6.37537
test/pcc,0.72272
train/loss,0.04633
train/mae,0.21524
train/mse,0.04633
train/pcc,
valid/loss,8.90467
valid/mae,2.03169


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=4, learning_rate=1e-05
Validation metrics: {'loss': 16.222658395767212, 'mae': 2.836919069290161, 'mse': 16.222658157348633, 'pcc': 0.8375843167304993}
Test metrics: {'loss': 6.800465941429138, 'mae': 2.154733896255493, 'mse': 7.918953895568848, 'pcc': 0.7285481691360474}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,█▆▃▃▆▃▁▃▃▄▂▃▃▂▂▂▂▃▂▃▂▂▃▄▂▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂
train/mae,█▇▅▄▆▅▁▅▄▆▄▅▅▃▄▃▄▄▄▄▄▄▄▄▃▆▄▄▄▄▃▄▃▄▃▄▃▃▃▄
train/mse,█▆▃▃▆▃▁▃▃▄▂▃▃▂▂▂▂▃▂▃▂▂▃▄▂▃▂▂▃▂▂▂▂▂▂▂▂▂▂▂
train/pcc,▁▅▅▄▂▄ ▂▃▄▅▄▄▆▆▇▅▄▇▆▆▆▇▇▆ █▆▅▄█▅▇▇██████
valid/loss,██▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▂▃▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
valid/mae,██▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁

0,1
test/loss,6.80047
test/mae,2.15473
test/mse,7.91895
test/pcc,0.72855
train/loss,8.23386
train/mae,2.86947
train/mse,8.23386
train/pcc,
valid/loss,16.22266
valid/mae,2.83692


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=5e-05
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
train/loss,█▃▄▃▃▃▁▃▂▃▃▂▂▂▂▂▂▁▂▂▂▁▂▂▂▁▁▁▁
train/mae,█▄▅▄▄▄▁▄▃▄▃▃▃▃▂▂▃▂▃▂▂▂▃▃▂▁▂▂▂
train/mse,█▃▄▃▃▃▁▃▂▃▃▂▂▂▂▂▂▁▂▂▂▁▂▂▂▁▁▁▁
train/pcc,▂▃▁▄▄▅ ▅▆▆█▆█▇▇▆▇█▆▇▇▇▇▇▇ █▇█
valid/loss,█▄▄▄▃▂▂▂▂▃▂▁▂▁▂▂▂▂▁▁▁▁▁▂▁▂▁▁▂
valid/mae,█▅▅▅▅▄▄▃▃▄▂▂▂▂▂▃▂▁▁▁▁▁▁▂▁▂▂▁▁
valid/mse,█▄▄▄▃▂▂▂▂▃▂▁▂▁▂▂▂▂▁▁▁▁▁▂▁▂▁▁▂
valid/pcc,▁▂▃▄▅▇▇▇▇▇████▇█████████████▇

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=6, learning_rate=1e-05
Validation metrics: {'loss': 14.545965909957886, 'mae': 2.526681423187256, 'mse': 14.545966148376465, 'pcc': 0.8208537101745605}
Test metrics: {'loss': 8.79493522644043, 'mae': 2.3661155700683594, 'mse': 10.302852630615234, 'pcc': 0.6499218344688416}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▅█▄▆▃▃▁▇▃▂▄▃▂▂▃▂▂▂▃▃▂▂▂▂▂▁▃▂▂▂▂▂▂▂▂▂▂▂▂▂
train/mae,██▆▆▅▅▃▇▅▄▆▅▅▄▅▄▄▅▅▅▄▄▅▅▅▁▅▄▅▃▄▄▄▃▃▃▃▄▃▃
train/mse,▅█▄▆▃▃▁▇▃▂▄▃▂▂▃▂▂▂▃▃▂▂▂▂▂▁▃▂▂▂▂▂▂▂▂▂▂▂▂▂
train/pcc,▄▂▂▁▂▅ ▂▃▄▂▂▅▅▅▇▆▃▄▄▃▅▅▅▅ ▄▇▆▆▇▅▆▆█▇█▇▇▆
valid/loss,██▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▂▁▁
valid/mae,██▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▁▁▁▂▁▁

0,1
test/loss,8.79494
test/mae,2.36612
test/mse,10.30285
test/pcc,0.64992
train/loss,0.98853
train/mae,0.99425
train/mse,0.98853
train/pcc,
valid/loss,14.54597
valid/mae,2.52668


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=0.001
Validation metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}
Test metrics: {'loss': nan, 'mae': nan, 'mse': nan, 'pcc': nan}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,
test/mae,
test/mse,
test/pcc,
train/loss,
train/mae,
train/mse,
train/pcc,
valid/loss,
valid/mae,


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=5e-05
Validation metrics: {'loss': 15.385756492614746, 'mae': 2.396526336669922, 'mse': 15.385757446289062, 'pcc': 0.7924265265464783}
Test metrics: {'loss': 7.830407118797302, 'mae': 2.1791698932647705, 'mse': 9.519826889038086, 'pcc': 0.6747744083404541}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▇▅▄█▄▄▁▃▃▃▃▄▃▃▂▂▂▂▂▂▂▃▂▂▃▁▂▂▂▂▂▂▄▄▂▂▃▂▂▁
train/mae,█▇▅▇▅▆▁▄▄▄▅▅▄▄▃▃▃▄▃▄▂▄▃▃▄▂▂▃▂▂▃▃▆▅▃▄▄▃▃▂
train/mse,▇▅▄█▄▄▁▃▃▃▃▄▃▃▂▂▂▂▂▂▂▃▂▂▃▁▂▂▂▂▂▂▄▄▂▂▃▂▂▁
train/pcc,▅▂▁▂▃▅ ▆▆▆▆▆▄▅▅▇▇▇▇▇█▆▇▇█ ▆▇▇▇█▇▂▄▆▆▇▇▇█
valid/loss,█▇▇▅▄▃▄▂▂▃▂▂▄▁▂▂▂▂▂▁▂▃▁▂▂▂▂▁▂▂▂▂▄▅▃▂▂▂▂▂
valid/mae,██▇▇▆▅▆▄▄▅▃▄▅▂▃▃▂▂▂▁▂▄▂▂▂▁▂▂▁▁▁▁▇▆▅▃▃▃▂▁

0,1
test/loss,7.83041
test/mae,2.17917
test/mse,9.51983
test/pcc,0.67477
train/loss,0.00144
train/mae,0.03791
train/mse,0.00144
train/pcc,
valid/loss,15.38576
valid/mae,2.39653


  0%|          | 0/50 [00:00<?, ?it/s]

Hyperparameters: hidden_size=512, mpnn_n_layers=8, learning_rate=1e-05
Validation metrics: {'loss': 13.170033693313599, 'mae': 2.3513410091400146, 'mse': 13.170034408569336, 'pcc': 0.8291530013084412}
Test metrics: {'loss': 7.997323608398437, 'mae': 2.202730178833008, 'mse': 9.168558120727539, 'pcc': 0.681717038154602}


VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
test/loss,▁
test/mae,▁
test/mse,▁
test/pcc,▁
train/loss,▄▄▃▂▃▂█▃▂▂▃▂▂▂▂▂▂▁▁▂▂▂▂▂▃▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁
train/mae,▄▄▃▃▃▃█▃▃▃▄▃▂▃▃▂▂▂▂▂▂▂▂▂▃▂▂▂▂▂▂▂▁▂▁▁▁▂▂▂
train/mse,▄▄▃▂▃▂█▃▂▂▃▂▂▂▂▂▂▁▁▂▂▂▂▂▃▁▁▁▁▁▁▁▁▂▁▁▁▁▂▁
train/pcc,▂▁▂▄▃▂ ▄▄▄▂▃▄▅▄▃▅▆▆▅▆▇▆▅█ ▄▅▆▇▅▆█▆▇▇▇▇▇▆
valid/loss,█▇▄▄▄▄▃▄▄▃▃▃▃▃▃▂▂▂▃▂▂▂▂▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁
valid/mae,█▆▄▄▄▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▂▂▂▂▁▂▁▁▁▁▁▁▁▁▁

0,1
test/loss,7.99732
test/mae,2.20273
test/mse,9.16856
test/pcc,0.68172
train/loss,0.18404
train/mae,0.429
train/mse,0.18404
train/pcc,
valid/loss,13.17003
valid/mae,2.35134


## Task 8. Train GINE (2 points).
1. Tune hyperparameters of a GNN with `GINELayer` as MPNN layer to obtain at most 2.0 MAE on the validation set. You can modify the GNN/MPNN architecture, so it uses some regularization tricks like dropout or batch norm. Don't change the validation batch size. If your validation MAE is in (2.0, 2.5], you can obtain 1 point.
2. Report the obtained MAE on the validation and test set (only the former need to be lower than 2.0 MAE).
3. Provide the link to the final run: [your link]

In [1]:
from torchmetrics import MeanAbsoluteError as MAE
from torchmetrics import MeanSquaredError as MSE
from torchmetrics import PearsonCorrCoef as PCC
from itertools import product

# Define the hyperparameter grid
grid_params = {
    "hidden_size": [128, 256, 512],
    "mpnn_n_layers": [4, 6, 8],
    "learning_rate": [1e-3, 5e-5, 1e-5],
}

# Define metrics
metrics = {
    "mae": MAE(),
    "mse": MSE(),
    "pcc": PCC(),
}

# Perform grid search
for params in product(*grid_params.values()):
    hidden_size, mpnn_n_layers, learning_rate = params

    # Create model with current hyperparameters
    model = GNN(
        node_features_size=node_featurizer.feat_size(),
        edge_features_size=edge_featurizer.feat_size(),
        hidden_size=hidden_size,
        output_size=1,
        mpnn_layer_cls=GINELayer,
        mpnn_n_layers=mpnn_n_layers,
        readout_cls=AttentionReadout,
        mpnn_layer_kwargs={}
    )

    # Create Trainer with current hyperparameters
    trainer = Trainer(
        run_dir=f"experiments/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
        train_dataset=train,
        valid_dataset=valid,
        train_metrics=metrics,
        valid_metrics=metrics,
        train_batch_size=32,
        model=model,
        logger=WandbLogger(
            logdir=f"runs/mpnn/hidden_{hidden_size}_layers_{mpnn_n_layers}_lr_{learning_rate}",
            project_name="mldd23",
            experiment_name=f"sage_{get_time_stamp()}",
        ),
        optimizer_kwargs={"lr": learning_rate},
        n_epochs=50,
        device="cpu",
        valid_every_n_epochs=1,
    )

    # Train and test the model
    valid_metrics = trainer.train()
    test_metrics = trainer.test(test)

    # Print metrics for each hyperparameter combination
    print(f"Hyperparameters: hidden_size={hidden_size}, mpnn_n_layers={mpnn_n_layers}, learning_rate={learning_rate}")
    print(f"Validation metrics: {valid_metrics}")
    print(f"Test metrics: {test_metrics}")

    # Close the trainer
    trainer.close()

ModuleNotFoundError: ignored

# Code optimization
Some pieces of code were written suboptimally. Your task is to slightly optimize them.

## Task 9. Optimize SumReadout (1 point).
`SumReadout` was written using `to_dense_embeddings` function which does some unecessary memory allocations and computations. Your task is to rewrite the method using code from a bare torch library. Hint: `torch.index_add`.

In [None]:
class OptimizedSumReadout(ReadoutBase):
    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        adj_matrix = graph.adjacency_matrix().to(node_embeddings.device)

        # Calculate the sum of node embeddings using torch.index_add
        graph_embeddings = torch.zeros(graph.batch_size, node_embeddings.size(1), device=node_embeddings.device)

        # Cumulative sum of batch sizes to obtain indices for torch.index_add_
        cum_batch_nodes = torch.cat([torch.tensor([0], device=node_embeddings.device), graph.batch_num_nodes().cumsum(0)])

        # Perform the index_add operation using a while loop
        i = 0
        while i < len(cum_batch_nodes) - 1:
            start_idx, end_idx = cum_batch_nodes[i], cum_batch_nodes[i + 1]
            graph_embeddings[i] = torch.sum(node_embeddings[start_idx:end_idx], dim=0)
            i += 1

        return graph_embeddings


test_readout(OptimizedSumReadout, expected_sum_readout)


## Task 10. Optimize MeanReadout (1 point).
Your task is to rewrite the method using code from a bare torch library.

In [None]:
class OptimizedMeanReadout(ReadoutBase):
    def forward(self,
                node_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Attributes:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            graph_embeddings: graph embeddings of shape.[batch_size, hidden_size]
        """
        dj_matrix = graph.adjacency_matrix().to(node_embeddings.device)

        # Calculate the mean of node embeddings using torch.index_add
        graph_embeddings = torch.zeros(graph.batch_size, node_embeddings.size(1), device=node_embeddings.device)

        # Cumulative sum of batch sizes to obtain indices for torch.index_add_
        cum_batch_nodes = torch.cat([torch.tensor([0], device=node_embeddings.device), graph.batch_num_nodes().cumsum(0)])

        # Perform the index_add operation using a while loop
        i = 0
        while i < len(cum_batch_nodes) - 1:
            start_idx, end_idx = cum_batch_nodes[i], cum_batch_nodes[i + 1]
            num_elements = end_idx - start_idx
            graph_embeddings[i] = torch.sum(node_embeddings[start_idx:end_idx], dim=0) / num_elements
            i += 1

        return graph_embeddings

        # Create an index tensor to be used with torch.index_add()


test_readout(OptimizedMeanReadout, expected_mean_readout)

## Task 11. Optimize SimpleMPNNLayer (1 point).
We can make our implementations of `SimpleMPNNLayer` layer (and basically any other MPNN layer) slightly faster by:
- reducing the costs of the message embedding (in the case of `SimpleMPNNLayer`, it's application of `self.linear_3`) from $O(m)$ to $O(n)$, where $m$ is the number of edges in the graph and $n$ is the number of nodes.
- removing quite expensive `to_dense_batch` call.

Your task is to apply the above optimizations.

In [None]:
class OptimizedSimpleMPNNLayer(MPNNLayerBase):
    def __init__(self, hidden_size: int):
        super().__init__()
        self.hidden_size = hidden_size
        self.linear_1 = nn.Linear(hidden_size, hidden_size)
        self.linear_2 = nn.Linear(hidden_size, hidden_size)
        self.linear_3 = nn.Linear(hidden_size, hidden_size)

    def forward(self,
                node_embeddings: torch.Tensor,
                edge_embeddings: torch.Tensor,
                graph: dgl.DGLGraph) -> torch.Tensor:
        """
        Arguments:
            node_embeddings: node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
            edge_embeddings: edge embeddings in a sparse format, i.e. [total_num_edges, hidden_size]
            graph: a DGLGraph that contains the graph structure
        Returns:
            node_embeddings: updated node embeddings in a sparse format, i.e. [total_num_nodes, hidden_size]
        """
        start_nodes, end_nodes = graph.edges(order='srcdst')  # using this `order` value sorts the `start_nodes`
        start_nodes = start_nodes.to(torch.int64)

        messages = self.linear_3(node_embeddings[end_nodes])  # W_3x_j

        # Compute aggregated messages without using to_dense_batch
        aggregated_message = torch.zeros_like(node_embeddings)
        aggregated_message.scatter_add_(0, start_nodes.unsqueeze(1).expand_as(messages), messages)

        aggregated_message = self.linear_2(aggregated_message)  # W_2\sum_{j\in N(i)} W_3x_j
        node_embeddings = self.linear_1(node_embeddings) + aggregated_message  # W_1x_i + W_2\sum_{j\in N(i)} W_3x_j
        return node_embeddings

test_mpnn_layer(OptimizedSimpleMPNNLayer, expected_simple_mpnn_output)