## Two-level variational Bayesian inference

In [108]:
import pandas as pd
import torch
from torch.autograd import Variable

import torch.nn as nn
import torch.optim as optim

In [109]:
from torch.distributions.normal import Normal
from torch.distributions.multivariate_normal import MultivariateNormal
from torch.distributions.gamma import Gamma

In [110]:
#mean = torch.Tensor([[1, 2, 3], [4, 5, 6]])
mean = torch.rand(3)
#cov1 = torch.eye(3)
#cov2 = torch.Tensor([[1, 1, 1], [1, 2, 2], [1, 2, 3]])
#cov = torch.stack([cov1, cov2], 0)
cov = torch.diag(torch.exp(torch.rand(3)))
#cov = torch.eye(3)
distrib = MultivariateNormal(loc=mean, covariance_matrix=cov)
distrib.sample().reshape(3,1)
#cov

tensor([[3.0802],
        [3.8075],
        [0.3488]])

In [111]:
t= Normal(0,1).sample(torch.Size([4,1]))
t.shape

torch.Size([4, 1])

In [112]:
mean.sum()

tensor(2.2282)

In [113]:
Gamma(torch.exp(torch.rand(1)), torch.exp(torch.rand(1))).sample()

tensor([0.3725])

In [114]:
1 + torch.nan

nan

In [115]:
class StraightLineLayer(nn.Module):
    def __init__(self, input_dim):

        super().__init__()
        self.input_dim = input_dim
        self.W_mean = nn.Parameter(torch.rand(self.input_dim))
        self.V = nn.Parameter(torch.rand(self.input_dim))
        self.a = nn.Parameter(torch.rand(1)-2.)
        self.b = nn.Parameter(torch.rand(1)-4.)
        # self.a = nn.Parameter(torch.rand(1))
        # self.b = nn.Parameter(torch.rand(1))
        
        self.tau_log = []
        self.a_log = []
        self.b_log = []
        self.V_log = []
        self.W_mean_log = []
        #self.c = nn.Parameter(torch.tensor(1.0))
        #self.d = nn.Parameter(torch.tensor(1.0))

    def forward(self, X):

        #alpha = Gamma(self.c, self.d).sample(torch.Size([1, 1]))
        tau = Gamma(torch.exp(self.a), torch.exp(self.b)).sample() 
        self.tau_log.append(tau)
        
        assert tau > 0
        #tau = torch.exp(log_tau)
        
        #V = torch.diag(torch.exp(self.V))
        
        z = Normal(0, 1).sample(torch.Size([self.input_dim, 1]))
        #cov = 1/tau * torch.diag(torch.exp(self.V))
        #W = MultivariateNormal(loc = self.W_mean, covariance_matrix = cov).sample().reshape(self.input_dim,1) 
        #W = Normal(self.W_mean, 1/tau * V).sample().reshape(self.input_dim, 1)
        W = self.W_mean.reshape(self.input_dim, 1) + z * (torch.exp(self.V).reshape(self.input_dim, 1) / tau)**(0.5)
        
        return torch.matmul(X, W), W, tau

In [116]:
from torch.autograd import Variable

class fullVariationalBayes(nn.Module):

    def __init__(self, input_dim, a0=1e-2, b0=1e-4, c0=1e-2, d0=1e-4, num_samples = 50):
        super().__init__()
        
        self.num_samples = num_samples
        self.c = nn.Parameter(torch.rand(1)-2.)
        self.d = nn.Parameter(torch.rand(1)-4.)
        
        self.a0 = Variable(torch.tensor(a0))
        self.b0 = Variable(torch.tensor(b0))
        self.c0 = Variable(torch.tensor(c0))
        self.d0 = Variable(torch.tensor(d0))
        
        self.f = StraightLineLayer(input_dim)
        
    def forward(self, x, y):
        
        nLogLik = 0.0
        tmp_data_size = x.shape[0]
        #tmp_sigmaN2 = torch.exp(self.f.b - self.f.a)
        
        for i in range(self.num_samples):
            
            pred, W, tau = self.f(x)
            #print(pred.shape)
            alpha = Gamma(torch.exp(self.c), torch.exp(self.d)).sample(torch.Size([1, 1]))
            nLogLik = nLogLik - Normal(pred, 1/tau**(0.5)).log_prob(y).sum() 
            nLogLik = nLogLik - Normal(0, 1/(tau*alpha)**(0.5)).log_prob(W).sum()
            nLogLik = nLogLik - Gamma(self.a0, self.b0).log_prob(tau)
            nLogLik = nLogLik - Gamma(self.c0, self.d0).log_prob(alpha)
                       
        nLogLik = nLogLik / self.num_samples
 
        LogVar_W_tau = (-1.) * tmp_data_size * (self.f.b + self.f.V.sum() + torch.exp(self.f.a - self.f.b) - torch.special.digamma(torch.exp(self.f.a)))
        LogVar_W_tau = LogVar_W_tau - torch.lgamma(torch.exp(self.f.a)) + (torch.exp(self.f.a) - 1) * torch.special.digamma(torch.exp(self.f.a)) + self.f.b - torch.exp(self.f.a)
    
        LogVar_alpha = (-1.) * torch.lgamma(torch.exp(self.c)) + (torch.exp(self.c) - 1) * torch.special.digamma(torch.exp(self.c)) + self.d - torch.exp(self.c)

        self.f.a_log.append(self.f.a.item())
        self.f.b_log.append(self.f.b.item())
        self.f.V_log.append(self.f.V.data)
        self.f.W_mean_log.append(self.f.W_mean.data)
        
        return LogVar_W_tau + LogVar_alpha + nLogLik 

In [117]:
tensorX = torch.hstack([torch.ones(100, 1), torch.rand(100, 3)])

In [118]:
tensorX.shape

torch.Size([100, 4])

In [119]:
#torch.matmul(tensorX, torch.rand(4,1))

In [120]:
tensory = torch.matmul(tensorX, torch.tensor([1., 2., 3., 5.]).reshape(4, 1))

In [121]:
from tqdm import tqdm

epochs = 1000

#model = nLogLikelyhood_v1()

#model = nLogLikelyhood_v2()

#model = maxPosterior_v1()

learning_rate = 0.02

model = fullVariationalBayes(input_dim = tensorX.shape[1])

optimizer = optim.Adam(model.parameters(), lr = learning_rate)

for epoch in tqdm(range(epochs), desc="Training..."):
    
    optimizer.zero_grad()
    
    #nLogLik = model(x_tensor, y_tensor)
    #e = torch.mean(nLogLik)
    
    nLogLik = model(tensorX, tensory)
    #nLogLik.backward(retain_graph=True)
    nLogLik.backward()
    
    #e.backward()
    optimizer.step()

Training...:  60%|█████▉    | 597/1000 [00:43<00:29, 13.81it/s]


ValueError: Expected parameter concentration (Tensor of shape (1,)) of distribution Gamma(concentration: tensor([nan], grad_fn=<ExpBackward0>), rate: tensor([0.2009], grad_fn=<ExpBackward0>)) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([nan], grad_fn=<ExpBackward0>)

In [122]:
model.f.tau_log[-5:]

[tensor([5.9174e-38]),
 tensor([5.9174e-38]),
 tensor([5.9174e-38]),
 tensor([5.9174e-38]),
 tensor([5.9174e-38])]

In [135]:
model.f.a_log[-5:]

[-41.75222396850586,
 -41.849422454833984,
 -41.94668197631836,
 -42.044002532958984,
 -42.14138412475586]

In [126]:
model.f.b_log[-5:]

[-1.660336971282959,
 -1.6493194103240967,
 -1.6382932662963867,
 -1.627258539199829,
 -1.6162152290344238]

In [124]:
model.f.W_mean_log[-5:]

[tensor([1.5988, 0.8867, 1.2007, 1.0883]),
 tensor([1.5988, 0.8867, 1.2007, 1.0883]),
 tensor([1.5988, 0.8867, 1.2007, 1.0883]),
 tensor([1.5988, 0.8867, 1.2007, 1.0883]),
 tensor([1.5988, 0.8867, 1.2007, 1.0883])]

In [132]:
model.f.V_log[-5:]

[tensor([-13.4542, -13.2659, -13.2444, -13.2380]),
 tensor([-13.4542, -13.2659, -13.2444, -13.2380]),
 tensor([-13.4542, -13.2659, -13.2444, -13.2380]),
 tensor([-13.4542, -13.2659, -13.2444, -13.2380]),
 tensor([-13.4542, -13.2659, -13.2444, -13.2380])]

In [128]:
model.d, model.c

(Parameter containing:
 tensor([-15.0960], requires_grad=True),
 Parameter containing:
 tensor([5.0907], requires_grad=True))

In [127]:
torch.exp(model.f.V)

tensor([1.4352e-06, 1.7325e-06, 1.7701e-06, 1.7815e-06],
       grad_fn=<ExpBackward0>)

In [129]:
model.f.W_mean

Parameter containing:
tensor([1.5988, 0.8867, 1.2007, 1.0883], requires_grad=True)

In [130]:
torch.exp(model.f.a)

tensor([nan], grad_fn=<ExpBackward0>)

In [131]:
torch.exp(model.f.b)

tensor([0.2009], grad_fn=<ExpBackward0>)