In [None]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import sys
import os
import yaml
from pathlib import Path
from tbparse import SummaryReader
from torch.utils.data import DataLoader
import json
import tqdm

from pinf.datasets.gradients import (
    dS_2D_ToyExample_two_parameters_dalpha,
    dS_2D_ToyExample_two_parameters_dbeta
)

from pinf.datasets.datasets import DataSet_2D_ToyExample_external_two_parameters

from pinf.datasets.log_likelihoods import log_p_2D_ToyExample_two_parameters

from pinf.plot.utils import (
    eval_pdf_on_grid_2D,
    plot_pdf_2D
)

from pinf.models.construct_INN_2D_GMM_two_parameters import set_up_sequence_INN_2D_ToyExample_two_parameters

from pinf.models.INN import INN_Model__MultipleExternalParameters

device = "cuda:0"

In [None]:
alpha_list = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
alpha_list.reverse()
beta_list = [0.2,0.25,1/3,0.5,1.0,2.0,3.0,4.0,5.0]

Load the model

---

In [None]:
def load_INN(base_path:str,use_last:bool = False,device:str = "cuda:0")->INN_Model__MultipleExternalParameters:

    config_i = yaml.safe_load(Path(base_path + "/hparams.yaml").read_text())
    state_dict_folder_i = base_path + f"/checkpoints/"

    files = os.listdir(state_dict_folder_i)
    
    for f in files:

        # Use the last recorded state dict
        if use_last:

            if f == "last.ckpt":
                state_dict_path_i = os.path.join(state_dict_folder_i,f)
                break

        # Use the best performing state dict
        else:
            if f.startswith("checkpoint_epoch"):
                state_dict_path_i = os.path.join(state_dict_folder_i,f)
                break

    config_i["device"] = device

    DS_training = DataSet_2D_ToyExample_external_two_parameters(**config_i["config_data"]["init_data_set_params"],base_path="../../data/2D_Toy_two_external_parameters/")

    INN_i = set_up_sequence_INN_2D_ToyExample_two_parameters(config=config_i,training_set=DS_training)
    INN_i.load_state_dict(state_dict_path_i)
    INN_i.train(False)

    print(state_dict_path_i)

    return INN_i,config_i

Get the best run

---

In [None]:
def find_best_run(experiment_folder_list:list[str]):

    best_mean_validation_KL = None

    for experiment_folder in experiment_folder_list:

        subfolders = os.listdir(experiment_folder)
        
        for subfolder in subfolders:

            full_path = os.path.join(experiment_folder,subfolder)

            print(full_path)

            reader_k = SummaryReader(full_path,extra_columns=set(["wall_time"]))
            df_k = reader_k.scalars
            df_red = df_k[(df_k["tag"] == "model_performance/mean_validation_KL")]
            kl_k = df_red["value"].values

            if (best_mean_validation_KL is None) or (best_mean_validation_KL > kl_k.min()):
                best_folder = full_path
                best_mean_validation_KL = kl_k.min()

    print("\n\nBest folder")
    print(best_folder)

In [None]:
# Reverse KL
get_best_models = True

if get_best_models:
    experiment_folder = "results/.../lightning_logs/"

    folder_list = [
        os.path.join(experiment_folder)
        ]

    find_best_run(folder_list)

In [None]:
#reverse KL + NLL

if get_best_models:

    experiment_folder = "results/.../lightning_logs/"

    folder_list = [
        os.path.join(experiment_folder)
        ]

    find_best_run(folder_list)

In [None]:
path_dict = {
    "TRADE_no_grid":    "../../results/runs_2D_ToyExample_two_external_parameters/<Your experiment name>/lightning_logs/version_0",
    "NLL_only":         "../../results/runs_2D_ToyExample_two_external_parameters/<Your experiment name>/lightning_logs/version_0",
    "reverse_KL":       "../../results/runs_2D_ToyExample_two_external_parameters/<Your experiment name>/lightning_logs/version_0",
    "reverse_KL_NLL":   "../../results/runs_2D_ToyExample_two_external_parameters/<Your experiment name>/lightning_logs/version_0",
}

Load the models

---

In [None]:
model_dict = {}
config_dict = {}

for k in path_dict:

    INN_k,config_k = load_INN(
        base_path = path_dict[k]
        )
    
    model_dict[k] = INN_k
    config_dict[k] = config_k

Plot the ESS as a funciton of the two external parameters

---

In [None]:
def get_ESS_r(log_p_theta_INN:torch.Tensor,log_p_target_INN:torch.Tensor)->float:

    # Compuete the relative Kish effective sample size
    log_omega = log_p_target_INN - log_p_theta_INN
    log_a = 2 * torch.logsumexp(log_omega,0)
    log_b = torch.logsumexp(2 * log_omega,0)

    ESS_r = (torch.exp(log_a - log_b) / len(log_omega)).item()

    return ESS_r

In [None]:
alpha_ESS = np.linspace(0.2,0.8,100)
beta_ESS = np.exp(np.linspace(np.log(0.25),np.log(4.0),100))

bs_ESS = 10000
samples_ESS = 2000
n_batches_ESS = int(samples_ESS / bs_ESS)

In [None]:
ESS_map_dict = {}

for key in model_dict:

    print(key)

    with torch.no_grad():

        ESS_map = torch.zeros([len(alpha_ESS),len(beta_ESS)])

        for i,alpha_i in enumerate(alpha_ESS):
            for j,beta_j in enumerate(beta_ESS):

                log_p_theta_INN = torch.zeros([0])
                log_p_target_INN = torch.zeros([0])

                for l in range(n_batches_ESS):

                    #Get INN samples
                    x_INN_l = model_dict[key].sample(bs_ESS,[alpha_i,beta_j])

                    #Get the INN density
                    log_p_theta_INN_l = model_dict[key].log_prob(x_INN_l,[alpha_i,beta_j]).detach().cpu()
                    log_p_theta_INN = torch.cat((log_p_theta_INN,log_p_theta_INN_l),0)

                    #Get the ground truth density
                    log_p_target_INN_l = log_p_2D_ToyExample_two_parameters(x_INN_l,[alpha_i,beta_j],device=device).detach().cpu()
                    log_p_target_INN = torch.cat((log_p_target_INN,log_p_target_INN_l),0)

                ESS_r_ij = get_ESS_r(log_p_theta_INN=log_p_theta_INN,log_p_target_INN=log_p_target_INN)

                ESS_map[i][j] = ESS_r_ij

        ESS_map_dict[key] = ESS_map

        torch.save(ESS_map,f"./ESS_map_{key}.pt")

In [None]:
if not os.path.exists("./Figures/"):
    os.mkdir("./Figures")

ESS_map_dict = {}

for i,key in enumerate(model_dict.keys()):

    ESS_map_i = torch.load(f"./ESS_map_{key}.pt")
    ESS_map_dict[key] = ESS_map_i

# Plot the ESS
fs = 15
label_dict = {
    "TRADE_no_grid":"TRADE (no grid)",
    "NLL_only":"NLL",
    "reverse_KL":"Reverse KL",
    "reverse_KL_NLL": "Reverse KL + NLL"
}
fig_ESS,axs_ESS = plt.subplots(2,2,figsize=(12,10))

x_grid,y_grid = np.meshgrid(beta_ESS,alpha_ESS)

# Get the limits for the plotting
min_ESS = None
max_ESS = None

for i,key in enumerate(model_dict.keys()):

    if (min_ESS is None) or (min_ESS > ESS_map_dict[key].min()):
        min_ESS = ESS_map_dict[key].min()

    if (max_ESS is None) or (max_ESS < ESS_map_dict[key].max()):
        max_ESS = ESS_map_dict[key].max()

axes_flat = axs_ESS.reshape(-1)

for i,key in enumerate(ESS_map_dict.keys()):

    ESS_map_i = torch.load(f"./ESS_map_{key}.pt")
    
    s = axes_flat[i].pcolormesh(x_grid,y_grid,ESS_map_i,cmap = "jet",vmin = min_ESS, vmax = max_ESS)
    axes_flat[i].set_title(label_dict[key],fontsize=fs)
    axes_flat[i].set_xscale("log")
    axes_flat[i].set_xticks([0.25,0.5,1.0,2.0,4.0],[0.25,0.5,1.0,2.0,4.0])
    axes_flat[i].tick_params(axis='both', which='major', labelsize=fs)
    axes_flat[i].set_xlabel(r"$\beta$",fontsize = fs)
    axes_flat[i].set_ylabel(r"$\alpha$",fontsize = fs)

    cbar = fig_ESS.colorbar(s, ax=axes_flat[i])
    cbar.ax.tick_params(labelsize=fs)

plt.tight_layout()

plt.savefig(
        os.path.join("./Figures","ESS_r.jpeg"),
        bbox_inches='tight',
        dpi = 300
)

plt.close(fig_ESS)

Plot the ground truth density:

---

In [None]:
alpha_list_plots = [0.2,0.35,0.5,0.65,0.8]
alpha_list_plots.reverse()
beta_list_plots = [0.2,0.5,1.0,2.0,5.0]

x_lims = [-15,15]
y_lims =  [-15,15]

x_res = 250
y_res = 250

fig_densities,axes_densities = plt.subplots(len(alpha_list_plots),len(beta_list_plots),figsize = [4 * len(beta_list_plots),4 * len(alpha_list_plots)])

for i,alpha_i in enumerate(alpha_list_plots):
    for j,beta_j in enumerate(beta_list_plots):

        with torch.no_grad():
            density_cij,x_grid_cij,y_grid_cij = eval_pdf_on_grid_2D(
                pdf = log_p_2D_ToyExample_two_parameters,
                x_lims = x_lims,
                y_lims=y_lims,
                x_res = x_res,
                y_res=y_res,
                device=device,
                kwargs_pdf={"parameter_list":[alpha_i,beta_j],"device":device}
            )

            plot_pdf_2D(
                pdf_grid=density_cij.detach().cpu().exp(),
                x_grid=x_grid_cij,
                y_grid=y_grid_cij,
                cmap = "jet",
                title = r"$\alpha = $" + f"{alpha_i}"+r" $\beta = $" + f"{round(beta_j,5)}",
                ax = axes_densities[i][j]
            )

    plt.tight_layout()

    plt.savefig(
            os.path.join(f"./Figures",f"densities_gt.jpeg"),
            bbox_inches='tight',
            dpi = 300
    )

Plot the densities of the INNs

---

In [None]:
alpha_list_plots = [0.2,0.35,0.5,0.65,0.8]
alpha_list_plots.reverse()
beta_list_plots = [0.2,0.5,1.0,2.0,5.0]

for c,key in enumerate(model_dict.keys()):

    fig_densities,axes_densities = plt.subplots(len(alpha_list_plots),len(beta_list_plots),figsize = [4 * len(beta_list_plots),4 * len(alpha_list_plots)])

    for i,alpha_i in enumerate(alpha_list_plots):
        for j,beta_j in enumerate(beta_list_plots):

            with torch.no_grad():
                density_cij,x_grid_cij,y_grid_cij = eval_pdf_on_grid_2D(
                    pdf = model_dict[key].log_prob,
                    x_lims = x_lims,
                    y_lims=y_lims,
                    x_res = x_res,
                    y_res=y_res,
                    device=device,
                    args_pdf=[[alpha_i,beta_j]]
                )

                plot_pdf_2D(
                    pdf_grid=density_cij.detach().cpu().exp(),
                    x_grid=x_grid_cij,
                    y_grid=y_grid_cij,
                    cmap = "jet",
                    title = r"$\alpha = $" + f"{alpha_i}"+r" $\beta = $" + f"{round(beta_j,5)}",
                    ax = axes_densities[i][j]
                )

    plt.tight_layout()

    plt.savefig(
            os.path.join(f"./Figures",f"densities_{key}.jpeg"),
            bbox_inches='tight',
            dpi = 300
    )



Output the validation KLD at different points in parameter space:

---

In [None]:
validation_data_loader_dict = {}

alpha_KLD = [0.2,0.4,0.5,0.6,0.8]
beta_KLD = [0.25,0.5,1.0,2.0,4.0]

validation_data_loader_dict = {}

n_validation_samples = 5000
bs_validation = 2500

for alpha in alpha_KLD:
    for beta in beta_KLD:

        DS_ij = DataSet_2D_ToyExample_external_two_parameters(
            d = 2,
            parameter_coordinates = [[alpha,beta]],
            mode="validation",
            n_samples=n_validation_samples,
            base_path=f"../../data/2D_Toy_two_external_parameters/"

        )

        DL_ij = DataLoader(
            DS_ij, 
            batch_size=bs_validation, 
            shuffle=True,
            num_workers=11,
            )
        
        validation_data_loader_dict[f"alpha_{alpha}_beta_{beta}"] = DL_ij

In [None]:
with open("../../data/2D_Toy_two_external_parameters/Z_dict.json","r") as f:
    Z_dict = json.load(f)
f.close()

In [None]:
with torch.no_grad():  

    KLD_dict = {}
    error_KLD_dict = {}

    n_samples_bootstrap = 50

    for alpha in tqdm.tqdm(alpha_KLD):
        for beta in beta_KLD:

            KLD_dict[f"alpha_{alpha}_beta_{beta}"] = {}
            error_KLD_dict[f"alpha_{alpha}_beta_{beta}"] = {}

            for key in model_dict:

                
                DL_i = validation_data_loader_dict[f"alpha_{alpha}_beta_{beta}"]  

                log_p_theta_val = torch.zeros([0])
                log_p_target_val = torch.zeros([0])

                for j,(alpha_batch,beta_batch,x_batch) in enumerate(DL_i):   

                    log_p_target_val_j = log_p_2D_ToyExample_two_parameters(
                        x = x_batch,
                        parameter_list=[alpha,beta],
                        device = "cpu",
                        Z = Z_dict[f"alpha_{alpha}_beta_{beta}"]
                    )
                    log_p_target_val = torch.cat((log_p_target_val,log_p_target_val_j),0)

                    log_p_theta_val_j = model_dict[key].log_prob(x = x_batch.to(device),parameter_list = [alpha,beta]).detach().cpu()
                    log_p_theta_val = torch.cat((log_p_theta_val,log_p_theta_val_j),0)

                assert(log_p_target_val.shape == log_p_theta_val.shape)

                samples_KLD = np.zeros(n_samples_bootstrap)

                for i in range(n_samples_bootstrap):
                    indices = np.random.randint(0,len(log_p_theta_val),len(log_p_theta_val))
                    samples_KLD[i] = (log_p_target_val[indices] - log_p_theta_val[indices]).mean()

                mean_samples_KLD = samples_KLD.mean()

                error_KLD_j = np.sqrt(np.square(samples_KLD - mean_samples_KLD).sum() / (n_samples_bootstrap - 1))
                val_KLD_j = (log_p_target_val - log_p_theta_val).mean()

                KLD_dict[f"alpha_{alpha}_beta_{beta}"][key] = val_KLD_j
                error_KLD_dict[f"alpha_{alpha}_beta_{beta}"][key] = error_KLD_j

In [None]:
col_key_to_col_label_dict = {
    "TRADE_no_grid":"TRADE",
    "NLL_only":"NLL",
    "reverse_KL":"rev. KL",
    "reverse_KL_NLL":"rev. KL + NLL"
}

In [None]:
table_str = "\\begin{tabularx}{\\textwidth}{|c|"

for i in range(len(col_key_to_col_label_dict.keys())):
    table_str = table_str + "c|"
table_str = table_str+ "}\n\hline\n(\\alpha,\\beta)"

# Column names
for col_key in model_dict.keys():
    table_str += "&\\textbf{" + col_key_to_col_label_dict[col_key] + "}"
table_str += "\\\\\n\hline\n"

# Get the best value in each row
is_best_dict = {}
for row_key in KLD_dict.keys():
    is_best_dict[row_key] = {}

    best_col_val_i = None

    for col_key in model_dict.keys():
        is_best_dict[row_key][col_key] = False

        if (best_col_val_i is None) or (KLD_dict[row_key][col_key] < best_col_val_i):
            best_col_i = col_key
            best_col_val_i = KLD_dict[row_key][col_key]

    is_best_dict[row_key][best_col_i] = True
# Fill the table
for row_key in KLD_dict.keys():

    alpha_i = row_key.split("_")[1]
    beta_i = row_key.split("_")[3]

    table_str += f"({alpha_i}, {beta_i})"

    for col_label in col_key_to_col_label_dict.keys():

        # Round the entries of the cell
        magnitude = np.floor(np.log10(abs(error_KLD_dict[row_key][col_label]))) 
        magnitude = abs(int(magnitude - 2))

        if is_best_dict[row_key][col_label]:
            table_str += "&\\textbf{"+ f"{round(KLD_dict[row_key][col_label].item(),magnitude)}$\pm${round(error_KLD_dict[row_key][col_label].item(),magnitude)}"+"}"
        else:
            table_str += f"&{round(KLD_dict[row_key][col_label].item(),magnitude)}$\pm${round(error_KLD_dict[row_key][col_label].item(),magnitude)}"

    table_str += "\\\\\n"

table_str += "\hline\n"

table_str = table_str +"\end{tabularx}"

print(table_str)