In [None]:
cd ~/gpu/jtan/github/latent_separation

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt

import os

In [None]:
def GenerateSamples(d, rho, num_sample):
    xs = torch.randn(num_sample, d).cuda()
    noise = torch.randn(num_sample, d).cuda()
    ys = rho * xs + (1-rho**2)**0.5 * noise
    xs_ys = torch.cat([xs, ys], dim=-1)
    return xs_ys, xs, ys

In [None]:
def ComputeTrueValues(d, axis_min, axis_max):
    num_point = 256
    rho = np.linspace(axis_min, axis_max, num_point)
    mi = - np.log(1 - rho ** 2) * d / 2
    gradient = (rho * d) / (1 - rho ** 2)

    np.savez('true_{}.npz'.format(d), rho=rho, mi=mi, gradient=gradient)

    print('Complete: {}-d true values.'.format(d))
    print('  Save as: true_{}.npz'.format(d))
    
def _permute_dims_last_axes(joint, axes):
    """
    Randomly permutes the sample along last specified axes
    """
    perm = torch.clone(joint)
    batch_size, dim_z = perm.size()

    for axis in axes:
        pi = torch.randperm(batch_size).to(perm.device)
        perm[:, axis] = joint[pi, axis]

    return perm

def _shuffle_batch(z):
    
    batch_size, dim_z = z.size()
    pi = torch.randperm(batch_size).to(z.device)
    shuffled_z = z[pi]
    return shuffled_z

In [None]:
import math
import numpy as np

import torch
import torch.nn as nn
from torch.nn import functional as F

from models import network
from utils import distributions
from utils.math import (log_density_gaussian, log_importance_weight_matrix, 
    matrix_log_density_gaussian, gaussian_entropy)

In [None]:
class Critic(nn.Module):
    """
    MI critic
    """
    def __init__(self, n_layers=4, latent_dim=10, aux_dim=1, n_units=128, activation='leaky_relu'):
        super(Critic, self).__init__()
        
        self.latent_dim = latent_dim
        self.n_layers = n_layers
        self.n_units = n_units
        
        if activation == 'leaky_relu':
            self.activation = nn.LeakyReLU(negative_slope=0.2, inplace=True)
        else:
            self.activation = getattr(F, activation)
        
        self.dense_in = nn.Linear(latent_dim + aux_dim, n_units)
        self.dense1 = nn.Linear(n_units, n_units)
        self.dense2 = nn.Linear(n_units, n_units)
        self.dense3 = nn.Linear(n_units, n_units)

        # theoretically 1 with sigmoid but apparently bad results 
        # => use 2 and softmax
        # out_units = 2
        # self.dense_out = nn.Linear(n_units, out_units)

        out_units = 1
        self.dense_out = nn.Linear(n_units, out_units)
        
    
    def forward(self, z):
        
        x = self.activation(self.dense_in(z))
        x = self.activation(self.dense1(x))
        x = self.activation(self.dense2(x))
        x = self.activation(self.dense3(x))

        out = self.dense_out(x)
        
        return out

In [None]:
class MI_matching(nn.Module):
    """
    Train critic to estimate density ratio and use density ratio estimate to construct a critic 
    for the KL lower bound.
    """
    def __init__(self, device, sensitive_latent_idx=[1], critic_kwargs={}, 
    optim_kwargs=dict(lr=5e-5), **kwargs):
        super().__init__(**kwargs)

        self.device = device
        self.critic = Critic(**critic_kwargs).to(self.device)
        print('Critic model for MI estimation:')
#         print(self.critic)
        self.CE_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') # torch.nn.CrossEntropyLoss(reduction='mean')
        self.opt_C = torch.optim.Adam(self.critic.parameters(), **optim_kwargs)

    def call_optimize(self, latent_factors, generative_factors, sensitive_latent_idx, latent_stats, 
        storage=None, **kwargs):
    
        batch_size = latent_stats.size(0)
        half_batch_size = batch_size // 2
        generative_factors, generative_factors_critic = torch.split(generative_factors, half_batch_size, dim=0)
        latent_factors, latent_factors_critic = torch.split(latent_factors, half_batch_size, dim=0)
        gf_dim = generative_factors.size(1)

        # generative_factors = torch.stack([generative_factors[:, sidx] for sidx in sensitive_latent_idx])
        joint_zm = torch.cat([latent_factors, generative_factors], dim=1)

        marginal_zm = torch.cat([latent_factors_critic, 
            _shuffle_batch(generative_factors_critic)], dim=1)

        V_z_joint = torch.squeeze(self.critic(joint_zm))
        V_z_marginal = torch.squeeze(self.critic(marginal_zm))

        # Lower bound on MI
        I_JS = (1. + V_z_joint.mean() - torch.exp(V_z_marginal).mean())
        # Optimize critic
        ones = torch.ones(half_batch_size, dtype=torch.float, device=self.device)
        zeros = torch.zeros_like(ones)

        critic_loss = 0.5 * (self.CE_loss(V_z_joint, zeros) + self.CE_loss(V_z_marginal, ones))
        critic_loss.backward()
        self.opt_C.step()
        self.opt_C.zero_grad()

        if storage is not None:
            storage['I_JS'].append(supervised_term.item())
        print('critic_loss', critic_loss.item())
        return I_JS

In [None]:
def toy_MI_calc(d, range_rho, num_sample, num_epoch, n_units, GenerateData):
    MI_approximation = []
    gradient_approximation = []
    device = 'cuda'
    for rho in range_rho:
        
        model1 = MI_matching(device, [0], critic_kwargs={'n_units': n_units, 'latent_dim': d, 'aux_dim': d})
        # optimizer = torch.optim.Adam(model1.parameters(), lr=1e-4)

        rho = torch.FloatTensor([rho]).to(device)
        rho.requires_grad = True

        for epoch in range(num_epoch):
            xs_ys, xs, ys = GenerateData(d, rho, num_sample)
            
            MI = -model1.call_optimize(xs, ys, [0], xs)
            print('mi',MI.item())
            if epoch == num_epoch - 1:
                MI_approximation.append(MI)
                gradient_approximation.append(-rho.grad.data)
            else:
                rho.grad.data.zero_()
#                 optimizer.step()

    gradient_approximation = torch.stack(gradient_approximation).view(-1).detach().cpu().numpy()
    MI_approximation = torch.stack(MI_approximation).view(-1).detach().cpu().numpy()
    return gradient_approximation, MI_approximation

In [None]:
def ComputeMINE_F(d, axis_min, axis_max, num_lines):
    print('Compute MINE-f(NWJ) estimation, d={}'.format(d))

    num_point = 30
    num_sample = 256
    num_epoch = 200
    num_units = 512

    rho = np.linspace(axis_min, axis_max, num_point)
    mi, gradient = [], []

    for i in tqdm(range(num_lines)):
        g, m = toy_MI_calc(d, rho, num_sample, num_epoch, num_units, GenerateSamples)
        mi.append(m)
        gradient.append(g)

    mi = np.stack(mi, axis=0)
    gradient = np.stack(gradient, axis=0)

    np.savez('I_js_{}.npz'.format(d), rho=rho, mi=mi, gradient=gradient)
    print('Complete: {} times {}-d mine-f estimation.'.format(num_lines, d))
    print('  Save as: I_js_{}.npz'.format(d))
    return rho, mi, gradient

In [None]:
class ScoreEstimator:
    def __init__(self):
        pass

    def rbf_kernel(self, x1, x2, kernel_width):
        return torch.exp(
            -torch.sum(torch.mul((x1 - x2), (x1 - x2)), dim=-1) / (2 * torch.mul(kernel_width, kernel_width))
        )

    def gram(self, x1, x2, kernel_width):
        x_row = torch.unsqueeze(x1, -2)
        x_col = torch.unsqueeze(x2, -3)
        kernel_width = kernel_width[..., None, None]
        return self.rbf_kernel(x_row, x_col, kernel_width)

    def grad_gram(self, x1, x2, kernel_width):
        x_row = torch.unsqueeze(x1, -2)
        x_col = torch.unsqueeze(x2, -3)
        kernel_width = kernel_width[..., None, None]
        G = self.rbf_kernel(x_row, x_col, kernel_width)
        diff = (x_row - x_col) / (kernel_width[..., None] ** 2)
        G_expand = torch.unsqueeze(G, -1)
        grad_x2 = G_expand * diff
        grad_x1 = G_expand * (-diff)
        return G, grad_x1, grad_x2

    def heuristic_kernel_width(self, x_samples, x_basis):
        n_samples = x_samples.size()[-2]
        n_basis = x_basis.size()[-2]
        x_samples_expand = torch.unsqueeze(x_samples, -2)
        x_basis_expand = torch.unsqueeze(x_basis, -3)
        pairwise_dist = torch.sqrt(
            torch.sum(torch.mul(x_samples_expand - x_basis_expand, x_samples_expand - x_basis_expand), dim=-1)
        )
        k = n_samples * n_basis // 2
        top_k_values = torch.topk(torch.reshape(pairwise_dist, [-1, n_samples * n_basis]), k=k)[0]
        kernel_width = torch.reshape(top_k_values[:, -1], x_samples.size()[:-2])
        return kernel_width.detach()

    def compute_gradients(self, samples, x=None):
        raise NotImplementedError()
        

In [None]:
class SpectralScoreEstimator(ScoreEstimator):
    def __init__(self, n_eigen=None, eta=None, n_eigen_threshold=None):
        self._n_eigen = n_eigen
        self._eta = eta
        self._n_eigen_threshold = n_eigen_threshold
        super().__init__()

    def nystrom_ext(self, samples, x, eigen_vectors, eigen_values, kernel_width):
        M = torch.tensor(samples.size()[-2]).to(samples.device)
        Kxq = self.gram(x, samples, kernel_width)
        ret = torch.sqrt(M.float()) * torch.matmul(Kxq, eigen_vectors)
        ret *= 1. / torch.unsqueeze(eigen_values, dim=-2)
        return ret

    def compute_gradients(self, samples, x=None):
        if x is None:
            kernel_width = self.heuristic_kernel_width(samples, samples)
            x = samples
        else:
            _samples = torch.cat([samples, x], dim=-2)
            kernel_width = self.heuristic_kernel_width(_samples, _samples)

        M = samples.size()[-2]
        Kq, grad_K1, grad_K2 = self.grad_gram(samples, samples, kernel_width)
        if self._eta is not None:
            Kq += self._eta * torch.eye(M)

        eigen_values, eigen_vectors = torch.symeig(Kq, eigenvectors=True, upper=True)

        if (self._n_eigen is None) and (self._n_eigen_threshold is not None):
            eigen_arr = torch.mean(
                torch.reshape(eigen_values, [-1, M]), dim=0)

            eigen_arr = torch.flip(eigen_arr, [-1])
            eigen_arr /= torch.sum(eigen_arr)
            eigen_cum = torch.cumsum(eigen_arr, dim=-1)
            eigen_lt = torch.lt(eigen_cum, self._n_eigen_threshold)
            self._n_eigen = torch.sum(eigen_lt)
        if self._n_eigen is not None:
            eigen_values = eigen_values[..., -self._n_eigen:]
            eigen_vectors = eigen_vectors[..., -self._n_eigen:]
        eigen_ext = self.nystrom_ext(samples, x, eigen_vectors, eigen_values, kernel_width)
        grad_K1_avg = torch.mean(grad_K1, dim=-3)
        M = torch.tensor(M).to(samples.device)
        beta = -torch.sqrt(M.float()) * torch.matmul(torch.transpose(eigen_vectors, -1, -2),
                                                     grad_K1_avg) / torch.unsqueeze(eigen_values, -1)
        grads = torch.matmul(eigen_ext, beta)
        self._n_eigen = None
        return grads

In [None]:
def entropy_surrogate(estimator, samples):
    dlog_q = estimator.compute_gradients(samples.detach(), None)
    surrogate_cost = torch.mean(torch.sum(dlog_q.detach() * samples, -1))
    return surrogate_cost


def MIGE(d, range_rho, num_sample, GenerateData, threshold=None, n_eigen=None):
    spectral_j = SpectralScoreEstimator(n_eigen=n_eigen, n_eigen_threshold=threshold)
    spectral_m = SpectralScoreEstimator(n_eigen=n_eigen, n_eigen_threshold=threshold)
    approximations = []
    for rho in range_rho:
        rho = torch.FloatTensor([rho]).cuda()
        rho.requires_grad = True
        xs_ys, xs, ys = GenerateData(d, rho, num_sample)

        ans = entropy_surrogate(spectral_j, xs_ys) \
              - entropy_surrogate(spectral_m, ys)

        ans.backward()
        approximations.append(rho.grad.data)

    approximations = torch.stack(approximations).view(-1).detach().cpu().numpy()
    return approximations

In [None]:
def ComputeMIGE(d, axis_min, axis_max, num_lines):
    print('Compute MIGE estimation, d={}'.format(d))

    num_point = 32
    num_sample = 256

    rho = np.linspace(axis_min, axis_max, num_point)
    gradient = []

    for i in tqdm(range(num_lines)):
        g = MIGE(d, rho, num_sample, GenerateSamples, threshold=0.99)
        gradient.append(g)

    gradient = np.stack(gradient, axis=0)

    np.savez('mige_{}.npz'.format(d), rho=rho, gradient=gradient)
    print('Complete: {} times {}-d mige estimation.'.format(num_lines, d))
    print('  Save as: mige_{}.npz'.format(d))

In [None]:
ds = [8, 16, 32]
num_lines = 1
axis_min, axis_max = -0.8, 0.8

for d in ds:
    ComputeTrueValues(d, axis_min, axis_max)

for d in ds:
    ComputeMIGE(d, axis_min, axis_max, num_lines)    

for d in ds:
    ComputeMINE_F(d, axis_min, axis_max, num_lines)

In [None]:
def Draw(ds, num_lines):
    # Plot Setting
    plt.figure(figsize=(3.8 * len(ds), 3.6 * 2))

    for i in range(len(ds)):
        d = ds[i]
        true_value = np.load('true_{}.npz'.format(d))
        true_r, true_mi, true_g = true_value['rho'], true_value['mi'], true_value['gradient']
        mine_value = np.load('I_js_{}.npz'.format(d))
        mine_r, mine_mi, mine_g = mine_value['rho'], mine_value['mi'], mine_value['gradient']
        mige_value = np.load('mige_{}.npz'.format(d))
        mige_r, mige_g = mige_value['rho'], mige_value['gradient']

        plt.subplot(2, len(ds), i + 1)
        for j in range(num_lines - 1):
            plt.plot(mine_r, mine_mi[j], color='C1', alpha=0.1, linewidth=1.5)
        for j in range(num_lines - 1):
            plt.plot(mine_f_r, mine_f_mi[j], color='C2', alpha=0.1, linewidth=1.5)
        plt.plot(true_r, true_mi, label=r'True MI', linewidth=1.5)
        plt.plot(mine_r, mine_mi[num_lines - 1], label=r'MINE', color='C1', linewidth=1.5)
        plt.title('d = {}'.format(d), fontsize=16)
        plt.grid(True, linestyle='-.')
        plt.tick_params(labelbottom=False)
        if i == 0:
            plt.ylabel('Mutual Information', fontsize=16)
            plt.legend(loc='upper left', prop={'size': 12})

        plt.subplot(2, len(ds), len(ds) + i + 1)
        for j in range(num_lines - 1):
            plt.plot(mine_r, mine_g[j], color='C1', alpha=0.1, linewidth=1.5)
        for j in range(num_lines - 1):
            plt.plot(mine_f_r, mine_f_g[j], color='C2', alpha=0.1, linewidth=1.5)
        for j in range(num_lines - 1):
            plt.plot(mige_r, mige_g[j], color='C3', alpha=0.1, linewidth=1.5)
        plt.plot(true_r, true_g, label='True Gradient', linewidth=1.5)
        plt.plot(mine_r, mine_g[num_lines - 1], label=r'$\nabla_{\rho}$ MINE', color='C1', linewidth=1.5)
        plt.plot(mige_r, mige_g[num_lines - 1], label='MIGE (ours)', color='C3', linewidth=1.5)
        plt.xlabel(r'$\rho$', fontsize=16)
        plt.grid(True, linestyle='-.')
        if i == 0:
            plt.ylabel('Gradient', fontsize=16)
            plt.legend(loc='upper left', prop={'size': 12})

    plt.tight_layout()
    plt.savefig('./toy.pdf', format='pdf')
    plt.show()

In [None]:
Draw(ds, num_lines)