In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger
from mango import Tuner
from sklearn.metrics import r2_score, root_mean_squared_error
from torch.nn import (
    BatchNorm1d,
    L1Loss,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch_geometric.nn import (
    GINEConv,
    GPSConv,
    GraphNorm,
    SAGPooling,
    SetTransformerAggregation,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext kedro.ipython

In [3]:
train_dataloader = catalog.load("kp_train_dataloader")
test_dataloader = catalog.load("kp_test_dataloader")

In [5]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_size: int, dense_size: int):  # noqa: PLR0913
        # Loading params
        super().__init__()
        self.num_layers = 2
        node_dim = 9
        edge_dim = 4
        pe_dim = 8
        # Initial embeddings
        self.node_emb = Linear(pe_dim+node_dim, hidden_size)
        self.pe_lin = Linear(30, pe_dim)
        self.pe_norm = BatchNorm1d(30)
        self.edge_emb = Linear(edge_dim, hidden_size)
        self.aggr = SetTransformerAggregation(hidden_size)
        # PNA
        self.gps_list = ModuleList([])
        self.gn_list = ModuleList([])
        self.aggr_list = ModuleList([])
        self.pool_list = ModuleList([])
        # Initial layers
        for _ in range(self.num_layers):
            nn = Sequential(
                Linear(hidden_size, hidden_size),
                ReLU(),
                Linear(hidden_size, hidden_size))
            self.gps_list.append(GPSConv(hidden_size, GINEConv(nn), heads=4, dropout=0.2))
            self.gn_list.append(GraphNorm(hidden_size))
            self.aggr_list.append(SetTransformerAggregation(hidden_size))

        # Linear layers
        self.linear1 = Linear(hidden_size, dense_size)
        self.linear2 = Linear(dense_size, int(dense_size / 2))
        self.linear3 = Linear(int(dense_size / 2), 1)

    def forward(self, x, pe, edge_attr, edge_index, batch_index):  # noqa: PLR0913
        # Initial embeddings
        x_pe = self.pe_norm(pe)
        x = torch.cat((x, self.pe_lin(x_pe)), 1)
        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)
        ###Initial convolution
        global_representation = []
        global_representation.append(self.aggr(x, batch_index))
        ### Internal convolutions
        for i in range(self.num_layers):
            x = self.gps_list[i](x, edge_index, batch_index, edge_attr=edge_attr)
            x = self.gn_list[i](x, batch_index)
            global_representation.append(self.aggr_list[i](x, batch_index))
        ### Output block
        x = sum(global_representation)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.linear3(x)
        return x

In [17]:
model = GNN(32,32)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=20,
                              min_lr=0.00001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def train():
    model.train()

    total_loss = 0
    for data in train_dataloader:
        data = data.to(device)
        optimizer.zero_grad()
        out = model(data.x.float(), data.pe, data.edge_attr.float(), data.edge_index,
                    data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_dataloader.dataset)

In [18]:
@torch.no_grad()
def test():
    model.eval()

    total_error = 0
    for data in test_dataloader:
        data = data.to(device)
        out = model(data.x.float(), data.pe, data.edge_attr.float(), data.edge_index,
                    data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(test_dataloader.dataset)

In [19]:

for epoch in range(1, 101):
    loss = train()
    val_mae = test()
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {val_mae:.4f}')

Epoch: 01, Loss: 0.7849, Val: 0.7535, Test: 0.7535
Epoch: 02, Loss: 0.7857, Val: 0.7440, Test: 0.7440
Epoch: 03, Loss: 0.7799, Val: 0.7369, Test: 0.7369
Epoch: 04, Loss: 0.7721, Val: 0.7351, Test: 0.7351
Epoch: 05, Loss: 0.7704, Val: 0.7340, Test: 0.7340
Epoch: 06, Loss: 0.7681, Val: 0.7337, Test: 0.7337
Epoch: 07, Loss: 0.7653, Val: 0.7335, Test: 0.7335
Epoch: 08, Loss: 0.7633, Val: 0.7333, Test: 0.7333
Epoch: 09, Loss: 0.7679, Val: 0.7331, Test: 0.7331
