In [1]:
given_data_size = 5000

eta = 5
optimizer_lr = 8e-5
total_trial_num = 5        # num trial
n_max_mirror_iterations=30       # num mirror steps
n_max_iterations=1000        # max inner steps
patient_max = n_max_iterations          # max patient
stopping_norm = 1e-4
num_samples = 3000         # number of theta generated
verbose = True

set_seed = 114530 # start from 114530
pretrain_factor = 4
outer_lr_final_factor = 1e-1
outer_eta_final_factor = 1e0
max_flow_length = 40
min_flow_length = 20

flow_inc_length = 4
flow_width = 512

import math
outer_lr_factor = pow(math.e, math.log(outer_lr_final_factor)/n_max_mirror_iterations)
outer_eta_factor = pow(math.e, math.log(outer_eta_final_factor)/n_max_mirror_iterations)

In [2]:
!pip install normflows

Collecting normflows
  Downloading normflows-1.7.2.tar.gz (64 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/64.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.8/64.8 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: normflows
  Building wheel for normflows (setup.py) ... [?25l[?25hdone
  Created wheel for normflows: filename=normflows-1.7.2-py2.py3-none-any.whl size=86917 sha256=f334102a1c83c6540b77ada5da1c19fbaf68e4734299e93aea2fba7b8100dc7e
  Stored in directory: /root/.cache/pip/wheels/8a/a4/89/3e09f53a561355c45eccfebeffc07a0e34d36a3f41e3ef68a3
Successfully built normflows
Installing collected packages: normflows
Successfully installed normflows-1.7.2


In [3]:
# Import required packages
import torch
import numpy as np
import normflows as nf
import os
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import torch.nn.functional as F
import torch.distributions as TD
import pandas as pd
import seaborn as sb
import torch.nn as nn
import shutil
import gc
import copy
import time

In [4]:
torch.manual_seed(114514)
# Move model on GPU if available
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
contaminated_data_m =  nf.distributions.base.GaussianMixture(1,2, loc=[[0.0, 0.0]],scale=[[1.0, 1.0]]).to(device)
sampled_theta_m = nf.distributions.TwoMoons().to(device)

cov_mx = torch.tensor([[1, 0], [0, 1.]], dtype=torch.double).to(device)

In [8]:
def find_L_n_faster(x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  std_normal = TD.MultivariateNormal(
            torch.zeros(num_samples, 2).to(device),
            torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1).to(torch.double).to(device)

  exp_mat = theta_list - x_data_rep
  logphi = std_normal.log_prob(exp_mat)
  return torch.mean(torch.log(torch.tensor(num_samples)) - torch.logsumexp(logphi, dim=1))


def find_first_variation_var_faster(x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  std_normal = TD.MultivariateNormal(
            torch.zeros(num_samples, 2).to(device),
            torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1).to(device)
  log_prob_mx = std_normal.log_prob(theta_list - x_data_rep)
  m = torch.nn.Softmax(dim=1)
  temp_mx = m(log_prob_mx)

  final_mx = temp_mx
  col_mean = -torch.mean(num_samples*final_mx, dim=0)
  return torch.var(col_mean)

def find_first_variation_inner_var_faster(x, theta_list, tau, model2, model1):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  std_normal = TD.MultivariateNormal(
            torch.zeros(num_samples, 2).to(device),
            torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1).to(device)
  log_prob_mx = std_normal.log_prob(theta_list - x_data_rep)
  m = torch.nn.Softmax(dim=1)
  temp_mx = m(log_prob_mx)

  final_mx = temp_mx
  first_variation = torch.mean(num_samples*final_mx, dim=0)

  sec_term = (1/tau) * (model2.log_prob(theta_list) - model1.log_prob(theta_list))
  input = -1*first_variation  + sec_term
  first_variation_var = torch.mean(torch.abs(input))
  return torch.var(input )

In [None]:
sampled_theta_m = nf.distributions.TwoMoons().to(device)

Ln_rho_k_list = np.zeros((total_trial_num, n_max_mirror_iterations))
first_variation_k_2d = np.zeros((total_trial_num, n_max_mirror_iterations))
times_2d = np.zeros((total_trial_num, n_max_mirror_iterations))

dim1_normal = TD.MultivariateNormal(
    torch.zeros(1).to(device), 1 * torch.eye(1).to(device))

for trail_num in range(total_trial_num):

  torch.manual_seed(set_seed + trail_num)
  mirror_loss_hist = np.array([])
  first_variation_k_1d= np.array([])
  times = np.array([])

  sampled_mean = sampled_theta_m.sample(given_data_size).to(torch.float32).to(device)
  normal_temp = TD.MultivariateNormal(
        sampled_mean.to(device),
        torch.eye(2).to(device).unsqueeze(0).repeat(given_data_size, 1, 1).to(device),
        validate_args=False)

  given_data = normal_temp.sample().detach().to(device)

  # Set up model
  # Pretrain the model
  target_pretrain = TD.MultivariateNormal(
      torch.zeros(2).to(device), pretrain_factor * torch.eye(2).to(device))
  # Define 2D Gaussian base distribution
  # base = nf.distributions.base.DiagGaussian(2)
  base = nf.distributions.DiagGaussian(2, trainable=False)
  # Define list of flows
  num_layers = flow_inc_length
  flows = []
  for i in range(num_layers):
      # Neural network with two hidden layers having 64 units each
      # Last layer is initialized by zeros making training more stable
      param_map = nf.nets.MLP([1, flow_width, flow_width, 2], init_zeros=True )
      # Add flow layer
      flows.append(nf.flows.AffineCouplingBlock(param_map))
      # Swap dimensions
      flows.append(nf.flows.Permute(2, mode='swap'))

  # Construct flow model
  model = nf.NormalizingFlow(base, flows).to(device)

  # covariance_matrix (Tensor) – positive-definite covariance matrix
  optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)

  for it in tqdm(range(200)):
      optimizer.zero_grad()
      x = target_pretrain.sample((num_samples,)).to(device)
      loss = model.forward_kld(x)

      # Do backprop and optimizer step
      if ~(torch.isnan(loss) | torch.isinf(loss)):
          loss.backward()
          optimizer.step()

      if it % 100 == 99:
        clear_output(wait=True)
        print('Loss:', loss.item())

  for name, param in model.named_parameters():
    if 'q0' not in name:
      param.requires_grad = False

  for mirror_itr in range(n_max_mirror_iterations):
    model_flows_copy = copy.deepcopy(model.flows)
    flows = [i for i in model_flows_copy] # copy the previous flows
    # add new ones
    for i in range(num_layers):
        param_map = nf.nets.MLP([1, flow_width, flow_width, 2], init_zeros=True )
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        flows.append(nf.flows.Permute(2, mode='swap'))

    model2 = nf.NormalizingFlow(base, flows).to(device)

    optimizer = torch.optim.Adam(model2.parameters(), lr=optimizer_lr * (outer_lr_factor**mirror_itr), weight_decay=1e-5)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.9)
    current_patient = 0
    epcoh_min = 200

    model2.train()
    torch.cuda.synchronize()
    start_epoch = time.time()

    input_eta = (eta*(outer_eta_factor**mirror_itr))
    for it in tqdm(range(n_max_iterations), disable = not verbose):
      optimizer.zero_grad()
      torch.manual_seed(set_seed)
      z = base.sample(num_samples)
      log_prob_rho_0 = base.log_prob(z)
      sampled_theta, log_det_model = model2.forward_and_log_det(z)
      log_prob_model = log_prob_rho_0 - log_det_model
      sampled_theta_old, log_det_model_old = model.forward_and_log_det(z)
      log_prob_prev = model.log_prob(sampled_theta)

      L_n_loss = find_L_n_faster(x=given_data.to(torch.double), theta_list=sampled_theta.to(torch.double))
      kld_loss = torch.mean(log_prob_model).to(torch.double) - torch.mean(log_prob_prev).to(torch.double)
      kld_loss = kld_loss if kld_loss.item() >= 0 else torch.tensor([0.0]).to(device)
      kld_loss = kld_loss if kld_loss.item() <= 500 else torch.tensor([500.0]).to(device)
      loss = L_n_loss  + (1/input_eta)*kld_loss

      # Do backprop and optimizer step
      if ~(torch.isnan(loss) | torch.isinf(loss)):
        loss.backward()
        grads = [param.grad.detach().flatten()
            for param in model2.parameters()
            if param.grad is not None]
        norm = torch.cat(grads).norm()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
      torch.manual_seed(set_seed)
      theta_temp, _ = model2.sample(num_samples)
      FVV2 = find_first_variation_var_faster(x=given_data, theta_list = theta_temp)
      FVV1 = find_first_variation_inner_var_faster(x=given_data, theta_list=theta_temp, tau=input_eta, model2=model2, model1=model)
      if verbose:
        if it % 100 == 99:
          # clear_output(wait=True)
          print('trail:',trail_num,'m step:',mirror_itr,'Loss:', loss.item(), ' L_n:', L_n_loss.item(), ' kld:',kld_loss.item())
          print('norm:', norm.item(), 'FVV1:', FVV1.item(), 'FVV2:', FVV2.item())



    # After finishing the inner loop
    torch.cuda.synchronize()
    end_epoch = time.time()
    elapsed = end_epoch - start_epoch
    times = np.append(times, elapsed)

    model2_flows_copy = copy.deepcopy(model2.flows)

    print("the current flow length is", str(len(model2_flows_copy)/2))
    model = nf.NormalizingFlow(base, model2_flows_copy).to(device)
    model.load_state_dict(model2.state_dict())
    for name, param in model.named_parameters():
      if 'q0' not in name:
        param.requires_grad = False

    torch.manual_seed(set_seed)
    with torch.no_grad():
      z = base.sample(num_samples)
      log_prob_rho_0 = base.log_prob(z)
      generated1, log_porb = model.forward_and_log_det(z)

    first_variation_k_temp = find_first_variation_var_faster(x=given_data, theta_list = generated1)
    first_variation_k_1d = np.append(first_variation_k_1d, first_variation_k_temp.to('cpu').data.numpy())

    L_n_loss_temp = find_L_n_faster(given_data, generated1)
    mirror_loss_hist = np.append(mirror_loss_hist, L_n_loss_temp.to('cpu').data.numpy())

    # Student-Teacher Data Distilling: teacher: model; student: model3
    model_flow_temp = copy.deepcopy(model.flows)
    if len(model_flow_temp) >= 2*max_flow_length :
      flows = []
      for i in range(min_flow_length):
        param_map = nf.nets.MLP([1, flow_width, flow_width, 2], init_zeros=True )
        flows.append(nf.flows.AffineCouplingBlock(param_map))
        flows.append(nf.flows.Permute(2, mode='swap'))
      model3 = nf.NormalizingFlow(base, flows).to(device)
      optimizer_m3 = torch.optim.Adam(model3.parameters(), lr=1e-5, weight_decay=1e-5)

      for it in tqdm(range(3000)):
        optimizer_m3.zero_grad()

        # x, _ = model.sample(num_samples)
        # loss = model3.forward_kld(x)
        z = base.sample(num_samples)
        log_prob_rho_0 = base.log_prob(z)
        sampled_theta_short, log_det_model_short = model3.forward_and_log_det(z)
        log_prob_model = log_prob_rho_0 - log_det_model_short
        sampled_theta_long, log_det_model_long = model.forward_and_log_det(z)
        log_prob_long_model = model.log_prob(sampled_theta_short)
        # loss = torch.mean(log_prob_model).to(torch.double) - torch.mean(log_prob_long_model).to(torch.double)
        loss = torch.mean((sampled_theta_short - sampled_theta_long)**2)
        # Do backprop and optimizer step
        if ~(torch.isnan(loss) | torch.isinf(loss)):
            loss.backward()
            optimizer_m3.step()

        if it % 100 == 99:
          # clear_output(wait=True)
          print('Student Teacher Loss:', loss.item())

        if loss.item() <= 0.0001:
          break

      model3_flows_copy = copy.deepcopy(model3.flows)
      model = nf.NormalizingFlow(base, model3_flows_copy).to(device)
      model.load_state_dict(model3.state_dict())
      for name, param in model.named_parameters():
        if 'q0' not in name:
          param.requires_grad = False

  Ln_rho_k_list[trail_num,:] = mirror_loss_hist
  first_variation_k_2d[trail_num:] = first_variation_k_1d
  times_2d[trail_num:] = times