In [1]:
# pytorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
from pyro.infer.util import torch_item
from torch.distributions.uniform import Uniform
from torch.distributions.normal import Normal as Normal_torch

# python
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
import os
from PIL import Image
from torch.utils.data.dataset import Dataset
from scipy.misc import imread
import math
import pandas as pd

# pyro
import pyro
from pyro.distributions import Normal, Categorical, MultivariateNormal
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam, SGD
import pyro.poutine as poutine
from pyro.contrib.autoguide import AutoDiagonalNormal

In [2]:
batch_size = 32
resize = 32
epoch = 200
lr = 0.0001
weight_decay = 0.0005
num_samples = 10

In [3]:
transform_train = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

transform_test = transforms.Compose([
    transforms.Resize((resize, resize)),
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,)),
])

In [4]:
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=True, download=True, transform=transform_train),batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=False, transform=transform_test),batch_size=batch_size, shuffle=True)

os.environ['CUDA_VISIBLE_DEVICES'] = '1'

In [5]:
def learning_rate(init, epoch):
    optim_factor = 0
    if(epoch > 160):
        optim_factor = 3
    elif(epoch > 120):
        optim_factor = 2
    elif(epoch > 60):
        optim_factor = 1

    return init*math.pow(0.2, optim_factor)

In [6]:
class LeNet(nn.Module):
    def __init__(self, num_classes, inputs=1):
        super(LeNet, self).__init__()
        self.conv1 = nn.Conv2d(inputs, 6, 5, stride=1, bias=False)
        self.conv2 = nn.Conv2d(6, 16, 5, stride=1, bias=False)
        self.fc1 = nn.Linear(16*5*5, 120, bias=False)
        self.fc2 = nn.Linear(120, 84, bias=False)
        self.fc3 = nn.Linear(84, num_classes, bias=False)

    def forward(self, x):
        out = F.max_pool2d(F.softplus(self.conv1(x)), 2)
        out = F.max_pool2d(F.softplus(self.conv2(out)), 2)
        out = out.view(out.size(0), -1)
        out = F.softplus(self.fc1(out))
        out = F.softplus(self.fc2(out))
        out = self.fc3(out)
        return out

In [13]:
class Bayesian(nn.Module):
    def __init__(self):
        super(Bayesian, self).__init__()
        self.net = LeNet(10, 1)
        self.log_softmax = nn.LogSoftmax(dim=1)
        
    def normal_prior(self,name, params):
        mu_param = pyro.param('{}_mu'.format(name), torch.randn_like(params))
        sigma_param = F.softplus(pyro.param('{}_sigma'.format(name), torch.randn_like(params)))
        prior = Normal(loc=mu_param, scale=sigma_param)
        return prior
    
    def mean_field_norm_prior(self, name, params, eps=10e-7):
        loc_init = pyro.param('{}_mu'.format(name), torch.normal(mean=torch.zeros_like(params), std=torch.mul(torch.ones_like(params), 0.1)))
        untransformed_scale_init = pyro.param('{}_sigma'.format(name), torch.normal(mean=torch.ones_like(params)*(-3), std=torch.mul(torch.ones_like(params), 0.1)))
        sigma = eps + F.softplus(untransformed_scale_init)
        dist = Normal(loc=loc_init, scale=sigma)
        return dist

    def fixed_normal_prior(self, params):
        dist = Normal(loc=torch.zeros_like(params), scale=torch.ones_like(params))
        return dist
    
    def model(self, x, y):
        conv1w_prior = self.fixed_normal_prior(self.net.conv1.weight)
        conv2w_prior = self.fixed_normal_prior(self.net.conv2.weight)
        fc1w_prior = self.fixed_normal_prior(self.net.fc1.weight)
        fc2w_prior = self.fixed_normal_prior(self.net.fc2.weight)
        fc3w_prior = self.fixed_normal_prior(self.net.fc3.weight)
        
        priors = {
            'conv1.weight':conv1w_prior,
            'conv2.weight':conv2w_prior, 
            'fc1.weight': fc1w_prior,
            'fc2.weight':fc2w_prior,
            'fc3.weight':fc3w_prior
        }
        
        # lift module parameters to random variables sampled from the priors
        lifted_module = pyro.random_module("module", self.net, priors)
        
        # sample a classifier
        lifted_reg_model = lifted_module()
        
        p_hat = self.log_softmax(lifted_reg_model(x))
        
        with pyro.plate('observe_data'):
            pyro.sample("obs", Categorical(logits=p_hat), obs=y)
    
    def guide(self, x, y):
        conv1w_prior = self.mean_field_norm_prior('conv1w',self.net.conv1.weight)
        conv2w_prior = self.mean_field_norm_prior('conv2w',self.net.conv2.weight)
        fc1w_prior = self.mean_field_norm_prior('fc1w',self.net.fc1.weight)
        fc2w_prior = self.mean_field_norm_prior('fc2w', self.net.fc2.weight)
        fc3w_prior = self.mean_field_norm_prior('fc3w',self.net.fc3.weight)
        
        priors = {
            'conv1.weight':conv1w_prior,
            'conv2.weight':conv2w_prior, 
            'fc1.weight': fc1w_prior,
            'fc2.weight':fc2w_prior,
            'fc3.weight':fc3w_prior
        }
        lifted_module = pyro.random_module("module", self.net, priors)
        return lifted_module()

In [14]:
net = Bayesian()
net.cuda()

Bayesian(
  (net): LeNet(
    (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1), bias=False)
    (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1), bias=False)
    (fc1): Linear(in_features=400, out_features=120, bias=False)
    (fc2): Linear(in_features=120, out_features=84, bias=False)
    (fc3): Linear(in_features=84, out_features=10, bias=False)
  )
  (log_softmax): LogSoftmax()
)

In [15]:
def simple_elbo_kl_annealing(model, guide, *args, **kwargs):
    # get the annealing factor and latents to anneal from the keyword
    # arguments passed to the model and guide
    annealing_factor = kwargs.pop('annealing_factor', 1.0)
    # run the guide and replay the model against the guide
    guide_trace = poutine.trace(guide).get_trace(*args, **kwargs)
    model_trace = poutine.trace(
        poutine.replay(model, trace=guide_trace)).get_trace(*args, **kwargs)

    elbo = 0.0
    # loop through all the sample sites in the model and guide trace and
    # construct the loss; note that we scale all the log probabilities of
    # samples sites in `latents_to_anneal` by the factor `annealing_factor`
    for name, site in model_trace.nodes.items():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"].split('$$$')[0] in ['module'] else 1.0
            elbo = elbo + factor * site["fn"].log_prob(site["value"]).sum()
    for name, site in guide_trace.nodes.items():
        if site["type"] == "sample":
            factor = annealing_factor if site["name"].split('$$$')[0] in ['module'] else 1.0
            elbo = elbo - factor * site["fn"].log_prob(site["value"]).sum()
    return -elbo

In [16]:
pyro.clear_param_store()
optim = Adam({"lr": 0.01})
svi = SVI(net.model, net.guide, optim, loss=simple_elbo_kl_annealing)

In [17]:
def predict(x, net):
    sampled_models = net.guide(None, None)
    yhats = sampled_models(x).data
    return yhats

def train(e, svi, loader):
    train_loss = 0
    correct = 0
    total = 0
    m = math.ceil(len(loader.dataset)/batch_size)
    svi.optim = Adam({"lr": learning_rate(lr, e), 'weight_decay': weight_decay})
    
    for batch_idx, data in enumerate(loader):
        inputs_value = data[0]
        targets = data[1]
        
        x = inputs_value.view(-1, 1, resize, resize).repeat(num_samples, 1, 1, 1).cuda()
        y = targets.repeat(num_samples).cuda()
        
        beta = 2 ** (m - (batch_idx + 1)) / (2 ** m - 1)
        
        x, y = Variable(x), Variable(y)
        
        loss =svi.step(x, y, annealing_factor=beta)
        train_loss += loss
        
        predicted = torch.argmax(predict(x, svi), dim=1)
        correct += predicted.eq(y.data).cpu().sum().item()
        total += targets.size(0)
        
#         print('|Epoch:{}/{}|Iter:{}/{}|Loss:{}|Acc:{}'.format(
#             e, epoch, batch_idx+1, (len(loader.dataset.train_data)//batch_size)+1, loss, (100*correct/total)/num_samples))
    print('================>Epoch: ',e, 'Loss: ', train_loss/(len(loader.dataset.train_data)*num_samples), 'Acc: ', (100*correct/total)/num_samples) 

In [18]:
for e in range(epoch):
    train(e, svi, train_loader)



KeyboardInterrupt: 