In [7]:
import torch
import numpy as np
_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.stats import multivariate_normal as mv
import matplotlib.pyplot as plt

In [8]:
# n_IW_samples =10
# m1 = 2
# m2 = 5
# var=5
# x,y= sample_proposal(m1, m2, var, n_IW_samples)

In [9]:
# x = torch.ones(64,32)
# y = torch.ones(10,32)
# (x@y.T).size()

In [10]:
## Input G , mu1, var1, mu2, var2
## Output: z,W, KL

In [11]:
# def sample_proposal(self,mean1,mean2,var1,var2, n_IW_samples, device=_device):
#     diag1 = torch.eye(self.latent_dim1)
#     diag2 = torch.eye(self.latent_dim2)
#     mn1 = torch.distributions.MultivariateNormal(mean1, diag1)
#     mn2 = torch.distributions.MultivariateNormal(mean2, diag2)
 
#     return [mn1.sample([n_IW_samples,self.batch_size]).to(device), mn2.sample([n_IW_samples,self.batch_size]).to(device)]

In [12]:
# mean1= torch.randn(10).to(_device)
# diag1 = torch.eye(10).to(_device)
# mn1 = torch.distributions.MultivariateNormal(mean1, diag1)
# mn1.sample([5,5]).to(_device)

In [13]:
class importance_sampler():
    def __init__(self, latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size  = batch_size
    def sample_proposal(self,mean1,mean2,var1,var2, n_IW_samples, device=_device):
        
        diag1 = torch.eye(self.latent_dim1).to(device)
        diag2 = torch.eye(self.latent_dim2).to(device)
        mn1 = torch.distributions.MultivariateNormal(mean1, diag1*var1)
        mn2 = torch.distributions.MultivariateNormal(mean2, diag2*var2)
        return [mn1.sample([n_IW_samples,self.batch_size]).to(device), mn2.sample([n_IW_samples,self.batch_size]).to(device)]

    def proposal_dist(self,mean1,mean2,var1,var2,z1,z2, device=_device):
        diag = torch.eye(self.latent_dim1+self.latent_dim2).to(device)
        var_inv = diag*(1/torch.cat((var1,var2),0))
        z_tot = torch.cat((z1,z2),2)
        mean_tot = torch.cat((mean1,mean2),0)
        z_sqd = (((z_tot-mean_tot)**2)@var_inv*((z_tot-mean_tot)**2)).sum(-1)
#         z_sqd = -((z1-mean1)**2).sum(-1)-((z2-mean2)**2).sum(-1)               #[n_IW_samples, batch_size]
        log_p_x = -z_sqd     
        return log_p_x
    def target_dist(self,G,z1,z2,mu1,var1,mu2,var2):
        # mu1: [batch_size,latent_dim1], z1: [n_IW_samples,latent_dim1]
        g11 = G[:self.latent_dim1,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g12 = G[:self.latent_dim1,self.latent_dim2:] #[latent_dim1, latent_dim2]
        g21 = G[self.latent_dim1:,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g22 = G[self.latent_dim1:,self.latent_dim2:] #[latent_dim1, latent_dim2] 
        z_sqd = -(z1**2).sum(-1)-(z2**2).sum(-1)               #[n_IW_samples,batch_size] 
        h1   = (z1@g11*z2).sum(-1) 
        h2   = (z1@g12*(z2**2)).sum(-1) 
        h3   = ((z1**2)@g21*z2).sum(-1) 
        h4   = ((z1**2)@g22*(z2**2)).sum(-1)                  #[n_IW_samples, batch_size,latent_dim]
        h    = (h1+h2+h3+h4)                 #[n_IW_samples, batch_size]
        d1   = (mu1*z1+var1*(z1**2)).sum(-1) 
        d2   = (mu2*z2+var2*(z2**2)).sum(-1)                  #[n_IW_samples, batch_size,latent_dim2]
        d    = (d1 + d2)                     #[n_IW_samples, batch_size]
        log_t_x    = (z_sqd+h+d)                     #[n_IW_samples, batch_size]
        return log_t_x
    def calc(self,G,mu1,var1,mu2,var2,n_IW_samples,mu3,var3,mu4,var4): 
        z1_prior, z2_prior        = self.sample_proposal(mu3,mu4,var3,var4,n_IW_samples)  #[n_IW_samples,batch_size,latent_dim1],[n_IW_samples,batch_size,latent_dim2]
        z1_posterior,z2_posterior = self.sample_proposal(mu3,mu4,var3,var4,n_IW_samples)  #[n_IW_samples,batch_size,latent_dim1],[n_IW_samples,batch_size,latent_dim2]
        t_x_prior  = self.target_dist(G,z1_prior, z2_prior,torch.zeros_like(mu1),torch.zeros_like(var1),torch.zeros_like(mu2),torch.zeros_like(var2))
        t_x_post   = self.target_dist(G,z1_posterior, z2_posterior,mu1,var1,mu2,var2)
        p_x_prior  = self.proposal_dist(mu3,mu4,var3,var4,z1_prior,z2_prior)
        p_x_post   = self.proposal_dist(mu3,mu4,var3,var4,z1_posterior,z2_posterior)      #[batch_size,n_IW_samples]
        IS_weights_prior = t_x_prior  -  p_x_prior
        prior_normalization = torch.logsumexp(IS_weights_prior,0)
        IS_weights_prior = torch.exp(IS_weights_prior - prior_normalization)
        IS_weights_post  = t_x_post   -  p_x_post
        posterior_normalization = torch.logsumexp(IS_weights_post,0)
        diff_post = IS_weights_post - posterior_normalization
        IS_weights_post  = torch.exp(diff_post)
#         print(diff_post[:,1].size())
#         print(diff_post[:,1])

        return z1_prior,z2_prior,z1_posterior,z2_posterior, IS_weights_prior,IS_weights_post

In [31]:
# x= torch.zeros(5)
# y = torch.ones(5,)
# (x*y).size()