In [57]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger
from mango import Tuner
from sklearn.metrics import r2_score, root_mean_squared_error
from torch.nn import (
    Embedding,
    BatchNorm1d,
    L1Loss,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch_geometric.nn import (
    GINEConv,
    GPSConv,
    GraphNorm,
    SAGPooling,
    SetTransformerAggregation,
    global_max_pool
)


In [2]:
%load_ext kedro.ipython

In [3]:
train_dataloader = catalog.load("random_train_dataloader")
test_dataloader = catalog.load("random_test_dataloader")

In [7]:
batch = next(iter(train_dataloader))

In [15]:
batch


[1m[[0m
    [1;35mDataBatch[0m[1m([0m[33mx[0m=[1m[[0m[1;36m313[0m, [1;36m9[0m[1m][0m, [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m640[0m[1m][0m, [33medge_attr[0m=[1m[[0m[1;36m640[0m, [1;36m4[0m[1m][0m, [33msmiles[0m=[1m[[0m[1;36m16[0m[1m][0m, [33my[0m=[1m[[0m[1;36m16[0m[1m][0m, [33mpe[0m=[1m[[0m[1;36m313[0m, [1;36m30[0m[1m][0m, [33mbatch[0m=[1m[[0m[1;36m313[0m[1m][0m, [33mptr[0m=[1m[[0m[1;36m17[0m[1m][0m[1m)[0m,
    [1;35mDataBatch[0m[1m([0m[33mx[0m=[1m[[0m[1;36m97[0m, [1;36m9[0m[1m][0m, [33medge_index[0m=[1m[[0m[1;36m2[0m, [1;36m166[0m[1m][0m, [33medge_attr[0m=[1m[[0m[1;36m166[0m, [1;36m4[0m[1m][0m, [33msmiles[0m=[1m[[0m[1;36m16[0m[1m][0m, [33my[0m=[1m[[0m[1;36m16[0m[1m][0m, [33mpe[0m=[1m[[0m[1;36m97[0m, [1;36m30[0m[1m][0m, [33mbatch[0m=[1m[[0m[1;36m97[0m[1m][0m, [33mptr[0m=[1m[[0m[1;36m17[0m[1m][0m[1m)[0m
[1m][0m

In [84]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_size: int, dense_size: int):  # noqa: PLR0913
        # Loading params
        super().__init__()
        self.num_layers = 10
        node_dim = 9
        edge_dim = 4
        pool_rate = 0.5
        pe_dim = 8
        self.pool = SAGPooling(hidden_size, pool_rate)


        # Initial cation embeddings
        self.cation_node_emb = Linear(pe_dim+node_dim, hidden_size)
        self.cation_pe_lin = Linear(30, pe_dim)
        self.cation_pe_norm = BatchNorm1d(30)
        self.cation_edge_emb = Linear(edge_dim, hidden_size)
        # GPS_cation
        self.cation_gps_list = ModuleList([])
        for _ in range(self.num_layers):
            nn = Sequential(
                Linear(hidden_size, hidden_size),
                ReLU(),
                Linear(hidden_size, hidden_size))
            self.cation_gps_list.append(GPSConv(hidden_size, GINEConv(nn), heads=4, dropout=0.2))


        # Initial anion embeddings
        self.anion_node_emb = Linear(pe_dim+node_dim, hidden_size)
        self.anion_pe_lin = Linear(30, pe_dim)
        self.anion_pe_norm = BatchNorm1d(30)
        self.anion_edge_emb = Linear(edge_dim, hidden_size)
        # GPS_anion
        self.anion_gps_list = ModuleList([])
        for _ in range(self.num_layers):
            nn = Sequential(
                Linear(hidden_size, hidden_size),
                ReLU(),
                Linear(hidden_size, hidden_size))
            self.anion_gps_list.append(GPSConv(hidden_size, GINEConv(nn), heads=4, dropout=0.2))


        # Aggregation and Linear layers
        self.cation_aggr = SetTransformerAggregation(hidden_size)
        self.anion_aggr = SetTransformerAggregation(hidden_size)
        self.linear1 = Linear(2*hidden_size, dense_size)
        self.linear2 = Linear(dense_size, int(dense_size / 2))
        self.linear3 = Linear(int(dense_size / 2), 1)

    def forward(self, cation_graph, anion_graph):  # noqa: PLR0913
        ### Cation/anion variable assignment
        x_c, pe_c, edge_attr_c, edge_index_c, batch_index_c = (
            cation_graph.x.float(),
            cation_graph.pe,
            cation_graph.edge_attr.float(),
            cation_graph.edge_index,
            cation_graph.batch,
        )
        x_a, pe_a, edge_attr_a, edge_index_a, batch_index_a = (
            anion_graph.x.float(),
            anion_graph.pe,
            anion_graph.edge_attr.float(),
            anion_graph.edge_index,
            anion_graph.batch,
        )


        ### Initial cation embeddings
        x_pe_c = self.cation_pe_norm(pe_c)
        x_c = torch.cat((x_c, self.cation_pe_lin(x_pe_c)), 1)
        x_c = self.cation_node_emb(x_c)
        edge_attr_c = self.cation_edge_emb(edge_attr_c)
        ### Cation convolutions
        for i in range(self.num_layers):
            x_c = self.cation_gps_list[i](x_c, edge_index_c, batch_index_c, edge_attr=edge_attr_c)


        ### Initial anion embeddings
        x_pe_a = self.anion_pe_norm(pe_a)
        x_a = torch.cat((x_a, self.anion_pe_lin(x_pe_a)), 1)
        x_a = self.anion_node_emb(x_a)
        edge_attr_a = self.anion_edge_emb(edge_attr_a)
        ### Anion convolutions
        for i in range(self.num_layers):
            x_a = self.anion_gps_list[i](x_a, edge_index_a, batch_index_a, edge_attr=edge_attr_a)
        ### Output block
        x_c = self.cation_aggr(x_c, batch_index_c)
        x_a = self.anion_aggr(x_a, batch_index_a)
        x = torch.cat((x_c, x_a), 1)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.linear3(x)
        return x

In [85]:
model = GNN(64, 32)
x = model(batch[0], batch[1])

In [86]:
import lightning as L


class GNN_L(L.LightningModule):
    def __init__(self, model, lr: float, weight_decay: float, gamma: float):
        super().__init__()
        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_fn = L1Loss()
        self.gamma = gamma

    def forward(self, cation_graph, anion_graph):  # noqa: PLR0913
        return self.model(cation_graph, anion_graph)

    def training_step(self, batch, batch_index):
        preds = self(batch[0], batch[1]).squeeze()
        target = batch[0].y.float()
        loss = self.loss_fn(preds, target)
        r2 = r2_score(target.numpy(), preds.detach().numpy())
        rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("r2", r2)
        self.log("rmse", rmse)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_index):
        preds = self(batch[0], batch[1]).squeeze()
        target = batch[0].y.float()
        val_r2 = r2_score(target.numpy(), preds.detach().numpy())
        val_rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("val_r2", val_r2)
        self.log("val_rmse", val_rmse)
    def test_step(self, batch, batch_index):
        preds = self(batch[0], batch[1]).squeeze()
        target = batch[0].y.float()
        test_rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("test_rmse", test_rmse)
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        return self(batch[0], batch[1]).squeeze().squeeze()
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
            weight_decay = self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=self.gamma
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

In [89]:
# Trainer
def train_model(params):
    lr = params["lr"]
    weight_decay = params["weight_decay"]
    gamma = params["gamma"]
    filename = "GPS_Main_Model_1"
    L.seed_everything(42)
    model = GNN_L(GNN(32,32), lr, weight_decay, gamma)
    early_stopping = EarlyStopping("val_rmse", patience=10, mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    checkpoint_callback = ModelCheckpoint(filename="{epoch}-{loss:.2f}-{rmse:.2f}",
                                          monitor="val_rmse",
                                          save_top_k=2,
                                          mode="min")

    logger = CSVLogger(save_dir="logs", name=filename)
    trainer = L.Trainer(
        max_epochs=50,
        callbacks=[early_stopping, lr_monitor, checkpoint_callback],
        log_every_n_steps=20,
        logger=logger,
        deterministic=True,
        accumulate_grad_batches=1
    )
    # Model pretraining
    trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    result = trainer.test(model, test_dataloader)
    result = result[0]['test_rmse']
    return model

In [90]:
params = {
    "lr": 0.01,
    "weight_decay": 1e-05,
    "gamma": 0.95,
}
model = train_model(params)

Seed set to 42



  | Name    | Type   | Params | Mode 
-------------------------------------------
0 | model   | GNN    | 253 K  | train
1 | loss_fn | L1Loss | 0      | train
-------------------------------------------
253 K     Trainable params
0         Non-trainable params
253 K     Total params
1.012     Total estimated model params size (MB)


Epoch 12:  44%|████▍     | 74/167 [08:09<10:15,  0.15it/s, v_num=0]        
Epoch 21: 100%|██████████| 167/167 [01:01<00:00,  2.70it/s, v_num=2]
Testing DataLoader 0: 100%|██████████| 37/37 [00:03<00:00, 11.70it/s]
