### Imports and Setup for Image Classification in PyTorch

In [None]:
import os
import torch
import torchvision
import tarfile
import torch.nn as nn
import numpy as np
import torch.nn.functional as F
from torchvision.datasets.utils import download_url
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import torchvision.transforms as tt
from torch.utils.data import random_split
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
matplotlib.rcParams['figure.facecolor']='#ffffff'

### Download the CIFAR-10 dataset
`dataset_url` stores the URL of the CIFAR-10 `.tgz` archive. `download_url` downloads it to the current directory.

In [None]:
dataset_url="https://s3.amazonaws.com/fast-ai-imageclas/cifar10.tgz"
download_url(dataset_url, '.')

with tarfile.open('./cifar10.tgz', 'r:gz') as tar:
  tar.extractall(path='./data')

data_dir="./data/cifar10"
print(os.listdir(data_dir))

classes=os.listdir(data_dir + "/train")
print(classes)

### Define image augmentation and preprocessing for training
Applies standard image augmentation and preprocessing to make the model generalize better :
- `RandomCrop` : Randomly crops a `32*32` region with reflection padding.
- `RandomHorizontalFlip` : Randomly flips the image horizontally.
- `ToTensor` : Converts the image to a PyTorch tensor.
- `Normalize` : Standardizes pixel values using the given channel-wise mean and std.

In [None]:
stats=((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_and_val_tfms=tt.Compose([tt.RandomCrop(32, padding=4, padding_mode='reflect'),
                       tt.RandomHorizontalFlip(),
                       tt.RandomRotation(10),
                       tt.ToTensor(),
                       tt.Normalize(*stats, inplace=True)])
test_tfms=tt.Compose([tt.ToTensor(), tt.Normalize(*stats, inplace=True)])

### Create the datasets
Loads the CIFAR-10 images from folders using `ImageFolder`.

Each subfolder name is treated as the class label.

Applies respective transformations to training and validation sets.

In [None]:
train_and_val_ds=ImageFolder(data_dir+'/train', train_and_val_tfms)
test_ds=ImageFolder(data_dir+'/test', test_tfms)

train_size=int(0.8*len(train_and_val_ds))
val_size=len(train_and_val_ds)-train_size

training_ds, validation_ds=random_split(train_and_val_ds, [train_size, val_size])

### Create data loaders
Wraps the datasets with `DataLoader` to enable efficient mini-batch processing :

- `batch_size = 128` : Loads data in batches of 128 samples.
- `shuffle=True` : Randomizes the order of training data for better generalization.
- `num_workers=2` : Loads data in parallel using 2 subprocesses to improve speed.
- `pin_memory=True` : Speeds up data transfer to GPU (if using CUDA).

In [None]:
batch_size=128
training_dl=DataLoader(training_ds, batch_size, shuffle=True, num_workers=2, pin_memory=True)
validation_dl=DataLoader(validation_ds, batch_size, num_workers=2, pin_memory=True)
test_dl=DataLoader(test_ds, batch_size, num_workers=2, pin_memory=True)

### Define a denormalization function
Reverses the normalization applied to image tensors so they can be visualized correctly.

In [None]:
def denormalize(images, means, stds):
  means=torch.tensor(means).reshape(1, 3, 1, 1)
  stds=torch.tensor(stds).reshape(1, 3, 1, 1)
  return images*stds+means

### Let us see data of a batch

In [None]:
def show_batch(dl):
  for i, _ in dl:
    fig, ax=plt.subplots(figsize=(16, 8))
    ax.set_xticks([])
    ax.set_yticks([])
    denorm_i=denormalize(i, *stats)
    ax.imshow(make_grid(denorm_i, nrow=16).permute(1, 2, 0).clamp(0, 1))
    break

In [None]:
show_batch(training_dl)

### Load data batches onto GPU (if available)

In [None]:
torch.cuda.is_available()

In [None]:
def default_device():
  if torch.cuda.is_available():
    return torch.device("cuda")
  return torch.device("cpu")

In [None]:
device=default_device()
device

In [None]:
def to_device(data, device):
  if isinstance(data, (list, tuple)):
    return [to_device(x, device) for x in data]
  return data.to(device, non_blocking=True)

In [None]:
for i, _ in training_dl:
  print(i.shape)
  print(i.device)
  i=to_device(i, device)
  print(i.device)
  break

In [None]:
class deviceDataLoader():

  def __init__(self, dl, device):
    self.dl=dl
    self.device=device

  def __iter__(self):
    for i in self.dl:
      yield to_device(i, self.device)

  def __len__(self):
    return len(self.dl)

In [None]:
device=default_device()

In [None]:
training_dl=deviceDataLoader(training_dl, device)
validation_dl=deviceDataLoader(validation_dl, device)

### An example of a simple residual block

In [None]:
class simpleResidualBlock(nn.Module):
  def __init__(self):
    super().__init__()
    self.c1=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
    self.r1=nn.ReLU()
    self.c2=nn.Conv2d(3, 3, kernel_size=3, stride=1, padding=1)
    self.r2=nn.ReLU()

  def forward(self, x):
    out=self.c1(x)
    out=self.r1(out)
    out=self.c2(out)
    return self.r2(out)+x

In [None]:
simpleResnet=to_device(simpleResidualBlock(), device)

In [None]:
for i, _ in training_dl:
  print(i.shape)
  out=simpleResnet(i)
  print(out.shape)
  break

In [None]:
del simpleResnet, i, _
torch.cuda.empty_cache()

### Define accuracy metric
Calculates the accuracy of model predictions compared to the true labels.

In [None]:
def accuracy(output, label):
  _, prediction=torch.max(output, dim=1)
  return torch.tensor(torch.sum(prediction==label).item()/len(prediction))

### Define a base class for image classification models
This base class extends `nn.Module` and includes standard training and evaluation methods used in image classification tasks.

In [None]:
class imageClassiicationBase(nn.Module):

  def training_step(self, batch):
    image, label=batch
    out=self(image)
    loss=F.cross_entropy(out, label)
    return loss

  def validation_step(self, batch):
    image, label=batch
    out=self(image)
    loss=F.cross_entropy(out, label)
    acc=accuracy(out, label)
    return {"Loss":loss, "Accuracy":acc}

  def validation_epoch_end(self, output):
    all_batch_loss=[i['Loss'] for i in output]
    all_batch_acc=[i['Accuracy'] for i in output]
    mean_loss=torch.stack(all_batch_loss).mean()
    mean_acc=torch.stack(all_batch_acc).mean()
    return {"Mean_loss":mean_loss, "Mean_accuracy":mean_acc}

  def epoch_end(self, epoch, result):
    print(f'Epoch : {epoch}, Mean_loss : {result["Mean_loss"]:.4f}, Mean_accuracy : {result["Mean_accuracy"]:.4f}')

### ResNet9 implementation

In [None]:
def convBlock(in_c, out_c, pool=False):
  layers=[nn.Conv2d(in_c, out_c, kernel_size=3, padding=1),
          nn.BatchNorm2d(out_c),
          nn.ReLU(inplace=True)]
  if pool:
    layers.append(nn.MaxPool2d(2))
  return nn.Sequential(*layers)

class ResNet9(imageClassiicationBase):
  def __init__(self, in_c, num_classes):
    super().__init__()

    self.c1=convBlock(in_c, 64)
    self.c2=convBlock(64, 128, pool=True)
    self.r1=nn.Sequential(convBlock(128, 128), convBlock(128, 128))

    self.c3=convBlock(128, 256, pool=True)
    self.c4=convBlock(256, 512, pool=True)
    self.r2=nn.Sequential(convBlock(512, 512), convBlock(512, 512))

    self.classifier=nn.Sequential(nn.MaxPool2d(4),
                                  nn.Flatten(),
                                  nn.Dropout(0.2),
                                  nn.Linear(512, num_classes))

  def forward(self, batch):
    out=self.c1(batch)
    out=self.c2(out)
    out=self.r1(out)+out
    out=self.c3(out)
    out=self.c4(out)
    out=self.r2(out)+out
    out=self.classifier(out)
    return out

In [None]:
model=to_device(ResNet9(3, 10), device)
model

### Evaluate model on validation set
Sets the model to evaluation mode and disables gradient tracking.

Calls `validation_step` on each batch and aggregates results using `validation_epoch_end`.

The function `fit_one_cycle` is a training loop in deep learning that uses the One Cycle Policy for learning rate scheduling.

In [None]:
def evaluate(model, validate_loader):
  model.eval()
  with torch.no_grad():
    output=[model.validation_step(batch) for batch in validate_loader]
  return model.validation_epoch_end(output)

def get_lr(optimizer):
  for i in optimizer.param_groups:
    return i['lr']

def fit_one_cycle(epoch, max_lr, model, train_loader, validate_loader, wt_decay=0, grad_clip=None, opt_func=torch.optim.SGD):

  torch.cuda.empty_cache()
  history=[]
  optimizer=opt_func(model.parameters(), max_lr, weight_decay=wt_decay)
  sched=torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr, epochs=epoch, steps_per_epoch=len(training_dl))

  for i in range(epoch):

    model.train()
    losses=[]
    lrs=[]

    for batch in train_loader:
      optimizer.zero_grad()
      loss=model.training_step(batch)
      losses.append(loss)
      loss.backward()

      if grad_clip:
        nn.utils.clip_grad_value_(model.parameters(), grad_clip)

      optimizer.step()
      lrs.append(get_lr(optimizer))
      sched.step()

    result=evaluate(model, validate_loader)
    result["Training_loss"]=torch.stack(losses).mean().item()
    result["LRs"]=lrs
    model.epoch_end(i+1, result)
    history.append(result)
  return history

### Training phase

In [None]:
result_0=evaluate(model, validation_dl)

In [None]:
result_0

In [None]:
epoch=10
max_lr=0.01
grad_clip=0.1
wt_decay=10**(-4)
opt_func=torch.optim.Adam

In [None]:
history=fit_one_cycle(epoch, max_lr, model, training_dl, validation_dl, wt_decay, grad_clip, opt_func)

### Evaluation of accuracy

In [None]:
accuracies=[i["Mean_accuracy"] for i in history]
plt.plot(accuracies, "r.-")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.title("Graph of accuracy against epoch")
plt.show()

### Compare the losses

In [None]:
mean_losses=[i["Mean_loss"].item() for i in history]
training_losses=[i["Training_loss"] for i in history]
plt.plot(mean_losses, "b.-", label="validation_loss")
plt.plot(training_losses, "k.-", label="training_loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Graph of loss against epoch")
plt.legend()
plt.show()

### Learning rate scheduling graph

In [None]:
lrs=[i for j in history for i in j["LRs"]]
plt.plot(lrs, "y-")
plt.xlabel("Epoch")
plt.ylabel("Learning rate")
plt.title("Graph of average Learning rate against epoch")
plt.show()

In [None]:
test_dl=deviceDataLoader(test_dl, device)

### Prediction accuracy

In [None]:
output=evaluate(model, test_dl)

In [None]:
output

## The Model can hit 87-90% accuracy on datasets like `CIFAR10` in under 5-6 minutes.