In [1]:
given_data_size = 5000
total_trial_num = 10
eta = 1e0
t_max = 50
set_seed = 114530

In [2]:
!pip install normflows

Collecting normflows
  Downloading normflows-1.7.2.tar.gz (64 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/64.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.8/64.8 kB[0m [31m1.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: normflows
  Building wheel for normflows (setup.py) ... [?25l[?25hdone
  Created wheel for normflows: filename=normflows-1.7.2-py2.py3-none-any.whl size=86917 sha256=5582a56299cff4aa6325dd449eacea68fda8a1da9d6040694921b36e5ea86ff7
  Stored in directory: /root/.cache/pip/wheels/8a/a4/89/3e09f53a561355c45eccfebeffc07a0e34d36a3f41e3ef68a3
Successfully built normflows
Installing collected packages: normflows
Successfully installed normflows-1.7.2


In [3]:
# Import required packages
import torch
import numpy as np
import normflows as nf
import os
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import torch.nn.functional as F
import torch.distributions as TD
import pandas as pd
import seaborn as sb
import torch.nn as nn
import shutil
import gc

In [4]:
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
sampled_theta_m = nf.distributions.TwoMoons().to(device)
cov_mx = torch.tensor([[1, 0.0], [0.0, 1.]], dtype=torch.double).to(device)

In [5]:
def get_p1_bottom_faster(wgt, x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  std_normal = TD.MultivariateNormal(
            torch.zeros(num_samples, 2).to(device),
            cov_mx.unsqueeze(0).repeat(num_samples, 1, 1),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(theta_list - x_data_rep))
  wgt_rep = wgt.repeat(given_data_size, 1)
  return torch.nanmean(prob_mx*wgt_rep, 1)*given_data_size

def get_p1_bottom_faster2KW(wgt, x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]
  mu = theta_list[:,0:2]
  var = theta_list[:,2:4]

  std_normal = TD.MultivariateNormal(
            mu.to(device),
            var.repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  wgt_rep = wgt.repeat(given_data_size, 1)
  return torch.nanmean(prob_mx*wgt_rep, 1)*given_data_size

def find_L_n_ver2(x, wgt, mu):
  num = mu.shape[0]
  # Define the Multivariate Std Normal
  std_normal = TD.MultivariateNormal(
            torch.zeros(num, 2).to(device),
            cov_mx.unsqueeze(0).repeat(num, 1, 1),
            validate_args=False)
  t_list = []
  for k in x:
    prob_temp = (wgt*torch.exp(std_normal.log_prob(mu - k))).view(-1)
    prob_temp = prob_temp[~torch.any(prob_temp.isnan(),dim=0)]
    out = torch.log(torch.nanmean(prob_temp)*num)
    out = out[~torch.any(out.isnan(),dim=0)]
    t_list.append(out)
  t_list = torch.stack(t_list)
  return -torch.nanmean(t_list)

def find_L_n_faster2KW(x, wgt, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  mu = theta_list[:,0:2]
  var = theta_list[:,2:4]

  std_normal = TD.MultivariateNormal(
            mu.to(device),
            var.repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  prob_mx_log_col_mean = torch.log(torch.mean(prob_mx, dim = 1))
  return -torch.mean(prob_mx_log_col_mean)

def find_L_n_ver3(x, wgt, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]
  mu = theta_list[:,0:2]
  var = theta_list[:,2:4]
  std_normal = TD.MultivariateNormal(
            mu.to(device),
            var.repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  prob_mx_log_col_mean = torch.log(torch.sum(prob_mx*wgt.repeat(given_data_size,1), dim = 1))
  return -torch.mean(prob_mx_log_col_mean)

In [None]:
Ln_rho_k_list = np.zeros((total_trial_num, t_max))
std_normal2 = TD.MultivariateNormal(
          torch.zeros(2).to(device),
          torch.eye(2).to(device),
          validate_args=False)
sampled_mu = nf.distributions.TwoMoons().to(device)
for trail_num in range(total_trial_num):

  torch.manual_seed(set_seed + trail_num)
  mean = sampled_mu.sample(given_data_size)
  var = std_normal2.sample((given_data_size,)) ** 2

  normal_temp = TD.MultivariateNormal(
          mean.to(device),
          var.repeat(1,1,2).reshape(given_data_size,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(given_data_size, 1, 1).to(device),
          validate_args=False)

  given_data = normal_temp.sample()

  grid_size = 7
  L = torch.max(torch.absolute(given_data)).item()

  grid_x, grid_y, grid_z, grid_a = torch.meshgrid(torch.linspace(-L, L, grid_size), torch.linspace(-L, L, grid_size), torch.linspace(0.01, 4, grid_size), torch.linspace(0.01, 4, grid_size))
  zz = torch.cat([grid_x.unsqueeze(4), grid_y.unsqueeze(4), grid_z.unsqueeze(4), grid_a.unsqueeze(4) ], 4).view(-1, 4)
  zz = zz.to(device)
  num_samples = grid_size ** 4
  theta_list = zz
  wgt = torch.tensor([1/num_samples], dtype=torch.float64).repeat(num_samples).to(device)

  L_n_loss_list = np.array([])


  folder_name = 'trial' + str(trail_num)
  for t in tqdm(range(t_max)):
    given_data_size = given_data.shape[0]

    std_normal = TD.MultivariateNormal(
              theta_list[:,0:2].to(device),
              theta_list[:,2:4].repeat(1,1,2).reshape(num_samples,2,2).to(device)*torch.eye(2).to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
              validate_args=False)

    x_data_rep = given_data.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
    prob_mx_new = torch.transpose(torch.exp(std_normal.log_prob(x_data_rep)), 0, 1)
    new_bottom = get_p1_bottom_faster2KW(wgt, given_data, theta_list).repeat(1,1,num_samples).reshape(num_samples,given_data_size)
    wgt_update = torch.nanmean(prob_mx_new/new_bottom,1) - 1
    wgt = wgt + eta*wgt_update*wgt
    wgt = wgt/torch.sum(wgt)

    L_n_loss = find_L_n_ver3(given_data, wgt, theta_list)
    L_n_loss_list = np.append(L_n_loss_list, L_n_loss.to('cpu').data.numpy())
    if t % 5 == 4:
      # clear_output(wait=True)
      print('trail:',trail_num, 'L_n :', L_n_loss.item() )


  # at the end of trails
  Ln_rho_k_list[trail_num,:] = L_n_loss_list
