In [1]:
import torch
from scipy.stats import multivariate_normal as mv
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
## In this code, I will try to sample from 1 dimensional gaussian distribution using single Gaussian distribution
## My phase space is 2 dimensional (Z,P)
## Function 1: Leapfrog integrator for SHM dx/dt = p/m, dp/dt = -x/m 
## Input G , mu1, var1, mu2, var2
## Output: z, KL
## Function 2: MCMC step
## U(z) = -log p. Here, p = exp(...), -log p = z^2+z'^2-g11*z*z'-g12*z*z'^2-g21*z^2*z'-g22*z^2*z'^2-mu1*z-var1*z^2-mu2*z'-var2*z'^2

In [3]:
class hamiltonian_sampler():
    def __init__(self, latent_dim1, latent_dim2, batch_size):
        self.latent_dim1 = latent_dim1
        self.latent_dim2 = latent_dim2
        self.batch_size = batch_size
        self.T = 10.0

    def leap_frog(self,G,Z1_0,Z2_0,P1_0,P2_0,mu1,mu2,var1,var2):
        g11 = G[:self.latent_dim1,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g12 = G[:self.latent_dim1,self.latent_dim2:] #[latent_dim1, latent_dim2]
        g21 = G[self.latent_dim1:,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g22 = G[self.latent_dim1:,self.latent_dim2:] #[latent_dim1, latent_dim2]
        epsilon = 0.01
        times = torch.linspace(0, self.T, steps=int(self.T/epsilon)+1 )
        times1= times.repeat(self.latent_dim1,1)
        times2= times.repeat(self.latent_dim2,1)
        Z1_t = (torch.zeros_like(times1)+Z1_0).to(device)
        Z2_t = (torch.zeros_like(times2)+Z2_0).to(device)
        P1_t = (torch.zeros_like(times1)+P1_0).to(device)
        P2_t = (torch.zeros_like(times2)+P2_0).to(device)
        A1_t = (torch.zeros_like(times1)-2*Z1_t).to(device)  # change -Zt here
        A2_t = (torch.zeros_like(times2)-2*Z2_t).to(device)  # change -Zt here
        P1_t[:,1]= P1_t[:,0]+0.5*epsilon*A1_t[:,0]
        P2_t[:,1]= P2_t[:,0]+0.5*epsilon*A2_t[:,0]
        for i in range(1,times.size()[0]-1):
            Z1_t[:,i]   = Z1_t[:,i-1]+epsilon*P1_t[:,i]
            Z2_t[:,i]   = Z2_t[:,i-1]+epsilon*P2_t[:,i]
            P1_t[:,i+1] = P1_t[:,i]+epsilon*A1_t[:,i-1]
            P2_t[:,i+1] = P2_t[:,i]+epsilon*A2_t[:,i-1]
            A1_t[:,i]   = -2*Z1_t[:,i]+g11@Z2_t[:,i]   + g12@(Z2_t[:,i]**2)        +2*Z2_t[:,i]@g21.T*Z1_t[:,i] + 2*(Z2_t[:,i]**2)@g22.T*Z1_t[:,i]
            A2_t[:,i]   = -2*Z2_t[:,i]+g11.T@Z1_t[:,i] + 2*Z1_t[:,i]@g12*Z2_t[:,i] + (Z1_t[:,i]**2)@g21       + 2*(Z1_t[:,i]**2)@g22*Z2_t[:,i]
        return Z1_t[:,:-1],Z2_t[:,:-1], P1_t[:,:-1], P2_t[:,:-1]
    def hamiltonian(self,G,z1,z2,p1,p2,mu1,mu2,var1,var2):
        g11 = G[:self.latent_dim1,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g12 = G[:self.latent_dim1,self.latent_dim2:] #[latent_dim1, latent_dim2]
        g21 = G[self.latent_dim1:,:self.latent_dim2] #[latent_dim1, latent_dim2]
        g22 = G[self.latent_dim1:,self.latent_dim2:] #[latent_dim1, latent_dim2] 
        z_sqd = -(z1**2).sum(-1)-(z2**2).sum(-1)    
        p_sqd = (p1**2).sum(-1)+(p2**2).sum(-1)      
        h1   = (z1@g11*z2).sum(-1)
        h2   = (z1@g12*(z2**2)).sum(-1)
        h3   = ((z1**2)@g21*z2).sum(-1)
        h4   = ((z1**2)@g22*(z2**2)).sum(-1)
        h    = h1+h2+h3+h4      
        d1   = (mu1@z1.T+var1@(z1**2).T)
        d2   = (mu2@z2.T+var2@(z2**2).T)        
        d    = d1 + d2                          
        U_z = -(z_sqd+h+d)
        H = p_sqd/2+U_z
        return H
    def MCMC(self,G,Z1_t,Z2_t,P1_t,P2_t,mu1,mu2,var1,var2):
        H_init = self.hamiltonian(G,Z1_t[:,0],Z2_t[:,0],P1_t[:,0],P2_t[:,0],mu1,mu2,var1,var2)
        H_finl = self.hamiltonian(G,Z1_t[:,-1],Z2_t[:,-1],P1_t[:,-1],P2_t[:,-1],mu1,mu2,var1,var2)
        val = torch.exp(H_finl-H_init)
        if val>=1:
            flag = 1 #accept
        elif val>torch.rand(1) :
            flag = 1 #accept
        else:
            flag = 0
        return flag,val
    def calc(self,G,mu1,var1,mu2,var2): 
        HMC_Z1 = torch.tensor([]).to(device)
        HMC_Z2 = torch.tensor([]).to(device)
        M   = 1
        Z1_0 = torch.ones(self.latent_dim1,1).to(device)
        Z2_0 = torch.ones(self.latent_dim2,1).to(device)
        P1_mn= torch.distributions.MultivariateNormal(torch.zeros(self.latent_dim1), M * torch.eye(self.latent_dim1))
        P2_mn= torch.distributions.MultivariateNormal(torch.zeros(self.latent_dim2), M * torch.eye(self.latent_dim2))
        P1_0 = P1_mn.sample([1])[0].unsqueeze(1).to(device)
        P2_0 = P2_mn.sample([1])[0].unsqueeze(1).to(device)
        Z1_t,Z2_t,P1_t,P2_t = self.leap_frog(G,Z1_0,Z2_0,P1_0,P2_0,mu1,mu2,var1,var2)
        flag,val = self.MCMC(G,Z1_t,Z2_t,P1_t,P2_t,mu1,mu2,var1,var2)
        if flag ==1:
            HMC_Z1 = torch.cat( (HMC_Z1,Z1_t[:,-1].unsqueeze(0)),dim=0)
            HMC_Z2 = torch.cat( (HMC_Z2,Z2_t[:,-1].unsqueeze(0)),dim=0)
        return HMC_Z1,HMC_Z2


In [44]:


# G = 0*torch.randn(2*latent_dim1,2*latent_dim2)
# mu1 = 0*torch.randn(latent_dim1,)
# mu2 = 0*torch.randn(latent_dim2,)
# var1 = 0*torch.randn(latent_dim1,)
# var2 = 0*torch.randn(latent_dim2,)

