# Load packages

In [1]:
%load_ext autoreload
%autoreload 2

import os
os.chdir('..')

from liver_microsome_prediction.utils import *

In [20]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepchem as dc
import lightning as L

from torch.optim import AdamW
from tqdm.auto import tqdm
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks import ModelCheckpoint
from deepchem.models.torch_models import MPNNModel
from sklearn.metrics import mean_squared_error
from sklearn.preprocessing import StandardScaler, MinMaxScaler

In [3]:
def smiles_split(df, smiles, seed=42, k_fold=5, splitter='scaffold'):
    import deepchem as dc
    Xs, ys = np.arange(len(smiles)), np.ones(len(smiles))
    dataset = dc.data.DiskDataset.from_numpy(X=Xs,y=ys,w=np.zeros(len(smiles)),ids=smiles)
    if splitter == 'random':
        splitter = dc.splits.RandomSplitter()
    elif splitter == 'scaffold':
        splitter = dc.splits.ScaffoldSplitter()
    elif splitter == 'fingerprints':
        splitter = dc.splits.FingerprintSplitter()
    folds = splitter.k_fold_split(dataset, k=k_fold, seed=seed)
    dfs = []
    for fold in folds:
        train_indices = fold[0].X
        val_indices = fold[1].X
        train_df = df.iloc[train_indices].reset_index(drop=True)
        val_df = df.iloc[val_indices].reset_index(drop=True)
        dfs.append((train_df, val_df))
    return dfs

In [4]:
BATCH_SIZE=32
SEED=42
K_FOLD=5

In [5]:
L.seed_everything(SEED)

Global seed set to 42


42

In [6]:
train_df = pd.read_csv('./data/train.csv')
test_df = pd.read_csv('./data/test.csv')

In [7]:
test_df['MLM']=0
test_df['HLM']=0

In [8]:
train_smiles = train_df['SMILES'].tolist()

In [9]:
for train_fold, val_fold in smiles_split(train_df, train_smiles, seed=SEED, k_fold=K_FOLD, splitter='fingerprints'):
    break

In [12]:
for col in ['AlogP','Molecular_Weight','Num_H_Acceptors','Num_H_Donors','Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']:
    scaler = MinMaxScaler()
    train_fold[col] = scaler.fit_transform(train_fold[[col]].values).reshape(-1)
    val_fold[col] = scaler.transform(val_fold[[col]].values).reshape(-1)
    test_df[col] = scaler.transform(test_df[[col]].values).reshape(-1)
    
    train_nan_value = train_fold[col].mean()
    train_fold[col] = train_fold[col].fillna(train_nan_value)
    val_fold[col] = val_fold[col].fillna(train_nan_value)
    test_df[col] = test_df[col].fillna(train_nan_value)

In [13]:
featurizer = dc.feat.MolGraphConvFeaturizer(use_edges=True)

In [14]:
train_X = featurizer.featurize(train_fold['SMILES'].tolist())
train_w = train_fold[['AlogP','Molecular_Weight','Num_H_Acceptors','Num_H_Donors','Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']].values
train_dataset = dc.data.NumpyDataset(X=train_X, y=train_fold[['MLM','HLM']].values, w=train_w)

val_X = featurizer.featurize(val_fold['SMILES'].tolist())
val_w = val_fold[['AlogP','Molecular_Weight','Num_H_Acceptors','Num_H_Donors','Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']].values
val_dataset = dc.data.NumpyDataset(X=val_X, y=val_fold[['MLM','HLM']].values, w=val_w)

test_X = featurizer.featurize(test_df['SMILES'].tolist())
test_w = test_df[['AlogP','Molecular_Weight','Num_H_Acceptors','Num_H_Donors','Num_RotatableBonds', 'LogD', 'Molecular_PolarSurfaceArea']].values
test_dataset = dc.data.NumpyDataset(X=test_X, y=test_df[['MLM','HLM']].values, w=test_w)

In [15]:
def collate_fn(samples):
    X = [sample[0] for sample in samples]
    y = torch.Tensor([sample[1] for sample in samples])
    w = torch.Tensor([sample[2] for sample in samples])
    return ([X],y,w)

In [16]:
train_datas = []
val_datas = []
test_datas = []

for x,y,w in zip(train_dataset.X, train_dataset.y, train_dataset.w):
    train_datas.append((x,y,w))
    
for x,y,w in zip(val_dataset.X, val_dataset.y, val_dataset.w):
    val_datas.append((x,y,w))
    
for x,y,w in zip(test_dataset.X, test_dataset.y, test_dataset.w):
    test_datas.append((x,y,w))
    
train_dataloader = DataLoader(train_datas, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_dataloader = DataLoader(val_datas, batch_size=BATCH_SIZE*2, collate_fn=collate_fn)
test_dataloader = DataLoader(test_datas, batch_size=BATCH_SIZE*2, collate_fn=collate_fn)

In [17]:
class LitMPNNSMILESClassification(L.LightningModule):
    def __init__(self, batch_size, node_out_feats=64, n_tasks=2):
        super().__init__()
        model = MPNNModel(
            mode='regression',
            n_tasks=n_tasks,
            node_out_feats=node_out_feats,
            batch_size=batch_size,
        )
        self._prepare_batch = lambda batch : model._prepare_batch(batch)
        self.model = model.model
        self.model.model.predict = nn.Identity()
        self.batch_size=batch_size
        self.classifier = nn.Sequential(
            nn.LazyLinear(node_out_feats//2),
            nn.ReLU(),
            nn.LazyLinear(n_tasks)
        )
        
        self.validation_step_outputs = []
        
    def forward(self, x, xp):
        x = self.model(x)
        x = torch.cat([x, xp], dim=-1)
        x_out = self.classifier(x)
        return x_out
    
    def training_step(self, batch, batch_idx):
        x, *_ = self._prepare_batch(batch)
        y_true, xp = batch[1]/100, batch[2]
        y_pred = self(x, xp)
        loss1 = F.mse_loss(y_pred[:,0].flatten(), y_true[:,0].flatten())
        loss2 = F.mse_loss(y_pred[:,1].flatten(), y_true[:,1].flatten())
        loss = (loss1**0.5 + loss2**0.5)/2
        self.log_dict({"train_loss": loss}, on_step=True, prog_bar=True, batch_size=self.batch_size)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, *_ = self._prepare_batch(batch)
        y_true, xp = batch[1], batch[2]
        y_pred = self(x, xp)*100
        loss1 = F.mse_loss(y_pred[:,0].flatten(), y_true[:,0].flatten())
        loss2 = F.mse_loss(y_pred[:,1].flatten(), y_true[:,1].flatten())
        loss = (loss1, loss2)
        self.validation_step_outputs.append(loss)
        return loss
    
    def on_validation_epoch_end(self):
        loss = torch.Tensor(self.validation_step_outputs)
        loss1, loss2 = loss[:, 0], loss[:, 1]
        loss = ((loss1.mean())**0.5 + (loss2.mean())**0.5)/2
        self.log("val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.validation_step_outputs.clear()
        
    def predict_step(self, batch, batch_idx):
        x, *_ = self._prepare_batch(batch)
        y_true, xp = batch[1], batch[2]
        y_pred = self(x, xp)*100
        return y_pred
        
    def configure_optimizers(self): 
        optimizer = AdamW(self.model.parameters(), lr=1e-3, betas=(0.9, 0.999), weight_decay=0.01)
        return optimizer

In [18]:
lit_model = LitMPNNSMILESClassification(
    batch_size=BATCH_SIZE
)

In [21]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='./checkpoint/',
    filename='MPNN-{epoch:02d}-{train_loss:.2f}-{val_loss:.2f}',
    save_top_k=1,
)

In [22]:
trainer = L.Trainer(
    accelerator='gpu',
#     precision='bf16-mixed',
    max_epochs=100,
    callbacks=[checkpoint_callback],
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [23]:
trainer.fit(lit_model, train_dataloader, val_dataloader)

You are using a CUDA device ('NVIDIA GeForce RTX 3080') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type       | Params
------------------------------------------
0 | model      | MPNN       | 673 K 
1 | classifier | Sequential | 0     
------------------------------------------
673 K     Trainable params
0         Non-trainable params
673 K     Total params
2.693     Total estimated model params size (MB)


Epoch 0: 100%|██████████| 88/88 [00:03<00:00, 29.07it/s, v_num=40, train_loss=0.382]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 47.12it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 48.16it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 48.29it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 43.87it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 44.79it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 45.05it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 45.42it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 42.82it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 42.06it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 42.31it/s][A
Epoch 0: 

Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 46.80it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 46.34it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 46.25it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 45.76it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 44.54it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 45.36it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 46.16it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 46.47it/s][A
Epoch 7: 100%|██████████| 88/88 [00:02<00:00, 33.19it/s, v_num=40, train_loss=0.352, val_loss=34.10]
Epoch 8: 100%|██████████| 88/88 [00:02<00:00, 35.72it/s, v_num=40, train_loss=0.340, val_loss=34.10]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 

Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 49.49it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 49.92it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 49.99it/s][A
Epoch 14: 100%|██████████| 88/88 [00:02<00:00, 30.67it/s, v_num=40, train_loss=0.329, val_loss=33.90]
Epoch 15: 100%|██████████| 88/88 [00:02<00:00, 35.69it/s, v_num=40, train_loss=0.320, val_loss=33.90]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 49.77it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 44.32it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 45.35it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 45.51it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 45.92it/s][A
Validation DataLoader 0:  55%|█████▍    

Epoch 22: 100%|██████████| 88/88 [00:02<00:00, 34.11it/s, v_num=40, train_loss=0.291, val_loss=33.50]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 48.92it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 48.43it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 48.46it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 48.38it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 47.49it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 44.41it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 44.64it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 44.62it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 44.65it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 44.98i

Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 42.88it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 44.13it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 45.05it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 45.52it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 45.01it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 44.02it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 44.13it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 44.42it/s][A
Epoch 29: 100%|██████████| 88/88 [00:02<00:00, 31.06it/s, v_num=40, train_loss=0.282, val_loss=34.10]
Epoch 30: 100%|██████████| 88/88 [00:02<00:00, 34.31it/s, v_num=40, train_loss=0.284, val_loss=34.10]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         

Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 42.85it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 43.24it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 43.19it/s][A
Epoch 36: 100%|██████████| 88/88 [00:02<00:00, 31.02it/s, v_num=40, train_loss=0.280, val_loss=34.10]
Epoch 37: 100%|██████████| 88/88 [00:02<00:00, 33.44it/s, v_num=40, train_loss=0.290, val_loss=34.10]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 47.71it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 48.78it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 48.92it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 48.73it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 46.78it/s][A
Validation DataLoader 0:  55%|█████▍    

Epoch 44: 100%|██████████| 88/88 [00:02<00:00, 34.03it/s, v_num=40, train_loss=0.287, val_loss=35.00]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 47.29it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 47.61it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 43.73it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 42.54it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 42.87it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 43.25it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 43.93it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 44.28it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 43.99it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 42.61i

Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 46.25it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 47.11it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 47.39it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 46.91it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 46.63it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 46.16it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 30.10it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 30.62it/s][A
Epoch 51: 100%|██████████| 88/88 [00:02<00:00, 30.97it/s, v_num=40, train_loss=0.295, val_loss=35.50]
Epoch 52: 100%|██████████| 88/88 [00:02<00:00, 33.70it/s, v_num=40, train_loss=0.271, val_loss=35.50]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         

Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 41.39it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 40.71it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 40.00it/s][A
Epoch 58: 100%|██████████| 88/88 [00:02<00:00, 30.19it/s, v_num=40, train_loss=0.297, val_loss=35.60]
Epoch 59: 100%|██████████| 88/88 [00:02<00:00, 34.03it/s, v_num=40, train_loss=0.238, val_loss=35.60]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 47.95it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 44.70it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 44.55it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 44.64it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 45.24it/s][A
Validation DataLoader 0:  55%|█████▍    

Epoch 66: 100%|██████████| 88/88 [00:02<00:00, 34.46it/s, v_num=40, train_loss=0.241, val_loss=35.30]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 47.22it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 48.54it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 45.47it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 46.36it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 45.53it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 45.74it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 45.99it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 46.13it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 46.36it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 46.33i

Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 48.26it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 48.04it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 45.26it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 45.10it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 45.36it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 44.09it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 44.34it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 44.06it/s][A
Epoch 73: 100%|██████████| 88/88 [00:02<00:00, 30.40it/s, v_num=40, train_loss=0.227, val_loss=36.40]
Epoch 74: 100%|██████████| 88/88 [00:02<00:00, 34.21it/s, v_num=40, train_loss=0.253, val_loss=36.40]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         

Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 44.04it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 44.27it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 44.40it/s][A
Epoch 80: 100%|██████████| 88/88 [00:03<00:00, 29.05it/s, v_num=40, train_loss=0.220, val_loss=36.00]
Epoch 81: 100%|██████████| 88/88 [00:02<00:00, 33.84it/s, v_num=40, train_loss=0.249, val_loss=36.00]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 46.28it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 47.89it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 44.43it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 43.77it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 44.18it/s][A
Validation DataLoader 0:  55%|█████▍    

Epoch 88: 100%|██████████| 88/88 [00:02<00:00, 33.90it/s, v_num=40, train_loss=0.209, val_loss=36.40]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         | 1/11 [00:00<00:00, 46.47it/s][A
Validation DataLoader 0:  18%|█▊        | 2/11 [00:00<00:00, 44.72it/s][A
Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 43.55it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 41.28it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 42.34it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 43.08it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 43.37it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 42.99it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 43.10it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 43.52i

Validation DataLoader 0:  27%|██▋       | 3/11 [00:00<00:00, 49.29it/s][A
Validation DataLoader 0:  36%|███▋      | 4/11 [00:00<00:00, 48.98it/s][A
Validation DataLoader 0:  45%|████▌     | 5/11 [00:00<00:00, 47.77it/s][A
Validation DataLoader 0:  55%|█████▍    | 6/11 [00:00<00:00, 47.77it/s][A
Validation DataLoader 0:  64%|██████▎   | 7/11 [00:00<00:00, 46.91it/s][A
Validation DataLoader 0:  73%|███████▎  | 8/11 [00:00<00:00, 46.67it/s][A
Validation DataLoader 0:  82%|████████▏ | 9/11 [00:00<00:00, 46.96it/s][A
Validation DataLoader 0:  91%|█████████ | 10/11 [00:00<00:00, 46.54it/s][A
Epoch 95: 100%|██████████| 88/88 [00:02<00:00, 30.49it/s, v_num=40, train_loss=0.213, val_loss=36.10]
Epoch 96: 100%|██████████| 88/88 [00:02<00:00, 33.61it/s, v_num=40, train_loss=0.232, val_loss=36.10]
Validation: 0it [00:00, ?it/s][A
Validation:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   0%|          | 0/11 [00:00<?, ?it/s][A
Validation DataLoader 0:   9%|▉         

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 88/88 [00:02<00:00, 29.89it/s, v_num=40, train_loss=0.206, val_loss=37.90]


In [None]:
lit_model = LitMPNNSMILESClassification.load_from_checkpoint(
    'checkpoint/MPNN-epoch=19-train_loss=0.29-val_loss=32.97.ckpt',
    batch_size=BATCH_SIZE
)

In [None]:
trainer = L.Trainer(
    accelerator='gpu',
)

In [None]:
preds = trainer.predict(lit_model, test_dataloader)

In [None]:
submission = pd.read_csv('data/sample_submission.csv')

In [None]:
submission[['MLM','HLM']] = torch.cat(preds).numpy()

In [None]:
submission.to_csv('submission/mpnn.csv',index=False)