In [1]:
M = 1000
pretrain_factor = 9
total_steps = 100
step_size = 1e-2
total_trails = 10
set_seed = 114540


In [2]:
# Import required packages
import torch
import numpy as np
import os
import math
from matplotlib import pyplot as plt
from tqdm import tqdm
from IPython.display import clear_output
import torch.nn.functional as F
import torch.distributions as TD
import pandas as pd
import seaborn as sb
import torch.nn as nn
import shutil
import gc
import copy
import math
from scipy.optimize import linear_sum_assignment

In [3]:
enable_cuda = True
device = torch.device('cuda' if torch.cuda.is_available() and enable_cuda else 'cpu')

target_pretrain = TD.MultivariateNormal(
    torch.zeros(2).to(device), pretrain_factor * torch.eye(2).to(device))
std_normal = TD.MultivariateNormal(
    torch.zeros(2).to(device), torch.eye(2).to(device))

In [4]:
class Target(nn.Module):
    """
    Sample target distributions to test models
    """

    def __init__(self, prop_scale=torch.tensor(6.0), prop_shift=torch.tensor(-3.0)):
        """Constructor

        Args:
          prop_scale: Scale for the uniform proposal
          prop_shift: Shift for the uniform proposal
        """
        super().__init__()
        self.register_buffer("prop_scale", prop_scale)
        self.register_buffer("prop_shift", prop_shift)

    def log_prob(self, z):
        """
        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        raise NotImplementedError("The log probability is not implemented yet.")

    def rejection_sampling(self, num_steps=1):
        """Perform rejection sampling on image distribution

        Args:
          num_steps: Number of rejection sampling steps to perform

        Returns:
          Accepted samples
        """
        eps = torch.rand(
            (num_steps, self.n_dims),
            dtype=self.prop_scale.dtype,
            device=self.prop_scale.device,
        )
        z_ = self.prop_scale * eps + self.prop_shift
        prob = torch.rand(
            num_steps, dtype=self.prop_scale.dtype, device=self.prop_scale.device
        )
        prob_ = torch.exp(self.log_prob(z_) - self.max_log_prob)
        accept = prob_ > prob
        z = z_[accept, :]
        return z

    def sample(self, num_samples=1):
        """Sample from image distribution through rejection sampling

        Args:
          num_samples: Number of samples to draw

        Returns:
          Samples
        """
        z = torch.zeros(
            (0, self.n_dims), dtype=self.prop_scale.dtype, device=self.prop_scale.device
        )
        while len(z) < num_samples:
            z_ = self.rejection_sampling(num_samples)
            ind = np.min([len(z_), num_samples - len(z)])
            z = torch.cat([z, z_[:ind, :]], 0)
        return z

In [5]:
class TargetDist(Target):
    """
    Bimodal two-dimensional distribution
    """
    def __init__(self):
        super().__init__()
        self.n_dims = 2
        self.max_log_prob = 0.0

    def log_prob(self, z):
        """
        ```
        log(p) = - 1/2 * ((norm(z) - 2) / 0.2) ** 2
                 + log(  exp(-1/2 * ((z[0] - 2) / 0.3) ** 2)
                       + exp(-1/2 * ((z[0] + 2) / 0.3) ** 2))
        ```

        Args:
          z: value or batch of latent variable

        Returns:
          log probability of the distribution for z
        """
        log_prob = (
            -1/4 * (torch.linalg.vector_norm(z, ord=2, dim=1)) ** 4
            - np.log(5.568327996831707845284817982118835702013624390283243910753675818829745533647795702212177687384708494)
        )
        return log_prob


In [None]:
objective_fun_loss_list_2d = np.zeros((total_trails, total_steps))
x_all_data = torch.zeros((total_steps, M, 2))

p = TargetDist()
for trail_num in range(total_trails):
  torch.manual_seed(set_seed + trail_num)

  target_sample = p.sample(M)
  # initial x sampled from rho_0
  x = target_pretrain.sample((M,)).to(device)

  objective_fun_loss_list = np.array([])

  for t in tqdm(range(total_steps)):
    l2_norm = torch.linalg.vector_norm(x, dim=1).reshape(-1,1).to(device)

    x_new = x - torch.cat([l2_norm, l2_norm], 1).to(device) ** 2 * x * step_size + math.sqrt(2 *  step_size) * std_normal.sample((M,)).to(device)
    x = x_new
    x_all_data[t,:,:] = x

  for t in tqdm(range(total_steps)):
    x = x_all_data[t,:,:]
    x_final = target_sample
    x_data_rep = x.repeat(1,1,M).reshape(M,M,-1)
    cost = torch.norm(x_data_rep - x_final, p = 1, dim = 2).cpu().numpy()
    row_ind, col_ind = linear_sum_assignment(cost)
    temp_loss = (cost[row_ind, col_ind].sum() / M )
    objective_fun_loss_list = np.append(objective_fun_loss_list, temp_loss )

  objective_fun_loss_list_2d[trail_num,:] = objective_fun_loss_list