Simplified setup and test for trivial energy function (anisotropic multivariate normal distribution).

To run, go to this directory, run
```
source ./setup.sh 
jupyter lab ./langevin_step.ipynb
```
and execute all cells.

If there are deps issues, try making a `virtualenv` given `requirements.freeze.txt`.

Context:
 - <https://github.com/google-research/ibc/issues/6>

In [None]:
import dataclasses as dc
import os
import random

import einops
import matplotlib.pyplot as plt
import moviepy
import moviepy.editor as mpy
import numpy as np
import torch
from torch import nn
from torch.autograd.functional import jacobian
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.uniform import Uniform

In [None]:
DimY = 2


@dc.dataclass
class Config:
    name: str
    # How many iterations to take.
    n_iters: int
    # How many chains in parallel.
    num_chains: int
    # How many samples to keep along all chains.
    history_count: int
    # Interval to use for plotting.
    iteration_interval: int
    # Show first iteration (random init) in plotting.
    show_first: bool

    # Langevin.
    y_min: torch.Tensor
    y_max: torch.Tensor
    y_mean: torch.Tensor
    y_std: torch.Tensor
    step_size_init: float
    step_size_final: float
    step_size_power: float

In [None]:
def seed(value):
    random.seed(value)
    np.random.seed(value)
    torch.manual_seed(value)


def uniform_sample(y_min, y_max, num_chains, batch_size):
    lb = y_min.expand(batch_size, num_chains, DimY)
    ub = y_max.expand(batch_size, num_chains, DimY)
    return Uniform(lb, ub).sample()


def gradient_wrt_act(ebm_net, x, ys):
    """Same as in google-research/ibc."""
    assert not torch.is_grad_enabled()

    def Ex_sum(ys):
        # Adapt trick from:
        # https://discuss.pytorch.org/t/computing-batch-jacobian-efficiently/80771/5  # noqa
        energies = ebm_net(x, ys)
        return energies.sum()

    # WARNING: This may be rather slow.
    with torch.set_grad_enabled(True):
        dE_dys = jacobian(Ex_sum, ys)
    assert dE_dys.shape == ys.shape
    return dE_dys

In [None]:
def get_step_size(config, iteration):
    blend = iteration / (config.n_iters - 1)
    blend = blend ** config.step_size_power
    step_size = config.step_size_init + blend * (config.step_size_final - config.step_size_init)
    return step_size


def langevin_step(y_samples, dE_dys, step_size, use_ibc_style=True):
    # Independent draw for covariance.
    y_noise = torch.normal(
        mean=0.0,
        std=1.0,
        size=y_samples.shape,
        device=y_samples.device,
    )
    # Perturb samples according to gradient and desired noise level.
    if use_ibc_style:
        # From google-research/ibc.
        delta_y = -step_size * (0.5 * dE_dys + y_noise)
    else:
        # Correct formulation. See:
        # https://en.wikipedia.org/wiki/Metropolis-adjusted_Langevin_algorithm  # noqa
        # Note that this uses different step size scaling.
        delta_y = -step_size * 0.5 * dE_dys + np.sqrt(step_size) * y_noise
    # Shift current actions
    y_samples = y_samples + delta_y
    return y_samples


def langevin_sample(config, ebm_net, x, num_chains, callback=None, *, use_ibc_style=True):
    assert not torch.is_grad_enabled()
    N = x.shape[0]
    # Draw initial samples.
    y_samples = uniform_sample(config.y_min, config.y_max, num_chains, batch_size=N)
    for i in range(config.n_iters + 1):
        is_last = i == config.n_iters
        if callback is not None:
            callback(i, y_samples)
        if not is_last:
            # Compute gradient.
            dE_dys = gradient_wrt_act(ebm_net, x, y_samples)
            # Compute step size given current iteration.
            step_size = get_step_size(config, i)
            # Produce next set of samples (driving towards typical set).
            y_samples = langevin_step(y_samples, dE_dys, step_size, use_ibc_style=use_ibc_style)
    return y_samples

In [None]:
def torch_log_normal_pdf(x, mu, std):
    """Log of multivariate normal distribution."""
    cov = torch.diag(std ** 2)
    return MultivariateNormal(mu, cov).log_prob(x)


class NormalEbm(nn.Module):
    """
    Provides an energy function that represents a normal distribution.
    """

    def __init__(self, y_mu, y_std):
        super().__init__()
        assert y_std.shape == y_mu.shape
        self._y_std = y_std
        self._y_mu = y_mu

    def forward(self, x, ys):
        N, K, _ = ys.shape
        ys = einops.rearrange(ys, "N K DimY -> (N K) DimY")
        # N.B. Incoprorate log into pdf computation so we avoid numeric
        # problems in autograd.
        probs = torch_log_normal_pdf(ys, self._y_mu, self._y_std)
        probs = einops.rearrange(probs, "(N K) -> N K 1", N=N)
        energies = -probs
        return energies

In [None]:
@torch.no_grad()
def plot_2d_ebm(config, ebm_net, x, grid_size=200, alpha=None):
    # We will denote unnormalized action coordinates as `u` and `v`.
    action_u = torch.linspace(config.y_min[0], config.y_max[0], steps=grid_size)
    action_v = torch.linspace(config.y_min[1], config.y_max[1], steps=grid_size)
    action_v_grid, action_u_grid = torch.meshgrid(action_u, action_v)
    action_us = einops.rearrange(action_u_grid, "H W -> (H W)")
    action_vs = einops.rearrange(action_v_grid, "H W -> (H W)")
    ys = einops.rearrange([action_us, action_vs], "C HW -> () HW C").to(x)
    Zs = ebm_net(x, ys)
    Z_grid = einops.rearrange(Zs, "1 N 1 -> N")
    # Show probabilities used for sampling.
    num_grid_samples = grid_size ** 2
    Z_grid = torch.exp(-Z_grid)
    Z_grid = einops.rearrange(Z_grid, "(H W) -> H W", H=grid_size)
    mesh = plt.pcolormesh(
        action_u_grid.numpy(),
        action_v_grid.numpy(),
        Z_grid.numpy(),
        cmap="cool",
        alpha=alpha,
        shading="auto",
    )
    return mesh


def plot_2d_ebm_callback(config, ebm_net, iteration, x, y_samples, *, for_pub=False):
    N, _ = x.shape
    assert N == 1
    if for_pub:
        fig, ax = plt.subplots(figsize=(3, 2))
        alpha = 0.5
    else:
        fig, ax = plt.subplots()
        alpha = 1.0
    mesh = plot_2d_ebm(config, ebm_net, x)
    plt.colorbar(mesh, label=r"$p(y)$")
    plt.grid(False)
    plt.axis("scaled")
    y_samples = y_samples.squeeze(0)
    plt.scatter(
        y_samples[:, 0], y_samples[:, 1], marker="x", s=10, color="blue", alpha=alpha
    )
    if for_pub:
        return fig
    else:
        plt.title(f"Iter {iteration} / {config.n_iters}")
        image = mpl_figure_to_image(fig)
        plt.close(fig)
        return image


def repeat_last(images, *, count):
    # To ensure last frame is saved.
    assert len(images) > 0
    return images + [images[-1]] * count


def merge_rgba_to_rgb(rgba, *, bg_color=[255, 255, 255]):
    assert rgba.dtype == np.uint8
    rgb = rgba[..., :3]
    alpha = rgba[..., [3]] / 255.0
    rgb = alpha * rgb + bg_color * (1 - alpha)
    return rgb.astype(np.uint8)


def mpl_figure_to_image(fig):
    fig.canvas.draw()
    buffer = fig.canvas.buffer_rgba()
    rgba = np.asarray(buffer)
    rgb = merge_rgba_to_rgb(rgba)
    return rgb


def jupyter_movie(clip):
    return moviepy.video.io.html_tools.ipython_display(
        clip, fps=10, loop=False, autoplay=False, rd_kwargs={"logger": None}
    )

In [None]:
@torch.no_grad()
def check_langevin_distribution(config, *, use_ibc_style):
    seed(0)
    ebm_net = NormalEbm(config.y_mean, config.y_std)
    ebm_net.eval()
    # Observation `x` isn't really used here.
    x = torch.zeros(size=(1, 0))
    # Callback state.
    images = []
    all_ys = []
    final_fig = None

    def callback(iteration, ys_latest):
        nonlocal final_fig
        all_ys.append(ys_latest)
        ys_history = torch.cat(all_ys[-config.history_count:], dim=1)
        ys = ys_history
        if iteration % config.iteration_interval != 0:
            return
        if iteration == 0 and not config.show_first:
            return
        image = plot_2d_ebm_callback(config, ebm_net, iteration, x, ys)
        images.append(image)
        if iteration == config.n_iters:
            final_fig = plot_2d_ebm_callback(config, ebm_net, iteration, x, ys, for_pub=True)

    # Do sampling.
    ys_latest = langevin_sample(
        config,
        ebm_net,
        x,
        config.num_chains,
        callback=callback,
        use_ibc_style=use_ibc_style,
    )
    ys_history = torch.cat(all_ys[-config.history_count:], dim=1)
    ys = ys_history
    # Check statistics.
    ys = ys.squeeze(0)
    y_std_actual, y_mean_actual = torch.std_mean(ys, dim=0, unbiased=False)
    # n.b. bad numeric condition if expected is zero, but eh.
    y_mean_rel_error = ((config.y_mean - y_mean_actual) / config.y_mean).abs().max()
    y_std_rel_error = ((config.y_std - y_std_actual) / config.y_std).abs().max()
    print(f"  mean_rel_error: {y_mean_rel_error}")
    print(f"  std_rel_error: {y_std_rel_error}")
    assert final_fig is not None
    return images, final_fig

In [None]:
common = dict(
    y_min=torch.tensor([0.0, 0.0]),
    y_max=torch.tensor([1.0, 1.0]),
    y_mean=torch.tensor([0.3, 0.4]),
    y_std=torch.as_tensor([0.1, 0.2]),
    step_size_init=5e-3,
    step_size_final=1e-3,
    step_size_power=2,
)

num_particles = 400

# Multiple chains, computing against final
multi_chain = Config(
    name="multi_chain",
    n_iters=50,
    num_chains=num_particles,
    history_count=1,
    iteration_interval=1,
    show_first=True,
    **common
)
# Single long chain, computing against history (after burn-in stage).
single_chain = Config(
    name="single_chain",
    n_iters=1000,
    num_chains=1,
    history_count=num_particles,
    iteration_interval=50,
    show_first=True,
    **common
)

configs = [
    multi_chain,
    single_chain,
]

In [None]:
os.makedirs("/tmp/langevin", exist_ok=True)
for config in configs:    
    print(f"[{config.name}]")
    for use_ibc_style, suffix in zip([True, False], ["ibc", "correct"]):
        print(suffix)
        base = f"/tmp/langevin/{config.name}-{suffix}"
        images, final_fig = check_langevin_distribution(
            config,
            use_ibc_style=use_ibc_style,
        )
        fps = 20
        clip = mpy.ImageSequenceClip(repeat_last(images, count=30), fps=fps)
        clip.write_videofile(f"{base}.mp4", fps=fps, logger=None)
        clip.write_gif(f"{base}.gif", fps=fps, logger=None)
        final_fig.savefig(f"{base}-final.png", dpi=300, bbox_inches="tight")
        display(jupyter_movie(clip))
        display(final_fig)
        plt.close(final_fig)