In [1]:
%cd ..

/home/nikita/edu/competitions/admet


In [2]:
from collections import deque

import numpy as np
import pandas as pd

import torch
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
from torch.nn import functional as F

from lightning import pytorch as pl

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import RandomOverSampler

from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors

from chemprop import data, featurizers, models, nn

from tqdm import tqdm

In [3]:
df_train = pd.read_csv("data/train_admet.csv", index_col=0)
df_test = pd.read_csv("data/test_data.csv", index_col=0)
sample = pd.read_csv("data/sample.csv")

In [4]:
RDLogger.DisableLog("rdApp.*")

def get_decsriptors_df(smiles_list):
    descriptors_list = []

    for smiles in tqdm(smiles_list):
        descriptors_list.append(
            Descriptors.CalcMolDescriptors(Chem.MolFromSmiles(smiles), 0)
        )
    return pd.DataFrame(descriptors_list).fillna(0)

train_descriptors = get_decsriptors_df(df_train["Drug"])
test_descriptors = get_decsriptors_df(df_test["Drug"])

100%|██████████| 7939/7939 [00:50<00:00, 158.02it/s]
100%|██████████| 1221/1221 [00:08<00:00, 144.02it/s]


In [5]:
train_descriptors["Ipc"] = np.log(train_descriptors["Ipc"])
test_descriptors["Ipc"] = np.log(test_descriptors["Ipc"])

In [6]:
scaler = StandardScaler()
scaler.fit(train_descriptors)
train_descriptors = pd.DataFrame(scaler.transform(train_descriptors), columns=train_descriptors.columns)
test_descriptors = pd.DataFrame(scaler.transform(test_descriptors), columns=test_descriptors.columns)

In [7]:
df_train = pd.concat([df_train, train_descriptors], axis=1)
df_test = pd.concat([df_test, test_descriptors], axis=1)

In [8]:
descriptors = train_descriptors.columns
def get_descriptors_features(df):
    return df[descriptors]

In [9]:
df_trains = []
df_vals = []
df_tests = []
properties = df_train.property.unique()

for prop in properties:
    subset_train = df_train[df_train.property == prop]
    subset_train, subset_val = train_test_split(
        subset_train, test_size=0.2, random_state=75, stratify=subset_train.Y
    )
    sampler = RandomOverSampler()
    subset_train = sampler.fit_resample(subset_train, subset_train.Y)[0]
    df_trains.append(subset_train)
    df_vals.append(subset_val)
    df_tests.append(df_test[df_test.property == prop])

In [10]:
train_data_total = []
val_data_total = []
test_data_total = []

train_data_descriptors = []
val_data_descriptors = []
test_data_descriptors = []
for i in range(len(df_trains)):
    train_data_total.append(
        [
            data.MoleculeDatapoint.from_smi(smi, [y])
            for smi, y in zip(df_trains[i]["Drug"], df_trains[i]["Y"])
        ]
    )
    val_data_total.append(
        [
            data.MoleculeDatapoint.from_smi(smi, [y])
            for smi, y in zip(df_vals[i]["Drug"], df_vals[i]["Y"])
        ]
    )

    test_data_total.append(
        [data.MoleculeDatapoint.from_smi(smi) for smi in df_tests[i]["Drug"]]
    )

    train_data_descriptors.append(get_descriptors_features(df_trains[i]).to_numpy().astype(np.float32))
    val_data_descriptors.append(get_descriptors_features(df_vals[i]).to_numpy().astype(np.float32))
    test_data_descriptors.append(get_descriptors_features(df_tests[i]).to_numpy().astype(np.float32))

In [11]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_datasets = [
    data.MoleculeDataset(train_data, featurizer) for train_data in train_data_total
]
val_datasets = [
    data.MoleculeDataset(val_data, featurizer) for val_data in val_data_total
]
test_datasets = [
    data.MoleculeDataset(test_data, featurizer) for test_data in test_data_total
]

train_loaders = [
    data.build_dataloader(train_dataset, shuffle=False, batch_size=32)
    for train_dataset in train_datasets
]
val_loaders = [
    data.build_dataloader(val_dataset, shuffle=False, batch_size=32)
    for val_dataset in val_datasets
]
test_loaders = [
    data.build_dataloader(test_dataset, shuffle=False, batch_size=32)
    for test_dataset in test_datasets
]

train_feature_loaders = [
    torch.utils.data.DataLoader(train_data_descriptors[i], batch_size=32, shuffle=False)
    for i in range(len(train_datasets))
]
val_feature_loaders = [
    torch.utils.data.DataLoader(val_data_descriptors[i], batch_size=32, shuffle=False)
    for i in range(len(val_datasets))
]
test_feature_loaders = [
    torch.utils.data.DataLoader(test_data_descriptors[i], batch_size=32, shuffle=False)
    for i in range(len(test_datasets))
]

In [12]:
class CombinedLoader(torch.utils.data.DataLoader):
    def __init__(self, *loaders):
        self.loaders = loaders

    def __iter__(self):
        for item in zip(*self.loaders):
            yield item

    def __len__(self):
        return min(len(loader) for loader in self.loaders)


train_combined_loaders = [
    CombinedLoader(train_loaders[i], train_feature_loaders[i])
    for i in range(len(train_loaders))
]
val_combined_loaders = [
    CombinedLoader(val_loaders[i], val_feature_loaders[i])
    for i in range(len(val_loaders))
]
test_combined_loaders = [
    CombinedLoader(test_loaders[i], test_feature_loaders[i])
    for i in range(len(test_loaders))
]

In [13]:
class MoleculeCrusher(torch.nn.Module):
    def __init__(self, mpnn, input_dim, embedding_dim=256):
        super().__init__()
        self.mpnn = mpnn
        self.embedding_dim = embedding_dim

        self.embedder = Sequential(
            Linear(input_dim, embedding_dim),
            ReLU(),
            Linear(embedding_dim, embedding_dim),
            ReLU(),
            Linear(embedding_dim, embedding_dim),
        )

        prev_in_features = mpnn.predictor.ffn[0][0].in_features
        prev_out_features = mpnn.predictor.ffn[0][0].out_features
        self.mpnn.predictor.ffn[0][0] = Linear(
            prev_in_features + embedding_dim, prev_out_features
        )

        self.bn = BatchNorm1d(embedding_dim + prev_in_features)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=2e-5)

        # self.initialize_weights()

    @property
    def device(self):
        return next(self.parameters()).device

    def loss(self, pred, target):
        return F.binary_cross_entropy(pred, target, reduction="mean")

    def metric(self, pred, target):
        return roc_auc_score(target, pred)

    def initialize_weights(self):
        for m in self.modules():
            torch.nn.init.xavier_uniform_(m.weight)
            m.bias.data.fill_(0.0)

    def forward(self, bmg: data.collate.BatchMolGraph, features: torch.Tensor):
        features_embedding = self.embedder(features)
        mol_embedding = self.mpnn.agg(self.mpnn.message_passing(bmg), bmg.batch)

        embedding = self.bn(torch.cat([mol_embedding, features_embedding], dim=1))

        return self.mpnn.predictor(embedding)

In [14]:
mpnns = []
for i in range(len(train_datasets)):
    mp = nn.BondMessagePassing()
    agg = nn.MeanAggregation()
    ffn = nn.BinaryClassificationFFN()
    batch_norm = True
    metric_list = [
        nn.metrics.BinaryAUROCMetric(),
        nn.metrics.BinaryAccuracyMetric(),
        nn.metrics.BCEMetric(),
    ]

    mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)
    mpnns.append(mpnn)

mc_models = [MoleculeCrusher(mpnns[i], 210, 256) for i in range(len(mpnns))]

In [15]:
def train(model, train_loader, val_loader, epochs, model_name):
    best_score = -float("inf")
    best_it = 0
    for epoch in range(epochs):
        model.train()
        train_loop = tqdm(train_loader, desc=f"Train epoch {epoch}")
        losses = []
        for batch in train_loop:
            batch_mol, batch_feat = batch
            bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol

            bmg.V = bmg.V.to(model.device)
            bmg.E = bmg.E.to(model.device)
            bmg.edge_index = bmg.edge_index.to(model.device)
            bmg.rev_edge_index = bmg.rev_edge_index.to(model.device)
            bmg.batch = bmg.batch.to(model.device)
            target = target.to(model.device)
            batch_feat = batch_feat.to(model.device)
            pred = model.forward(bmg, batch_feat)

            model.optimizer.zero_grad()
            loss = model.loss(pred, target)
            loss.backward()
            model.optimizer.step()
            losses.append(loss.item())
            train_loop.set_postfix(loss=np.mean(losses))

        model.eval()
        val_loop = tqdm(val_loader, desc=f"Val epoch {epoch}")
        all_preds = []
        all_targets = []
        for batch in val_loop:
            batch_mol, batch_feat = batch
            bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol

            bmg.V = bmg.V.to(model.device)
            bmg.E = bmg.E.to(model.device)
            bmg.edge_index = bmg.edge_index.to(model.device)
            bmg.rev_edge_index = bmg.rev_edge_index.to(model.device)
            bmg.batch = bmg.batch.to(model.device)
            target = target.to(model.device)
            batch_feat = batch_feat.to(model.device)

            with torch.no_grad():
                pred = model.forward(bmg, batch_feat)

            all_preds.extend(pred.view(-1).tolist())
            all_targets.extend(target.view(-1).tolist())

        roc_auc = roc_auc_score(all_targets, all_preds)
        if roc_auc > best_score:
            best_score = roc_auc
            best_it = epoch

        print(f"Validation ROC AUC: {roc_auc}")
        torch.save(model.state_dict(), f"checkpoints/{model_name}_{epoch}.pt")
    print(f"Best score: {best_score} at iteration {best_it}")

In [217]:
for i in range(len(mc_models)):
    train(
        mc_models[i], train_combined_loaders[i], val_combined_loaders[i], 25, f"mpnn{i}"
    )

Train epoch 0: 100%|██████████| 138/138 [00:06<00:00, 22.97it/s, loss=0.572]
Val epoch 0: 100%|██████████| 32/32 [00:00<00:00, 61.34it/s]


Validation ROC AUC: 0.8344893399124849


Train epoch 1: 100%|██████████| 138/138 [00:04<00:00, 29.45it/s, loss=0.493]
Val epoch 1: 100%|██████████| 32/32 [00:00<00:00, 64.40it/s]


Validation ROC AUC: 0.8631645098221767


Train epoch 2: 100%|██████████| 138/138 [00:05<00:00, 27.54it/s, loss=0.44] 
Val epoch 2: 100%|██████████| 32/32 [00:00<00:00, 62.59it/s]


Validation ROC AUC: 0.8689600595847687


Train epoch 3: 100%|██████████| 138/138 [00:05<00:00, 27.53it/s, loss=0.393]
Val epoch 3: 100%|██████████| 32/32 [00:00<00:00, 54.63it/s]


Validation ROC AUC: 0.873626757285169


Train epoch 4: 100%|██████████| 138/138 [00:04<00:00, 28.58it/s, loss=0.347]
Val epoch 4: 100%|██████████| 32/32 [00:00<00:00, 57.86it/s]


Validation ROC AUC: 0.8717841293486019


Train epoch 5: 100%|██████████| 138/138 [00:04<00:00, 28.49it/s, loss=0.302]
Val epoch 5: 100%|██████████| 32/32 [00:00<00:00, 46.09it/s]


Validation ROC AUC: 0.8694992707072589


Train epoch 6: 100%|██████████| 138/138 [00:04<00:00, 27.70it/s, loss=0.258]
Val epoch 6: 100%|██████████| 32/32 [00:00<00:00, 56.00it/s]


Validation ROC AUC: 0.8693634981224592


Train epoch 7: 100%|██████████| 138/138 [00:04<00:00, 30.22it/s, loss=0.218]
Val epoch 7: 100%|██████████| 32/32 [00:00<00:00, 48.61it/s]


Validation ROC AUC: 0.8607826707631197


Train epoch 8: 100%|██████████| 138/138 [00:04<00:00, 29.43it/s, loss=0.182]
Val epoch 8: 100%|██████████| 32/32 [00:00<00:00, 66.05it/s]


Validation ROC AUC: 0.858881854575924


Train epoch 9: 100%|██████████| 138/138 [00:04<00:00, 30.84it/s, loss=0.15] 
Val epoch 9: 100%|██████████| 32/32 [00:00<00:00, 52.15it/s]


Validation ROC AUC: 0.8512165223598052


Train epoch 10: 100%|██████████| 138/138 [00:05<00:00, 26.59it/s, loss=0.128]
Val epoch 10: 100%|██████████| 32/32 [00:00<00:00, 57.71it/s]


Validation ROC AUC: 0.8359576234366757


Train epoch 11: 100%|██████████| 138/138 [00:05<00:00, 23.81it/s, loss=0.11] 
Val epoch 11: 100%|██████████| 32/32 [00:00<00:00, 47.81it/s]


Validation ROC AUC: 0.8240115755826584


Train epoch 12: 100%|██████████| 138/138 [00:05<00:00, 25.81it/s, loss=0.0996]
Val epoch 12: 100%|██████████| 32/32 [00:00<00:00, 48.53it/s]


Validation ROC AUC: 0.839167675263011


Train epoch 13: 100%|██████████| 138/138 [00:05<00:00, 26.06it/s, loss=0.0984]
Val epoch 13: 100%|██████████| 32/32 [00:00<00:00, 59.34it/s]


Validation ROC AUC: 0.8575512832448873


Train epoch 14: 100%|██████████| 138/138 [00:05<00:00, 27.04it/s, loss=0.0931]
Val epoch 14: 100%|██████████| 32/32 [00:00<00:00, 51.44it/s]


Validation ROC AUC: 0.850615243769978


Train epoch 15: 100%|██████████| 138/138 [00:05<00:00, 26.04it/s, loss=0.0658]
Val epoch 15: 100%|██████████| 32/32 [00:00<00:00, 53.67it/s]


Validation ROC AUC: 0.8388321230177203


Train epoch 16: 100%|██████████| 138/138 [00:05<00:00, 25.70it/s, loss=0.052] 
Val epoch 16: 100%|██████████| 32/32 [00:00<00:00, 50.44it/s]


Validation ROC AUC: 0.8284920708810476


Train epoch 17: 100%|██████████| 138/138 [00:05<00:00, 26.99it/s, loss=0.0457]
Val epoch 17: 100%|██████████| 32/32 [00:00<00:00, 62.99it/s]


Validation ROC AUC: 0.8302086242745865


Train epoch 18: 100%|██████████| 138/138 [00:04<00:00, 28.15it/s, loss=0.0433]
Val epoch 18: 100%|██████████| 32/32 [00:00<00:00, 48.74it/s]


Validation ROC AUC: 0.8209974242001056


Train epoch 19: 100%|██████████| 138/138 [00:04<00:00, 28.04it/s, loss=0.0444]
Val epoch 19: 100%|██████████| 32/32 [00:00<00:00, 56.79it/s]


Validation ROC AUC: 0.8303288799925518


Train epoch 20: 100%|██████████| 138/138 [00:05<00:00, 27.20it/s, loss=0.0492]
Val epoch 20: 100%|██████████| 32/32 [00:00<00:00, 57.65it/s]


Validation ROC AUC: 0.834855925891444


Train epoch 21: 100%|██████████| 138/138 [00:05<00:00, 25.53it/s, loss=0.0531]
Val epoch 21: 100%|██████████| 32/32 [00:00<00:00, 53.31it/s]


Validation ROC AUC: 0.8292252428389661


Train epoch 22: 100%|██████████| 138/138 [00:05<00:00, 25.40it/s, loss=0.0426]
Val epoch 22: 100%|██████████| 32/32 [00:00<00:00, 62.16it/s]


Validation ROC AUC: 0.8367567420786395


Train epoch 23: 100%|██████████| 138/138 [00:05<00:00, 26.39it/s, loss=0.0405]
Val epoch 23: 100%|██████████| 32/32 [00:00<00:00, 60.23it/s]


Validation ROC AUC: 0.8467321478447073


Train epoch 24: 100%|██████████| 138/138 [00:05<00:00, 25.05it/s, loss=0.0516]
Val epoch 24: 100%|██████████| 32/32 [00:00<00:00, 54.44it/s]


Validation ROC AUC: 0.8711246625081464
Best score: 0.873626757285169 at iteration 3


Train epoch 0: 100%|██████████| 62/62 [00:03<00:00, 16.20it/s, loss=0.691]
Val epoch 0: 100%|██████████| 9/9 [00:00<00:00, 30.50it/s]


Validation ROC AUC: 0.5652040816326531


Train epoch 1: 100%|██████████| 62/62 [00:03<00:00, 16.52it/s, loss=0.659]
Val epoch 1: 100%|██████████| 9/9 [00:00<00:00, 36.64it/s]


Validation ROC AUC: 0.6100000000000001


Train epoch 2:  60%|█████▉    | 37/62 [00:02<00:01, 18.99it/s, loss=0.58] 

In [218]:
mc_models[0].load_state_dict(torch.load("checkpoints/mpnn0_18.pt"))
mc_models[1].load_state_dict(torch.load("checkpoints/mpnn1_18.pt"))
mc_models[2].load_state_dict(torch.load("checkpoints/mpnn2_10.pt"))

  mc_models[0].load_state_dict(torch.load("checkpoints/mpnn0_18.pt"))
  mc_models[1].load_state_dict(torch.load("checkpoints/mpnn1_18.pt"))
  mc_models[2].load_state_dict(torch.load("checkpoints/mpnn2_10.pt"))


<All keys matched successfully>

In [226]:
for model in mc_models:
    model.to("cpu")

In [233]:
test_preds = []
for i in range(len(test_datasets)):
    property_preds = []
    for batch in test_combined_loaders[i]:
        batch_mol, batch_feat = batch
        bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol

        with torch.no_grad():
            pred = mc_models[i].forward(bmg, batch_feat)
        property_preds.extend(pred.view(-1).tolist())
    test_preds.append(property_preds)

test_preds = sum(test_preds, [])

In [239]:
sample["Y"] = test_preds
sample.to_csv("submissions/chemprop_multi.csv", index=False)