In [None]:
import torch
import matplotlib.pyplot as plt
import os
import yaml
from pathlib import Path
from torch.utils.data import DataLoader
import tqdm

from pinf.datasets.datasets import get_EMNIST_datasets
from pinf.models.construct_INN_EMNIST import set_up_sequence_INN_MNIST_like
from pinf.models.histogram import HistogramDist

Set parameters of the evaluation

---

In [None]:
experiment_name = "EMNIST_digits"

if experiment_name == "EMNIST_digits":
    model_folder = "../../results/runs_EMNIST_digits/<Your experiment name>/lightning_logs/version_0/"
    data_dim = 28 * 28
    n_samples_plot_per_class = 15
    n_classes = 10

else:
    raise ValueError

Initialize the INN

---

In [None]:
config = yaml.safe_load(Path(model_folder + "/hparams.yaml").read_text())
state_dict_folder = os.path.join(model_folder,"checkpoints/")
state_dict_files = os.listdir(state_dict_folder)

if len(state_dict_files) > 1:
    raise ValueError("more than one state ditct provided")

state_dict_file = os.path.join(state_dict_folder,state_dict_files[0])
print(state_dict_file)

In [None]:
# Folder for results
folder = f"../../results/{experiment_name}/"

if not os.path.exists(folder):
    os.makedirs(folder)

In [None]:
with torch.no_grad():
    INN = set_up_sequence_INN_MNIST_like(config=config)
    INN.load_state_dict(path = state_dict_file)
    INN.train(False)

Load the validation data

---

In [None]:
DS_training,DS_validation = get_EMNIST_datasets(
    data_folder = "../../data/",
    mean_normalization = config["config_data"]["init_data_set_params"]["mean_normalization"],
    scale_normalization = config["config_data"]["init_data_set_params"]["scale_normalization"],
    sigma_dequantization = 0.0,
    split = "digits"
)

val_DL = DataLoader(dataset=DS_validation,batch_size = 512)
train_DL = DataLoader(dataset=DS_training,batch_size = 512)

Plot states

---

In [None]:
def plot_states(samples_list:list,n_samples_per_class:int,n_classes:int,name:str=None,rotate:bool = False)->None:
    fig,axes =  plt.subplots(n_classes,n_samples_per_class,figsize = (n_samples_per_class * 5,n_classes * 5))

    for i in range(n_classes):
        for j in range(n_samples_per_class):

            if rotate:
                im = samples_list[i][j].permute(1,0)
            else:
                im = samples_list[i][j]

            axes[i][j].imshow(im,cmap = "Grays")
            axes[i][j].axis("off")

    plt.tight_layout()

    if name is not None:
        plt.savefig(
        os.path.join(folder,name),
        bbox_inches='tight'
    )
    plt.close(fig)

n_samples_plot_per_class = 6 

Get validation samples

In [None]:
im_list_validation = []

for i in range(n_classes):
    for batch in val_DL:
        mask = (batch[1] == i)
        
        im_i = batch[0][mask][:n_samples_plot_per_class].squeeze().detach().cpu()

        im_list_validation.append(im_i)

        break

In [None]:
plot_states(samples_list = im_list_validation,n_samples_per_class = n_samples_plot_per_class,n_classes = n_classes,rotate=True,name = "val_samples.pdf")

In [None]:
fig,axes = plt.subplots(1,2,figsize = (20,10))            

im_1 = im_list_validation[2][3].permute(1,0)
im_2 = im_list_validation[2][1].permute(1,0)
axes[0].imshow(im_1,cmap = "Grays")
axes[0].axis("off")

axes[1].imshow(im_2,cmap = "Grays")
axes[1].axis("off")

plt.savefig(os.path.join(folder,"Different_types_letter_2.pdf"))

In [None]:
fig,axes = plt.subplots(1,2,figsize = (20,10))            

im_1 = im_list_validation[7][2].permute(1,0)
im_2 = im_list_validation[7][3].permute(1,0)
axes[0].imshow(im_1,cmap = "Grays")
axes[0].axis("off")

axes[1].imshow(im_2,cmap = "Grays")
axes[1].axis("off")

plt.savefig(os.path.join(folder,"Different_types_letter_7.pdf"))

Get INN samples

In [None]:
c_tensor = torch.ones(n_classes,n_samples_plot_per_class,device = config["device"])
c_tensor *= torch.arange(n_classes,device = config["device"]).reshape(-1,1)
c_tensor = c_tensor.reshape(-1,1).long()

x_INN = INN.sample(n_samples = len(c_tensor),beta_tensor = c_tensor).detach().cpu()

im_list = []

for i in range(n_classes):
    im_list.append(x_INN[i * n_samples_plot_per_class:(i+1)*n_samples_plot_per_class].squeeze().detach().cpu())

In [None]:
plot_states(samples_list = im_list,n_samples_per_class = n_samples_plot_per_class,n_classes = n_classes,rotate=True,name = "model_samples.pdf")

Sample energies following the learned distribution

---

In [None]:
n_samples_energy = int(1e6)
bs_energy = int(1e3)

n_batches = int(n_samples_energy / bs_energy)

INN.eval()

if not os.path.exists(folder):
    os.makedirs(folder)

with torch.no_grad():
    for c in tqdm.tqdm(range(n_classes)):

        if os.path.exists(os.path.join(folder,f"energies_INN_c_{c}.pt")):
            continue

        energies_c = torch.zeros([0])

        for i in tqdm.tqdm(range(n_batches)):
            c_tensor = c * torch.ones(bs_energy,device = config["device"]).reshape(-1,1) 
            x_i = INN.sample(n_samples=bs_energy,beta_tensor=c_tensor.long())

            energies_ci = - INN.log_prob(x_i,c_tensor.long()).detach().cpu()

            energies_c = torch.cat((energies_c,energies_ci),0)

        # Save the recorded energies
        torch.save(energies_c,os.path.join(folder,f"energies_INN_c_{c}.pt"))

Get the energies of the validation set

---

In [None]:
with torch.no_grad():
    for c in tqdm.tqdm(range(n_classes)):

        if os.path.exists(os.path.join(folder,f"energies_data_c_{c}.pt")):
            continue

        energies_c = torch.zeros([0])

        for batch in val_DL:
            mask = (batch[1] == c)
        
            x_i = batch[0][mask].to(config["device"])

            c_tensor = c * torch.ones(len(x_i),device = config["device"]).reshape(-1,1) 

            energies_ci = - INN.log_prob(x_i,c_tensor.long()).detach().cpu()

            energies_c = torch.cat((energies_c,energies_ci),0)

        for batch in train_DL:
            mask = (batch[1] == c)
        
            x_i = batch[0][mask].to(config["device"])

            c_tensor = c * torch.ones(len(x_i),device = config["device"]).reshape(-1,1) 

            energies_ci = - INN.log_prob(x_i,c_tensor.long()).detach().cpu()

            energies_c = torch.cat((energies_c,energies_ci),0)

        # Save the recorded energies
        torch.save(energies_c,os.path.join(folder,f"energies_data_c_{c}.pt"))

Get the energies of the validation set with noise

---

In [None]:
with torch.no_grad():
    for c in tqdm.tqdm(range(n_classes)):

        if os.path.exists(os.path.join(folder,f"energies_data_noisy_c_{c}.pt")):
            continue

        energies_c = torch.zeros([0])

        for batch in val_DL:
            mask = (batch[1] == c)
        
            x_i = batch[0][mask].to(config["device"])
            x_i += torch.randn_like(x_i) * config["config_data"]["data_set_config"]["sigma_dequantization"]

            c_tensor = c * torch.ones(len(x_i),device = config["device"]).reshape(-1,1) 

            energies_ci = - INN.log_prob(x_i,c_tensor.long()).detach().cpu()

            energies_c = torch.cat((energies_c,energies_ci),0)

        for batch in train_DL:
            mask = (batch[1] == c)
        
            x_i = batch[0][mask].to(config["device"])
            x_i += torch.randn_like(x_i) * config["config_data"]["data_set_config"]["sigma_dequantization"]

            c_tensor = c * torch.ones(len(x_i),device = config["device"]).reshape(-1,1) 

            energies_ci = - INN.log_prob(x_i,c_tensor.long()).detach().cpu()

            energies_c = torch.cat((energies_c,energies_ci),0)

        # Save the recorded energies
        torch.save(energies_c,os.path.join(folder,f"energies_data_noisy_c_{c}.pt"))

Load the stored pseudo-energies and compute the empirical distribution

---

In [None]:
p_e_INN_list = []

n_bins = 500

min_e = -4000
max_e = -2500

for c in tqdm.tqdm(range(n_classes)):

    
    energies_c = torch.load(os.path.join(folder,f"energies_INN_c_{c}.pt"))
    mask = torch.isfinite(energies_c)
    energies_c = energies_c[mask]

    mask = (energies_c >= min_e) * (energies_c <= max_e)
    energies_c = energies_c[mask]

    p_e_c = HistogramDist(
        data = energies_c,
        n_bins = n_bins
    )

    p_e_INN_list.append(p_e_c)


In [None]:
p_e_val_list = []

for c in tqdm.tqdm(range(n_classes)):
    energies_c = torch.load(os.path.join(folder,f"energies_data_c_{c}.pt"))

    p_e_c = HistogramDist(
        data = energies_c,
        n_bins = n_bins
    )

    p_e_val_list.append(p_e_c)

In [None]:
p_e_val_noisy_list = []

for c in tqdm.tqdm(range(n_classes)):
    energies_c = torch.load(os.path.join(folder,f"energies_data_noisy_c_{c}.pt"))

    p_e_c = HistogramDist(
        data = energies_c,
        n_bins = n_bins
    )

    p_e_val_noisy_list.append(p_e_c)

Plot the distributions of the pseudo-energies:

---

In [None]:
fig,axes = plt.subplots(5,2,figsize = (13,12))

fs = 15
e_eval = torch.linspace(min_e-100,max_e+100,1000)

for i,ax in enumerate(axes.flatten()):

    ax.set_title(f"class '{i}'",fontsize = fs)
    ax.plot(e_eval,p_e_INN_list[i](e_eval),c = "k",lw = 3,label = "INN samples")
    ax.plot(e_eval,p_e_val_list[i](e_eval),c = "orange",lw = 3,label = "observed samples")
    ax.plot(e_eval,p_e_val_noisy_list[i](e_eval),c = "b",lw = 3,label = "observed samples + noise")
    ax.tick_params(axis='x', labelsize=fs)
    ax.tick_params(axis='y', labelsize=fs)
    ax.set_xlabel(r"$e$",fontsize = fs)
    ax.set_ylabel(r"$p(e)$",fontsize = fs)


handles, labels = [], []

for handle, label in zip(*ax.get_legend_handles_labels()):
        handles.append(handle)
        labels.append(label)
    
fig.legend(handles, labels, loc='upper center', bbox_to_anchor=(0.5, -0.0), ncol=3,fontsize = fs)


plt.tight_layout()

plt.savefig(
    os.path.join(folder,f"energy_distributions.pdf"),
    bbox_inches='tight'
)
plt.close(fig)