# Fitting 2d Toy Datasets with the FKLD

In [None]:
# Import packages
import numpy as np
import torch
import normflows as nf
import larsflow as lf

from matplotlib import pyplot as plt

from tqdm import tqdm

In [None]:
# Get device
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

In [None]:
# Function for model creation

def create_model(p, resampled=True):
    # Set up model

    # Define flows
    K = 16
    torch.manual_seed(0)

    latent_size = 2
    hidden_units = 32
    hidden_layers = 3
    
    flows = []
    for i in range(K):
        net = nf.nets.LipschitzMLP([latent_size] + [hidden_units] * (hidden_layers - 1) + [latent_size],
                               init_zeros=True, lipschitz_const=0.9)
        flows += [nf.flows.Residual(net, reduce_memory=True)]
        flows += [nf.flows.ActNorm(latent_size)]

    # Set prior and q0
    if resampled:
        a = nf.nets.MLP([latent_size, 256, 256, 1], output_fn="sigmoid")
        q0 = lf.distributions.ResampledGaussian(latent_size, a, 100, 0.1, trainable=False)
    else:
        q0 = nf.distributions.DiagGaussian(latent_size, trainable=False)

    # Construct flow model
    model = lf.NormalizingFlow(q0=q0, flows=flows, p=p)

    # Move model on GPU if available
    return model.to(device)


# Function to train model

def train(model, max_iter=20000, num_samples=2 ** 10, lr=1e-3, weight_decay=1e-5, 
          q0_weight_decay=1e-4):
    optimizer = torch.optim.Adam(model.parameters(),  lr=lr, weight_decay=weight_decay)
    model.train()

    for it in tqdm(range(max_iter)):
        
        x = model.p.sample(num_samples)

        loss = model.forward_kld(x)

        loss.backward()
        optimizer.step()
        
        # Make layers Lipschitz continuous
        nf.utils.update_lipschitz(model, 5)

        # Clear gradients
        nf.utils.clear_grad(model)

In [None]:
# Plot function
def plot_results(model, target=True, a=False, save=False, prefix=''):
    # Prepare z grid for evaluation
    grid_size = 300
    xx, yy = torch.meshgrid(torch.linspace(-3, 3, grid_size), torch.linspace(-3, 3, grid_size))
    zz = torch.cat([xx.unsqueeze(2), yy.unsqueeze(2)], 2).view(-1, 2)
    zz = zz.to(device)
    
    log_prob = model.p.log_prob(zz).to('cpu').view(*xx.shape)
    prob = torch.exp(log_prob)
    prob[torch.isnan(prob)] = 0
    prob_target = prob.data.numpy()
    
    if target:
        plt.figure(figsize=(15, 15))
        plt.pcolormesh(xx, yy, prob_target)
        plt.gca().set_aspect('equal', 'box')
        plt.axis('off')
        if save:
            plt.savefig(prefix + 'target.png', dpi=300)
        plt.show()

    nf.utils.update_lipschitz(model, 200)
    with torch.no_grad():
        model.eval()
        log_prob = []
        for zz_ in torch.split(zz, 1024):
            log_prob.append(model.log_prob(zz_).to('cpu'))
        log_prob = torch.cat(log_prob).view(*xx.shape)

        prob = torch.exp(log_prob.to('cpu').view(*xx.shape))
        prob[torch.isnan(prob)] = 0
        prob_model = prob.data.numpy()

    plt.figure(figsize=(15, 15))
    plt.pcolormesh(xx, yy, prob_model)
    plt.gca().set_aspect('equal')
    plt.axis('off')
    if save:
        plt.savefig(prefix + 'model.png', dpi=300)
    plt.show()

    log_prob = model.q0.log_prob(zz).to('cpu').view(*xx.shape)
    prob = torch.exp(log_prob)
    prob[torch.isnan(prob)] = 0

    plt.figure(figsize=(15, 15))
    plt.pcolormesh(xx, yy, prob.data.numpy())
    plt.gca().set_aspect('equal')
    plt.axis('off')
    if save:
        plt.savefig(prefix + 'base.png', dpi=300)
    plt.show()
    
    if a:
        prob = model.q0.a(zz).to('cpu').view(*xx.shape)
        prob[torch.isnan(prob)] = 0

        plt.figure(figsize=(15, 15))
        plt.pcolormesh(xx, yy, prob.data.numpy())
        plt.gca().set_aspect('equal')
        plt.axis('off')
        if save:
            plt.savefig(prefix + 'a.png', dpi=300)
        plt.show()
    
    # Compute KLD
    eps = 1e-10
    kld = np.sum(prob_target * np.log((prob_target + eps) / (prob_model + eps)) * 6 ** 2 / grid_size ** 2)
    print(kld)

In [None]:
# Train models
p = [nf.distributions.TwoMoons(), nf.distributions.CircularGaussianMixture(), nf.distributions.RingMixture()]
name = ['moons', 'circle', 'rings']

for i in range(len(p)):
    # Train model with Gaussain base distribution
    model = create_model(p[i], False)
    train(model)
    # Plot and save results
    plot_results(model, save=True,
                 prefix='results/2d_toy_experiments/fkld/resflow/' 
                 + name[i] + '_gauss_')
    
    # Train model with resampled base distribution
    model = create_model(p[i], True)
    train(model, weight_decay=1e-3)
    # Plot and save results
    plot_results(model, save=True, a=True,
                 prefix='results/2d_toy_experiments/fkld/resflow/' 
                 + name[i] + '_resampled_')