In [1]:
import pandas as pd
import numpy as np
import torch as T
from torch import nn
import torch.nn.functional as F

import chemprop
import rdkit

import pickle as pkl
import gzip

from tqdm.notebook import tqdm
import argparse

import random

In [2]:
import covid
from covid.datasets import *
from covid.modules import *
from covid.data import *
from covid.model import *
from covid.schedulers import LinearWarmupScheduler
from covid.reporting import get_performance_plots

from covid.modules.chemistry import MPNEncoder

In [3]:
import matplotlib.pyplot as plt

from ipywidgets import widgets
from IPython.display import display

In [4]:
import covid.training

In [5]:
config = covid.training.CovidTrainingConfiguration()
dl, vdl = covid.training._create_dataloaders(config)

In [6]:
len(dl)

8545

In [7]:
it = iter(dl)

In [10]:
for _ in range (8500):
    next(it)

In [4]:
np.random.seed(4);
random.seed(4);
T.manual_seed(4);

# 4 -- chosen by fair die roll.  Guaranteed to be random.  https://xkcd.com/221/

In [5]:
DROPOUT_RATE = 0.4
BATCH_SIZE = 16

VALIDATION_FREQUENCY = 0.2

In [6]:
DEVICE = 'cuda'

In [7]:
chem_model = MPNEncoder(
    layers_per_message=2, 
    hidden_size=300,
    dropout=DROPOUT_RATE
)

protein_model = create_protein_model(dropout=DROPOUT_RATE)

model = CovidModel(chem_model, protein_model, dropout=DROPOUT_RATE)

In [8]:
params = list(model.parameters())

In [17]:
for param in params:
    if len(param.shape) >= 2:
        T.nn.init.kaiming_normal_(param)
    else:
        T.nn.init.normal_(param)

In [8]:
model.to(DEVICE);

In [9]:
if not os.path.exists('./data/training'):
    create_data_split('./data', './data/training', './data/final_holdout')
    
    for i in range(10):
        create_data_split('./data/training', f'./data/train_{i:02}', f'./data/valid_{i:02}')

In [10]:
# Have to reset the seeds again to account for the possibility we did or didn't just create the datasets

np.random.seed(4);
random.seed(4);
T.manual_seed(4);

In [11]:
data = covid.datasets.StitchDataset('./data/train_00')
dataloader = create_dataloader(data, BATCH_SIZE)

In [12]:
validation_data = covid.datasets.StitchDataset('./data/valid_00', neg_rate=0.0)
validation_dataloader = create_dataloader(validation_data, BATCH_SIZE)

In [13]:
optim = T.optim.Adam(model.parameters(), lr=1e-4, betas=(0.95, 0.99))
warmup = LinearWarmupScheduler(optim, 2000)

In [14]:
losses = []
validation_stats = []

In [65]:
with gzip.open('./checkpoints/model_train_fold00_000.pkl.gz', 'rb') as f:
    state = T.load(f)

In [69]:
a, b, c, d = zip(*state['validation_stats'])

In [72]:
c

(array([0.5757798 , 0.24077486, 0.3670774 , 0.32965682, 0.28705316]),
 array([0.66705973, 0.7942977 , 0.87277217, 0.79145701, 0.82455177]),
 array([0.67934792, 0.80454201, 0.88669814, 0.79894614, 0.83416545]),
 array([0.80140947, 0.77208418, 0.89145058, 0.82033474, 0.88650104]),
 array([0.84964117, 0.78305617, 0.90101303, 0.82478764, 0.91807779]),
 array([0.84856255, 0.819456  , 0.87960643, 0.83847463, 0.91785839]))

In [74]:
from collections import namedtuple

In [75]:
namedtuple('')

<function collections.namedtuple(typename, field_names, *, rename=False, defaults=None, module=None)>

In [78]:
from collections import namedtuple
ConfusionMatrix = namedtuple('ConfusionMatrix', ['tp', 'fp', 'fn', 'tn'])

In [17]:
from functools import partial

get_validation_loss = partial(
    calculate_average_loss_and_accuracy, 
    model, 
    validation_dataloader,
    DEVICE
)

In [18]:
epoch = 0
last_validation = epoch

if os.path.exists("training_state_00.pkl"):
    state = T.load("./training_state_00.pkl", map_location=DEVICE)
        
    epoch = state.get('epoch', epoch)
    losses = state.get('losses', losses)
    validation_stats = state.get('validation_stats', validation_stats)
    last_validation = state.get('last_validation', last_validation)
    model.load_state_dict(state['model'])
    optim.load_state_dict(state['optim'])
    warmup.load_state_dict(state['warmup'])
    
    
if epoch == 0:
    vloss, vacc, v_conf = get_validation_loss()
    validation_stats.append([0, vloss, vacc, v_conf])

In [19]:
if not os.path.exists("./checkpoints/"):
    os.mkdir("./checkpoints")

In [20]:
chart_area = widgets.Output()
display(chart_area)

for epoch in tqdm(range(epoch, 100)):
    idx = 0
    
    model.train()
    pct_epoch = 0
    
    for batch in tqdm(dataloader, leave=False):

        model.zero_grad()
        _, _, loss, _ = run_model(model, batch, DEVICE)

        loss.backward()

        optim.step()
        warmup.step()
            
        idx += BATCH_SIZE
        pct_epoch = min(1.0, idx/len(data))
        
        losses.append((epoch + pct_epoch, loss.item()))
        
        if pct_epoch == 1.0 or epoch + pct_epoch - last_validation > VALIDATION_FREQUENCY:
            vloss, vacc, v_conf = get_validation_loss()
            validation_stats.append([epoch+pct_epoch, vloss, vacc, v_conf])
            
            chart_area.clear_output()
            fig = get_performance_plots(losses, validation_stats)
            with chart_area:
                display(fig)
            plt.close(fig)
            
            last_validation = epoch + pct_epoch
            
    state = {
        'epoch': epoch,
        'losses': losses,
        'validation_stats': validation_stats,
        'last_validation': last_validation,
        'model': model.state_dict(),
        'optim': optim.state_dict(),
        'warmup': warmup.state_dict(),
    }
    T.save(state, f'./checkpoints/model_00_{epoch:03}.pkl')
    T.save(state, "./training_state_00.pkl")

Output()

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

KeyboardInterrupt: 

In [None]:
plot_losses(losses, validation_stats, max(10, len(losses)//2000))