In [None]:
import csv
import sys
import math

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from torch.autograd import Variable
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision import datasets

torch.manual_seed(42)z


<torch._C.Generator at 0x7cb6b1cc45d0>

In [None]:
class BBB_HyperParameters(object):

    def __init__(self, ):
        self.lr = 1e-4 #1e-3, 1e-4, 1e-5
        self.momentum = 0.95
        self.hidden_units = 1200
        self.pi = 0.75 # 0.75, 0.5, 0.25
        self.s1 = float(np.exp(-1)) # exp(0), exp(-1), exp(-2)
        self.s2 = float(np.exp(-8)) # exp(-6), exp(-7), exp(-8)
        self.max_epoch = 200
        self.n_test_samples = 10
        self.batch_size = 128
    


def gaussian(x, mu, sigma):
    return (1. / (torch.sqrt(torch.tensor(2. * np.pi)) * sigma)) * torch.exp(- (x - mu) ** 2 / (2. * sigma ** 2))


def mixture_prior(input, pi, s1, s2):
    p1 = pi * gaussian(input, 0., s1)
    p2 = (1. - pi) * gaussian(input, 0., s2)
    return torch.log(p1 + p2)


def log_gaussian_rho(x, mu, rho):
    return float(-0.5 * np.log(2 * np.pi)) - rho - (x - mu) ** 2 / (2 * torch.exp(rho) ** 2)



In [None]:
class BBBLayer(nn.Module):
    def __init__(self, n_input, n_output, hyper):
        super(BBBLayer, self).__init__()
        self.n_input = n_input
        self.n_output = n_output

        self.s1 = hyper.s1
        self.s2 = hyper.s2
        self.pi = hyper.pi

        
        self.weight_mu = nn.Parameter(torch.Tensor(n_output, n_input))
        self.bias_mu = nn.Parameter(torch.Tensor(n_output))

        torch.nn.init.trunc_normal_(self.weight_mu, std=0.05)
        torch.nn.init.constant_(self.bias_mu, 0.)

        #torch.nn.init.kaiming_uniform_(self.weight_mu, nonlinearity='relu')
        #fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(self.weight_mu)
        #bound = 1 / math.sqrt(fan_in)
        #torch.nn.init.uniform_(self.bias_mu, -bound, bound)
        
        # rho parameters
        self.weight_rho = nn.Parameter(torch.Tensor(n_output, n_input).normal_(-6.5, 0.1))
        self.bias_rho = nn.Parameter(torch.Tensor(n_output).normal_(-6.5, 0.1))

        #self.bias_rho = nn.Parameter(torch.Tensor(n_output).normal_(-8.0, .05))
        #self.weight_rho = nn.Parameter(torch.Tensor(n_output, n_input).normal_(-8.0, .05))


        self.log_prior = 0. 
        self.log_varpost = 0. 

    def forward(self, data, infer=False):
        if infer:
            output = F.linear(data, self.weight_mu, self.bias_mu)
            return output

        epsilon_W = Variable(torch.Tensor(self.n_output, self.n_input).normal_(0, 1).cuda())
        epsilon_b = Variable(torch.Tensor(self.n_output).normal_(0, 1).cuda())
        W = self.weight_mu + torch.log(1+torch.exp(self.weight_rho)) * epsilon_W
        b = self.bias_mu + torch.log(1+torch.exp(self.bias_rho)) * epsilon_b

        self.log_varpost = log_gaussian_rho(W, self.weight_mu, self.weight_rho).sum() + log_gaussian_rho(b, self.bias_mu, self.bias_rho).sum()
        self.log_prior = mixture_prior(W, self.pi, self.s2, self.s1).sum() + mixture_prior(b, self.pi, self.s2, self.s1).sum()

        output = F.linear(data, W, b)

        return output

In [22]:
class BBB(nn.Module):
    def __init__(self, n_input, n_output, hyper):
        super(BBB, self).__init__()

        self.n_input = n_input
        self.layers = nn.ModuleList([])
        self.layers.append(BBBLayer(n_input, hyper.hidden_units, hyper))
        self.layers.append(BBBLayer(hyper.hidden_units, hyper.hidden_units, hyper))
        self.layers.append(BBBLayer(hyper.hidden_units, n_output, hyper))

    def forward(self, data, infer=False):
        output = F.relu(self.layers[0](data.view(-1, self.n_input), infer))
        output = F.relu(self.layers[1](output, infer))
        output = F.softmax(self.layers[2](output, infer), dim=1)
        return output

    def get_prior_varpost(self):
        log_prior = self.layers[0].log_prior + self.layers[1].log_prior + self.layers[2].log_prior
        log_varpost = self.layers[0].log_varpost + self.layers[1].log_varpost + self.layers[2].log_varpost
        return log_prior, log_varpost

In [None]:
def MonteCarloSampling(model, data, target):
    s_log_prior, s_log_varpost, s_log_likelihood = 0., 0., 0.

    #print(model(data)[0])
    
    output = torch.log(model(data))

    #print(f"Log-Output: {output}")

    sample_log_prior, sample_log_varpost = model.get_prior_varpost()

    #print(f"log_prior: {sample_log_prior}, log_varpost: {sample_log_varpost}")

    
    sample_log_likelihood = -F.nll_loss(output, target, reduction='sum')

    #print(f"log_likelihood: {sample_log_likelihood}")

    s_log_prior += sample_log_prior 
    s_log_varpost += sample_log_varpost 
    s_log_likelihood += sample_log_likelihood

    return s_log_prior, s_log_varpost, s_log_likelihood


def ELBO(log_prior, log_varpost, l_likelihood, m):
    kl = (1/m) * (log_varpost - log_prior)
    return kl - l_likelihood

In [None]:
def train(model, optimizer, loader, train=True):
    loss_sum = 0
    kl_sum = 0
    m = len(loader)

    for batch_id, (data, target) in enumerate(loader):
        data, target = data.cuda(), target.cuda()
        model.zero_grad()
        
        log_prior, log_varpost, l_likelihood = MonteCarloSampling(model, data, target)
        loss = ELBO(log_prior, log_varpost, l_likelihood, m)
        loss_sum += loss / m

        if train:
            loss.backward()
            optimizer.step()
        else:
            kl_sum += (1. / m) * (log_varpost - log_prior)
    if train:
        return loss_sum
    else:
        return kl_sum

def evaluate(model, loader, infer=True, samples=1):
    acc_sum = 0
    for idx, (data, target) in enumerate(loader):
        data, target = data.cuda(), target.cuda()

        if samples == 1:
            output = model(data, infer=infer)
        else:
            output = model(data)
            for i in range(samples - 1):
                output += model(data)

        predict = output.data.max(1)[1]
        acc = predict.eq(target.data).cpu().sum().item()
        acc_sum += acc
    return acc_sum / len(loader)

In [None]:
def BBB_run(hyper, train_loader, valid_loader, test_loader, n_input, n_output):
    
    model = BBB(n_input, n_output, hyper).cuda()
    optimizer = torch.optim.SGD(model.parameters(), lr=hyper.lr, momentum=hyper.momentum)

    train_losses = np.zeros(hyper.max_epoch)
    valid_accs = np.zeros(hyper.max_epoch)
    test_accs = np.zeros(hyper.max_epoch)

    for epoch in range(hyper.max_epoch):
        train_loss = train(model, optimizer, train_loader)
        valid_acc = evaluate(model, valid_loader)
        test_acc = evaluate(model, test_loader)

        print('Epoch', epoch + 1, 'Loss', float(train_loss),
              'Valid Error', round(100 * (1 - valid_acc / hyper.batch_size), 3), '%',
              'Test Error',  round(100 * (1 - test_acc / hyper.batch_size), 3), '%')

        valid_accs[epoch] = valid_acc
        test_accs[epoch] = test_acc
        train_losses[epoch] = train_loss


    return model

In [None]:
transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Lambda(lambda x: x  / 126.),  
        ])


train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)

test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)


n_input = 28 * 28
n_ouput = 10

split_size = 1/6

generator = torch.Generator().manual_seed(42)
train_set, validaton_set= torch.utils.data.random_split(train_data, [1 - split_size, split_size], generator=generator)


hyper = BBB_HyperParameters()

train_loader = DataLoader(train_set, batch_size=hyper.batch_size, shuffle=True, num_workers=1)
valid_loader = DataLoader(validaton_set, batch_size=hyper.batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_data, batch_size=hyper.batch_size, shuffle=True, num_workers=1)


model = BBB_run(hyper, train_loader, valid_loader, test_loader, n_input, n_ouput)


Epoch 1 Loss 39185.60546875 Valid Error 4.658 % Test Error 4.47 %
Epoch 2 Loss 39175.44140625 Valid Error 4.015 % Test Error 4.015 %
Epoch 3 Loss 39158.37109375 Valid Error 3.807 % Test Error 3.323 %
Epoch 4 Loss 39142.76953125 Valid Error 3.847 % Test Error 3.619 %
Epoch 5 Loss 39127.765625 Valid Error 3.343 % Test Error 2.957 %
Epoch 6 Loss 39113.66015625 Valid Error 3.323 % Test Error 3.006 %
Epoch 7 Loss 39099.45703125 Valid Error 3.333 % Test Error 2.967 %
Epoch 8 Loss 39085.40625 Valid Error 3.303 % Test Error 2.907 %
Epoch 9 Loss 39071.20703125 Valid Error 3.293 % Test Error 3.006 %
Epoch 10 Loss 39057.35546875 Valid Error 3.313 % Test Error 2.878 %
Epoch 11 Loss 39043.5625 Valid Error 3.214 % Test Error 2.878 %
Epoch 12 Loss 39029.8515625 Valid Error 3.244 % Test Error 2.917 %


KeyboardInterrupt: 