In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Normal
import numpy as np

In [7]:
# BNN层，类似于BP网络的Linear层，与BP网络类似，一层BNN层由weight和bias组成，weight和bias都具有均值和方差
class Linear_BBB(nn.Module):
    def __init__(self, input_features, output_features, prior_var=1.):
        """
            初始化 : 先验是一个以0为中心，方差为20的正态分布
        """
        # 初始化神经网络各层
        super().__init__()
        # 设置输入输出维度
        self.input_features = input_features
        self.output_features = output_features

        # initialize mu and rho parameters for the weights of the layer
        self.w_mu = nn.Parameter(torch.zeros(output_features, input_features))
        self.w_rho = nn.Parameter(torch.zeros(output_features, input_features))

        #initialize mu and rho parameters for the layer's bias
        self.b_mu =  nn.Parameter(torch.zeros(output_features))
        self.b_rho = nn.Parameter(torch.zeros(output_features))

        #initialize weight samples (these will be calculated whenever the layer makes a prediction)
        self.w = None
        self.b = None

        # initialize prior distribution for all of the weights and biases
        self.prior = torch.distributions.Normal(0,prior_var)

    def forward(self, input):
        """
          Optimization process
        """
        # sample weights
        # 从标准正态分布中采样权重
        w_epsilon = Normal(0,1).sample(self.w_mu.shape)
        # 获得服从均值为mu，方差为delta的正态分布的样本
        self.w = self.w_mu + torch.log(1+torch.exp(self.w_rho)) * w_epsilon

        # sample bias
        # 与sample weights同理
        b_epsilon = Normal(0,1).sample(self.b_mu.shape)
        self.b = self.b_mu + torch.log(1+torch.exp(self.b_rho)) * b_epsilon

        # record log prior by evaluating log pdf of prior at sampled weight and bias
        # 计算log p(w)，用于后续计算loss
        w_log_prior = self.prior.log_prob(self.w)
        b_log_prior = self.prior.log_prob(self.b)
        self.log_prior = torch.sum(w_log_prior) + torch.sum(b_log_prior)

        # record log variational posterior by evaluating log pdf of normal distribution defined by parameters with respect at the sampled values
        # 计算 log p(w|\theta)，用于后续计算loss
        self.w_post = Normal(self.w_mu.data, torch.log(1+torch.exp(self.w_rho)))
        self.b_post = Normal(self.b_mu.data, torch.log(1+torch.exp(self.b_rho)))
        self.log_post = self.w_post.log_prob(self.w).sum() + self.b_post.log_prob(self.b).sum()

        # 权重确定后，和BP网络层一样使用
        return F.linear(input, self.w, self.b)

In [8]:
class MLP_BBB(nn.Module):
    def __init__(self, hidden_units, noise_tol=.1,  prior_var=1.):

        # initialize the network like you would with a standard multilayer perceptron, but using the BBB layer
        super().__init__()
        # 输入为1，输出为1，只含有一层隐藏层的BNN
        self.hidden = Linear_BBB(1,hidden_units, prior_var=prior_var)
        self.out = Linear_BBB(hidden_units, 1, prior_var=prior_var)
        self.noise_tol = noise_tol # we will use the noise tolerance to calculate our likelihood

    def forward(self, x):
        # again, this is equivalent to a standard multilayer perceptron
        # 激活函数选用sigmoid
        x = torch.sigmoid(self.hidden(x))
        x = self.out(x)
        return x

    def log_prior(self):
        # calculate the log prior over all the layers
        return self.hidden.log_prior + self.out.log_prior

    def log_post(self):
        # calculate the log posterior over all the layers
        return self.hidden.log_post + self.out.log_post

    # 计算loss
    def sample_elbo(self, input, target, samples):
        # we calculate the negative elbo, which will be our loss function
        #initialize tensors
        outputs = torch.zeros(samples, target.shape[0])
        log_priors = torch.zeros(samples)
        log_posts = torch.zeros(samples)
        log_likes = torch.zeros(samples)
        # make predictions and calculate prior, posterior, and likelihood for a given number of samples

        # 蒙特卡罗近似
        for i in range(samples):
            outputs[i] = self(input).reshape(-1) # make predictions
            log_priors[i] = self.log_prior() # get log prior
            log_posts[i] = self.log_post() # get log variational posterior
            log_likes[i] = Normal(outputs[i], self.noise_tol).log_prob(target.reshape(-1)).sum() # calculate the log likelihood
        # calculate monte carlo estimate of prior posterior and likelihood
        log_prior = log_priors.mean()
        log_post = log_posts.mean()
        log_like = log_likes.mean()
        # calculate the negative elbo (which is our loss function)
        loss = log_post - log_prior - log_like
        return loss

In [9]:
def toy_function(x):
    return -x**4 + 3*x**2 + 1

In [12]:
# toy dataset we can start with
x = torch.tensor([-2, -1.8, -1, 1, 1.8, 2]).reshape(-1,1)
print(x)
y = toy_function(x)

tensor([[-2.0000],
        [-1.8000],
        [-1.0000],
        [ 1.0000],
        [ 1.8000],
        [ 2.0000]])


In [4]:
samples = 500
x_tmp = torch.linspace(-2,2,100).reshape(-1,1)
y_samp = np.zeros((samples,100))
print(x_tmp.shape)
print(y_samp.shape)

torch.Size([100, 1])
(500, 100)


In [11]:
net = MLP_BBB(32, prior_var=10)
optimizer = optim.Adam(net.parameters(), lr=.1)
epochs = 2000
for epoch in range(epochs):  # loop over the dataset multiple times
    optimizer.zero_grad()
    # forward + backward + optimize
    loss = net.sample_elbo(x, y, 1)
    loss.backward()
    optimizer.step()
    if epoch % 10 == 0:
        print('epoch: {}/{}'.format(epoch+1,epochs))
        print('Loss:', loss.item())
print('Finished Training')

samples = 100
x_tmp = torch.linspace(-5,5,100).reshape(-1,1)
y_samp = np.zeros((samples,100))
for s in range(samples):
    y_tmp = net(x_tmp).detach().numpy()
    y_samp[s] = y_tmp.reshape(-1)

print("test result:",np.mean(y_samp, axis = 0))

epoch: 1/2000
Loss: 2168.072265625
epoch: 11/2000
Loss: 2648.31494140625
epoch: 21/2000
Loss: 3481.820556640625
epoch: 31/2000
Loss: 3500.812744140625
epoch: 41/2000
Loss: 2767.3134765625
epoch: 51/2000
Loss: 2900.45654296875
epoch: 61/2000
Loss: 9668.7900390625
epoch: 71/2000
Loss: 2263.67724609375
epoch: 81/2000
Loss: 2259.599853515625
epoch: 91/2000
Loss: 2051.795166015625
epoch: 101/2000
Loss: 2169.40673828125
epoch: 111/2000
Loss: 1948.0478515625
epoch: 121/2000
Loss: 1738.3046875
epoch: 131/2000
Loss: 1786.688720703125
epoch: 141/2000
Loss: 1605.324951171875
epoch: 151/2000
Loss: 1666.7508544921875
epoch: 161/2000
Loss: 1609.324462890625
epoch: 171/2000
Loss: 1625.7210693359375
epoch: 181/2000
Loss: 1114.5604248046875
epoch: 191/2000
Loss: 1207.525390625
epoch: 201/2000
Loss: 1026.197509765625
epoch: 211/2000
Loss: 868.3079223632812
epoch: 221/2000
Loss: 884.6836547851562
epoch: 231/2000
Loss: 920.8575439453125
epoch: 241/2000
Loss: 734.8651123046875
epoch: 251/2000
Loss: 1097.72