In [1]:
import pandas as pd
import numpy as np
import torch as T
from torch import nn
import torch.nn.functional as F

import chemprop
import rdkit

import pickle as pkl
import gzip

from tqdm.notebook import tqdm
import argparse

import random

In [2]:
import covid
from covid.datasets import *
from covid.modules import *
from covid.data import *
from covid.model import *
from covid.schedulers import LinearWarmupScheduler

from covid.modules.chemistry import MPNEncoder

In [3]:
import matplotlib.pyplot as plt

from ipywidgets import widgets
from IPython.display import display

In [4]:
np.random.seed(4);
random.seed(4);
T.manual_seed(4);

# 4 -- chosen by fair die roll.  Guaranteed to be random.  https://xkcd.com/221/

In [5]:
DROPOUT_RATE = 0.4
BATCH_SIZE = 16

VALIDATION_FREQUENCY = 0.2

In [6]:
DEVICE = 'cuda'

In [7]:
chem_model = MPNEncoder(
    layers_per_message=2, 
    hidden_size=300,
    dropout=DROPOUT_RATE
)

protein_model = create_protein_model(dropout=DROPOUT_RATE)

model = CovidModel(chem_model, protein_model, dropout=DROPOUT_RATE)

In [8]:
model.to(DEVICE);

In [9]:
if not os.path.exists('./data/training'):
    create_data_split('./data', './data/training', './data/final_holdout')
    
    for i in range(10):
        create_data_split('./data/training', f'./data/train_{i:02}', f'./data/valid_{i:02}')

In [10]:
# Have to reset the seeds again to account for the possibility we did or didn't just create the datasets

np.random.seed(4);
random.seed(4);
T.manual_seed(4);

In [11]:
data = covid.datasets.StitchDataset('./data/train_00')
dataloader = create_dataloader(data, BATCH_SIZE)

In [12]:
validation_data = covid.datasets.StitchDataset('./data/valid_00')
validation_dataloader = create_dataloader(validation_data, BATCH_SIZE, BATCH_SIZE * 500)

In [13]:
optim = T.optim.Adam(model.parameters(), lr=1e-4, betas=(0.95, 0.99))
warmup = LinearWarmupScheduler(optim, 2000)

In [14]:
losses = []
validation_stats = []

In [15]:
def plot_losses(losses, validation_stats, period=10, display_to=None):
    with plt.style.context('bmh'):
        fig, ax = plt.subplots()
        x, y = zip(*losses)
        
        x = np.reshape(x[:((len(x)//period)*period)],(-1,period)).mean(1)
        y = np.reshape(y[:((len(y)//period)*period)],(-1,period)).mean(1)
        
        ax.plot(x, y, lw=1, alpha=0.75, label='model (training)', c='C0')
        
        x, y, _, _ = zip(*validation_stats)
        
        ax.plot(x, y, label='model (validation)', c='C1')
        
        ax.plot((x[0], x[-1]), (0.358, 0.358), c='g', ls=':', lw=1.5, label='random baseline')
        ax.legend()
        if display_to is not None:
            with display_to:
                display(fig)
            plt.close(fig)

In [16]:
validation_stats

[]

In [17]:
from functools import partial

get_validation_loss = partial(
    calculate_average_loss_and_accuracy, 
    model, 
    validation_dataloader,
    DEVICE
)

In [18]:
epoch = 0
last_validation = epoch

if os.path.exists("training_state_00.pkl"):
    state = T.load("./training_state_00.pkl", map_location=DEVICE)
        
    epoch = state.get('epoch', epoch)
    losses = state.get('losses', losses)
    validation_stats = state.get('validation_stats', validation_stats)
    last_validation = state.get('last_validation', last_validation)
    model.load_state_dict(state['model'])
    optim.load_state_dict(state['optim'])
    warmup.load_state_dict(state['warmup'])
    
    
if epoch == 0:
    vloss, vacc, v_conf = get_validation_loss()
    validation_stats.append([0, vloss, vacc, v_conf])

In [19]:
if not os.path.exists("./checkpoints/"):
    os.mkdir("./checkpoints")

In [None]:
chart_area = widgets.Output()
display(chart_area)

for epoch in tqdm(range(epoch, 100)):
    idx = 0
    
    model.train()
    pct_epoch = 0
    
    for batch in tqdm(dataloader, leave=False):

        model.zero_grad()
        _, _, loss, _ = run_model(model, batch, DEVICE)

        loss.backward()

        optim.step()
        warmup.step()
            
        idx += BATCH_SIZE
        pct_epoch = min(1.0, idx/len(data))
        
        losses.append((epoch + pct_epoch, loss.item()))
        
        if epoch + pct_epoch - last_validation > VALIDATION_FREQUENCY:
            vloss, vacc, v_conf = get_validation_loss()
            validation_stats.append([epoch+pct_epoch, vloss, vacc, v_conf])
            
            chart_area.clear_output()
            plot_losses(losses, validation_stats, max(10, len(losses)//2000), chart_area)
            last_validation = epoch + pct_epoch
            
    state = {
        'epoch': epoch,
        'losses': losses,
        'validation_stats': validation_stats,
        'last_validation': last_validation,
        'model': model.state_dict(),
        'optim': optim.state_dict(),
        'warmup': warmup.state_dict(),
    }
    T.save(state, f'./checkpoints/model_00_{epoch:03}.pkl')
    T.save(state, "./training_state_00.pkl")

Output()

HBox(children=(FloatProgress(value=0.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

HBox(children=(FloatProgress(value=0.0, max=21362.0), HTML(value='')))

In [None]:
plot_losses(losses, validation_stats, max(10, len(losses)//2000))