In [8]:
import random
from tqdm import tqdm

import numpy as np
import pandas as pd

from rdkit import Chem

from atom3d.util.voxelize import dotdict, get_center, gen_rot_matrix, get_grid

from sklearn.model_selection import KFold, train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim.lr_scheduler import CosineAnnealingLR

from torch.utils.data import Dataset, DataLoader

In [9]:
class VoxelDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame
    ):
        super().__init__()
        self.df = df
        self.grid_config =  dotdict({
            # Mapping from elements to position in channel dimension.
            'element_mapping': {
                'H': 0,
                'C': 1,
                'O': 2,
                'N': 3,
                'F': 4,
                'P': 5,
                'S': 6,
                'I': 7,
                'B': 8,
                'Br': 9,
                'Cl': 10,
                'Si': 11
            },
            # Radius of the grids to generate, in angstroms.
            'radius': 7.5,
            # Resolution of each voxel, in angstroms.
            'resolution': 1.0,
            # Number of directions to apply for data augmentation.
            'num_directions': 20,
            # Number of rolls to apply for data augmentation.
            'num_rolls': 20,
        })
        
        self.what_set = "train_set" if self.df.iloc[0][0].startswith("train") else "test_set"
        
    def _voxelize(self, atoms):
        # Use center of molecule as subgrid center
        pos = atoms[['x', 'y', 'z']].astype(np.float32)
        center = get_center(pos)
        # Generate random rotation matrix
        rot_mat = gen_rot_matrix(self.grid_config)
        # Transform protein/ligand into voxel grids and rotate
        grid = get_grid(atoms, center, config=self.grid_config, rot_mat=rot_mat)
        # Last dimension is atom channel, so we need to move it to the front
        # per pytroch style
        grid = np.moveaxis(grid, -1, 0)
        
        return torch.tensor(grid, dtype=torch.float)
    
    def _mol2df(self, mol):
        conf = mol.GetConformer()

        atoms = pd.DataFrame({"element": [atom.GetSymbol() for atom in mol.GetAtoms()]})
        pos = pd.DataFrame(conf.GetPositions(), columns=["x", "y", "z"])

        df = pd.concat([atoms, pos], axis=1)

        return df
        
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        ex = Chem.MolFromMolFile(f"../data/mol_files/{self.what_set}/{row.prefix}_ex.mol", removeHs=False)
        g = Chem.MolFromMolFile(f"../data/mol_files/{self.what_set}/{row.prefix}_g.mol", removeHs=False)

        ex_df = self._mol2df(ex)
        g_df = self._mol2df(g)

        X_ex = self._voxelize(ex_df)
        X_g = self._voxelize(g_df)
        y = torch.tensor([row.Reorg_g, row.Reorg_ex], dtype=torch.float)
        
        return X_g, X_ex, y


In [10]:
train_df = pd.read_csv("../data/train_set.ReorgE.csv", names=["prefix", "SMILES", "Reorg_g", "Reorg_ex"], header=0)
train_df, val_df = train_test_split(train_df, test_size=0.2, random_state=42)

test_df = pd.read_csv("../data/test_set.csv", names=["prefix", "SMILES", "Reorg_g", "Reorg_ex"], header=0)

train_data = VoxelDataset(train_df)
val_data = VoxelDataset(val_df)
test_data = VoxelDataset(test_df)

In [12]:
batch_size = 32

train_dataloader = DataLoader(train_data, batch_size=batch_size, num_workers=4)
val_dataloader = DataLoader(val_data, batch_size=batch_size, num_workers=4)
test_dataloader = DataLoader(test_data, batch_size=batch_size, num_workers=4)

In [13]:
class ConvBlock(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        self.main = nn.Sequential(
            nn.Conv3d(in_channels=12, out_channels=64, kernel_size=3),
            nn.ReLU(),
            nn.Dropout3d(0.2),
            nn.MaxPool3d(2),
            nn.Conv3d(in_channels=64, out_channels=128, kernel_size=3),
            nn.ReLU(),
            nn.Dropout3d(0.2),
            nn.MaxPool3d(2)
        )
    
    def forward(self, x):
        x = self.main(x)
        
        return x.view(x.size(0), -1)
    
class MLP(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        self.main = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2)
        )
        
    def forward(self, x):
        return self.main(x)

        
class SimpleCNN(nn.Module):
    def __init__(
        self
    ):
        super().__init__()
        self.conv = ConvBlock()
        self.mlp = MLP()
        
    def forward(self, g, ex):
        g = self.conv(g)
        ex = self.conv(ex)
        
        x = torch.cat([g, ex], dim=1)
        
        return self.mlp(x)

In [14]:
num_epochs = 10

model = SimpleCNN()
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs * len(train_dataloader))
device = torch.device("cuda:1")

model.to(device)


for epoch in range(num_epochs):
    print(f"Epoch {epoch}")
    train_loss, val_loss = 0., 0.
    
    # train
    model.train()

    for (g, ex, y) in tqdm(train_dataloader):
        optimizer.zero_grad()

        g = g.to(device)
        ex = ex.to(device)
        y = y.to(device)
        
        pred = model(g, ex)
        loss = criterion(pred, y)
        
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_loss += loss.item()
    
    train_loss /= len(train_dataloader)
    
    # validation
    model.eval()
    
    for (g, ex, y) in val_dataloader:
        g = g.to(device)
        ex = ex.to(device)
        y = y.to(device)
        
        pred = model(g, ex)
        loss = criterion(pred, y)
        val_loss += loss * len(y)
    
    val_loss /= len(val_data)
    
    print(f"Train Loss: {train_loss}")
    print(f"Val Loss: {val_loss}")

Epoch 0


100%|██████████| 454/454 [01:45<00:00,  4.29it/s]


RuntimeError: CUDA out of memory. Tried to allocate 22.00 MiB (GPU 1; 23.70 GiB total capacity; 7.73 GiB already allocated; 25.06 MiB free; 8.05 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
preds = []

model.eval()
for (g, ex, y) in tqdm(test_dataloader):
    g = g.to(device)
    ex = ex.to(device)
    y = y.to(device)
        
    pred = model(g, ex)
    preds.append(pred)

    preds = torch.cat(preds).detach().cpu().numpy()

sub_df = pd.read_csv("../data/sample_submission.csv")
sub_df["Reorg_g"] = preds[:, 0]
sub_df["Reorg_ex"] = preds[:, 1]
sub_df.to_csv("submission.csv", sep=",", index=False)