# Planar flow

In [None]:
# Import required packages
import torch
import numpy as np
import normflow as nf

from matplotlib import pyplot as plt

In [None]:
K = 2
flows = []
for i in range(K):
    flows += [nf.flows.Planar((2,))]
#prior = torch.distributions.MultivariateNormal(torch.zeros(2), torch.eye(2))
prior = nf.distributions.TwoModes(2, 0.1)
q0 = nf.distributions.ConstDiagGaussian(np.zeros(2), 1 * np.ones(2))
nfm = nf.NormalizingFlow(prior=prior, q0=q0, flows=flows)

In [None]:
# Plot prior distribution
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)
log_prob = prior.log_prob(z)
prob = torch.exp(log_prob)

plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob)
plt.show()

log_q = nfm.log_q(z, None)
q = torch.exp(log_q)

# Plot initial posterior distribution
z, _, _ = nfm(torch.zeros(100), num_samples=10000)
z_np = z.data.numpy()
plt.figure(figsize=(10, 10))
plt.hist2d(z_np[:, :, 0].flatten(), z_np[:, :, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
plt.show()

In [None]:
# Train model
max_iter = 10000
batch_size = 100
num_samples = 100

loss_hist = np.array([])
log_q_hist = np.array([])
log_p_hist = np.array([])
x = torch.zeros(batch_size)
optimizer = torch.optim.Adam(nfm.parameters(), lr=0.0001)
for it in range(max_iter):
    optimizer.zero_grad()
    _, log_q, log_p = nfm(x, num_samples)
    mean_log_q = torch.mean(log_q)
    mean_log_p = torch.mean(log_p)
    loss = mean_log_q - mean_log_p
    loss.backward()
    optimizer.step()
    loss_hist = np.append(loss_hist, loss.data.numpy())
    log_q_hist = np.append(log_q_hist, mean_log_q.data.numpy())
    log_p_hist = np.append(log_p_hist, mean_log_p.data.numpy())

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.plot(log_q_hist, label='log_q')
plt.plot(log_p_hist, label='log_p')
plt.legend()
plt.show()

In [None]:
# Plot learned posterior distribution
z, _, _ = nfm(x, num_samples=10000)
z_np = z.data.numpy()
plt.figure(figsize=(10, 10))
plt.hist2d(z_np[:, :, 0].flatten(), z_np[:, :, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
plt.show()