## Setup Environment

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import pathlib
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim

from sklearn.model_selection import train_test_split
from time import time
from tqdm import tqdm

## Declare Constants

In [3]:
MAX_LEN = 120
NLATENT = 196
DECODER_HIDDEN_SIZE = 488
DECODER_NUM_LAYERS = 3

DATA_DIR = './data'
DATA_FILE_NAME = '250k_rndm_zinc_drugs_clean_3.csv'
CHAR_FILE_NAME = 'zinc.json'
TEST_IDX_FILE_NAME = 'test_idx.npy'

ALL_LETTERS = yaml.safe_load(open(pathlib.Path(DATA_DIR, CHAR_FILE_NAME))) + ['SOS']
N_LETTERS = len(ALL_LETTERS) 

MODELS_DIR = './models'

## Create a Utility Class

In [4]:
class VAEUtils:
    '''
    This purpose of this class is to help with various aspects
    of data processing
    '''
    
    def __init__(self, data_dir=DATA_DIR, 
                 data_file_name=DATA_FILE_NAME, 
                 test_idx_file_name=TEST_IDX_FILE_NAME,
                 max_len=MAX_LEN,
                 all_letters=ALL_LETTERS):
        
        self.data_dir = pathlib.Path(data_dir)
        self.data_file = self.data_dir / pathlib.Path(data_file_name)
        self.test_file = self.data_dir / pathlib.Path(test_idx_file_name)
        
        self.max_len = max_len
        
        self.all_letters = all_letters
        self.n_letters = len(all_letters)
        self.letters_to_indices_dict = dict((l, i) for i, l in enumerate(all_letters))
        self.indices_to_letters_dict = dict((i, l) for i, l in enumerate(all_letters))
               
    def get_data_df(self):
        df = pd.read_csv(self.data_file)
        df = df[df.smiles.str.len() <= self.max_len].reset_index(drop=True)
        
        # preprocess input smile to remove the newline character and add padding
        df.loc[:, 'smiles'] = df.loc[:, 'smiles'].str.strip()\
                    .str.pad(width=self.max_len, side='right', fillchar=" ")
        
        return df
        
    # One-hot matrix of first to last letters (not including EOS) for input
    def get_input_tensor(self, smile):
        tensor = torch.zeros(1, len(smile), self.n_letters) # batch_size * seq_length * num_features
        for i, letter in enumerate(smile):
            tensor[0][i][self.letters_to_indices_dict[letter]] = 1
        return tensor

    # LongTensor of first letter to end (EOS) for target
    def get_target_tensor(self, smile):
        letter_indexes = [self.letters_to_indices_dict[l] for l in smile]
        # letter_indexes.append(self.n_letters - 1) # EOS
        return torch.LongTensor(letter_indexes)
    
    def get_train_valid_test_splits(self, reg_col, valid_pct=.1):
        df = self.get_data_df()[['smiles', reg_col]]
        df = df.rename(columns={reg_col: 'reg_col'})
        
        test_idx = np.load(self.test_file)
        non_test_idx = np.array(df[~df.index.isin(test_idx)].index)
        train_idx, valid_idx = train_test_split(non_test_idx, test_size=valid_pct, 
                                                random_state=42, shuffle=True)
        
        assert len(df) == len(test_idx) + len(train_idx) + len(valid_idx)
        
        return df, train_idx, valid_idx, test_idx
         
    def get_dl(self, df, idx, bs, shuffle=False):
        
        df = df.iloc[idx]
        
        input_tensors = torch.zeros(len(df), self.max_len, self.n_letters)
        target_tensors = torch.zeros(len(df), self.max_len, dtype=torch.long)
        for i, smile in enumerate(tqdm(df.smiles)):
            input_tensors[i] = self.get_input_tensor(smile)
            target_tensors[i] = self.get_target_tensor(smile)
        
        input_tensors = input_tensors
        target_tensors = target_tensors
        
        # original_lengths = torch.tensor(df.smiles.str.strip().str.len().to_numpy())

        property_values = torch.tensor(df.reg_col.to_numpy()).type(torch.float32)
         
        ds = TensorDataset(input_tensors, target_tensors, property_values)
        dl = DataLoader(ds, shuffle=shuffle, batch_size=bs)
        
        return dl
    
vae_utils = VAEUtils()

## Create the Model Networks

In [5]:
class Lambda(nn.Module):
    '''
    This class simplifies layers from 
    custom functions
    '''
    def __init__(self, func):
        super().__init__()
        self.func = func
    
    def forward(self, x):
        return self.func(x)

In [6]:
class Encoder(nn.Module):
    def __init__(self, n_letters, nlatent, decoder_hidden_size):
        super().__init__()
        self.n_letters = n_letters
        self.nlatent = nlatent
        self.decoder_hidden_size = decoder_hidden_size
        self.encoder = nn.Sequential(
            Lambda(lambda x: x.permute(0, 2, 1)), # the features are in the channels dimension
            nn.Conv1d(in_channels=n_letters, out_channels=9, kernel_size=9),
            nn.Tanh(), # the authors of the paper used tanh
            # nn.ReLU(),
            nn.BatchNorm1d(9), # the authors of the paper did batch normalization
            nn.Conv1d(in_channels=9, out_channels=9, kernel_size=9),
            nn.Tanh(), # the authors of the paper used tanh
            # nn.ReLU(),
            nn.BatchNorm1d(9), # the authors of the paper did batch normalization
            nn.Conv1d(in_channels=9, out_channels=11, kernel_size=10),
            nn.Tanh(), # the authors of the paper used tanh
            #nn.ReLU(),
            nn.BatchNorm1d(11), # the authors of the paper did batch normalization
            nn.Flatten()
        )
        
        self.mean = nn.Linear(1045, nlatent)
        self.log_var = nn.Linear(1045, nlatent)
        self.dec_init_hidden = nn.Linear(nlatent, decoder_hidden_size)
        
    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var) 
        eps = torch.randn_like(std)
        sample = mean + (eps * std) 
        return sample
        
    def forward(self, x):
        x = self.encoder(x)
        mean = self.mean(x)
        log_var = self.log_var(x)
        z = self.reparameterize(mean, log_var)
        dec_init_hidden = self.dec_init_hidden(z)
        
        return z, mean, log_var, dec_init_hidden

In [7]:
def preprocess_decoder_input(input_tensors):
    '''
    Adjust for SOS and make batch the second demension
    '''
    sos_tensor = torch.zeros(1, N_LETTERS)
    sos_tensor[0][vae_utils.letters_to_indices_dict['SOS']] = 1
    new_tensor = torch.zeros(input_tensors.shape[0], MAX_LEN, N_LETTERS)
    new_tensor[:][0] = sos_tensor
    new_tensor[:, 1:MAX_LEN, :] = input_tensors[:, 0:MAX_LEN-1, :]
    new_tensor = new_tensor.permute(1, 0, 2).to(input_tensors.device)
    
    return new_tensor

In [8]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.preprocess = Lambda(preprocess_decoder_input)
        self.rnn = nn.GRU(input_size, hidden_size, num_layers)
        self.out = nn.Linear(hidden_size, input_size)
        self.softmax = nn.Softmax(dim=2)
        
    def forward(self, input, hidden):
        output, hidden = self.rnn(self.preprocess(input), )
        output = output.permute(1, 0, 2)
        output = self.softmax(self.out(output))
        return output, hidden

In [9]:
class PropertyPredictor(nn.Module):
    def __init__(self, nlatent):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Linear(nlatent, 1000),
            nn.Tanh(), # the authors of the paper used tanh
            # nn.ReLU(),
            nn.Dropout(.2),
            nn.Linear(1000, 1000),
            nn.Tanh(), # the authors of the paper used tanh
            # nn.ReLU(),
            nn.Dropout(.2),
            nn.Linear(1000, 1)
        )
        
    def forward(self, x):
        return self.predictor(x) 

## Write Helper Classes & Functions for Training and Testing

In [10]:
class Trainer:
    def __init__(self, reg_col):
        
        models_dir = pathlib.Path(MODELS_DIR)
        models_dir.mkdir(exist_ok=True, parents=True)
        self.encoder_file = models_dir / pathlib.Path(f'{reg_col}_encoder.pth')
        self.decoder_file = models_dir / pathlib.Path(f'{reg_col}_decoder.pth')
        self.property_predictor_file = models_dir / pathlib.Path(f'{reg_col}_property_predictor.pth')
        self.train_losses_file = models_dir / pathlib.Path(f'{reg_col}_train_losses.csv')
        self.val_losses_file = models_dir / pathlib.Path(f'{reg_col}_valid_losses.csv')

        self.loss_columns = ['total_loss', 'reconstruction_loss', 'kl_divergence', 'regression_loss']
        self.train_losses_df = pd.DataFrame(columns=self.loss_columns)
        self.val_losses_df = pd.DataFrame(columns=self.loss_columns)

    def intitialize_networks(self):
        self.encoder = Encoder(vae_utils.n_letters, NLATENT, DECODER_HIDDEN_SIZE)
        self.decoder = Decoder(vae_utils.n_letters, DECODER_HIDDEN_SIZE, DECODER_NUM_LAYERS)
        self.property_predictor = PropertyPredictor(NLATENT)
        
    def __time_since(self, since):
        now = time()
        s = now - since
        m = math.floor(s / 60)
        s -= m * 60
        return f'{m}m {s:.0f}s'     
        
    def kl_anneal_function(self, epoch, anneal_start, k=1):
        return 1 / (1 + np.exp(- k * (epoch - anneal_start)))
    
    def get_losses(self, epoch, input_tensors, target_tensors, property_values):
                
        recontruction_loss_func = nn.NLLLoss()
        reg_loss_func = nn.MSELoss()
        
        input_tensors = input_tensors.to(device)
        target_tensors = target_tensors.to(device)
        property_values = property_values.to(device)

        z, mean, log_var, dec_init_hidden = self.encoder(input_tensors)
        output, hidden = self.decoder(input_tensors, 
                                 dec_init_hidden.unsqueeze(0).repeat(DECODER_NUM_LAYERS, 1, 1))
        reg_pred = self.property_predictor(z)

        kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        kl_weight = self.kl_anneal_function(epoch, anneal_start)

        reg_loss = reg_loss_func(reg_pred.flatten(), property_values)
        reg_loss = reg_loss.type(torch.float32) 

        reconstruction_loss = 0
        for i in range(input_tensors.shape[0]):
            reconstruction_loss += recontruction_loss_func(output[i], target_tensors[i])
        reconstruction_loss /= input_tensors.shape[0]

        loss = reconstruction_loss + kl_divergence * kl_weight + reg_loss
        
        return loss, reconstruction_loss, kl_divergence, reg_loss
    
    def process_dl(self, epoch, anneal_start, device, dl, train, 
                   enc_opt=None, dec_opt=None, pp_opt=None):
        
        num_loaders = len(dl)
        loader_loss = 0
        recon_loss_epoch = 0
        kld_epoch = 0
        reg_loss_epoch = 0
        
        for input_tensors, target_tensors, property_values in dl:
            loss, reconstruction_loss, kl_divergence, reg_loss = \
                self.get_losses(epoch, input_tensors, target_tensors, property_values)
            
            if train:
                enc_opt.zero_grad()
                dec_opt.zero_grad()
                pp_opt.zero_grad()
                
                loss.backward()
                
                enc_opt.step()
                dec_opt.step()
                pp_opt.step()
                
            loader_loss += loss.item()
            recon_loss_epoch += reconstruction_loss.item()
            kld_epoch += kl_divergence.item()
            reg_loss_epoch += reg_loss.item()
        
        loader_loss /= num_loaders
        recon_loss_epoch /= num_loaders
        kld_epoch /= num_loaders
        reg_loss_epoch /= num_loaders
        
        return (loader_loss, recon_loss_epoch, kld_epoch, reg_loss_epoch)
    
    def save_parameters_and_losses(self, train_losses:list, val_losses:list):
        
        torch.save(self.encoder.state_dict(), self.encoder_file)
        torch.save(self.decoder.state_dict(), self.decoder_file)
        torch.save(self.property_predictor.state_dict(), self.property_predictor_file)
        
        temp_df = pd.DataFrame(data=train_losses, columns=self.loss_columns)
        self.train_losses_df = self.train_losses_df.append(temp_df, ignore_index=True)
        self.train_losses_df.to_csv(self.train_losses_file, index=False)
        
        temp_df = pd.DataFrame(data=val_losses, columns=self.loss_columns)
        self.val_losses_df = self.val_losses_df.append(temp_df, ignore_index=True)
        self.val_losses_df.to_csv(self.val_losses_file, index=False)
        
    def load_parameters_and_losses(self):
        self.encoder.load_state_dict(torch.load(self.encoder_file))
        self.decoder.load_state_dict(torch.load(self.decoder_file))
        self.property_predictor.load_state_dict(torch.load(self.property_predictor_file))
        self.train_losses_df = pd.read_csv(self.train_losses_file)
        self.val_losses_df = pd.read_csv(self.val_losses_file)
        
    
    def fit(self, epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous=False):
        
        self.intitialize_networks()
    
        if load_previous:
            self.load_parameters_and_losses()
            
        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)
        self.property_predictor = self.property_predictor.to(device)

        enc_opt = optim.Adam(self.encoder.parameters(), lr=lr)
        dec_opt = optim.Adam(self.decoder.parameters(), lr=lr)
        pp_opt = optim.Adam(self.property_predictor.parameters(), lr=lr)

        train_losses = []
        val_losses = []
        prev_n_epochs = len(self.train_losses_df) 

        start_time = time()
        for epoch in tqdm(range(epochs)):
            self.encoder.train()
            self.decoder.train()
            self.property_predictor.train()
            train_loss = self.process_dl(prev_n_epochs + epoch, anneal_start, device, 
                                           train_dl, True, enc_opt, dec_opt, pp_opt)
            train_losses.append(train_loss)

            self.encoder.eval()
            self.decoder.eval()
            self.property_predictor.eval()
            with torch.no_grad():
                val_loss = self.process_dl(prev_n_epochs + epoch, anneal_start, device,
                                             valid_dl, False)
                val_losses.append(val_loss)
                
            # if epoch % 5 == 4:
            print(f"Epoch: {prev_n_epochs + epoch + 1:3d} | " + \
                  f"Train Loss: {train_loss[0]:10.5f} | Val Loss: {val_loss[0]:10.5f} | " + \
                  f"Time Taken: {self.__time_since(start_time)}")
            print(f"Train Recon Loss: {train_loss[1]:10.5f} | Val Recon Loss: {val_loss[1]:10.5f}")
            print(f"Train KLD: {train_loss[2]:10.5f} | Val KLD: {val_loss[2]:10.5f}")
            print(f"Train Reg Loss: {train_loss[3]:10.5f} | Val Reg Loss: {val_loss[3]:10.5f}\n")
            
            if epoch % save_every == save_every - 1:
                self.save_parameters_and_losses(train_losses, val_losses)
                train_losses = []
                val_losses = []
            
        self.save_parameters_and_losses(train_losses, val_losses)
        
        del self.encoder
        del self.decoder
        del self.property_predictor
        
    def plot_losses(self):
        
        train_losses_df = pd.read_csv(self.train_losses_file)
        val_losses_df = pd.read_csv(self.val_losses_file)
        
        plt.rcParams.update({'font.size': 15})
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20,10))
        
                
        ax1.plot(train_losses_df['total_loss'], label='Train')
        ax1.plot(val_losses_df['total_loss'], label='Validation')
        ax1.set_ylabel('Total Loss')
        ax1.legend(loc='upper right')
        
        ax2.plot(train_losses_df['reconstruction_loss'], label='Train')
        ax2.plot(val_losses_df['reconstruction_loss'], label='Validation')
        ax2.set_ylabel('Reconstruction Loss')
        ax2.legend(loc='upper right')
        
        ax3.plot(train_losses_df['kl_divergence'], label='Train')
        ax3.plot(val_losses_df['kl_divergence'], label='Validation')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('KL Divergence')
        ax3.legend(loc='upper right')
        
        
        ax4.plot(train_losses_df['regression_loss'], label='Train')
        ax4.plot(val_losses_df['regression_loss'], label='Validation')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Regression Loss')
        ax4.legend(loc='upper right')
        
        plt.show()

In [11]:
def get_dataloaders(df, train_idx, valid_idx, test_idx, max_samples):
    start_time = time()
    train_dl = vae_utils.get_dl(df, train_idx[:max_samples], bs, shuffle=True)
    valid_dl = vae_utils.get_dl(df, valid_idx[:max_samples], bs)
    test_dl = vae_utils.get_dl(df, test_idx[:max_samples], bs)
    print(f'Time taken to get dataloaders: {time() - start_time:.2f}s')
    
    return train_dl, valid_dl, test_dl

In [12]:
def get_pred_mae(reg_col, dl):
    encoder = Encoder(vae_utils.n_letters, NLATENT, DECODER_HIDDEN_SIZE)
    property_predictor = PropertyPredictor(NLATENT)

    encoder.load_state_dict(torch.load(pathlib.Path(MODELS_DIR, f'{reg_col}_encoder.pth')))
    property_predictor.load_state_dict(torch.load(pathlib.Path(MODELS_DIR, f'{reg_col}_property_predictor.pth')))
    
    encoder.eval()
    property_predictor.eval()
    
    with torch.no_grad():
        abs_errors = []
        for input_tensors, target_tensors, property_values in dl:
            z, mean, log_var, dec_init_hidden = encoder(input_tensors)
            reg_pred = property_predictor(z)
            abs_error = torch.abs(property_values - reg_pred.flatten())
            abs_errors.append(abs_error)
            
        abs_errors = torch.cat(abs_errors)
        mae_error = abs_errors.mean()
        
    return mae_error

## Define Parameters for Training

In [13]:
valid_pct = .1
bs = 2048
max_samples = 250000
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
epochs = 120
save_every = 5
anneal_start = 29
lr = 0.0005
load_previous = False

## Train and Test the Network for _logP_

In [14]:
errors = {}

In [None]:
reg_col = 'logP'
df, train_idx, valid_idx, test_idx = vae_utils.get_train_valid_test_splits(reg_col, valid_pct)
train_dl, valid_dl, test_dl = get_dataloaders(df, train_idx, valid_idx, test_idx, max_samples)
trainer = Trainer(reg_col)
trainer.fit(epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous)

100%|██████████| 224018/224018 [04:14<00:00, 881.62it/s]
100%|██████████| 24891/24891 [00:28<00:00, 883.56it/s]
100%|██████████| 546/546 [00:00<00:00, 805.07it/s]


Time taken to get dataloaders: 283.51s


  1%|          | 1/120 [04:30<8:56:05, 270.30s/it]

Epoch:   1 | Train Loss:    0.55326 | Val Loss:   -0.24018 | Time Taken: 4m 30s
Train Recon Loss:   -0.60448 | Val Recon Loss:   -0.65918
Train KLD: 243816.15256 | Val KLD: 409009.96755
Train Reg Loss:    1.15774 | Val Reg Loss:    0.41900



  2%|▏         | 2/120 [09:02<8:53:58, 271.51s/it]

Epoch:   2 | Train Loss:   -0.38025 | Val Loss:   -0.45850 | Time Taken: 9m 3s
Train Recon Loss:   -0.67541 | Val Recon Loss:   -0.67114
Train KLD: 470523.34077 | Val KLD: 482249.31490
Train Reg Loss:    0.29516 | Val Reg Loss:    0.21264



  2%|▎         | 3/120 [13:34<8:49:50, 271.72s/it]

Epoch:   3 | Train Loss:   -0.51106 | Val Loss:   -0.53335 | Time Taken: 13m 35s
Train Recon Loss:   -0.68364 | Val Recon Loss:   -0.68876
Train KLD: 567402.15611 | Val KLD: 586440.30168
Train Reg Loss:    0.17258 | Val Reg Loss:    0.15540



  3%|▎         | 4/120 [18:07<8:46:32, 272.35s/it]

Epoch:   4 | Train Loss:   -0.54727 | Val Loss:   -0.58458 | Time Taken: 18m 8s
Train Recon Loss:   -0.68642 | Val Recon Loss:   -0.69375
Train KLD: 636262.99915 | Val KLD: 646872.89183
Train Reg Loss:    0.13915 | Val Reg Loss:    0.10916

Epoch:   5 | Train Loss:   -0.58122 | Val Loss:   -0.60482 | Time Taken: 22m 41s
Train Recon Loss:   -0.69613 | Val Recon Loss:   -0.69736
Train KLD: 691638.52727 | Val KLD: 695744.13642
Train Reg Loss:    0.11490 | Val Reg Loss:    0.09253



  5%|▌         | 6/120 [27:14<8:38:13, 272.75s/it]

Epoch:   6 | Train Loss:   -0.58601 | Val Loss:   -0.61686 | Time Taken: 27m 14s
Train Recon Loss:   -0.69775 | Val Recon Loss:   -0.69825
Train KLD: 732138.30852 | Val KLD: 717072.05529
Train Reg Loss:    0.11171 | Val Reg Loss:    0.08136



  6%|▌         | 7/120 [31:46<8:33:23, 272.60s/it]

Epoch:   7 | Train Loss:   -0.61172 | Val Loss:   -0.61219 | Time Taken: 31m 47s
Train Recon Loss:   -0.69825 | Val Recon Loss:   -0.69850
Train KLD: 775134.49119 | Val KLD: 765895.37200
Train Reg Loss:    0.08645 | Val Reg Loss:    0.08624



  7%|▋         | 8/120 [36:19<8:28:41, 272.51s/it]

Epoch:   8 | Train Loss:   -0.61123 | Val Loss:   -0.61925 | Time Taken: 36m 19s
Train Recon Loss:   -0.69838 | Val Recon Loss:   -0.69858
Train KLD: 810054.30909 | Val KLD: 794592.79327
Train Reg Loss:    0.08693 | Val Reg Loss:    0.07911



  8%|▊         | 9/120 [40:51<8:24:02, 272.46s/it]

Epoch:   9 | Train Loss:   -0.61576 | Val Loss:   -0.52274 | Time Taken: 40m 51s
Train Recon Loss:   -0.69845 | Val Recon Loss:   -0.69868
Train KLD: 835736.32955 | Val KLD: 810605.09495
Train Reg Loss:    0.08205 | Val Reg Loss:    0.17533

Epoch:  10 | Train Loss:   -0.60687 | Val Loss:   -0.57229 | Time Taken: 45m 23s
Train Recon Loss:   -0.69861 | Val Recon Loss:   -0.69886
Train KLD: 818840.50057 | Val KLD: 762199.26322
Train Reg Loss:    0.09005 | Val Reg Loss:    0.12499



  9%|▉         | 11/120 [49:57<8:15:44, 272.89s/it]

Epoch:  11 | Train Loss:   -0.62288 | Val Loss:   -0.59688 | Time Taken: 49m 58s
Train Recon Loss:   -0.69868 | Val Recon Loss:   -0.69892
Train KLD: 777436.83608 | Val KLD: 674657.83714
Train Reg Loss:    0.07145 | Val Reg Loss:    0.09826



 10%|█         | 12/120 [54:30<8:11:18, 272.95s/it]

Epoch:  12 | Train Loss:   -0.60922 | Val Loss:   -0.59350 | Time Taken: 54m 31s
Train Recon Loss:   -0.69875 | Val Recon Loss:   -0.69869
Train KLD: 610105.70341 | Val KLD: 459722.60817
Train Reg Loss:    0.08024 | Val Reg Loss:    0.09820



 11%|█         | 13/120 [59:03<8:06:37, 272.87s/it]

Epoch:  13 | Train Loss:   -0.61190 | Val Loss:   -0.63406 | Time Taken: 59m 4s
Train Recon Loss:   -0.69876 | Val Recon Loss:   -0.69901
Train KLD: 348253.92273 | Val KLD: 268159.36298
Train Reg Loss:    0.07244 | Val Reg Loss:    0.05385



 12%|█▏        | 14/120 [1:03:36<8:02:07, 272.90s/it]

Epoch:  14 | Train Loss:   -0.59510 | Val Loss:   -0.60880 | Time Taken: 63m 37s
Train Recon Loss:   -0.69883 | Val Recon Loss:   -0.69905
Train KLD: 204980.79609 | Val KLD: 172585.67668
Train Reg Loss:    0.08066 | Val Reg Loss:    0.07083

Epoch:  15 | Train Loss:   -0.57279 | Val Loss:   -0.59595 | Time Taken: 68m 10s
Train Recon Loss:   -0.69883 | Val Recon Loss:   -0.69879
Train KLD: 128582.64347 | Val KLD: 112146.04026
Train Reg Loss:    0.08671 | Val Reg Loss:    0.06854



 13%|█▎        | 16/120 [1:12:46<7:54:48, 273.93s/it]

Epoch:  16 | Train Loss:   -0.54011 | Val Loss:   -0.57983 | Time Taken: 72m 46s
Train Recon Loss:   -0.69875 | Val Recon Loss:   -0.69902
Train KLD: 74264.54148 | Val KLD: 62489.05018
Train Reg Loss:    0.09688 | Val Reg Loss:    0.06723



 14%|█▍        | 17/120 [1:17:20<7:50:34, 274.12s/it]

Epoch:  17 | Train Loss:   -0.49673 | Val Loss:   -0.53078 | Time Taken: 77m 21s
Train Recon Loss:   -0.69890 | Val Recon Loss:   -0.69878
Train KLD: 39359.36250 | Val KLD: 32843.22288
Train Reg Loss:    0.11320 | Val Reg Loss:    0.09376



 15%|█▌        | 18/120 [1:21:54<7:45:32, 273.85s/it]

Epoch:  18 | Train Loss:   -0.43769 | Val Loss:   -0.48843 | Time Taken: 81m 54s
Train Recon Loss:   -0.69888 | Val Recon Loss:   -0.69911
Train KLD: 21115.75059 | Val KLD: 20744.68397
Train Reg Loss:    0.13145 | Val Reg Loss:    0.08321



 16%|█▌        | 19/120 [1:26:27<7:40:37, 273.63s/it]

Epoch:  19 | Train Loss:   -0.36368 | Val Loss:   -0.39067 | Time Taken: 86m 27s
Train Recon Loss:   -0.69890 | Val Recon Loss:   -0.69913
Train KLD: 11200.41780 | Val KLD: 10246.39739
Train Reg Loss:    0.14816 | Val Reg Loss:    0.13734

Epoch:  20 | Train Loss:   -0.19665 | Val Loss:   -0.24198 | Time Taken: 91m 1s
Train Recon Loss:   -0.69895 | Val Recon Loss:   -0.69912
Train KLD: 6514.37769 | Val KLD: 7063.35840
Train Reg Loss:    0.20656 | Val Reg Loss:    0.13648



 18%|█▊        | 21/120 [1:35:37<7:32:54, 274.49s/it]

Epoch:  21 | Train Loss:    0.19674 | Val Loss:   -0.01616 | Time Taken: 95m 38s
Train Recon Loss:   -0.69895 | Val Recon Loss:   -0.69911
Train KLD: 4591.25550 | Val KLD: 3684.12725
Train Reg Loss:    0.32915 | Val Reg Loss:    0.22835



 18%|█▊        | 22/120 [1:40:13<7:28:45, 274.75s/it]

Epoch:  22 | Train Loss:    0.76870 | Val Loss:    0.36085 | Time Taken: 100m 13s
Train Recon Loss:   -0.69870 | Val Recon Loss:   -0.69901
Train KLD: 2867.98708 | Val KLD: 2083.51270
Train Reg Loss:    0.50562 | Val Reg Loss:    0.36115



 19%|█▉        | 23/120 [1:44:49<7:24:52, 275.18s/it]

Epoch:  23 | Train Loss:    2.23923 | Val Loss:    1.08551 | Time Taken: 104m 49s
Train Recon Loss:   -0.69891 | Val Recon Loss:   -0.69902
Train KLD: 1958.64925 | Val KLD:  954.43869
Train Reg Loss:    1.15371 | Val Reg Loss:    0.91499



 20%|██        | 24/120 [1:49:25<7:20:29, 275.31s/it]

Epoch:  24 | Train Loss:    3.58729 | Val Loss:    1.70017 | Time Taken: 109m 25s
Train Recon Loss:   -0.69895 | Val Recon Loss:   -0.69918
Train KLD:  900.05398 | Val KLD:  129.21039
Train Reg Loss:    2.06075 | Val Reg Loss:    2.07986

Epoch:  25 | Train Loss:    2.37546 | Val Loss:    1.66393 | Time Taken: 114m 0s
Train Recon Loss:   -0.69903 | Val Recon Loss:   -0.69922
Train KLD:  146.86378 | Val KLD:   41.56446
Train Reg Loss:    2.09155 | Val Reg Loss:    2.08496



 22%|██▏       | 26/120 [1:58:37<7:12:07, 275.82s/it]

Epoch:  26 | Train Loss:   12.98066 | Val Loss:    1.88748 | Time Taken: 118m 38s
Train Recon Loss:   -0.69901 | Val Recon Loss:   -0.69923
Train KLD:  644.28412 | Val KLD:   27.11551
Train Reg Loss:    2.09144 | Val Reg Loss:    2.09901



 22%|██▎       | 27/120 [2:03:13<7:07:30, 275.81s/it]

Epoch:  27 | Train Loss:    4.62977 | Val Loss:    2.11704 | Time Taken: 123m 14s
Train Recon Loss:   -0.69909 | Val Recon Loss:   -0.69910
Train KLD:   68.29951 | Val KLD:   15.55939
Train Reg Loss:    2.08969 | Val Reg Loss:    2.07823



 23%|██▎       | 28/120 [2:07:49<7:02:59, 275.86s/it]

Epoch:  28 | Train Loss:   43.30083 | Val Loss:    3.12163 | Time Taken: 127m 50s
Train Recon Loss:   -0.69912 | Val Recon Loss:   -0.69931
Train KLD:  351.59661 | Val KLD:   14.59354
Train Reg Loss:    2.08861 | Val Reg Loss:    2.08135



 24%|██▍       | 29/120 [2:12:25<6:58:21, 275.85s/it]

Epoch:  29 | Train Loss:    7.15266 | Val Loss:    3.23558 | Time Taken: 132m 25s
Train Recon Loss:   -0.69904 | Val Recon Loss:   -0.69927
Train KLD:   21.42856 | Val KLD:    6.90213
Train Reg Loss:    2.08868 | Val Reg Loss:    2.07858

Epoch:  30 | Train Loss:   23.89623 | Val Loss:   10.87087 | Time Taken: 137m 1s
Train Recon Loss:   -0.69919 | Val Recon Loss:   -0.69929
Train KLD:   45.01454 | Val KLD:   18.98640
Train Reg Loss:    2.08815 | Val Reg Loss:    2.07695



 26%|██▌       | 31/120 [2:21:38<6:49:46, 276.25s/it]

Epoch:  31 | Train Loss:   11.29414 | Val Loss:    6.29753 | Time Taken: 141m 39s
Train Recon Loss:   -0.69918 | Val Recon Loss:   -0.69940
Train KLD:   13.55424 | Val KLD:    6.70531
Train Reg Loss:    2.08438 | Val Reg Loss:    2.09495



 27%|██▋       | 32/120 [2:26:14<6:44:46, 275.98s/it]

Epoch:  32 | Train Loss:    6.62494 | Val Loss:   15.29438 | Time Taken: 146m 14s
Train Recon Loss:   -0.69918 | Val Recon Loss:   -0.69941
Train KLD:    5.94651 | Val KLD:   15.80308
Train Reg Loss:    2.08645 | Val Reg Loss:    2.07448



 28%|██▊       | 33/120 [2:30:50<6:40:08, 275.97s/it]

Epoch:  33 | Train Loss:    9.44371 | Val Loss:   15.85528 | Time Taken: 150m 50s
Train Recon Loss:   -0.69927 | Val Recon Loss:   -0.69943
Train KLD:    8.45884 | Val KLD:   15.19821
Train Reg Loss:    2.08531 | Val Reg Loss:    2.07729



 28%|██▊       | 34/120 [2:35:25<6:35:27, 275.90s/it]

Epoch:  34 | Train Loss:    6.25732 | Val Loss:    6.42814 | Time Taken: 155m 26s
Train Recon Loss:   -0.69928 | Val Recon Loss:   -0.69946
Train KLD:    4.96325 | Val KLD:    5.14438
Train Reg Loss:    2.08263 | Val Reg Loss:    2.07575

Epoch:  35 | Train Loss:    3.40825 | Val Loss:    5.33338 | Time Taken: 160m 2s
Train Recon Loss:   -0.69936 | Val Recon Loss:   -0.69953
Train KLD:    2.03862 | Val KLD:    3.97755
Train Reg Loss:    2.08264 | Val Reg Loss:    2.08198



 30%|███       | 36/120 [2:44:39<6:26:39, 276.18s/it]

Epoch:  36 | Train Loss:    3.14087 | Val Loss:    3.48231 | Time Taken: 164m 39s
Train Recon Loss:   -0.69941 | Val Recon Loss:   -0.69959
Train KLD:    1.75955 | Val KLD:    2.10918
Train Reg Loss:    2.08508 | Val Reg Loss:    2.07792



 31%|███       | 37/120 [2:49:14<6:21:56, 276.10s/it]

Epoch:  37 | Train Loss:    3.12916 | Val Loss:    3.14009 | Time Taken: 169m 15s
Train Recon Loss:   -0.69942 | Val Recon Loss:   -0.69956
Train KLD:    1.74147 | Val KLD:    1.76381
Train Reg Loss:    2.08870 | Val Reg Loss:    2.07745



 32%|███▏      | 38/120 [2:53:50<6:17:05, 275.92s/it]

Epoch:  38 | Train Loss:    2.79485 | Val Loss:    6.14128 | Time Taken: 173m 50s
Train Recon Loss:   -0.69940 | Val Recon Loss:   -0.69964
Train KLD:    1.41147 | Val KLD:    4.75962
Train Reg Loss:    2.08325 | Val Reg Loss:    2.08289



 32%|███▎      | 39/120 [2:58:26<6:12:30, 275.94s/it]

Epoch:  39 | Train Loss:    3.67216 | Val Loss:   12.60929 | Time Taken: 178m 26s
Train Recon Loss:   -0.69942 | Val Recon Loss:   -0.69966
Train KLD:    2.29162 | Val KLD:   11.23170
Train Reg Loss:    2.08025 | Val Reg Loss:    2.07864

Epoch:  40 | Train Loss:    6.13809 | Val Loss:    4.99268 | Time Taken: 183m 2s
Train Recon Loss:   -0.69949 | Val Recon Loss:   -0.69962
Train KLD:    4.75560 | Val KLD:    3.61696
Train Reg Loss:    2.08220 | Val Reg Loss:    2.07551



 34%|███▍      | 41/120 [3:07:39<6:03:34, 276.14s/it]

Epoch:  41 | Train Loss:    2.96787 | Val Loss:    4.83292 | Time Taken: 187m 39s
Train Recon Loss:   -0.69947 | Val Recon Loss:   -0.69930
Train KLD:    1.58336 | Val KLD:    3.44902
Train Reg Loss:    2.08401 | Val Reg Loss:    2.08326



 35%|███▌      | 42/120 [3:12:15<5:58:51, 276.04s/it]

Epoch:  42 | Train Loss:    2.57697 | Val Loss:    2.21074 | Time Taken: 192m 15s
Train Recon Loss:   -0.69945 | Val Recon Loss:   -0.69969
Train KLD:    1.19452 | Val KLD:    0.82669
Train Reg Loss:    2.08190 | Val Reg Loss:    2.08375



 36%|███▌      | 43/120 [3:16:50<5:54:06, 275.93s/it]

Epoch:  43 | Train Loss:    2.26776 | Val Loss:    2.38529 | Time Taken: 196m 51s
Train Recon Loss:   -0.69954 | Val Recon Loss:   -0.69969
Train KLD:    0.88711 | Val KLD:    1.01535
Train Reg Loss:    2.08019 | Val Reg Loss:    2.06964



 37%|███▋      | 44/120 [3:21:26<5:49:31, 275.95s/it]

Epoch:  44 | Train Loss:    2.21898 | Val Loss:    3.48729 | Time Taken: 201m 27s
Train Recon Loss:   -0.69958 | Val Recon Loss:   -0.69976
Train KLD:    0.83519 | Val KLD:    2.11449
Train Reg Loss:    2.08336 | Val Reg Loss:    2.07256

Epoch:  45 | Train Loss:    2.74975 | Val Loss:    3.59016 | Time Taken: 206m 3s
Train Recon Loss:   -0.69950 | Val Recon Loss:   -0.69974
Train KLD:    1.36978 | Val KLD:    2.19744
Train Reg Loss:    2.07947 | Val Reg Loss:    2.09246



 38%|███▊      | 46/120 [3:30:39<5:40:40, 276.22s/it]

Epoch:  46 | Train Loss:    4.49666 | Val Loss:    2.20356 | Time Taken: 210m 40s
Train Recon Loss:   -0.69959 | Val Recon Loss:   -0.69975
Train KLD:    3.11710 | Val KLD:    0.83487
Train Reg Loss:    2.07915 | Val Reg Loss:    2.06844



 39%|███▉      | 47/120 [3:35:15<5:35:48, 276.01s/it]

Epoch:  47 | Train Loss:    1.96091 | Val Loss:    1.87470 | Time Taken: 215m 15s
Train Recon Loss:   -0.69959 | Val Recon Loss:   -0.69981
Train KLD:    0.58071 | Val KLD:    0.49395
Train Reg Loss:    2.07979 | Val Reg Loss:    2.08057



 40%|████      | 48/120 [3:39:51<5:31:15, 276.04s/it]

Epoch:  48 | Train Loss:    2.32801 | Val Loss:    2.74491 | Time Taken: 219m 52s
Train Recon Loss:   -0.69963 | Val Recon Loss:   -0.69985
Train KLD:    0.94604 | Val KLD:    1.35272
Train Reg Loss:    2.08160 | Val Reg Loss:    2.09205



 41%|████      | 49/120 [3:44:27<5:26:38, 276.04s/it]

Epoch:  49 | Train Loss:    3.34095 | Val Loss:    5.32350 | Time Taken: 224m 28s
Train Recon Loss:   -0.69965 | Val Recon Loss:   -0.69983
Train KLD:    1.95754 | Val KLD:    3.95419
Train Reg Loss:    2.08306 | Val Reg Loss:    2.06914

Epoch:  50 | Train Loss:    4.69276 | Val Loss:    3.10708 | Time Taken: 229m 3s
Train Recon Loss:   -0.69969 | Val Recon Loss:   -0.69988
Train KLD:    3.31341 | Val KLD:    1.73411
Train Reg Loss:    2.07904 | Val Reg Loss:    2.07285



 42%|████▎     | 51/120 [3:53:40<5:17:41, 276.25s/it]

Epoch:  51 | Train Loss:    3.04993 | Val Loss:    4.01468 | Time Taken: 233m 41s
Train Recon Loss:   -0.69973 | Val Recon Loss:   -0.69987
Train KLD:    1.66772 | Val KLD:    2.64443
Train Reg Loss:    2.08195 | Val Reg Loss:    2.07012



 43%|████▎     | 52/120 [3:58:16<5:12:59, 276.17s/it]

Epoch:  52 | Train Loss:    3.57507 | Val Loss:    2.62809 | Time Taken: 238m 17s
Train Recon Loss:   -0.69971 | Val Recon Loss:   -0.69988
Train KLD:    2.19378 | Val KLD:    1.25096
Train Reg Loss:    2.08100 | Val Reg Loss:    2.07701



 44%|████▍     | 53/120 [4:02:52<5:08:17, 276.08s/it]

Epoch:  53 | Train Loss:    2.57537 | Val Loss:    2.26668 | Time Taken: 242m 53s
Train Recon Loss:   -0.69973 | Val Recon Loss:   -0.69989
Train KLD:    1.19586 | Val KLD:    0.89602
Train Reg Loss:    2.07924 | Val Reg Loss:    2.07055



 45%|████▌     | 54/120 [4:07:28<5:03:39, 276.05s/it]

Epoch:  54 | Train Loss:    2.65262 | Val Loss:    8.92662 | Time Taken: 247m 29s
Train Recon Loss:   -0.69975 | Val Recon Loss:   -0.69990
Train KLD:    1.27319 | Val KLD:    7.54652
Train Reg Loss:    2.07918 | Val Reg Loss:    2.08000

Epoch:  55 | Train Loss:    9.27399 | Val Loss:    3.61259 | Time Taken: 252m 4s
Train Recon Loss:   -0.69975 | Val Recon Loss:   -0.69991
Train KLD:    7.89569 | Val KLD:    2.23806
Train Reg Loss:    2.07805 | Val Reg Loss:    2.07445



 47%|████▋     | 56/120 [4:16:41<4:54:34, 276.16s/it]

Epoch:  56 | Train Loss:    3.25199 | Val Loss:    3.54051 | Time Taken: 256m 41s
Train Recon Loss:   -0.69977 | Val Recon Loss:   -0.69990
Train KLD:    1.87297 | Val KLD:    2.16703
Train Reg Loss:    2.07880 | Val Reg Loss:    2.07337



 48%|████▊     | 57/120 [4:21:17<4:49:52, 276.07s/it]

Epoch:  57 | Train Loss:    2.23251 | Val Loss:    2.14504 | Time Taken: 261m 17s
Train Recon Loss:   -0.69975 | Val Recon Loss:   -0.69990
Train KLD:    0.85200 | Val KLD:    0.77550
Train Reg Loss:    2.08026 | Val Reg Loss:    2.06944



 48%|████▊     | 58/120 [4:25:53<4:45:16, 276.06s/it]

Epoch:  58 | Train Loss:    1.92801 | Val Loss:    2.00291 | Time Taken: 265m 53s
Train Recon Loss:   -0.69978 | Val Recon Loss:   -0.69993
Train KLD:    0.55154 | Val KLD:    0.63178
Train Reg Loss:    2.07625 | Val Reg Loss:    2.07106



 49%|████▉     | 59/120 [4:30:28<4:40:27, 275.87s/it]

Epoch:  59 | Train Loss:    2.07376 | Val Loss:    2.26360 | Time Taken: 270m 29s
Train Recon Loss:   -0.69978 | Val Recon Loss:   -0.69994
Train KLD:    0.69269 | Val KLD:    0.89552
Train Reg Loss:    2.08085 | Val Reg Loss:    2.06803

Epoch:  60 | Train Loss:    2.41239 | Val Loss:    3.74458 | Time Taken: 275m 4s
Train Recon Loss:   -0.69977 | Val Recon Loss:   -0.69994
Train KLD:    1.03365 | Val KLD:    2.37759
Train Reg Loss:    2.07851 | Val Reg Loss:    2.06693



 51%|█████     | 61/120 [4:39:41<4:31:37, 276.22s/it]

Epoch:  61 | Train Loss:    2.83055 | Val Loss:    2.21838 | Time Taken: 279m 42s
Train Recon Loss:   -0.69980 | Val Recon Loss:   -0.69995
Train KLD:    1.45225 | Val KLD:    0.83289
Train Reg Loss:    2.07811 | Val Reg Loss:    2.08544



 52%|█████▏    | 62/120 [4:44:17<4:26:51, 276.06s/it]

Epoch:  62 | Train Loss:    2.27471 | Val Loss:    2.39496 | Time Taken: 284m 18s
Train Recon Loss:   -0.69978 | Val Recon Loss:   -0.69994
Train KLD:    0.89795 | Val KLD:    1.02569
Train Reg Loss:    2.07654 | Val Reg Loss:    2.06921



 52%|█████▎    | 63/120 [4:48:53<4:22:14, 276.05s/it]

Epoch:  63 | Train Loss:    2.77864 | Val Loss:    2.59754 | Time Taken: 288m 54s
Train Recon Loss:   -0.69978 | Val Recon Loss:   -0.69996
Train KLD:    1.40051 | Val KLD:    1.22107
Train Reg Loss:    2.07792 | Val Reg Loss:    2.07644



 53%|█████▎    | 64/120 [4:53:29<4:17:40, 276.08s/it]

Epoch:  64 | Train Loss:    3.97804 | Val Loss:    4.59395 | Time Taken: 293m 30s
Train Recon Loss:   -0.69980 | Val Recon Loss:   -0.69991
Train KLD:    2.60028 | Val KLD:    3.22732
Train Reg Loss:    2.07756 | Val Reg Loss:    2.06654

Epoch:  65 | Train Loss:    3.60266 | Val Loss:    3.19245 | Time Taken: 298m 6s
Train Recon Loss:   -0.69979 | Val Recon Loss:   -0.69995
Train KLD:    2.22725 | Val KLD:    1.82345
Train Reg Loss:    2.07521 | Val Reg Loss:    2.06894



 55%|█████▌    | 66/120 [5:02:43<4:08:40, 276.30s/it]

Epoch:  66 | Train Loss:    3.00033 | Val Loss:    2.50007 | Time Taken: 302m 43s
Train Recon Loss:   -0.69982 | Val Recon Loss:   -0.69993
Train KLD:    1.62246 | Val KLD:    1.12658
Train Reg Loss:    2.07769 | Val Reg Loss:    2.07342



 56%|█████▌    | 67/120 [5:07:19<4:03:58, 276.20s/it]

Epoch:  67 | Train Loss:    2.29410 | Val Loss:    6.68804 | Time Taken: 307m 19s
Train Recon Loss:   -0.69982 | Val Recon Loss:   -0.69998
Train KLD:    0.91812 | Val KLD:    5.31319
Train Reg Loss:    2.07581 | Val Reg Loss:    2.07482



 57%|█████▋    | 68/120 [5:11:55<3:59:23, 276.23s/it]

Epoch:  68 | Train Loss:    3.88353 | Val Loss:    2.78107 | Time Taken: 311m 55s
Train Recon Loss:   -0.69983 | Val Recon Loss:   -0.69998
Train KLD:    2.50593 | Val KLD:    1.41053
Train Reg Loss:    2.07744 | Val Reg Loss:    2.07052



 57%|█████▊    | 69/120 [5:16:31<3:54:42, 276.12s/it]

Epoch:  69 | Train Loss:    3.09906 | Val Loss:    4.91268 | Time Taken: 316m 31s
Train Recon Loss:   -0.69984 | Val Recon Loss:   -0.69997
Train KLD:    1.71709 | Val KLD:    3.50483
Train Reg Loss:    2.08181 | Val Reg Loss:    2.10781

Epoch:  70 | Train Loss:    3.74360 | Val Loss:    5.86419 | Time Taken: 321m 7s
Train Recon Loss:   -0.69981 | Val Recon Loss:   -0.69999
Train KLD:    2.36787 | Val KLD:    4.49230
Train Reg Loss:    2.07554 | Val Reg Loss:    2.07188



 59%|█████▉    | 71/120 [5:25:43<3:45:25, 276.04s/it]

Epoch:  71 | Train Loss:    4.25462 | Val Loss:    3.60540 | Time Taken: 325m 43s
Train Recon Loss:   -0.69983 | Val Recon Loss:   -0.69996
Train KLD:    2.87830 | Val KLD:    2.24013
Train Reg Loss:    2.07616 | Val Reg Loss:    2.06524



In [None]:
trainer.plot_losses()

In [None]:
reg_mean_train = df.iloc[train_idx].reg_col.mean()
mean_mae_test = np.mean(np.abs(df.iloc[test_idx].reg_col - reg_mean_train))
vae_mae_test = get_pred_mae(reg_col, test_dl).item()
errors[reg_col] = {'mean_mae_test': mean_mae_test, 'vae_mae_test': vae_mae_test}

## Train and Test the Network for _QED_

In [None]:
reg_col = 'qed'
df, train_idx, valid_idx, test_idx = vae_utils.get_train_valid_test_splits(reg_col, valid_pct)
train_dl, valid_dl, test_dl = get_dataloaders(df, train_idx, valid_idx, test_idx, max_samples)
trainer = Trainer(reg_col)
trainer.fit(epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous)

In [None]:
trainer.plot_losses()

In [None]:
reg_mean_train = df.iloc[train_idx].reg_col.mean()
mean_mae_test = np.mean(np.abs(df.iloc[test_idx].reg_col - reg_mean_train))
vae_mae_test = get_pred_mae(reg_col, test_dl).item()
errors[reg_col] = {'mean_mae_test': mean_mae_test, 'vae_mae_test': vae_mae_test}

## Report the Results

In [None]:
pd.DataFrame(errors).transpose()

## Remarks

The model was trained in 4 stages. Except for the 1st stage, the other two stages were trained after loading the saving parameters from the previous stage.
The stages are summarized below:

1. max_exmaples = 20,000, epochs = 50, anneal_start = 15, load_previous = False, bs = 200
2. max_examples = 100,000, epochs = 15, anneal_start = 5, load_previous = True, bs = 200
3. max_examples = all, epochs = 10, anneal_start = 3, load_previous = True, bs = 200
4. max_examples = all, epochs = 15, anneal_start = 5, load_previous = True, bs = 1000
