In [1]:
import torch
import os
from dem.energies.gmm_energy import GMM
import matplotlib.pyplot as plt
import umap
import numpy as np
from scipy.stats import gaussian_kde
import warnings

warnings.filterwarnings("ignore", category=FutureWarning)

  from pkg_resources import DistributionNotFound, get_distribution


In [2]:
base_path = "outputs/gmm_8modes_dim"
# dims = [2, 4, 8, 16, 32, 64, 128]  # Adjust as needed for higher dimensions
dims = [2]
gen_files = [f"{base_path}{d}/train/samples_10000.pt" for d in dims]

In [3]:
def create_and_save_plots(gen_samples_unnorm, train_samples_unnorm, dim, save_dir):
    # Subsample train samples
    idx = np.random.choice(train_samples_unnorm.shape[0], 10000, replace=False)
    train_samples_sub = train_samples_unnorm[idx]

    # --- UMAP reduction if needed ---
    if dim > 2:
        reducer = umap.UMAP(n_components=2, n_jobs=1)
        train_samples_2d = reducer.fit_transform(train_samples_sub.numpy())
        gen_samples_2d = reducer.transform(gen_samples_unnorm.numpy())
    else:
        train_samples_2d = train_samples_sub.numpy()
        gen_samples_2d = gen_samples_unnorm.numpy()

    # --- Build grid for contours ---
    x = np.linspace(min(train_samples_2d[:, 0].min(), gen_samples_2d[:, 0].min()),
                    max(train_samples_2d[:, 0].max(), gen_samples_2d[:, 0].max()), 200)
    y = np.linspace(min(train_samples_2d[:, 1].min(), gen_samples_2d[:, 1].min()),
                    max(train_samples_2d[:, 1].max(), gen_samples_2d[:, 1].max()), 200)
    X, Y = np.meshgrid(x, y)
    grid = np.vstack([X.ravel(), Y.ravel()]).T

    # --- KDE for real ---
    train_samples_2d_T = train_samples_2d.T
    kde_real = gaussian_kde(train_samples_2d_T, bw_method=0.2)
    Z_real = kde_real(grid.T).reshape(X.shape)

    # --- Plot 1: Real contours + Generated scatter ---
    plt.figure(figsize=(7, 6))
    plt.scatter(train_samples_2d[:, 0], train_samples_2d[:, 1], s=8, alpha=0.6, label="Real")
    plt.scatter(gen_samples_2d[:, 0], gen_samples_2d[:, 1], s=8, alpha=0.3, label="Generated")
    plt.title(f"Generated vs Real Contours (Dim {dim})")
    plt.legend()
    plt.savefig(os.path.join(save_dir, f"gen_vs_realContours_dim{dim}_.png"), dpi=300, bbox_inches="tight")
    plt.close()

    # --- Plot 2: Generated contours + Real scatter ---
    plt.figure(figsize=(7, 6))
    plt.contour(X, Y, Z_real, levels=8, cmap="inferno", linewidths=1.1)
    plt.scatter(gen_samples_2d[:, 0], gen_samples_2d[:, 1], s=8, alpha=0.6, label="Generated")
    plt.title(f"Real vs Generated Contours (Dim {dim})")
    plt.legend()
    plt.savefig(os.path.join(save_dir, f"real_vs_genContours_dim{dim}_.png"), dpi=300, bbox_inches="tight")
    plt.close()

In [4]:
for gen_path in gen_files:
    # Extract dimension from path
    dim = int(gen_path.split("dim")[1].split("/")[0])
    
    # Instantiate GMM with new settings
    gmm_obj = GMM(
        dimensionality=dim,
        n_mixes=8,
        loc_scaling=8,
        log_var_scaling=1.0,  # Normalization factor as log variance scaling
        data_normalization_factor=11,
        device="cpu",
        train_set_size=100000,
    )

    # Get training set (reference)
    train_samples = gmm_obj.setup_train_set().cpu()
    train_samples_unnorm = gmm_obj.unnormalize(train_samples)

    # Load generated set
    gen_samples = torch.load(gen_path).cpu()
    gen_samples_unnorm = gmm_obj.unnormalize(gen_samples)

    # Create save directory from path
    save_dir = os.path.dirname(gen_path)
    
    # Create and save plots
    create_and_save_plots(gen_samples_unnorm, train_samples_unnorm, dim, save_dir)