In [None]:
import matplotlib.pyplot as plt
import torch 
import tqdm

from pinf.plot.utils import eval_pdf_on_grid_2D
from pinf.models.GMM import GMM
from pinf.models.histogram import HistogramDist

Settings

---

In [None]:
n_samples_pseudo_energies = 1000000
bs_pseudo_energy = 100000
n_bins_hist = 500
fs = 20

Initialize the distribution

---

In [None]:
m1 = torch.tensor([-0.5,-0.5]).reshape(1,-1)
m2 = torch.tensor([0.5,0.5]).reshape(1,-1)
means = torch.cat((m1,m2),0)

S1 = (torch.eye(2) * 0.2).reshape(1,2,2)
S2= (torch.eye(2) * 0.1).reshape(1,2,2)
S = torch.cat((S1,S2),0)
p_GMM = GMM(means = means,covs=S,weights = torch.tensor([0.5,0.5]))

Get pseudo-energies 

---

In [None]:
pseudo_energies = torch.zeros([0])

n_batches = int(n_samples_pseudo_energies / bs_pseudo_energy)

for i in tqdm.tqdm(range(n_batches)):

    # Get samples following the target distribution
    x_i = p_GMM.sample(bs_pseudo_energy)

    # Get the pseudo-energy
    e_i = - p_GMM.log_prob(x_i)

    pseudo_energies = torch.cat((pseudo_energies,e_i))

Get empirical distribution of the pseudo energies

---

In [None]:
p_hist = HistogramDist(
    data = pseudo_energies,
    n_bins = n_bins_hist
)

Get energy distribution by integrting the distribution in data space

---

In [None]:
# Evaluate the distribution on a grid
lim_pdf_grid = 2.0
res_pdf_grid = 2000

p_GMM_grid,x_grid,y_grid = eval_pdf_on_grid_2D(
    pdf = p_GMM,
    x_lims = [-lim_pdf_grid,lim_pdf_grid],
    y_lims = [-lim_pdf_grid,lim_pdf_grid],
    x_res = res_pdf_grid,
    y_res = res_pdf_grid
    )

# Get the volume element for integration
dA = (x_grid[0][1] - x_grid[0][0]) * (y_grid[1][0] - y_grid[0][0])

In [None]:
# Compute the integral up up to a certain energy threshold
e_integral= torch.linspace(pseudo_energies.min()-1,pseudo_energies.max()+1,1000)
p_GMM_grid_flat = p_GMM_grid.flatten()

integrals = torch.zeros(len(e_integral))

for i,e_i in tqdm.tqdm(enumerate(e_integral)):
    mask = (-p_GMM_grid_flat.log() < e_i)

    I_i = p_GMM_grid_flat[mask].sum() * dA

    integrals[i] = I_i

In [None]:
# Plot the Integral as a function of the threshold
fig,ax = plt.subplots(1,1,figsize = (13,6))

ax.plot(e_integral,integrals,lw = 3,c = "k")

ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_xlabel(r"$E'$",fontsize = fs)
ax.set_ylabel(r"$I(E')$",fontsize = fs)

plt.savefig("./energy_distribution_integral.pdf")

In [None]:
# Compute the gradient of the integral with respect to the threshold energy
e_center = (e_integral[1:] + e_integral[:-1]) / 2

de = e_integral[1] - e_integral[0]
grad = (integrals[1:] - integrals[:-1]) / de

# Normalize the distribution
Z = grad.sum() * de
grad /= Z

Plotting

---

In [None]:
e_plot = torch.linspace(pseudo_energies.min()-1,pseudo_energies.max()+1,1000)

fig,ax = plt.subplots(1,1,figsize = (13,6))

# Empirical distribution based on observed samples
lw = 3

ax.plot(e_plot,p_hist(e_plot),label = "samples",lw = lw,c = "orange")
ax.plot(e_center,grad,label = "integral",lw = lw,ls = "-.",c = "k")
ax.legend(fontsize = fs)
ax.tick_params(axis='x', labelsize=fs)
ax.tick_params(axis='y', labelsize=fs)
ax.set_xlabel("e",fontsize = fs)
ax.set_ylabel("p(e)",fontsize = fs)

plt.savefig("./energy_distribution_hist_vs_integral.pdf")