In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger
from pytorch_lightning.utilities.model_summary import ModelSummary
from rdkit import Chem
from rdkit.Chem import AllChem, Draw
from rdkit.Chem.Draw import IPythonConsole
from sklearn.metrics import r2_score as sklearn_r2
from sklearn.metrics import root_mean_squared_error
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler
from torch.nn import BatchNorm1d, HuberLoss, L1Loss, Linear, ModuleList, MSELoss
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch_geometric.nn import (
    GATv2Conv,
    GCNConv,
    GINConv,
    GraphNorm,
    PNAConv,
    SAGPooling,
    global_add_pool,
    global_max_pool,
    global_mean_pool,
)
from torcheval.metrics.functional import r2_score
from twinning import twin


In [2]:
%load_ext kedro.ipython

In [3]:
df = catalog.load("merged_database")
train_dataloader = catalog.load("sys_train_dataloader")
test_dataloader = catalog.load("sys_test_dataloader")

In [4]:
# Plot data split

# fig, axs = plt.subplots(1,2,sharey=True, tight_layout=True)
# n_bins = 30
# axs[0].hist(df_strain["MP"], bins=n_bins, color="b")
# axs[0].hist(df_stest["MP"], bins=n_bins, color="r")
# axs[1].hist(df_train["MP"], bins=n_bins, color="b")
# axs[1].hist(df_test["MP"], bins=n_bins, color="r")


In [5]:
# Compute the maximum in-degree in the training data.
from torch_geometric.utils import degree

max_degree = -1
for batch in train_dataloader:
    for i in range(batch.batch_size):
        data = batch[i]
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        max_degree = max(max_degree, int(d.max()))

# Compute the in-degree histogram tensor
deg = torch.zeros(max_degree + 1, dtype=torch.long)
for batch in train_dataloader:
    for i in range(batch.batch_size):
        data = batch[i]
        d = degree(data.edge_index[1], num_nodes=data.num_nodes, dtype=torch.long)
        deg += torch.bincount(d, minlength=deg.numel())

In [6]:
class GNN(torch.nn.Module):
    def __init__(self, model_type: str, num_layers: int, pooling_method:str = "add"):
        # Loading params
        super().__init__()
        self.pooling_method = pooling_method
        self.model_type = model_type
        self.num_layers = num_layers
        node_dim = 31
        edge_dim = 11
        hidden_size = 32
        dense_size = 16
        # PNA
        self.pan_list = ModuleList([])
        self.gn_list = ModuleList([])
        self.pan = PNAConv(
                in_channels=node_dim,
                out_channels = hidden_size,
                edge_dim=edge_dim,
                aggregators=["mean", "min", "max", "std"],
                scalers=["identity", "amplification", "attenuation"],
                deg=deg,
                towers=1)
        self.gn = GraphNorm(hidden_size)
        self.pool = SAGPooling(hidden_size, 0.5)
        for _ in range(self.num_layers):
            self.pan_list.append(PNAConv(
                in_channels=hidden_size,
                out_channels = hidden_size,
                edge_dim=edge_dim,
                aggregators=["mean", "min", "max", "std"],
                scalers=["identity", "amplification", "attenuation"],
                deg=deg,
                towers=1))
            self.gn_list.append(GraphNorm(hidden_size))
        # GAT
        num_heads = 3
        self.gat_list = ModuleList([])
        self.gat = GATv2Conv(
                in_channels=node_dim,
                out_channels = hidden_size,
                edge_dim=edge_dim,
                heads = num_heads,
                concat=False)
        for _ in range(self.num_layers):
            self.gat_list.append(GATv2Conv(
                in_channels=hidden_size,
                out_channels = hidden_size,
                edge_dim=edge_dim,
                heads = num_heads,
                concat=False))
        # GCN
        self.gcn_list = ModuleList([])
        self.gcn = GCNConv(in_channels=node_dim,
                           out_channels=hidden_size,
                           improved=True)
        for _ in range(self.num_layers):
            self.gcn_list.append(GCNConv(in_channels=hidden_size,
                           out_channels=hidden_size,
                           improved=True))
        # GIN
        self.gin_list = ModuleList([])
        self.linear_gin1 = Linear(node_dim, hidden_size)
        self.linear_gin2 = Linear(hidden_size, hidden_size)
        self.gin = GINConv(nn=self.linear_gin1)
        for _ in range(self.num_layers):
            self.gin_list.append(GINConv(nn=self.linear_gin2))
        # Linear layers
        self.linear1 = Linear(hidden_size, dense_size)
        self.linear2 = Linear(dense_size, int(dense_size / 2))
        self.linear3 = Linear(int(dense_size / 2), 1)

    def forward(self, x, edge_attr, edge_index, batch_index):
        # PNA
        if self.model_type == "PNA":
            x = self.pan(x, edge_index, edge_attr)
            x = self.gn(x, batch_index)
            x, edge_index, edge_attr, batch_index, _, _ = self.pool(
            x, edge_index, edge_attr, batch_index
        )
            for i in range(self.num_layers):
                x = self.pan_list[i](x, edge_index, edge_attr)
                x = self.gn_list[i](x, batch_index)
        # GAT
        if self.model_type == "GAT":
            x = self.gat(x, edge_index, edge_attr)
            x = self.gn(x, batch_index)
            for i in range(self.num_layers):
                x = self.gat_list[i](x, edge_index, edge_attr)
                x = self.gn_list[i](x, batch_index)
        # GCN
        if self.model_type == "GCN":
            x = self.gcn(x, edge_index)
            x = self.gn(x, batch_index)
            for i in range(self.num_layers):
                x = self.gcn_list[i](x, edge_index)
                x = self.gn_list[i](x, batch_index)
        # GIN
        if self.model_type == "GIN":
            x = self.gin(x, edge_index)
            x = self.gn(x, batch_index)
            for i in range(self.num_layers):
                x = self.gin_list[i](x, edge_index)
                x = self.gn_list[i](x, batch_index)
        # Output block
        if self.pooling_method == "add":
            x = global_add_pool(x, batch_index)
        elif self.pooling_method =="mean":
            x = global_mean_pool(x, batch_index)
        elif self.pooling_method =="max":
            x = global_max_pool(x, batch_index)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.2)
        x = torch.relu(self.linear2(x))
        x = self.linear3(x)
        return x

In [7]:
# model = GNN("PNA", num_layers=0, pooling_method="max")
# model(batch.x, batch.edge_attr, batch.edge_index, batch.batch)

In [12]:
import lightning as L


class GNN_L(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model
        self.lr = 0.001
        self.loss_fn = MSELoss()

    def forward(self, x, edge_attr, edge_index, batch_index):
        return self.model(
            x.float(), edge_attr.float(), edge_index, batch_index
        )

    def training_step(self, batch, batch_nb):
        preds = self(
            batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch
        ).squeeze()
        target = batch.y.float()
        loss = self.loss_fn(preds, target)
        r2 = r2_score(preds, target)
        rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("r2", r2)
        self.log("rmse", rmse)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_nb):
        preds = self(
            batch.x.float(), batch.edge_attr.float(), batch.edge_index, batch.batch
        ).squeeze()
        target = batch.y.float()
        val_r2 = r2_score(preds, target)
        val_rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("val_r2", val_r2)
        self.log("val_rmse", val_rmse)
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=0.98
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

In [13]:
# Trainer
def train_model(train_dataloader, test_dataloader, model_type, num_layers, pooling_method):
    L.seed_everything(42)
    model = GNN_L(GNN(model_type, num_layers=num_layers, pooling_method=pooling_method))
    early_stopping = EarlyStopping("val_rmse", patience=5)
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    checkpoint_callback = ModelCheckpoint(filename="{epoch}-{loss:.2f}-{rmse:.2f}",
                                          monitor="val_rmse",
                                          save_top_k=2,
                                          mode="min")
    logger = CSVLogger(save_dir="logs", name=f"{model_type}-{num_layers}-{pooling_method}")
    trainer = L.Trainer(
        max_epochs=50,
        callbacks=[early_stopping, lr_monitor, checkpoint_callback],
        log_every_n_steps=5,
        logger=logger,
        deterministic=True,
        accumulate_grad_batches=1,
        enable_progress_bar= True
    )
    trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    return model


In [14]:
def mass_train():
    model_list = ["PNA"]
    layers_range = range(0,4,2)
    pooling_list = ["max"]
    for model_type in model_list:
        for num_layers in layers_range:
            for pooling_method in pooling_list:
                train_model(train_dataloader=train_dataloader,
                            test_dataloader = test_dataloader,
                            model_type=model_type,
                            num_layers=num_layers,
                            pooling_method=pooling_method)


In [15]:
mass_train()

Seed set to 42
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | model   | GNN     | 28.5 K | train
1 | loss_fn | MSELoss | 0      | train
--------------------------------------------
28.5 K    Trainable params
0         Non-trainable params
28.5 K    Total params
0.114     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

                                                                           

Epoch 16: 100%|██████████| 20/20 [00:01<00:00, 10.78it/s, v_num=5]

Seed set to 42





GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs

  | Name    | Type    | Params | Mode 
--------------------------------------------
0 | model   | GNN     | 81.7 K | train
1 | loss_fn | MSELoss | 0      | train
--------------------------------------------
81.7 K    Trainable params
0         Non-trainable params
81.7 K    Total params
0.327     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

                                                                           

Epoch 19: 100%|██████████| 20/20 [00:02<00:00,  6.92it/s, v_num=5]
