# Real NVP

In [None]:
# Import required packages
import torch
import numpy as np
import normflow as nf

from matplotlib import pyplot as plt
from tqdm import tqdm

In [None]:
# Set up model

# Define flows
K = 8
#torch.manual_seed(0)

b = torch.tensor([0, 1])
flows = []
for i in range(K):
    s = nf.nets.MLP([2, 8, 2])
    t = nf.nets.MLP([2, 8, 2])
    if i % 2 == 0:
        flows += [nf.flows.MaskedAffineFlow(b, s, t)]
    else:
        flows += [nf.flows.MaskedAffineFlow(1 - b, s, t), nf.flows.BatchNorm()]
flows = flows[:-1] # Remove last Batch Norm layer to allow arbirary output

# Set prior and q0
#prior = nf.distributions.Sinusoidal(0.2, 4)
prior = nf.distributions.TwoModes(2, 0.1)
q0 = nf.distributions.ConstDiagGaussian(np.zeros(2, dtype=np.float32), np.ones(2, dtype=np.float32))

# Construct flow model
nfm = nf.NormalizingFlow(prior=prior, q0=q0, flows=flows)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
nfm = nfm.to(device)
nfm = nfm.double()

In [None]:
# Plot prior distribution
grid_size = 200
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
z = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2)
log_prob = prior.log_prob(z)
prob = torch.exp(log_prob)

plt.figure(figsize=(10, 10))
plt.pcolormesh(xx, yy, prob)
plt.show()

# Plot initial posterior distribution
z, _, _ = nfm(torch.zeros(512, device=device), num_samples=512)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(10, 10))
plt.hist2d(z_np[:, :, 0].flatten(), z_np[:, :, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
plt.show()

In [None]:
# Train model
max_iter = 20000
batch_size = 128
num_samples = 128
anneal_iter = 10000
annealing = True
show_iter = 500


loss_hist = np.array([])
log_q_hist = np.array([])
log_p_hist = np.array([])
x = torch.zeros(batch_size, device=device)

optimizer = torch.optim.Adam(nfm.parameters(), lr=1e-3, weight_decay=1e-3)
for it in tqdm(range(max_iter)):
    optimizer.zero_grad()
    _, log_q, log_p = nfm(x, num_samples)
    mean_log_q = torch.mean(log_q)
    mean_log_p = torch.mean(log_p)
    if annealing:
        loss = mean_log_q - np.min([1., 0.01 + it / anneal_iter]) * mean_log_p
    else:
        loss = mean_log_q - mean_log_p
    loss.backward()
    optimizer.step()
    
    # Plot learned posterior
    if (it + 1) % show_iter == 0:
        torch.cuda.manual_seed(0)
        z, _, _ = nfm(torch.zeros(512).to(device), num_samples=512)
        z_np = z.to('cpu').data.numpy()
        
        plt.figure(figsize=(10, 10))
        plt.hist2d(z_np[:, :, 0].flatten(), z_np[:, :, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
        plt.show()
        
        loss_hist = np.append(loss_hist, loss.to('cpu').data.numpy())
        log_q_hist = np.append(log_q_hist, mean_log_q.to('cpu').data.numpy())
        log_p_hist = np.append(log_p_hist, mean_log_p.to('cpu').data.numpy())

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.plot(log_q_hist, label='log_q')
plt.plot(log_p_hist, label='log_p')
plt.legend()
plt.show()

In [None]:
# Plot learned posterior distribution
z, _, _ = nfm(torch.zeros(512).to(device), num_samples=512)
z_np = z.to('cpu').data.numpy()
plt.figure(figsize=(10, 10))
plt.hist2d(z_np[:, :, 0].flatten(), z_np[:, :, 1].flatten(), (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
plt.show()