In [None]:
# low-probability estimation using gaussians as the example distribution, and differentiable cost functions

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import matplotlib.pyplot as plt
import numpy as np

import tqdm

from langevin import sample_langevin_conditional, SampleConfig, AnimateConfig

cost_threshold = -7.5

def cost_function(xs):
    return xs[:, 1].pow(2) + xs[:, 0]

def prior_density(xs):
    return torch.exp(-xs.pow(2).sum(dim=-1) / 2) / (2 * np.pi) ** (xs.shape[-1] / 2)

def log_prior_density(xs):
    return -0.5 * xs.pow(2).sum(dim=-1) - np.log(2 * np.pi) * (xs.shape[-1] / 2)

img_size = 500

full_lims = ((-10, 10), (-10, 10))
focus_lims = ((-10, 0), (-5, 5))

xs_full, ys_full = torch.meshgrid(torch.linspace(full_lims[0][0], full_lims[0][1], img_size), torch.linspace(full_lims[1][0], full_lims[0][1], img_size))
xs_focus, ys_focus = torch.meshgrid(torch.linspace(focus_lims[0][0], focus_lims[0][1], img_size), torch.linspace(focus_lims[1][0], focus_lims[1][1], img_size))

# plot density
points_full = torch.stack([xs_full, ys_full], dim=-1)

density_full = prior_density(points_full)

cost_full = cost_function(points_full.reshape(-1, 2)).reshape(points_full.shape[:-1])

plt.figure(figsize=(10, 8))
plt.contourf(xs_full.numpy(), ys_full.numpy(), cost_full.numpy(), levels=20, cmap='inferno')
plt.xlabel('x')
plt.ylabel('y') 
plt.title('Cost Function')
plt.show()

plt.figure(figsize=(10, 8))
plt.contourf(xs_full.numpy(), ys_full.numpy(), density_full.numpy(), levels=20, cmap='inferno')
plt.colorbar(label='Density')
plt.contour(xs_full.numpy(), ys_full.numpy(), cost_full.numpy(), levels=[cost_threshold], colors='red')
plt.xlabel('x')
plt.ylabel('y') 
plt.title('Prior Density')
plt.show()

points_focus = torch.stack([xs_focus, ys_focus], dim=-1)
density_focus = prior_density(points_focus)
cost_focus = cost_function(points_focus.reshape(-1, 2)).reshape(points_focus.shape[:-1])

conditional_density_focus = density_focus.clone()
conditional_density_focus[cost_focus > cost_threshold] = 0
conditional_density_focus /= conditional_density_focus.sum() * (focus_lims[0][1] - focus_lims[0][0]) / img_size

plt.figure(figsize=(10, 8))
plt.contourf(xs_focus.numpy(), ys_focus.numpy(), conditional_density_focus.numpy(), levels=20, cmap='inferno')
plt.colorbar(label='Density')
plt.xlabel('x')
plt.ylabel('y') 
plt.title('Conditional Density')
plt.show()

# check that the prior density integrates to 1
integral = density_full.sum() * (full_lims[0][1] - full_lims[0][0]) * (full_lims[1][1] - full_lims[1][0]) / (img_size * img_size)
print(f"Integral of density: {integral.item():.6f}")

conditional_density_full = density_full.clone()
conditional_density_full[cost_full > cost_threshold] = 0
normalizing_constant = conditional_density_full.sum() * (full_lims[0][1] - full_lims[0][0]) * (full_lims[1][1] - full_lims[1][0]) / (img_size * img_size)

print(normalizing_constant)

In [None]:
sample_config = SampleConfig(
    steps=10_000, 
    start_beta=1, 
    end_beta=100,
    animate=AnimateConfig(
        output_dir="langevin_samples",
        capture_every=100,
        duration=10000,
        size=200,
        d_lims=(0, 5),
        xlims=(-10, 0),
        ylims=(-5, 5)
    ),
    progress=True
)

samples = sample_langevin_conditional(
    torch.randn(10_000, 2, dtype=torch.float64), 
    log_prior_density, 
    cost_function, 
    cost_threshold, 
    sample_config
)

plt.figure(figsize=(10, 8))
plt.hist2d(samples[:, 0].numpy(), samples[:, 1].numpy(), bins=100, cmap='inferno', range=[full_lims[0], full_lims[1]], density=True)
plt.colorbar()
plt.title('Samples')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

In [3]:
def kernel_density(xs, ps, bandwidth=1):
    # calculate the kernel density estimate of xs given the known samples ps
    # xs: (n, d)
    # ps: (m, d)
    # bandwidth: float

    def rbf(a, b, bandwidth):
        return torch.exp(-((a - b).pow(2).sum(dim=-1) / (2 * bandwidth**2)))

    def normalizing_constant(ps, bandwidth):
        n = ps.shape[0]
        dim = ps.shape[1]
        return (2 * np.pi * bandwidth**2)**(dim / 2) * n

    print(xs[:, None, :].shape)
    print(ps[None, :, :].shape)

    rbfs = rbf(xs[:, None, :], ps[None, :, :], bandwidth).sum(dim=-1)
    print(rbfs.shape)
    return rbfs / normalizing_constant(ps, bandwidth)

In [4]:
# xs, ys = torch.meshgrid(torch.linspace(full_lims[0][0], full_lims[0][1], 100), torch.linspace(full_lims[1][0], full_lims[1][1], 100))

# grid_points = torch.stack([xs.flatten(), ys.flatten()], dim=1)

# print(grid_points.shape)
# print(samples.shape)

# density = kernel_density(grid_points, samples[len(samples)//2:], bandwidth=0.1)

# plt.figure(figsize=(8, 8))
# plt.contourf(xs.numpy(), ys.numpy(), density.reshape(100, 100).numpy())
# plt.colorbar()

# costs = cost_function(grid_points)
# plt.contour(xs.numpy(), ys.numpy(), costs.reshape(100, 100).numpy(), levels=[cost_threshold], colors='red')

# plt.title('Kernel Density Estimate')
# plt.xlabel('x')
# plt.ylabel('y')
# plt.show()

# # Integrate the density over the grid to verify it sums to approximately 1
# dx = (full_lims[0][1] - full_lims[0][0]) / (xs.shape[0] - 1)
# dy = (full_lims[1][1] - full_lims[1][0]) / (ys.shape[0] - 1)

# integral = density.sum() * dx * dy
# print(f"Integral of density: {integral.item():.4f}")  # Should be close to 1.0

# def prior_density(xs):
#     return torch.exp(-(xs**2).sum(dim=1) / 2) / np.sqrt(2 * np.pi)

# q_values = kernel_density(samples[:len(samples)//2], samples[len(samples)//2:], bandwidth=0.1)
# p_values = prior_density(samples[:len(samples)//2])

# print((p_values / q_values).mean())

In [None]:
# Fit a Gaussian distribution to the samples

# but scale up the covariance a bit

mean = samples.mean(dim=0)
cov = torch.cov(samples.T)
# cov *= 2

print("Fitted Gaussian parameters:")
print(f"Mean: {mean}")
print(f"Covariance matrix:\n{cov}")

# Visualize the fitted Gaussian
xs, ys = torch.meshgrid(torch.linspace(full_lims[0][0], full_lims[0][1], 100), 
                       torch.linspace(full_lims[1][0], full_lims[1][1], 100))
grid_points = torch.stack([xs.flatten(), ys.flatten()], dim=1)

# Calculate Gaussian density on grid
diff = grid_points - mean.unsqueeze(0)
inv_cov = torch.inverse(cov)
mahalanobis = torch.sum(diff @ inv_cov * diff, dim=1)
density = torch.exp(-0.5 * mahalanobis) / (2 * np.pi * torch.sqrt(torch.det(cov)))

plt.figure(figsize=(8, 8))
plt.contourf(xs.numpy(), ys.numpy(), density.reshape(100, 100).numpy())
plt.colorbar()
plt.title('Fitted Gaussian Distribution')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

In [None]:
# Generate samples from the fitted Gaussian (q(x))
n_samples = 100_000
proposal_samples = torch.distributions.MultivariateNormal(mean, cov).sample((n_samples,)).to(torch.float64)

plt.figure(figsize=(8, 8))
plt.hist2d(proposal_samples[:, 0].numpy(), proposal_samples[:, 1].numpy(), bins=100, cmap='inferno', range=[full_lims[0], full_lims[1]], density=True)
plt.colorbar()
plt.title('Proposal Samples')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

def log_proposal_density(x):
    k = x.shape[-1]
    log_normalizer = -(k / 2) * np.log(2 * np.pi) - 0.5 * np.log(torch.det(cov))
    diff = x - mean.unsqueeze(0)
    log_unnormalized = -0.5 * torch.sum(diff @ inv_cov * diff, dim=1)
    return log_normalizer + log_unnormalized

mask = cost_function(proposal_samples) < cost_threshold
log_weights = log_prior_density(proposal_samples[mask]) - log_proposal_density(proposal_samples[mask])
max_log_weight = torch.max(log_weights)
log_Z = max_log_weight + torch.log(torch.sum(torch.exp(log_weights - max_log_weight)) / n_samples)

print(f"Estimated log normalizing constant: {log_Z.item():.2f}")
print(f"Estimated normalizing constant: {torch.exp(log_Z).item():.2e}")
print(f"Actual normalizing constant: {normalizing_constant.item():.2e}")