<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Bayesian_Neural_Networks_(BNNs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pyro

In [None]:
import pyro
import pyro.distributions as dist
import torch.nn as nn
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

class BayesianNN(nn.Module):
    def __init__(self, input_dim):
        super(BayesianNN, self).__init__()
        self.linear = nn.Linear(input_dim, 1)

    def model(self, x, y):
        w_prior = dist.Normal(torch.zeros(1), torch.ones(1))
        b_prior = dist.Normal(torch.zeros(1), torch.ones(1))
        priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
        lifted_module = pyro.random_module('module', self, priors)
        lifted_reg_model = lifted_module()
        prediction_mean = lifted_reg_model(x).squeeze(-1)
        pyro.sample('obs', dist.Normal(prediction_mean, 1), obs=y)

    def guide(self, x, y):
        w_loc = pyro.param('w_loc', torch.randn(1))
        w_scale = pyro.param('w_scale', torch.ones(1), constraint=dist.constraints.positive)
        b_loc = pyro.param('b_loc', torch.randn(1))
        b_scale = pyro.param('b_scale', torch.ones(1), constraint=dist.constraints.positive)
        w_dist = dist.Normal(w_loc, w_scale)
        b_dist = dist.Normal(b_loc, b_scale)
        dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
        lifted_module = pyro.random_module('module', self, dists)
        return lifted_module()

bnn = BayesianNN(input_dim=1)
svi = SVI(bnn.model, bnn.guide, Adam({'lr': 0.03}), loss=Trace_ELBO())

x_data = torch.tensor([[1.0], [2.0], [3.0]])
y_data = torch.tensor([1.5, 3.5, 2.5])
for step in range(1000):
    loss = svi.step(x_data, y_data)
    if step % 100 == 0:
        print(f'Step {step}, Loss: {loss}')