In [None]:
import torch
import matplotlib.pyplot as plt
from functools import partial
import os
import json

from pinf.models.GMM import GMM
from pinf.plot.utils import eval_pdf_on_grid_2D
from pinf.datasets.log_likelihoods import log_p_2D_ToyExample_two_parameters

Settings

---

In [None]:
device = "cpu"
generate_new = True

Target distribution

---

In [None]:
def p_alpha_beta(x,alpha,beta,device,Z = None):
    return log_p_2D_ToyExample_two_parameters(
        x = x,
        parameter_list=[alpha,beta],
        device = device,
        Z = Z).exp()

Plot the distributions and approximate the partition function

---

In [None]:
alpha_list = [0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9]
beta_list = [0.2,0.25,1/3,0.5,1.0,2.0,3.0,4.0,5.0]

In [None]:
folder = "../data/2D_Toy_two_external_parameters"

if generate_new:

   fig,axes = plt.subplots(len(alpha_list),len(beta_list),figsize = (5 * len(beta_list),5 * len(alpha_list)))

   Z_dict = {}

   for i,alpha in enumerate(alpha_list):
      for j,beta in enumerate(beta_list):
        
         p_ij = partial(p_alpha_beta,alpha = alpha,beta = round(beta,5),device = device)

         pdf_grid,x_grid,y_grid = eval_pdf_on_grid_2D(
            pdf=p_ij,
            x_lims = [-17,17],
            y_lims = [-17,17],
            x_res = 10000,
            y_res = 10000,
            )
      
         #Get the volume element
         dA = (x_grid[0,1] - x_grid[0,0]) * (y_grid[1,0] - y_grid[0,0])

         #Get the partition function
         Z_ij = pdf_grid.sum() * dA
        
         Z_dict[f"alpha_{alpha}_beta_{round(beta,5)}"] = Z_ij.item()
         
         axes[i][j].imshow(pdf_grid,extent = [x_grid.min(),x_grid.max(),y_grid.min(),y_grid.max()],origin = 'lower',cmap = "jet")

         axes[i][j].set_title(f'a = {alpha}, b = {round(beta,5)}')
         axes[i][j].axis('off')

   if not os.path.exists(folder):
      os.makedirs(folder)

   with open(os.path.join(folder,'Z_dict.json'), 'w') as f:
      json.dump(Z_dict, f)
   f.close()

Perform rejection sampling sampling

---

In [None]:
class ConcatenatedGMM(GMM):
    def __init__(self,means:torch.tensor,covs:torch.tensor,sigma_noise:float,weights:torch.tensor = None,device = None)->None:

        #Compute new covariance matrices

        for i in range(len(means)):
            
            covs[i] = covs[i] + torch.eye(2) * sigma_noise**2

        super().__init__(means = means,covs = covs,weights = weights,device=device)

In [None]:
S_1 = torch.tensor([[1.0,-0.5],[-0.5,7.0]])
S_2 = torch.tensor([[1.0,0.5],[0.5,7.0]])

m_1 = torch.tensor([-4.0,0.0])
m_2 = torch.tensor([4.0,0.0])

with open(os.path.join(folder,'Z_dict.json'), 'r') as f:
    Z_dict = json.load(f)
f.close()

if generate_new:
    n_samples = 500000
    bs = 10000

    for i,alpha in enumerate(alpha_list):
        for j,beta in enumerate(beta_list):

            #Get the proposal distribution, use condatenation with gaussian in case of beta < 1.0 for better tail sampling
            if beta < 1.0:
                p_prop = ConcatenatedGMM(
                    means = [m_1,m_2],
                    covs = [S_1,S_2],
                    device = device,
                    weights = torch.tensor([alpha,1.0 - alpha]),
                    sigma_noise = 3.0
                )

            else:
                p_prop = GMM(
                    means = [m_1,m_2],
                    covs = [S_1,S_2],
                    device = device,
                    weights = torch.tensor([alpha,1.0 - alpha])
                )

            p_eval = partial(p_alpha_beta,alpha = alpha,beta = round(beta,5),Z = Z_dict[f"alpha_{alpha}_beta_{round(beta,5)}"],device = device)
            
            samples_i = torch.zeros([0,2])

            while True:

                #Get u
                u = torch.rand(bs)

                #get proposals
                x_prop = p_prop.sample(bs)

                r = p_eval(x_prop) / (p_prop(x_prop) * 100)

                accept = u < r

                samples_i = torch.cat((samples_i,x_prop[accept]),dim = 0)

                if len(samples_i) > n_samples:
                    break

            #Save the results
            if os.path.exists(os.path.join(folder,"validation_data/")) == False:
                os.makedirs(os.path.join(folder,"validation_data/"))
         
            if os.path.exists(os.path.join(folder,"training_data/")) == False:
                os.makedirs(os.path.join(folder,"training_data/"))
           
            torch.save(samples_i[:int(0.8 * n_samples)],os.path.join(folder,f'training_data/alpha_{alpha}_beta_{round(beta,5)}_dim_2.pt'))
            torch.save(samples_i[int(0.8 * n_samples):],os.path.join(folder,f'validation_data/alpha_{alpha}_beta_{round(beta,5)}_dim_2.pt'))

Plot samples

---

In [None]:
fig,axes = plt.subplots(len(alpha_list),len(beta_list),figsize = (5 * len(beta_list),5 * len(alpha_list)))

for i,alpha in enumerate(alpha_list):
    for j,beta in enumerate(beta_list):

        data_val_i = torch.load(os.path.join(folder,f'validation_data/alpha_{alpha}_beta_{round(beta,5)}_dim_2.pt'))

        axes[i,j].scatter(data_val_i[:,0],data_val_i[:,1],s = 0.1)
        axes[i,j].set_title(f'alpha = {alpha} = beta_{round(beta,5)}')
        axes[i,j].axis('off')
        axes[i,j].set_xlim(-17,17)
        axes[i,j].set_ylim(-17,17)

Plot Empirical distribution of the samples

---

In [None]:
fig,axes = plt.subplots(len(alpha_list),len(beta_list),figsize = (5 * len(beta_list),5 * len(alpha_list)))

for i,alpha in enumerate(alpha_list):
    for j,beta in enumerate(beta_list):

        data_train_i = torch.load(os.path.join(folder,f'training_data/alpha_{alpha}_beta_{round(beta,5)}_dim_2.pt'))

        _ = axes[i,j].hist2d(data_train_i[:,0].numpy(),data_train_i[:,1].numpy(),bins = 150, density = True,range = [[-15,15],[-15,15]],cmap = "jet")
        axes[i,j].set_title(f'alpha = {alpha} = beta_{round(beta,5)}')
        axes[i,j].axis('off')
        axes[i,j].set_xlim(-17,17)
        axes[i,j].set_ylim(-17,17)