## Test flow-based sampling using reverse KL

In [None]:
import math
import time
from typing import Callable

import numpy as np
import matplotlib.pyplot as plt
import torch
import zuko

import ment_torch as ment

In [None]:
plt.style.use("style.mplstyle")

In [None]:
def make_flow(ndim: int, transforms: int = 3, depth: int = 2, width: int = 64) -> zuko.flows.Flow:
    hidden_features = [width] * depth
    flow = zuko.flows.NSF(features=ndim, transforms=transforms, hidden_features=hidden_features)
    flow = zuko.flows.Flow(flow.transform.inv, flow.base)  # fast sampling
    return flow

In [None]:
def prob_func(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., 0]
    x2 = x[..., 1]
    log_prob = torch.sin(torch.pi * x1) - 2.0 * (x1**2 + x2**2 - 2.0)**2
    return torch.exp(log_prob)

In [None]:
ndim = 2
cov_matrix = torch.eye(ndim)
unnorm_matrix = torch.linalg.cholesky(cov_matrix)

In [None]:
flow = make_flow(ndim=ndim)
sampler = ment.samp.FlowSampler(
    ndim=ndim, 
    flow=flow, 
    unnorm_matrix=unnorm_matrix, 
    train_kws=dict(
        iters=1000,
        batch_size=256,
    )
)
sampler.train(prob_func);

In [None]:
fig, ax = plt.subplots(figsize=(3, 2))
ax.plot(sampler.train_history["loss"])
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
plt.show()

In [None]:
x = sampler(prob_func, 100_000)

bins = 64
xmax = 3.0

grid_edges = 2 * [torch.linspace(-xmax, xmax, bins)]
grid_points = torch.stack(torch.meshgrid(*grid_edges, indexing="ij"), axis=-1)
grid_values = prob_func(grid_points)
grid_values = grid_values.reshape((bins, bins))

fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(4.5, 2))
axs[0].hist2d(x[:, 0], x[:, 1], bins=grid_edges)
axs[1].pcolormesh(grid_edges[0], grid_edges[0], grid_values.T)
plt.show()