## Setup Environment

In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import pathlib
import yaml
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import math

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset
from torch.utils.data import DataLoader
from torch import optim

from sklearn.model_selection import train_test_split
from time import time
from tqdm import tqdm

## Declare Constants

In [3]:
MAX_LEN = 120
NLATENT = 196
DECODER_HIDDEN_SIZE = 488
DECODER_NUM_LAYERS = 3

DATA_DIR = './data'
DATA_FILE_NAME = '250k_rndm_zinc_drugs_clean_3.csv'
CHAR_FILE_NAME = 'zinc.json'
TEST_IDX_FILE_NAME = 'test_idx.npy'

ALL_LETTERS = yaml.safe_load(open(pathlib.Path(DATA_DIR, CHAR_FILE_NAME))) + ['SOS']
N_LETTERS = len(ALL_LETTERS) 

MODELS_DIR = './models'

## Create a Utility Class

In [4]:
class VAEUtils:
    '''
    This purpose of this class is to help with various aspects
    of data processing
    '''
    
    def __init__(self, data_dir=DATA_DIR, 
                 data_file_name=DATA_FILE_NAME, 
                 test_idx_file_name=TEST_IDX_FILE_NAME,
                 max_len=MAX_LEN,
                 all_letters=ALL_LETTERS):
        
        self.data_dir = pathlib.Path(data_dir)
        self.data_file = self.data_dir / pathlib.Path(data_file_name)
        self.test_file = self.data_dir / pathlib.Path(test_idx_file_name)
        
        self.max_len = max_len
        
        self.all_letters = all_letters
        self.n_letters = len(all_letters)
        self.letters_to_indices_dict = dict((l, i) for i, l in enumerate(all_letters))
        self.indices_to_letters_dict = dict((i, l) for i, l in enumerate(all_letters))
               
    def get_data_df(self):
        df = pd.read_csv(self.data_file)
        df = df[df.smiles.str.len() <= self.max_len].reset_index(drop=True)
        
        # preprocess input smile to remove the newline character and add padding
        df.loc[:, 'smiles'] = df.loc[:, 'smiles'].str.strip()\
                    .str.pad(width=self.max_len, side='right', fillchar=" ")
        
        return df
        
    # One-hot matrix of first to last letters (not including EOS) for input
    def get_input_tensor(self, smile):
        tensor = torch.zeros(1, len(smile), self.n_letters) # batch_size * seq_length * num_features
        for i, letter in enumerate(smile):
            tensor[0][i][self.letters_to_indices_dict[letter]] = 1
        return tensor

    # LongTensor of first letter to end (EOS) for target
    def get_target_tensor(self, smile):
        letter_indexes = [self.letters_to_indices_dict[l] for l in smile]
        # letter_indexes.append(self.n_letters - 1) # EOS
        return torch.LongTensor(letter_indexes)
    
    def get_train_valid_test_splits(self, reg_col, valid_pct=.1):
        df = self.get_data_df()[['smiles', reg_col]]
        df = df.rename(columns={reg_col: 'reg_col'})
        
        test_idx = np.load(self.test_file)
        non_test_idx = np.array(df[~df.index.isin(test_idx)].index)
        train_idx, valid_idx = train_test_split(non_test_idx, test_size=valid_pct, 
                                                random_state=42, shuffle=True)
        
        assert len(df) == len(test_idx) + len(train_idx) + len(valid_idx)
        
        return df, train_idx, valid_idx, test_idx
         
    def get_dl(self, df, idx, bs, shuffle=False):
        
        df = df.iloc[idx]
        
        input_tensors = torch.zeros(len(df), self.max_len, self.n_letters)
        target_tensors = torch.zeros(len(df), self.max_len, dtype=torch.long)
        for i, smile in enumerate(tqdm(df.smiles)):
            input_tensors[i] = self.get_input_tensor(smile)
            target_tensors[i] = self.get_target_tensor(smile)
        
        input_tensors = input_tensors
        target_tensors = target_tensors
        
        # original_lengths = torch.tensor(df.smiles.str.strip().str.len().to_numpy())

        property_values = torch.tensor(df.reg_col.to_numpy()).type(torch.float32)
         
        ds = TensorDataset(input_tensors, target_tensors, property_values)
        dl = DataLoader(ds, shuffle=shuffle, batch_size=bs)
        
        return dl
    
vae_utils = VAEUtils()

## Create the Model Networks

In [5]:
class Lambda(nn.Module):
    '''
    This class simplifies layers from 
    custom functions
    '''
    def __init__(self, func):
        super().__init__()
        self.func = func
    
    def forward(self, x):
        return self.func(x)

In [6]:
class Encoder(nn.Module):
    def __init__(self, n_letters, nlatent, decoder_hidden_size):
        super().__init__()
        self.n_letters = n_letters
        self.nlatent = nlatent
        self.decoder_hidden_size = decoder_hidden_size
        self.encoder = nn.Sequential(
            Lambda(lambda x: x.permute(0, 2, 1)), # the features are in the channels dimension
            nn.Conv1d(in_channels=n_letters, out_channels=9, kernel_size=9),
            # nn.Tanh(), # the authors of the paper used tanh
            nn.ReLU(),
            nn.BatchNorm1d(9), # the authors of the paper did batch normalization
            nn.Conv1d(in_channels=9, out_channels=9, kernel_size=9),
            # nn.Tanh(), # the authors of the paper used tanh
            nn.ReLU(),
            nn.BatchNorm1d(9), # the authors of the paper did batch normalization
            nn.Conv1d(in_channels=9, out_channels=11, kernel_size=10),
            # nn.Tanh(), # the authors of the paper used tanh
            nn.ReLU(),
            nn.BatchNorm1d(11), # the authors of the paper did batch normalization
            nn.Flatten()
        )
        
        self.mean = nn.Linear(1045, nlatent)
        self.log_var = nn.Linear(1045, nlatent)
        self.dec_init_hidden = nn.Linear(nlatent, decoder_hidden_size)
        
    def reparameterize(self, mean, log_var):
        std = torch.exp(0.5 * log_var) 
        eps = torch.randn_like(std)
        sample = mean + (eps * std) 
        return sample
        
    def forward(self, x):
        x = self.encoder(x)
        mean = self.mean(x)
        log_var = self.log_var(x)
        z = self.reparameterize(mean, log_var)
        dec_init_hidden = self.dec_init_hidden(z)
        
        return z, mean, log_var, dec_init_hidden

In [7]:
def preprocess_decoder_input(input_tensors):
    '''
    Adjust for SOS and make batch the second demension
    '''
    sos_tensor = torch.zeros(1, N_LETTERS)
    sos_tensor[0][vae_utils.letters_to_indices_dict['SOS']] = 1
    new_tensor = torch.zeros(input_tensors.shape[0], MAX_LEN, N_LETTERS)
    new_tensor[:][0] = sos_tensor
    new_tensor[:, 1:MAX_LEN, :] = input_tensors[:, 0:MAX_LEN-1, :]
    new_tensor = new_tensor.permute(1, 0, 2).to(input_tensors.device)
    
    return new_tensor

In [8]:
class Decoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.preprocess = Lambda(preprocess_decoder_input)
        self.rnn = nn.GRU(input_size, hidden_size, num_layers)
        self.out = nn.Linear(hidden_size, input_size)
        self.softmax = nn.LogSoftmax(dim=2)
        
    def forward(self, input, hidden):
        output, hidden = self.rnn(self.preprocess(input), )
        output = output.permute(1, 0, 2)
        output = self.softmax(self.out(output))
        return output, hidden

In [9]:
class PropertyPredictor(nn.Module):
    def __init__(self, nlatent):
        super().__init__()
        self.predictor = nn.Sequential(
            nn.Linear(nlatent, 1000),
            # nn.Tanh(), # the authors of the paper used tanh
            nn.ReLU(),
            nn.Dropout(.2),
            nn.Linear(1000, 1000),
            # nn.Tanh(), # the authors of the paper used tanh
            nn.ReLU(),
            nn.Dropout(.2),
            nn.Linear(1000, 1)
        )
        
    def forward(self, x):
        return self.predictor(x) 

## Write Helper Classes & Functions for Training and Testing

In [10]:
class Trainer:
    def __init__(self, reg_col):
        
        models_dir = pathlib.Path(MODELS_DIR)
        models_dir.mkdir(exist_ok=True, parents=True)
        self.encoder_file = models_dir / pathlib.Path(f'{reg_col}_encoder.pth')
        self.decoder_file = models_dir / pathlib.Path(f'{reg_col}_decoder.pth')
        self.property_predictor_file = models_dir / pathlib.Path(f'{reg_col}_property_predictor.pth')
        self.train_losses_file = models_dir / pathlib.Path(f'{reg_col}_train_losses.csv')
        self.val_losses_file = models_dir / pathlib.Path(f'{reg_col}_valid_losses.csv')

        self.loss_columns = ['total_loss', 'reconstruction_loss', 'kl_divergence', 'regression_loss']
        self.train_losses_df = pd.DataFrame(columns=self.loss_columns)
        self.val_losses_df = pd.DataFrame(columns=self.loss_columns)

    def intitialize_networks(self):
        self.encoder = Encoder(vae_utils.n_letters, NLATENT, DECODER_HIDDEN_SIZE)
        self.decoder = Decoder(vae_utils.n_letters, DECODER_HIDDEN_SIZE, DECODER_NUM_LAYERS)
        self.property_predictor = PropertyPredictor(NLATENT)
        
    def __time_since(self, since):
        now = time()
        s = now - since
        m = math.floor(s / 60)
        s -= m * 60
        return f'{m}m {s:.0f}s'     
        
    def kl_anneal_function(self, epoch, anneal_start, k=1):
        return 1 / (1 + np.exp(- k * (epoch - anneal_start)))
    
    def get_losses(self, epoch, input_tensors, target_tensors, property_values):
                
        recontruction_loss_func = nn.NLLLoss()
        reg_loss_func = nn.MSELoss()
        
        input_tensors = input_tensors.to(device)
        target_tensors = target_tensors.to(device)
        property_values = property_values.to(device)

        z, mean, log_var, dec_init_hidden = self.encoder(input_tensors)
        output, hidden = self.decoder(input_tensors, 
                                 dec_init_hidden.unsqueeze(0).repeat(DECODER_NUM_LAYERS, 1, 1))
        reg_pred = self.property_predictor(z)

        kl_divergence = -0.5 * torch.sum(1 + log_var - mean.pow(2) - log_var.exp())
        kl_weight = self.kl_anneal_function(epoch, anneal_start)

        reg_loss = reg_loss_func(reg_pred.flatten(), property_values)
        reg_loss = reg_loss.type(torch.float32) 

        reconstruction_loss = 0
        for i in range(input_tensors.shape[0]):
            reconstruction_loss += recontruction_loss_func(output[i], target_tensors[i])
        reconstruction_loss /= input_tensors.shape[0]

        loss = reconstruction_loss + kl_divergence * kl_weight + reg_loss
        
        return loss, reconstruction_loss, kl_divergence, reg_loss
    
    def process_dl(self, epoch, anneal_start, device, dl, train, 
                   enc_opt=None, dec_opt=None, pp_opt=None):
        
        num_loaders = len(dl)
        loader_loss = 0
        recon_loss_epoch = 0
        kld_epoch = 0
        reg_loss_epoch = 0
        
        for input_tensors, target_tensors, property_values in dl:
            loss, reconstruction_loss, kl_divergence, reg_loss = \
                self.get_losses(epoch, input_tensors, target_tensors, property_values)
            
            if train:
                enc_opt.zero_grad()
                dec_opt.zero_grad()
                pp_opt.zero_grad()
                
                loss.backward()
                
                enc_opt.step()
                dec_opt.step()
                pp_opt.step()
                
            loader_loss += loss.item()
            recon_loss_epoch += reconstruction_loss.item()
            kld_epoch += kl_divergence.item()
            reg_loss_epoch += reg_loss.item()
        
        loader_loss /= num_loaders
        recon_loss_epoch /= num_loaders
        kld_epoch /= num_loaders
        reg_loss_epoch /= num_loaders
        
        return (loader_loss, recon_loss_epoch, kld_epoch, reg_loss_epoch)
    
    def save_parameters_and_losses(self, train_losses:list, val_losses:list):
        
        torch.save(self.encoder.state_dict(), self.encoder_file)
        torch.save(self.decoder.state_dict(), self.decoder_file)
        torch.save(self.property_predictor.state_dict(), self.property_predictor_file)
        
        temp_df = pd.DataFrame(data=train_losses, columns=self.loss_columns)
        self.train_losses_df = self.train_losses_df.append(temp_df, ignore_index=True)
        self.train_losses_df.to_csv(self.train_losses_file, index=False)
        
        temp_df = pd.DataFrame(data=val_losses, columns=self.loss_columns)
        self.val_losses_df = self.val_losses_df.append(temp_df, ignore_index=True)
        self.val_losses_df.to_csv(self.val_losses_file, index=False)
        
    def load_parameters_and_losses(self):
        self.encoder.load_state_dict(torch.load(self.encoder_file))
        self.decoder.load_state_dict(torch.load(self.decoder_file))
        self.property_predictor.load_state_dict(torch.load(self.property_predictor_file))
        self.train_losses_df = pd.read_csv(self.train_losses_file)
        self.val_losses_df = pd.read_csv(self.val_losses_file)
        
    
    def fit(self, epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous=False):
        
        self.intitialize_networks()
    
        if load_previous:
            self.load_parameters_and_losses()
            
        self.encoder = self.encoder.to(device)
        self.decoder = self.decoder.to(device)
        self.property_predictor = self.property_predictor.to(device)

        enc_opt = optim.Adam(self.encoder.parameters(), lr=lr)
        dec_opt = optim.Adam(self.decoder.parameters(), lr=lr)
        pp_opt = optim.Adam(self.property_predictor.parameters(), lr=lr)

        train_losses = []
        val_losses = []
        prev_n_epochs = len(self.train_losses_df) 

        start_time = time()
        for epoch in tqdm(range(epochs)):
            self.encoder.train()
            self.decoder.train()
            self.property_predictor.train()
            train_loss = self.process_dl(prev_n_epochs + epoch, anneal_start, device, 
                                           train_dl, True, enc_opt, dec_opt, pp_opt)
            train_losses.append(train_loss)

            self.encoder.eval()
            self.decoder.eval()
            self.property_predictor.eval()
            with torch.no_grad():
                val_loss = self.process_dl(prev_n_epochs + epoch, anneal_start, device,
                                             valid_dl, False)
                val_losses.append(val_loss)
                
            # if epoch % 5 == 4:
            print(f"Epoch: {prev_n_epochs + epoch + 1:3d} | " + \
                  f"Train Loss: {train_loss[0]:10.5f} | Val Loss: {val_loss[0]:10.5f} | " + \
                  f"Time Taken: {self.__time_since(start_time)}")
            print(f"Train Recon Loss: {train_loss[1]:10.5f} | Val Recon Loss: {val_loss[1]:10.5f}")
            print(f"Train KLD: {train_loss[2]:10.5f} | Val KLD: {val_loss[2]:10.5f}")
            print(f"Train Reg Loss: {train_loss[3]:10.5f} | Val Reg Loss: {val_loss[3]:10.5f}\n")
            
            if epoch % save_every == save_every - 1:
                self.save_parameters_and_losses(train_losses, val_losses)
                train_losses = []
                val_losses = []
            
        self.save_parameters_and_losses(train_losses, val_losses)
        
        del self.encoder
        del self.decoder
        del self.property_predictor
        
    def plot_losses(self):
        
        train_losses_df = pd.read_csv(self.train_losses_file)
        val_losses_df = pd.read_csv(self.val_losses_file)
        
        plt.rcParams.update({'font.size': 15})
        fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(20,10))
        
                
        ax1.plot(train_losses_df['total_loss'], label='Train')
        ax1.plot(val_losses_df['total_loss'], label='Validation')
        ax1.set_ylabel('Total Loss')
        ax1.legend(loc='upper right')
        
        ax2.plot(train_losses_df['reconstruction_loss'], label='Train')
        ax2.plot(val_losses_df['reconstruction_loss'], label='Validation')
        ax2.set_ylabel('Reconstruction Loss')
        ax2.legend(loc='upper right')
        
        ax3.plot(train_losses_df['kl_divergence'], label='Train')
        ax3.plot(val_losses_df['kl_divergence'], label='Validation')
        ax3.set_xlabel('Epoch')
        ax3.set_ylabel('KL Divergence')
        ax3.legend(loc='upper right')
        
        
        ax4.plot(train_losses_df['regression_loss'], label='Train')
        ax4.plot(val_losses_df['regression_loss'], label='Validation')
        ax4.set_xlabel('Epoch')
        ax4.set_ylabel('Regression Loss')
        ax4.legend(loc='upper right')
        
        plt.show()

In [11]:
def get_dataloaders(df, train_idx, valid_idx, max_samples, bs):
    start_time = time()
    train_dl = vae_utils.get_dl(df, train_idx[:max_samples], bs, shuffle=True)
    valid_dl = vae_utils.get_dl(df, valid_idx[:max_samples], bs)
    print(f'Time taken to get dataloaders: {time() - start_time:.2f}s')
    
    return train_dl, valid_dl

In [12]:
def get_pred_mae(reg_col, df, test_idx, bs):
    encoder = Encoder(vae_utils.n_letters, NLATENT, DECODER_HIDDEN_SIZE)
    property_predictor = PropertyPredictor(NLATENT)

    encoder.load_state_dict(torch.load(pathlib.Path(MODELS_DIR, f'{reg_col}_encoder.pth')))
    property_predictor.load_state_dict(torch.load(pathlib.Path(MODELS_DIR, f'{reg_col}_property_predictor.pth')))
    
    dl = vae_utils.get_dl(df, test_idx, bs)
    
    encoder.eval()
    property_predictor.eval()
    
    with torch.no_grad():
        abs_errors = []
        for input_tensors, target_tensors, property_values in dl:
            z, mean, log_var, dec_init_hidden = encoder(input_tensors)
            reg_pred = property_predictor(z)
            abs_error = torch.abs(property_values - reg_pred.flatten())
            abs_errors.append(abs_error)
            
        abs_errors = torch.cat(abs_errors)
        mae_error = abs_errors.mean()
        
    return mae_error

## Define Parameters for Training

In [13]:
valid_pct = .1
bs = 2048
max_samples = 250000
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device("cpu")
# device = torch.device("cpu")
epochs = 120
save_every = 2
anneal_start = 29
lr = 0.0005
load_previous = False

## Train and Test the Network for _logP_

In [14]:
errors = {}

In [None]:
reg_col = 'logP'
df, train_idx, valid_idx, test_idx = vae_utils.get_train_valid_test_splits(reg_col, valid_pct)
train_dl, valid_dl = get_dataloaders(df, train_idx, valid_idx, max_samples, bs)
trainer = Trainer(reg_col)
trainer.fit(epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous)

100%|██████████| 224018/224018 [04:20<00:00, 859.99it/s]
100%|██████████| 24891/24891 [00:28<00:00, 862.10it/s]
100%|██████████| 546/546 [00:00<00:00, 911.38it/s]


Time taken to get dataloaders: 290.53s


  1%|          | 1/120 [04:30<8:55:52, 270.19s/it]

Epoch:   1 | Train Loss:    2.42376 | Val Loss:    1.38142 | Time Taken: 4m 30s
Train Recon Loss:    1.18179 | Val Recon Loss:    0.89172
Train KLD: 263673.48295 | Val KLD: 397913.70433
Train Reg Loss:    1.24197 | Val Reg Loss:    0.48969



  2%|▏         | 2/120 [08:59<8:50:26, 269.71s/it]

Epoch:   2 | Train Loss:    1.00560 | Val Loss:    0.74065 | Time Taken: 8m 59s
Train Recon Loss:    0.69357 | Val Recon Loss:    0.55042
Train KLD: 477994.25312 | Val KLD: 501746.81130
Train Reg Loss:    0.31203 | Val Reg Loss:    0.19023



  2%|▎         | 3/120 [13:30<8:46:44, 270.12s/it]

Epoch:   3 | Train Loss:    0.65418 | Val Loss:    0.55034 | Time Taken: 13m 30s
Train Recon Loss:    0.48239 | Val Recon Loss:    0.42963
Train KLD: 555387.07116 | Val KLD: 552097.56550
Train Reg Loss:    0.17180 | Val Reg Loss:    0.12071

Epoch:   4 | Train Loss:    0.52617 | Val Loss:    0.47782 | Time Taken: 18m 0s
Train Recon Loss:    0.40147 | Val Recon Loss:    0.37639
Train KLD: 621789.10795 | Val KLD: 622369.07993
Train Reg Loss:    0.12469 | Val Reg Loss:    0.10143



  4%|▍         | 5/120 [22:31<8:38:23, 270.46s/it]

Epoch:   5 | Train Loss:    0.47541 | Val Loss:    0.48934 | Time Taken: 22m 31s
Train Recon Loss:    0.36258 | Val Recon Loss:    0.34613
Train KLD: 666524.28906 | Val KLD: 626949.27524
Train Reg Loss:    0.11282 | Val Reg Loss:    0.14321



  5%|▌         | 6/120 [27:01<8:33:28, 270.25s/it]

Epoch:   6 | Train Loss:    0.43188 | Val Loss:    0.41901 | Time Taken: 27m 1s
Train Recon Loss:    0.33609 | Val Recon Loss:    0.34752
Train KLD: 698316.64830 | Val KLD: 671604.25481
Train Reg Loss:    0.09576 | Val Reg Loss:    0.07147



  6%|▌         | 7/120 [31:32<8:29:19, 270.44s/it]

Epoch:   7 | Train Loss:    0.40787 | Val Loss:    0.38696 | Time Taken: 31m 32s
Train Recon Loss:    0.31935 | Val Recon Loss:    0.30938
Train KLD: 727872.33494 | Val KLD: 699678.96394
Train Reg Loss:    0.08844 | Val Reg Loss:    0.07751

Epoch:   8 | Train Loss:    0.38395 | Val Loss:    0.43490 | Time Taken: 36m 2s
Train Recon Loss:    0.30482 | Val Recon Loss:    0.29961
Train KLD: 760455.57841 | Val KLD: 717516.14663
Train Reg Loss:    0.07892 | Val Reg Loss:    0.13509



  8%|▊         | 9/120 [40:34<8:21:01, 270.83s/it]

Epoch:   9 | Train Loss:    0.37823 | Val Loss:    0.35110 | Time Taken: 40m 34s
Train Recon Loss:    0.29392 | Val Recon Loss:    0.28823
Train KLD: 761147.79034 | Val KLD: 696151.78245
Train Reg Loss:    0.08373 | Val Reg Loss:    0.06234

Epoch:  10 | Train Loss:    0.35688 | Val Loss:    0.33741 | Time Taken: 45m 5s
Train Recon Loss:    0.28476 | Val Recon Loss:    0.28063
Train KLD: 754408.13580 | Val KLD: 713673.07332
Train Reg Loss:    0.07057 | Val Reg Loss:    0.05532



  9%|▉         | 11/120 [49:36<8:12:14, 270.96s/it]

Epoch:  11 | Train Loss:    0.35428 | Val Loss:    0.33398 | Time Taken: 49m 36s
Train Recon Loss:    0.27996 | Val Recon Loss:    0.27508
Train KLD: 714360.29830 | Val KLD: 625623.31611
Train Reg Loss:    0.07032 | Val Reg Loss:    0.05540

Epoch:  12 | Train Loss:    0.35010 | Val Loss:    0.32576 | Time Taken: 54m 6s
Train Recon Loss:    0.27265 | Val Recon Loss:    0.26949
Train KLD: 560804.62670 | Val KLD: 464944.85457
Train Reg Loss:    0.06890 | Val Reg Loss:    0.04919



 11%|█         | 13/120 [58:38<8:03:17, 271.00s/it]

Epoch:  13 | Train Loss:    0.35281 | Val Loss:    0.34637 | Time Taken: 58m 38s
Train Recon Loss:    0.26731 | Val Recon Loss:    0.26553
Train KLD: 353644.58629 | Val KLD: 291051.82843
Train Reg Loss:    0.07086 | Val Reg Loss:    0.06879

Epoch:  14 | Train Loss:    0.36521 | Val Loss:    0.39037 | Time Taken: 63m 9s
Train Recon Loss:    0.26605 | Val Recon Loss:    0.26173
Train KLD: 214898.17173 | Val KLD: 178387.08428
Train Reg Loss:    0.07498 | Val Reg Loss:    0.10856



 12%|█▎        | 15/120 [1:07:40<7:54:27, 271.12s/it]

Epoch:  15 | Train Loss:    0.38132 | Val Loss:    0.35600 | Time Taken: 67m 40s
Train Recon Loss:    0.26018 | Val Recon Loss:    0.25995
Train KLD: 130615.86321 | Val KLD: 106427.67803
Train Reg Loss:    0.08118 | Val Reg Loss:    0.06350

Epoch:  16 | Train Loss:    0.40291 | Val Loss:    0.39077 | Time Taken: 72m 10s
Train Recon Loss:    0.25664 | Val Recon Loss:    0.25484
Train KLD: 72825.77056 | Val KLD: 57597.44404
Train Reg Loss:    0.08572 | Val Reg Loss:    0.08803



 14%|█▍        | 17/120 [1:16:40<7:44:43, 270.71s/it]

Epoch:  17 | Train Loss:    0.43758 | Val Loss:    0.39000 | Time Taken: 76m 41s
Train Recon Loss:    0.25428 | Val Recon Loss:    0.25293
Train KLD: 37024.06325 | Val KLD: 30136.00053
Train Reg Loss:    0.09961 | Val Reg Loss:    0.06895

Epoch:  18 | Train Loss:    0.49727 | Val Loss:    0.57935 | Time Taken: 81m 11s
Train Recon Loss:    0.25746 | Val Recon Loss:    0.25086
Train KLD: 19808.79817 | Val KLD: 18861.79002
Train Reg Loss:    0.11810 | Val Reg Loss:    0.21260



 16%|█▌        | 19/120 [1:25:45<7:37:13, 271.62s/it]

Epoch:  19 | Train Loss:    0.58217 | Val Loss:    0.51458 | Time Taken: 85m 45s
Train Recon Loss:    0.24857 | Val Recon Loss:    0.24845
Train KLD: 11205.97558 | Val KLD: 10033.31150
Train Reg Loss:    0.14645 | Val Reg Loss:    0.09856

Epoch:  20 | Train Loss:    0.74688 | Val Loss:    0.68686 | Time Taken: 90m 16s
Train Recon Loss:    0.24603 | Val Recon Loss:    0.24633
Train KLD: 6410.26089 | Val KLD: 6211.34798
Train Reg Loss:    0.20983 | Val Reg Loss:    0.15854



 18%|█▊        | 21/120 [1:34:49<7:29:00, 272.12s/it]

Epoch:  21 | Train Loss:    0.96316 | Val Loss:    0.81306 | Time Taken: 94m 50s
Train Recon Loss:    0.24435 | Val Recon Loss:    0.24477
Train KLD: 3772.97719 | Val KLD: 3253.62854
Train Reg Loss:    0.25324 | Val Reg Loss:    0.16681

Epoch:  22 | Train Loss:    1.67968 | Val Loss:    1.29435 | Time Taken: 99m 21s
Train Recon Loss:    0.24222 | Val Recon Loss:    0.24237
Train KLD: 2821.88906 | Val KLD: 1987.66743
Train Reg Loss:    0.49114 | Val Reg Loss:    0.38541



 19%|█▉        | 23/120 [1:43:52<7:19:04, 271.59s/it]

Epoch:  23 | Train Loss:    2.29750 | Val Loss:    2.02929 | Time Taken: 103m 52s
Train Recon Loss:    0.24066 | Val Recon Loss:    0.24026
Train KLD: 1085.85870 | Val KLD:  898.92842
Train Reg Loss:    1.06757 | Val Reg Loss:    0.97006

Epoch:  24 | Train Loss:    4.13956 | Val Loss:    2.44916 | Time Taken: 108m 22s
Train Recon Loss:    0.24237 | Val Recon Loss:    0.25369
Train KLD:  751.62386 | Val KLD:   47.88138
Train Reg Loss:    2.03871 | Val Reg Loss:    2.07707



 21%|██        | 25/120 [1:52:53<7:09:13, 271.09s/it]

Epoch:  25 | Train Loss:    2.51757 | Val Loss:    2.49287 | Time Taken: 112m 53s
Train Recon Loss:    0.23832 | Val Recon Loss:    0.23721
Train KLD:   28.81396 | Val KLD:   26.60832
Train Reg Loss:    2.08640 | Val Reg Loss:    2.07757

Epoch:  26 | Train Loss:    5.71003 | Val Loss:    4.33630 | Time Taken: 117m 23s
Train Recon Loss:    0.23529 | Val Recon Loss:    0.23605
Train KLD:  188.13993 | Val KLD:  111.70041
Train Reg Loss:    2.09081 | Val Reg Loss:    2.09119



 22%|██▎       | 27/120 [2:01:54<6:59:42, 270.78s/it]

Epoch:  27 | Train Loss:    7.38696 | Val Loss:   12.92414 | Time Taken: 121m 54s
Train Recon Loss:    0.23385 | Val Recon Loss:    0.23424
Train KLD:  106.78203 | Val KLD:  223.76273
Train Reg Loss:    2.08889 | Val Reg Loss:    2.07775



In [None]:
trainer.plot_losses()

In [None]:
reg_mean_train = df.iloc[train_idx].reg_col.mean()
mean_mae_test = np.mean(np.abs(df.iloc[test_idx].reg_col - reg_mean_train))
vae_mae_test = get_pred_mae(reg_col, df, test_idx, bs).item()
errors[reg_col] = {'mean_mae_test': mean_mae_test, 'vae_mae_test': vae_mae_test}

## Train and Test the Network for _QED_

In [None]:
reg_col = 'qed'
df, train_idx, valid_idx, test_idx = vae_utils.get_train_valid_test_splits(reg_col, valid_pct)
train_dl, valid_dl = get_dataloaders(df, train_idx, valid_idx, max_samples, bs)
trainer = Trainer(reg_col)
trainer.fit(epochs, save_every, anneal_start, lr, train_dl, valid_dl, device, load_previous)

In [None]:
trainer.plot_losses()

In [None]:
reg_mean_train = df.iloc[train_idx].reg_col.mean()
mean_mae_test = np.mean(np.abs(df.iloc[test_idx].reg_col - reg_mean_train))
vae_mae_test = get_pred_mae(reg_col, df, test_idx, bs).item()
errors[reg_col] = {'mean_mae_test': mean_mae_test, 'vae_mae_test': vae_mae_test}

## Report the Results

In [None]:
pd.DataFrame(errors).transpose()

## Remarks

The model was trained in 4 stages. Except for the 1st stage, the other two stages were trained after loading the saving parameters from the previous stage.
The stages are summarized below:

1. max_exmaples = 20,000, epochs = 50, anneal_start = 15, load_previous = False, bs = 200
2. max_examples = 100,000, epochs = 15, anneal_start = 5, load_previous = True, bs = 200
3. max_examples = all, epochs = 10, anneal_start = 3, load_previous = True, bs = 200
4. max_examples = all, epochs = 15, anneal_start = 5, load_previous = True, bs = 1000
