In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np


In [14]:
# Preparing the CIFAR-10 Dataset

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])

# Loading the datasset
train_dataset = datasets.CIFAR10(root='./data', train=True,download=True, transform = transform)
val_dataset = datasets.CIFAR10(root='./data', train=False, download = True, transform=transform)

# Dataloader
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size,shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Files already downloaded and verified
Files already downloaded and verified


In [15]:
class SimpleCNN(nn.Module):
    def __init__(self, use_batch_norm=True):
        super(SimpleCNN, self).__init__()
        self.use_batch_norm = use_batch_norm

        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32) if use_batch_norm else nn.Identity()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64) if use_batch_norm else nn.Identity()

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128) if use_batch_norm else nn.Identity()

        self.fc1 = nn.Linear(128 * 4 * 4, 256)
        self.fc2 = nn.Linear(256, 128)
        self.fc3 = nn.Linear(128, 10)
        self.relu = nn.ReLU()

def forward(self, x):
    x = self.pool(self.relu(self.bn1(self.conv1(x))))
    x = self.pool(self.relu(self.bn2(self.conv2(x))))
    x = self.pool(self.relu(self.bn3(self.conv3(x))))
    x = x.view(x.size(0), -1)
    x = self.relu(self.fc1(x))
    x = self.relu(self.fc2(x))
    x = self.fc3(x)
    return x


In [19]:
# Defining params
learning_rate = 0.001
epochs = 50
patience = 5 # Early stopping

# Model instantiation
model_with_bn = SimpleCNN(use_batch_norm=True)
model_without_bn = SimpleCNN(use_batch_norm=False)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer_with_bn = optim.Adam(model_with_bn.parameters(),lr=learning_rate)
optimizer_without_bn = optim.Adam(model_without_bn.parameters(),lr=learning_rate)

#Early stopping utility
class EarlyStopping:
    def __init__ (self, patience=5):
        self.patience = patience
        self.best_loss = float('inf')
        self.counter = 0

    def step(selfself,val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
            return True
        else:
            self.counter += 1
            if self.counter >= self.patience:
                return False
            return True

