# Complete implementation of R-LayerNorm and Bayesian R-LayerNorm for noisy image data WITH PDF REPORT GENERATION
# Optimized for Kaggle Notebook P100 GPU with 16GB RAM
# Version: Updated sample sizes for statistical significance

# Bayesian R-LayerNorm: Robust Normalization for Corrupted Inputs

## Motivation
Standard normalization layers (BatchNorm, LayerNorm) assume i.i.d. features and can be brittle under input corruptions. This experiment explores two variants — **R-LayerNorm** and **Bayesian R-LayerNorm** — that adapt normalization statistics using local noise estimates. The goal is to see whether these adaptive layers improve robustness on common image corruptions without sacrificing clean accuracy.

## Methods
We use a small CNN (three conv layers, Tanh activations) and replace all normalization layers with one of three types:
- **LayerNorm** (baseline): standard per-channel mean/variance normalization.
- **R-LayerNorm**: adds a learned scaling of the standard deviation based on local variance (noise estimate).
- **Bayesian R-LayerNorm**: further refines the scaling using a psi‑function that models uncertainty in the noise estimate, optionally returning an uncertainty map.

We train on CIFAR‑10 (50k images, 10% held out for validation) and evaluate on four corruptions (Gaussian noise, shot noise, Gaussian blur, contrast) at severity 3.  
For each normalization type, we run three independent seeds (42, 123, 456) for 50 epochs with Adam (lr=0.001).  
Test sets are either generated on‑the‑fly (if CIFAR‑10‑C is unavailable) or loaded from the pre‑generated CIFAR‑10‑C dataset.

## What we compare
- Training curves (loss, validation accuracy) with mean ± std across seeds.
- Final test accuracy on each corruption type, again with mean ± std.
- Improvement over standard LayerNorm to quantify gains.
- (For Bayesian R‑LayerNorm we also record uncertainty estimates, but they are not analysed in this notebook.)

All models and results are saved to `/kaggle/working/` for later inspection.

In [1]:
# ============================================================
# Cell 1: Setup and imports
# ============================================================
import time
import torch
import gc
# Clear CUDA cache aggressively
torch.cuda.empty_cache()
time.sleep(2)
gc.collect()

# Reset GPU memory stats
if torch.cuda.is_available():
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()

!pip install seaborn tqdm -q

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset, ConcatDataset
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm
import warnings
import os
import random
from PIL import Image, ImageFilter
import gc

warnings.filterwarnings('ignore')

# ============================================================
# Utility: set seed (called at the start of each run)
# ============================================================
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ============================================================
# Normalization layer definitions (unchanged)
# ============================================================
class StandardLayerNorm(nn.Module):
    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def forward(self, x):
        mean = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True, unbiased=False)
        x_norm = (x - mean) / torch.sqrt(var + self.eps)
        return self.weight * x_norm + self.bias

class RLayerNorm(nn.Module):
    def __init__(self, num_features, lambda_init=0.01, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.lambda_param = nn.Parameter(torch.tensor(float(lambda_init)))
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def forward(self, x):
        B, C, H, W = x.shape
        mean = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)

        # Local variance for noise estimation
        x_padded = F.pad(x, (1, 1, 1, 1), mode='reflect')
        local_means = F.avg_pool2d(x_padded, kernel_size=3, stride=1)
        local_vars = F.avg_pool2d(x_padded**2, kernel_size=3, stride=1) - local_means**2
        local_vars = local_vars.clamp(min=0)
        noise_est = local_vars.mean(dim=[2, 3], keepdim=True)

        lambda_safe = self.lambda_param.clamp(1e-3, 1.0)
        noise_scale = 1 + lambda_safe * noise_est / (var + self.eps)
        x_norm = (x - mean) / (std * noise_scale + self.eps)
        return self.weight * x_norm + self.bias

class BayesianRLayerNorm(nn.Module):
    def __init__(self, num_features, lambda_init=0.01, eps=1e-5):
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.lambda_param = nn.Parameter(torch.tensor(float(lambda_init)))
        self.weight = nn.Parameter(torch.ones(1, num_features, 1, 1))
        self.bias = nn.Parameter(torch.zeros(1, num_features, 1, 1))

    def psi_function(self, t):
        return torch.log1p(t) - t / (1 + t)

    def forward(self, x, return_uncertainty=False):
        B, C, H, W = x.shape
        mean = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True, unbiased=False)
        std = torch.sqrt(var + self.eps)

        x_padded = F.pad(x, (1, 1, 1, 1), mode='reflect')
        local_means = F.avg_pool2d(x_padded, kernel_size=3, stride=1)
        local_vars = F.avg_pool2d(x_padded**2, kernel_size=3, stride=1) - local_means**2
        local_vars = local_vars.clamp(min=0)
        noise_est = local_vars.mean(dim=[2, 3], keepdim=True)

        lambda_safe = self.lambda_param.clamp(1e-3, 1.0)
        lambdaE = lambda_safe * noise_est / (var + self.eps)
        psi = self.psi_function(lambdaE)

        effective_std = std * torch.exp(0.5 * psi)
        normalized = (x - mean) / (effective_std + self.eps)
        output = self.weight * normalized + self.bias

        if return_uncertainty:
            uncertainty = 1.0 / (effective_std**2 + self.eps)
            return output, uncertainty
        return output

# ============================================================
# Model definition (same as before)
# ============================================================
class EfficientCNN(nn.Module):
    def __init__(self, norm_type='layer', num_classes=10):
        super().__init__()
        self.norm_type = norm_type
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, padding=1)
        self.norm1 = self._create_norm(16)
        self.act1 = nn.Tanh()
        self.pool1 = nn.MaxPool2d(2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.norm2 = self._create_norm(32)
        self.act2 = nn.Tanh()
        self.pool2 = nn.MaxPool2d(2)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.norm3 = self._create_norm(64)
        self.act3 = nn.Tanh()
        self.pool3 = nn.MaxPool2d(2)

        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(64, num_classes)

    def _create_norm(self, num_features):
        if self.norm_type == 'layer':
            return StandardLayerNorm(num_features)
        elif self.norm_type == 'r_layer':
            return RLayerNorm(num_features)
        elif self.norm_type == 'bayesian_r_layer':
            return BayesianRLayerNorm(num_features)
        else:
            raise ValueError(f"Unknown norm_type: {self.norm_type}")

    def forward(self, x, return_uncertainty=False):
        x = self.pool1(self.act1(self.norm1(self.conv1(x))))
        x = self.pool2(self.act2(self.norm2(self.conv2(x))))
        x = self.conv3(x)
        if return_uncertainty and isinstance(self.norm3, BayesianRLayerNorm):
            x, unc = self.norm3(x, return_uncertainty=True)
            x = self.act3(x)
            x = self.pool3(x)
            x = self.global_pool(x).view(x.size(0), -1)
            return self.fc(x), unc
        else:
            x = self.act3(self.norm3(x))
            x = self.pool3(x)
            x = self.global_pool(x).view(x.size(0), -1)
            return self.fc(x)

# ============================================================
# Dataset with on‑the‑fly corruption (full CIFAR‑10)
# ============================================================
class OnlineNoisyCIFAR10(Dataset):
    """
    Applies corruption to CIFAR‑10 images on the fly.
    """
    def __init__(self, root='./data', train=True, noise_type='gaussian', severity=3,
                 transform=None, download=True):
        self.clean = CIFAR10(root=root, train=train, download=download, transform=None)
        self.noise_type = noise_type
        self.severity = severity
        # Base transform: to tensor + normalization
        self.base_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        ])
        self.transform = transform

    def __len__(self):
        return len(self.clean)

    def apply_corruption(self, img_pil):
        img_np = np.array(img_pil).astype(np.float32) / 255.0

        if self.noise_type == 'gaussian':
            noise = np.random.randn(*img_np.shape) * 0.1 * self.severity
            img_np = img_np + noise
        elif self.noise_type == 'shot_noise':
            mask = np.random.random(img_np.shape) < 0.05 * self.severity
            salt = np.random.random(mask.sum()) > 0.5
            img_np_flat = img_np.reshape(-1)
            mask_flat = mask.reshape(-1)
            img_np_flat[mask_flat] = salt.astype(np.float32)
            img_np = img_np_flat.reshape(img_np.shape)
        elif self.noise_type == 'blur':
            img_pil = img_pil.filter(ImageFilter.GaussianBlur(radius=self.severity*0.5))
            img_np = np.array(img_pil).astype(np.float32) / 255.0
        elif self.noise_type == 'contrast':
            mean = img_np.mean()
            contrast_factor = max(0.5, 1.0 - 0.2 * self.severity)
            img_np = contrast_factor * (img_np - mean) + mean

        img_np = np.clip(img_np, 0, 1)
        return torch.from_numpy(img_np.transpose(2, 0, 1)).float()

    def __getitem__(self, idx):
        img_pil, label = self.clean[idx]
        img_tensor = self.apply_corruption(img_pil)
        # Apply normalization
        img_tensor = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))(img_tensor)
        if self.transform:
            img_tensor = self.transform(img_tensor)
        return img_tensor, label

# ============================================================
# Dataset for CIFAR-10-C (if available)
# ============================================================
class CIFAR10C_Dataset(Dataset):
    """
    Loads pre‑generated corruptions from CIFAR-10-C.
    Expects folder structure:
        /path/to/CIFAR-10-C/
            gaussian_noise.npy
            shot_noise.npy
            gaussian_blur.npy   (maps to 'blur' in our code)
            contrast.npy
            labels.npy
    Each .npy file contains 10000x32x32x3 images, 5 severity levels concatenated.
    We use severity level 3 (index 2 in 0‑based).
    """
    def __init__(self, root, corruption_type, severity=3, transform=None):
        self.root = root
        self.corruption_type = corruption_type
        self.severity = severity
        self.transform = transform

        # Map our names to CIFAR-10-C filenames
        name_map = {
            'gaussian': 'gaussian_noise.npy',
            'shot_noise': 'shot_noise.npy',
            'blur': 'gaussian_blur.npy',
            'contrast': 'contrast.npy'
        }
        file_name = name_map.get(corruption_type)
        if file_name is None:
            raise ValueError(f"Corruption type {corruption_type} not available in CIFAR-10-C")

        data_path = os.path.join(root, file_name)
        labels_path = os.path.join(root, 'labels.npy')

        if not os.path.exists(data_path) or not os.path.exists(labels_path):
            raise FileNotFoundError(f"CIFAR-10-C files not found in {root}")

        # Load images and labels
        all_images = np.load(data_path)                # (50000, 32, 32, 3) for 5 severities
        all_labels = np.load(labels_path)              # (10000,) repeated 5 times

        # Select severity level: each severity block is 10000 images
        start = (severity - 1) * 10000
        end = severity * 10000
        self.images = all_images[start:end]            # (10000, 32, 32, 3)
        self.labels = all_labels[start:end]            # (10000,)

        # Convert to float and scale to [0,1]
        self.images = self.images.astype(np.float32) / 255.0

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]                          # (32,32,3) in [0,1]
        label = self.labels[idx]

        # Convert to tensor (CxHxW)
        img = torch.from_numpy(img.transpose(2, 0, 1)).float()

        # Normalize with CIFAR-10 stats
        normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
        img = normalize(img)

        if self.transform:
            img = self.transform(img)

        return img, label

# ============================================================
# Trainer with validation on clean set
# ============================================================
class Trainer:
    def __init__(self, device, run_seed):
        self.device = device
        self.seed = run_seed
        set_seed(run_seed)

    def clear_memory(self):
        torch.cuda.empty_cache()
        gc.collect()

    def train_and_evaluate(self, norm_type, train_loader, val_loader, test_loaders,
                           epochs=50, lr=0.001):
        model = EfficientCNN(norm_type=norm_type).to(self.device)
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.Adam(model.parameters(), lr=lr)

        train_losses, val_losses, val_accs = [], [], []
        best_val_loss = float('inf')
        best_model_state = None

        for epoch in range(1, epochs+1):
            # Training
            model.train()
            running_loss = 0.0
            for inputs, targets in train_loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                optimizer.step()
                running_loss += loss.item()
            train_loss = running_loss / len(train_loader)
            train_losses.append(train_loss)

            # Validation on clean data
            model.eval()
            val_loss = 0.0
            correct = 0
            total = 0
            with torch.no_grad():
                for inputs, targets in val_loader:
                    inputs, targets = inputs.to(self.device), targets.to(self.device)
                    outputs = model(inputs)
                    loss = criterion(outputs, targets)
                    val_loss += loss.item()
                    _, pred = outputs.max(1)
                    total += targets.size(0)
                    correct += pred.eq(targets).sum().item()
            val_loss /= len(val_loader)
            val_acc = 100. * correct / total
            val_losses.append(val_loss)
            val_accs.append(val_acc)

            print(f"Epoch {epoch:2d}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}%")

            # Save best model based on validation loss
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_model_state = model.state_dict().copy()

            if epoch % 10 == 0:
                self.clear_memory()

        # Load best model for evaluation
        model.load_state_dict(best_model_state)
        # Evaluate on each corruption type
        results = {}
        for name, loader in test_loaders.items():
            acc = self.evaluate(model, loader)
            results[name] = acc
        return model, results, train_losses, val_losses, val_accs

    def evaluate(self, model, loader):
        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, targets in loader:
                inputs, targets = inputs.to(self.device), targets.to(self.device)
                outputs = model(inputs)
                _, pred = outputs.max(1)
                total += targets.size(0)
                correct += pred.eq(targets).sum().item()
        return 100. * correct / total

# ============================================================
# Prepare data loaders (with optional CIFAR-10-C test set)
# ============================================================
def get_loaders(seed, test_samples_per_noise=10000):
    set_seed(seed)

    # Paths for Kaggle
    cifar10_root = '/kaggle/working/data'          # CIFAR-10 will be downloaded here
    cifar10c_path = '/kaggle/input/cifar-10/CIFAR-10-C'   # Provided by user

    # Clean validation set: 10% of training data
    full_train = CIFAR10(root=cifar10_root, train=True, download=True, transform=transforms.ToTensor())
    n_train = len(full_train)
    indices = list(range(n_train))
    np.random.shuffle(indices)
    split = int(0.1 * n_train)
    val_idx, train_idx = indices[:split], indices[split:]

    # Normalization transform for clean validation
    norm_transform = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    clean_transform = transforms.Compose([transforms.ToTensor(), norm_transform])

    # Validation dataset (clean)
    val_dataset = Subset(CIFAR10(root=cifar10_root, train=True, download=True, transform=clean_transform), val_idx)

    # Training dataset: mix of corruptions generated on the fly
    noise_types = ['gaussian', 'shot_noise', 'blur', 'contrast']
    train_datasets = []
    for nt in noise_types:
        ds = OnlineNoisyCIFAR10(root=cifar10_root, train=True, noise_type=nt, severity=3, download=True)
        train_subset = Subset(ds, train_idx)
        train_datasets.append(train_subset)
    train_dataset = ConcatDataset(train_datasets)

    # Test loaders: use CIFAR-10-C if available, otherwise fallback to online generation
    test_loaders = {}

    # Check if CIFAR-10-C exists
    use_cifar10c = os.path.isdir(cifar10c_path) and all(
        os.path.exists(os.path.join(cifar10c_path, f)) for f in
        ['gaussian_noise.npy', 'shot_noise.npy', 'gaussian_blur.npy', 'contrast.npy', 'labels.npy']
    )

    if use_cifar10c:
        print(f"Using CIFAR-10-C from {cifar10c_path}")
        for nt in noise_types:
            # Create dataset for this corruption, severity=3
            ds = CIFAR10C_Dataset(root=cifar10c_path, corruption_type=nt, severity=3)
            # Optionally limit samples
            if test_samples_per_noise and test_samples_per_noise < len(ds):
                indices = np.random.choice(len(ds), test_samples_per_noise, replace=False)
                ds = Subset(ds, indices)
            loader = DataLoader(ds, batch_size=64, shuffle=False, num_workers=2)
            test_loaders[nt] = loader
    else:
        print("CIFAR-10-C not found, using on-the-fly corruption for test set.")
        for nt in noise_types:
            test_ds = OnlineNoisyCIFAR10(root=cifar10_root, train=False, noise_type=nt, severity=3, download=True)
            if test_samples_per_noise and test_samples_per_noise < len(test_ds):
                test_indices = np.random.choice(len(test_ds), test_samples_per_noise, replace=False)
                test_ds = Subset(test_ds, test_indices)
            test_loader = DataLoader(test_ds, batch_size=64, shuffle=False, num_workers=2)
            test_loaders[nt] = test_loader

    # DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True, num_workers=2)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False, num_workers=2)

    return train_loader, val_loader, test_loaders

# ============================================================
# Run multiple seeds and collect results
# ============================================================
def run_experiment(seeds=[42, 123, 456], epochs=50, test_samples=10000):
    norm_types = ['layer', 'r_layer', 'bayesian_r_layer']
    norm_names = {'layer': 'LayerNorm', 'r_layer': 'R-LayerNorm', 'bayesian_r_layer': 'Bayesian R-LayerNorm'}

    # Create model output directory
    model_dir = '/kaggle/working/models'
    os.makedirs(model_dir, exist_ok=True)

    # Store results
    all_results = {nt: [] for nt in norm_types}
    all_train_metrics = {nt: [] for nt in norm_types}

    for seed in seeds:
        print(f"\n{'='*60}\nRUNNING SEED {seed}\n{'='*60}")
        set_seed(seed)
        trainer = Trainer(device, seed)

        train_loader, val_loader, test_loaders = get_loaders(seed, test_samples_per_noise=test_samples)

        for norm_type in norm_types:
            print(f"\n--- Training {norm_type} ---")
            model, results, tr_loss, val_loss, val_acc = trainer.train_and_evaluate(
                norm_type, train_loader, val_loader, test_loaders, epochs=epochs, lr=0.001)

            # Save model
            model_path = os.path.join(model_dir, f"{norm_type}_seed{seed}.pth")
            torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")

            all_results[norm_type].append(results)
            all_train_metrics[norm_type].append((tr_loss, val_loss, val_acc))
            trainer.clear_memory()

    # Compute summary statistics
    summary = {}
    for norm_type in norm_types:
        noise_accs = {noise: [] for noise in ['gaussian','shot_noise','blur','contrast']}
        for seed_results in all_results[norm_type]:
            for noise, acc in seed_results.items():
                noise_accs[noise].append(acc)
        mean_std = {}
        for noise, acc_list in noise_accs.items():
            mean = np.mean(acc_list)
            std = np.std(acc_list)
            mean_std[noise] = (mean, std)
        summary[norm_type] = mean_std

    return all_results, all_train_metrics, summary, norm_names

# ============================================================
# Plotting and saving results (no PDF)
# ============================================================
def generate_plots_and_save_results(all_results, all_train_metrics, summary, norm_names, seeds, epochs):
    # Output directory
    out_dir = '/kaggle/working/'
    viz_dir = os.path.join(out_dir, 'visualizations')
    os.makedirs(viz_dir, exist_ok=True)

    # Colors and styles
    colors = {'layer': '#3498db', 'r_layer': '#2ecc71', 'bayesian_r_layer': '#e74c3c'}
    line_styles = {'layer': '-', 'r_layer': '--', 'bayesian_r_layer': '-.'}
    noise_types = ['gaussian', 'shot_noise', 'blur', 'contrast']
    noise_display = [n.replace('_',' ').title() for n in noise_types]

    # 1. Training curves (mean ± std)
    plt.figure(figsize=(14, 5))
    plt.subplot(1,2,1)
    for nt in norm_names.keys():
        losses = np.array([all_train_metrics[nt][i][0] for i in range(len(seeds))])
        mean_loss = losses.mean(axis=0)
        std_loss = losses.std(axis=0)
        epochs_range = range(1, epochs+1)
        plt.plot(epochs_range, mean_loss, label=norm_names[nt], color=colors[nt], linestyle=line_styles[nt])
        plt.fill_between(epochs_range, mean_loss-std_loss, mean_loss+std_loss, color=colors[nt], alpha=0.2)
    plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.title('Training Loss (mean ± std)'); plt.legend(); plt.grid(alpha=0.3)

    plt.subplot(1,2,2)
    for nt in norm_names.keys():
        val_accs = np.array([all_train_metrics[nt][i][2] for i in range(len(seeds))])
        mean_acc = val_accs.mean(axis=0)
        std_acc = val_accs.std(axis=0)
        plt.plot(epochs_range, mean_acc, label=norm_names[nt], color=colors[nt], linestyle=line_styles[nt])
        plt.fill_between(epochs_range, mean_acc-std_acc, mean_acc+std_acc, color=colors[nt], alpha=0.2)
    plt.xlabel('Epoch'); plt.ylabel('Accuracy (%)'); plt.title('Validation Accuracy (mean ± std)'); plt.legend(); plt.grid(alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, 'training_curves.png'), dpi=150)
    plt.close()

    # 2. Bar chart with error bars
    plt.figure(figsize=(12,6))
    x = np.arange(len(noise_types))
    width = 0.25
    for i, nt in enumerate(norm_names.keys()):
        means = [summary[nt][noise][0] for noise in noise_types]
        stds  = [summary[nt][noise][1] for noise in noise_types]
        plt.bar(x + i*width - width, means, width, yerr=stds, capsize=3,
                label=norm_names[nt], color=colors[nt], alpha=0.8)
    plt.xlabel('Noise Type'); plt.ylabel('Accuracy (%)')
    plt.title('Test Accuracy on Corruptions (mean ± std)')
    plt.xticks(x, noise_display, rotation=45, ha='right')
    plt.legend(); plt.grid(alpha=0.3, axis='y')
    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, 'performance_comparison.png'), dpi=150)
    plt.close()

    # 3. Improvement over baseline
    plt.figure(figsize=(10,6))
    baseline_means = [summary['layer'][noise][0] for noise in noise_types]
    baseline_stds  = [summary['layer'][noise][1] for noise in noise_types]
    for nt, label, color in zip(['r_layer','bayesian_r_layer'], ['R-LayerNorm','Bayesian R-LayerNorm'], ['#2ecc71','#e74c3c']):
        means = [summary[nt][noise][0] for noise in noise_types]
        stds  = [summary[nt][noise][1] for noise in noise_types]
        imp = [m - b for m,b in zip(means, baseline_means)]
        imp_std = [np.sqrt(s**2 + bs**2) for s,bs in zip(stds, baseline_stds)]
        plt.bar(x + (0 if nt=='r_layer' else width), imp, width,
                yerr=imp_std, capsize=3, label=label, color=color, alpha=0.8)
    plt.axhline(0, color='black', linestyle='-', alpha=0.3)
    plt.xlabel('Noise Type'); plt.ylabel('Accuracy Improvement (%)')
    plt.title('Improvement Over Standard LayerNorm')
    plt.xticks(x, noise_display, rotation=45, ha='right')
    plt.legend(); plt.grid(alpha=0.3, axis='y')
    plt.tight_layout()
    plt.savefig(os.path.join(viz_dir, 'improvement_chart.png'), dpi=150)
    plt.close()

    # 4. Save summary as CSV
    rows = []
    for norm_type, nice_name in norm_names.items():
        for noise in noise_types:
            mean, std = summary[norm_type][noise]
            rows.append({
                'Normalization': nice_name,
                'Noise Type': noise,
                'Mean Accuracy': mean,
                'Std Accuracy': std
            })
    df = pd.DataFrame(rows)
    csv_path = os.path.join(out_dir, 'results_summary.csv')
    df.to_csv(csv_path, index=False)
    print(f"Results saved to {csv_path}")

    print(f"\n✅ All plots saved in {viz_dir}")

# ============================================================
# Main execution
# ============================================================
if __name__ == "__main__":
    # Parameters
    SEEDS = [42, 123, 456]          # 3 runs
    EPOCHS = 50                      # can be increased
    TEST_SAMPLES = 10000               # per corruption (use 10000 for full test set)

    print("\n" + "="*60)
    print("STARTING EXPERIMENT ON KAGGLE")
    print(f"Seeds: {SEEDS}, Epochs: {EPOCHS}, Test samples per corruption: {TEST_SAMPLES}")
    print("="*60)

    all_results, all_train_metrics, summary, norm_names = run_experiment(
        seeds=SEEDS, epochs=EPOCHS, test_samples=TEST_SAMPLES)

    generate_plots_and_save_results(all_results, all_train_metrics, summary, norm_names, SEEDS, EPOCHS)

    print("\n" + "="*60)
    print("EXPERIMENT COMPLETE. Check /kaggle/working/ for outputs.")
    print("="*60)

Using device: cuda
GPU: Tesla P100-PCIE-16GB
GPU Memory: 17.1 GB

STARTING EXPERIMENT ON KAGGLE
Seeds: [42, 123, 456], Epochs: 50, Test samples per corruption: 10000

RUNNING SEED 42


100%|██████████| 170M/170M [00:17<00:00, 9.79MB/s] 


Using CIFAR-10-C from /kaggle/input/cifar-10/CIFAR-10-C

--- Training layer ---
Epoch  1/50 | Train Loss: 1.7386 | Val Loss: 1.4678 | Val Acc: 47.28%
Epoch  2/50 | Train Loss: 1.4443 | Val Loss: 1.3541 | Val Acc: 50.90%
Epoch  3/50 | Train Loss: 1.3487 | Val Loss: 1.2699 | Val Acc: 54.26%
Epoch  4/50 | Train Loss: 1.2917 | Val Loss: 1.2131 | Val Acc: 57.24%
Epoch  5/50 | Train Loss: 1.2532 | Val Loss: 1.1716 | Val Acc: 58.02%
Epoch  6/50 | Train Loss: 1.2207 | Val Loss: 1.1611 | Val Acc: 58.68%
Epoch  7/50 | Train Loss: 1.1954 | Val Loss: 1.1319 | Val Acc: 59.60%
Epoch  8/50 | Train Loss: 1.1763 | Val Loss: 1.1219 | Val Acc: 59.90%
Epoch  9/50 | Train Loss: 1.1575 | Val Loss: 1.1225 | Val Acc: 60.20%
Epoch 10/50 | Train Loss: 1.1421 | Val Loss: 1.1115 | Val Acc: 60.56%
Epoch 11/50 | Train Loss: 1.1308 | Val Loss: 1.0960 | Val Acc: 61.34%
Epoch 12/50 | Train Loss: 1.1178 | Val Loss: 1.0865 | Val Acc: 61.70%
Epoch 13/50 | Train Loss: 1.1074 | Val Loss: 1.0889 | Val Acc: 61.74%
Epoch 14/5