In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger
from mango import scheduler, Tuner
from sklearn.metrics import r2_score, root_mean_squared_error
from torch.nn import (
    BatchNorm1d,
    L1Loss,
    BCEWithLogitsLoss,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch_geometric.nn import (
    GINEConv,
    GPSConv,
    GraphNorm,
    SAGPooling,
    SetTransformerAggregation,
)
from rdkit import Chem
from sklearn.feature_selection import r_regression

In [None]:
%load_ext kedro.ipython

In [None]:
merged_df = catalog.load("merged_database")

In [None]:
merged_df

In [None]:
mp_list = []
for mp in merged_df.MP.values:
    if mp<=80:
        mp=0
    else:
        mp=1
    mp_list.append(mp)
merged_df.MP = pd.Series(mp_list)

In [None]:
import typing as t

import torch_geometric
import torch_geometric.transforms as T
from gnn_mp_model.pipelines.data_featurization.utils import from_smiles
from rdkit import Chem
from rdkit.Chem.Descriptors import MolWt
from sklearn.model_selection import train_test_split
from torch_geometric.loader import DataLoader

transform = T.AddRandomWalkPE(walk_length=30, attr_name="pe")
def _generate_graph_list(df: pd.DataFrame) -> t.List:
    data_list = []
    for _, row in df.iterrows():
        smiles = row["smiles"]
        label = row["MP"]
        graph = from_smiles(smiles)
        graph.y = label
        graph = transform(graph)
        smiles_list = smiles.split(".")
        if "+" in smiles_list[0]:
            cation_smiles = smiles_list[0]
            anion_smiles = smiles_list[1]
        else:
            cation_smiles = smiles_list[1]
            anion_smiles = smiles_list[0]
        cation = Chem.MolFromSmiles(cation_smiles)
        anion = Chem.MolFromSmiles(anion_smiles)
        cation_wt = MolWt(cation)
        anion_wt = MolWt(anion)
        cation_natoms = len(cation.GetAtoms())
        cation_nbonds = len(cation.GetBonds())
        anion_natoms = len(anion.GetAtoms())
        anion_nbonds = len(anion.GetBonds())
        graph.gf = torch.tensor(
            [
                cation_wt,
                cation_natoms,
                cation_nbonds,
                anion_wt,
                anion_natoms,
                anion_nbonds,
            ]
        ).unsqueeze(0)
        data_list.append(graph)
    return data_list


# Nodes
def generate_graph_loader(
    df: pd.DataFrame, train: bool, batch_size: int
) -> torch_geometric.loader.DataLoader:
    data_list = _generate_graph_list(df)
    graph_loader = DataLoader(
        data_list,
        batch_size=batch_size,
        shuffle=train,
        drop_last=True,
    )
    return graph_loader


def random_data_split(df: pd.DataFrame, split_ratio: float) -> t.Tuple:
    df_train, df_test = train_test_split(df, test_size=split_ratio, random_state=42)
    return df_train, df_test


In [None]:
df_train, df_test = random_data_split(merged_df, split_ratio=0.1)

In [None]:
train_dataloader = generate_graph_loader(df_train, True, 32)
test_dataloader = generate_graph_loader(df_train, False, 32)
predict_dataloader = generate_graph_loader(merged_df, False, 1)

In [None]:
batch = next(iter(train_dataloader))

In [None]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_size: int, dense_size: int, num_layers:int, pooling: bool):  # noqa: PLR0913
        # Loading params
        super().__init__()
        self.num_layers = num_layers
        self.pooling = pooling
        node_dim = 9
        edge_dim = 4
        pe_dim = 8
        gf_dim = 6
        pool_rate = 0.65
        # Initial embeddings
        self.node_emb = Linear(pe_dim+node_dim, hidden_size)
        self.pe_lin = Linear(30, pe_dim)
        self.pe_norm = BatchNorm1d(30)
        self.edge_emb = Linear(edge_dim, hidden_size)
        self.aggr = SetTransformerAggregation(hidden_size)
        # PNA
        self.gps_list = ModuleList([])
        self.gn_list = ModuleList([])
        self.aggr_list = ModuleList([])
        self.pool_list = ModuleList([])
        # Initial layers
        for _ in range(self.num_layers):
            nn = Sequential(
                Linear(hidden_size, hidden_size),
                ReLU(),
                Linear(hidden_size, hidden_size))
            self.gps_list.append(GPSConv(hidden_size, GINEConv(nn, edge_dim=hidden_size), heads=4, dropout=0.2))
            self.gn_list.append(GraphNorm(hidden_size))
            self.aggr_list.append(SetTransformerAggregation(hidden_size))
            self.pool_list.append(SAGPooling(hidden_size, pool_rate))

        # Linear layers
        self.linear1 = Linear(hidden_size+gf_dim, dense_size)
        self.linear2 = Linear(dense_size, int(dense_size / 2))
        self.linear3 = Linear(int(dense_size / 2), 1)

    def forward(self, x, pe, edge_attr, edge_index, batch_index, gf):  # noqa: PLR0913
        # Initial embeddings
        x_pe = self.pe_norm(pe)
        x = torch.cat((x, self.pe_lin(x_pe)), 1)
        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)
        global_representation = []
        global_representation.append(self.aggr(x, batch_index))
        ### Internal convolutions
        for i in range(self.num_layers):
            x = self.gps_list[i](x, edge_index, batch_index, edge_attr=edge_attr)
            x = self.gn_list[i](x, batch_index)
            if self.pooling is True:
                x, edge_index, edge_attr, batch_index, _, _ = self.pool_list[i](
                x, edge_index, edge_attr, batch_index
            )
            global_representation.append(self.aggr_list[i](x, batch_index))
        ### Output block
        x = sum(global_representation)
        x = torch.cat((x, gf), 1)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.linear3(x)
        return x

In [None]:
# model = GNN(32,32,3,True)
# x = model(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch, batch.gf.float()).squeeze()
# train_accuracy = Accuracy(task="binary")

In [None]:
import lightning as L
from torchmetrics import Accuracy


class GNN_L(L.LightningModule):
    def __init__(self, params):
        super().__init__()
        self.model = GNN(params["hidden_size"], params["dense_size"], params["num_layers"], params["pooling"])
        self.lr = params["lr"]
        self.weight_decay = params["weight_decay"]
        self.gamma = params["gamma"]
        self.loss_fn = BCEWithLogitsLoss()
        self.train_accuracy = Accuracy(task="binary")
        self.val_accuracy = Accuracy(task="binary")
        self.save_hyperparameters(params)

    def forward(self, x, pe, edge_attr, edge_index, batch_index, gf):  # noqa: PLR0913
        return self.model(x.float(), pe, edge_attr.float(), edge_index, batch_index, gf)

    def training_step(self, batch, batch_index):
        preds = self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch, batch.gf.float()).squeeze()
        target = batch.y.float()
        loss = self.loss_fn(preds, target)
        acc = self.train_accuracy(preds, target)
        self.log("CEL", loss)
        self.log("train_acc", acc)
        return loss

    def validation_step(self, batch, batch_index):
        preds = self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch, batch.gf.float()).squeeze()
        target = batch.y.float()
        acc = self.val_accuracy(preds, target)
        self.log("val_acc", acc)
    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        preds = [self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch, batch.gf.float()).squeeze(), batch.smiles, batch.y]
        return preds
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
            weight_decay = self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=self.gamma
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

In [None]:
# Trainer

def train_model(params):
    filename = "GPS_Main_Model_class"
    L.seed_everything(42)
    model = GNN_L(params)
    early_stopping = EarlyStopping("val_acc", patience=10, mode="max", strict=False)
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    checkpoint_callback = ModelCheckpoint(filename="{epoch}-{val_acc:.2f}",
                                          monitor="val_acc",
                                          save_top_k=2,
                                          mode="max")

    logger = CSVLogger(save_dir="logs", name=filename)
    trainer = L.Trainer(
        max_epochs=75,
        callbacks=[early_stopping, lr_monitor, checkpoint_callback],
        log_every_n_steps=20,
        logger=logger,
        deterministic=True,
        accumulate_grad_batches=1
    )
    # Model pretraining
    trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    return model, trainer

In [None]:
params = {
    "lr": 0.001,
    "weight_decay": 1e-05,
    "gamma": 0.95,
    "hidden_size": 32,
    "dense_size": 64,
    "num_layers": 5,
    "pooling": False
}
model, trainer = train_model(params)

In [None]:
model_ckpt = GNN_L.load_from_checkpoint("all_data_model.ckpt")

In [None]:
def predict_on_df(model, predict_dataloader):
    preds = trainer.predict(model, predict_dataloader)
    preds_list = []
    target_list = []
    smiles_list = []
    for item in preds:
        preds_list.append(float(item[0]))
        target_list.append(float(item[2]))
        smiles_list.append(item[1][0])
    df = pd.DataFrame(data=[preds_list, target_list, smiles_list]).transpose()
    df.columns = ["preds", "target", "smiles"]
    df["error"] = abs(df["preds"]- df["target"]).astype(float)
    return df

def add_cation_anion_smiles(df):
    cation_list = []
    anion_list = []
    for _, row in df.iterrows():
        smiles_list = row["smiles"].split(".")
        if "+" in smiles_list[0]:
            cation_list.append(smiles_list[0])
            anion_list.append(smiles_list[1])
        else:
            cation_list.append(smiles_list[1])
            anion_list.append(smiles_list[0])

    df["cation_smiles"] = cation_list
    df["anion_smiles"] = anion_list
    return df

In [None]:
df = predict_on_df(model_ckpt, predict_dataloader)
df = add_cation_anion_smiles(df)