In [1]:
import torch
import torch.nn as nn

In [2]:
latent_dim1 = 1
latent_dim2 = 1
batch_size = 128
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class gibbs_sampler():
    def __init__(self):
        self.g11 = nn.Parameter(torch.randn(latent_dim1,latent_dim1,requires_grad=True)).to(device)
        self.g22 = nn.Parameter(torch.randn(latent_dim2,latent_dim2,requires_grad=True)).to(device)
        self.g21 = torch.zeros(latent_dim2,latent_dim1,requires_grad=False).to(device)
        
    def calc(self,mu,var,mup,varp):

        z0       = torch.randn(batch_size,latent_dim1).to(device)        ## For estimating z  in  q(z |z',x)
        zp0      = torch.randn(batch_size,latent_dim2).to(device)       ## For estimating z' in  q(z'|z,x)
        z_prior  = torch.randn(batch_size,latent_dim1).to(device)        ## For estimating z  in  p(z |z')
        zp_prior = torch.randn(batch_size,latent_dim2).to(device)       ## For estimating zp  in p(z'|z)
        for i in range(1):
            #Gibbs sampling q(z,z'|x,x')
            alpha  = self.f(zp0,var,self.g22)
            beta   = self.g(alpha,zp0,mu,self.g11)
            z0     = beta+torch.sqrt(alpha.float())*torch.randn_like(beta) #z0    = beta  + sqrt(alpha) * eps
            alphap = self.f(z0,varp,self.g22)
            betap  = self.g(alphap,z0,mup,self.g11)
            zp0    = betap+torch.sqrt(alphap.float())*torch.randn_like(betap) #zp0   = betap + sqrt(alphap)* epsp         
            
            #Gibbs sampling p(z,z')
            alpha1   = self.f(zp_prior,torch.zeros_like(var),self.g22)
            beta1    = self.g(alpha1,zp_prior,torch.zeros_like(mu),self.g11)
            z_prior  = beta1+torch.sqrt(alpha1.float())*torch.randn_like(beta1)#z_prior    = beta1  + sqrt(alpha1) * eps
            alphap1  = self.f(z_prior,torch.zeros_like(varp),self.g22)
            betap1   = self.g(alphap1,z_prior,torch.zeros_like(mup),self.g11)
            zp_prior = betap1+torch.sqrt(alphap1.float())*torch.randn_like(betap1)#zp_prior   = betap1 + sqrt(alphap1)* eps
            #Creating G matrix
        G1 = torch.cat((self.g11,self.g21),0)
        G2 = torch.cat((self.g21,self.g22),0)
        G  = torch.cat((G1,G2),1)

        #print("G",torch.sum(G))      

        return G,z0,zp0,z_prior,zp_prior
    def f(self,zp,var,g22):
        """
        zp: z' (latent space of SVHN) has size 64x16
        var: variance of MNIST
        g22: coupling weights of size 16x16
        """
        val   = 2*(1+torch.matmul(torch.square(zp),torch.exp(g22))+var)
        alpha = torch.reciprocal(val)
        return alpha
    
    def g(self,alpha,zp,mu,g11):
        """
        alpha: variance of latent space of MNIST
        zp: z' (latent space of SVHN) has size 64x16
        mu: Mean of MNIST
        g11: coupling weights of size 16x16
        """
        beta = torch.matmul(zp,g11)+mu
        beta = alpha*beta
        return beta