In [None]:
import os, sys
os.chdir('/root/code/1d_density_estimation')
sys.path.insert(0, os.getcwd())

from utils.notebook_utils import *

## Overview

This notebook explores a few methods for training a flexible 1d density estimator. For the first
experiment, we attempt to match a mixture of gaussians using another learned mixture of gaussians.
The means and variances of the learned mixture are learned by optimizing the reverse KL divergence
directly with SGD.

In [None]:
"""
Utilities for mixture distributions.
"""

import torch
import torch.nn.functional as F
import torch.distributions as D

def random_gumbel(shape: torch.Size, device: torch.device, eps=1e-12) -> torch.Tensor:
    u = torch.empty(shape, dtype=torch.float32, device=device).uniform_(eps, 1 - eps)
    return -torch.log(-torch.log(u))

def gaussian_mixture_pdf(x: torch.Tensor, pi_logits: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
    pi = F.softmax(pi_logits, dim=0)
    
    x_ = x.unsqueeze(-1)
    pi = pi.unsqueeze(0)
    mu_ = mu.unsqueeze(0)
    sigma_ = sigma.unsqueeze(0)
    
    f = 1 / (np.sqrt(2 * np.pi) * sigma_) * torch.exp(-1 / 2 * ((x_ - mu_) / sigma_) ** 2)
    return torch.sum(pi * f, axis=-1)

def relaxed_gaussian_mixture_pdf(x: torch.Tensor, pi_logits: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, n_samples: int, temp: float = 0.) -> torch.Tensor:
    if temp == 0.:
        return gaussian_mixture_pdf(x, pi_logits, mu, sigma)
    
    pi_logits_ = pi_logits.repeat(n_samples).view(n_samples, -1)
    pi = F.softmax((pi_logits_ + random_gumbel(pi_logits_.shape, pi_logits_.device)) / temp, dim=-1)
    
    mu_ = (pi * mu.unsqueeze(0)).sum(dim=1)
    sigma_ = torch.sqrt((pi * (sigma ** 2).unsqueeze(0)).sum(dim=1))
    
    x_ = x.unsqueeze(0)
    mu_ = mu_.unsqueeze(-1)
    sigma_ = sigma_.unsqueeze(-1)
    probs = 1 / (np.sqrt(2 * np.pi) * sigma_) * torch.exp(-1 / 2 * ((x_ - mu_) / sigma_) ** 2)
    return probs.mean(dim=0)

def gaussian_mixture_log_likelihood(x: torch.Tensor, pi_logits: torch.Tensor, mu: torch.Tensor,
    sigma: torch.Tensor) -> torch.Tensor:
    
    log_pi = F.log_softmax(pi_logits, dim=0)
    
    x_ = x.unsqueeze(-1)
    log_pi = log_pi.unsqueeze(0)
    mu_ = mu.unsqueeze(0)
    sigma_ = sigma.unsqueeze(0)
    
    f = -1 / 2 * ((x_ - mu_) / sigma_) ** 2 - 1 / 2 * np.log(2 * np.pi) - torch.log(sigma_) + log_pi
    return torch.logsumexp(f, dim=-1)

def relaxed_gaussian_mixture_log_likelihood(x: torch.Tensor, pi_logits: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor, n_samples: int, temp: float = 0.) -> torch.Tensor:
    if temp == 0.:
        return gaussian_mixture_log_likelihood(x, pi_logits, mu, sigma)
    
    pi_logits_ = pi_logits.repeat(n_samples).view(n_samples, -1)
    pi = F.softmax((pi_logits_ + random_gumbel(pi_logits_.shape, pi_logits_.device)) / temp, dim=-1)
    
    mu_ = (pi * mu.unsqueeze(0)).sum(dim=1)
    sigma_ = torch.sqrt((pi * (sigma ** 2).unsqueeze(0)).sum(dim=1))
    
    x_ = x.unsqueeze(0)
    mu_ = mu_.unsqueeze(-1)
    sigma_ = sigma_.unsqueeze(-1)
    f = -1 / 2 * ((x_ - mu_) / sigma_) ** 2 - 1 / 2 * np.log(2 * np.pi) - torch.log(sigma_ * n_samples)
    return torch.logsumexp(f, dim=0)

def gaussian_mixture_sample(count: int, pi_logits: torch.Tensor, mu: torch.Tensor, sigma: torch.Tensor,
    temp: float = 0.) -> torch.Tensor:
    
    if temp < 0:
        raise ValueError(f"bad temp {temp}")
    
    pi_logits_ = pi_logits.repeat(count).view(count, -1)
    
    if temp == 0.:
        pi = D.Categorical(logits=pi_logits_).sample()
        pi = F.one_hot(pi, num_classes=pi_logits.shape[0])
    else:
        pi = F.softmax((pi_logits_ + random_gumbel(pi_logits_.shape, pi_logits_.device)) / temp, dim=-1)
    
    mu_ = mu.repeat(count).view(count, -1)
    sigma_ = sigma.repeat(count).view(count, -1)
    x = D.Normal(loc=mu_, scale=sigma_).sample()
    return torch.sum(pi * x, axis=1)

In [None]:
"""
Plotting utilities.
"""

import io
import uuid
import plotly

from PIL             import Image
from tqdm            import tqdm
from IPython.display import Image as IPImage

def make_layout(width, height):
    return go.Layout(width=width, height=height, margin=dict(t=0,r=0,b=0,l=0),
        xaxis=dict(gridcolor='rgb(220,220,220)'), yaxis=dict(gridcolor='rgb(220,220,220)'),
        plot_bgcolor='rgb(255,255,255)', legend=dict(orientation='h'))

def figures_to_gif(figures, duration=200):
    images = []
    
    with tqdm(total=len(figures)) as pbar:
        with io.BytesIO() as buf:
            plotly.io.write_image(figures[0], format='png', file=buf)
            data = buf.getvalue()
        with io.BytesIO(data) as buf:
            img = Image.open(buf)
            img.load()
            images.append(img)
        pbar.update(1)

        for fig in figures[1:]:
            with io.BytesIO() as buf:
                plotly.io.write_image(fig, format='png', file=buf)
                data = buf.getvalue()
            with io.BytesIO(data) as buf:
                img = Image.open(buf)
                img.load()
                images.append(img)
            pbar.update(1)
            
    name = str(uuid.uuid4()) + '.gif'
    images[0].save(name, save_all=True, append_images=images[1:], optimize=False, duration=duration, loop=0)
    return IPImage(name)

In [None]:
show("### Visual confirmation that sampling and PDF evaluation work for normal mixtures:")

mu = torch.Tensor([-12, -4, 4, 12])
sigma = torch.Tensor([1, 1, 1, 1])
pi_logits = torch.Tensor([0, 1, 2, 3])

x = gaussian_mixture_sample(8 * 1024, pi_logits, mu, sigma)
t0 = go.Histogram(x=x, histnorm='probability density')

x = torch.linspace(-16, 16, 1024)
y = gaussian_mixture_pdf(x, pi_logits, mu, sigma)
t1 = go.Scatter(x=x, y=y, mode='lines')

l = make_layout(width=950, height=600)
f = go.Figure(data=[t0, t1], layout=l)
py.iplot(f)

In [None]:
show("### Visual confirmation that sampling and PDF evaluation work for relaxed mixtures:")

m = torch.linspace(-16, 16, 1024)
temps = np.linspace(1, 0, 256)
figures = []

for t in tqdm(temps):
    y = relaxed_gaussian_mixture_pdf(m, pi_logits, mu, sigma, n_samples=1024, temp=t)
    t0 = go.Scatter(x=m.numpy(), y=y.numpy(), mode='lines', name=f'pdf (t={t:.2f})', showlegend=True)
    
    x = gaussian_mixture_sample(8 * 1024, pi_logits, mu, sigma, temp=t)
    t1 = go.Histogram(x=x, histnorm='probability density', name=f'hist t={t:.2f}', showlegend=True)
    l = make_layout(width=950, height=150)
    l.xaxis.update(range=[-16, 16])
    l.yaxis.update(range=[0, 0.28])
    figures.append(go.Figure(data=[t0, t1], layout=l))
    
display(figures_to_gif(figures, duration=20))

In [None]:
show("### Bias of Gradient At Optimum, When Fitting a Gaussian Mixture")

from tqdm import tqdm

sample_size = 1024 * 1024
temps = [1 / 1024, 1 / 512, 1 / 256, 1 / 128, 1 / 64, 1 / 32, 1 / 16, 1 / 8, 1 / 4, 1 / 2, 1]

pi_logits_grads = []
mu_grads = []
log_sigma_grads = []

pi_logits_grads_relax = []
mu_grads_relax = []
log_sigma_grads_relax = []

for temp in tqdm(temps):
    for mode in [0, 1]:
        pi_logits_ref = torch.Tensor([0, 1, 2, 3]).cuda()
        mu_ref        = torch.Tensor([-12, -4, 4, 12]).cuda()
        sigma_ref     = torch.Tensor([1, 1, 1, 1]).cuda()

        pi_logits = torch.Tensor([0, 1, 2, 3]).cuda()
        mu        = torch.Tensor([-12, -4, 4, 12]).cuda()
        log_sigma = torch.zeros_like(sigma_ref).cuda()

        pi_logits.requires_grad = True
        mu.requires_grad        = True
        log_sigma.requires_grad = True

        x = gaussian_mixture_sample(sample_size, pi_logits, mu, torch.exp(log_sigma), temp)

        if mode == 0:
            llq = gaussian_mixture_log_likelihood(x, pi_logits, mu, torch.exp(log_sigma))
        elif mode == 1:
            llq = relaxed_gaussian_mixture_log_likelihood(x, pi_logits, mu, torch.exp(log_sigma), n_samples=1024, temp=temp)

        llp = gaussian_mixture_log_likelihood(x, pi_logits_ref, mu_ref, sigma_ref)
        loss = torch.mean(llq - llp)
        loss.backward()

        if mode == 0:
            pi_logits_grads.append(pi_logits.grad.cpu().detach().numpy())
            mu_grads.append(mu.grad.cpu().detach().numpy())
            log_sigma_grads.append(log_sigma.grad.cpu().detach().numpy())
        if mode == 1:
            pi_logits_grads_relax.append(pi_logits.grad.cpu().detach().numpy())
            mu_grads_relax.append(mu.grad.cpu().detach().numpy())
            log_sigma_grads_relax.append(log_sigma.grad.cpu().detach().numpy())
    
pi_logits_grad_norms = [np.linalg.norm(g) for g in pi_logits_grads]
mu_grad_norms        = [np.linalg.norm(g) for g in mu_grads]
log_sigma_grad_norms = [np.linalg.norm(g) for g in log_sigma_grads]

pi_logits_grad_norms_relax = [np.linalg.norm(g) for g in pi_logits_grads_relax]
mu_grad_norms_relax        = [np.linalg.norm(g) for g in mu_grads_relax]
log_sigma_grad_norms_relax = [np.linalg.norm(g) for g in log_sigma_grads_relax]

x = [np.log(1 / t) / np.log(2) for t in temps]
traces = []

for name, norms in zip(
    ['pi logits', 'mu', 'log sigma', 'pi logits (relax)', 'mu (relax)', 'log sigma (relax)'],
    [pi_logits_grad_norms, mu_grad_norms, log_sigma_grad_norms, pi_logits_grad_norms_relax, mu_grad_norms_relax, log_sigma_grad_norms_relax]):
    traces.append(go.Scatter(x=x, y=norms, mode='markers+lines', name=name))

l = make_layout(width=950, height=600)
l.xaxis.update(title='lg(1 / temp)')
l.yaxis.update(title='norm of gradient at optimum')
f = go.Figure(data=traces, layout=l)
py.iplot(f)

In [None]:
show("### Fitting a Gaussian Mixture by Optimizing Reverse KL-Divergence (Exact Marginalization)")

show(r"""Observation: the objective is difficult to optimize using a first-order method. The loss does
not always decrease monotonically, and runs are liable to divergence. This happens even with various
settings for the step size decay.""")

import torch.optim

from tqdm import tqdm

n_batch = 1024

# stable up to ~2k iters
init_step_size         = 1e-1
final_step_size        = init_step_size / 32
step_size_anneal_iters = 4 * 1024
total_iters            = 32 * 1024

def cosine_ramp(start: float, end: float, cur: int, total: int) -> float:
    return start + (end - start) * (1 - 1 / 2 * (1 + np.cos(np.pi * min(cur / total, 1))))

step_size = lambda t: cosine_ramp(init_step_size, final_step_size, t, step_size_anneal_iters)

pi_logits_ref = torch.Tensor([0, 0, 0, 0]).cuda() # TODO: try different logits later
mu_ref        = torch.Tensor([-12, -4, 4, 12]).cuda()
sigma_ref     = torch.Tensor([1, 1, 1, 1]).cuda()

pi_logits = torch.Tensor([0, 0, 0, 0]).cuda()
mu        = torch.linspace(-8, 8, 4).cuda() #torch.empty_like(mu_ref).uniform_(-8, 8)
#sigma     = torch.ones_like(sigma_ref)
log_sigma = torch.zeros_like(sigma_ref).cuda()

#pi_logits.requires_grad = True
mu.requires_grad = True
#log_sigma.requires_grad = True

optim = torch.optim.Adam([mu], lr=init_step_size, betas=(0.99, 0.999))
losses = []

for cur_iter in tqdm(range(total_iters)):
    loss = 0
    pi = F.softmax(pi_logits, dim=0)

    for i in range(pi_logits.shape[0]):
        x = D.Normal(loc=mu[i].repeat(n_batch), scale=torch.exp(log_sigma[i]).repeat(n_batch)).sample()
        llq = gaussian_mixture_log_likelihood(x, pi_logits, mu, torch.exp(log_sigma))
        llp = gaussian_mixture_log_likelihood(x, pi_logits_ref, mu_ref, sigma_ref)
        loss += pi[i] * torch.mean(llq - llp)
    
    losses.append(loss.cpu().detach().numpy())
    
    for g in optim.param_groups:
        g['lr'] = step_size(cur_iter)
        
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if cur_iter % 256 == 0 or not torch.isfinite(loss):
        print(loss)
        print(mu)
        
        if not torch.isfinite(loss):
            print("got nonfinite loss; exiting")
            break

In [None]:
t = go.Scatter(x=np.arange(len(losses)), y=losses, mode='lines')
l = make_layout(width=950, height=600)
f = go.Figure(data=[t], layout=l)
py.iplot(f)

In [None]:
show("### Fitting a Gaussian Mixture by Optimizing Reverse KL-Divergence (Relaxation, Both Quantities, Learned Temp)")

import torch.optim

from tqdm import tqdm

"""
Tuned config for Adam. Different temperature and step size schedules generally did worse.
"""
n_batch = 1024

init_temp              = 1 / 1
final_temp             = 1 / 64
temp_anneal_iters      = 2 * 1024
total_iters            = 32 * 1024

# The smaller the final relaxation temperature, the smaller the final step size to which we
# need to decay.
init_step_size         = 1e-2
final_step_size        = init_step_size / 4
step_size_anneal_iters = 2 * 1024

def cosine_ramp(start: float, end: float, cur: int, total: int) -> float:
    return start + (end - start) * (1 - 1 / 2 * (1 + np.cos(np.pi * min(cur / total, 1))))

temp      = 2.3 * torch.ones(1) # Chosen so that inverse_softplus(temp / 4) approx 1.
step_size = lambda t: cosine_ramp(init_step_size, final_step_size, t, step_size_anneal_iters)

pi_logits_ref = torch.Tensor([0, 0, 0, 0]) # TODO: try different logits later
mu_ref        = torch.Tensor([-12, -4, 4, 12])
sigma_ref     = torch.Tensor([1, 1, 1, 1])

pi_logits = torch.Tensor([0, 0, 0, 0])
#mu        = torch.linspace(-8, 8, 4) # This works better, but makes the problem easier.
mu        = torch.empty_like(mu_ref).uniform_(-8, 8)
sigma     = torch.ones_like(sigma_ref)
log_sigma = torch.zeros_like(sigma_ref)

temp.requires_grad = True
#pi_logits.requires_grad = True
mu.requires_grad = True
#log_sigma.requires_grad = True # Learning this is still unstable.

optim = torch.optim.Adam([temp, mu], lr=init_step_size, betas=(0.99, 0.999))
losses = []

for cur_iter in tqdm(range(total_iters)):
    x = gaussian_mixture_sample(n_batch, pi_logits, mu, torch.exp(log_sigma), F.softplus(temp / 4))
    llq = relaxed_gaussian_mixture_log_likelihood(x, pi_logits, mu, torch.exp(log_sigma), n_samples=1024, temp=F.softplus(temp / 4))
    llp = gaussian_mixture_log_likelihood(x, pi_logits_ref, mu_ref, sigma_ref)
    loss = torch.mean(llq - llp)
    losses.append(loss.detach().numpy())
    
    for g in optim.param_groups:
        g['lr'] = step_size(cur_iter)
        
    optim.zero_grad()
    loss.backward()
    optim.step()
    
    if cur_iter % 512 == 0 or not torch.isfinite(loss):
        print(loss, F.softplus(temp / 4))
        print(mu)
        
        if not torch.isfinite(loss):
            print("got nonfinite loss; exiting")
            break

In [None]:
t = go.Scatter(x=np.arange(len(losses)), y=losses, mode='lines')
l = make_layout(width=950, height=600)
f = go.Figure(data=[t], layout=l)
py.iplot(f)

In [None]:
"""
Flow utilities.
"""

import attr
import torch
import torch.nn as nn
import torch.nn.functional as F

from collections import OrderedDict
from typing      import Tuple

def inverse_softplus(y: float) -> float:
    # XXX: this is not a numerically stable implementation.
    return np.log(np.exp(y) - 1)

@attr.s(eq=False, repr=False)
class LocalWarpTransformation(nn.Module):
    init_radius:   float        = attr.ib(default=1.)
    init_slope:    float        = attr.ib(default=1., validator=lambda i, a, x: x > 0.5)
    init_loc:      float        = attr.ib(default=0.)
    max_slope:     float        = attr.ib(default=10., validator=lambda i, a, x: x > 0.5)
    device:        torch.device = attr.ib(default=None)
    requires_grad: bool         = attr.ib(default=True)
        
    def __attrs_post_init__(self) -> None:
        super().__init__()
        
        r = torch.full((), fill_value=inverse_softplus(self.init_radius), dtype=torch.float32,
            device=self.device, requires_grad=self.requires_grad)
        s = torch.full((), fill_value=inverse_softplus(self.init_slope - 0.5), dtype=torch.float32,
            device=self.device, requires_grad=self.requires_grad)
        b = torch.full((), fill_value=self.init_loc, dtype=torch.float32, device=self.device,
            requires_grad=self.requires_grad)
        
        self.r, self.s, self.b = nn.Parameter(r), nn.Parameter(s), nn.Parameter(b)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        r, s = F.softplus(self.r), 0.5 + F.softplus(self.s)
        return x - r * (2 * F.sigmoid(2 * (1 - s) / r * (x - self.b)) - 1)
    
    def jacobian(self, x: torch.Tensor) -> torch.Tensor:
        r, s = F.softplus(self.r), 0.5 + F.softplus(self.s)
        z = 2 * (1 - s) / r * (x - self.b)
        y = F.sigmoid(z)
        return 1 - 4 * (1 - s) * y * (1 - y)
    
    def inverse(self, y: torch.Tensor, step_size: float = 1., n_iters: int = 16) -> torch.Tensor:
        """
        Inverts the function using Newton's method.
        
        XXX: this can fail when `self.s` is large; the secant method or bisection are more reliable
        choices.
        """
        
        x_rec = y
        
        for _ in range(n_iters):
            x_rec = x_rec - step_size * (self.forward(x_rec) - y) / (self.jacobian(x_rec) + 1e-8)
            
        return x_rec
    
    def ln_abs_det_jacobian(self, x: torch.Tensor) -> torch.Tensor:
        # XXX: this computation is likely not stable.
        return torch.log(self.jacobian(x))

@attr.s(eq=False, repr=False)
class AffineTransformation(nn.Module):
    device:        torch.device = attr.ib(default=None)
    requires_grad: bool         = attr.ib(default=True)
        
    def __attrs_post_init__(self) -> None:
        super().__init__()
        
        w = torch.ones((), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad)
        b = torch.zeros((), dtype=torch.float32, device=self.device, requires_grad=self.requires_grad)
        self.w, self.b = nn.Parameter(w), nn.Parameter(b)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.w * x + self.b
    
    def inverse(self, y: torch.Tensor, **kwargs) -> torch.Tensor:
        return (y - self.b) / self.w
    
    def ln_abs_det_jacobian(self, x: torch.Tensor) -> torch.Tensor:
        return self.w.abs().log() * torch.ones_like(x)
    
    def post_update(self) -> None:
        pass
    
@attr.s(eq=False, repr=False)
class Flow(nn.Module):
    n_blocks:          int          = attr.ib(validator=lambda i, a, x: x >= 1)
    init_slope:        float        = attr.ib(default=1.)
    device:            torch.device = attr.ib(default=None)
    requires_grad:     bool         = attr.ib(default=True)
        
    def __attrs_post_init__(self) -> None:
        super().__init__()
        
        modules = []
        
        for k in range(self.n_blocks):
            modules.append((f'affine_{k + 1}', AffineTransformation()))
            modules.append((f'warp_{k + 1}', LocalWarpTransformation(init_slope=self.init_slope)))

        modules.append((f'affine_{self.n_blocks + 1}', AffineTransformation()))
        self.modules_ = nn.Sequential(OrderedDict(modules))
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.modules_(x)
    
    def inverse(self, y: torch.Tensor, **kwargs) -> torch.Tensor:
        x = y
        
        for name, module in reversed(list(self.modules_.named_children())):
            x = module.inverse(x, **kwargs)
            
        return x
    
    def forward_with_jacobian_factor(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        y, f = x, torch.zeros_like(x)
        
        for name, module in self.modules_.named_children():
            f += module.ln_abs_det_jacobian(y)
            y = module.forward(y)
            
        return y, f
    
    def inverse_with_jacobian_factor(self, y: torch.Tensor, **kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
        x, f = y, torch.zeros_like(y)
        
        for name, module in reversed(list(self.modules_.named_children())):
            x = module.inverse(x, **kwargs)
            f += module.ln_abs_det_jacobian(x)
            
        return x, f

In [None]:
show("## Basic Test: Fitting More Gaussian Mixtures")

import torch.optim

# Successful configuration for three-mode mixture.

n_batch = 32 * 1024

pi_logits_ref = torch.Tensor([0, 0, 0])
mu_ref        = torch.Tensor([0, 2, 4])
sigma_ref     = torch.Tensor([1 / 2, 1 / 2, 1 / 2])

base_dist = D.Normal(loc=torch.zeros(n_batch), scale=torch.ones(n_batch))
f = Flow(n_blocks=16)

init_step_size         = 1e-3
final_step_size        = 1e-3 / 1
step_size_anneal_iters = 8 * 1024
total_iters            = 64 * 1024

"""
# Here we check that the 16-layer flow with affine layers can still fit a 2d mixture.

n_batch = 32 * 1024

pi_logits_ref = torch.Tensor([0, 0])
mu_ref        = torch.Tensor([-2, 2])
sigma_ref     = torch.Tensor([1 / 2, 1 / 2])

base_dist = D.Normal(loc=torch.zeros(n_batch), scale=torch.ones(n_batch))
f = Flow(use_affine_layers=True, n_blocks=16)

init_step_size         = 1e-3
final_step_size        = 1e-3 / 1
step_size_anneal_iters = 8 * 1024
total_iters            = 64 * 1024
"""

def cosine_ramp(start: float, end: float, cur: int, total: int) -> float:
    return start + (end - start) * (1 - 1 / 2 * (1 + np.cos(np.pi * min(cur / total, 1))))

step_size = lambda t: cosine_ramp(init_step_size, final_step_size, t, step_size_anneal_iters)
optim = torch.optim.Adam(list(f.parameters()), lr=init_step_size, betas=(0.99, 0.999))

losses = []
params = {k : [] for k, _ in f.named_parameters()}

for cur_iter in range(total_iters):
    x = base_dist.sample()
    y, d = f.forward_with_jacobian_factor(x)
    llq = (-1 / 2 * np.log(2 * np.pi) - x ** 2 / 2) - d
    llp = gaussian_mixture_log_likelihood(y, pi_logits_ref, mu_ref, sigma_ref)
    loss = (llq - llp).mean()
    
    optim.zero_grad()
    loss.backward()
    
    has_bad_grad = False
    
    for name, param in f.named_parameters():
        if not torch.isfinite(param.grad):
            print(f"grad for param {name} has nonfinite values (step {cur_iter})")
            has_bad_grad = True
            
    if has_bad_grad:
        break
    
    optim.step()
    losses.append(float(loss.detach().numpy()))
    
    for k, v in f.named_parameters():
        params[k].append(float(v.detach().numpy()))
    
    if cur_iter % 512 == 0 or not torch.isfinite(loss):
        print(cur_iter, loss)
        
        if not torch.isfinite(loss):
            print("got nonfinite loss; exiting")
            break

In [None]:
show("### Visualization of Transformed Base Density")

import io
import plotly

from PIL import Image

x = D.Normal(loc=torch.zeros(8 * 1024), scale=torch.ones(8 * 1024)).sample()
y = x

t = go.Histogram(x=y.numpy(), histnorm='probability density')
l = make_layout(width=950, height=150)
l.xaxis.update(range=[-3, 6])
l.yaxis.update(range=[0, 0.6])

figures = [go.Figure(data=[t], layout=l)]

with torch.no_grad():
    for name, module in f.modules_.named_children():
        y = module.forward(y)
        t = go.Histogram(x=y.numpy(), histnorm='probability density')
        figures.append(go.Figure(data=[t], layout=l))
            
display(figures_to_gif(figures, duration=200))

In [None]:
show("## Basic Test: Fitting Laplace Distribution")

import torch.optim

def laplace_log_likelihood(x: torch.Tensor, mu: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
    return -torch.log(2 * b) - torch.abs(x - mu) / b

"""
# Working configuration for learning Laplace(0, 1). The model was unable to match the target density
# well unless the slope parameters were initialized to values less than 1. Here, we use 0.6.

mu_ref = torch.zeros(1)
b_ref  = torch.ones(1)

n_batch = 32 * 1024

base_dist = D.Normal(loc=torch.zeros(n_batch), scale=torch.ones(n_batch))
f = Flow(n_blocks=16, init_slope=0.6)

init_step_size         = 1e-3
final_step_size        = 1e-3 / 1
step_size_anneal_iters = 8 * 1024
total_iters            = 64 * 1024
"""

mu_ref = torch.zeros(1)
b_ref  = torch.full((1,), fill_value=1 / 16)

n_batch = 32 * 1024

base_dist = D.Normal(loc=torch.zeros(n_batch), scale=torch.ones(n_batch))
f = Flow(n_blocks=16, init_slope=0.6)

init_step_size         = 1e-3
final_step_size        = 1e-3 / 1
step_size_anneal_iters = 8 * 1024
total_iters            = 64 * 1024

def cosine_ramp(start: float, end: float, cur: int, total: int) -> float:
    return start + (end - start) * (1 - 1 / 2 * (1 + np.cos(np.pi * min(cur / total, 1))))

step_size = lambda t: cosine_ramp(init_step_size, final_step_size, t, step_size_anneal_iters)
optim = torch.optim.Adam(list(f.parameters()), lr=init_step_size, betas=(0.99, 0.999))

losses = []
params = {k : [] for k, _ in f.named_parameters()}

for cur_iter in range(total_iters):
    x = base_dist.sample()
    y, d = f.forward_with_jacobian_factor(x)
    llq = (-1 / 2 * np.log(2 * np.pi) - x ** 2 / 2) - d
    llp = laplace_log_likelihood(y, mu_ref, b_ref)
    loss = (llq - llp).mean()
    
    optim.zero_grad()
    loss.backward()
    
    has_bad_grad = False
    
    for name, param in f.named_parameters():
        if not torch.isfinite(param.grad):
            print(f"grad for param {name} has nonfinite values (step {cur_iter})")
            has_bad_grad = True
            
    if has_bad_grad:
        break
    
    optim.step()
    losses.append(float(loss.detach().numpy()))
    
    for k, v in f.named_parameters():
        params[k].append(float(v.detach().numpy()))
    
    if cur_iter % 512 == 0 or not torch.isfinite(loss):
        print(cur_iter, loss)
        
        if not torch.isfinite(loss):
            print("got nonfinite loss; exiting")
            break

In [None]:
show("### Visualization of Transformed Base Density")

import io
import plotly

from PIL import Image

x = D.Normal(loc=torch.zeros(8 * 1024), scale=torch.ones(8 * 1024)).sample()
y = x

t = go.Histogram(x=y.numpy(), histnorm='probability density')
l = make_layout(width=950, height=150)
l.xaxis.update(range=[-3, 6])
l.yaxis.update(range=[0, 0.6])

figures = [go.Figure(data=[t], layout=l)]

with torch.no_grad():
    for name, module in f.modules_.named_children():
        y = module.forward(y)
        t = go.Histogram(x=y.numpy(), histnorm='probability density')
        figures.append(go.Figure(data=[t], layout=l))
            
display(figures_to_gif(figures, duration=200))