In [1]:
%config InlineBackend.figure_format='retina'
%matplotlib inline
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
import arviz as az

import numpy as np
from tabulate import tabulate
import pandas as pd
import pickle
import torch
import numpy as np
from sklearn import preprocessing

import numpyro
from numpyro.infer import MCMC, NUTS, Predictive
import numpyro.distributions as dist
from jax import random

NUM_CHAINS = 3
numpyro.set_host_device_count(NUM_CHAINS)

print(f"Running on NumPryo v{numpyro.__version__}")

Running on NumPryo v0.8.0


In [2]:
CODE_DIR = "/home/cbarkhof/fall-2021"
import sys; sys.path.append(CODE_DIR)
from analysis.analysis_steps import make_run_overview_df

prefixes = ["(mdr-vae-exp 8 oct)", "(fb-vae-exp 8 oct) ", "(beta-vae-exp 6 oct) ", "(inf-vae-exp 5 oct) "]
run_df = make_run_overview_df(prefixes=prefixes, add_data_group=False)
run_df.drop("run_name", axis=1)

Unnamed: 0,objective,l_rate,beta_beta,free_bits,mdr_value,l_mmd,decoder
MDR-VAE 40 dec: CNN.T,MDR-VAE,0,0,0,40,0,basic_deconv_decoder
MDR-VAE 32 dec: CNN.T,MDR-VAE,0,0,0,32,0,basic_deconv_decoder
MDR-VAE 24 dec: CNN.T,MDR-VAE,0,0,0,24,0,basic_deconv_decoder
MDR-VAE 16 dec: CNN.T,MDR-VAE,0,0,0,16,0,basic_deconv_decoder
MDR-VAE 8 dec: CNN.T,MDR-VAE,0,0,0,8,0,basic_deconv_decoder
...,...,...,...,...,...,...,...
INFO-VAE l_Rate 100 l_MMD 100 dec: CNN.T,INFO-VAE,100,0,0,0,100,basic_deconv_decoder
INFO-VAE l_Rate 1000 l_MMD 1 dec: CNN.T,INFO-VAE,1000,0,0,0,1,basic_deconv_decoder
INFO-VAE l_Rate 100 l_MMD 1000 dec: PixelCNN++,INFO-VAE,100,0,0,0,1000,cond_pixel_cnn_pp
INFO-VAE l_Rate 1 l_MMD 1000 dec: CNN.T,INFO-VAE,1,0,0,0,1000,basic_deconv_decoder


In [3]:
ANALYSIS_DIR = f"{CODE_DIR}/analysis/analysis-files"
TEST_VALID_EVAL_FILE = "test-valid-results.pt"
KNN_PREDICT_STATS_FILE = "knn-preds-stats.pickle"
DATA_SPACE_STATS = "data_space_stats.pickle"
REDO_MMD_RESULT_FILE = "redo_mmd.pt"

In [4]:
Y_COLS_ALL = ["y_L0_sample", "y_L2_sample", "y_KL_sample"]
X_COLS_ALL = ["distortion mean", "kl_prior_post mean", "elbo mean"]

# Predictors X
X = []

# Some statistics on samples Y that we want to model
y_L0_sample = [] # data_space_stats: L_0_all
y_KL_sample = [] # knn_predict_stats kl_instance_marg_pred
y_L2_sample = [] # data_space_stats: L2_all_data

run_names, clean_names = [], []

# How many Y variables per model type
N_samples_per_model = 100

for row_index, row in run_df.iterrows():
    run_name = row["run_name"]
    clean_name = row_index
    clean_names.append(clean_name)
    run_names.append(run_name)
    save_dir = f"{ANALYSIS_DIR}/{run_name}"
    
    # Get X features
    test_valid = torch.load(f"{save_dir}/{TEST_VALID_EVAL_FILE}")
    stats = []
    for m in X_COLS_ALL:
        stats.append(test_valid["test"][m])
    
    # get MMD from somewhere else
    mmd_redo = torch.load(f"{save_dir}/{REDO_MMD_RESULT_FILE}")
    stats.append(mmd_redo["mmd_redo_test"])
    X.append(stats)
    
    # Get y features
    knn_stats = pickle.load(open(f"{save_dir}/{KNN_PREDICT_STATS_FILE}", "rb"))
    data_stats = pickle.load(open(f"{save_dir}/{DATA_SPACE_STATS}", "rb"))
    
    y_KL_sample.append(knn_stats["samples"]["kl_instance_marg_pred"][:N_samples_per_model]) # <- max N samples per model
    y_L0_sample.append(data_stats["samples"]["L0_all"][:N_samples_per_model]) # <- max N samples per model
    y_L2_sample.append(data_stats["samples"]["L2_all"][:N_samples_per_model]) # <- max N samples per model

X_COLS_ALL.append("MMD test")

# Construct data DF
X = np.array(X)
run_names = np.array(run_names)[:, None]
clean_names = np.array(clean_names)[:, None]
columns = ["run_name", "clean_name"] + X_COLS_ALL
data = np.concatenate([run_names, clean_names, X], axis=1)
df_data = pd.DataFrame(data, columns=columns)
df_data["y_L0_sample"] = y_L0_sample
df_data["y_L2_sample"] = y_L2_sample
df_data["y_KL_sample"] = y_KL_sample

# Merge with config DF
all_df = run_df.merge(df_data, left_on='run_name', right_on='run_name')

# Explode with y values (copy X and configs) -> len becomes N_samples times N_models in run_df
all_df = all_df.apply(lambda x: x.explode() if x.name in Y_COLS_ALL else x)
assert len(all_df) == len(run_df) * N_samples_per_model, "len(all_df) should be equal to len(run_df) * N_samples_per_model"
all_df

Unnamed: 0,objective,l_rate,beta_beta,free_bits,mdr_value,l_mmd,decoder,run_name,clean_name,distortion mean,kl_prior_post mean,elbo mean,MMD test,y_L0_sample,y_L2_sample,y_KL_sample
0,MDR-VAE,0,0,0,40,0,basic_deconv_decoder,(mdr-vae-exp 8 oct) MDR-VAE[R>=40.0] | q(z|x) ...,MDR-VAE 40 dec: CNN.T,67.41467658996582,39.4882266998291,-106.90290313720703,0.14857995510101318,69.000000,65.980026,1.505243
0,MDR-VAE,0,0,0,40,0,basic_deconv_decoder,(mdr-vae-exp 8 oct) MDR-VAE[R>=40.0] | q(z|x) ...,MDR-VAE 40 dec: CNN.T,67.41467658996582,39.4882266998291,-106.90290313720703,0.14857995510101318,7.000000,59.059410,0.359951
0,MDR-VAE,0,0,0,40,0,basic_deconv_decoder,(mdr-vae-exp 8 oct) MDR-VAE[R>=40.0] | q(z|x) ...,MDR-VAE 40 dec: CNN.T,67.41467658996582,39.4882266998291,-106.90290313720703,0.14857995510101318,19.000000,75.636040,0.333165
0,MDR-VAE,0,0,0,40,0,basic_deconv_decoder,(mdr-vae-exp 8 oct) MDR-VAE[R>=40.0] | q(z|x) ...,MDR-VAE 40 dec: CNN.T,67.41467658996582,39.4882266998291,-106.90290313720703,0.14857995510101318,11.000000,67.817955,0.160689
0,MDR-VAE,0,0,0,40,0,basic_deconv_decoder,(mdr-vae-exp 8 oct) MDR-VAE[R>=40.0] | q(z|x) ...,MDR-VAE 40 dec: CNN.T,67.41467658996582,39.4882266998291,-106.90290313720703,0.14857995510101318,11.000000,71.190613,0.592991
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
75,INFO-VAE,1000,0,0,0,100,cond_pixel_cnn_pp,"(inf-vae-exp 5 oct) INFO-VAE[l_1_rate=1000.0, ...",INFO-VAE l_Rate 1000 l_MMD 100 dec: PixelCNN++,80.36876289367676,4.050742158142384e-05,-80.36880111694336,9.417533874511719e-05,56.000000,58.664700,1.878116
75,INFO-VAE,1000,0,0,0,100,cond_pixel_cnn_pp,"(inf-vae-exp 5 oct) INFO-VAE[l_1_rate=1000.0, ...",INFO-VAE l_Rate 1000 l_MMD 100 dec: PixelCNN++,80.36876289367676,4.050742158142384e-05,-80.36880111694336,9.417533874511719e-05,89.000000,56.172970,2.260882
75,INFO-VAE,1000,0,0,0,100,cond_pixel_cnn_pp,"(inf-vae-exp 5 oct) INFO-VAE[l_1_rate=1000.0, ...",INFO-VAE l_Rate 1000 l_MMD 100 dec: PixelCNN++,80.36876289367676,4.050742158142384e-05,-80.36880111694336,9.417533874511719e-05,52.000000,62.154228,1.878065
75,INFO-VAE,1000,0,0,0,100,cond_pixel_cnn_pp,"(inf-vae-exp 5 oct) INFO-VAE[l_1_rate=1000.0, ...",INFO-VAE l_Rate 1000 l_MMD 100 dec: PixelCNN++,80.36876289367676,4.050742158142384e-05,-80.36880111694336,9.417533874511719e-05,115.000000,64.045074,2.086671


In [5]:
all_df.to_csv("BDA-analysis-data.csv")