In [1]:
"""
Utils for neural networks
"""

from torch import nn


def init_weights_(m):
    """
    Initializes weights of m according to Xavier normal method.

    :param m: module
    :return:
    """
    if isinstance(m, nn.Conv2d):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

    elif isinstance(m, nn.BatchNorm2d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

    elif isinstance(m, nn.BatchNorm1d):
        nn.init.ones_(m.weight)
        nn.init.zeros_(m.bias)

    elif isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

In [2]:
"""
Data loaders
"""
import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import torchvision.datasets as datasets



class PartialDataset(Dataset):
    def __init__(self, dataset, n_items=10):
        self.dataset = dataset
        self.n_items = n_items

    def __getitem__(self):
        return self.dataset.__getitem__()

    def __len__(self):
        return min(self.n_items, len(self.dataset))


def get_cifar_loader(root='../data/', batch_size=128, train=True, shuffle=True, num_workers=4, n_items=-1):
    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
                                     std=[0.5, 0.5, 0.5])

    data_transforms = transforms.Compose(
        [transforms.ToTensor(),
        normalize])

    dataset = datasets.CIFAR10(root=root, train=train, download=True, transform=data_transforms)
    if n_items > 0:
        dataset = PartialDataset(dataset, n_items)

    loader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)

    return loader

if __name__ == '__main__':
    train_loader = get_cifar_loader()
    for X, y in train_loader:
        print(X[0])
        print(y[0])
        print(X[0].shape)
        img = np.transpose(X[0], [1,2,0])
        plt.imshow(img*0.5 + 0.5)
        plt.savefig('sample.png')
        print(X[0].max())
        print(X[0].min())
        break

100%|██████████| 170M/170M [00:02<00:00, 79.4MB/s] 


tensor([[[-0.3961, -0.5059, -0.5765,  ..., -0.9608, -0.9059, -0.9294],
         [-0.6784, -0.7882, -0.5765,  ..., -0.8510, -0.9216, -0.8902],
         [-0.7882, -0.9373, -0.8431,  ..., -0.7333, -0.9373, -0.8118],
         ...,
         [-0.8431, -0.8039, -0.7882,  ..., -0.6314, -0.7020, -0.7255],
         [-0.8510, -0.8275, -0.8196,  ..., -0.6627, -0.7176, -0.7569],
         [-0.8667, -0.8431, -0.8353,  ..., -0.6941, -0.7333, -0.7725]],

        [[-0.3333, -0.4902, -0.5922,  ..., -0.9373, -0.9216, -0.9373],
         [-0.6235, -0.7490, -0.5608,  ..., -0.8275, -0.9373, -0.8902],
         [-0.8196, -0.9451, -0.8353,  ..., -0.7490, -0.9451, -0.8118],
         ...,
         [-0.8275, -0.7882, -0.7725,  ..., -0.5922, -0.6549, -0.6863],
         [-0.8353, -0.8118, -0.8039,  ..., -0.6314, -0.6863, -0.7255],
         [-0.8510, -0.8275, -0.8196,  ..., -0.6627, -0.7020, -0.7490]],

        [[-0.3647, -0.6235, -0.7020,  ..., -0.9529, -0.9373, -0.9216],
         [-0.6471, -0.8039, -0.6627,  ..., -0

In [3]:
"""
VGG
"""
import numpy as np
from torch import nn

# ## Models implementation
def get_number_of_parameters(model):
    parameters_n = 0
    for parameter in model.parameters():
        parameters_n += np.prod(parameter.shape).item()

    return parameters_n

class VGG_A_BatchNorm(nn.Module):
    def __init__(self, inp_ch=3, num_classes=10, init_weights=True):
        super().__init__()
        
        self.features = nn.Sequential(
            # stage 1
            nn.Conv2d(in_channels=inp_ch, out_channels=64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # stage 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # stage 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # stage 4
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),
            
            # stage5
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))
        
        self.classifier = nn.Sequential(
            nn.Linear(512 * 1 * 1, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Linear(512, num_classes))
        
        if init_weights:
            self._init_weights()
    
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x.view(-1, 512 * 1 * 1))
        return x
    
    def _init_weights(self):
        for m in self.modules():
            init_weights_(m)

class VGG_A(nn.Module):
    """VGG_A model

    size of Linear layers is smaller since input assumed to be 32x32x3, instead of
    224x224x3
    """

    def __init__(self, inp_ch=3, num_classes=10, init_weights=True):
        super().__init__()

        self.features = nn.Sequential(
            # stage 1
            nn.Conv2d(in_channels=inp_ch, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # stage 2
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # stage 3
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # stage 4
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2),

            # stage5
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.classifier = nn.Sequential(
            nn.Linear(512 * 1 * 1, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, num_classes))

        if init_weights:
            self._init_weights()

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x.view(-1, 512 * 1 * 1))
        return x

    def _init_weights(self):
        for m in self.modules():
            init_weights_(m)


class VGG_A_Light(nn.Module):
    def __init__(self, inp_ch=3, num_classes=10):
        super().__init__()

        self.stage1 = nn.Sequential(
            nn.Conv2d(in_channels=inp_ch, out_channels=16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage2 = nn.Sequential(
            nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        '''
        self.stage3 = nn.Sequential(
            nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage4 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2))
        '''
        self.classifier = nn.Sequential(
            nn.Linear(32 * 8 * 8, 128),
            nn.ReLU(),
            nn.Linear(128, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes))

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        # x = self.stage3(x)
        # x = self.stage4(x)
        # x = self.stage5(x)
        x = self.classifier(x.view(-1, 32 * 8 * 8))
        return x


class VGG_A_Dropout(nn.Module):
    def __init__(self, inp_ch=3, num_classes=10):
        super().__init__()

        self.stage1 = nn.Sequential(
            nn.Conv2d(in_channels=inp_ch, out_channels=64, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage2 = nn.Sequential(
            nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage3 = nn.Sequential(
            nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage4 = nn.Sequential(
            nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.stage5 = nn.Sequential(
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2))

        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(512 * 1 * 1, 512),
            nn.ReLU(True),
            nn.Dropout(),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, num_classes))

    def forward(self, x):
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.stage4(x)
        x = self.stage5(x)
        x = self.classifier(x.view(-1, 512 * 1 * 1))
        return x


if __name__ == '__main__':
    print(get_number_of_parameters(VGG_A()))
    print(get_number_of_parameters(VGG_A_Light()))
    print(get_number_of_parameters(VGG_A_Dropout()))

9750922
285162
9750922


In [5]:
import matplotlib as mpl
mpl.use('Agg')  # Use non-interactive backend for matplotlib
import matplotlib.pyplot as plt
from torch import nn
import numpy as np
import torch
import os
import random
from tqdm import tqdm as tqdm
import copy
import time
from collections import defaultdict

# Import your custom modules
# from models.vgg import VGG_A
# from models.vgg import VGG_A_BatchNorm
# from data.loaders import get_cifar_loader

# Initialize parameters
device_id = [0,1]
num_workers = 4
batch_size = 128

# Kaggle output paths
figures_path = '/kaggle/working/figures'
models_path = '/kaggle/working/models'
os.makedirs(figures_path, exist_ok=True)
os.makedirs(models_path, exist_ok=True)

# Device configuration
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
print(f"GPU count: {torch.cuda.device_count()}")

# Data loaders
train_loader = get_cifar_loader(train=True, batch_size=batch_size, num_workers=num_workers)
val_loader = get_cifar_loader(train=False, batch_size=batch_size, num_workers=num_workers)

# Accuracy calculation
def get_accuracy(model, loader):
    correct = 0
    total = 0
    model.eval()
    with torch.no_grad():
        for data in loader:
            x, y = data
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
    return correct / total

# Loss calculation
def get_loss(model, loader, criterion):
    total_loss = 0
    total_samples = 0
    model.eval()
    with torch.no_grad():
        for data in loader:
            x, y = data
            x, y = x.to(device), y.to(device)
            outputs = model(x)
            loss = criterion(outputs, y)
            total_loss += loss.item() * x.size(0)
            total_samples += x.size(0)
    return total_loss / total_samples

# Seed setup
def set_random_seeds(seed_value=0):
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    random.seed(seed_value)
    if torch.cuda.is_available(): 
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# Training function with full metrics tracking
def train(model, optimizer, criterion, train_loader, val_loader, scheduler=None, 
          epochs_n=100, model_save_path=None, lr=None, model_type=None):
    model.to(device)
    if torch.cuda.device_count() > 1:
        model = nn.DataParallel(model, device_ids=device_id)
    
    # Initialize metrics storage
    train_losses = []         # Per-epoch training loss
    train_accs = []           # Per-epoch training accuracy
    val_losses = []           # Per-epoch validation loss
    val_accs = []             # Per-epoch validation accuracy
    batch_losses = []         # Per-batch training loss
    learning_rates = []
    best_val_acc = 0.0
    best_model_state = None

    # Calculate total batches per epoch
    batches_per_epoch = len(train_loader)
    total_batches = batches_per_epoch * epochs_n
    
    for epoch in range(epochs_n):
        epoch_start = time.time()
        
        # Training phase
        model.train()
        epoch_train_loss = 0.0
        correct = 0
        total = 0
        
        for batch_idx, data in enumerate(train_loader):
            x, y = data
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            
            # Record per-batch loss
            batch_loss = loss.item()
            batch_losses.append(batch_loss)
            epoch_train_loss += batch_loss * x.size(0)
            
            # Calculate batch accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += y.size(0)
            correct += (predicted == y).sum().item()
        
        # Calculate epoch metrics
        avg_train_loss = epoch_train_loss / total
        train_acc = correct / total
        train_losses.append(avg_train_loss)
        train_accs.append(train_acc)
        
        # Validation phase
        val_loss = get_loss(model, val_loader, criterion)
        val_acc = get_accuracy(model, val_loader)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Update scheduler if provided
        if scheduler is not None:
            scheduler.step()
            current_lr = optimizer.param_groups[0]['lr']
        else:
            current_lr = lr
        learning_rates.append(current_lr)
        
        # Check for best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = copy.deepcopy(model.state_dict())
        
        epoch_time = time.time() - epoch_start
        print(f'[{model_type}] Epoch {epoch+1}/{epochs_n} (LR: {current_lr:.0e}) | '
              f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.4f} | '
              f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f} | '
              f'Time: {epoch_time:.1f}s')
    
    # Save best model
    if model_save_path and best_model_state is not None:
        torch.save(best_model_state, model_save_path)
        print(f'Saved best model to {model_save_path} with val acc {best_val_acc:.4f}')
    
    # Return all collected metrics
    metrics = {
        'train_losses': train_losses,      # Per-epoch
        'train_accs': train_accs,          # Per-epoch
        'val_losses': val_losses,          # Per-epoch
        'val_accs': val_accs,              # Per-epoch
        'batch_losses': batch_losses,      # Per-batch
        'learning_rates': learning_rates,  # Per-epoch
        'best_val_acc': best_val_acc
    }
    return metrics

# Main analysis function
def analyze_bn_effect():
    # Learning rates to test
    learning_rates = [1e-3, 2e-3, 1e-4, 5e-4]
    epochs = 20
    
    # Store results
    results_non_bn = {}
    results_bn = {}
    
    # Train non-BN models with different learning rates
    print("\n" + "="*50)
    print("Training non-BatchNorm models")
    print("="*50)
    for lr in learning_rates:
        print(f"\n{'='*20} Training without BN (LR: {lr:.0e}) {'='*20}")
        set_random_seeds(2020)
        model = VGG_A()
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        criterion = nn.CrossEntropyLoss()
        model_path = os.path.join(models_path, f'non_bn_model_lr_{lr}.pth')
        metrics = train(model, optimizer, criterion, train_loader, val_loader, 
                       epochs_n=epochs, model_save_path=model_path, 
                       lr=lr, model_type="Without BN")
        results_non_bn[lr] = metrics
    
    # Train BN models with different learning rates
    print("\n" + "="*50)
    print("Training BatchNorm models")
    print("="*50)
    for lr in learning_rates:
        print(f"\n{'='*20} Training with BN (LR: {lr:.0e}) {'='*20}")
        set_random_seeds(2020)
        model_bn = VGG_A_BatchNorm()
        optimizer_bn = torch.optim.Adam(model_bn.parameters(), lr=lr)
        model_path_bn = os.path.join(models_path, f'bn_model_lr_{lr}.pth')
        metrics_bn = train(model_bn, optimizer_bn, criterion, train_loader, val_loader,
                          epochs_n=epochs, model_save_path=model_path_bn,
                          lr=lr, model_type="With BN")
        results_bn[lr] = metrics_bn
    
    # Create performance comparison plots
    print("\nCreating performance comparison plots...")
    create_comparison_plots(results_non_bn, results_bn, learning_rates, epochs)
    
    # Create loss landscape plots
    print("\nCreating loss landscape plots...")
    plot_loss_landscape(results_non_bn, results_bn, learning_rates)
    
    print("\nAnalysis complete! All outputs saved to /kaggle/working")

# Create comprehensive comparison plots
def create_comparison_plots(results_non_bn, results_bn, learning_rates, epochs):
    # Create figure for non-BN models
    plt.figure(figsize=(18, 12))
    plt.suptitle('VGG Without BatchNorm: Performance Across Learning Rates', fontsize=16)
    
    # Plot training loss
    plt.subplot(2, 2, 1)
    for lr in learning_rates:
        metrics = results_non_bn[lr]
        plt.plot(range(1, epochs+1), metrics['train_losses'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.legend()
    
    # Plot training accuracy
    plt.subplot(2, 2, 2)
    for lr in learning_rates:
        metrics = results_non_bn[lr]
        plt.plot(range(1, epochs+1), metrics['train_accs'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Training Accuracy')
    plt.title('Training Accuracy')
    plt.grid(True)
    plt.legend()
    
    # Plot validation loss
    plt.subplot(2, 2, 3)
    for lr in learning_rates:
        metrics = results_non_bn[lr]
        plt.plot(range(1, epochs+1), metrics['val_losses'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss')
    plt.grid(True)
    plt.legend()
    
    # Plot validation accuracy
    plt.subplot(2, 2, 4)
    for lr in learning_rates:
        metrics = results_non_bn[lr]
        plt.plot(range(1, epochs+1), metrics['val_accs'], 'o-', label=f'LR={lr:.0e}')
        # Annotate final accuracy
        final_acc = metrics['val_accs'][-1]
        plt.annotate(f'{final_acc:.4f}', (epochs, final_acc), 
                     textcoords="offset points", xytext=(10,0), ha='left')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title('Validation Accuracy')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(figures_path, 'non_bn_performance_comparison.png'), dpi=300)
    plt.close()
    
    # Create figure for BN models
    plt.figure(figsize=(18, 12))
    plt.suptitle('VGG With BatchNorm: Performance Across Learning Rates', fontsize=16)
    
    # Plot training loss
    plt.subplot(2, 2, 1)
    for lr in learning_rates:
        metrics = results_bn[lr]
        plt.plot(range(1, epochs+1), metrics['train_losses'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss')
    plt.grid(True)
    plt.legend()
    
    # Plot training accuracy
    plt.subplot(2, 2, 2)
    for lr in learning_rates:
        metrics = results_bn[lr]
        plt.plot(range(1, epochs+1), metrics['train_accs'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Training Accuracy')
    plt.title('Training Accuracy')
    plt.grid(True)
    plt.legend()
    
    # Plot validation loss
    plt.subplot(2, 2, 3)
    for lr in learning_rates:
        metrics = results_bn[lr]
        plt.plot(range(1, epochs+1), metrics['val_losses'], 'o-', label=f'LR={lr:.0e}')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss')
    plt.grid(True)
    plt.legend()
    
    # Plot validation accuracy
    plt.subplot(2, 2, 4)
    for lr in learning_rates:
        metrics = results_bn[lr]
        plt.plot(range(1, epochs+1), metrics['val_accs'], 'o-', label=f'LR={lr:.0e}')
        # Annotate final accuracy
        final_acc = metrics['val_accs'][-1]
        plt.annotate(f'{final_acc:.4f}', (epochs, final_acc), 
                     textcoords="offset points", xytext=(10,0), ha='left')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title('Validation Accuracy')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(figures_path, 'bn_performance_comparison.png'), dpi=300)
    plt.close()
    
    # Create BN vs non-BN comparison at best LR
    best_lr = max(learning_rates)  # Typically higher LR benefits BN
    plt.figure(figsize=(15, 10))
    plt.suptitle(f'BatchNorm vs No BatchNorm Comparison (LR={best_lr:.0e})', fontsize=16)
    
    # Training Loss comparison
    plt.subplot(2, 2, 1)
    plt.plot(range(1, epochs+1), results_non_bn[best_lr]['train_losses'], 'o-', label='Without BN')
    plt.plot(range(1, epochs+1), results_bn[best_lr]['train_losses'], 'o-', label='With BN')
    plt.xlabel('Epoch')
    plt.ylabel('Training Loss')
    plt.title('Training Loss Comparison')
    plt.grid(True)
    plt.legend()
    
    # Training Accuracy comparison
    plt.subplot(2, 2, 2)
    plt.plot(range(1, epochs+1), results_non_bn[best_lr]['train_accs'], 'o-', label='Without BN')
    plt.plot(range(1, epochs+1), results_bn[best_lr]['train_accs'], 'o-', label='With BN')
    plt.xlabel('Epoch')
    plt.ylabel('Training Accuracy')
    plt.title('Training Accuracy Comparison')
    plt.grid(True)
    plt.legend()
    
    # Validation Loss comparison
    plt.subplot(2, 2, 3)
    plt.plot(range(1, epochs+1), results_non_bn[best_lr]['val_losses'], 'o-', label='Without BN')
    plt.plot(range(1, epochs+1), results_bn[best_lr]['val_losses'], 'o-', label='With BN')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Loss')
    plt.title('Validation Loss Comparison')
    plt.grid(True)
    plt.legend()
    
    # Validation Accuracy comparison
    plt.subplot(2, 2, 4)
    plt.plot(range(1, epochs+1), results_non_bn[best_lr]['val_accs'], 'o-', label='Without BN')
    plt.plot(range(1, epochs+1), results_bn[best_lr]['val_accs'], 'o-', label='With BN')
    plt.xlabel('Epoch')
    plt.ylabel('Validation Accuracy')
    plt.title('Validation Accuracy Comparison')
    plt.grid(True)
    plt.legend()
    
    plt.tight_layout(rect=[0, 0, 1, 0.96])
    plt.savefig(os.path.join(figures_path, 'bn_vs_non_bn_comparison.png'), dpi=300)
    plt.close()

# Create loss landscape visualization as per project requirements
def plot_loss_landscape(results_non_bn, results_bn, learning_rates):
    # Prepare batch loss data for non-BN models
    non_bn_losses = []
    for lr in learning_rates:
        non_bn_losses.append(results_non_bn[lr]['batch_losses'])
    
    # Prepare batch loss data for BN models
    bn_losses = []
    for lr in learning_rates:
        bn_losses.append(results_bn[lr]['batch_losses'])
    
    # Find the minimum length among all runs
    min_length_non_bn = min(len(losses) for losses in non_bn_losses)
    min_length_bn = min(len(losses) for losses in bn_losses)
    min_length = min(min_length_non_bn, min_length_bn)
    
    # Trim all loss lists to the same length
    non_bn_losses = [losses[:min_length] for losses in non_bn_losses]
    bn_losses = [losses[:min_length] for losses in bn_losses]
    
    # Create min and max curves for non-BN
    min_curve_non_bn = []
    max_curve_non_bn = []
    for step in range(min_length):
        step_losses = [losses[step] for losses in non_bn_losses]
        min_curve_non_bn.append(min(step_losses))
        max_curve_non_bn.append(max(step_losses))
    
    # Create min and max curves for BN
    min_curve_bn = []
    max_curve_bn = []
    for step in range(min_length):
        step_losses = [losses[step] for losses in bn_losses]
        min_curve_bn.append(min(step_losses))
        max_curve_bn.append(max(step_losses))
    
    # Plot loss landscape for non-BN models
    plt.figure(figsize=(10, 6))
    steps = range(min_length)
    plt.fill_between(steps, min_curve_non_bn, max_curve_non_bn, color='red', alpha=0.3)
    plt.plot(steps, min_curve_non_bn, 'r-', linewidth=1, label='Min Loss')
    plt.plot(steps, max_curve_non_bn, 'r--', linewidth=1, label='Max Loss')
    plt.xlabel('Training Steps')
    plt.ylabel('Batch Loss')
    plt.title('Loss Landscape Without BatchNorm')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(figures_path, 'non_bn_loss_landscape.png'), dpi=300)
    plt.close()
    
    # Plot loss landscape for BN models
    plt.figure(figsize=(10, 6))
    steps = range(min_length)
    plt.fill_between(steps, min_curve_bn, max_curve_bn, color='blue', alpha=0.3)
    plt.plot(steps, min_curve_bn, 'b-', linewidth=1, label='Min Loss')
    plt.plot(steps, max_curve_bn, 'b--', linewidth=1, label='Max Loss')
    plt.xlabel('Training Steps')
    plt.ylabel('Batch Loss')
    plt.title('Loss Landscape With BatchNorm')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    plt.tight_layout()
    plt.savefig(os.path.join(figures_path, 'bn_loss_landscape.png'), dpi=300)
    plt.close()
    
    # Create combined loss landscape comparison
    plt.figure(figsize=(12, 8))
    steps = range(min_length)
    
    # Plot non-BN landscape
    plt.fill_between(steps, min_curve_non_bn, max_curve_non_bn, color='red', alpha=0.3, label='Without BN')
    plt.plot(steps, min_curve_non_bn, 'r-', linewidth=1, alpha=0.7)
    plt.plot(steps, max_curve_non_bn, 'r--', linewidth=1, alpha=0.7)
    
    # Plot BN landscape
    plt.fill_between(steps, min_curve_bn, max_curve_bn, color='blue', alpha=0.3, label='With BN')
    plt.plot(steps, min_curve_bn, 'b-', linewidth=1, alpha=0.7)
    plt.plot(steps, max_curve_bn, 'b--', linewidth=1, alpha=0.7)
    
    plt.xlabel('Training Steps')
    plt.ylabel('Batch Loss')
    plt.title('Batch Normalization Effect on Loss Landscape')
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend()
    
    # Add annotations
    plt.annotate('BN stabilizes loss landscape\nreducing variance across learning rates', 
                xy=(min_length*0.6, (max_curve_non_bn[min_length//2] + min_curve_non_bn[min_length//2])/2),
                xytext=(min_length*0.4, max_curve_non_bn[0]*0.8),
                arrowprops=dict(facecolor='black', shrink=0.05),
                fontsize=12)
    
    plt.tight_layout()
    plt.savefig(os.path.join(figures_path, 'bn_loss_landscape_comparison.png'), dpi=300)
    plt.close()

# Run the analysis
if __name__ == '__main__':
    analyze_bn_effect()

Using device: cuda
GPU count: 2

Training non-BatchNorm models

[Without BN] Epoch 1/20 (LR: 1e-03) | Train Loss: 1.7739, Train Acc: 0.3063 | Val Loss: 1.4425, Val Acc: 0.4477 | Time: 21.7s
[Without BN] Epoch 2/20 (LR: 1e-03) | Train Loss: 1.2262, Train Acc: 0.5512 | Val Loss: 1.0553, Val Acc: 0.6202 | Time: 20.2s
[Without BN] Epoch 3/20 (LR: 1e-03) | Train Loss: 0.9344, Train Acc: 0.6686 | Val Loss: 0.9420, Val Acc: 0.6737 | Time: 20.3s
[Without BN] Epoch 4/20 (LR: 1e-03) | Train Loss: 0.7700, Train Acc: 0.7320 | Val Loss: 0.8052, Val Acc: 0.7195 | Time: 19.8s
[Without BN] Epoch 5/20 (LR: 1e-03) | Train Loss: 0.6459, Train Acc: 0.7749 | Val Loss: 0.8381, Val Acc: 0.7257 | Time: 20.1s
[Without BN] Epoch 6/20 (LR: 1e-03) | Train Loss: 0.5408, Train Acc: 0.8118 | Val Loss: 0.8160, Val Acc: 0.7435 | Time: 20.3s
[Without BN] Epoch 7/20 (LR: 1e-03) | Train Loss: 0.4560, Train Acc: 0.8434 | Val Loss: 0.8076, Val Acc: 0.7421 | Time: 20.1s
[Without BN] Epoch 8/20 (LR: 1e-03) | Train Loss: 0.37