In [None]:
%load_ext autoreload
%autoreload 2

from guided_flow.distributions.checkerboard import CheckerboardDistribution
from guided_flow.distributions.circle import CircleDistribution, ConcentricCircleDistribution
from guided_flow.distributions.moon import MoonDistribution
from guided_flow.distributions.uniform import UniformDistribution
from guided_flow.distributions.gaussian import EightGaussiansDistribution, GaussianDistribution
from guided_flow.distributions.s_curve import SCurveDistribution
from guided_flow.distributions.laplace import LaplaceDistribution
from guided_flow.distributions.spiral import SpiralDistribution
from sklearn.neighbors import KernelDensity as KDE
import matplotlib.pyplot as plt
import numpy as np
import torch

rows = 4
fig, ax = plt.subplots(rows, 3)
for i, dist in enumerate([
    LaplaceDistribution, CheckerboardDistribution, CircleDistribution, 
    ConcentricCircleDistribution, MoonDistribution, UniformDistribution, 
    EightGaussiansDistribution, GaussianDistribution, SCurveDistribution,
    SpiralDistribution
]):
    print(dist.__name__)

    dist = dist()

    x = dist.sample(500, device='cuda').cpu()

    ax[i // 3, i % 3].scatter(x[:, 0], x[:, 1], s=2)

    ax[i // 3, i % 3].set_title(dist)
    if dist.__str__() == "LaplaceDistribution" or dist.__str__() == "GaussianDistribution":
        # heatmap to show probability density based on the samples generated using KDE
        grid = torch.linspace(-5, 5, 100)
        grid = torch.meshgrid(grid, grid)
        ax[i // 3, i % 3].imshow(dist.log_prob(
            torch.cat([grid[0].reshape(-1, 1), grid[1].reshape(-1, 1)], dim=1)
        ).numpy().reshape(100, 100), extent=[-5, 5, -5, 5], origin='lower', cmap='viridis')

        ax[i // 3, i % 3].set_xlim(-5, 5)
        ax[i // 3, i % 3].set_ylim(-5, 5)
    else:
        # heatmap to show probability density based on the samples generated using KDE
        kde = KDE(bandwidth=0.1, kernel='exponential')
        kde.fit(x)
        grid = np.linspace(-3, 3, 100)
        grid = np.meshgrid(grid, grid)
        ax[i // 3, i % 3].imshow(kde.score_samples(np.concatenate([grid[0].reshape(-1, 1), grid[1].reshape(-1, 1)], axis=1)).reshape(100, 100), extent=[-3, 3, -3, 3], origin='lower', cmap='viridis')

        ax[i // 3, i % 3].set_xlim(-1.5, 1.5)
        ax[i // 3, i % 3].set_ylim(-1.5, 1.5)
    ax[i // 3, i % 3].set_aspect('equal')

for j in range(i + 1, rows * 3):
    ax[j // 3, j % 3].set_visible(False)
    ax[j // 3, j % 3].set_axis_off()

# fig.tight_layout()
# set vspace
fig.subplots_adjust(hspace=0.5)
fig.set_size_inches(6 * rows / 3, 10)

In [None]:
from guided_flow.distributions.base import get_distribution

def plot_J_weighted_log_prob(x1_dist: str, scale: float = 1.0):
    x1_dist = get_distribution(x1_dist)
    J = x1_dist.get_J(torch.randn(1000, 2, device='cuda'))

    x = torch.linspace(-3, 3, 100)
    y = torch.linspace(-3, 3, 100)
    X, Y = torch.meshgrid(x, y, indexing="ij")
    XY = torch.stack([X, Y], dim=-1).reshape(-1, 2).to('cuda')
    try:
        log_prob = x1_dist.log_prob(XY)
    except Exception as e:
        print("No log_prob for", x1_dist, e)
        log_prob = 0.
        sample = x1_dist.sample(1000, device='cuda').cpu()
        plt.scatter(sample[:, 0], sample[:, 1], s=2)
    p = -x1_dist.get_J(XY) * scale + log_prob
    plt.imshow(p.reshape(100, 100).t().cpu().numpy(), cmap='jet', vmin=None, vmax=None, extent=[-3, 3, -3, 3], origin='lower')
    plt.colorbar()
    plt.show()
    
plot_J_weighted_log_prob('spiral')