In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

from made import MADE

from train import train_forward
from scores import log_likelihood
from scores import difference_loglik

from utils import update_device

In [None]:
from data.bsds300 import BSDS300
from data.gas import Gas
from data.hepmass import Hepmass
from data.miniboone import Miniboone
from data.power import Power

In [None]:
from flows import create_iaf
from flows import create_maf
from flows import create_paf
from flows import create_realnvp
from flows import create_flows

In [None]:
from structure.ar import AR
from structure.iar import IAR
from structure.twoblock import TwoBlock

from transforms.affine import Affine
from transforms.piecewise import PiecewiseAffine
from transforms.piecewise_additive import PiecewiseAffineAffine

In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device_cpu = torch.device("cpu")

In [None]:
#dataset = BSDS300()
#dataset = Gas()
#dataset = Hepmass()
#dataset = Miniboone()
dataset = Power()

print('Training size:', dataset.train_n, 'Validation size:', dataset.valid_n, 'Test size:', dataset.test_n)
print('Dimension:', dataset.dim_input)

In [None]:
dim_input = dataset.dim_input
dim_hidden = [10,10,10]
num_trans = 5
perm_type = 'random'

In [None]:
flows, names = [], []
#flows.append(create_iaf(dim_input, dim_hidden, num_trans, perm_type)), names.append('IAF')
#flows.append(create_maf(dim_input, dim_hidden, num_trans, perm_type)),names.append('MAF')
#flows.append(create_paf(dim_input, dim_hidden, num_trans, perm_type, structures=AR)), names.append('PAF')
#flows.append(create_realnvp(dim_input, dim_hidden, num_trans, perm_type)), names.append('Real NVP')

#transformations = [Affine,PiecewiseAffine, Affine, PiecewiseAffine, Affine, PiecewiseAffine, Affine, PiecewiseAffine, Affine, PiecewiseAffine]
#structures = AR
#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, structure=structures, transformation=transformations)), names.append('PAF/IAF')

flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, structure=AR, transformation=PiecewiseAffineAffine)), names.append('PAAF')

In [None]:
losses = []
optimizers = []

epochs = 10
batch_size = 800
for i in range(len(flows)):
    flow = flows[i]
    update_device(device, flow, dataset)

    optimizer = torch.optim.AdamW(flow.parameters(), lr=1e-3, weight_decay=1e-2)
    #optimizer = torch.optim.SGD(flow.parameters(), lr=1e-3)
    optimizers.append(optimizer)

    losses.append(train_forward(flow, flow.get_base_distr(), dataset.get_training_data(), optimizer, epochs, batch_size, print_n=1, name=names[i]))

    update_device(device_cpu, flow, dataset)

In [None]:
#Additional training with same optimizer
epochs = 5
if epochs > 0:
    for i in range(len(flows)):
        flow = flows[i]
        update_device(device, flow, dataset)

        optimizer = optimizers[i]

        losses[i] += (train_forward(flow, flow.get_base_distr(), dataset.get_training_data(), optimizer, epochs, batch_size, print_n=10, name=names[i]))

        update_device(device_cpu, flow, dataset)

In [None]:
log_scale = False
num_epoch_skip = 0

plt.subplot(2,1,1)
for i in range(len(losses)):
    plt.plot(losses[i], label=names[i], alpha=0.8)
plt.legend()

plt.subplot(2,1,2)
for i in range(len(losses)):
    plt.plot((losses[i])[num_epoch_skip:], label=names[i], alpha=0.8)
plt.legend()

if log_scale:
    plt.yscale('log')

In [None]:
print('Results based on training data:' + '\n')

train_data = dataset.get_training_data()
for i in range(len(flows)):
    log_lik, mean = log_likelihood(train_data, flows[i])
    print("Mean loglikelihood for " + names[i] + ":" + str(mean))

In [None]:
print('Results based on validation data' + '\n')

valid_data = dataset.get_validation_data()
for i in range(len(flows)):           
    log_lik, mean = log_likelihood(valid_data, flows[i])
    print("Mean loglikelihood for " + names[i] + ":" + str(mean))

In [None]:
print('Results based on test data' + '\n')

test_data = dataset.get_test_data()
for i in range(len(flows)):           
    log_lik, mean = log_likelihood(test_data, flows[i])
    print("Mean loglikelihood for " + names[i] + ":" + str(mean))