In [26]:
import re
import string
import random
import torch as t
import torch.nn as nn
from torch.distributions import Normal
from torch.autograd import Variable

from functools import reduce
import numpy as np

In [27]:
# Example 
GROUND_PRIOR_MEAN = 0.9
GROUND_PRIOR_VAR = 1.0
GROUND_SIGMA = 5.5

In [36]:
def log_norm(x, mu, std):    
    norm_constant = -0.5 * t.log(2*np.pi*std**2)
    
    return norm_constant - (0.5 * (1/(std**2))* (x-mu)**2)


class Elbo(nn.Module):
    def __init__(self):
        super(Elbo, self).__init__()
        self.n_latent = 100 # Number of latent samples
        self.softplus = nn.Softplus()
        
        # adaptive variational params
        self.qm = nn.Parameter(t.randn(1,1), requires_grad=True)
        self.qs = nn.Parameter(t.randn(1,1), requires_grad=True)
        
        #create holders for prior mean and std, and likelihood std.
        self.prior_m = Variable(t.randn(1,1), requires_grad=False)
        #self.prior_m = nn.Parameter(t.randn(1,1), requires_grad=False)
        self.prior_s = Variable(t.randn(1,1), requires_grad=False)
        # self.prior_s = nn.Parameter(t.randn(1,1), requires_grad=False)
        self.likelihood_s = Variable(t.FloatTensor((1)), requires_grad=False)
        # self.likelihood_s = nn.Parameter(t.FloatTensor((1)), requires_grad=False)
        
        #Set the prior and likelihood moments.
        self.prior_s.data.fill_(GROUND_PRIOR_VAR)
        self.prior_m.data.fill_(GROUND_PRIOR_MEAN)
        self.likelihood_s.data.fill_(GROUND_SIGMA)
     
        
    def generate_rand(self):
        return np.random.normal(size=(self.n_latent,1))
    
    
    def get_mean(self) :
        return self.qm.data.numpy()
    
    
    def get_var(self) :
        torch_var = self.softplus(self.qs).data**2
        return torch_var.numpy()
    
    
    def reparam(self, eps):
        eps = Variable(t.FloatTensor(eps))
        # eps = nn.Parameter(t.FloatTensor(eps))
        
        return eps.mul(self.softplus(self.qs)) \
                .add(self.qm)
    
    
    def compute_elbo(self, x, y):
        eps = self.generate_rand()
        z = self.reparam(eps)
        
        q_likelihood = t.mean(log_norm(z, self.qm, self.softplus(self.qs)))
        prior = t.mean(log_norm(z, self.prior_m, self.prior_s))
        
        xz = x * z.transpose(0,1)
        sum_log_prob = t.sum(log_norm(y, xz, self.likelihood_s), 0)
        likelihood = t.mean(sum_log_prob)
        
        kl_div_mc = q_likelihood - prior
        loss = likelihood - kl_div_mc
        
        return loss

# Gen example data

In [29]:
N = 200
w = 3.2

X = np.random.uniform(low=-50, high=50, size=(N, 1))
Y = w *X + np.random.normal(size=(N, 1), scale=GROUND_SIGMA)

# Optimise it

In [37]:
q = Elbo()
optimizer = t.optim.Adam(q.parameters(), lr=0.2)
x = Variable(t.Tensor(X), requires_grad=False) 
y = Variable(t.Tensor(Y), requires_grad=False)

for i in range(3501):
    loss = - q.compute_elbo(x, y)
    optimizer.zero_grad()
    loss.backward(retain_graph=True)
    optimizer.step()
    
    if i % 250 ==0:
        print(q.get_mean(), q.get_var())

[[0.77676654]] [[0.15058462]]
[[3.1996546]] [[0.00028702]]
[[3.1990616]] [[0.00018747]]
[[3.1990438]] [[0.00016577]]
[[3.2004526]] [[0.0001654]]
[[3.2011387]] [[0.00016327]]
[[3.1997194]] [[0.0001637]]
[[3.2006679]] [[0.00015809]]
[[3.2003458]] [[0.00015999]]
[[3.2033272]] [[0.0001635]]
[[3.1951232]] [[0.00016368]]
[[3.20158]] [[0.00016144]]
[[3.1981337]] [[0.00016061]]
[[3.201482]] [[0.00016215]]
[[3.1979587]] [[0.00016075]]


# Eval

In [39]:
def analytical_posterior_var(var, X) :
    scaled_prec = (1/var**2) * X.T @ X +1
    
    return scaled_prec**-1


def analytical_posterior_mean(prior_mean, var, X, Y) :
    scaled_cov = (1/var**2) * X.T @ Y
    post_var = analytical_posterior_var(var, X)
    
    return post_var * (prior_mean + scaled_cov)



TRUE_POST_MEAN = analytical_posterior_mean(GROUND_PRIOR_MEAN, GROUND_PRIOR_VAR, X, Y)
TRUE_POST_VAR = analytical_posterior_var(GROUND_PRIOR_VAR, X)

q.get_mean() - TRUE_POST_MEAN, \
q.get_var() - TRUE_POST_VAR


(array([[-0.00242217]]), array([[0.00015538]]))