In [1]:
given_data_size = 5000

eta = 5
optimizer_lr = 1e-4
total_trial_num = 10        # num trial
n_max_mirror_iterations=50       # num mirror steps
n_max_iterations=1000        # max inner steps
patient_max = 200          # max patient
stopping_norm = 1e-4
num_samples = 2401         # number of theta generated
verbose = True
set_seed = 114530 # start from 114530
pretrain_factor = 4

outer_eta_factor = 1
outer_lr_factor = 0.912


In [2]:
!pip install normflows

Collecting normflows
  Downloading normflows-1.7.2.tar.gz (64 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/64.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.8/64.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: normflows
  Building wheel for normflows (setup.py) ... [?25l[?25hdone
  Created wheel for normflows: filename=normflows-1.7.2-py2.py3-none-any.whl size=86917 sha256=9b13e842b0f62bcfeca812c5d26eb624420899237cb5a1e880d1d9d5b4709caf
  Stored in directory: /root/.cache/pip/wheels/8a/a4/89/3e09f53a561355c45eccfebeffc07a0e34d36a3f41e3ef68a3
Successfully built normflows
Installing collected packages: normflows
Successfully installed normflows-1.7.2


In [3]:
# Import required packages
import torch
import numpy as np
import normflows as nf
import os
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import torch.nn.functional as F
import torch.distributions as TD
import pandas as pd
import seaborn as sb
import torch.nn as nn
import shutil
import gc
import copy
import time

In [4]:
torch.manual_seed(114514)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

In [5]:
def find_L_n_faster_LS(x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  mu = theta_list[:,0:2]
  var = theta_list[:,2:4] ** 2

  std_normal = TD.MultivariateNormal(
            mu.to(device),
            var.repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  prob_mx_log_col_mean = torch.log(torch.mean(prob_mx, dim = 1))
  return -torch.mean(prob_mx_log_col_mean)

def find_first_variation_var_faster_LS(x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  mu = theta_list[:,0:2]
  var = theta_list[:,2:4] ** 2

  std_normal = TD.MultivariateNormal(
            mu.to(device),
            var.repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  row_mean = torch.transpose(torch.mean(prob_mx, dim = 1).repeat(num_samples, 1), 0, 1)
  first_variation = -torch.mean(prob_mx/row_mean, dim = 0)
  first_variation_var = torch.var(first_variation)
  return first_variation_var

In [None]:
Ln_rho_k_list = np.zeros((total_trial_num, n_max_mirror_iterations))
first_variation_k_2d = np.zeros((total_trial_num, n_max_mirror_iterations))
times_2d = np.zeros((total_trial_num, n_max_mirror_iterations))

dim1_normal = TD.MultivariateNormal(
    torch.zeros(1).to(device), 1 * torch.eye(1).to(device))

std_normal2 = TD.MultivariateNormal(
          torch.zeros(2).to(device),
          torch.eye(2).to(device),
          validate_args=False)
sampled_mu = nf.distributions.TwoMoons().to(device)
for trail_num in range(total_trial_num):

  torch.manual_seed(set_seed + trail_num)
  mirror_loss_hist = np.array([])
  first_variation_k_1d= np.array([])
  times = np.array([])


  mean = sampled_mu.sample(given_data_size)
  var = std_normal2.sample((given_data_size,)) ** 2

  normal_temp = TD.MultivariateNormal(
          mean.to(device),
          var.repeat(1,1,2).reshape(given_data_size,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(given_data_size, 1, 1).to(device),
          validate_args=False)

  given_data = normal_temp.sample().detach().to(device)

  # Set up model
  # Pretrain the model
  target_pretrain = TD.MultivariateNormal(
      torch.zeros(4).to(device), pretrain_factor * torch.eye(4).to(device))
  # Define 2D Gaussian base distribution
  base = nf.distributions.DiagGaussian(4, trainable=False)
  # Define list of flows
  num_layers = 30
  flows = []
  for i in range(num_layers):
      # Neural network with two hidden layers having 64 units each
      # Last layer is initialized by zeros making training more stable
      param_map = nf.nets.MLP([2, 64, 64, 4], init_zeros=True )
      # Add flow layer
      flows.append(nf.flows.AffineCouplingBlock(param_map))
      # Swap dimensions
      flows.append(nf.flows.Permute(4, mode='swap'))


  # Construct flow model
  model = nf.NormalizingFlow(base, flows).to(device)
  optimizer1 = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

  for it in tqdm(range(1000)):
      optimizer1.zero_grad()

      # Get training samples
      x = target_pretrain.sample((num_samples,)).to(device)

      # Compute loss
      loss = model.forward_kld(x)

      # Do backprop and optimizer step
      if ~(torch.isnan(loss) | torch.isinf(loss)):
          loss.backward()
          optimizer1.step()

      if it % 100 == 99:
        clear_output(wait=True)
        print('Loss:', loss.item())


  for mirror_itr in range(n_max_mirror_iterations):
    flows = []
    for i in range(num_layers):
        param_map = nf.nets.MLP([2, 64, 64, 4], init_zeros=True )
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        flows.append(nf.flows.Permute(4, mode='swap'))

    model2 = nf.NormalizingFlow(base, flows).to(device)
    model2.load_state_dict(model.state_dict())

    a = 250/9
    b = 250/9
    optimizer_lr_input = optimizer_lr* (a)/(mirror_itr + b)
    optimizer = torch.optim.Adam(model2.parameters(), lr=optimizer_lr_input, weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=1)
    current_patient = 0
    epcoh_min = 200

    model2.train()
    torch.cuda.synchronize()
    start_epoch = time.time()

    input_eta = (eta*(outer_eta_factor**mirror_itr))
    for it in tqdm(range(n_max_iterations), disable = not verbose):
      optimizer.zero_grad()
      torch.manual_seed(set_seed)
      z = base.sample(num_samples)

      torch.manual_seed(set_seed + it)
      unif = torch.ones(given_data.shape[0]).to(device)
      idx = unif.multinomial(500, replacement=False).to(device)
      given_data_sub = given_data[idx].to(device)

      log_prob_rho_0 = base.log_prob(z)
      sampled_theta, log_det_model = model2.forward_and_log_det(z)
      log_prob_model = log_prob_rho_0 - log_det_model
      # sampled_theta, log_prob_model = model2.sample(num_samples)
      log_prob_prev = model.log_prob(sampled_theta)

      L_n_loss = find_L_n_faster_LS(x=given_data_sub.to(torch.double), theta_list=sampled_theta.to(torch.double))
      kld_loss = torch.mean(log_prob_model).to(torch.double) - torch.mean(log_prob_prev).to(torch.double)
      kld_loss = kld_loss if kld_loss.item() >= 0 else torch.tensor([0.0]).to(device)
      kld_loss = kld_loss if kld_loss.item() <= 5 else torch.tensor([5.0]).to(device)
      loss = L_n_loss  + (1/input_eta)*kld_loss
      # Do backprop and optimizer step
      if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        grads = [param.grad.detach().flatten()
            for param in model2.parameters()
            if param.grad is not None]
        norm = torch.cat(grads).norm()


        optimizer.step()
        optimizer.zero_grad()
        scheduler.step()

      if verbose:
        if it % 100 == 99:
          # clear_output(wait=True)
          print('trail:',trail_num,'m step:',mirror_itr,'Loss:', loss.item(), ' L_n:', L_n_loss.item(), ' kld:',kld_loss.item())

      if norm.item() > epcoh_min:
        current_patient = current_patient + 1

      if norm.item() < epcoh_min:
        epcoh_min = norm.item()
        current_patient = 0

      if (current_patient >= patient_max) or (norm.item() < stopping_norm):
        break

    # After finishing the inner loop
    torch.cuda.synchronize()
    end_epoch = time.time()
    elapsed = end_epoch - start_epoch
    times = np.append(times, elapsed)

    model.load_state_dict(model2.state_dict())

    torch.manual_seed(set_seed)
    with torch.no_grad():
      z = base.sample(num_samples)
      log_prob_rho_0 = base.log_prob(z)
      generated1, log_porb = model.forward_and_log_det(z)

    first_variation_k_temp = find_first_variation_var_faster_LS(x=given_data, theta_list = generated1)
    first_variation_k_1d = np.append(first_variation_k_1d, first_variation_k_temp.to('cpu').data.numpy())
    L_n_loss_temp = find_L_n_faster_LS(given_data, generated1)
    mirror_loss_hist = np.append(mirror_loss_hist, L_n_loss_temp.to('cpu').data.numpy())

  Ln_rho_k_list[trail_num,:] = mirror_loss_hist
  first_variation_k_2d[trail_num:] = first_variation_k_1d
  times_2d[trail_num:] = times