In [1]:
%cd ..

/home/nikita/edu/competitions/admet


In [2]:
from collections import deque
from tqdm import tqdm

import numpy as np
import pandas as pd

import torch
from torch import nn
import torch.nn.functional as F

from torch_geometric.data import Data, DataLoader
from torch_geometric.nn import GINConv, GATv2Conv, GCNConv, Sequential
from torch_geometric.nn.aggr import AttentionalAggregation, MeanAggregation

from rdkit import Chem
from rdkit.Chem import AllChem

from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from sklearn.preprocessing import StandardScaler
from imblearn.over_sampling import RandomOverSampler

In [3]:
def smiles_to_graph(smiles):
    mol = Chem.MolFromSmiles(smiles)
    if mol is None:
        return None

    atom_features = []
    for atom in mol.GetAtoms():
        atom_features.append(atom.GetAtomicNum())

    edges = []
    for bond in mol.GetBonds():
        i = bond.GetBeginAtomIdx()
        j = bond.GetEndAtomIdx()
        edges.append((i, j))
        edges.append((j, i))

    # Convert to PyTorch tensors
    edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
    x = torch.tensor(atom_features, dtype=torch.float).view(-1, 1)

    return Data(x=x, edge_index=edge_index)

In [4]:
df_train = pd.read_csv("data/train_admet.csv", index_col=0)
df_test = pd.read_csv("data/test_data.csv", index_col=0)
sample = pd.read_csv("data/sample.csv")

In [5]:
df_trains = []
df_vals = []
df_tests = []

properties = df_train.property.unique()
for prop in properties:
    segment = df_train[df_train["property"] == prop]
    train, val = train_test_split(
        segment, test_size=0.2, random_state=75, stratify=segment.Y
    )
    df_trains.append(train)
    df_vals.append(val)
    df_tests.append(df_test[df_test["property"] == prop])

sampler = RandomOverSampler(random_state=0)

for i in range(len(df_trains)):
    df_trains[i] = sampler.fit_resample(df_trains[i], df_trains[i].Y)[0]
    df_vals[i] = sampler.fit_resample(df_vals[i], df_vals[i].Y)[0]

In [6]:
train_datasets = []
val_datasets = []

for i in range(len(df_trains)):
    train_data = []
    for j, row in df_trains[i].iterrows():
        features = smiles_to_graph(row["Drug"])
        target = row["Y"]

        features.y = target
        train_data.append(features)
    train_datasets.append(train_data)

    val_data = []
    for j, row in df_vals[i].iterrows():
        features = smiles_to_graph(row["Drug"])
        target = row["Y"]

        features.y = target
        val_data.append(features)
    val_datasets.append(val_data)



In [7]:
train_dataloaders = []
val_dataloaders = []

for i in range(len(train_datasets)):
    train_dataloaders.append(DataLoader(train_datasets[i], batch_size=32, shuffle=True))
    val_dataloaders.append(DataLoader(val_datasets[i], batch_size=32, shuffle=False))



In [17]:
class MoleculeCrusher(nn.Module):
    def __init__(self):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=100, embedding_dim=128)
        self.gnn = Sequential(
            "x, edge_index, batch",
            [
                (GCNConv(128, 256), "x, edge_index -> x"),
                nn.SiLU(),
                (GCNConv(256, 256), "x, edge_index -> x"),
                nn.SiLU(),
                (GCNConv(256, 512), "x, edge_index -> x"),
                nn.SiLU(),
                (GCNConv(512, 512), "x, edge_index -> x"),
                nn.SiLU(),
                (GCNConv(512, 1024), "x, edge_index -> x"),
                nn.SiLU(),
                (GCNConv(1024, 1024), "x, edge_index -> x"),
                nn.SiLU(),
                (
                    MeanAggregation(),
                    "x, batch -> x",
                ),
                nn.SiLU(),
                (nn.Linear(1024, 1), "x -> x"),
            ],
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)

    @property
    def device(self):
        return next(self.parameters()).device

    def compute_loss(self, pred, target):
        return F.binary_cross_entropy_with_logits(pred, target)
    
    def compute_auc(self, pred, target):
        return roc_auc_score(target, pred)

    def forward(self, data: Data):
        x = data.x.view(-1).to(torch.long)
        idx = data.edge_index
        batch = data.batch

        x = self.embedding(x)
        x = self.gnn(x, idx, batch)

        return x.view(-1)

    @torch.no_grad()
    def predict(self, data: Data):
        return torch.sigmoid(self.forward(data))

    def train(self, epochs, train_dataloader, val_dataloader):
        for epoch in range(epochs):
            train_loss_buffer = deque(maxlen=4)
            train_loop = tqdm(train_dataloader, desc=f"Train epoch {epoch}")
            for data in train_loop:
                data = data.to(self.device)
                target = data.y.to(torch.float32)
                pred = self.forward(data)
                loss = self.compute_loss(pred, target)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                train_loss_buffer.append(loss.item())
                train_loop.set_postfix(loss=np.mean(train_loss_buffer))

            val_loss_buffer = list()
            val_loop = tqdm(val_dataloader, desc=f"Val epoch {epoch}")
            all_targets = []
            all_preds = []
            with torch.no_grad():
                for data in val_loop:
                    data = data.to(self.device)
                    target = data.y.to(torch.float32)
                    pred = self.forward(data)
                    all_targets.append(target.detach().cpu().numpy())
                    all_preds.append(pred.detach().cpu().numpy())
                    loss = self.compute_loss(pred, target)
                    val_loss_buffer.append(loss.item())
                    val_loop.set_postfix(loss=np.mean(val_loss_buffer))

            all_preds = np.concatenate(all_preds, axis=0).reshape(-1)
            all_targets = np.concatenate(all_targets, axis=0).reshape(-1)
            print(self.compute_auc(all_preds, all_targets))
            val_loop.set_postfix(loss=self.compute_auc(all_preds, all_targets))

In [18]:
model = MoleculeCrusher()
model.to("cuda")
model.train(epochs=500, train_dataloader=train_dataloaders[0], val_dataloader=val_dataloaders[0])

Train epoch 0:   0%|          | 0/138 [00:00<?, ?it/s]

Train epoch 0: 100%|██████████| 138/138 [00:01<00:00, 91.94it/s, loss=0.665]
Val epoch 0: 100%|██████████| 35/35 [00:00<00:00, 183.97it/s, loss=0.66] 


0.6998775861163622


Train epoch 1: 100%|██████████| 138/138 [00:01<00:00, 99.21it/s, loss=0.639] 
Val epoch 1: 100%|██████████| 35/35 [00:00<00:00, 182.41it/s, loss=0.632]


0.7090586273892039


Train epoch 2: 100%|██████████| 138/138 [00:01<00:00, 98.88it/s, loss=0.629]
Val epoch 2: 100%|██████████| 35/35 [00:00<00:00, 183.16it/s, loss=0.673]


0.7099168373240916


Train epoch 3: 100%|██████████| 138/138 [00:01<00:00, 99.35it/s, loss=0.649]
Val epoch 3: 100%|██████████| 35/35 [00:00<00:00, 181.50it/s, loss=0.658]


0.6739606306448225


Train epoch 4: 100%|██████████| 138/138 [00:01<00:00, 98.09it/s, loss=0.67] 
Val epoch 4: 100%|██████████| 35/35 [00:00<00:00, 176.41it/s, loss=0.618]


0.7223108328082336


Train epoch 5: 100%|██████████| 138/138 [00:01<00:00, 98.14it/s, loss=0.609]
Val epoch 5: 100%|██████████| 35/35 [00:00<00:00, 181.79it/s, loss=0.623]


0.7239123871035498


Train epoch 6: 100%|██████████| 138/138 [00:01<00:00, 98.43it/s, loss=0.622]
Val epoch 6: 100%|██████████| 35/35 [00:00<00:00, 168.99it/s, loss=0.612]


0.7304006511237134


Train epoch 7: 100%|██████████| 138/138 [00:01<00:00, 97.73it/s, loss=0.688]
Val epoch 7: 100%|██████████| 35/35 [00:00<00:00, 176.84it/s, loss=0.606]


0.7327028854232305


Train epoch 8: 100%|██████████| 138/138 [00:01<00:00, 97.54it/s, loss=0.641]
Val epoch 8: 100%|██████████| 35/35 [00:00<00:00, 168.52it/s, loss=0.633]


0.7230049490653224


Train epoch 9: 100%|██████████| 138/138 [00:01<00:00, 98.15it/s, loss=0.59] 
Val epoch 9: 100%|██████████| 35/35 [00:00<00:00, 156.54it/s, loss=0.614]


0.7319250813904641


Train epoch 10: 100%|██████████| 138/138 [00:01<00:00, 98.75it/s, loss=0.675]
Val epoch 10: 100%|██████████| 35/35 [00:00<00:00, 177.69it/s, loss=0.619]


0.7311160995589161


Train epoch 11: 100%|██████████| 138/138 [00:01<00:00, 97.28it/s, loss=0.802]
Val epoch 11: 100%|██████████| 35/35 [00:00<00:00, 178.37it/s, loss=0.604]


0.7363572516278093


Train epoch 12: 100%|██████████| 138/138 [00:01<00:00, 98.18it/s, loss=0.658]
Val epoch 12: 100%|██████████| 35/35 [00:00<00:00, 173.69it/s, loss=0.615]


0.7389121901911363


Train epoch 13: 100%|██████████| 138/138 [00:01<00:00, 98.17it/s, loss=0.62] 
Val epoch 13: 100%|██████████| 35/35 [00:00<00:00, 175.31it/s, loss=0.615]


0.7393109378281874


Train epoch 14: 100%|██████████| 138/138 [00:01<00:00, 97.13it/s, loss=0.661]
Val epoch 14: 100%|██████████| 35/35 [00:00<00:00, 176.27it/s, loss=0.621]


0.7295473639991599


Train epoch 15: 100%|██████████| 138/138 [00:01<00:00, 97.18it/s, loss=0.66] 
Val epoch 15: 100%|██████████| 35/35 [00:00<00:00, 179.27it/s, loss=0.625]


0.7352299280613316


Train epoch 16: 100%|██████████| 138/138 [00:01<00:00, 97.07it/s, loss=0.663]
Val epoch 16: 100%|██████████| 35/35 [00:00<00:00, 156.10it/s, loss=0.612]


0.7406351738080236


Train epoch 17: 100%|██████████| 138/138 [00:01<00:00, 95.92it/s, loss=0.702]
Val epoch 17: 100%|██████████| 35/35 [00:00<00:00, 176.36it/s, loss=0.603]


0.7405974322621298


Train epoch 18: 100%|██████████| 138/138 [00:01<00:00, 96.84it/s, loss=0.63] 
Val epoch 18: 100%|██████████| 35/35 [00:00<00:00, 180.69it/s, loss=0.616]


0.7225881511237136


Train epoch 19: 100%|██████████| 138/138 [00:01<00:00, 96.10it/s, loss=0.646]
Val epoch 19: 100%|██████████| 35/35 [00:00<00:00, 174.74it/s, loss=0.603]


0.7496980676328502


Train epoch 20:  61%|██████    | 84/138 [00:00<00:00, 97.02it/s, loss=0.648]


KeyboardInterrupt: 

In [None]:
model.train(epochs=4000, train_dataloader=train_dataloaders[0], val_dataloader=val_dataloaders[0])

In [123]:
val_sample = next(iter(val_dataloaders[0])).to(model.device)
pred = model.predict(val_sample)
pred = (pred >= 0.5).to(torch.float32)
print(pred.to(torch.long))
print(val_sample.y)

tensor([0, 0, 1, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0,
        0, 0, 1, 0, 0, 1, 1, 0], device='cuda:0')
tensor([1, 0, 1, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1,
        0, 0, 1, 0, 0, 1, 1, 0], device='cuda:0')
