In [2]:
import torch
import torch.nn as nn
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
## hello 

To sample from the prior we consider its conditional distributions,
<span class="math display">
    \begin{aligned}
    p(z|z')&\sim \mathrm{exp}\left(-(1-g_{22}z'^2)z^2+(g_{11}z'\right) \\
    p(z'|z)&\sim \mathrm{exp}\left(-(1-g_{22}z^2) {z'}^2+(g_{11}z\right)
    \end{aligned}
</span>
The variance and mean of p(z|z') is given by,
<span class="math display">
    \begin{aligned}
    \sigma^2&=\frac{1}{2(1-g_{22}z'^2)}\\
    \mu &=\sigma^2*(g_{11}z')
    \end{aligned}
</span>
The variance and mean of p(z'|z) is given by,
<span class="math display">
    \begin{aligned}
    \sigma'^2&=\frac{1}{2(1-g_{22}z^2)}\\
    \mu' &=\sigma'^2*(g_{11}z)
    \end{aligned}
</span>
To implement Gibbs sampling on prior, define the following functions:
* var_calc_prior:
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&:  z\, (\mathrm{or}\, z'),  g_{22} \\
    \mathrm{Output}&: \sigma^2\, (\mathrm{or}\, \sigma'^2)
    \end{aligned}
</span>
* mean_calc_prior:  
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&: z\, (\mathrm{or}\, z'), \sigma^2\, (\mathrm{or}\, \sigma'^2), g_{11},  g_{22}\\
    \mathrm{Output}&: \mu\,(\mathrm{or}\, \mu')
    \end{aligned}
</span>
The algorithm is as follows:
* Initialize z and z'
* Iterate 5 times and on each iteration calculate
<span class="math display">
    \begin{aligned}
    \sigma^2&= \mathrm{var\_calc\_prior}(z',g_{22})\\
    \mu   &= \mathrm{mu\_calc\_prior}(z',\sigma^2, g_{11},  g_{22})\\
    z    &= \mu+\sqrt{\sigma^2}\odot \epsilon\\
    \sigma'^2&= \mathrm{var\_calc\_prior}(z,g_{22})\\
    \mu'   &= \mathrm{mu\_calc\_prior}(z,\sigma'^2, g_{11},  g_{22})\\
    z' &= \mu'+\sqrt{\sigma'^2}\odot \epsilon
    \end{aligned}
</span>

In [3]:
class gibbs_sampler():
    def __init__(self,latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size = batch_size
        
    def var_calc(self,z,g22,lambda_2):
        val   = 2*(1-torch.matmul(torch.square(z),g22)-lambda_2)
        return torch.reciprocal(val)
    def mean_calc(self,z,var,g11,lambda_1):
        beta = torch.matmul(z,g11)+lambda_1
        return var*beta
    
    def value_calc(self,z,g11,g22,lambda_1,lambda_2):
        var1          = self.var_calc(z,g22,lambda_2)
        mean1         = self.mean_calc(z,var1,g11,lambda_1)
        out           = mean1+torch.sqrt(var1.float())*torch.randn_like(var1)
        return out 
    
    def gibbs_sample(self,flag,z1,z2,g11,g22,lambda_1,lambda_2,lambdap_1,lambdap_2,n_iterations, return_mean_and_variance=False):
        if return_mean_and_variance:
            # return whatever
            # TODO
        else:
            if flag == 1:
                z1 = torch.randn(self.batch_size,self.latent_dim1).to(device)       ## 
                z2 = torch.randn(self.batch_size,self.latent_dim2).to(device)       ## For estimating z' in  q(z'|z,x)

            for i in range(n_iterations):
                z1  = self.value_calc(z2,torch.transpose(g11,0,1),torch.transpose(g22,0,1),lambda_1,lambda_2) 
                z2  = self.value_calc(z1,g11,g22,lambdap_1,lambdap_2) 

            return z1,z2       
    
"""
Size of z1:        [batch_size, latent_dim1]
Size of z2:        [batch_size, latent_dim1]
Size of g11:       [latent_dim1, latent_dim2]
Size of g22:       [latent_dim1, latent_dim2]
Size of lambda_1:  [batch_size, latent_dim1]
Size of lambda_2:  [batch_size, latent_dim1]
Size of lambdap_1: [batch_size, latent_dim2]
Size of lambdap_2: [batch_size, latent_dim2]

"""

To sample from the posterior we consider its conditional distributions,
<span class="math display">
    \begin{aligned}
    q(z|z',x)&\sim \mathrm{exp}\left(-(1-g_{22}z'^2-\lambda_2(x))z^2+(g_{11}z'+\lambda_1(x))z\right) \\
    q(z'|z,x')&\sim \mathrm{exp}\left(-(1-g_{22}z^2 -\lambda'_2(x')) {z'}^2+(g_{11}z+\lambda'_1(x')){z'}\right)
    \end{aligned}
</span>
The variance and mean of q(z|z',x) is given by,
<span class="math display">
    \begin{aligned}
    \sigma^2&=\frac{1}{2(1-g_{22}z'^2-\lambda_2(x))}\\
    \mu &=\sigma^2*(g_{11}z'+\lambda_1(x))
    \end{aligned}
</span>
The variance and mean of q(z'|z,x') is given by,
<span class="math display">
    \begin{aligned}
    \sigma'^2&=\frac{1}{2(1-g_{22}z^2-\lambda'_2(x'))}\\
    \mu' &=\sigma'^2*(g_{11}z+\lambda'_1(x'))
    \end{aligned}
</span>
To implement Gibbs sampling on posterior, define the following functions:
* var_calc_posterior:
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&:  z\, (\mathrm{or}\, z'), g_{22}, \lambda_2(x) (\mathrm{or}\, \lambda'_2(x')) \\
    \mathrm{Output}&: \sigma^2\, (\mathrm{or}\, \sigma'^2)
    \end{aligned}
</span>
* mean_calc_posterior:  
<span class="math display">
    \begin{aligned}
    \mathrm{Input}&: z\, (\mathrm{or}\, z'), \sigma^2\, (\mathrm{or}\, \sigma'^2), g_{11}, g_{22},\lambda_1(x) (\mathrm{or}\, \lambda'_1(x'))\\
    \mathrm{Output}&: \mu\,(\mathrm{or}\, \mu')
    \end{aligned}
</span>
The algorithm is as follows:
* Initialize z and z'
* Iterate 5 times and on each iteration calculate
<span class="math display">
    \begin{aligned}
    \sigma^2&= \mathrm{var\_calc\_posterior}(z',g_{22},\lambda_2(x))\\
    \mu   &= \mathrm{mu\_calc\_posterior}(z',\sigma^2, g_{11},  g_{22},\lambda_1(x))\\
    z    &= \mu+\sqrt{\sigma^2}\odot \epsilon\\
    \sigma'^2&= \mathrm{var\_calc\_posterior}(z,g_{22},\lambda'_2(x'))\\
    \mu'   &= \mathrm{mu\_calc\_posterior}(z,\sigma'^2, g_{11},  g_{22},\lambda'_2(x'))\\
    z' &= \mu'+\sqrt{\sigma'^2}\odot \epsilon
    \end{aligned}
</span>