# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torchvision import datasets, models
from torchvision import transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np

# Model

## 1. Simple CNN Model Structure

In [None]:
# Define custom CNN class
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

**Output Shape Explanation**

1. **Input Shape**: If the input is a single image of size 32x32 with 3 color channels, the shape would be (1, 3, 32, 32).

2. **After conv1 and pool**:
    - **Convolution**: The output from conv1 will have the shape (1, 32, 32, 32) (1 batch, 32 channels, 32 height, 32 width) because we have 32 filters and the padding is set to 1, which keeps the size the same.
    
    - **Pooling**: After applying `MaxPool2d`(2, 2), the height and width are halved, so the output shape becomes (1, 32, 16, 16).

3. **After conv2 and pool**:
    - **Convolution**: The output from conv2 will have the shape (1, 64, 16, 16) (1 batch, 64 channels).
    
    - **Pooling**: After pooling, the size will again be halved, resulting in (1, 64, 8, 8).

4. **After flatten**: The output is reshaped to (-1, 64 * 8 * 8) which is (-1, 4096). The -1 allows PyTorch to infer the batch size, so it becomes (1, 4096).

5. **After fc1**: The output shape is (1, 128) since the first fully connected layer has 128 outputs.
6. **After fc2**: Finally, the output shape is (1, num_classes), which in this case will be (1, 10) if you keep the default num_classes as 10.

## 2. Simple CNN Model with prints after each layer

In [None]:
class SimpleCNN(nn.Module):
    def __init__(self, num_classses=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classses)

    def forward(self, x):
        print("Input shape:", x.shape)

        x = torch.relu(self.conv1(x))
        print("After conv1:", x.shape)

        x = self.pool(x)
        print("After Pool1:", x.shape)

        x = torch.relu(self.conv2(x))
        print("After conv2:", x.shape)

        x = self.pool(x)
        print("After Pool2:", x.shape)

        x = x.view(-1, 64*8*8) # flatten
        print('After Flatten:', x.shape)

        x = torch.relu(self.fc1(x))
        print('After fc1:', x.shape)

        x = torch.relu(self.fc2(x))
        print('After fc2:', x.shape)

        return x

Input shape: torch.Size([1, 3, 32, 32])
After conv1 and pool: torch.Size([1, 32, 16, 16])
After conv2 and pool: torch.Size([1, 64, 8, 8])
After flatten: torch.Size([1, 4096])
After fc1: torch.Size([1, 128])
After fc2: torch.Size([1, 10])


## 3. CNN with BatchNorm

### **What is BatchNorm**?
Batch Normalization (often abbreviated as BatchNorm) is a technique widely used in deep learning to improve the training of neural networks, particularly convolutional neural networks (CNNs). Below is a comprehensive overview of Batch Normalization, covering why it should be used, its formula, when to use it, and other important points.

#### **Why Use Batch Normalization**?
- **1. Stabilizes Learning**: Batch Normalization reduces internal covariate shift, which is the change in the distribution of network activations due to the weights of previous layers.   
By normalizing the inputs to each layer, BatchNorm stabilizes learning dynamics.

- **2. Accelerates Training**: By normalizing each mini-batch, BatchNorm allows for larger learning rates, which can lead to faster convergence. This means you can train models more quickly and efficiently.

- **3. Reduces Sensitivity to Initialization**: Networks with BatchNorm become less sensitive to the initial values of the weights, making it easier to train neural networks successfully.

- **4. Acts as a Regularizer**: Batch Normalization can have a slight regularizing effect, reducing the need for other regularization techniques like dropout in some cases.

- **5. Improves Performance**: Many researchers have found that using Batch Normalization leads to higher overall performance in terms of accuracy, especially in deeper networks.

#### **Formula**
The Batch Normalization layer normalizes the input of each layer as follows:

For a given mini-batch 
B containing m examples, with each example having x_i​:
- **1. Compute the mean μB and variance σ2B of the mini-batch**:

![mean and variance in BN](<../../pics/Batchnorm.png>)

- **2. Normalize the batch**:

![Normalize the batch](<../../pics/normalize_batch.png>)

where ϵ is a small constant added for numerical stability.

- **3. Scale and shift the normalized output**:

![scale_shift](<../../pics/scale_shift.png>)

where `γ` and `β` are learnable parameters (**scale and shift**) that enable the network to represent the original input data distribution if needed.

#### **When to Use Batch Normalization**
- **1. After Convolutional Layers**: It is commonly placed after convolutional and fully connected layers but before the activation function (e.g., ReLU).

- **2. In Deep Networks**: Especially useful in very deep networks where gradients can vanish or explode.

- **3. In Transfer Learning**: When using pre-trained models, BatchNorm can help in fine-tuning the model effectively.

- **4. In CNNs for Vision Tasks**: Typically beneficial in tasks like image classification, object detection, and segmentation tasks involving CNNs.

To improve your `SimpleCNN` code with Batch Normalization and multiple pooling strategies, I'll modify it as follows:

- **1. Batch Normalization**: Adding batch normalization after each convolutional layer helps stabilize and accelerate training.
- **2. Pooling Variants**: Using both Max Pooling and Average Pooling can help the model capture different aspects of spatial features.


In [None]:
import torch
import torch.nn as nn

# Define custom CNN class
class ImprovedCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(ImprovedCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)
        
        self.maxpool = nn.MaxPool2d(2, 2)
        self.avgpool = nn.AvgPool2d(2, 2)
        
        self.fc1 = nn.Linear(64 * 4 * 4, 128)  # Adjusted input size for pooling layers
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        # First conv layer with batchnorm and max pooling
        x = self.maxpool(torch.relu(self.bn1(self.conv1(x))))
        
        # Second conv layer with batchnorm and average pooling
        x = self.avgpool(torch.relu(self.bn2(self.conv2(x))))
        
        # Flatten the tensor for fully connected layers
        x = x.view(-1, 64 * 4 * 4)
        
        # Fully connected layers
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        
        return x

## 4. Loading pre-trained model (ResNet-18)

In [None]:
# Define function for loading pre-trained model (ResNet-18)
def load_resnet18(num_classes=10):
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

# Utils

## 1. Helper Classes: `AverageMeter` class for tracking loss and other metrics


In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, value, n=1):
        self.sum += value * n
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count if self.count > 0 else 0

## 2. Function for visualize a batch of data

In [None]:
def normalize_image(image):
    image_min = image.min()
    image_max = image.max()
    image.clamp_(min = image_min, max = image_max)
    image.add_(-image_min).div_(image_max - image_min + 1e-5)
    return image

def plot_images(images, labels, classes, normalize=True):
    n_images = len(images)

    rows = int(np.sqrt(n_images))
    cols = int(np.sqrt(n_images))

    fig = plt.figure(figsize=(10, 10))

    for i in range(rows*cols):

        ax = fig.add_subplot(rows, cols, i+1)

        image = images[i]
        if normalize:
            image = normalize_image(image)

        ax.imshow(image.permute(1, 2, 0).cpu().numpy())
        ax.set_title(classes[labels[i]])
        ax.axis('off')

# Trasform and load data

In [5]:
# Advanced preprocessing and data augmentation
transform_train = T.Compose([T.RandomHorizontalFlip(),
                             T.RandomRotation(10),
                             T.RandomResizedCrop(32, scale=(0.8, 1.0)),
                             T.ToTensor(),
                             T.Normalize((0.5, 0.5, 0.5), 
                                         (0.5, 0.5, 0.5))])

transform_val = T.Compose([T.ToTensor(),
                           T.Normalize((0.5, 0.5, 0.5), 
                                       (0.5, 0.5, 0.5))])

In [None]:
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform_train, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform_val, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## Visualize a batch

In [None]:
batch = next(iter(train_loader))
classes = train_dataset.classes
plot_images(batch[0], batch[1], classes)

# Model & Loss_fn & Optimizer

In [None]:
# Choose model (custom CNN or ResNet-18)
model_type = "custom"  # change to "resnet" for ResNet-18
num_classes = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if model_type == "custom":
    model = SimpleCNN(num_classes=num_classes).to(device)
    summary(model, (3, 32, 32))
else:
    model = load_resnet18(num_classes=num_classes).to(device)
    summary(model, (3, 32, 32))


# Define criterion, optimizer, and metrics
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)

# Training and Validation Functions


In [None]:
# Training and validation functions
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch, accuracy_metric):
    model.train()
    loss_meter = AverageMeter()
    accuracy_metric.reset()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Training]", leave=False)
    
    for X_batch, y_batch in progress_bar:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update loss and accuracy
        loss_meter.update(loss.item(), X_batch.size(0))
        preds = outputs.argmax(dim=1)
        accuracy_metric.update(preds, y_batch)

        progress_bar.set_postfix(loss=loss_meter.avg, accuracy=accuracy_metric.compute().item())
        
    avg_loss = loss_meter.avg
    avg_accuracy = accuracy_metric.compute().item()
    
    return avg_loss, avg_accuracy

In [None]:
def validate(model, dataloader, criterion, device, epoch, accuracy_metric):
    model.eval()
    loss_meter = AverageMeter()
    accuracy_metric.reset()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Validation]", leave=False)
    
    with torch.no_grad():
        for X_batch, y_batch in progress_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            # Update loss and accuracy
            loss_meter.update(loss.item(), X_batch.size(0))
            preds = outputs.argmax(dim=1)
            accuracy_metric.update(preds, y_batch)

            progress_bar.set_postfix(loss=loss_meter.avg)
        
    avg_loss = loss_meter.avg
    avg_accuracy = accuracy_metric.compute().item()
    
    return avg_loss, avg_accuracy

# Training Script

In [None]:
# Initialize TensorBoard  
writer = SummaryWriter()  

# Training loop  
num_epochs = 20  
best_val_acc = 0.0  
for epoch in range(num_epochs):  
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, train_accuracy)  
    val_loss, val_acc = validate(model, val_loader, criterion, device, epoch, val_accuracy)  

    # Log metrics to TensorBoard  
    writer.add_scalar('Loss/Train', train_loss, epoch)  
    writer.add_scalar('Accuracy/Train', train_acc, epoch)  
    writer.add_scalar('Loss/Validation', val_loss, epoch)  
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)  

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, "  
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")  

    # Save best model  
    if val_acc > best_val_acc:  
        best_val_acc = val_acc  
        torch.save(model.state_dict(), "best_model.pth")  
        print(f"Best model saved at epoch {epoch+1} with validation accuracy: {best_val_acc:.4f}")  

# Close TensorBoard writer  
writer.close()

In [None]:
# Plot training and validation loss and accuracy
def plot_metrics(metric_values, title, xlabel, ylabel):
    plt.figure(figsize=(10, 5))
    plt.plot(metric_values, label=f"{ylabel}")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.show()

# Data visualization function
def show_sample_predictions(model, dataloader, class_names, device):
    model.eval()
    with torch.no_grad():
        X_batch, y_batch = next(iter(dataloader))
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        outputs = model(X_batch)
        _, preds = torch.max(outputs, 1)

        plt.figure(figsize=(12, 8))
        for i in range(8):
            plt.subplot(2, 4, i + 1)
            plt.imshow(np.transpose(X_batch[i].cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)  # Un-normalize for display
            plt.title(f"True: {class_names[y_batch[i]]}, Pred: {class_names[preds[i]]}")
            plt.axis("off")
        plt.show()

# Example usage: show predictions for a batch of validation data
class_names = train_dataset.classes  # CIFAR-10 class names
show_sample_predictions(model, val_loader, class_names, device)

In [None]:
%load_ext tensorboard 
%tensorboard --logdir runs