<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Bayesian_Neural_Networks_(BNNs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install pyro

In [None]:
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

class BayesianRegression(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(BayesianRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

    def model(self, x, y=None):
        w_prior = dist.Normal(torch.zeros_like(self.linear.weight), torch.ones_like(self.linear.weight))
        b_prior = dist.Normal(torch.zeros_like(self.linear.bias), torch.ones_like(self.linear.bias))
        priors = {'linear.weight': w_prior, 'linear.bias': b_prior}
        lifted_module = pyro.random_module("module", self, priors)
        lifted_reg_model = lifted_module()
        with pyro.plate("map", x.shape[0]):
            prediction = lifted_reg_model(x)
            pyro.sample("obs", dist.Normal(prediction, 1.0), obs=y)

    def guide(self, x, y=None):
        w_mu = torch.randn_like(self.linear.weight)
        w_sigma = torch.randn_like(self.linear.weight)
        b_mu = torch.randn_like(self.linear.bias)
        b_sigma = torch.randn_like(self.linear.bias)
        w_dist = dist.Normal(w_mu, w_sigma)
        b_dist = dist.Normal(b_mu, b_sigma)
        dists = {'linear.weight': w_dist, 'linear.bias': b_dist}
        lifted_module = pyro.random_module("module", self, dists)
        return lifted_module()

input_dim = 1
output_dim = 1
model = BayesianRegression(input_dim, output_dim)
optimizer = Adam({"lr": 0.01})
svi = SVI(model.model, model.guide, optimizer, loss=Trace_ELBO())

# Example training loop
x = torch.randn(100, 1)
y = 3 * x + torch.randn(100, 1) * 0.5
for step in range(1000):
    loss = svi.step(x, y)
    if step % 100 == 0:
        print(f"Step {step}, Loss: {loss}")