### Importing Libraries

In [None]:
import numpy as np
import tarfile
import os

import torch
from torch import nn
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torchvision.models import resnet18

torch.manual_seed(100)

###  Helper Functions

In [None]:
def accuracy(outputs, labels):
    """
    Calculate the accuracy of the model's predictions.

    Args:
        outputs (torch.Tensor): Model outputs.
        labels (torch.Tensor): Ground truth labels.

    Returns:
        torch.Tensor: Accuracy of the model.
    """
    _, preds = torch.max(outputs, dim=1)
    return torch.tensor(torch.sum(preds == labels).item() / len(preds))

def training_step(model, batch):
    """
    Perform a single training step.

    Args:
        model (torch.nn.Module): The neural network model.
        batch (tuple): A tuple containing batch of input images and labels.

    Returns:
        torch.Tensor: Loss of the model on the batch.
    """
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, labels)
    return loss

def validation_step(model, batch):
    """
    Perform a single validation step.

    Args:
        model (torch.nn.Module): The neural network model.
        batch (tuple): A tuple containing batch of input images and labels.

    Returns:
        dict: Dictionary containing the loss and accuracy of the model on the batch.
    """
    images, labels = batch
    images, labels = images.to(device), labels.to(device)
    out = model(images)
    loss = F.cross_entropy(out, labels)
    acc = accuracy(out, labels)
    return {'Loss': loss.detach(), 'Acc': acc}

def validation_epoch_end(model, outputs):
    """
    Calculate the average loss and accuracy over all batches in a validation epoch.

    Args:
        model (torch.nn.Module): The neural network model.
        outputs (list): List of dictionaries containing the loss and accuracy of each validation batch.

    Returns:
        dict: Dictionary containing the average loss and accuracy over all batches.
    """
    batch_losses = [x['Loss'] for x in outputs]
    epoch_loss = torch.stack(batch_losses).mean()
    batch_accs = [x['Acc'] for x in outputs]
    epoch_acc = torch.stack(batch_accs).mean()
    return {'Loss': epoch_loss.item(), 'Acc': epoch_acc.item()}

def epoch_end(model, epoch, result):
    """
    Print the epoch summary.

    Args:
        model (torch.nn.Module): The neural network model.
        epoch (int): Current epoch number.
        result (dict): Dictionary containing the training and validation metrics for the epoch.
    """
    print("Epoch [{}], last_lr: {:.5f}, train_loss: {:.4f}, val_loss: {:.4f}, val_acc: {:.4f}".format(
        epoch, result['lrs'][-1], result['train_loss'], result['Loss'], result['Acc']))

def distance(model, model0):
    """
    Calculate the distance between two models.

    Args:
        model (torch.nn.Module): The first neural network model.
        model0 (torch.nn.Module): The second neural network model.

    Returns:
        float: Normalized distance between the two models.
    """
    distance = 0
    normalization = 0
    for (k, p), (k0, p0) in zip(model.named_parameters(), model0.named_parameters()):
        space = '  ' if 'bias' in k else ''
        current_dist = (p.data - p0.data).pow(2).sum().item()
        current_norm = p.data.pow(2).sum().item()
        distance += current_dist
        normalization += current_norm
    print(f'Distance: {np.sqrt(distance)}')
    print(f'Normalized Distance: {1.0 * np.sqrt(distance / normalization)}')
    return 1.0 * np.sqrt(distance / normalization)


### Evaluate Function

In [None]:
@torch.no_grad()
def evaluate(model, val_loader):
    """
    Evaluate the model on the validation dataset.

    Args:
        model (torch.nn.Module): The neural network model.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.

    Returns:
        dict: Dictionary containing the average loss and accuracy over the validation dataset.
    """
    model.eval()
    outputs = [validation_step(model, batch) for batch in val_loader]
    return validation_epoch_end(model, outputs)

### Get Learning Rate Function

In [None]:
def get_lr(optimizer):
    """
    Get the current learning rate of the optimizer.

    Args:
        optimizer (torch.optim.Optimizer): The optimizer.

    Returns:
        float: The current learning rate.
    """
    for param_group in optimizer.param_groups:
        return param_group['lr']


### Fit One Cycle Function


In [None]:
def fit_one_cycle(epochs, max_lr, model, train_loader, val_loader,
                  weight_decay=0, grad_clip=None, opt_func=torch.optim.SGD):
    """
    Train the model using the One Cycle Policy.

    Args:
        epochs (int): Number of epochs to train.
        max_lr (float): Maximum learning rate.
        model (torch.nn.Module): The neural network model.
        train_loader (torch.utils.data.DataLoader): DataLoader for the training dataset.
        val_loader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
        weight_decay (float): Weight decay (L2 regularization) factor (default: 0).
        grad_clip (float or None): Gradient clipping value (default: None).
        opt_func (torch.optim.Optimizer): The optimizer class (default: torch.optim.SGD).

    Returns:
        list: List of dictionaries containing training and validation metrics for each epoch.
    """
    torch.cuda.empty_cache()
    history = []

    optimizer = opt_func(model.parameters(), max_lr, weight_decay=weight_decay)

    sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, verbose=True)

    for epoch in range(epochs):
        model.train()
        train_losses = []
        lrs = []
        for batch in train_loader:
            loss = training_step(model, batch)
            train_losses.append(loss)
            loss.backward()

            if grad_clip:
                nn.utils.clip_grad_value_(model.parameters(), grad_clip)

            optimizer.step()
            optimizer.zero_grad()

            lrs.append(get_lr(optimizer))


        # Validation phase
        result = evaluate(model, val_loader)
        result['train_loss'] = torch.stack(train_losses).mean().item()
        result['lrs'] = lrs
        epoch_end(model, epoch, result)
        history.append(result)
        sched.step(result['Loss'])
    return history


## Training & Loading Model


### Downloading The Dataset

In [None]:
# Dowload the dataset
dataset_url = "https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')

# Extract from archive
with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
    tar.extractall(path='./data')

# Look into the data directory
data_dir = './data/cifar10'
print(os.listdir(data_dir))
classes = os.listdir(data_dir + "/train")
print(classes)

"""
Loading the CIFAR-10 dataset with a collection of 60,000 32x32 color images in 10 classes, with 6,000 images per class. 
The classes are: 'airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', and 'truck'.
"""

### Transforming the DataSet

In [None]:
# Define the transformation for training images
transform_train = tt.Compose([
    tt.ToTensor(),  # Convert the image to a PyTorch tensor
    # Normalize the image with the given mean and standard deviation
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

# Define the transformation for testing images
transform_test = tt.Compose([
    tt.ToTensor(),  # Convert the image to a PyTorch tensor
    # Normalize the image with the given mean and standard deviation
    tt.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])


In [None]:
# Create training dataset using ImageFolder with the specified transformations
train_ds = ImageFolder(data_dir+'/train', transform_train)

# Create validation dataset using ImageFolder with the specified transformations
valid_ds = ImageFolder(data_dir+'/test', transform_test)


In [None]:
# Define the batch size
batch_size = 256

# Create training data loader with the specified batch size, shuffling, and other parameters
train_dl = DataLoader(train_ds, batch_size, shuffle=True, num_workers=3, pin_memory=True)

# Create validation data loader with the specified batch size, shuffling, and other parameters
valid_dl = DataLoader(valid_ds, batch_size*2, num_workers=3, pin_memory=True)
