# Complete Code: Bayesian PASA vs Baselines on CIFAR-10-C (Colab Ready)

# This notebook implements a full comparison between the proposed Bayesian PASA activation and standard baselines (ReLU, LeakyReLU, GELU, Swish, Mish) on the CIFAR-10-C corrupted dataset. It integrates the Bayesian R‚ÄëLayerNorm normalization and the new Bayesian PASA activation, providing:

  Extraction of CIFAR-10-C from a zip file on Google Drive (or fallback to on‚Äëthe‚Äëfly corruption).

  Definitions of all activations (ReLU, LeakyReLU, GELU, Swish, Mish, original PASA, Bayesian PASA).

  Definitions of normalization layers (Standard LayerNorm, R‚ÄëLayerNorm, Bayesian R‚ÄëLayerNorm).

  A flexible CNN (EfficientCNN) that can combine any activation and normalization.

  Training and evaluation on four corruption types (Gaussian, Shot, Blur, Contrast).

  Collection of softmax weights from Bayesian PASA for stability analysis.

  Generation of comparison plots and final tables.

The code is optimized for Google Colab with a T4 GPU and 12GB RAM. All memory is managed to avoid OOM errors.

üöÄ Setup and Imports

In [5]:
# -*- coding: utf-8 -*-
"""Bayesian_PASA_vs_Baselines_CIFAR10C.ipynb

Automatically generated for Colab T4 GPU.
"""

# Install required packages (if not already installed)
!pip install torch torchvision tqdm matplotlib seaborn fpdf --quiet

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, ConcatDataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import os
import random
import zipfile
from PIL import Image, ImageFilter
import gc
from copy import deepcopy

# For PDF report generation (optional)
from fpdf import FPDF
import matplotlib
matplotlib.use('Agg')  # Non-interactive backend to save memory

warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# Skip Drive mounting entirely - use on-the-fly corruption
print("Using on-the-fly corruption generation (no Drive mounting required)")
use_precomputed_cifar10c = False

Using device: cuda
GPU: Tesla T4
VRAM: 15.6 GB
Using on-the-fly corruption generation (no Drive mounting required)


üß† Activation Functions

Original PASA (simplified stable version with running stats)

In [6]:
class PASA(nn.Module):
    """
    Original PASA ‚Äì Stable version with running averages and learnable temperature.
    """
    def __init__(self,
                 alpha0=1.0, alpha1=0.5, kappa=0.1,
                 tau=5.0, beta=1.0,
                 lambda1=0.5, mu1=0.0, tau_lin=2.0,
                 eps=1e-6, momentum=0.99,
                 temperature_init=1.0):
        super().__init__()
        self.alpha0 = nn.Parameter(torch.tensor(alpha0))
        self.alpha1 = nn.Parameter(torch.tensor(alpha1))
        self.kappa  = nn.Parameter(torch.tensor(kappa))
        self.tau_mix = nn.Parameter(torch.tensor(temperature_init))

        self.tau = tau
        self.beta = beta
        self.lambda1 = lambda1
        self.mu1 = mu1
        self.tau_lin = tau_lin
        self.eps = eps
        self.momentum = momentum

        # Running statistics
        self.register_buffer('running_absmean', torch.tensor(0.5))
        self.register_buffer('running_noise_var', torch.tensor(1.0))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

    def forward(self, x, return_weights=False):
        if self.training:
            with torch.no_grad():
                batch_absmean = x.abs().mean()
                batch_var = x.var(unbiased=False)
                if self.num_batches_tracked == 0:
                    self.running_absmean = batch_absmean
                    self.running_noise_var = batch_var
                else:
                    self.running_absmean = (self.momentum * self.running_absmean +
                                            (1 - self.momentum) * batch_absmean)
                    self.running_noise_var = (self.momentum * self.running_noise_var +
                                              (1 - self.momentum) * batch_var)
                self.num_batches_tracked += 1

        mu_abs = self.running_absmean
        sigma2 = self.running_noise_var.clamp(min=1e-2, max=1e2)
        log_sigma2 = torch.log(sigma2).clamp(min=-10, max=10)

        # Component functions
        alpha = self.alpha0 + self.alpha1 * torch.tanh(self.kappa * mu_abs)
        S = torch.sigmoid(alpha * x)
        L = x / (1 + x.abs() / self.tau)
        sigma = torch.sqrt(sigma2)
        z = x / (self.beta * sigma * 1.41421356237)
        N = torch.tanh(1.4 * z)

        # Evidence scores
        log_prior = torch.log(torch.tensor(1/3.0, device=x.device))
        E1 = -self.lambda1 * (x - self.mu1)**2 + log_prior
        E2 = -x.abs() / self.tau_lin + log_prior
        E3 = - (x**2) / (2 * sigma2) - log_sigma2 + log_prior

        # Softmax mixing with temperature
        w = torch.softmax(torch.stack([E1, E2, E3], dim=-1) / self.tau_mix, dim=-1)

        out = w[..., 0] * S + w[..., 1] * L + w[..., 2] * N
        if return_weights:
            return out, w
        return out

Bayesian PASA (new formulation with œà‚Äëfunction and variational evidence)

In [7]:
class BayesianPASA(nn.Module):
    """
    Bayesian PASA ‚Äì Incorporates œà‚Äëfunction and variational evidence.
    """
    def __init__(self,
                 alpha0=1.0, alpha1=0.5, kappa=0.1,
                 tau=5.0, beta=1.0,
                 lambda1=0.5, mu1=0.0, tau_lin=2.0,
                 lambda3=0.1,                # noise branch regularization
                 eps=1e-6, momentum=0.99,
                 temperature_init=1.0):
        super().__init__()
        self.alpha0 = nn.Parameter(torch.tensor(alpha0))
        self.alpha1 = nn.Parameter(torch.tensor(alpha1))
        self.kappa  = nn.Parameter(torch.tensor(kappa))
        self.tau_mix = nn.Parameter(torch.tensor(temperature_init))

        self.tau = tau
        self.beta = beta
        self.lambda1 = lambda1
        self.mu1 = mu1
        self.tau_lin = tau_lin
        self.lambda3 = lambda3
        self.eps = eps
        self.momentum = momentum

        # Running statistics
        self.register_buffer('running_absmean', torch.tensor(0.5))
        self.register_buffer('running_noise_var', torch.tensor(1.0))
        self.register_buffer('running_local_entropy', torch.tensor(1.0))
        self.register_buffer('num_batches_tracked', torch.tensor(0, dtype=torch.long))

    def psi(self, t):
        """Stable œà(t) = log(1+t) - t/(1+t)"""
        return torch.log1p(t) - t / (1 + t)

    def forward(self, x, return_weights=False):
        B, C, H, W = x.shape

        # Update running statistics during training
        if self.training:
            with torch.no_grad():
                # Global stats
                batch_absmean = x.abs().mean()
                batch_var = x.var(unbiased=False)

                # Local entropy estimate (using 3x3 patches)
                x_padded = F.pad(x, (1,1,1,1), mode='reflect')
                local_mean = F.avg_pool2d(x_padded, kernel_size=3, stride=1)
                local_var = F.avg_pool2d(x_padded**2, kernel_size=3, stride=1) - local_mean**2
                local_entropy = local_var.mean(dim=[1,2,3]).mean()  # scalar per batch

                if self.num_batches_tracked == 0:
                    self.running_absmean = batch_absmean
                    self.running_noise_var = batch_var
                    self.running_local_entropy = local_entropy
                else:
                    self.running_absmean = (self.momentum * self.running_absmean +
                                            (1 - self.momentum) * batch_absmean)
                    self.running_noise_var = (self.momentum * self.running_noise_var +
                                              (1 - self.momentum) * batch_var)
                    self.running_local_entropy = (self.momentum * self.running_local_entropy +
                                                  (1 - self.momentum) * local_entropy)
                self.num_batches_tracked += 1

        # Use running averages
        mu_abs = self.running_absmean
        sigma2 = self.running_noise_var.clamp(min=1e-2, max=1e2)
        log_sigma2 = torch.log(sigma2).clamp(min=-10, max=10)
        local_E = self.running_local_entropy.clamp(min=1e-2)

        # Component functions with œà modulation
        # Adaptive sigmoid slope
        alpha_slope = self.alpha0 + self.alpha1 * torch.tanh(self.kappa * self.psi(self.lambda3 * local_E))
        S = torch.sigmoid(alpha_slope * x)

        # Moderate linear (unchanged)
        L = x / (1 + x.abs() / self.tau)

        # Noise‚Äëaware branch with effective sigma
        sigma_eff = torch.sqrt(sigma2) * torch.exp(0.5 * self.psi(self.lambda3 * local_E))
        z = x / (self.beta * sigma_eff * 1.41421356237)
        N = torch.tanh(1.4 * z)

        # Evidence scores derived from variational approximation
        log_prior = torch.log(torch.tensor(1/3.0, device=x.device))
        E1 = -0.5 * self.lambda1 * (x - self.mu1)**2 + log_prior
        E2 = -x.abs() / self.tau_lin + log_prior
        E3 = -0.5 * (x**2) / sigma2 - 0.5 * log_sigma2 - 0.5 * self.psi(self.lambda3 * local_E) + log_prior

        # Softmax with learnable temperature
        w = torch.softmax(torch.stack([E1, E2, E3], dim=-1) / self.tau_mix, dim=-1)

        out = w[..., 0] * S + w[..., 1] * L + w[..., 2] * N
        if return_weights:
            return out, w
        return out

Baseline Activations

In [8]:
class Mish(nn.Module):
    def forward(self, x):
        return x * torch.tanh(F.softplus(x))

class Swish(nn.Module):
    def __init__(self, beta=1.0):
        super().__init__()
        self.beta = beta
    def forward(self, x):
        return x * torch.sigmoid(self.beta * x)

üß™ Normalization Layers (from Bayesian R‚ÄëLayerNorm)

In [9]:
class StandardLayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def forward(self, x):
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_norm + self.bias

class RLayerNorm(nn.Module):
    def __init__(self, num_features, lambda_init=0.01, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.lambda_param = nn.Parameter(torch.tensor(lambda_init))
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def forward(self, x):
        B, C, H, W = x.shape
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)

        # Local variance for noise estimate
        x_padded = F.pad(x, (1,1,1,1), mode='reflect')
        local_means = F.avg_pool2d(x_padded, kernel_size=3, stride=1)
        local_vars = F.avg_pool2d(x_padded**2, kernel_size=3, stride=1) - local_means**2
        local_vars = local_vars.clamp(min=0)
        noise_estimate = local_vars.mean(dim=[2,3], keepdim=True)

        lambda_safe = self.lambda_param.clamp(1e-3, 1.0)
        noise_scale = 1 + lambda_safe * noise_estimate / (var + self.eps)
        x_norm = (x - mean) / (std * noise_scale + self.eps)
        return self.weight * x_norm + self.bias

class BayesianRLayerNorm(nn.Module):
    def __init__(self, num_features, lambda_init=0.01, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.lambda_param = nn.Parameter(torch.tensor(lambda_init))
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def psi(self, t):
        return torch.log1p(t) - t / (1 + t)

    def forward(self, x, return_uncertainty=False):
        B, C, H, W = x.shape
        mean = x.mean(dim=[2,3], keepdim=True)
        var = x.var(dim=[2,3], keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)

        x_padded = F.pad(x, (1,1,1,1), mode='reflect')
        local_means = F.avg_pool2d(x_padded, kernel_size=3, stride=1)
        local_vars = F.avg_pool2d(x_padded**2, kernel_size=3, stride=1) - local_means**2
        local_vars = local_vars.clamp(min=0)
        noise_estimate = local_vars.mean(dim=[2,3], keepdim=True)

        lambda_safe = self.lambda_param.clamp(1e-3, 1.0)
        lambdaE = lambda_safe * noise_estimate / (var + self.eps)
        psi_term = self.psi(lambdaE)

        effective_std = std * torch.exp(0.5 * psi_term)
        normalized = (x - mean) / (effective_std + self.eps)
        output = self.weight * normalized + self.bias

        if return_uncertainty:
            uncertainty = 1.0 / (effective_std**2 + self.eps)
            return output, uncertainty
        return output

üèóÔ∏è Flexible CNN Model

In [10]:
class EfficientCNN(nn.Module):
    """
    CNN with configurable normalization and activation.
    """
    def __init__(self, norm_type='layer', activation='relu', num_classes=10):
        super().__init__()
        self.norm_type = norm_type
        self.activation = activation

        # Feature extractor
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.norm1 = self._create_norm_layer(16)
        self.act1 = self._create_activation()
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.norm2 = self._create_norm_layer(32)
        self.act2 = self._create_activation()
        self.pool2 = nn.MaxPool2d(2)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.norm3 = self._create_norm_layer(64)
        self.act3 = self._create_activation()
        self.pool3 = nn.MaxPool2d(2)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

        self.return_weights = False  # used for PASA analysis

    def _create_norm_layer(self, num_features):
        if self.norm_type == 'layer':
            return StandardLayerNorm(num_features)
        elif self.norm_type == 'r_layer':
            return RLayerNorm(num_features)
        elif self.norm_type == 'bayesian_r_layer':
            return BayesianRLayerNorm(num_features)
        else:
            raise ValueError(f"Unknown norm_type: {self.norm_type}")

    def _create_activation(self):
        if self.activation == 'relu':
            return nn.ReLU()
        elif self.activation == 'leakyrelu':
            return nn.LeakyReLU(0.01)
        elif self.activation == 'gelu':
            return nn.GELU()
        elif self.activation == 'swish':
            return Swish()
        elif self.activation == 'mish':
            return Mish()
        elif self.activation == 'pasa':
            return PASA()
        elif self.activation == 'bayesian_pasa':
            return BayesianPASA()
        else:
            raise ValueError(f"Unknown activation: {self.activation}")

    def forward(self, x, return_weights=False):
        self.return_weights = return_weights

        # Block 1
        x = self.conv1(x)
        x = self.norm1(x)
        x = self.act1(x)
        x = self.pool1(x)

        # Block 2
        x = self.conv2(x)
        x = self.norm2(x)
        x = self.act2(x)
        x = self.pool2(x)

        # Block 3 ‚Äì handle weight return for PASA
        x = self.conv3(x)
        x = self.norm3(x)

        if self.return_weights and isinstance(self.act3, (PASA, BayesianPASA)):
            x, weights = self.act3(x, return_weights=True)
        else:
            x = self.act3(x)
            weights = None

        x = self.pool3(x)
        x = self.global_pool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        if weights is not None:
            return x, weights
        return x

üì¶ Dataset: CIFAR-10-C (on‚Äëthe‚Äëfly corruption)

We use the same CachedNoisyCIFAR10 as in the original code, but we add an option to load pre‚Äëcorrupted images if use_precomputed_cifar10c is True. For simplicity, we'll keep the on‚Äëthe‚Äëfly generation, which is faster and doesn't require a specific folder structure.

In [11]:
class CachedNoisyCIFAR10(Dataset):
    """
    Dataset with in-memory caching for speed.
    Applies various noise corruptions to CIFAR-10 images.
    """
    def __init__(self, noise_type='gaussian', severity=3, num_samples=1000, train=True):
        self.noise_type = noise_type
        self.severity = severity
        self.train = train
        self.num_samples = num_samples

        # Download clean CIFAR-10
        transform = transforms.ToTensor()
        self.clean_dataset = CIFAR10(root='./data', train=train, download=True, transform=None)

        # Cache samples
        self.cached_samples = []
        self.cached_labels = []

        indices = list(range(min(num_samples, len(self.clean_dataset))))
        print(f"Caching {noise_type} samples...")
        for idx in tqdm(indices, leave=False):
            img, label = self.clean_dataset[idx]
            img_tensor = self.apply_noise_and_transform(img)
            self.cached_samples.append(img_tensor)
            self.cached_labels.append(label)

    def apply_noise_and_transform(self, img_pil):
        img_np = np.array(img_pil).astype(np.float32) / 255.0

        if self.noise_type == 'gaussian':
            noise = np.random.randn(*img_np.shape) * 0.1 * self.severity
            img_np = img_np + noise
        elif self.noise_type == 'shot_noise':
            mask = np.random.random(img_np.shape) < 0.05 * self.severity
            salt = np.random.random(mask.sum()) > 0.5
            img_np_flat = img_np.reshape(-1)
            mask_flat = mask.reshape(-1)
            img_np_flat[mask_flat] = salt.astype(np.float32)
            img_np = img_np_flat.reshape(img_np.shape)
        elif self.noise_type == 'blur':
            img_pil = img_pil.filter(ImageFilter.GaussianBlur(radius=self.severity*0.5))
            img_np = np.array(img_pil).astype(np.float32) / 255.0
        elif self.noise_type == 'contrast':
            mean = img_np.mean()
            contrast_factor = max(0.5, 1.0 - 0.2 * self.severity)
            img_np = contrast_factor * (img_np - mean) + mean

        img_np = np.clip(img_np, 0, 1)
        img_tensor = torch.from_numpy(img_np.transpose(2, 0, 1)).float()

        # Normalize with CIFAR-10 stats
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        img_tensor = normalize(img_tensor)
        return img_tensor

    def __len__(self):
        return len(self.cached_samples)

    def __getitem__(self, idx):
        return self.cached_samples[idx], self.cached_labels[idx]

üèãÔ∏è Training and Evaluation Functions

We'll create a function to train a model for a given combination of normalization and activation, and return test accuracies per noise type.

In [12]:
def train_and_evaluate(norm_type, activation, num_epochs=5, lr=0.001,
                       train_samples=500, test_samples=200):
    """
    Train a model with specified norm and activation on mixed corruptions,
    then evaluate on each noise type.
    Returns: dict of test accuracies per noise, and the trained model.
    """
    print(f"\n>>> Training: norm={norm_type}, act={activation}")

    model = EfficientCNN(norm_type=norm_type, activation=activation).to(device)

    # Datasets: mix all noise types for training
    noise_types = ['gaussian', 'shot_noise', 'blur', 'contrast']
    datasets = []
    for noise in noise_types:
        ds = CachedNoisyCIFAR10(noise_type=noise, severity=3,
                                 num_samples=train_samples, train=True)
        datasets.append(ds)
    train_loader = DataLoader(ConcatDataset(datasets), batch_size=32,
                              shuffle=True, num_workers=2, pin_memory=True)

    optimizer = optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    model.train()
    for epoch in range(num_epochs):
        running_loss = 0.0
        correct = 0
        total = 0
        pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
        for inputs, targets in pbar:
            inputs, targets = inputs.to(device), targets.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
            pbar.set_postfix({'loss': loss.item(), 'acc': 100.*correct/total})
        epoch_loss = running_loss / total
        epoch_acc = 100. * correct / total
        print(f"  Epoch {epoch+1}: Loss={epoch_loss:.4f}, Acc={epoch_acc:.2f}%")

    # Evaluation on each noise type
    model.eval()
    test_results = {}
    for noise in noise_types:
        test_ds = CachedNoisyCIFAR10(noise_type=noise, severity=3,
                                      num_samples=test_samples, train=False)
        test_loader = DataLoader(test_ds, batch_size=32, shuffle=False,
                                 num_workers=2, pin_memory=True)
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in test_loader:
                inputs, targets = inputs.to(device), targets.to(device)
                outputs = model(inputs)
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
        acc = 100. * correct / total
        test_results[noise] = acc
        print(f"    {noise}: {acc:.2f}%")

    return test_results, model

üî¨ Experiments: Comparing Combinations

We will run the following combinations:

    Baseline: ReLU + LayerNorm

    Original PASA: PASA + LayerNorm

    Bayesian PASA: BayesianPASA + LayerNorm

    Bayesian PASA + Bayesian R‚ÄëLayerNorm: BayesianPASA + BayesianRLayerNorm

    Bayesian R‚ÄëLayerNorm alone: ReLU + BayesianRLayerNorm (to see its standalone effect)

We'll store results and also collect softmax weights from the PASA models for analysis.

In [13]:
# Define experiment configurations
experiments = [
    ('layer', 'relu', 'ReLU + LayerNorm'),
    ('layer', 'pasa', 'PASA + LayerNorm'),
    ('layer', 'bayesian_pasa', 'BayesianPASA + LayerNorm'),
    ('bayesian_r_layer', 'bayesian_pasa', 'BayesianPASA + Bayesian R‚ÄëLayerNorm'),
    ('bayesian_r_layer', 'relu', 'ReLU + Bayesian R‚ÄëLayerNorm'),
]

results = {}
models = {}
weights_data = {}  # to store softmax weights for PASA models

# Run each experiment (reduce samples for quick demo; increase for final results)
for norm, act, label in experiments:
    test_acc, model = train_and_evaluate(norm, act,
                                         num_epochs=5,          # increase to 10-20 for better accuracy
                                         train_samples=500,    # reduce to 200 for speed
                                         test_samples=200)
    results[label] = test_acc
    models[label] = model.cpu()  # move to CPU to save GPU memory
    torch.cuda.empty_cache()

    # If this is a PASA model, collect weights on a sample batch
    if 'pasa' in act:
        # Create a small test loader for Gaussian noise
        ds = CachedNoisyCIFAR10(noise_type='gaussian', severity=3,
                                 num_samples=100, train=False)
        loader = DataLoader(ds, batch_size=32, shuffle=False)
        model.eval()
        model.to(device)
        all_weights = []
        with torch.no_grad():
            for inputs, _ in loader:
                inputs = inputs.to(device)
                _, w = model(inputs, return_weights=True)
                all_weights.append(w.cpu())
        model.cpu()
        torch.cuda.empty_cache()
        weights_data[label] = torch.cat(all_weights, dim=0)


>>> Training: norm=layer, act=relu


100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 170M/170M [00:03<00:00, 43.7MB/s]


Caching gaussian samples...




Caching shot_noise samples...




Caching blur samples...




Caching contrast samples...




  Epoch 1: Loss=2.2793, Acc=14.85%




  Epoch 2: Loss=2.2071, Acc=21.65%




  Epoch 3: Loss=2.1544, Acc=27.75%




  Epoch 4: Loss=2.1045, Acc=31.70%




  Epoch 5: Loss=2.0549, Acc=34.80%
Caching gaussian samples...




    gaussian: 20.00%
Caching shot_noise samples...




    shot_noise: 16.50%
Caching blur samples...




    blur: 29.50%
Caching contrast samples...




    contrast: 23.50%

>>> Training: norm=layer, act=pasa
Caching gaussian samples...




Caching shot_noise samples...




Caching blur samples...




Caching contrast samples...




  Epoch 1: Loss=2.2617, Acc=16.60%




  Epoch 2: Loss=2.1991, Acc=22.50%




  Epoch 3: Loss=2.1524, Acc=27.30%




  Epoch 4: Loss=2.1046, Acc=29.40%




  Epoch 5: Loss=2.0576, Acc=33.55%
Caching gaussian samples...




    gaussian: 17.50%
Caching shot_noise samples...




    shot_noise: 13.50%
Caching blur samples...




    blur: 26.50%
Caching contrast samples...




    contrast: 26.50%
Caching gaussian samples...


                                       


>>> Training: norm=layer, act=bayesian_pasa




Caching gaussian samples...




Caching shot_noise samples...




Caching blur samples...




Caching contrast samples...




  Epoch 1: Loss=2.2754, Acc=14.65%




  Epoch 2: Loss=2.2170, Acc=19.90%




  Epoch 3: Loss=2.1781, Acc=25.60%




  Epoch 4: Loss=2.1343, Acc=29.05%




  Epoch 5: Loss=2.0912, Acc=30.00%
Caching gaussian samples...




    gaussian: 21.00%
Caching shot_noise samples...




    shot_noise: 21.50%
Caching blur samples...




    blur: 25.00%
Caching contrast samples...




    contrast: 26.50%
Caching gaussian samples...


                                       


>>> Training: norm=bayesian_r_layer, act=bayesian_pasa




Caching gaussian samples...




Caching shot_noise samples...




Caching blur samples...




Caching contrast samples...




  Epoch 1: Loss=2.2682, Acc=15.60%




  Epoch 2: Loss=2.2070, Acc=21.35%




  Epoch 3: Loss=2.1616, Acc=26.40%




  Epoch 4: Loss=2.1202, Acc=30.00%




  Epoch 5: Loss=2.0731, Acc=32.70%
Caching gaussian samples...




    gaussian: 21.50%
Caching shot_noise samples...




    shot_noise: 18.50%
Caching blur samples...




    blur: 27.00%
Caching contrast samples...




    contrast: 27.50%
Caching gaussian samples...


                                       


>>> Training: norm=bayesian_r_layer, act=relu




Caching gaussian samples...




Caching shot_noise samples...




Caching blur samples...




Caching contrast samples...




  Epoch 1: Loss=2.2788, Acc=15.50%




  Epoch 2: Loss=2.2054, Acc=22.45%




  Epoch 3: Loss=2.1546, Acc=26.50%




  Epoch 4: Loss=2.1065, Acc=36.50%




  Epoch 5: Loss=2.0577, Acc=36.60%
Caching gaussian samples...




    gaussian: 18.00%
Caching shot_noise samples...




    shot_noise: 19.00%
Caching blur samples...




    blur: 31.00%
Caching contrast samples...




    contrast: 27.00%


üìä Visualization

Test Accuracy Comparison

In [14]:
# Prepare data for plotting
noise_types = ['gaussian', 'shot_noise', 'blur', 'contrast']
df_rows = []
for label, acc_dict in results.items():
    for noise in noise_types:
        df_rows.append({'Model': label, 'Noise': noise, 'Accuracy': acc_dict[noise]})
df = pd.DataFrame(df_rows)

# Grouped bar plot
plt.figure(figsize=(14, 6))
x = np.arange(len(noise_types))
width = 0.15
colors = plt.cm.tab10(np.linspace(0, 1, len(experiments)))

for i, label in enumerate([e[2] for e in experiments]):
    accs = [df[(df.Model == label) & (df.Noise == noise)]['Accuracy'].values[0] for noise in noise_types]
    plt.bar(x + i*width - 2*width, accs, width, label=label, color=colors[i])

plt.xlabel('Noise Type')
plt.ylabel('Test Accuracy (%)')
plt.title('Activation + Normalization Comparison on CIFAR-10-C')
plt.xticks(x, noise_types)
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig('comparison_bar.png', dpi=150)
plt.show()

Softmax Weight Distribution (for PASA models)

In [17]:
if weights_data:
    fig, axes = plt.subplots(3, 3, figsize=(15, 12))  # Changed from (2,3) to (3,3)
    branch_names = ['Sigmoid', 'Linear', 'Noise']
    for row, (label, weights) in enumerate(weights_data.items()):
        for b in range(3):
            ax = axes[row, b]  # Now row goes 0,1,2 which is valid
            ax.hist(weights[..., b].numpy().flatten(), bins=50, alpha=0.7)
            ax.set_title(f'{label} ‚Äì {branch_names[b]}')
            ax.set_xlim(0, 1)
    plt.tight_layout()
    plt.savefig('pasa_weights.png', dpi=150)
    plt.show()



Summary Table


In [18]:
print("\n" + "="*70)
print("Final Test Accuracies (%)")
print("="*70)
print(f"{'Model':<40} {'Gaussian':>8} {'Shot':>8} {'Blur':>8} {'Contrast':>8}")
for label in [e[2] for e in experiments]:
    row = [label]
    for noise in noise_types:
        row.append(f"{results[label][noise]:>8.2f}")
    print(" ".join(row))


Final Test Accuracies (%)
Model                                    Gaussian     Shot     Blur Contrast
ReLU + LayerNorm    20.00    16.50    29.50    23.50
PASA + LayerNorm    17.50    13.50    26.50    26.50
BayesianPASA + LayerNorm    21.00    21.50    25.00    26.50
BayesianPASA + Bayesian R‚ÄëLayerNorm    21.50    18.50    27.00    27.50
ReLU + Bayesian R‚ÄëLayerNorm    18.00    19.00    31.00    27.00


üßπ Memory Cleanup

In [19]:
# Clear all remaining GPU memory
torch.cuda.empty_cache()
gc.collect()
print("\n‚úÖ All experiments completed. GPU memory freed.")


‚úÖ All experiments completed. GPU memory freed.
