In [1]:
import torch
import numpy as np
_device = 'cuda' if torch.cuda.is_available() else 'cpu'
from scipy.stats import multivariate_normal as mv
import matplotlib.pyplot as plt

In [2]:
# n_IW_samples =10
# m1 = 2
# m2 = 5
# var=5
# x,y= sample_proposal(m1, m2, var, n_IW_samples)

In [3]:
# x = torch.ones(64,32)
# y = torch.ones(10,32)
# (x@y.T).size()

In [4]:
## Input G , mu1, var1, mu2, var2
## Output: z,W, KL

In [5]:
class importance_sampler():
    def __init__(self, latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size = batch_size
    def sample_proposal(self,var, n_IW_samples, device=_device):
        mn1 = torch.distributions.MultivariateNormal(torch.zeros(self.latent_dim1), var * torch.eye(self.latent_dim1))
        mn2 = torch.distributions.MultivariateNormal(torch.zeros(self.latent_dim2), var * torch.eye(self.latent_dim2))
        return [mn1.sample([n_IW_samples,]).to(device), mn2.sample([n_IW_samples,]).to(device)]
    def proposal_dist(self,z1,z2,var):
        #cov_mat = var*torch.eye() #FIX
        dim   = self.latent_dim1+self.latent_dim2
        z_sqd = -(z1**2).sum(-1)-(z2**2).sum(-1)
        p_x = 1/(2*np.pi*var)**(dim/2)*torch.exp(z_sqd/var)   #FIX add covariance 
        p_x = p_x.repeat(self.batch_size, 1)
        return p_x
    def target_dist(self,G,z1,z2,mu1,var1,mu2,var2):
        # mu1: [batch_size,latent_dim1], z1: [n_IW_samples,latent_dim1]
        g11 = G[:self.latent_dim1,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g12 = G[:self.latent_dim1,self.latent_dim2:] #[latent_dim1, latent_dim2]
        g21 = G[self.latent_dim1:,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g22 = G[self.latent_dim1:,self.latent_dim2:] #[latent_dim1, latent_dim2] 
        z_sqd = -(z1**2).sum(-1)-(z2**2).sum(-1)     #[n_IW_samples] 
        h1   = (z1@g11*z2).sum(-1)
        h2   = (z1@g12*(z2**2)).sum(-1)
        h3   = ((z1**2)@g21*z2).sum(-1)
        h4   = ((z1**2)@g22*(z2**2)).sum(-1)
        h    = h1+h2+h3+h4      
        d1   = (mu1@z1.T+var1@(z1**2).T)
        d2   = (mu2@z2.T+var2@(z2**2).T)        
        d    = d1 + d2                            #[batch_size, n_IW_samples] 
        t_x    = torch.exp((z_sqd+h+d))             #[batch_size, n_IW_samples] 
        return t_x
    def KL_calculator(self,weights,p_x,t_x):
        KLD = torch.tensor([1]).to(_device)
        return KLD
    def calc(self,G,mu1,var1,mu2,var2,n_IW_samples): 
        proposal_var = 0.1
        x = self.sample_proposal(proposal_var,n_IW_samples)
        z1_prior, z2_prior = self.sample_proposal(proposal_var,n_IW_samples)  #[n_IW_samples,latent_dim1],[n_IW_samples,latent_dim2]
        z1_posterior,z2_posterior = self.sample_proposal(proposal_var,n_IW_samples)#[n_IW_samples,latent_dim1],[n_IW_samples,latent_dim2]
        t_x_prior = self.target_dist(G,z1_prior, z2_prior,torch.zeros_like(mu1),torch.zeros_like(var1),torch.zeros_like(mu2),torch.zeros_like(var2))
        t_x_post = self.target_dist(G,z1_posterior, z2_posterior,mu1,var1,mu2,var2)
        p_x_prior = self.proposal_dist(z1_prior,z2_prior,proposal_var)
        p_x_post = self.proposal_dist(z1_posterior,z2_posterior,proposal_var)  #[batch_size,n_IW_samples]
        IS_weights_prior = torch.log(t_x_prior)-torch.log(p_x_prior) 
        prior_normalization = (torch.logsumexp(IS_weights_prior,1)).unsqueeze(1)
        IS_weights_prior = torch.exp(IS_weights_prior - prior_normalization)
        IS_weights_post  = torch.log(t_x_post)-torch.log(p_x_post)
        posterior_normalization = (torch.logsumexp(IS_weights_post,1)).unsqueeze(1)
        IS_weights_post  = torch.exp(IS_weights_post - posterior_normalization)
        KLD = self.KL_calculator(IS_weights_post,p_x_post,t_x_post)
        return z1_prior,z2_prior,z1_posterior,z2_posterior, IS_weights_prior,IS_weights_post, KLD

In [33]:
# x = torch.randn(15)
# x = x.repeat(10, 1)
# x.size()

torch.Size([10, 15])