In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt

import ray
from ray import tune
from ray.tune import CLIReporter
from ray.tune.schedulers import ASHAScheduler

from NormalizingFlows.src.train import train_forward_with_tuning, train_forward
from NormalizingFlows.src.utils import update_device, load_best_model, load_checkpoint_model
from NormalizingFlows.src.flows import create_flows

from NormalizingFlows.src.structure.ar import AR 
from NormalizingFlows.src.structure.iar import IAR
from NormalizingFlows.src.structure.twoblock import TwoBlock

from NormalizingFlows.src.transforms.affine import Affine
from NormalizingFlows.src.transforms.piecewise import PiecewiseAffine
from NormalizingFlows.src.transforms.piecewise_additive import PiecewiseAffineAdditive
from NormalizingFlows.src.transforms.piecewise_affine import PiecewiseAffineAffine


In [None]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
device_cpu = torch.device("cpu")

In [None]:
dim_input = dataset.dim_input
num_trans = 5
perm_type = 'random'

In [None]:
dim_hidden = [126,126,105,105, 20, 10]

flows, names = [], []
#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=AR, 
#            transformation=PiecewiseAffine)), names.append('PAF')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=AR, 
#            transformation=PiecewiseAffineAdditive)), names.append('PAFAd')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=AR, 
#            transformation=PiecewiseAffineAffine)), names.append('PAFAf')

flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=AR, 
            transformation=Affine)), names.append('MAF')

#flows.append(create_flows(dim_input, dim_hidden, 2*num_trans, perm_type, flow_forward=False, structure=AR, 
#            transformation=Affine)), names.append('MAF-double')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=TwoBlock,
#            transformation=Affine)), names.append('Real NVP')

#flows.append(create_flows(dim_input, dim_hidden, 2*num_trans, perm_type, flow_forward=False, structure=TwoBlock,
#            transformation=Affine)), names.append('Real NVP-double')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=TwoBlock,
#            transformation=PiecewiseAffine)), names.append('TwoBlock-PAF')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=TwoBlock,
#            transformation=PiecewiseAffineAdditive)), names.append('TwoBlock-PAFAd')

#flows.append(create_flows(dim_input, dim_hidden, num_trans, perm_type, flow_forward=False, structure=TwoBlock,
#            transformation=PiecewiseAffineAffine)), names.append('TwoBlock-PAFAf')

for ind, flow in enumerate(flows):
    flow.name = names[ind]

In [None]:
tuning = False
if tuning:
    losses = []
    optimizers = []

    epochs = 200
    batch_size = 16
    num_hyperparam_samples = 4

    config = {
        'lr': tune.loguniform(1e-4, 1e-1),
        'weight_decay': tune.loguniform(1e-5, 1e-1)
    }
    scheduler = ASHAScheduler(
        time_attr='training_iteration',
        metric="loss",
        mode='min',
        max_t=epochs,
        grace_period=100,
        reduction_factor=2
    )
    reporter=CLIReporter(
        metric_columns=['loss', 'training_iteration']
    )

    for ind, flow in enumerate(flows):
        update_device(device_cpu, flow, dataset)
        result = tune.run(
            partial(train_forward_with_tuning, model=flow, dataset=dataset, epochs=epochs, batch_size=batch_size, print_n=epochs+1, name=names[ind]),
            config=config,
            num_samples=num_hyperparam_samples,
            scheduler=scheduler,
            progress_reporter=reporter,
            verbose=0
        )

        update_device(device_cpu, flow, dataset)
