<a href="https://colab.research.google.com/github/RachelZhou287/542_Final/blob/main/BNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torchvision
print(torchvision.__version__)
from torchvision import transforms
from torch.utils.data import DataLoader,Dataset
from PIL import Image
import matplotlib.pyplot as plt
import pandas as pd
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pandas as pd
import pyro
from pyro.distributions import Normal, Categorical
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam

0.23.0+cu126


In [2]:
# Test integrity first
!tar -tzf /content/CUB_200_2011.tgz > /dev/null


In [3]:
# Extract safely
!mkdir -p /content/data
!tar -xzf /content/CUB_200_2011.tgz -C /content/data/

# Confirm
!ls /content/data/CUB_200_2011 | head

attributes
bounding_boxes.txt
classes.txt
image_class_labels.txt
images
images.txt
parts
README
train_test_split.txt


In [16]:
# Confirm
!ls /content/data/CUB_200_2011 | head

attributes
bounding_boxes.txt
classes.txt
image_class_labels.txt
images
images.txt
parts
README
train_test_split.txt


# Load dataset

In [17]:

class CUBDataset(Dataset):
    def __init__(self, root, train=True, transform=None):
        self.root = root
        self.transform = transform

        img_files = pd.read_csv(os.path.join(root, "images.txt"), sep=" ", names=["img_id", "filepath"])
        labels = pd.read_csv(os.path.join(root, "image_class_labels.txt"), sep=" ", names=["img_id", "target"])
        split = pd.read_csv(os.path.join(root, "train_test_split.txt"), sep=" ", names=["img_id", "is_training_img"])
        df = img_files.merge(labels, on="img_id").merge(split, on="img_id")
        df = df[df["is_training_img"] == int(train)]

        self.paths = df["filepath"].values
        self.targets = df["target"].values - 1  # 0-indexed

    def __len__(self):  return len(self.paths) # number of samples

    def __getitem__(self, idx):
        img_path = os.path.join(self.root, "images", self.paths[idx])
        img = Image.open(img_path).convert("RGB")
        label = self.targets[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# === Transforms + loaders ===
transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor()
])

data_root = "/content/data/CUB_200_2011"
train_data = CUBDataset(data_root, train=True, transform=transform)
test_data  = CUBDataset(data_root, train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2) #batch size: group 32 images per training step
test_loader  = DataLoader(test_data, batch_size=32, shuffle=False, num_workers=2)

print(f"✅ Loaded {len(train_data)} training and {len(test_data)} testing images.")


✅ Loaded 5994 training and 5794 testing images.


# NN Structure

In [13]:
class NN(nn.Module):
  def __init__(self, input_size, hidden_size, output_size):
    super(NN, self).__init__()
    self.fc1=nn.Linear(input_size,hidden_size) # y=xW^T+b
    self.out=nn.Linear(hidden_size,output_size)

  def forward(self,x):
    x = x.view(x.size(0), -1)
    output=self.fc1(x)
    output=F.relu(output)
    output=self.out(output)
    return output # logit prob

    log_prob=F.log_softmax(output,dim=1)
net=NN(input_size=3*128*128,hidden_size=512,output_size=200) # 200 species in total


# Model

In [18]:
def model(x_data,y_data):

  fc1w_prior=Normal(loc=torch.zeros_like(net.fc1.weight),scale=torch.ones_like(net.fc1.weight)).to_event(2)
  fc1b_prior=Normal(loc=torch.zeros_like(net.fc1.bias),scale=torch.ones_like(net.fc1.bias)).to_event(1)

  outw_prior=Normal(loc=torch.zeros_like(net.out.weight),scale=torch.ones_like(net.out.weight)).to_event(2)
  outb_prior=Normal(loc=torch.zeros_like(net.out.bias),scale=torch.ones_like(net.out.bias)).to_event(1)

  priors={'fc1.weight':fc1w_prior,'fc1.bias':fc1b_prior,'out.weight':outw_prior,'out.bias':outb_prior}
  lifted_model=pyro.random_module("module",net,priors)
  lifted_reg_model=lifted_model() # one regressor model sampling
  lhat=log_prob(lifted_reg_model(x_data)) # compute log_probability
  pyro.sample("obs", Categorical(logits=lhat), obs=y_data)


# Guide

In [15]:
softplus = torch.nn.Softplus()

def guide(x_data, y_data):

    # 1st layer weight
    fc1w_mu = torch.randn_like(net.fc1.weight)
    fc1w_sigma = torch.randn_like(net.fc1.weight)
    fc1w_mu_param = pyro.param("fc1w_mu", fc1w_mu)
    fc1w_sigma_param = softplus(pyro.param("fc1w_sigma", fc1w_sigma))
    fc1w_prior = Normal(loc=fc1w_mu_param, scale=fc1w_sigma_param)
    fc1b_mu = torch.randn_like(net.fc1.bias)
    fc1b_sigma = torch.randn_like(net.fc1.bias)
    fc1b_mu_param = pyro.param("fc1b_mu", fc1b_mu)
    fc1b_sigma_param = softplus(pyro.param("fc1b_sigma", fc1b_sigma))
    fc1b_prior = Normal(loc=fc1b_mu_param, scale=fc1b_sigma_param)
    # Output layer weight distribution priors
    outw_mu = torch.randn_like(net.out.weight)
    outw_sigma = torch.randn_like(net.out.weight)
    outw_mu_param = pyro.param("outw_mu", outw_mu)
    outw_sigma_param = softplus(pyro.param("outw_sigma", outw_sigma))
    outw_prior = Normal(loc=outw_mu_param, scale=outw_sigma_param).independent(1)
    # Output layer bias distribution priors
    outb_mu = torch.randn_like(net.out.bias)
    outb_sigma = torch.randn_like(net.out.bias)
    outb_mu_param = pyro.param("outb_mu", outb_mu)
    outb_sigma_param = softplus(pyro.param("outb_sigma", outb_sigma))
    outb_prior = Normal(loc=outb_mu_param, scale=outb_sigma_param)
    priors = {'fc1.weight': fc1w_prior, 'fc1.bias': fc1b_prior, 'out.weight': outw_prior, 'out.bias': outb_prior}

    lifted_module = pyro.random_module("module", net, priors)

    return lifted_module()

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import constraints
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro
from pyro.distributions import Normal, Categorical


# ======================================================
# 1️⃣ Define NN architecture
# ======================================================
class NN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        # Flatten 3×128×128 RGB images → [batch, 49152]
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.out(x)  # raw logits (no softmax)
        return x


# Initialize network
input_size = 3 * 128 * 128
net = NN(input_size=input_size, hidden_size=512, output_size=200)


# ======================================================
# 2️⃣ Define Bayesian model
# ======================================================
def model(x_data, y_data):
    # Normal(0,1) priors on all weights and biases
    fc1w_prior = Normal(torch.zeros_like(net.fc1.weight),
                        torch.ones_like(net.fc1.weight)).to_event(2)
    fc1b_prior = Normal(torch.zeros_like(net.fc1.bias),
                        torch.ones_like(net.fc1.bias)).to_event(1)
    outw_prior = Normal(torch.zeros_like(net.out.weight),
                        torch.ones_like(net.out.weight)).to_event(2)
    outb_prior = Normal(torch.zeros_like(net.out.bias),
                        torch.ones_like(net.out.bias)).to_event(1)

    priors = {
        'fc1.weight': fc1w_prior,
        'fc1.bias': fc1b_prior,
        'out.weight': outw_prior,
        'out.bias': outb_prior
    }

    lifted_model = pyro.random_module("module", net, priors)
    lifted_reg_model = lifted_model()  # one sample of network parameters

    # Forward pass
    logits = lifted_reg_model(x_data)

    # Each image is an independent observation
    with pyro.plate("data", x_data.size(0)):
        pyro.sample("obs", Categorical(logits=logits), obs=y_data)


# ======================================================
# 3️⃣ Define Guide (Variational Posterior)
# ======================================================
soft_plus = torch.nn.Softplus()

def guide(x_data, y_data):
    # Layer 1 weight posterior
    fc1w_mu = pyro.param("fc1w_mu", torch.randn_like(net.fc1.weight))
    fc1w_sigma_param = pyro.param("fc1w_sigma", torch.ones_like(net.fc1.weight))
    fc1w_sigma = soft_plus(fc1w_sigma_param)
    fc1w_dist = Normal(fc1w_mu, fc1w_sigma).to_event(2)

    # Layer 1 bias posterior
    fc1b_mu = pyro.param("fc1b_mu", torch.randn_like(net.fc1.bias))
    fc1b_sigma_param = pyro.param("fc1b_sigma", torch.ones_like(net.fc1.bias))
    fc1b_sigma = soft_plus(fc1b_sigma_param)
    fc1b_dist = Normal(fc1b_mu, fc1b_sigma).to_event(1)

    # Output layer weight posterior
    outw_mu = pyro.param("outw_mu", torch.randn_like(net.out.weight))
    outw_sigma_param = pyro.param("outw_sigma", torch.ones_like(net.out.weight))
    outw_sigma = soft_plus(outw_sigma_param)
    outw_dist = Normal(outw_mu, outw_sigma).to_event(2)

    # Output layer bias posterior
    outb_mu = pyro.param("outb_mu", torch.randn_like(net.out.bias))
    outb_sigma_param = pyro.param("outb_sigma", torch.ones_like(net.out.bias))
    outb_sigma = soft_plus(outb_sigma_param)
    outb_dist = Normal(outb_mu, outb_sigma).to_event(1)

    dists = {
        'fc1.weight': fc1w_dist,
        'fc1.bias': fc1b_dist,
        'out.weight': outw_dist,
        'out.bias': outb_dist
    }

    lifted_module = pyro.random_module("module", net, dists)
    return lifted_module()


# ======================================================
# 4️⃣ Training Loop (SVI)
# ======================================================
pyro.clear_param_store()

optimizer = Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

num_epochs = 5

for epoch in range(num_epochs):
    total_loss = 0.0
    for batch_id, (x, y) in enumerate(train_loader):
        x = x.view(x.size(0), -1)  # Flatten
        y = y.long()               # Ensure class indices
        loss = svi.step(x, y)
        total_loss += loss

    avg_loss = total_loss / len(train_loader.dataset)
    print(f"Epoch {epoch+1} - Loss: {avg_loss:.4f}")


