## Setup

In [None]:
import sys, os
sys.path.append(os.path.abspath('../src'))
sys.path.append(os.path.abspath('..'))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

torch.set_grad_enabled(False)
device = torch.device('cuda:0')

import numpy as np
from scipy import stats
import matplotlib.pyplot as plt
from datetime import datetime
from tqdm import tqdm

# Custom lib
from model import LatentNeuralODEBuilder
from utils import gpu, asnp

In [None]:
class SineSet(Dataset):
    def __init__(self, data, time):
        self.data = data
        self.time = time
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.time
    
class GenericSet(Dataset):
    def __init__(self, data, time):
        self.data = data
        self.time = time
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return self.data[idx], self.time

In [None]:
def parse_experiment_files(path_dir):
    files = [f for f in os.listdir(path_dir) if f[0] != '.']
    files = [f.split('_') for f in files]

    experiments = {}

    for f in files:
        dataset = f[0]
        exp_key = f[1:2] + f[3:-1]
        exp_key = '_'.join(exp_key)

        exp_run = f[-1]

        if dataset not in experiments:
            experiments[dataset] = {}

        if exp_key in experiments[dataset]:
            experiments[dataset][exp_key].append(exp_run)
        else:
            experiments[dataset][exp_key] = [exp_run]

    return experiments

def load_models(experiments, dataset, target, dir_path):
    if dataset not in experiments or target not in experiments[dataset]:
        raise KeyError()
        
    runs = experiments[dataset][target]
    
    models = []
    
    for run in runs:
        model = target.split('_')[0]
        params = '_'.join(target.split('_')[1:])
        path = "{}_{}_lode_{}_{}".format(dataset, model, params, run)
        
        model = torch.load(dir_path + '/' + path)
        models.append(model)
        
    return models

def load_model_params(model_data, exp_type, inds=None):
    
    if inds is None:
        inds = range(len(model_data))
    
    elbo_type = exp_type.split('_')[0]

    loaded_models = []
    
    if elbo_type in ['base', 'betavae']:
        elbo_type = 'iwae'
    
    for i in inds:
        builder = LatentNeuralODEBuilder(**model_data[i]['model_args'])
        model = builder.build_latent_node(elbo_type).to(device)
        model.load_state_dict(model_data[i]['model_state_dict'])

        loaded_models.append(model)
        
    return loaded_models

In [None]:
dataset = 'aussign'
model_path = '../models'
exp_type = 'miwae_3_8'

experiments = parse_experiment_files(model_path)
model_data = load_models(experiments, dataset, exp_type, model_path)

# Load Data
if dataset == 'sine':
    generator = torch.load('../' + model_data[0]['data_path'])['generator']
    test_time, test_data = generator.get_test_set()
    test_data = test_data.reshape(len(test_data), -1, 1)

    test_data_tt = gpu(test_data)
    test_time_tt = gpu(test_time)

    test_dataset = SineSet(test_data_tt, test_time_tt)
    test_loader = DataLoader(test_dataset, batch_size = len(test_dataset))

else:
    data = torch.load('../' + model_data[0]['data_path'])

    test_data = data['test_dataset']

    test_time = list(range(test_data.shape[1]))
    
    test_data_tt = gpu(test_data)
    test_time_tt = gpu(test_time)

    test_dataset = SineSet(test_data_tt, test_time_tt)
    test_loader = DataLoader(test_dataset, batch_size = len(test_dataset))

## Loss Calculation

In [None]:
def get_loss(exp_type, loaded_models, data, time, args):
    elbo_type = exp_type.split('_')[0]

    if elbo_type in ['base', 'betavae']:
        elbo_type = 'iwae'

    mses = []
    elbos = []

    for model in loaded_models:
        out = model.forward(data, time, args)
        elbo = model.get_elbo(data, *out, args)

        if elbo_type in ['miwae', 'ciwae']:
            pred_x = torch.mean(torch.mean(out[0], 2), 1)
        elif elbo_type == 'iwae':
            pred_x = torch.mean(out[0], 1)
        elif elbo_type == 'piwae':
            pred_x = out[0].view(data.shape[0], args['M'], args['K'], data.shape[1], data.shape[2])
            pred_x = torch.mean(torch.mean(pred_x, 1), 1)
            
            elbo = (elbo[0] + elbo[0]) / 2
        else:
            pred_x = out[0]
        
        mse = nn.MSELoss()(data, pred_x)
        
        elbos.append(elbo.item())
        mses.append(mse.item())

    return mses, elbos

def get_loss_batch(exp_type, loaded_models, data, time, args, batch_size=None):
    if batch_size:
        data = data.view(batch_size, data.shape[0] // batch_size, *data.shape[1:])
        mses = []
        elbos = []

        for d in tqdm(data):
            mse, elbo = get_loss(exp_type, loaded_models, d, time, args)
            mses.append(mse)
            elbos.append(elbo)
        
        # This is ok since our subsamples are same size.
        return np.mean(mses, axis=0), np.mean(elbos, axis=0)

    else:
        return get_loss(exp_type, loaded_models, data, time, args)

In [None]:
def compute_losses(experiments, dataset, model_path, data, time):
    losses = {}

    for exp in tqdm(experiments[dataset]):
        model_data = load_models(experiments, dataset, exp, model_path)
        loaded_models = load_model_params(model_data, exp)
        train_args = model_data[0]['train_args']
        
        train_args['M'] = 1
        train_args['K'] = 250
        train_args['beta'] = 0
        mses, elbos = get_loss_batch(exp, loaded_models, data, time, train_args, None)

        del loaded_models
        del model_data
        
        losses[exp] = [mses, elbos]
    return losses

losses = compute_losses(experiments, dataset, model_path, test_data_tt, test_time_tt)

In [None]:
best_model = 'base_1_1'
best_mses = losses[best_model][0]
best_elbos = losses[best_model][1]

sorted_losses = dict(sorted(losses.items(), key=lambda item: np.mean(item[1][0])))

for exp_type, (mses, elbos) in sorted_losses.items():
    mses = [m for m in mses if m < 3]
    print('Experiment: {}'.format(exp_type))
    print("MSE: {:.4f}±{:.4f}".format(np.mean(mses), np.std(mses)))
    print("log p(x): {:.2f}±{:.2f}".format(np.mean(elbos), np.std(elbos)))

    p_val_elbo = stats.ttest_ind(best_elbos, elbos).pvalue
    p_val_mse = stats.ttest_ind(best_mses, mses).pvalue
    print('mse p-value vs best: {}'.format(p_val_mse))
    print('log p(x) p-value vs best: {}'.format(p_val_elbo))
    print('---')

## Training Dynamics

In [None]:
def get_train_times(experiments, dataset):
    runtimes = {}

    for exp in experiments[dataset]:
        model_data = load_models(experiments, dataset, exp, model_path)
        rts = []
        for d in model_data:
            rts.append(np.mean(d['train_obj'].runtimes))
        
        runtimes[exp] = np.mean(rts)

    sorted_runtimes = dict(sorted(runtimes.items(), key=lambda item: item[1]))
    for k, v in sorted_runtimes.items():
        print("Exp: {} Avg Epoch Time: {}".format(k, v))
    
    return sorted_runtimes

In [None]:
runtimes = get_train_times(experiments, dataset)

In [None]:
def moving_average(x, w):
    return np.convolve(x, np.ones(w), 'valid') / w

def plot_val_loss(ax, data, title, trunc_l=0, trunc_r=400):
    loss_hists = [d['train_obj'].val_loss_hist[trunc_l:trunc_r] for d in data]
    
    avg_loss_hist = [moving_average(np.array(h), 15) for h in loss_hists]
    
    mean_val = np.mean(avg_loss_hist, 0)
    std_val = np.std(avg_loss_hist, 0)

    time = range(len(avg_loss_hist[0]))

    if title == 'base_1_1':
        ax.plot(time, mean_val, label=title, c='k', lw=3)
    else:
        ax.plot(time, mean_val, label=title)

    
    ax.fill_between(time, mean_val + std_val, mean_val - std_val, alpha=0.2)
    #ax.fill_between(time, np.min(avg_loss_hist, axis=0), np.max(avg_loss_hist, axis=0), alpha=0.2)

In [None]:
target_models = ['miwae_3_8', 'base_1_1', 'iwae_1_5', 'iwae_1_25', 'piwae_3_8', 'ciwae_6_4_0.5']

fig, ax = plt.subplots(2, 1, figsize=(15, 10))

for exp in target_models:
    data = load_models(experiments, dataset, exp, model_path)
    plot_val_loss(ax[0], data, exp)

    plot_val_loss(ax[1], data, exp)

ax[0].set_xlim(0, 150)
ax[0].legend(loc='right')
ax[0].set_xlabel('Epoch')
ax[0].set_ylabel('- ELBO')

ax[1].set_ylim(-400, 1000)
ax[1].set_xlim(150, 375)
ax[1].set_xlabel('Epoch')
ax[1].set_ylabel('- ELBO')
plt.show()

## Visualization

In [None]:
dataset = 'aussign'
model_path = '../models'
exp_type = 'iwae_1_5'

experiments = parse_experiment_files(model_path)
model_data = load_models(experiments, dataset, exp_type, model_path)

loaded_models = load_model_params(model_data, exp_type)
model = loaded_models[1]
args = model_data[0]['train_args']

In [None]:
tp = np.sort(np.random.uniform(0, 7, 200))
tp = np.linspace(0, 7, 200)
s1 = torch.Tensor(np.sin(tp * 3) * 5)
s2 = torch.Tensor(np.sin(tp * 3) * -1)
plt.plot(tp, s1)
plt.plot(tp, s2)
plt.plot(tp, s1+s2, ls='--')

data_1 = s1.view(-1, 1).to(device)
data_2 = s2.view(-1, 1).to(device)
tp = torch.Tensor(tp).to(device)

In [None]:
data_1 = test_data_tt[0]
data_2 = test_data_tt[25]
tp = test_time_tt

plt.plot(asnp(test_time_tt), asnp(data_1))
plt.plot(asnp(test_time_tt), asnp(data_2))

plt.plot(asnp(test_time_tt), asnp(data_2 + data_1), ls='--')
plt.show()

In [None]:
out1 = model.forward(data_1.unsqueeze(0), tp, args)
out2 = model.forward(data_2.unsqueeze(0), tp, args)

z0_1 = torch.mean(out1[1], 1)
z0_2 = torch.mean(out2[1], 1)

# Latent arithmetic
z0_comb = z0_1 + z0_2
z0_comb = torch.mean(z0_comb, 1)

# Mean vector
z0_comb = torch.mean(torch.cat([z0_1, z0_2]), axis=0)

pred_z = model.generate_from_latent(z0_comb, tp, args['model_rtol'],
                                    args['model_atol'], args['method'])
pred_x = model.dec(pred_z)

plt.plot(asnp(tp), asnp(pred_x[0]))
plt.plot(asnp(tp), s1+s2, ls='--')

In [None]:
ind = 120
out_1 = model.forward(test_data_tt[ind].unsqueeze(0), test_time_tt, args)


out = torch.mean(torch.mean(out_1[0], 1), 1)
out = torch.mean(out_1[0], 1)

d_ind = 2
plt.plot(asnp(test_time_tt), asnp(out)[0][:, d_ind], ls='--')
plt.plot(asnp(test_time_tt), asnp(test_data_tt[ind])[:, d_ind])
plt.show()