# Neural Spline Flow on a Circular and a Normal Coordinate

In [None]:
# Import packages
import torch
import numpy as np

import normflows as nf

from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm

from tqdm import tqdm

In [None]:
# Set up target
class GaussianVonMises(nf.distributions.Target):
    def __init__(self):
        super().__init__(prop_scale=torch.tensor(2 * np.pi), 
                         prop_shift=torch.tensor(-np.pi))
        self.n_dims = 2
        self.max_log_prob = -1.99
        self.log_const = -1.5 * np.log(2 * np.pi) - np.log(np.i0(1))
    
    def log_prob(self, x):
        return -0.5 * x[:, 0] ** 2 + torch.cos(x[:, 1] - 3 * x[:, 0]) + self.log_const

In [None]:
target = GaussianVonMises()

In [None]:
# Plot target
grid_size = 300
xx, yy = torch.meshgrid(torch.linspace(-2.5, 2.5, grid_size), torch.linspace(-np.pi, np.pi, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
#zz = zz.to(device)

log_prob = target.log_prob(zz).view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()

In [None]:
base = nf.distributions.UniformGaussian(2, [1], torch.tensor([1., 2 * np.pi]))

K = 20

flow_layers = []
for i in range(K):
    flow_layers += [nf.flows.CircularAutoregressiveRationalQuadraticSpline(2, 2, 128, [1], num_bins=20,
                                                                           tail_bound=torch.tensor([5., np.pi]),
                                                                           permute_mask=True)]

model = nf.NormalizingFlow(base, flow_layers, target)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)

In [None]:
# Plot model
log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')
plt.gca().set_aspect('equal', 'box')
plt.show()

In [None]:
# Train model
max_iter = 500
num_samples = 2 ** 10
show_iter = 100


loss_hist = np.array([])

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-4)
for it in tqdm(range(max_iter)):
    optimizer.zero_grad()
    
    # Compute loss
    loss = model.reverse_kld(num_samples)
    
    # Do backprop and optimizer step
    if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        optimizer.step()
    
    # Log loss
    loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
    
    # Plot learned model
    if (it + 1) % show_iter == 0:
        model.eval()
        with torch.no_grad():
            log_prob = model.log_prob(zz.to(device)).to('cpu').view(*xx.shape)
        model.train()
        prob = torch.exp(log_prob)
        prob[torch.isnan(prob)] = 0

        plt.figure(figsize=(15, 15))
        plt.pcolormesh(yy, xx, prob.data.numpy(), cmap='coolwarm')
        plt.gca().set_aspect('equal', 'box')
        plt.show()

# Plot loss
plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

In [None]:
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')

phi = np.linspace(-np.pi, np.pi, grid_size)
z = np.linspace(-2.5, 2.5, grid_size)

# create the sphere surface
x = np.outer(np.ones(grid_size), np.cos(phi))
y = np.outer(np.ones(grid_size), np.sin(phi))
z = np.outer(z, np.ones(grid_size))

# simulate heat pattern (striped)
prob_vis = prob / torch.max(prob)
myheatmap = prob_vis.data.numpy()

ax._axis3don = False
ax.plot_surface(x, y, z, cstride=1, rstride=1, facecolors=cm.coolwarm(myheatmap), shade=False)

plt.show()