# Setup of the Notebook

In [90]:
import torch
from torch import nn
from torch_geometric.datasets import MoleculeNet
from torch_geometric.data import DataLoader
import torch.nn.functional as F

from torch_geometric.nn import GCNConv
from torch_geometric.nn import GINConv
from torch_geometric import nn as gnn

from sklearn.metrics import roc_auc_score
import numpy as np
from time import time

from torch.utils.tensorboard import SummaryWriter

In [91]:
#TODO: Modify for the server
if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"

The below is required for some reason as it is not present in the pytorch geometric version (but in the docs it is? It is the copy-pasted sourcecode)

In [92]:
from typing import Optional, Callable, List
from torch_geometric.typing import Adj

import copy

import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import ModuleList, Sequential, Linear, BatchNorm1d, ReLU

from torch_geometric.nn.conv import GCNConv, SAGEConv, GINConv, GATConv
from torch_geometric.nn.models.jumping_knowledge import JumpingKnowledge


class BasicGNN(torch.nn.Module):
    r"""An abstract class for implementing basic GNN models.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Size of each hidden and output sample.
        num_layers (int): Number of message passing layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last'):
        super().__init__()
        self.in_channels = in_channels
        self.hidden_channels = hidden_channels
        self.out_channels = hidden_channels
        if jk == 'cat':
            self.out_channels = num_layers * hidden_channels
        self.num_layers = num_layers
        self.dropout = dropout
        self.act = act

        self.convs = ModuleList()

        self.jk = None
        if jk != 'last':
            self.jk = JumpingKnowledge(jk, hidden_channels, num_layers)

        self.norms = None
        if norm is not None:
            self.norms = ModuleList(
                [copy.deepcopy(norm) for _ in range(num_layers)])

    def reset_parameters(self):
        for conv in self.convs:
            conv.reset_parameters()
        for norm in self.norms or []:
            norm.reset_parameters()
        if self.jk is not None:
            self.jk.reset_parameters()

    def forward(self, x: Tensor, edge_index: Adj, *args, **kwargs) -> Tensor:
        xs: List[Tensor] = []
        for i in range(self.num_layers):
            x = self.convs[i](x, edge_index, *args, **kwargs)
            if self.norms is not None:
                x = self.norms[i](x)
            if self.act is not None:
                x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            if self.jk is not None:
                xs.append(x)
        return x if self.jk is None else self.jk(xs)

    def __repr__(self) -> str:
        return (f'{self.__class__.__name__}({self.in_channels}, '
                f'{self.out_channels}, num_layers={self.num_layers})')


class GCN(BasicGNN):
    r"""The Graph Neural Network from the `"Semi-supervised
    Classification with Graph Convolutional Networks"
    <https://arxiv.org/abs/1609.02907>`_ paper, using the
    :class:`~torch_geometric.nn.conv.GCNConv` operator for message passing.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        num_layers (int): Number of GNN layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.GCNConv`.
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last',
                 **kwargs):
        super().__init__(in_channels, hidden_channels, num_layers, dropout,
                         act, norm, jk)

        self.convs.append(GCNConv(in_channels, hidden_channels, **kwargs))
        for _ in range(1, num_layers):
            self.convs.append(
                GCNConv(hidden_channels, hidden_channels, **kwargs))



class GraphSAGE(BasicGNN):
    r"""The Graph Neural Network from the `"Inductive Representation Learning
    on Large Graphs" <https://arxiv.org/abs/1706.02216>`_ paper, using the
    :class:`~torch_geometric.nn.SAGEConv` operator for message passing.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        num_layers (int): Number of GNN layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.SAGEConv`.
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last',
                 **kwargs):
        super().__init__(in_channels, hidden_channels, num_layers, dropout,
                         act, norm, jk)

        self.convs.append(SAGEConv(in_channels, hidden_channels, **kwargs))
        for _ in range(1, num_layers):
            self.convs.append(
                SAGEConv(hidden_channels, hidden_channels, **kwargs))



class GIN(BasicGNN):
    r"""The Graph Neural Network from the `"How Powerful are Graph Neural
    Networks?" <https://arxiv.org/abs/1810.00826>`_ paper, using the
    :class:`~torch_geometric.nn.GINConv` operator for message passing.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        num_layers (int): Number of GNN layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.GINConv`.
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last',
                 **kwargs):
        super().__init__(in_channels, hidden_channels, num_layers, dropout,
                         act, norm, jk)

        self.convs.append(
            GINConv(GIN.MLP(in_channels, hidden_channels), **kwargs))
        for _ in range(1, num_layers):
            self.convs.append(
                GINConv(GIN.MLP(hidden_channels, hidden_channels), **kwargs))

    @staticmethod
    def MLP(in_channels: int, out_channels: int) -> torch.nn.Module:
        return Sequential(
            Linear(in_channels, out_channels),
            BatchNorm1d(out_channels),
            ReLU(inplace=True),
            Linear(out_channels, out_channels),
        )



class GAT(BasicGNN):
    r"""The Graph Neural Network from the `"Graph Attention Networks"
    <https://arxiv.org/abs/1710.10903>`_ paper, using the
    :class:`~torch_geometric.nn.GATConv` operator for message passing.

    Args:
        in_channels (int): Size of each input sample.
        hidden_channels (int): Hidden node feature dimensionality.
        num_layers (int): Number of GNN layers.
        dropout (float, optional): Dropout probability. (default: :obj:`0.`)
        act (Callable, optional): The non-linear activation function to use.
            (default: :meth:`torch.nn.ReLU(inplace=True)`)
        norm (torch.nn.Module, optional): The normalization operator to use.
            (default: :obj:`None`)
        jk (str, optional): The Jumping Knowledge mode
            (:obj:`"last"`, :obj:`"cat"`, :obj:`"max"`, :obj:`"last"`).
            (default: :obj:`"last"`)
        **kwargs (optional): Additional arguments of
            :class:`torch_geometric.nn.conv.GATConv`.
    """
    def __init__(self, in_channels: int, hidden_channels: int, num_layers: int,
                 dropout: float = 0.0,
                 act: Optional[Callable] = ReLU(inplace=True),
                 norm: Optional[torch.nn.Module] = None, jk: str = 'last',
                 **kwargs):
        super().__init__(in_channels, hidden_channels, num_layers, dropout,
                         act, norm, jk)

        if 'concat' in kwargs:
            del kwargs['concat']

        if 'heads' in kwargs:
            assert hidden_channels % kwargs['heads'] == 0
        out_channels = hidden_channels // kwargs.get('heads', 1)

        self.convs.append(
            GATConv(in_channels, out_channels, dropout=dropout, **kwargs))
        for _ in range(1, num_layers):
            self.convs.append(GATConv(hidden_channels, out_channels, **kwargs))

# Data, Model and Training

In [93]:
dataset = MoleculeNet(root="data/", name="Tox21")
#I assume this is the right order as this is also like so in the raw version of the automatically downloaded Tox21
dataset_label_names = ["NR-AR","NR-AR-LBD","NR-AhR","NR-Aromatase","NR-ER","NR-ER-LBD","NR-PPAR-gamma","SR-ARE","SR-ATAD5","SR-HSE","SR-MMP","SR-p53"]

#TODO: Do filtering

In [94]:
torch.manual_seed(1337)

#TODO: Do more involved splitting (Cluster Cross Validation etc... -> CHEMBL like)
dataset = dataset.shuffle()
train_dataset = dataset[:6500]
test_dataset = dataset[6500:]

In [95]:
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [96]:
class MLP(nn.Module):
    def __init__(self, n_layers, input_dim, hidden_dim, output_dim, dropout, activation):
        super(MLP, self).__init__()

        if n_layers == 1:
            modules = [
                nn.Dropout(p = dropout),
                nn.Linear(input_dim, output_dim)
            ]
        else:
            modules = [
                nn.Dropout(p = dropout),
                nn.Linear(input_dim, hidden_dim)
            ]

            for i in range(n_layers - 2):
                modules.extend([
                    activation,
                    nn.Dropout(p = dropout),
                    nn.Linear(hidden_dim, hidden_dim)
                ])
            
            modules.extend([
                activation,
                nn.Dropout(p = dropout),
                nn.Linear(hidden_dim, output_dim)
            ])

        self.net = nn.Sequential(*modules)

    def forward(self, x):
        return torch.sigmoid(self.net(x))





class GCN(nn.Module):
    def __init__(self, n_hidden_channels, n_graph_layers, n_graph_dropout, n_linear_layers, n_linear_dropout):
        super(GCN, self).__init__()
        self.gin = GIN(dataset.num_node_features, n_hidden_channels, n_graph_layers, dropout=n_graph_dropout)
        
        self.head = MLP(n_linear_layers, n_hidden_channels, n_hidden_channels, dataset.num_classes, n_linear_dropout, nn.ReLU())

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.gin(x, edge_index)

        # 2. Readout layer
        x = gnn.global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x =  self.head(x)
        
        return x

In [97]:
def train(model, optimizer, loader, epoch, logger: SummaryWriter):

    model.train()

    n_minibatches = len(loader)

    for batch_nr, data in enumerate(loader):
        x, edge_index, batch = data.x.float().to(device), data.edge_index.to(device), data.batch.to(device)

        out = model(x, edge_index, batch)

        y = data.y.to(device)
        is_not_nan = ~y.isnan()
        y = torch.nan_to_num(y, 0.5)

        loss = (F.binary_cross_entropy(out, y, reduction="none") * is_not_nan).mean() #Same as is the DeepTox paper

        optimizer.zero_grad() 
        loss.backward()
        optimizer.step()

        logger.add_scalar(f"BCR_MultiTask", loss.detach().cpu().numpy(), global_step=n_minibatches * (epoch - 1) + batch_nr + 1)

        

def test(model, loader, epoch:int, logger: SummaryWriter, run_type = "test"):
    model.eval()

    indices_list = [[] for d in range(dataset.num_classes)]
    probs_list = [[] for d in range(dataset.num_classes)]

    for data in loader:  # Iterate in batches over the training/test dataset.
        x, edge_index, batch = data.x.float().to(device), data.edge_index.to(device), data.batch.to(device)

        out = model(x, edge_index, batch)

        for i in range(dataset.num_classes):
            y = data.y[:,i]
            is_not_nan = ~y.isnan()

            indices = y[is_not_nan].long().detach().cpu().numpy()
            rs = out[is_not_nan].detach().cpu().numpy()
            probs = rs[:, i]
        
            #print(indices.shape, probs1.shape, rs.shape, np.ones_like(indices).shape)
            indices_list[i].append(indices)
            probs_list[i].append(probs)
        
    for i, indices, probs in zip(range(dataset.num_classes), indices_list, probs_list):
        indices = np.concatenate(indices)
        probs = np.concatenate(probs)
        logger.add_scalar(f"AUC-ROC/{run_type}/{dataset_label_names[i]}", roc_auc_score(indices, probs), global_step=epoch)

    return 0.5, 0.5 # Derive ratio of correct predictions.

In [98]:
def train_config(
    hidden_channels = 128,
    head_depth = 3, 
    base_depth = 5, 
    base_dropout = 0.5, 
    head_dropout = 0.5, 
    lr = 1e-2, 
    epochs = 100, 
    config_comment = ""
    ):

    model = GCN(
        n_hidden_channels = hidden_channels,
        n_graph_layers = base_depth, 
        n_graph_dropout = base_dropout, 
        n_linear_layers = head_depth, 
        n_linear_dropout = head_dropout
    ).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    logger = SummaryWriter(comment = config_comment)

    test(model, train_loader, 0, logger, run_type="train")
    test(model, test_loader, 0, logger, run_type="validation")

    for epoch in range(1, epochs + 1):
    
        train(model, optimizer, train_loader, epoch, logger)
        test(model, train_loader, epoch, logger, run_type="train")
        test(model, test_loader, epoch, logger, run_type="validation")

        logger.flush()

# Setup Simple GridSearch

In [99]:
import itertools

def dict_product(dicts):

    return (dict(zip(dicts, x)) for x in itertools.product(*dicts.values()))

In [100]:
search_grid = {
    "hidden_channels": [64, 256, 1028],
    "head_depth": [1,2,3,4],
    "base_depth": [3,5,10],
    "base_dropout": [0.5, 0.2],
    "head_dropout": [0.5, 0.2],
    "lr": [1e-2, 1e-3]
}

In [103]:
def search_configs(search_grid, randomly_try_n = -1):

    configurations = [config for config in dict_product(search_grid)]
    print(f"Total number of Grid-Search configurations: {len(configurations)}")

    if randomly_try_n == -1:
        do_indices = range(len(configurations))
    else:
        do_indices = np.random.choice(len(configurations), size=randomly_try_n)
    
    print(f"Number of configurations now being trained {len(do_indices)}")
    print("--------------------------------------------------------------------------------------------\n")
    
    for idx in do_indices:
        
        config = configurations[idx]

        config_str = str(config).replace("'","").replace(":", "-").replace(" ", "").replace("}", "").replace("_","").replace(",", "_").replace("{","_")

        print(f"Training config {config_str} ... ", end="")
        dt = time()
    
        train_config(
            hidden_channels = config["hidden_channels"], 
            head_depth = config["head_depth"], 
            base_depth =  config["base_depth"], 
            base_dropout =  config["base_dropout"], 
            head_dropout =  config["head_dropout"], 
            lr =  config["lr"], 
            config_comment = config_str
            )
            
        print(f"done (took {time() - dt:.2f}s)")

In [104]:
search_configs(search_grid, 10)

Total number of Grid-Search configurations: 288
done (took 308.88s)
Doing config _hiddenchannels-64_headdepth-2_basedepth-3_basedropout-0.5_headdropout-0.2_lr-0.001 ... 