In [None]:
import numpy as np
import torch as torch
from helper import set_seeds
from tqdm import tqdm
from utils.q_model_ens import MultivariateQuantileModel
from torch.utils.data import DataLoader, TensorDataset
from losses import multivariate_qr_loss
from helper import generate_directions
from plot_helper import evaluate_conditional_performance
import argparse
import os
import warnings
from datasets import datasets
from transformations import CVAETransform, ConditionalIdentityTransform
from directories_names import get_cvae_model_save_name, get_save_final_figure_results_dir, get_model_summary_save_dir, \
    get_save_final_results_dir
from main import *
import ast
from argparse import Namespace
import matplotlib
from sys import platform

if platform not in ['win32', 'darwin']:
    matplotlib.use('Agg')

warnings.filterwarnings("ignore")



%load_ext autoreload
%autoreload 2

device_name = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device_name)


# Parameters setting

In [None]:
seed = 0
args =  Namespace(
    seed=seed,
    tau=0.1,
    dataset_name = 'nonlinear_cond_banana_k_dim_1',
    ds_type='SYN',
    num_ep=10000,
    hs= "[64, 64, 64]",
    dropout= 0.,
    lr= 1e-3,
    wd= 0,
    bs= 256,
    wait= 100,
    test_ratio= 0.2,
    calibration_ratio=0.4,
    device=device,
    num_ens=1,
    gpu=1,
    boot=0,
    num_u = 32,
    transform='CVAE',
    vae_loss = 'KL',
    vae_z_dim=3,
    vae_mode='CVAE',
    suppress_plots=0,
)

args = parse_args_utils(args)

seed = args.seed
set_seeds(seed)

# Load dataset

In [None]:
dataset_name = args.dataset_name
print("dataset_name: ", dataset_name, "transformation: ", args.transform,
      f"tau: {args.tau}, conformalization tau: {args.conformalization_tau}, seed={args.seed}")

test_ratio = args.test_ratio
calibration_ratio = args.calibration_ratio
val_ratio = 0.2

is_real = 'real' in args.ds_type.lower()
scale = is_real
data = datasets.get_split_data(dataset_name, is_real, device, test_ratio, val_ratio, calibration_ratio, seed, scale)
x_train, x_val, y_train, y_val, x_test, y_te, = data['x_train'], data['x_val'], \
                                                data['y_train'], data['y_val'], \
                                                data['x_test'], data['y_te']
scale_x = data['scale_x']
scale_y = data['scale_y']
x_dim = x_train.shape[1]

if calibration_ratio > 0:
    x_cal, y_cal = data['x_cal'], data['y_cal']


# Learn the transformation

## Run train_vae.py to learn a transformation (CVAE) between $\mathcal{Y}$ and $\mathcal{Z}$.

# Transformation

In [None]:
if args.transform == 'identity':
    transform = ConditionalIdentityTransform()
elif args.transform == "VAE" or args.transform == "CVAE":
    transform = CVAETransform(
        get_cvae_model_save_name(dataset_name, seed, args.vae_loss, args.vae_z_dim, args.vae_mode), device=device)
else:
    print("transform must be one of 'identity', 'VAE', 'CVAE")
    assert False

untransformed_y_train = y_train
y_train = transform.cond_transform(y_train, x_train)
y_val = transform.cond_transform(y_val, x_val)


# Train the model

In [None]:
dim_y = y_train.shape[1]
y_grid_size = 3e5 if untransformed_y_train.shape[1] >= 3 else 3e3
model_ens = MultivariateQuantileModel(input_size=x_dim, y_size=dim_y,
                                      hidden_dimensions=args.hs, dropout=args.dropout,
                                      lr=args.lr, wd=args.wd, num_ens=args.num_ens, device=args.device, y_grid_size=y_grid_size)

# Data loader
loader = DataLoader(TensorDataset(x_train, y_train),
                    shuffle=True,
                    batch_size=args.bs)

# Loss function
loss_fn = multivariate_qr_loss
batch_loss = True
assert len(args.tau_list) == 1
eval_losses = []
train_losses = []
for ep in tqdm(range(args.num_ep)):

    if model_ens.done_training:
        break

    # Take train step
    ep_train_loss = []  # list of losses from each batch, for one epoch
    for batch in loader:
        u_list, gamma = generate_directions(dim_y, args.num_u, args.tau_list[0])
        args.gamma = gamma

        (xi, yi) = batch
        loss = model_ens.loss(loss_fn, xi, yi, u_list,
                              batch_q=batch_loss,
                              take_step=True, args=args)

        ep_train_loss.append(loss)

    ep_tr_loss = np.nanmean(np.stack(ep_train_loss, axis=0), axis=0).item()
    train_losses += [ep_tr_loss]

    # Validation loss
    y_val = y_val.to(args.device)
    u_list, gamma = generate_directions(dim_y, args.num_u, args.tau_list[0])
    args.gamma = gamma

    ep_va_loss = model_ens.update_va_loss(
        loss_fn, x_val, y_val, u_list,
        batch_q=batch_loss, curr_ep=ep, num_wait=args.wait,
        args=args)
    eval_losses += [ep_va_loss.item()]


params = {'dataset_name': dataset_name, 'transformation': transform, 'epoch': model_ens.best_va_ep[0],
          'is_real': is_real, 'seed': seed, 'tau': args.conformalization_tau,
          'vae_loss': args.vae_loss, 'vae_z_dim': args.vae_z_dim,
          'dropout': args.dropout, 'hs': str(args.hs), 'vae_mode': args.vae_mode}
base_save_dir = get_save_final_figure_results_dir(**params)
base_results_save_dir = get_save_final_results_dir(**params)
summary_base_save_dir = get_model_summary_save_dir(**params)


# Evaluate the model

In [None]:
evaluate_conditional_performance(model_ens, x_train, untransformed_y_train, y_train, x_test, y_te,
                                 base_save_dir, transform, is_conformalized=False, args=args,
                                 dataset_name=dataset_name, scale_x=scale_x, scale_y=scale_y,
                                 cache=None,
                                 summary_base_save_dir=summary_base_save_dir,
                                 base_results_save_dir=base_results_save_dir, is_real=is_real)


# Conformalize marginal coverage

In [None]:
assert calibration_ratio > 0
model_ens.conformalize(x_cal, y_cal, untransformed_y_train, y_train, transform, args.conformalization_tau,
                       args.tau)

# Evaluate conformalized model

In [None]:
evaluate_conditional_performance(model_ens, x_train, untransformed_y_train, y_train, x_test, y_te,
                                 base_save_dir, transform, is_conformalized=True, args=args,
                                 dataset_name=dataset_name, scale_x=scale_x, scale_y=scale_y, cache=None,
                                 summary_base_save_dir=summary_base_save_dir,
                                 base_results_save_dir=base_results_save_dir, is_real=is_real)
