In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import numpy as np
import matplotlib.pyplot as plt

# Improved data preprocessing
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4377, 0.4438, 0.4728), (0.1980, 0.2010, 0.1970))  # SVHN specific stats
])

train_dataset = datasets.SVHN(root='./data', split='train', transform=transform, download=True)
test_dataset = datasets.SVHN(root='./data', split='test', transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=128, shuffle=False)

class FlipoutLinear(nn.Module):
    def __init__(self, in_features, out_features, prior_sigma=1.0):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.prior_sigma = prior_sigma
        
        self.weight_mu = nn.Parameter(torch.zeros(out_features, in_features).normal_(0, 0.01))
        self.weight_rho = nn.Parameter(torch.zeros(out_features, in_features).normal_(-5, 0.01))
        self.bias_mu = nn.Parameter(torch.zeros(out_features).normal_(0, 0.01))
        self.bias_rho = nn.Parameter(torch.zeros(out_features).normal_(-5, 0.01))
        
        # Random signs for flipout
        self.register_buffer('input_sign', None)
        self.register_buffer('output_sign', None)
        
    def get_random_signs(self, batch_size, shape, device):
        return (2 * torch.bernoulli(torch.ones(batch_size, *shape, device=device) * 0.5) - 1)
        
    def forward(self, x):
        batch_size = x.size(0)
        
        if (self.input_sign is None or 
            self.input_sign.size(0) != batch_size or 
            self.output_sign is None or 
            self.output_sign.size(0) != batch_size):
            
            self.input_sign = self.get_random_signs(batch_size, (self.in_features,), x.device)
            self.output_sign = self.get_random_signs(batch_size, (self.out_features,), x.device)
        
        weight_sigma = F.softplus(self.weight_rho)
        bias_sigma = F.softplus(self.bias_rho)
        
        mean_output = F.linear(x, self.weight_mu, self.bias_mu)
        
        # Compute perturbation with corrected dimensions
        r_w = torch.randn_like(self.weight_mu)
        perturbed_weights = (r_w * weight_sigma).unsqueeze(0)
        
        x_reshape = x * self.input_sign
        perturbation = F.linear(x_reshape, perturbed_weights.squeeze(0)) * self.output_sign
        
        bias_perturbation = bias_sigma * torch.randn_like(self.bias_mu)
        
        return mean_output + perturbation + bias_perturbation

    def kl_loss(self):
        weight_sigma = F.softplus(self.weight_rho)
        bias_sigma = F.softplus(self.bias_rho)
        
        kl_weight = 0.5 * torch.sum(
            torch.log1p((weight_sigma**2) / (self.prior_sigma**2)) +
            (self.weight_mu**2) / (self.prior_sigma**2) - 1
        )
        
        kl_bias = 0.5 * torch.sum(
            torch.log1p((bias_sigma**2) / (self.prior_sigma**2)) +
            (self.bias_mu**2) / (self.prior_sigma**2) - 1
        )
        
        return (kl_weight + kl_bias) / self.weight_mu.numel()

class BayesianNNFlipout(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = FlipoutLinear(64 * 4 * 4, 512, prior_sigma=0.1)
        self.fc2 = FlipoutLinear(512, 10, prior_sigma=0.1)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(x)
        x = F.relu(self.conv2(x))
        x = self.pool(x)
        x = F.relu(self.conv3(x))
        x = self.pool(x)
        x = x.view(-1, 64 * 4 * 4)
        x = self.dropout(x)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        return self.fc2(x)

    def kl_loss(self):
        return self.fc1.kl_loss() + self.fc2.kl_loss()

def train_epoch(model, train_loader, optimizer, device, epoch, total_epochs=100):
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    beta = min(1.0, (epoch / (total_epochs * 0.2))) * 0.1
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        
        output = model(data)
        nll_loss = F.cross_entropy(output, target)
        kl_loss = model.kl_loss()
        loss = nll_loss + beta * kl_loss
        
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        pred = output.argmax(dim=1)
        correct += pred.eq(target).sum().item()
        total += target.size(0)
        
        if batch_idx % 100 == 0:
            print(f'Train Batch: {batch_idx}/{len(train_loader)}, '
                  f'Loss: {loss.item():.4f}, '
                  f'Acc: {100. * correct/total:.2f}%')
    
    return total_loss / len(train_loader), correct / total

def evaluate(model, test_loader, device):
    model.eval()
    test_loss = 0
    correct = 0
    
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            
            test_loss += F.cross_entropy(output, target).item()
            pred = output.argmax(dim=1)
            correct += pred.eq(target).sum().item()
    
    test_loss /= len(test_loader)
    accuracy = correct / len(test_loader.dataset)
    
    print(f'Test set: Average loss: {test_loss:.4f}, '
          f'Accuracy: {100. * accuracy:.2f}%')
    
    return test_loss, accuracy

# Training setup
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
model = BayesianNNFlipout().to(device)
optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

# Training loop with beta warmup
n_epochs = 100
beta_warmup_epochs = 10

for epoch in range(n_epochs):
    beta = min(1.0, epoch / beta_warmup_epochs)  # Linear warmup of KL term
    
    train_loss, train_acc = train_epoch(model, train_loader, optimizer, device, beta)
    test_loss, test_acc = evaluate(model, test_loader, device)
    
    scheduler.step(test_loss)
    
    print(f'Epoch: {epoch+1}/{n_epochs}')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {100*train_acc:.2f}%')
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {100*test_acc:.2f}%')
    print('-' * 60)

# Save model
torch.save(model.state_dict(), 'svhn_results_4000_samples_VI.pth')

Using downloaded and verified file: ./data/train_32x32.mat
Using downloaded and verified file: ./data/test_32x32.mat
Train Batch: 0/573, Loss: 2.3040, Acc: 10.16%
Train Batch: 100/573, Loss: 1.6860, Acc: 22.92%
Train Batch: 200/573, Loss: 0.9147, Acc: 40.51%
Train Batch: 300/573, Loss: 0.6271, Acc: 52.50%
Train Batch: 400/573, Loss: 0.5617, Acc: 59.24%
Train Batch: 500/573, Loss: 0.6370, Acc: 63.78%
Test set: Average loss: 0.4827, Accuracy: 85.91%
Epoch: 1/100
Train Loss: 1.0164, Train Acc: 66.17%
Test Loss: 0.4827, Test Acc: 85.91%
------------------------------------------------------------
Train Batch: 0/573, Loss: 0.4961, Acc: 86.72%
Train Batch: 100/573, Loss: 0.4564, Acc: 84.70%
Train Batch: 200/573, Loss: 0.3535, Acc: 85.05%
Train Batch: 300/573, Loss: 0.5641, Acc: 85.34%
Train Batch: 400/573, Loss: 0.3676, Acc: 85.44%
Train Batch: 500/573, Loss: 0.4007, Acc: 85.55%
Test set: Average loss: 0.3898, Accuracy: 88.44%
Epoch: 2/100
Train Loss: 0.4702, Train Acc: 85.69%
Test Loss: 0.3