In [1]:
import torch
from torchviz import make_dot

The KL divergence is given by,
<span class="math display">
    \begin{aligned}
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'})) &= \mathrm{log}\left[\frac{\tilde{q}_{\phi}(z^{(i)},z'^{(i)}|x,x')}{\tilde{p}_{\theta}(z^{(i)},z'^{(i)})}\right]-\langle \mathbf{sg}( \mathbb{E}_q[T]),\lambda\rangle-\langle \mathbf{sg}( \mathbb{E}_q[T']),\lambda'\rangle+(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle
    \end{aligned}
</span>

<span class="math display">
    \begin{aligned}
    T_{prior}&=[z_{prior},z^2_{prior}]\\
    T'_{prior}&=[z'_{prior},z'^2_{prior}]\\
    T_{posterior}&=[z_{posterior},z^2_{posterior}]\\
    T'_{posterior}&=[z'_{posterior},z'^2_{posterior}]\\
    \lambda&=[\lambda_1,\lambda_2]\\
    \lambda'&=[\lambda'_1,\lambda'_2]\\
    T_{prior}^2&=(z^2_{prior}+z'^2_{prior})\\
    T_{posterior}^2&=(z^2_{posterior}+z'^2_{posterior})
    \end{aligned}
</span>
We also define,
<span class="math display">
    \begin{aligned}
    \mathrm{log}p_{prior}&=-T_{prior}^2+T_{posterior}*G*T'_{posterior}\\
    \mathrm{log}p_{posterior}&=-T_{posterior}^2+T_{posterior}*G*T'_{posterior}+\lambda*T_{posterior}+\lambda'*T'_{posterior}\\
    \end{aligned}
</span>
The three terms of the partition function are given by:
<span class="math display">
    \begin{aligned}
    part_0&=\sum \mathrm{log}p_{posterior}-\mathrm{log}p_{prior}=\langle \lambda,T_{posterior}\rangle +\langle \lambda',T'_{posterior}\rangle \\
    part_1&=- \langle \lambda,sgd(T_{posterior})\rangle -\langle  \lambda',sgd(T'_{posterior})\rangle \\
    part_2&=(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle=\mathbb{E}_p\langle \mathbf{sg}[T_{prior}\otimes T'_{prior}],G\rangle-\mathbb{E}_q\langle \mathbf{sg}[T_{posterior}\otimes T'_{posterior}],G\rangle
    \end{aligned}
</span>

In [2]:
latent_dim1 = 1
latent_dim2 = 32
batch_size  = 1
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class kl_divergence():
    def __init__(self):
        pass

    def calc(self,G,z1,z2,z1_prior,z2_prior,mu1,log_var1,mu2,log_var2):
        ## Creating Sufficient statistics
        T1_prior = torch.cat((z1_prior,torch.square(z1_prior)),1)     # sufficient statistics for prior of set1
        T2_prior = torch.cat((z2_prior,torch.square(z2_prior)),1)     # sufficient statistics for prior of set2
        T1_post = torch.cat((z1,torch.square(z1)),1)                  # sufficient statistics for posterior of set1
        T2_post = torch.cat((z2,torch.square(z2)),1)                  # sufficient statistics for posterior of set2
        lambda1 = torch.cat((mu1,log_var1),1)                         # Output of encoder for set1
        lambda2 = torch.cat((mu2,log_var2),1)                         # Output of encoder for set2        
        T_prior_sqrd = torch.sum(torch.square(z1_prior),1) +torch.sum(torch.square(z2_prior),1) #stores z^2+z'^2
        T_post_sqrd  = torch.sum(torch.square(z1),1) +torch.sum(torch.square(z2),1)

        part_fun0 = torch.sum(torch.mul(lambda1,T1_post))+torch.sum(torch.mul(lambda2,T2_post))
        part_fun1 = -torch.sum(torch.mul(lambda1,T1_post.detach()))-torch.sum(torch.mul(lambda2,T2_post.detach())) #-lambda*Tq-lambda'Tq'
      
        T1_prior =T1_prior.unsqueeze(2)       #[128, 2]->[128, 2,1]
        T2_prior =T2_prior.unsqueeze(1)       #[128, 64]->[128, 1,64]
        T1_post  =T1_post.unsqueeze(2)        #[128, 2]->[128, 2,1]
        T2_post  =T2_post.unsqueeze(1)        #[128, 64]->[128, 1,64]
        Tprior_kron=torch.zeros(batch_size,2*latent_dim1,2*latent_dim2).to(device)   #[128, 2,64]
        Tpost_kron=torch.zeros(batch_size,2*latent_dim1,2*latent_dim2).to(device)    #[128, 2,64]  

        for i in range(batch_size):
            Tprior_kron[i,:]=torch.kron(T1_prior[i,:], T2_prior[i,:])
            Tpost_kron[i,:]=torch.kron(T1_post[i,:], T2_post[i,:])      

        part_fun2 = torch.sum(torch.mul(Tprior_kron.detach(),G)-torch.mul(Tpost_kron.detach(),G))
#         make_dot(part_fun0,params=dict(model.named_parameters())).render("part_fun0", format="png")
#         make_dot(part_fun1,params=dict(model.named_parameters())).render("part_fun1", format="png")
#         make_dot(part_fun2,params=dict(model.named_parameters())).render("part_fun2", format="png")

#         print(part_fun0)
#         print(part_fun1)
#         print(part_fun2)
        
        return part_fun0,part_fun1,part_fun2