In [None]:
import yaml
import torch
import matplotlib.pyplot as plt
import json
import numpy as np
import os
from pathlib import Path
import sys

from pinf.models.construct_INN_2D_GMM import set_up_sequence_INN_2D_GMM
from pinf.datasets.log_likelihoods import log_p_2D_GMM
from pinf.models.histogram import HistogramDist
from pinf.models.GMM import GMM
from pinf.datasets.parameters import S_2D_GMM,means_2D_GMM

Settings

---

In [None]:
T_0 = 1.0
T = 0.54556
device = "cuda:0"
bins = 250
n_samples = 100000

Initialize the target distribution

---

In [None]:
gmm = GMM(means=means_2D_GMM,covs=S_2D_GMM,device=device)

with open("../../data/2D_GMM/Z_T.json","r") as f:
    Z_T_dict = json.load(f)
f.close()

Approximate the energy distribution for the target

---

In [None]:
# Get data at the two condition values
data_c0 = torch.load(f"../../data/2D_GMM/validation_data/T_{T_0}_dim_{2}.pt")[:n_samples].to(device)
data_c = torch.load(f"../../data/2D_GMM/validation_data/T_{T}_dim_{2}.pt")[:n_samples].to(device)

c_0 = 1 / T_0
c = 1 / T

print("c_0 = ",c_0)
print("c = ",c)

In [None]:
e_gt_c0 = - log_p_2D_GMM(x = data_c0,beta_tensor=c_0,device=device,gmm=gmm)
e_gt_c = - log_p_2D_GMM(x = data_c,beta_tensor=c_0,device=device,gmm=gmm)

In [None]:
p_e_gt_c0 = HistogramDist(
    data = e_gt_c0.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

p_e_gt_c = HistogramDist(
    data = e_gt_c.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

Learned distributions:

---

TSF

In [None]:
base_path_volume_preserving =  "../../results/2D_GMM/<Your experiment name>/lightning_logs/version_0"

In [None]:
def load_INN_2D_GMM(base_path:str,device:str = "cuda:0"):

    config_i = yaml.safe_load(Path(base_path + "/hparams.yaml").read_text())
    state_dict_folder_i = base_path + f"/checkpoints/"

    files = os.listdir(state_dict_folder_i)
    
    for f in files:
        if f.startswith("checkpoint_epoch"):
            state_dict_path_i = os.path.join(state_dict_folder_i,f)
            break

    config_i["device"] = device

    INN_i = set_up_sequence_INN_2D_GMM(config=config_i)
    INN_i.load_state_dict(state_dict_path_i)
    INN_i.train(False)

    print(state_dict_path_i)

    return INN_i,config_i

In [None]:
INN_TSF,_ = load_INN_2D_GMM(
    base_path=base_path_volume_preserving,
    device=device
)

INN_TSF.eval()

In [None]:
# Get energy samples
with torch.no_grad():
    x_TSF_c0 = INN_TSF.sample(n_samples,c_0)
    e_TSF_c0 = - INN_TSF.log_prob(x_TSF_c0,c_0).detach().cpu()

    x_TSF_c = INN_TSF.sample(n_samples,c)
    e_TSF_c = - INN_TSF.log_prob(x_TSF_c,c_0).detach().cpu()

p_e_TSF_c0 = HistogramDist(
    data = e_TSF_c0.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

p_e_TSF_c = HistogramDist(
    data = e_TSF_c.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

TRADE

---

In [None]:
base_path_TRADE =   "../../results/2D_GMM/<Your experiment name>/lightning_logs/version_0"

In [None]:
INN_TRADE,_ = load_INN_2D_GMM(
    base_path=base_path_TRADE,
    device=device
)

INN_TRADE.eval()

In [None]:
# Get energy samples
with torch.no_grad():
    x_TRADE_c0 = INN_TRADE.sample(n_samples,c_0)
    e_TRADE_c0 = - INN_TRADE.log_prob(x_TRADE_c0,c_0).detach().cpu()

    x_TRADE_c = INN_TRADE.sample(n_samples,c)
    e_TRADE_c = - INN_TRADE.log_prob(x_TRADE_c,c_0).detach().cpu()

p_e_TRADE_c0 = HistogramDist(
    data = e_TRADE_c0.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

p_e_TRADE_c = HistogramDist(
    data = e_TRADE_c.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

NLL only

---

In [None]:
base_path_NLL =   "../../results/2D_GMM/<Your experiment name>/lightning_logs/version_0"

In [None]:
INN_NLL,_ = load_INN_2D_GMM(
    base_path=base_path_NLL,
    device=device
)

INN_NLL.eval()

In [None]:
# Get energy samples
with torch.no_grad():
    x_NLL_c0 = INN_NLL.sample(n_samples,c_0)
    e_NLL_c0 = - INN_NLL.log_prob(x_NLL_c0,c_0).detach().cpu()

    x_NLL_c = INN_NLL.sample(n_samples,c)
    e_NLL_c = - INN_NLL.log_prob(x_NLL_c,c_0).detach().cpu()

p_e_NLL_c0 = HistogramDist(
    data = e_NLL_c0.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

p_e_NLL_c = HistogramDist(
    data = e_NLL_c.detach().cpu(),
    n_bins=bins,
    device="cpu"
)

Plotting

---

In [None]:
e_eval_c0 = torch.linspace(0.0,10,1000)
e_eval_c = torch.linspace(0.0,7,1000)

fig,axes = plt.subplots(4,2,figsize = (13,15))

fs = 20
lw = 3

################################################################
# Ground truth distribution

axes[0][0].plot(e_eval_c0,p_e_gt_c0(e_eval_c0),lw = lw,c = "orange")
axes[0][1].plot(e_eval_c,p_e_gt_c(e_eval_c),lw = lw,c = "orange",label = "target")

p_e_trafo_gt = p_e_gt_c0(e_eval_c) * torch.exp((1 - c / c_0)*e_eval_c)
Z = p_e_trafo_gt.sum() * (e_eval_c[1] - e_eval_c[0])
p_e_trafo_gt /= Z
axes[0][1].plot(e_eval_c,p_e_trafo_gt,lw = lw,c = "k",ls = "--",label = "transformed")

################################################################
# TSF
axes[1][0].plot(e_eval_c0,p_e_TSF_c0(e_eval_c0),lw = lw,c = "orange",label = "target")
axes[1][1].plot(e_eval_c,p_e_TSF_c(e_eval_c),lw = lw,c = "orange",label = "target")

p_e_trafo_TSF = p_e_TSF_c0(e_eval_c) * torch.exp((1 - c / c_0)*e_eval_c)
Z = p_e_trafo_TSF.sum() * (e_eval_c[1] - e_eval_c[0])
p_e_trafo_TSF /= Z
axes[1][1].plot(e_eval_c,p_e_trafo_TSF,lw = lw,c = "k",ls = "--",label = "transformed")

################################################################
# TRADE
axes[2][0].plot(e_eval_c0,p_e_TRADE_c0(e_eval_c0),lw = lw,c = "orange",label = "target")
axes[2][1].plot(e_eval_c,p_e_TRADE_c(e_eval_c),lw = lw,c = "orange",label = "target")

p_e_trafo_TRADE = p_e_TRADE_c0(e_eval_c) * torch.exp((1 - c / c_0)*e_eval_c)
Z = p_e_trafo_TRADE.sum() * (e_eval_c[1] - e_eval_c[0])
p_e_trafo_TRADE /= Z
axes[2][1].plot(e_eval_c,p_e_trafo_TRADE,lw = lw,c = "k",ls = "--",label = "transformed")

################################################################
# NLL

axes[3][0].plot(e_eval_c0,p_e_NLL_c0(e_eval_c0),lw = lw,c = "orange",label = "target")
axes[3][1].plot(e_eval_c,p_e_NLL_c(e_eval_c),lw = lw,c = "orange",label = "target")

p_e_trafo_NLL = p_e_NLL_c0(e_eval_c) * torch.exp((1 - c / c_0)*e_eval_c)
Z = p_e_trafo_NLL.sum() * (e_eval_c[1] - e_eval_c[0])
p_e_trafo_NLL /= Z
axes[3][1].plot(e_eval_c,p_e_trafo_NLL,lw = lw,c = "k",ls = "--",label = "transformed")

names = ["Target","TSF","TRADE","NLL"]
c_list = [r"$c_0 = $"+f"{c_0}",r"$c = $"+f"{round(c,5)}"]
for i in range(4):
    for j in range(2):
        axes[i][j].set_title(names[i]+" "+c_list[j],fontsize = fs)
        axes[i][j].tick_params(axis='x', labelsize=fs)
        axes[i][j].tick_params(axis='y', labelsize=fs)
        axes[i][j].set_xlabel("e",fontsize = fs)
        axes[i][j].set_ylabel("p(e)",fontsize = fs)

    axes[i][1].legend(fontsize = fs)

plt.tight_layout()

plt.savefig(
    os.path.join(f"./transformed_energy_dist.pdf"),
    bbox_inches='tight'
)
plt.close(fig)