## Test flow-based sampling using reverse KL

In [None]:
import math
import time
from typing import Callable

import numpy as np
import matplotlib.pyplot as plt
import torch
import zuko

In [None]:
plt.style.use("style.mplstyle")

In [None]:
def make_flow(ndim: int, transforms: int = 3, depth: int = 2, width: int = 64) -> zuko.flows.Flow:
    hidden_features = [width] * depth
    flow = zuko.flows.NSF(features=ndim, transforms=transforms, hidden_features=hidden_features)
    flow = zuko.flows.Flow(flow.transform.inv, flow.base)  # fast sampling
    return flow

In [None]:
class CovNormalizer:
    def __init__(self, cov_matrix: torch.Tensor) -> None:
        self.cov_matrix = cov_matrix
        self.unnorm_matrix = torch.linalg.cholesky(self.cov_matrix)
        self.norm_matrix = torch.linalg.inv(self.unnorm_matrix)
        
    def unnormalize(self, x: torch.Tensor) -> torch.Tensor:
        return torch.matmul(x, self.unnorm_matrix.T)

    def normalize(self, x: torch.Tensor) -> torch.Tensor:
        return torch.matmul(x, self.norm_matrix.T)

In [None]:
class Sampler:
    def __init__(self, ndim: int, verbose: int = 0) -> None:
        self.ndim = ndim
        self.verbose = verbose
        self.prob_func = None

    def __call__(self, prob_func: torch.Tensor) -> torch.Tensor:
        raise NotImplementedError        

In [None]:
class FlowSampler(Sampler):
    def __init__(self, ndim: int, flow: zuko.flows.Flow, unnorm_matrix: torch.Tensor = None, train_kws: dict = None) -> None:
        super().__init__(ndim=ndim)
        
        self.flow = flow
        self.trained = False
        
        self.unnorm_matrix = unnorm_matrix
        if self.unnorm_matrix is None:
            self.unnorm_matrix = torch.eye(ndim)   

        self.train_kws = train_kws
        if self.train_kws is None:
            self.train_kws = {}

        self.train_kws.setdefault("batch_size", 512)
        self.train_kws.setdefault("iters", 1000)
        self.train_kws.setdefault("lr", 0.001)
        self.train_kws.setdefault("lr_min", 0.001)
        self.train_kws.setdefault("lr_decay", 0.99)
        self.train_kws.setdefault("print_freq", 100)
        self.train_kws.setdefault("verbose", 0)

        self.train_history = {}
        self.train_history["loss"] = []
        self.train_history["time"] = []

    def unnormalize(self, z: torch.Tensor) -> torch.Tensor:
        return torch.matmul(z, self.unnorm_matrix.T)
        
    def train(self, prob_func: Callable) -> dict:
        self.prob_func = prob_func

        self.train_history = {}
        self.train_history["loss"] = []
        self.train_history["time"] = []

        self.trained = True

        iters = self.train_kws["iters"]
        batch_size = self.train_kws["batch_size"]
        lr = self.train_kws["lr"]
        lr_min = self.train_kws["lr_min"]
        lr_decay = self.train_kws["lr_decay"]
        print_freq = self.train_kws["print_freq"]
        verbose = self.train_kws["verbose"]
    
        start_time = time.time()

        optimizer = torch.optim.Adam(self.flow.parameters(), lr=lr)
        for iteration in range(iters):
            x, log_prob = self.flow().rsample_and_log_prob((batch_size,)) 
            x = self.unnormalize(x)

            loss = torch.mean(log_prob) - torch.mean(torch.log(prob_func(x) + 1.00e-15))
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
    
            # Update learning rate
            for param_group in optimizer.param_groups:
                param_group["lr"] = max(lr_min, lr_decay * param_group["lr"])
    
            # Append to history array
            self.train_history["loss"].append(loss.detach())
            self.train_history["time"].append(time.time() - start_time)
    
            # Print update
            if verbose and (iteration % print_freq == 0):
                print(iteration, loss)
        
        return self.train_history

    def __call__(self, prob_func: Callable, size: int) -> torch.Tensor:
        if not (prob_func is self.prob_func):
            self.trained = False
            
        if not self.trained:
            self.train(prob_func)
        
        with torch.no_grad():
            x = self.flow().sample((size,))
            x = self.unnormalize(x)
            return x

In [None]:
def prob_func(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., 0]
    x2 = x[..., 1]
    log_prob = torch.sin(torch.pi * x1) - 2.0 * (x1**2 + x2**2 - 2.0)**2
    return torch.exp(log_prob)

In [None]:
ndim = 2
cov_matrix = torch.eye(ndim) * (1.0 ** 2)
normalizer = CovNormalizer(cov_matrix)
unnorm_matrix = normalizer.unnorm_matrix

In [None]:
flow = make_flow(ndim=ndim)
sampler = FlowSampler(
    ndim=ndim, 
    flow=flow, 
    unnorm_matrix=unnorm_matrix, 
    train_kws=dict(
        iters=500,
    )
)
sampler.train(prob_func);

In [None]:
fig, ax = plt.subplots(figsize=(3, 2))
ax.plot(sampler.train_history["loss"])
ax.set_xlabel("Iteration")
ax.set_ylabel("Loss")
plt.show()

In [None]:
x = sampler(prob_func, 100_000)

bins = 64
xmax = 3.0

grid_edges = 2 * [torch.linspace(-xmax, xmax, bins)]
grid_points = torch.stack(torch.meshgrid(*grid_edges, indexing="ij"), axis=-1)
grid_values = prob_func(grid_points)
grid_values = grid_values.reshape((bins, bins))

fig, axs = plt.subplots(ncols=2, sharex=True, sharey=True, figsize=(4.5, 2))
axs[0].hist2d(x[:, 0], x[:, 1], bins=grid_edges)
axs[1].pcolormesh(grid_edges[0], grid_edges[0], grid_values.T)
plt.show()