In [1]:
import torch
import numpy as np
_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.stats import multivariate_normal as mv
import matplotlib.pyplot as plt

Proposal distribution:

p($z_1,z_2$) = exp[-($z_1-\mu_1$)$(\Sigma)^2$($z_1-\mu_1$)]

In [2]:
class importance_sampler():
    def __init__(self, latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size  = batch_size
    def sample_proposal(self,mean1,mean2,var1,var2, n_IW_samples, device=_device):
        var1_prop = torch.diag_embed(var1) 
        var2_prop = torch.diag_embed(var2) 
        mn1 = torch.distributions.MultivariateNormal(mean1, var1_prop )
        mn2 = torch.distributions.MultivariateNormal(mean2, var2_prop )
        return [mn1.sample([n_IW_samples,]).to(device), mn2.sample([n_IW_samples,]).to(device)]

    def proposal_dist(self,mean1,mean2,var1,var2,z1,z2, device=_device):
        var1_diag = torch.diag_embed(var1).to(device)
        var2_diag = torch.diag_embed(var2).to(device)
        dist1 = torch.distributions.MultivariateNormal(mean1, var1_diag)
        dist2 = torch.distributions.MultivariateNormal(mean2, var2_diag)
        z_sqd = dist1.log_prob(z1)+dist2.log_prob(z2)
#         z1_diff  = z1 - mean1
#         z1_diff  = torch.transpose(z1_diff,0,1)
#         z2_diff  = z2 - mean2
#         z2_diff  = torch.transpose(z2_diff,0,1)
#         z1_sqrd  = ((z1_diff**2)@var1_inv*(z1_diff**2)).sum(-1)
#         z2_sqrd  = ((z2_diff**2)@var2_inv*(z2_diff**2)).sum(-1)
#         z_sqd    = z1_sqrd+z2_sqrd
#         var_cat = 0.5/torch.cat((var1,var2),1).to(device)   ##precision 1/(2var)
#         var_inv = torch.diag_embed(var_cat)
#         z_tot = torch.cat((z1,z2),2)
#         mean_tot = torch.cat((mean1,mean2),1)
#         z_diff = z_tot-mean_tot
#         z_diff = torch.transpose(z_diff,0,1)
#         z_sqd = (((z_diff)**2)@var_inv*((z_diff)**2)).sum(-1)
#         z_sqd = -((z1-mean1)**2).sum(-1)-((z2-mean2)**2).sum(-1)               #[n_IW_samples, batch_size]
        log_p_x = -z_sqd     
        return log_p_x
    def target_dist(self,G,z1,z2,mu1,var1,mu2,var2):
        # mu1: [batch_size,latent_dim1], z1: [n_IW_samples,batch_size,latent_dim1]
        g11 = G[:self.latent_dim1,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g12 = G[:self.latent_dim1,self.latent_dim2:] #[latent_dim1, latent_dim2]
        g21 = G[self.latent_dim1:,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g22 = G[self.latent_dim1:,self.latent_dim2:] #[latent_dim1, latent_dim2] 
#         z_sqd = -(z1**2).sum(-1)-(z2**2).sum(-1)               #[n_IW_samples,batch_size] 
        h1   = (z1@g11*z2).sum(-1) 
        h2   = (z1@g12*(z2**2)).sum(-1) 
        h3   = ((z1**2)@g21*z2).sum(-1) 
        h4   = ((z1**2)@g22*(z2**2)).sum(-1)                  #[n_IW_samples, batch_size,latent_dim]
        h    = (h1+h2+h3+h4)                 #[n_IW_samples, batch_size]
        d1   = (mu1*z1+var1*(z1**2)).sum(-1) 
        d2   = (mu2*z2+var2*(z2**2)).sum(-1)                  #[n_IW_samples, batch_size,latent_dim2]
        d    = (d1 + d2)                     #[n_IW_samples, batch_size]
#         log_t_x    = (z_sqd+h+d)                     #[n_IW_samples, batch_size]
        log_t_x    = h+d                     #[n_IW_samples, batch_size]

        return log_t_x
    def calc(self,G,mu1,var1,mu2,var2,n_IW_samples,mu3,var3,mu4,var4): 
        z1_posterior,z2_posterior = self.sample_proposal(mu3,mu4,var3,var4,n_IW_samples)  #[n_IW_samples,batch_size,latent_dim1],[n_IW_samples,batch_size,latent_dim2]
#         p_x_post   = self.proposal_dist(mu3,mu4,var3,var4,z1_posterior,z2_posterior)      #[batch_size,n_IW_samples]
#         t_x_post   = self.target_dist(G,z1_posterior, z2_posterior,mu1,var1,mu2,var2)
#         IS_weights_post  = t_x_post   -  p_x_post
        
#         p_x_post   = self.proposal_dist(mu3,mu4,var3,var4,z1_posterior,z2_posterior)      #[batch_size,n_IW_samples]
        t_x_post   = self.target_dist(G,z1_posterior, z2_posterior,mu1,var1,mu2,var2)
        IS_weights_post  = t_x_post   
        
        posterior_normalization = torch.logsumexp(IS_weights_post,0)
        diff_post = IS_weights_post - posterior_normalization
        IS_weights_post  = torch.exp(diff_post)

        return z1_posterior,z2_posterior, IS_weights_post

In [3]:
# x= torch.zeros(5)
# y = torch.ones(5,)
# (x*y).size()

In [4]:
# def sample_proposal(mean1,mean2,var1,var2, n_IW_samples, device=_device):
#     var1_prop = torch.diag_embed(var1) 
#     var2_prop = torch.diag_embed(var2) 
#     mn1 = torch.distributions.MultivariateNormal(mean1, var1_prop )
#     mn2 = torch.distributions.MultivariateNormal(mean2, var2_prop )
#     return [mn1.sample([n_IW_samples,]).to(device), mn2.sample([n_IW_samples,]).to(device)]

In [5]:
# mean1 = torch.zeros(128,2)
# mean2 = torch.zeros(128,2)
# var1  = torch.ones(128,2)
# var2  = torch.ones(128,2)
# n_IW_samples = 5
# z1,z2 =sample_proposal(mean1,mean2,var1,var2,n_IW_samples)

In [6]:
# z1.size()

In [7]:
# z1_numpy = z1.cpu().numpy()

In [8]:
# z11 = np.reshape(z1_numpy,(3000,1))

In [9]:
# import matplotlib.pyplot as plt
# n, bins, patches = plt.hist(z11)

In [7]:
# import torch
# from torch.distributions import MultivariateNormal 

# means = torch.tensor([0.0538,
#         0.0651])
# stds = torch.tensor([[0.7865,0],
#         [0,0.7792]])

# dist = MultivariateNormal(means, stds)
# a = torch.tensor([1.2,3.4])
# d = dist.log_prob(a)
# print(d.size())

torch.Size([])


In [20]:
# z1 = torch.randn(10,128,1)
# mean = torch.zeros(128,1)
# stds = torch.ones(128,1)
# stds_diag = torch.diag_embed(stds)
# dist = MultivariateNormal(mean, stds_diag)
# d = dist.log_prob(z1)
# print(d.size())

torch.Size([10, 128])
