In [1]:
%cd ..

/home/nikita/edu/competitions/admet


In [2]:
from collections import deque

import numpy as np
import pandas as pd

import torch
from torch.nn import Sequential, Linear, ReLU, BatchNorm1d
from torch.nn import functional as F

from lightning import pytorch as pl

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import RandomOverSampler

from rdkit import Chem, RDLogger
from rdkit.Chem import Descriptors

from chemprop import data, featurizers, models, nn

from tqdm import tqdm

In [3]:
df_train = pd.read_csv("data/shuffled_final_extended_train_data.csv", index_col=0)
df_test = pd.read_csv("data/test_data.csv", index_col=0)
sample = pd.read_csv("data/sample.csv")

In [22]:
RDLogger.DisableLog("rdApp.*")

def get_decsriptors_df(smiles_list):
    descriptors_list = []

    for smiles in tqdm(smiles_list):
        descriptors_list.append(
            Descriptors.CalcMolDescriptors(Chem.MolFromSmiles(smiles), 0)
        )
    return pd.DataFrame(descriptors_list).fillna(0)

train_descriptors = get_decsriptors_df(df_train["Drug"])
test_descriptors = get_decsriptors_df(df_test["Drug"])

100%|██████████| 26214/26214 [04:19<00:00, 100.94it/s]
100%|██████████| 1221/1221 [00:08<00:00, 146.17it/s]


In [24]:
train_descriptors["Ipc"] = np.log(train_descriptors["Ipc"] + 1)
test_descriptors["Ipc"] = np.log(test_descriptors["Ipc"] + 1)

In [26]:
scaler = StandardScaler()
scaler.fit(train_descriptors.astype(pd.Float32Dtype))
train_descriptors = pd.DataFrame(
    scaler.transform(train_descriptors.astype(pd.Float32Dtype)),
    columns=train_descriptors.columns,
)
test_descriptors = pd.DataFrame(
    scaler.transform(test_descriptors.astype(pd.Float32Dtype)),
    columns=test_descriptors.columns,
)

  scaler.fit(train_descriptors.astype(pd.Float32Dtype))
  scaler.transform(train_descriptors.astype(pd.Float32Dtype)),
  scaler.transform(test_descriptors.astype(pd.Float32Dtype)),


In [30]:
df_train = pd.concat([df_train.reset_index(), train_descriptors], axis=1)
df_test = pd.concat([df_test, test_descriptors], axis=1)

In [33]:
descriptors = train_descriptors.columns
def get_descriptors_features(df):
    return df[descriptors]

In [34]:
df_trains = []
df_vals = []
df_tests = []
properties = df_train.property.unique()

for prop in properties:
    subset_train = df_train[df_train.property == prop]
    subset_train, subset_val = train_test_split(
        subset_train, test_size=0.2, random_state=75, stratify=subset_train.Y
    )
    sampler = RandomOverSampler()
    subset_train = sampler.fit_resample(subset_train, subset_train.Y)[0]
    df_trains.append(subset_train)
    df_vals.append(subset_val)
    df_tests.append(df_test[df_test.property == prop])

In [35]:
df_train_total = pd.concat(df_trains, axis=0)
df_val_total = pd.concat(df_vals, axis=0)
df_test_total = pd.concat(df_tests, axis=0)


df_train_total = pd.concat(
    [
        df_train_total,
        pd.get_dummies(df_train_total["property"], prefix="property").astype(
            np.float32
        ),
    ],
    axis=1,
)
df_val_total = pd.concat(
    [
        df_val_total,
        pd.get_dummies(df_val_total["property"], prefix="property").astype(np.float32),
    ],
    axis=1,
)
df_train_total["property_1"] = df_train_total["property_1"] * df_train_total["Y"]
df_train_total["property_2"] = df_train_total["property_2"] * df_train_total["Y"]
df_train_total["property_3"] = df_train_total["property_3"] * df_train_total["Y"]

df_val_total["property_1"] = df_val_total["property_1"] * df_val_total["Y"]
df_val_total["property_2"] = df_val_total["property_2"] * df_val_total["Y"]
df_val_total["property_3"] = df_val_total["property_3"] * df_val_total["Y"]

In [36]:
df_train_total = df_train_total.sample(frac=1)

In [68]:
train_data_total = []
val_data_total = []
test_data_total = []

train_data_descriptors = []
val_data_descriptors = []
test_data_descriptors = []
train_data_total = []

for smi, y, features in zip(
    df_train_total["Drug"],
    df_train_total[["property", "property_1", "property_2", "property_3"]].values,
    get_descriptors_features(df_train_total).to_numpy().astype(np.float32)
):
    try:
        train_data_total.append(data.MoleculeDatapoint.from_smi(smi, y))
        train_data_descriptors.append(features)
    except Exception:
        continue

val_data_total = []

for smi, y, features in zip(
    df_val_total["Drug"],
    df_val_total[["property", "property_1", "property_2", "property_3"]].values,
    get_descriptors_features(df_val_total).to_numpy().astype(np.float32)
):
    try:
        val_data_total.append(data.MoleculeDatapoint.from_smi(smi, y))
        val_data_descriptors.append(features)
    except Exception:
        continue

test_data_total = [
    data.MoleculeDatapoint.from_smi(smi, y)
    for smi, y in zip(df_test_total["Drug"], df_test_total[["property"]].values)
]

train_data_descriptors = np.array(train_data_descriptors)
val_data_descriptors = np.array(val_data_descriptors)
test_data_descriptors = (
    get_descriptors_features(df_test_total).to_numpy().astype(np.float32)
)

In [71]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()
train_dataset = data.MoleculeDataset(train_data_total, featurizer)
val_dataset = data.MoleculeDataset(val_data_total, featurizer)
test_dataset = data.MoleculeDataset(test_data_total, featurizer)

train_loader = data.build_dataloader(train_dataset, shuffle=False, batch_size=64)

val_loader = data.build_dataloader(val_dataset, shuffle=False, batch_size=64)

test_loader = data.build_dataloader(test_dataset, shuffle=False, batch_size=64)

train_feature_loader = torch.utils.data.DataLoader(
    train_data_descriptors, batch_size=64, shuffle=False
)
val_feature_loader = torch.utils.data.DataLoader(
    val_data_descriptors, batch_size=64, shuffle=False
)
test_feature_loader = torch.utils.data.DataLoader(
    test_data_descriptors, batch_size=64, shuffle=False
)

In [72]:
class CombinedLoader(torch.utils.data.DataLoader):
    def __init__(self, *loaders):
        self.loaders = loaders

    def __iter__(self):
        for item in zip(*self.loaders):
            yield item

    def __len__(self):
        return min(len(loader) for loader in self.loaders)


train_combined_loader = CombinedLoader(train_loader, train_feature_loader)

val_combined_loader = CombinedLoader(val_loader, val_feature_loader)

test_combined_loader = CombinedLoader(test_loader, test_feature_loader)

In [76]:
class MoleculeCrusher(torch.nn.Module):
    def __init__(self, mpnn, input_dim, embedding_dim=256):
        super().__init__()
        self.mpnn = mpnn
        self.embedding_dim = embedding_dim

        self.embedder = Sequential(
            Linear(input_dim, embedding_dim),
            ReLU(),
            Linear(embedding_dim, embedding_dim),
            ReLU(),
            Linear(embedding_dim, embedding_dim),
        )

        prev_in_features = mpnn.predictor.ffn[0][0].in_features
        prev_out_features = mpnn.predictor.ffn[0][0].out_features
        self.mpnn.predictor.ffn[0][0] = Linear(
            prev_in_features + embedding_dim, prev_out_features
        )
        self.mpnn.predictor.ffn[1][2] = Linear(prev_out_features, 3)

        self.bn = BatchNorm1d(embedding_dim + prev_in_features)

        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-4, weight_decay=3e-4)

    @property
    def device(self):
        return next(self.parameters()).device

    def loss(self, pred, target):
        idx = target[:, 0].to(torch.long).unsqueeze(1)
        target = torch.gather(target, 1, idx)
        pred = torch.gather(pred, 1, idx - 1)
        return F.binary_cross_entropy(pred, target, reduction="mean")

    def metric(self, pred, target):
        idx = target[:, 0].to(torch.long).unsqueeze(1)
        target = torch.gather(target, 1, idx)
        pred = torch.gather(pred, 1, idx - 1)
        return roc_auc_score(target, pred)

    def forward(self, bmg: data.collate.BatchMolGraph, features: torch.Tensor):
        features_embedding = self.embedder(features)
        mol_embedding = self.mpnn.agg(self.mpnn.message_passing(bmg), bmg.batch)

        embedding = self.bn(torch.cat([mol_embedding, features_embedding], dim=1))

        return self.mpnn.predictor(embedding)

In [77]:
mp = nn.BondMessagePassing()
agg = nn.MeanAggregation()
ffn = nn.BinaryClassificationFFN()
batch_norm = True
metric_list = [
    nn.metrics.BinaryAUROCMetric(),
    nn.metrics.BinaryAccuracyMetric(),
    nn.metrics.BCEMetric(),
]

mpnn = models.MPNN(mp, agg, ffn, batch_norm, metric_list)


mc_model = MoleculeCrusher(mpnn, 210, 256)

In [78]:
def train(model, train_loader, val_loader, epochs, model_name):
    best_score = -float("inf")
    best_it = 0
    for epoch in range(epochs):
        model.train()
        train_loop = tqdm(train_loader, desc=f"Train epoch {epoch}")
        losses = []
        for batch in train_loop:
            batch_mol, batch_feat = batch
            bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol

            bmg.V = bmg.V.to(model.device)
            bmg.E = bmg.E.to(model.device)
            bmg.edge_index = bmg.edge_index.to(model.device)
            bmg.rev_edge_index = bmg.rev_edge_index.to(model.device)
            bmg.batch = bmg.batch.to(model.device)
            target = target.to(model.device)
            batch_feat = batch_feat.to(model.device)
            pred = model.forward(bmg, batch_feat)

            model.optimizer.zero_grad()
            loss = model.loss(pred, target)
            loss.backward()
            model.optimizer.step()
            losses.append(loss.item())
            train_loop.set_postfix(loss=np.mean(losses))

        model.eval()
        val_loop = tqdm(val_loader, desc=f"Val epoch {epoch}")
        all_preds = []
        all_targets = []
        for batch in val_loop:
            batch_mol, batch_feat = batch
            bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol

            bmg.V = bmg.V.to(model.device)
            bmg.E = bmg.E.to(model.device)
            bmg.edge_index = bmg.edge_index.to(model.device)
            bmg.rev_edge_index = bmg.rev_edge_index.to(model.device)
            bmg.batch = bmg.batch.to(model.device)
            target = target.to(model.device)
            batch_feat = batch_feat.to(model.device)

            with torch.no_grad():
                pred = model.forward(bmg, batch_feat)

            all_preds.extend(pred.tolist())
            all_targets.extend(target.tolist())

        roc_auc = model.metric(torch.tensor(all_preds), torch.tensor(all_targets))
        if roc_auc > best_score:
            best_score = roc_auc
            best_it = epoch

        print(f"Validation ROC AUC: {roc_auc}")
        torch.save(model.state_dict(), f"checkpoints/{model_name}_{epoch}.pt")
    print(f"Best score: {best_score} at iteration {best_it}")

In [79]:
train(mc_model, train_combined_loader, val_combined_loader, 25, "mpnn_general")

Train epoch 0: 100%|██████████| 317/317 [00:27<00:00, 11.35it/s, loss=0.472]
Val epoch 0: 100%|██████████| 66/66 [00:02<00:00, 24.20it/s]


Validation ROC AUC: 0.9200653257065586


Train epoch 1: 100%|██████████| 317/317 [00:28<00:00, 11.27it/s, loss=0.311]
Val epoch 1: 100%|██████████| 66/66 [00:02<00:00, 23.50it/s]


Validation ROC AUC: 0.9492390832504196


Train epoch 2: 100%|██████████| 317/317 [00:30<00:00, 10.45it/s, loss=0.214]
Val epoch 2: 100%|██████████| 66/66 [00:03<00:00, 18.59it/s]


Validation ROC AUC: 0.9658040046329198


Train epoch 3: 100%|██████████| 317/317 [00:27<00:00, 11.49it/s, loss=0.156]
Val epoch 3: 100%|██████████| 66/66 [00:02<00:00, 22.81it/s]


Validation ROC AUC: 0.9755073018427958


Train epoch 4: 100%|██████████| 317/317 [00:28<00:00, 11.11it/s, loss=0.118]
Val epoch 4: 100%|██████████| 66/66 [00:02<00:00, 24.55it/s]


Validation ROC AUC: 0.9803839205796144


Train epoch 5: 100%|██████████| 317/317 [00:32<00:00,  9.88it/s, loss=0.0931]
Val epoch 5: 100%|██████████| 66/66 [00:03<00:00, 21.47it/s]


Validation ROC AUC: 0.9838131120083284


Train epoch 6: 100%|██████████| 317/317 [00:31<00:00,  9.96it/s, loss=0.0758]
Val epoch 6: 100%|██████████| 66/66 [00:02<00:00, 23.50it/s]


Validation ROC AUC: 0.9852766258437023


Train epoch 7: 100%|██████████| 317/317 [00:31<00:00, 10.14it/s, loss=0.0628]
Val epoch 7: 100%|██████████| 66/66 [00:02<00:00, 22.83it/s]


Validation ROC AUC: 0.9863310136721077


Train epoch 8: 100%|██████████| 317/317 [00:30<00:00, 10.47it/s, loss=0.0541]
Val epoch 8: 100%|██████████| 66/66 [00:02<00:00, 23.47it/s]


Validation ROC AUC: 0.9876893648366761


Train epoch 9:  44%|████▍     | 140/317 [00:13<00:17, 10.28it/s, loss=0.0527]


KeyboardInterrupt: 

In [80]:
mc_model.load_state_dict(torch.load("checkpoints/mpnn_general_5.pt"))
mc_model.to("cpu")
mc_model.eval()

  mc_model.load_state_dict(torch.load("checkpoints/mpnn_general_5.pt"))


MoleculeCrusher(
  (mpnn): MPNN(
    (message_passing): BondMessagePassing(
      (W_i): Linear(in_features=86, out_features=300, bias=False)
      (W_h): Linear(in_features=300, out_features=300, bias=False)
      (W_o): Linear(in_features=372, out_features=300, bias=True)
      (dropout): Dropout(p=0.0, inplace=False)
      (tau): ReLU()
      (V_d_transform): Identity()
      (graph_transform): Identity()
    )
    (agg): MeanAggregation()
    (bn): BatchNorm1d(300, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (predictor): BinaryClassificationFFN(
      (ffn): MLP(
        (0): Sequential(
          (0): Linear(in_features=556, out_features=300, bias=True)
        )
        (1): Sequential(
          (0): ReLU()
          (1): Dropout(p=0.0, inplace=False)
          (2): Linear(in_features=300, out_features=3, bias=True)
        )
      )
      (criterion): BCELoss(task_weights=[[1.0]])
      (output_transform): Identity()
    )
    (X_d_transform): Identity()

In [81]:
probs = []
props = []
targs = []

for batch in val_combined_loader:
    batch_mol, batch_feat = batch
    bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol
    with torch.no_grad():
        preds = mc_model.forward(bmg, batch_feat)
    idx = target[:, 0].to(torch.long)

    probs.append(preds.gather(1, idx.unsqueeze(1) - 1).squeeze(1))
    props.append(idx)
    targs.append(target.gather(1, idx.unsqueeze(1)).squeeze(1))

probs = torch.cat(probs, dim=0)
targs = torch.cat(targs, dim=0)
props = torch.cat(props, dim=0)

In [82]:
print(roc_auc_score(targs[props == 1], probs[props == 1]))
print(roc_auc_score(targs[props == 2], probs[props == 2]))
print(roc_auc_score(targs[props == 3], probs[props == 3]))

0.976649510607356
0.9801587301587301
0.9963877963509232


In [83]:
probs = []

for batch in test_combined_loader:
    batch_mol, batch_feat = batch
    bmg, V_d, X_d, target, weights, lt_mask, gt_mask = batch_mol
    with torch.no_grad():
        preds = mc_model.forward(bmg, batch_feat)
    idx = target[:, 0].to(torch.long)

    probs.append(preds.gather(1, idx.unsqueeze(1) - 1).squeeze(1))

probs = torch.cat(probs, dim=0)

In [84]:
sample["Y"] = probs
sample.to_csv("submissions/gnn_dirty.csv", index=False)