In [7]:
given_data_size = 5000
total_trial_num = 10
m = 3000
eta = 1e0
t_max = 25
set_seed = 114530
pretrain_factor = 4

In [8]:
!pip install normflows



In [9]:
# Import required packages
import torch
import numpy as np
import normflows as nf
import os
import numpy as np
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import torch.nn.functional as F
import torch.distributions as TD
import pandas as pd
import seaborn as sb
import torch.nn as nn
import shutil
import gc

In [10]:
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')
sampled_theta_m = nf.distributions.TwoMoons().to(device)

cov_mx = torch.tensor([[1, 0.0], [0.0, 1.]], dtype=torch.double).to(device)


In [11]:
def get_p1_bottom_faster(wgt, x, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]

  std_normal = TD.MultivariateNormal(
            torch.zeros(num_samples, 2).to(device),
            cov_mx.unsqueeze(0).repeat(num_samples, 1, 1),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(theta_list - x_data_rep))
  wgt_rep = wgt.repeat(given_data_size, 1)
  return torch.nanmean(prob_mx*wgt_rep, 1)*given_data_size

def find_L_n_ver4(x, wgt, theta_list):
  num_samples = theta_list.shape[0]
  given_data_size = x.shape[0]
  mu = theta_list
  std_normal = TD.MultivariateNormal(
            mu.to(device),
            cov_mx.to(device).unsqueeze(0).repeat(num_samples, 1, 1).to(device),
            validate_args=False)
  x_data_rep = x.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
  prob_mx = torch.exp(std_normal.log_prob(x_data_rep))
  prob_mx_log_col_mean = torch.log(torch.sum(prob_mx*wgt.repeat(given_data_size,1), dim = 1))
  return -torch.mean(prob_mx_log_col_mean)

In [None]:
Ln_rho_k_list = np.zeros((total_trial_num, t_max))

target_pretrain = TD.MultivariateNormal(
    torch.zeros(2).to(device), pretrain_factor * torch.eye(2).to(device))

for trail_num in range(total_trial_num):

  torch.manual_seed(set_seed + trail_num)

  sampled_mean = sampled_theta_m.sample(given_data_size).to(device).to(torch.float64)
  normal_temp = TD.MultivariateNormal(
        sampled_mean.to(device),
        cov_mx.to(device).unsqueeze(0).repeat(given_data_size, 1, 1).to(device),
        validate_args=False)

  given_data = normal_temp.sample().detach().to(device)

  mu = target_pretrain.sample((m,))
  wgt = torch.tensor([1/m], dtype=torch.float64).repeat(m).to(device)
  L_n_loss_list = np.array([])

  folder_name = 'trial' + str(trail_num)
  for t in tqdm(range(t_max)):

    num_samples = m
    given_data_size = given_data.shape[0]

    std_normal = TD.MultivariateNormal(
              torch.zeros(num_samples, 2).to(device),
              cov_mx.unsqueeze(0).repeat(num_samples, 1, 1),
              validate_args=False)

    old_bottom = get_p1_bottom_faster(wgt, given_data, mu).repeat(num_samples, 1).T
    x_data_rep = given_data.repeat(1,1,num_samples).reshape(given_data_size,num_samples,-1)
    prob_mx = torch.exp(std_normal.log_prob(mu - x_data_rep))

    second_term = x_data_rep - mu
    mu_update1 = torch.nanmean(prob_mx*second_term[:,:,0]/old_bottom, 0)
    mu_update2 = torch.nanmean(prob_mx*second_term[:,:,1]/old_bottom, 0)
    mu[:,0] = mu[:,0] + eta*mu_update1
    mu[:,1] = mu[:,1] + eta*mu_update2

    prob_mx_new = torch.transpose(torch.exp(std_normal.log_prob(mu - x_data_rep)), 0, 1)
    new_bottom = get_p1_bottom_faster(wgt, given_data, mu).repeat(1,1,num_samples).reshape(num_samples,given_data_size)
    wgt_update = torch.nanmean(prob_mx_new/new_bottom,1) - 1
    wgt = wgt + eta*wgt_update*wgt
    wgt = wgt/torch.sum(wgt)

    L_n_loss = find_L_n_ver4(given_data, wgt, mu)
    L_n_loss_list = np.append(L_n_loss_list, L_n_loss.to('cpu').data.numpy())

    if t % 5 == 4:
      print('trail:',trail_num, 'L_n :', L_n_loss.item() )

  Ln_rho_k_list[trail_num,:] = L_n_loss_list