### Image classification project using the CIFAR-100 Dataset

In [314]:
import torch
from torch import nn
import torchvision
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from PIL import Image

import numpy as np

### Data set
* For this project I will be using the CIFAR-100 Dataset
* 60,000 images in total - 100 classes, 600 images per class,  10,000 test images and 50,000 training images

In [None]:
# Mean and std for CIFAR-100
mean = (0.5071, 0.4865, 0.4409)
std = (0.2673, 0.2564, 0.2762)

transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

train_data = datasets.CIFAR100(
    root="data",
    train=True,
    download=True,
    transform=transform_train,
    target_transform=None
)

test_data = datasets.CIFAR100(
    root="data",
    train=False,
    download=True,
    transform=transform_test,
    target_transform=None
)

In [407]:
def unnormalize_img(img, mean, std):
        """Unnormalize a tensor image."""
        mean = torch.tensor(mean).view(1, 1, -1)
        std = torch.tensor(std).view(1, 1, -1)
        return img * std + mean


In [408]:
image, label = train_data[0]

In [None]:
class_names = train_data.classes
class_names

In [410]:
class_to_idx = train_data.class_to_idx

In [None]:
# Plotting some examples of the data in gray scale
torch.manual_seed(0)
fig = plt.figure(figsize=(9,9))
rows, cols = 4, 4

for i in range(1, rows*cols+1):
    rand_idx = torch.randint(0, len(train_data), size=[1]).item()
    img, label = train_data[rand_idx]
    img = unnormalize_img(img.permute(1, 2, 0), mean, std)
    fig.add_subplot(rows, cols, i)
    plt.imshow(img.squeeze(), cmap="gray")
    plt.title(class_names[label])
    plt.axis(False)


### Preparing the DataLoader
* Convert the data into a python iterable, and batches

In [412]:
from torch.utils.data import DataLoader
BATCH_SIZE = 32
train_dataloader = DataLoader(dataset=train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True)

test_dataloader = DataLoader(dataset=test_data,
                              batch_size=BATCH_SIZE,
                              shuffle=False)

In [None]:
print(len(train_dataloader), len(test_dataloader))

In [None]:
# Example 
train_features_batch, train_labels_batch = next(iter(train_dataloader))

rand_idx = torch.randint(0, len(train_features_batch), size=[1]).item()
img, label = train_features_batch[rand_idx], train_labels_batch[rand_idx]
img = img.permute(1, 2, 0)
plt.imshow(img)
plt.title(class_names[label])
plt.axis(False)
print(f"Image size: {img.shape}")
print(f"Label: {label}, label size {label.shape}")

### Setting up the model 

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

#### Custom training and test step functions for later use

In [416]:
def train_step(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               optimizer: torch.optim.Optimizer,
               scheduler: torch.optim.lr_scheduler._LRScheduler, 
               accuracy_fn,
               device: torch.device = device):
    """Performs a training with model traing to learn on data_loader"""
    train_loss , train_acc = 0, 0 

    model.train()

    for batch, (X, y) in enumerate(data_loader):

        X, y = X.to(device), y.to(device)

        y_pred = model(X)

        loss = loss_fn(y_pred, y)
        train_loss += loss 
        train_acc += accuracy_fn(y_true=y,
                                 y_pred=y_pred.argmax(dim=1)) # from logits -> prediction labels
        

        optimizer.zero_grad()

        loss.backward()

        optimizer.step()
    
    scheduler.step()
    # average out train loss over all the train data
    train_loss /= len(data_loader)
    train_acc /= len(data_loader)

    print(f"Train loss: {train_loss:.5f} | Train acc: {train_acc:.2f}")

In [417]:
def test_step(model: torch.nn.Module,
              data_loader: torch.utils.data.DataLoader,
              loss_fn: torch.nn.Module,
              accuracy_fn,
              device: torch.device = device):
    """Performs a testing with model traing to learn on data_loader"""

    test_loss, test_acc = 0, 0

    model.eval()

    with torch.inference_mode():
        for X, y in data_loader:
            X, y = X.to(device), y.to(device)

            test_pred = model(X)

            test_loss += loss_fn(test_pred, y)
            test_acc += accuracy_fn(y_true=y, 
                                    y_pred=test_pred.argmax(dim=1))

        # Calculate the test loss
        test_loss /= len(data_loader)
        test_acc /= len(data_loader)
        print(f"Test loss: {test_loss:.4f} | Test acc {test_acc:.2f}")

In [425]:
class SqueezeExcitationBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch_size, channels, _, _ = x.size()
        # Global average pooling
        y = self.global_avg_pool(x).view(batch_size, channels)
        # Fully connected layers with reduction
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(batch_size, channels, 1, 1)
        # Scale the input features
        return x * y

class ResidualSEBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1, reduction=16):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(p=0.3)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.se = SqueezeExcitationBlock(out_channels, reduction)

        # Shortcut connection
        self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False) \
                        if in_channels != out_channels or stride != 1 else nn.Identity()

    def forward(self, x):
        shortcut = self.shortcut(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = nn.ReLU()(x)
        x = self.dropout(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.se(x)  # Apply SE block
        return nn.ReLU()(x + shortcut)

class CIFAR100Model(nn.Module):
    def __init__(self, input_shape: int, width_multiplier: int, num_classes: int):
        super().__init__()
        
        # Stem stage for initial feature extraction
        self.stem = nn.Sequential(
            nn.Conv2d(in_channels=input_shape, out_channels=16, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(16),
            nn.ReLU()
        )
        
        # Stage 1
        self.stage1 = self._make_stage(16, 16 * width_multiplier, num_blocks=2, stride=1)
        
        # Stage 2
        self.stage2 = self._make_stage(16 * width_multiplier, 32 * width_multiplier, num_blocks=2, stride=2)
        
        # Stage 3
        self.stage3 = self._make_stage(32 * width_multiplier, 64 * width_multiplier, num_blocks=2, stride=2)
        
        # Global pooling and classifier
        self.global_pool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(p=0.3)
        self.fc = nn.Linear(64 * width_multiplier, num_classes)

    def _make_stage(self, in_channels, out_channels, num_blocks, stride):
        layers = []
        for i in range(num_blocks):
            layers.append(
                ResidualSEBlock(
                    in_channels=in_channels if i == 0 else out_channels,
                    out_channels=out_channels,
                    stride=stride if i == 0 else 1
                )
            )
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.stem(x)
        x = self.stage1(x)
        x = self.stage2(x)
        x = self.stage3(x)
        x = self.global_pool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x) 
        x = self.fc(x)
        return x

In [None]:
torch.manual_seed(0)
model = CIFAR100Model(input_shape=3, 
                      width_multiplier=10,
                      num_classes=len(class_names)).to(device)
model

#### Calculating what number of in_features needed for our classifier layer

In [None]:
import torch.nn.functional as F

rand_image_tensor_resized = F.interpolate(rand_image_tensor.unsqueeze(0), size=(32, 32), mode='bilinear')

rand_image_tensor_rgb = rand_image_tensor_resized.repeat(1, 3, 1, 1) 

output = model(rand_image_tensor_rgb.to(device))
print(output.shape) 


### Setting up a loss function and optimizer

In [428]:
def accuracy_fn(y_true, y_pred):
    """
    Calculates accuracy between truth labels and predictions.
    """
    correct = torch.eq(y_true, y_pred).sum().item()
    acc = (correct / len(y_pred)) * 100
    return acc

In [429]:
loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50)


In [430]:
def print_train_time(start: float,
                     end: float,
                     device: torch.device = None):
    """Prints difference between start and end time."""
    total_time = end - start
    print(f"Train time on {device}: {total_time:.3f} seconds")
    return total_time

In [None]:
X, y = next(iter(train_dataloader))  # Get a batch of data
X = X.to(device)
output = model(X)  # Pass through the model
print(output.shape)  # Should be [batch_size, 100] for CIFAR-100


In [None]:
torch.manual_seed(2)
torch.cuda.manual_seed(2)

from timeit import default_timer as timer
from tqdm.auto import tqdm

train_time_model_start_2 = timer()
epochs = 115

for epoch in tqdm(range(epochs)):
    print(f"Epoch {epoch}")
    train_step(model=model,
               data_loader=train_dataloader,
               loss_fn=loss_fn,
               optimizer=optimizer,
               scheduler=scheduler,
               accuracy_fn=accuracy_fn,
               device=device)
    test_step(model=model,
               data_loader=test_dataloader,
               loss_fn=loss_fn,
               accuracy_fn=accuracy_fn,
               device=device)

train_time_model_end_2 = timer()

total_train_time_model_2 = print_train_time(start=train_time_model_start_2,
                                            end=train_time_model_end_2,
                                            device=device)


In [450]:
import torch.utils
torch.manual_seed(0)
def eval_model(model: torch.nn.Module,
               data_loader: torch.utils.data.DataLoader,
               loss_fn: torch.nn.Module,
               accuracy_fn,
               device=device):
    """Returs a dict conatining the results of model prediciton on data_loader"""
    loss, acc = 0, 0
    model.eval()
    with torch.inference_mode():
        for X, y in data_loader:
            #Make our data device agnostic
            X, y = X.to(device), y.to(device)
            # Make predicitoins
            y_pred = model(X)

            # Accululate the loss and acc values per batch
            loss += loss_fn(y_pred, y)
            acc += accuracy_fn(y_true=y,
                               y_pred=y_pred.argmax(dim=1))
            
        
        loss /= len(data_loader)
        acc /= len(data_loader)

    return {"model_name" : model.__class__.__name__, # only works when model was created with a class,
            "model_loss" : loss.item(),
            "model_acc" : acc}

In [None]:
model = CIFAR100Model(input_shape=3,  # Color channel (image.shape (1))
                           width_multiplier=10,
                           num_classes=len(class_names)).to(device)
MODEL_NAME = "CIFAR100_model.pth"
MODEL_PATH = Path("models")
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

model.load_state_dict(torch.load(MODEL_SAVE_PATH)) 


In [None]:
model_results = eval_model(model=model,
                             data_loader=test_dataloader,
                             loss_fn=loss_fn,
                             accuracy_fn=accuracy_fn,
                             device=device)
model_results

In [467]:
def make_predictions(model: torch.nn.Module,
                     data: list,
                     device: torch.device = device):
    pred_probs = []
    model.eval()
    with torch.inference_mode():
        for sample in data:
            # prepare sample, add batch dimention and pass to target device
            sample = torch.unsqueeze(sample, dim=0).to(device)

            # forward pass
            pred_logit = model(sample)

            # get pred prob (logit -> prediction probability)
            pred_prob = torch.softmax(pred_logit.squeeze(), dim=0)

            pred_probs.append(pred_prob.cpu())

    return torch.stack(pred_probs)

In [None]:
import random
random.seed(1)
test_samples = []
test_labels = []

for sample, label in random.sample(list(test_data), k=81):
    test_samples.append(sample)
    test_labels.append(label)

test_samples[0].shape

In [None]:
sample = unnormalize_img(test_samples[0].permute(1, 2, 0), mean, std)
plt.imshow(sample)
plt.title(class_names[test_labels[0]])
plt.show()


In [470]:
pred_probs = make_predictions(model=model,
                              data=test_samples)


In [None]:
pred_classses = pred_probs.argmax(dim=1)
pred_classses

In [None]:
# Plot prediciotns
plt.figure(figsize=(81, 81))
nrows = 9
ncols = 9
for i, sample in enumerate(test_samples):
    plt.subplot(nrows, ncols, i+1)

    sample = unnormalize_img(sample.permute(1, 2, 0), mean, std)
    plt.imshow(sample)

    #predition label in text form
    pred_label = class_names[pred_classses[i]]

    # get the truth label
    truth_label = class_names[test_labels[i]]

    title_text = f"pred: {pred_label} | truth: {truth_label}"

    if pred_label == truth_label:
        plt.title(title_text, fontsize=30, c="g")
    else:
        plt.title(title_text, fontsize=30, c="r")

    plt.axis(False)

### Saving the model

In [None]:
from pathlib import Path

MODEL_PATH = Path("models")
MODEL_PATH.mkdir(parents=True, exist_ok=True)

MODEL_NAME = "CIFAR100_model.pth"
MODEL_SAVE_PATH = MODEL_PATH / MODEL_NAME

print(f"Saving model to: {MODEL_SAVE_PATH}")
torch.save(obj=model.state_dict(),
           f=MODEL_SAVE_PATH)


### Loading it back

In [None]:
loaded_model = CIFAR100Model(input_shape=3,  # Color channel (image.shape (1))
                           width_multiplier=10,
                           num_classes=len(class_names)).to(device)
MODEL_NAME = "CIFAR100_model_(70_percent_acc).pth"
MODEL_PATH = Path("models")

loaded_model.load_state_dict(torch.load(MODEL_SAVE_PATH))
