In [1]:
# low-probability estimation using gaussians as the example distribution, and differentiable cost functions

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.func import grad

import matplotlib.pyplot as plt
import numpy as np

import tqdm

# define the cost function
def cost_function(x, y):
    return x**2 + y

cost_threshold = -5

def density(x, y):
    return torch.exp(-(x**2 + y**2) / 2) / np.sqrt(2 * np.pi)

In [None]:
# xlims = (-10, 10)   
# ylims = (-10, 10)

xlims = (-10, 10)
ylims = (-10, 10)

xs, ys = torch.linspace(ylims[0], ylims[1], 500), torch.linspace(xlims[0], xlims[1], 500)
xs, ys = torch.meshgrid(xs, ys)

densities = density(xs, ys)

print(densities.shape)

# Calculate cost function values
costs = cost_function(xs, ys)

print(costs.shape)

plt.contour(ys, xs, costs, levels=[cost_threshold], colors='red')
plt.xlim(xlims)
plt.ylim(ylims)
plt.imshow(densities, extent=xlims + ylims, cmap='plasma')
plt.colorbar()
plt.show()

In [None]:
# Calculate probability that cost is less than threshold by integrating
# the indicator function times the density over the plane

# Create indicator function (1 where cost < threshold, 0 otherwise)
indicator = (costs < cost_threshold).float()

print(indicator.shape, densities.shape)

# Multiply by density
integrand = indicator * densities

# Calculate area element (dx * dy)
dx = (xlims[1] - xlims[0]) / (costs.shape[1] - 1)
dy = (ylims[1] - ylims[0]) / (costs.shape[0] - 1)
area_element = dx * dy

# Integrate over plane by summing and multiplying by area element
probability = integrand.sum() * area_element

print(f"Probability that cost is less than {cost_threshold}: {probability:.2e}")

In [None]:
# plot conditional density
conditional_density = densities.clone()
conditional_density[costs >= cost_threshold] = 0

# Normalize by dividing by the integral (mean * area)
area = (xlims[1] - xlims[0]) * (ylims[1] - ylims[0])
conditional_density = conditional_density / (conditional_density.mean() * area)


plt.imshow(conditional_density, extent=xlims + ylims, cmap='plasma')
plt.colorbar()
plt.show()

In [None]:
# now, to sample from the conditional density, we sample from p'(x) \propto p(x) * exp(max(0, \beta * cost(x)))

# def prior_density(xs):
#     # xs: (n, 2)
#     return torch.exp(-(xs**2).sum(dim=1) / 2) / np.sqrt(2 * np.pi)

# cost_threshold = -5

def log_prior_density(xs):
    # xs: (n, 2)
    return - (xs**2).sum(dim=1) / 2

def log_cost(xs):
    # xs: (n, 2)
    return xs[:, 1]**2 + xs[:, 0]

def log_conditional_density(xs, beta):
    log_costs = -torch.relu(beta * (log_cost(xs) - cost_threshold))
    # exp_costs = 1
    log_prior_densities = log_prior_density(xs)

    if torch.any(torch.isnan(log_prior_densities + log_costs)):
        print(xs)
        print(log_prior_densities)
        print(log_costs)
        raise ValueError("NaN values in conditional density")

    return log_prior_densities + log_costs
    # return log_costs

def sample_step(xs_t, beta, step_size):
    # Langevin dynamics step
    # Calculate gradient of log density
    xs_t.requires_grad_(True)
    log_density = log_conditional_density(xs_t, beta)
    grad_log_density = torch.autograd.grad(log_density.sum(), xs_t)[0]

    if torch.any(torch.isnan(grad_log_density)):
        points_with_nan = xs_t[torch.any(torch.isnan(grad_log_density), dim=1)]
        print(points_with_nan)
        costs = log_cost(points_with_nan)
        densities = log_conditional_density(points_with_nan, beta)
        grads = torch.autograd.grad(densities.sum(), points_with_nan)[0]
        print(grads)
        raise ValueError("NaN values in gradient of log density")
    
    # Update with gradient and noise
    noise = torch.randn_like(xs_t) * np.sqrt(2 * step_size)
    xs_next = xs_t + step_size * grad_log_density + noise
    
    return xs_next.detach()

xs = torch.randn(10_000, 2, dtype=torch.float64)

print(xs.shape)

# Run chain with decaying parameters
n_steps = 10000
beta_start = 10.0
beta_end = 100.0
step_size_start = 0.001
step_size_end = 0.000001

image_idx = 0

for i in tqdm.tqdm(range(n_steps)):
    progress = i / (n_steps-1)
    # Exponentially decay step size
    step_size = step_size_start * (step_size_end/step_size_start)**progress
    # Linearly grow beta
    beta = beta_start + progress * (beta_end - beta_start)
    
    # plot xs, save image
    if i % 100 == 0:
        plt.hist2d(xs[:, 0], xs[:, 1], bins=100, range=[[-10, 0], [-5, 5]], cmap='plasma', density=True, vmin=0, vmax=6)
        plt.colorbar()
        # plt.scatter(xs[:, 0], xs[:, 1], s=0.1, marker='.', c='red', alpha=0.1)
        # plt.xlim(-10, 0)
        # plt.ylim(-5, 5)
        plt.title(f'beta: {beta:.2f}, lr: {step_size:.2e}')
        plt.savefig(f'images/xs_{image_idx}.png')
        plt.close()
        image_idx += 1

    # print(step_size, step_size_end, step_size_start)

    # Take MCMC step
    xs = sample_step(xs, beta, step_size)

samples = xs.detach()

In [6]:
import glob
from PIL import Image

# Create the GIF from saved images
images = []
filenames = [f"images/xs_{i}.png" for i in range(image_idx)]
for filename in filenames:
    images.append(Image.open(filename))
    
# Save the GIF
images[0].save(
    'images/animation.gif',
    save_all=True,
    append_images=images[1:],
    duration=50,  # Duration for each frame in milliseconds
    loop=0
)


In [7]:
def kernel_density(xs, ps, bandwidth=1):
    # calculate the kernel density estimate of xs given the known samples ps
    # xs: (n, d)
    # ps: (m, d)
    # bandwidth: float

    def rbf(a, b, bandwidth):
        return torch.exp(-((a - b).pow(2).sum(dim=-1) / (2 * bandwidth**2)))

    def normalizing_constant(ps, bandwidth):
        n = ps.shape[0]
        dim = ps.shape[1]
        return (2 * np.pi * bandwidth**2)**(dim / 2) * n

    print(xs[:, None, :].shape)
    print(ps[None, :, :].shape)

    rbfs = rbf(xs[:, None, :], ps[None, :, :], bandwidth).sum(dim=-1)
    print(rbfs.shape)
    return rbfs / normalizing_constant(ps, bandwidth)

In [None]:
xs, ys = torch.meshgrid(torch.linspace(xlims[0], xlims[1], 100), torch.linspace(ylims[0], ylims[1], 100))

grid_points = torch.stack([xs.flatten(), ys.flatten()], dim=1)

print(grid_points.shape)
print(samples.shape)

density = kernel_density(grid_points, samples[len(samples)//2:], bandwidth=0.1)

plt.figure(figsize=(8, 8))
plt.contourf(xs.numpy(), ys.numpy(), density.reshape(100, 100).numpy())
plt.colorbar()

costs = cost_function(ys, xs)
plt.contour(xs, ys, costs, levels=[cost_threshold], colors='red')

plt.title('Kernel Density Estimate')
plt.xlabel('x')
plt.ylabel('y')
plt.show()

In [None]:
# Integrate the density over the grid to verify it sums to approximately 1
dx = (xlims[1] - xlims[0]) / (xs.shape[0] - 1)
dy = (ylims[1] - ylims[0]) / (ys.shape[0] - 1)

integral = density.sum() * dx * dy
print(f"Integral of density: {integral.item():.4f}")  # Should be close to 1.0


In [10]:
def prior_density(xs):
    return torch.exp(-(xs**2).sum(dim=1) / 2) / np.sqrt(2 * np.pi)

q_values = kernel_density(samples[:len(samples)//2], samples[len(samples)//2:], bandwidth=0.1)
p_values = prior_density(samples[:len(samples)//2])

print((p_values / q_values).mean())

KeyboardInterrupt: 