In [2]:
import torch
import torch.nn as nn
latent_dim1 = 1
latent_dim2 = 1
batch_size  = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

To sample from the prior we consider its conditional distributions,
<span class="math display">
    \begin{aligned}
    p(z|z')&\sim \mathrm{exp}\left(-(1-g_{22}z'^2)z^2+(g_{11}z'\right) \\
    p(z'|z)&\sim \mathrm{exp}\left(-(1-g_{22}z^2) {z'}^2+(g_{11}z\right)
    \end{aligned}
</span>
The variance and mean of p(z|z') is given by,
<span class="math display">
    \begin{aligned}
    \sigma^2&=\frac{1}{2(1-g_{22}z'^2)}\\
    \mu &=\sigma^2*(g_{11}z')
    \end{aligned}
</span>
The variance and mean of p(z'|z) is given by,
<span class="math display">
    \begin{aligned}
    \sigma'^2&=\frac{1}{2(1-g_{22}z^2)}\\
    \mu' &=\sigma'^2*(g_{11}z)
    \end{aligned}
</span>
To implement Gibbs sampling on prior, define the following functions:
* var_calc_prior:
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&:  z\, (\mathrm{or}\, z'),  g_{22} \\
    \mathrm{Output}&: \sigma^2\, (\mathrm{or}\, \sigma'^2)
    \end{aligned}
</span>
* mean_calc_prior:  
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&: z\, (\mathrm{or}\, z'), \sigma^2\, (\mathrm{or}\, \sigma'^2), g_{11},  g_{22}\\
    \mathrm{Output}&: \mu\,(\mathrm{or}\, \mu')
    \end{aligned}
</span>
The algorithm is as follows:
* Initialize z and z'
* Iterate 5 times and on each iteration calculate
<span class="math display">
    \begin{aligned}
    \sigma^2&= \mathrm{var\_calc\_prior}(z',g_{22})\\
    \mu   &= \mathrm{mu\_calc\_prior}(z',\sigma^2, g_{11},  g_{22})\\
    z    &= \mu+\sqrt{\sigma^2}\odot \epsilon\\
    \sigma'^2&= \mathrm{var\_calc\_prior}(z,g_{22})\\
    \mu'   &= \mathrm{mu\_calc\_prior}(z,\sigma'^2, g_{11},  g_{22})\\
    z' &= \mu'+\sqrt{\sigma'^2}\odot \epsilon
    \end{aligned}
</span>

In [3]:
class gibbs_sampler():
    def __init__(self):
        pass
    #Gibbs sampling p(z,z')
    def var_calc_prior(self,z,g22):
        val   = 2*(1-torch.matmul(torch.square(z),g22))
        return torch.reciprocal(val)
    def mean_calc_prior(self,z,var,g11):
        beta = torch.matmul(z,g11)
        return var*beta
    def initialize_prior_sample(self,g11,g22):
        z_prior  = torch.randn(batch_size,latent_dim1).to(device)        ## For estimating z  in  p(z |z')
        zp_prior = torch.randn(batch_size,latent_dim2).to(device)       ## For estimating zp  in p(z'|z)
        for i in range(5000):
            var      = self.var_calc_prior(zp_prior,torch.transpose(g22,0,1))
            mean     = self.mean_calc_prior(zp_prior,var,torch.transpose(g11,0,1))
            z_prior  = mean+torch.sqrt(var.float())*torch.randn_like(var)
            varp     = self.var_calc_prior(z_prior,g22)
            meanp    = self.mean_calc_prior(z_prior,varp,g11)
            zp_prior = meanp+torch.sqrt(varp.float())*torch.randn_like(varp)
        return z_prior,zp_prior
    
    def prior_sample(self,z_prior,zp_prior,g11,g22):
        for i in range(5):
            var      = self.var_calc_prior(zp_prior,torch.transpose(g22,0,1))
            mean     = self.mean_calc_prior(zp_prior,var,torch.transpose(g11,0,1))
            z_prior  = mean+torch.sqrt(var.float())*torch.randn_like(var)
            varp     = self.var_calc_prior(z_prior,g22)
            meanp    = self.mean_calc_prior(z_prior,varp,g11)
            zp_prior = meanp+torch.sqrt(varp.float())*torch.randn_like(varp)
        return z_prior,zp_prior
    
    #Gibbs sampling q(z,z'|x,x')
    def var_calc_posterior(self,z,g22,lambda_2):
#         print('z',z.size())
#         print('g22',g22.size())
#         print('lambda_2',lambda_2.size())
        val   = 2*(1-torch.matmul(torch.square(z),g22)-lambda_2)
        return torch.reciprocal(val)
    def mean_calc_posterior(self,z,var,g11,lambda_1):
#         print('z',z.size())
#         print('g11',g11.size())
#         print('lambda_1',lambda_1.size())        
        beta = torch.matmul(z,g11)+lambda_1
        return var*beta
    def initialize_posterior_sample(self,g11,g22,lambda_1,lambda_2,lambdap_1,lambdap_2):
        z_posterior = torch.randn(batch_size,latent_dim1).to(device)       ## For estimating z  in  q(z |z',x)
        zp_posterior= torch.randn(batch_size,latent_dim2).to(device)       ## For estimating z' in  q(z'|z,x)
        for i in range(5000):
            var          = self.var_calc_posterior(zp_posterior,torch.transpose(g22,0,1),lambda_2)
            mean         = self.mean_calc_posterior(zp_posterior,var,torch.transpose(g11,0,1),lambda_1)
            z_posterior  = mean+torch.sqrt(var.float())*torch.randn_like(var)
            varp         = self.var_calc_posterior(z_posterior,g22,lambdap_2)
            meanp        = self.mean_calc_posterior(z_posterior,varp,g11,lambdap_1)
            zp_posterior = meanp+torch.sqrt(varp.float())*torch.randn_like(varp)
        return z_posterior,zp_posterior       
    
    def posterior_sample(self,z_posterior,zp_posterior,g11,g22,lambda_1,lambda_2,lambdap_1,lambdap_2):
        for i in range(5):
            var          = self.var_calc_posterior(zp_posterior,torch.transpose(g22,0,1),lambda_2)
            mean         = self.mean_calc_posterior(zp_posterior,var,torch.transpose(g11,0,1),lambda_1)
            z_posterior  = mean+torch.sqrt(var.float())*torch.randn_like(var)
            varp         = self.var_calc_posterior(z_posterior,g22,lambdap_2)
            meanp        = self.mean_calc_posterior(z_posterior,varp,g11,lambdap_1)
            zp_posterior = meanp+torch.sqrt(varp.float())*torch.randn_like(varp)
        return z_posterior,zp_posterior       