In [1]:
# boiler plate to allow running in colab, can ignore if running locally
import subprocess
try:
    import tailnflows
except ModuleNotFoundError:
    # need to build environment
    print('installing tailnflows environment...')
    subprocess.run(['pip', 'install', 'git+https://github.com/Tennessee-Wallaceh/tailnflows'])
    import tailnflows
    from tailnflows.utils import configure_colab_env
    configure_colab_env() # can take a while
    import torch
    torch.set_default_device('cuda')

In [3]:
import torch
from tailnflows.models.flows import (
    ModelUse, TTF_m, gTAF, mTAF, _df_to_tailp
)

models = {
    'TTF_m_fixed': lambda dim, dfs: TTF_m(
        dim,
        ModelUse.variational_inference,
        model_kwargs=dict(
            tail_bound=5.,
            rotation=False,
            num_bins=8,
            pos_tail_init=_df_to_tailp(dfs),
            neg_tail_init=_df_to_tailp(dfs),
            fix_tails=True,
        )
    ),
    'TTF_m': lambda dim, _: TTF_m(
        dim,
        ModelUse.variational_inference,
        model_kwargs=dict(
            tail_bound=5.,
            rotation=False,
            num_bins=8,
        )
    ),
    'mTAF': lambda dim, dfs: mTAF(
        dim,
        ModelUse.variational_inference,
        model_kwargs=dict(
            tail_bound=5.,
            rotation=False,
            tail_init=torch.tensor(dfs),
            num_bins=8
        )
    ),
    'gTAF': lambda dim, _: gTAF(
        dim, 
        ModelUse.variational_inference, 
        model_kwargs=dict(tail_bound=5., rotation=False, num_bins=8)
    ),
}

In [4]:
from tailnflows.train.variational_fit import train
from tailnflows.targets.heavy_tailed_nuisance import log_density
from tailnflows.experiments.utils import add_raw_data
import itertools

# experiment parameters
target_dims = [5]
nuisance_dfs = [1., 2., 30.]
repeats = 5
experiment_params = itertools.product(range(repeats), target_dims, nuisance_dfs)

for repeat, dim, nuisance_df in experiment_params:
    print('=' * 20, f'v={nuisance_df}, dim={dim}', '=' * 20)
    dfs = [nuisance_df] * dim # for shift data set all dfs are the same
    for label, model_fcn in models.items():
          torch.manual_seed(repeat)

          model = model_fcn(dim, dfs).to(torch.float32)

          losses, metrics = train(
            model,
            lambda x: log_density(x, heavy_df=nuisance_df),
            lr=1e-3,
            num_epochs=300,
            batch_size=100,
            label=label,
            seed=repeat,
          )

          add_raw_data('vi_newrun', label, (metrics.ess, metrics.psis_k, dim, nuisance_df))





100%|██████████| 300/300 [00:04<00:00, 63.05it/s, elbo=-125.500, tst_ess=0.013, tst_psis=0.939, model=TTF_m_fixed]
100%|██████████| 300/300 [00:04<00:00, 64.01it/s, elbo=-0.480, tst_ess=0.053, tst_psis=0.730, model=TTF_m]
100%|██████████| 300/300 [00:04<00:00, 63.90it/s, elbo=-31.329, tst_ess=0.010, tst_psis=0.930, model=mTAF]
100%|██████████| 300/300 [00:05<00:00, 56.21it/s, elbo=-0.788, tst_ess=0.001, tst_psis=0.821, model=gTAF]




100%|██████████| 300/300 [00:04<00:00, 63.45it/s, elbo=-1.935, tst_ess=0.040, tst_psis=0.768, model=TTF_m_fixed]
100%|██████████| 300/300 [00:05<00:00, 54.37it/s, elbo=-0.332, tst_ess=0.178, tst_psis=0.697, model=TTF_m]
100%|██████████| 300/300 [00:05<00:00, 58.22it/s, elbo=-2.372, tst_ess=0.087, tst_psis=0.818, model=mTAF]
100%|██████████| 300/300 [00:05<00:00, 54.05it/s, elbo=-0.276, tst_ess=0.213, tst_psis=0.837, model=gTAF]




100%|██████████| 300/300 [00:05<00:00, 59.37it/s, elbo=-0.170, tst_ess=0.822, tst_psis=0.365, model=TTF_m_fixed]
100%|██████████| 300/300 [00:05<00:00, 58.27it/s, elbo=-0.338, tst_ess=0.542, tst_psis=0.550, model=TTF_m]
100%|██████████| 300/300 [00:04<00:00, 60.64it/s, elbo=-0.122, tst_ess=0.353, tst_psis=0.794, model=mTAF]
100%|██████████| 300/300 [00:05<00:00, 57.99it/s, elbo=-0.242, tst_ess=0.733, tst_psis=0.487, model=gTAF]




100%|██████████| 300/300 [00:04<00:00, 61.73it/s, elbo=-2.346, tst_ess=0.001, tst_psis=0.713, model=TTF_m_fixed]
100%|██████████| 300/300 [00:04<00:00, 62.16it/s, elbo=-0.318, tst_ess=0.005, tst_psis=0.713, model=TTF_m]
100%|██████████| 300/300 [00:05<00:00, 52.51it/s, elbo=-79.737, tst_ess=0.023, tst_psis=0.878, model=mTAF]
100%|██████████| 300/300 [00:06<00:00, 48.56it/s, elbo=-1.721, tst_ess=0.002, tst_psis=0.839, model=gTAF]




100%|██████████| 300/300 [00:04<00:00, 62.54it/s, elbo=-1.182, tst_ess=0.001, tst_psis=0.714, model=TTF_m_fixed]
 36%|███▌      | 108/300 [00:02<00:03, 52.99it/s, neg_elbo=0.94, model=TTF_m]


KeyboardInterrupt: 