In [1]:
import import_ipynb
import kl_divergence_calculator
import gibbs_sampler_poise
import torch

importing Jupyter notebook from kl_divergence_calculator.ipynb
importing Jupyter notebook from gibbs_sampler_poise.ipynb


In [2]:
gibbs                   = gibbs_sampler_poise.gibbs_sampler()  
kl_div                  = kl_divergence_calculator.kl_divergence()
latent_dim1             = 1
latent_dim2             = 1
batch_size              = 1
device                  = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

The KL divergence is given by,
<span class="math display">
    \begin{aligned}
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'})) &= \mathrm{log}\left[\frac{\tilde{q}_{\phi}(z^{(i)},z'^{(i)}|x,x')}{\tilde{p}_{\theta}(z^{(i)},z'^{(i)})}\right]-\langle \mathbf{sg}( \mathbb{E}_q[T]),\lambda\rangle-\langle \mathbf{sg}( \mathbb{E}_q[T']),\lambda'\rangle+(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle
    \end{aligned}
</span>

<span class="math display">
    \begin{aligned}
    T_{prior}&=[z_{prior},z^2_{prior}]\\
    T'_{prior}&=[z'_{prior},z'^2_{prior}]\\
    T_{posterior}&=[z_{posterior},z^2_{posterior}]\\
    T'_{posterior}&=[z'_{posterior},z'^2_{posterior}]\\
    \lambda&=[\lambda_1,\lambda_2]\\
    \lambda'&=[\lambda'_1,\lambda'_2]\\
    T_{prior}^2&=(z^2_{prior}+z'^2_{prior})\\
    T_{posterior}^2&=(z^2_{posterior}+z'^2_{posterior})
    \end{aligned}
</span>
We also define,
<span class="math display">
    \begin{aligned}
    \mathrm{log}p_{prior}&=-T_{posterior}^2+T_{posterior}*G*T'_{posterior}\\
    \mathrm{log}p_{posterior}&=-T_{posterior}^2+T_{posterior}*G*T'_{posterior}+\lambda*T_{posterior}+\lambda'*T'_{posterior}\\
    \end{aligned}
</span>
The three terms of the partition function are given by:
<span class="math display">
    \begin{aligned}
    part_0&=\sum \mathrm{log}p_{posterior}-\mathrm{log}p_{prior}=\langle \lambda,T_{posterior}\rangle +\langle \lambda',T'_{posterior}\rangle \\
    part_1&=- \langle \lambda,sgd(T_{posterior})\rangle -\langle  \lambda',sgd(T'_{posterior})\rangle \\
    part_2&=(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle=\mathbb{E}_p\langle \mathbf{sg}[T_{prior}\otimes T'_{prior}],G\rangle-\mathbb{E}_q\langle \mathbf{sg}[T_{posterior}\otimes T'_{posterior}],G\rangle
    \end{aligned}
</span>
The KL divergence is given by,
<span class="math display">
    \begin{aligned}
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'})&=\langle \lambda,T_{posterior}\rangle +\langle \lambda',T'_{posterior}\rangle- \langle \lambda,sgd(T_{posterior})\rangle -\langle  \lambda',sgd(T'_{posterior})\rangle +(\mathbb{E}_p-\mathbb{E}_q)\langle \mathbf{sg}[T\otimes T'],G\rangle
    \end{aligned}
</span>
For G=0, we have,
<span class="math display">
    \begin{aligned}
    D_{KL}(q(\mathbf{z, z'|x, x'}) \Vert p(\mathbf{z, z'}))&=\langle \lambda,T_{posterior}\rangle +\langle \lambda',T'_{posterior}\rangle- \langle \lambda,sgd(T_{posterior})\rangle -\langle  \lambda',sgd(T'_{posterior})\rangle 
    \end{aligned}
</span>
The gradients are:
<span class="math display">
    \begin{aligned}
    \frac{\partial D_{KL}}{\partial \lambda} &= T_{posterior}+\lambda\frac{\partial T_{posterior}}{\partial \lambda} -  sgd(T_{posterior})=\lambda\frac{\partial T_{posterior}}{\partial \lambda} \\
    \frac{\partial D_{KL}}{\partial \lambda'} &= T'_{posterior}+\lambda'\frac{\partial T'_{posterior}}{\partial \lambda'} -  sgd(T'_{posterior})=\lambda'\frac{\partial T'_{posterior}}{\partial \lambda'} \\
    \frac{\partial D_{KL}}{\partial T_{posterior}} &=\lambda\\
    \frac{\partial D_{KL}}{\partial T'_{posterior}} &=\lambda'\\
    \end{aligned}
</span>
For independent Gaussians,
<span class="math display">
    \begin{aligned}
     D_{KL}&=\frac{1}{2}[\mu^2+\sigma^2-\mathrm{log}(\sigma^2)-1]+\frac{1}{2}[\mu'^2+\sigma'^2-\mathrm{log}(\sigma'^2)-1]
    \end{aligned}
</span>
The gradients are:
<span class="math display">
    \begin{aligned}
     \frac{\partial D_{KL}}{\partial \mu}&=\mu
    \end{aligned}
</span>

In [5]:
g11  = torch.zeros(latent_dim1,latent_dim2).to(device)
g22  = torch.zeros(latent_dim1,latent_dim2).to(device)        
g12  = torch.zeros(latent_dim1,latent_dim2).to(device)
G1   = torch.cat((g11,g12),0).to(device)
G2   = torch.cat((g12,g22),0).to(device)
G    = torch.cat((G1,G2),1)
mu1  = torch.randn(latent_dim1,latent_dim2).to(device)
var1 = torch.randn(latent_dim1,latent_dim2).to(device)
mu2  = torch.randn(latent_dim1,latent_dim2).to(device)
var2 = torch.randn(latent_dim1,latent_dim2).to(device)

z1_prior,z2_prior = gibbs.initialize_prior_sample(g11,g22)
z1_posterior,z2_posterior = gibbs.initialize_posterior_sample(g11,g22,mu1, var1,mu2,var2)
                                                                                

part_fun0,part_fun1,part_fun2 = kl_div.calc(G,z1_posterior,z2_posterior,z1_prior,z2_prior,mu1,var1,mu2,var2)


In [8]:
part_fun0
part_fun1
part_fun2

tensor(0., device='cuda:0')

In [13]:
x=part_fun0.backward()

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

In [14]:
part_fun0

tensor(3.4754, device='cuda:0')