# Application of a Normalizing Flow with a Resampled Base Distribution on a Toy Problem

In [None]:
# Import packages
import numpy as np
import torch
import normflow as nf
import larsflow as lf

from sklearn.datasets import make_moons

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Plot target distribution
grid_size = 300
noise = .05

x_np, _ = make_moons(2 ** 20, noise=noise)
plt.figure(figsize=(15, 15))
plt.hist2d(x_np[:, 0], x_np[:, 1], (grid_size, grid_size), range=[[-3, 3], [-3, 3]])
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
#plt.savefig('target.png')
plt.show()

## Train Flow with Gaussian Base Distribution

In [None]:
# Set up model

# Define flows
K = 8
torch.manual_seed(0)
np.random.seed(0)

latent_size = 2
b = torch.Tensor([1 if i % 2 == 0 else 0 for i in range(latent_size)])
flows = []
for i in range(K):
    param_map = nf.nets.MLP([latent_size // 2, 64, 64, latent_size], init_zeros=True)
    flows += [nf.flows.AffineCouplingBlock(param_map)]
    flows += [nf.flows.Permute(latent_size, mode='swap')]
    flows += [nf.flows.ActNorm(latent_size)]

# Set prior and q0
q0 = nf.distributions.DiagGaussian(latent_size, trainable=False)

# Construct flow model
model = lf.NormalizingFlow(q0=q0, flows=flows)

# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
model = model.to(device)

In [None]:
# Prepare z grid for evaluation
xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
zz = zz.to(device)

In [None]:
# Train model
max_iter = 10000
num_samples = 2 ** 12
show_iter = 500

loss_hist = np.array([])
Z_hist = np.array([])

model.train()

# Do mixed precision training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-2)
scaler = torch.cuda.amp.GradScaler()

for it in tqdm(range(max_iter)):
    x_np, _ = make_moons(num_samples, noise=noise)
    x = torch.tensor(x_np).float().to(device, non_blocking=True)
    
    loss = model.forward_kld(x)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    loss_hist = np.append(loss_hist, loss.item())

    # Clear gradients
    nf.utils.clear_grad(model)
        

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

In [None]:
model.eval()
log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)
            
prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
#plt.savefig('rnvp.png')
plt.show()

log_prob = model.q0.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
#plt.savefig('base.png')
plt.show()

## Train Flow with Resampled Base Distribution

In [None]:
# Set up model

# Define flows
torch.manual_seed(0)
np.random.seed(0)

flows = []
for i in range(K):
    param_map = nf.nets.MLP([latent_size // 2, 64, 64, latent_size], init_zeros=True)
    flows += [nf.flows.AffineCouplingBlock(param_map)]
    flows += [nf.flows.Permute(latent_size, mode='swap')]
    flows += [nf.flows.ActNorm(latent_size)]

# Set prior and q0
a = nf.nets.MLP([latent_size, 64, 64, 1], output_fn="sigmoid")
q0 = lf.distributions.ResampledGaussian(latent_size, a, 100, 0.1, trainable=False)

# Construct flow model
model = lf.NormalizingFlow(q0=q0, flows=flows)

# Move model on GPU if available
model = model.to(device)

In [None]:
# Train model

loss_hist = np.array([])
Z_hist = np.array([])

model.train()

# Do mixed precision training
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scaler = torch.cuda.amp.GradScaler()

for it in tqdm(range(max_iter)):
    x_np, _ = make_moons(num_samples, noise=noise)
    x = torch.tensor(x_np).float().to(device, non_blocking=True)
    
    loss = model.forward_kld(x)
    
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    
    loss_hist = np.append(loss_hist, loss.item())
    Z_hist = np.append(Z_hist, model.q0.Z.item())

    # Clear gradients
    nf.utils.clear_grad(model)
        

plt.figure(figsize=(10, 10))
plt.plot(loss_hist, label='loss')
plt.legend()
plt.show()

plt.figure(figsize=(10, 10))
plt.plot(Z_hist, label='Z')
plt.legend()
plt.show()

In [None]:
model.eval();
        
log_prob = model.log_prob(zz).to('cpu').view(*xx.shape)
            
prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
#plt.savefig('resampled_rnvp.png')
plt.show()

log_prob = model.q0.log_prob(zz).to('cpu').view(*xx.shape)
prob = torch.exp(log_prob)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
#plt.savefig('resampled_base.png')
plt.show()
        
prob = model.q0.a(zz).to('cpu').view(*xx.shape)
prob[torch.isnan(prob)] = 0

plt.figure(figsize=(15, 15))
plt.pcolormesh(xx, yy, prob.data.numpy())
plt.gca().set_aspect('equal')
plt.gca().tick_params(axis='both', which='major', labelsize=24)
plt.xlabel('$x_1$', fontsize=32)
plt.ylabel('$x_2$', fontsize=32)
plt.colorbar()
#plt.savefig('resampled_base_a.png')
plt.show()
                
model.train();