In [1]:
import torch

The KL divergence is given by,
<span class="math display">
    \begin{aligned}
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'})) &= \mathrm{log}\left[\frac{\tilde{q}_{\phi}(z^{(i)},z'^{(i)}|x,x')}{\tilde{p}_{\theta}(z^{(i)},z'^{(i)})}\right]-\langle \mathbf{sg}( \mathbb{E}_q[T]),\lambda\rangle-\langle \mathbf{sg}( \mathbb{E}_q[T']),\lambda'\rangle+(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle
    \end{aligned}
</span>

<span class="math display">
    \begin{aligned}
    T_{prior}&=[z_{prior},z^2_{prior}]\\
    T'_{prior}&=[z'_{prior},z'^2_{prior}]\\
    T_{posterior}&=[z_{posterior},z^2_{posterior}]\\
    T'_{posterior}&=[z'_{posterior},z'^2_{posterior}]\\
    \lambda&=[\lambda_1,\lambda_2]\\
    \lambda'&=[\lambda'_1,\lambda'_2]\\
    T_{prior}^2&=(z^2_{prior}+z'^2_{prior})\\
    T_{posterior}^2&=(z^2_{posterior}+z'^2_{posterior})
    \end{aligned}
</span>
We also define,
<span class="math display">
    \begin{aligned}
    \mathrm{log}p_{prior}&=-T_{posterior}^2+T_{posterior}*G+T'_{posterior}\\
    \mathrm{log}q_{posterior}&=-T_{posterior}^2+T_{posterior}*G*T'_{posterior}+\lambda*T_{posterior}+\lambda'*T'_{posterior}\\
    \end{aligned}
</span>
The three terms of the partition function are given by:
<span class="math display">
    \begin{aligned}
    part_0&=\sum \mathrm{log}q_{posterior}-\mathrm{log}p_{prior}=\langle \lambda,T_{posterior}\rangle +\langle \lambda',T'_{posterior}\rangle\\
    part_1&=-\langle \lambda,sgd(T_{posterior})\rangle -\langle  \lambda',sgd(T'_{posterior})\rangle\\
    part_2&=(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle\\
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'})) &=part_0+part_1+part_2
    \end{aligned}
</span>

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class kl_divergence():
    def __init__(self, latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size = batch_size

    def calc(self,G,z1,z2,z1_prior,z2_prior,mu1,log_var1,mu2,log_var2):
        ## Creating Sufficient statistics
        T1_prior = torch.cat((z1_prior,torch.square(z1_prior)),1)     #sufficient statistics for prior of set1
        T2_prior = torch.cat((z2_prior,torch.square(z2_prior)),1)   #sufficient statistics for prior of set2
        T1_post = torch.cat((z1,torch.square(z1)),1)                  #sufficient statistics for posterior of set1
        T2_post = torch.cat((z2,torch.square(z2)),1)                  #sufficient statistics for posterior of set2
        lambda1 = torch.cat((mu1,log_var1),1)                       #Output of encoder for set1
        lambda2 = torch.cat((mu2,log_var2),1)                          #Output of encoder for set2        
        T_prior_sqrd = torch.sum(torch.square(z1_prior),1) +torch.sum(torch.square(z2_prior),1) #stores z^2+z'^2
        T_post_sqrd  = torch.sum(torch.square(z1),1) +torch.sum(torch.square(z2),1)

        #Calculating KL divergence terms
#         part_fun0 = torch.sum(torch.mul(lambda1,T1_post))+torch.sum(torch.mul(lambda2,T2_post))
        part_fun0 = torch.mul(lambda1,T1_post)+torch.mul(lambda2,T2_post)
        part_fun1 = -torch.sum(torch.mul(lambda1,T1_post.detach()))-torch.sum(torch.mul(lambda2,T2_post.detach())) #-lambda*Tq-lambda'Tq' 
        part_fun1 = -torch.mul(lambda1,T1_post.detach())-torch.mul(lambda2,T2_post.detach()) #-lambda*Tq-lambda'Tq' 
        T1_prior =T1_prior.unsqueeze(2)         #[128, 2]->[128, 2,1]
        T2_prior =T2_prior.unsqueeze(1)         #[128, 2]->[128, 1,2]
        T1_post  =T1_post.unsqueeze(1)          #[128, 2]->[128, 1,2]
        T2_post  =T2_post.unsqueeze(2)          #[128, 2]->[128, 2,1]
        Tprior_kron=torch.zeros(self.batch_size,2*self.latent_dim1,2*self.latent_dim2).to(device)
        Tpost_kron=torch.zeros(self.batch_size,2*self.latent_dim1,2*self.latent_dim2).to(device)

        for i in range(self.batch_size-1):
            Tprior_kron[i,:]=torch.kron(T1_prior[i,:], T2_prior[i,:])
            Tpost_kron[i,:]=torch.kron(T1_post[i,:], T2_post[i,:])
#         part_fun2 = torch.sum(torch.mul(Tprior_kron.detach(),G)-torch.mul(Tpost_kron.detach(),G))
        part_fun2 = (torch.tensordot(Tprior_kron.detach(),G)-torch.tensordot(Tpost_kron.detach(),G))

#         print('sd',torch.mul(Tprior_kron.detach(),G).size())
        return part_fun0,part_fun1,part_fun2

In [3]:
# a=torch.randn(2,2).to(device)
# b=torch.randn(batch_size,1).to(device)
# c=torch.randn(batch_size,1).to(device)
# d=torch.randn(batch_size,1).to(device)
# e=torch.randn(batch_size,1).to(device)
# f=torch.randn(batch_size,1).to(device)
# g=torch.randn(batch_size,1).to(device)
# h=torch.randn(batch_size,1).to(device)
# qw=torch.randn(batch_size,1).to(device)
# kl_divergence.calc(b,a,c,d,e,f,g,b,h,qw)

In [4]:
# mat1 = torch.ones(3,2)
# mat2 = torch.ones(3,2)
#x=torch.kron(mat1, mat2)

In [5]:
# print(x.size())

In [6]:
# Tprior_kron=torch.zeros(5,2,2)
# Tprior_kron[1,:]