In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from lightning.pytorch.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
)
from lightning.pytorch.loggers import CSVLogger
from mango import Tuner
from sklearn.metrics import r2_score, root_mean_squared_error
from torch.nn import (
    BatchNorm1d,
    L1Loss,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch_geometric.nn import (
    GINEConv,
    GPSConv,
    GraphNorm,
    SAGPooling,
    SetTransformerAggregation,
)
from rdkit import Chem
from sklearn.feature_selection import r_regression

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
%load_ext kedro.ipython

In [48]:
df = catalog.load("kp_clean_database")
train_dataloader = catalog.load("kp_train_dataloader")
test_dataloader = catalog.load("kp_test_dataloader")

In [49]:
df

Unnamed: 0.1,Unnamed: 0,MP,smiles
0,0,0.252916,CCCCCCCCCCCCCC[N+](C)(C)Cc1ccccc1.COc1cc(ccc1O...
1,1,0.106805,C[N+](C)(C)C.CC(C)[C@H](N)C([O-])=O
2,2,-0.193424,CCCC[N+](CCCC)(CCCC)CCCC.CC(C)[C@H](N)C([O-])=O
3,3,-0.179413,CCCC[P+](CCCC)(CCCC)CCCC.CC(C)[C@H](N)C([O-])=O
4,4,-0.373561,CCCCC[N+](CCCCC)(CCCCC)CCCCC.FC(F)(F)C(=O)[N-]...
...,...,...,...
894,948,-0.994034,CCCC[NH3+].CCCCCCCC([O-])=O
895,949,-0.033302,[NH3+]CCO.CCCCCCCC([O-])=O
896,950,-1.494415,C1CC[NH2+]C1.CCCCCCCC([O-])=O
897,951,0.817346,[NH3+]CCO.CCCCCCCCCCCCCCCCCC([O-])=O


In [83]:
def add_mol_descriptors(df):
    from rdkit.Chem.Descriptors import CalcMolDescriptors

    values_list = []
    for _, row in df.iterrows():
        smiles = row["smiles"]
        mol = Chem.MolFromSmiles(smiles)
        desc = CalcMolDescriptors(mol)
        values = list(desc.values())[:-85]
        values_list.append(values)
    features = pd.DataFrame(values_list)
    features = features.dropna(axis=1)
    corelation = r_regression(features, df["MP"])
    corelation = [x if abs(x)>0.15 else np.NaN for x in corelation]
    features.loc[len(df)] = corelation
    features = features.dropna(axis=1)
    df = pd.concat([df, features], axis=1)
    df = df.drop(len(df)-1)
    df = df.drop(columns=[0])
    return df

In [84]:
df_f = add_mol_descriptors(df)



In [88]:
list(row)[3:]


[1m[[0m
    [1;36m10.332078609221465[0m,
    [1;36m10.332078609221465[0m,
    [1;36m2.8529759152452776[0m,
    [1;36m0.0[0m,
    [1;36m0.0[0m,
    [1;36m9.843390348640755[0m,
    [1;36m6.544756405912575[0m,
    [1;36m0.0[0m,
    [1;36m17.062475158264807[0m,
    [1;36m5.0[0m,
    [1;36m4.0[0m
[1m][0m

In [110]:
import typing as t

import pandas as pd
import torch_geometric
import torch_geometric.transforms as T
from gnn_mp_model.pipelines.data_featurization.utils import from_smiles
from torch_geometric.loader import DataLoader

transform = T.AddRandomWalkPE(walk_length=30, attr_name='pe')
def _generate_graph_list(df: pd.DataFrame) -> t.List:
    data_list = []
    for _, row in df.iterrows():
        smiles = row["smiles"]
        label = row["MP"]
        graph = from_smiles(smiles)
        graph.y = label
        graph = transform(graph)
        global_featuers = torch.tensor(np.array(list(row)[3:]))
        graph.gf = global_featuers
        data_list.append(graph)
    return data_list

def generate_graph_loader(df: pd.DataFrame) -> torch_geometric.loader.DataLoader:
    data_list = _generate_graph_list(df)
    graph_loader = DataLoader(
        data_list,
        batch_size=16,
        shuffle=True,
        drop_last=True,
    )
    return graph_loader

In [111]:
train_dataloader = generate_graph_loader(df_f)

In [112]:
batch = next(iter(train_dataloader))

In [128]:
class GNN(torch.nn.Module):
    def __init__(self, hidden_size: int, dense_size: int):  # noqa: PLR0913
        # Loading params
        super().__init__()
        self.num_layers = 2
        node_dim = 9
        edge_dim = 4
        pe_dim = 8
        # Initial embeddings
        self.node_emb = Linear(pe_dim+node_dim, hidden_size)
        self.pe_lin = Linear(30, pe_dim)
        self.pe_norm = BatchNorm1d(30)
        self.edge_emb = Linear(edge_dim, hidden_size)
        self.aggr = SetTransformerAggregation(hidden_size)
        # PNA
        self.gps_list = ModuleList([])
        self.gn_list = ModuleList([])
        self.aggr_list = ModuleList([])
        self.pool_list = ModuleList([])
        # Initial layers
        for _ in range(self.num_layers):
            nn = Sequential(
                Linear(hidden_size, hidden_size),
                ReLU(),
                Linear(hidden_size, hidden_size))
            self.gps_list.append(GPSConv(hidden_size, GINEConv(nn), heads=4, dropout=0.2))
            self.gn_list.append(GraphNorm(hidden_size))
            self.aggr_list.append(SetTransformerAggregation(hidden_size))

        # Linear layers
        self.linear1 = Linear(hidden_size+10, dense_size)
        self.linear2 = Linear(dense_size, int(dense_size / 2))
        self.linear3 = Linear(int(dense_size / 2), 1)

    def forward(self, x, pe, edge_attr, edge_index, batch_index, gf):  # noqa: PLR0913
        # Initial embeddings
        x_pe = self.pe_norm(pe)
        x = torch.cat((x, self.pe_lin(x_pe)), 1)
        x = self.node_emb(x)
        edge_attr = self.edge_emb(edge_attr)
        ###Initial convolution
        global_representation = []
        global_representation.append(self.aggr(x, batch_index))
        ### Internal convolutions
        for i in range(self.num_layers):
            x = self.gps_list[i](x, edge_index, batch_index, edge_attr=edge_attr)
            x = self.gn_list[i](x, batch_index)
            global_representation.append(self.aggr_list[i](x, batch_index))
        ### Output block
        x = sum(global_representation)
        x = torch.cat((x, gf.unsqueeze(0)), 1)
        x = torch.relu(self.linear1(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.relu(self.linear2(x))
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.linear3(x)
        return x

In [130]:
model = GNN(32,32)
x = model(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch, batch.gf.float())

In [131]:
x

[1;35mtensor[0m[1m([0m[1m[[0m[1m[[0m[1;36m-0.4903[0m[1m][0m[1m][0m, [33mgrad_fn[0m=[1m<[0m[1;95mAddmmBackward0[0m[1m>[0m[1m)[0m

In [8]:
import lightning as L


class GNN_L(L.LightningModule):
    def __init__(self, model, lr: float, weight_decay: float, gamma: float):
        super().__init__()
        self.model = model
        self.lr = lr
        self.weight_decay = weight_decay
        self.loss_fn = L1Loss()
        self.gamma = gamma

    def forward(self, x, pe, edge_attr, edge_index, batch_index):  # noqa: PLR0913
        return self.model(x.float(), pe, edge_attr.float(), edge_index, batch_index)

    def training_step(self, batch, batch_index):
        preds = self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch).squeeze()
        target = batch.y.float()
        loss = self.loss_fn(preds, target)
        r2 = r2_score(target.numpy(), preds.detach().numpy())
        rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("r2", r2)
        self.log("rmse", rmse)
        self.log("loss", loss)
        return loss

    def validation_step(self, batch, batch_index):
        preds = self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch).squeeze()
        target = batch.y.float()
        val_r2 = r2_score(target.numpy(), preds.detach().numpy())
        val_rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("val_r2", val_r2)
        self.log("val_rmse", val_rmse)
    def test_step(self, batch, batch_index):
        preds = self(batch.x.float(), batch.pe, batch.edge_attr.float(), batch.edge_index, batch.batch).squeeze()
        target = batch.y.float()
        test_rmse = root_mean_squared_error(target.numpy(), preds.detach().numpy())
        self.log("test_rmse", test_rmse)
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr=self.lr,
            weight_decay = self.weight_decay)
        scheduler = torch.optim.lr_scheduler.ExponentialLR(
            optimizer, gamma=self.gamma
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "epoch"}]

In [9]:
# Trainer
def train_model(params):
    lr = params["lr"]
    weight_decay = params["weight_decay"]
    gamma = params["gamma"]
    filename = "GPS_Main_Model_1"
    L.seed_everything(42)
    model = GNN_L(GNN(64,64), lr, weight_decay, gamma)
    early_stopping = EarlyStopping("val_rmse", patience=10, mode="min")
    lr_monitor = LearningRateMonitor(logging_interval="epoch")
    checkpoint_callback = ModelCheckpoint(filename="{epoch}-{loss:.2f}-{rmse:.2f}",
                                          monitor="val_rmse",
                                          save_top_k=2,
                                          mode="min")

    logger = CSVLogger(save_dir="logs", name=filename)
    trainer = L.Trainer(
        max_epochs=50,
        callbacks=[early_stopping, lr_monitor, checkpoint_callback],
        log_every_n_steps=5,
        logger=logger,
        deterministic=True,
        accumulate_grad_batches=1
    )
    # Model pretraining
    trainer.fit(model=model, train_dataloaders=train_dataloader, val_dataloaders=test_dataloader)
    result = trainer.test(model, test_dataloader)
    result = result[0]['test_rmse']
    return model

In [10]:
params = {
    "lr": 0.001,
    "weight_decay": 1e-05,
    "gamma": 0.95,
}
model = train_model(params)

Seed set to 42


Missing logger folder: logs\GPS_Main_Model_1

  | Name    | Type   | Params | Mode 
-------------------------------------------
0 | model   | GNN    | 292 K  | train
1 | loss_fn | L1Loss | 0      | train
-------------------------------------------
292 K     Trainable params
0         Non-trainable params
292 K     Total params
1.169     Total estimated model params size (MB)


Epoch 44: 100%|██████████| 50/50 [00:06<00:00,  8.21it/s, v_num=0]         
Testing DataLoader 0: 100%|██████████| 11/11 [00:00<00:00, 39.60it/s]
