In [None]:
from pyro import distributions as dist
import torch
from matplotlib import pyplot as plt
import numpy as np
import seaborn as sns
from glob import glob
from src.usflows.explib.config_parser import from_checkpoint
from src.usflows.explib.eval import RadialFlowEvaluator
import os
from src.usflows.distributions import Chi
from src.usflows.explib.datasets import DistributionDataset
from src.usflows.distributions import GMM
from torch.nn.functional import softplus

# Evaluate all Dims

In [None]:
# Change these variables as needed
arch = "" # Architecture identifier string (for titles / file names)
base_dir = "" # Path to directory containing subdirectories of model checkpoints

In [None]:
subfolders = sorted(os.listdir(base_dir))
subfolders = [os.path.join(base_dir, d) for d in subfolders]
subfolders = sorted([d for d in subfolders if os.path.isdir(d)])

model_dirs = [
    os.path.join(base_dir, subfolder) for subfolder in subfolders if os.path.isdir(os.path.join(base_dir, subfolder))
]

data = {}
print(model_dirs)
for i, model_dir in enumerate(model_dirs):
    print(model_dir)
    # Locate model files
    pkl_files = sorted([f for f in os.listdir(model_dir) if f.endswith(".pkl")])
    pt_files = sorted([f for f in os.listdir(model_dir) if f.endswith(".pt")])

    if not pkl_files or not pt_files:
        print(f"Skipping {model_dir} (missing files)")
        continue

    pkl_path = os.path.join(model_dir, pkl_files[-1])
    pt_path = os.path.join(model_dir, pt_files[-1])
    try:
        model = from_checkpoint(pkl_path, pt_path)
    except:
        continue
        
    dim = int(model_dir.split("_")[-1][:-1])

    print(f"{dim}D GMM")
    hdim = int(dim/2)
    distribution=GMM(
        loc=torch.stack([-torch.ones(dim), torch.ones(dim)]), 
        covariance_matrix=torch.stack([
            torch.diag(torch.tensor([5.]*hdim + [.5]*hdim)), 
            torch.eye(dim)
        ]),
        mixture_weights=torch.ones(2)/2
    )
    ref_dist = distribution
    
    ds = DistributionDataset(
        distribution=distribution,
        num_samples=10000
    )[:][0]

    data[i] = ds

    evaluator = RadialFlowEvaluator(
        model,
        ds,
        p=2.0,
        norm_distribution=Chi(
            df=dim,
            scale=softplus(model.base_distribution.scale_unconstrained),
            validate_args=False
        )
    )

    row = i
    col = 0
    
    scatter_fig, ax = plt.subplots()
    evaluator.logprob_reference_scatter_plot(ax=ax, ref_distribution=ref_dist)
    ax.set_title(f"Log-Probability Comparison ({dim}D)")
    scatter_fig.savefig(f"gmm_eval_logprobs_{dim}D_{arch}.png")
    
    scatter_fig, ax = plt.subplots()
    evaluator.nll_norm_scatter_plot(ax=ax, ref_distribution=ref_dist)
    ax.set_title(f"NLL vs Latent Norm ({dim}D)")
    scatter_fig.savefig(f"gmm_eval_nll_vs_latent_norms_{dim}D_{arch}.png")

plt.show()    

# Eval 2D

In [None]:
from matplotlib.patches import Circle

base_path =  base_dir + "/0_gaussian_mixture_2D"

pkl_path = sorted(glob(f"{base_path}/*.pkl"))[-1]
pt_path = sorted(glob(f"{base_path}/*.pt"))[-1]

dim = 2
distribution=GMM(
    loc=torch.stack([-torch.ones(dim), torch.ones(dim)]), 
    covariance_matrix=torch.stack([torch.diag(torch.Tensor([5., .5])) ,torch.eye(dim)]),
    mixture_weights=torch.ones(2)/2
)
ref_dist = distribution

model = from_checkpoint(pkl_path, pt_path)

with torch.no_grad():
    ds = distribution.sample([1000])
    latents = model.backward(ds) - model.base_distribution.loc

## Plot Determinant

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

_, ax = plt.subplots(figsize=(5, 5))
ax.set_facecolor('white')

ax
# 1. Generate sample data (replace with your dataset)
np.random.seed(42)
x = ds[:, 0].numpy()
y = ds[:, 1].numpy()

# 2. Create grid for density evaluation
x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

# 3. Calculate density using Kernel Density Estimation (KDE)

density = np.reshape(torch.exp(ref_dist.log_prob(torch.Tensor(positions).permute(1,0))).detach().T, x_grid.shape)


# Contour lines only
contour = ax.contour(x_grid, y_grid, density, levels=8, colors='black', linewidths=0.5)

# Add data points overlay (optional)
with torch.no_grad():
    #c = torch.exp(distribution.log_prob(ds))
    c = torch.exp(model.log_abs_det_jacobian(ds))
    if c.dim() == 0:
        c = [c] * len(ds)
scatter = ax.scatter(x, y, s=5, c=c, cmap="magma", alpha=1)
plt.colorbar(scatter, label='$|\\det J_{\\phi}(x)|$')

# Customize plot
ax.set_title('Data Distribution 2D GMM')
ax.set_aspect('equal')
ax.grid(alpha=0.2)
plt.tight_layout()

plt.savefig(f"contour_2d_gmm_{arch}_log_abs_det_jac.png")

###############################################3

_, ax = plt.subplots(figsize=(5, 5))
ax.set_facecolor('white')

# 1. Generate sample data (replace with your dataset)
np.random.seed(42)
x = latents[:, 0].numpy()
y = latents[:, 1].numpy()

# 2. Create grid for density evaluation
x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

# 3. Calculate density using Kernel Density Estimation (KDE)
kde = gaussian_kde(np.vstack([x, y]))
#density = np.reshape(kde(positions).T, x_grid.shape)

# Add data points overlay (optional)

scatter = ax.scatter(x, y, s=5, c=c, cmap="magma", alpha=1)
plt.colorbar(scatter, label='$|\\det J_{\\phi}(x)|$')

# Customize plot
ax.set_title(f'Centered Latent Data Distribution\n2D GMM ({arch})')

scale = softplus(model.base_distribution.scale_unconstrained)
ax.add_patch(Circle((0., 0.), radius=scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=1.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=2*scale, fill=False, edgecolor='black', linewidth=.5 , linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=2.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle="--"))
ax.add_patch(Circle((0., 0.), radius=3*scale, fill=False, edgecolor='black', linewidth=.5, linestyle="--"))
ax.set_aspect('equal')
ax.grid(alpha=0.2)
plt.tight_layout()
plt.savefig(f"contour_latent_2d_gmm_{arch}_log_abs_det_jac.png")

plt.show()

## Plot Densities

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import gaussian_kde

_, ax = plt.subplots(figsize=(5, 5))
ax.set_facecolor('white')

ax
# 1. Generate sample data (replace with your dataset)
np.random.seed(42)
x = ds[:, 0].numpy()
y = ds[:, 1].numpy()

# 2. Create grid for density evaluation
x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

# 3. Calculate density using Kernel Density Estimation (KDE)

density = np.reshape(torch.exp(ref_dist.log_prob(torch.Tensor(positions).permute(1,0))).detach().T, x_grid.shape)


# Contour lines only
contour = ax.contour(x_grid, y_grid, density, levels=8, colors='black', linewidths=0.5)


# Add data points overlay (optional)
with torch.no_grad():
    c = torch.exp(distribution.log_prob(ds))
    #c = torch.exp(model.log_abs_det_jacobian(ds))
scatter = ax.scatter(x, y, s=5, c=c, cmap="magma", alpha=1)
plt.colorbar(scatter, label='Data Density')

# Customize plot
ax.set_title('Data Distribution 2D GMM')
ax.set_aspect('equal')
ax.grid(alpha=0.2)
plt.tight_layout()

plt.savefig(f"contour_2d_gmm_{arch}_density.png")

###############################################3

_, ax = plt.subplots(figsize=(5, 5))
ax.set_facecolor('white')

# 1. Generate sample data (replace with your dataset)
np.random.seed(42)
x = latents[:, 0].numpy()
y = latents[:, 1].numpy()

# 2. Create grid for density evaluation
x_grid, y_grid = np.mgrid[x.min():x.max():100j, y.min():y.max():100j]
positions = np.vstack([x_grid.ravel(), y_grid.ravel()])

# 3. Calculate density using Kernel Density Estimation (KDE)
kde = gaussian_kde(np.vstack([x, y]))
#density = np.reshape(kde(positions).T, x_grid.shape)


# Add data points overlay (optional)

scatter = ax.scatter(x, y, s=5, c=c, cmap="magma", alpha=1)
#plt.colorbar(scatter, label='Data Density')

# Customize plot
ax.set_title(f'Centered Latent Data Distribution\n2D GMM ({arch})')

ax.add_patch(Circle((0., 0.), radius=scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=1.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=2*scale, fill=False, edgecolor='black', linewidth=.5 , linestyle='--'))
ax.add_patch(Circle((0., 0.), radius=2.5*scale, fill=False, edgecolor='black', linewidth=.5, linestyle="--"))
ax.add_patch(Circle((0., 0.), radius=3*scale, fill=False, edgecolor='black', linewidth=.5, linestyle="--"))
ax.set_aspect('equal')
ax.grid(alpha=0.2)
plt.tight_layout()
plt.savefig(f"contour_latent_2d_gmm_{arch}_density.png")

plt.show()