[![Fixel Algorithms](https://i.imgur.com/AqKHVZ0.png)](https://fixelalgorithms.gitlab.io)

# AI Program

## Deep Learning - Vision Transformer (ViT) - Classification of MNIST

> Notebook by:
> - Royi Avital RoyiAvital@fixelalgorithms.com

## Revision History

| Version | Date       | User        |Content / Changes                                                   |
|---------|------------|-------------|--------------------------------------------------------------------|
| 1.0.000 | 21/09/2025 | Royi Avital | First version                                                      |

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/FixelAlgorithmsTeam/FixelCourses/blob/master/AIProgram/2024_02/0089DeepLearningPyTorchSchedulers.ipynb)

In [None]:
# Import Packages

# General Tools
import numpy as np
import scipy as sp
import pandas as pd

# Machine Learning
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

# Deep Learning
import torch
import torch.nn            as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import torchinfo
from torchmetrics.classification import MulticlassAccuracy
import torchvision
import torchvista

# Miscellaneous
from platform import python_version
import random

# Typing
from typing import Callable, Dict, Generator, List, Literal, Optional, Self, Set, Tuple, Union
from numpy.typing import NDArray

# Visualization
import matplotlib.pyplot as plt

# Jupyter
from IPython import get_ipython

## Notations

* <font color='red'>(**?**)</font> Question to answer interactively.
* <font color='blue'>(**!**)</font> Simple task to add code for the notebook.
* <font color='green'>(**@**)</font> Optional / Extra self practice.
* <font color='brown'>(**#**)</font> Note / Useful resource / Food for thought.

Code Notations:

```python
someVar    = 2; #<! Notation for a variable
vVector    = np.random.rand(4) #<! Notation for 1D array
mMatrix    = np.random.rand(4, 3) #<! Notation for 2D array
tTensor    = np.random.rand(4, 3, 2, 3) #<! Notation for nD array (Tensor)
tuTuple    = (1, 2, 3) #<! Notation for a tuple
lList      = [1, 2, 3] #<! Notation for a list
dDict      = {1: 3, 2: 2, 3: 1} #<! Notation for a dictionary
oObj       = MyClass() #<! Notation for an object
dfData     = pd.DataFrame() #<! Notation for a data frame
dsData     = pd.Series() #<! Notation for a series
hObj       = plt.Axes() #<! Notation for an object / handler / function handler
```

### Code Exercise

 - Single line fill

```python
valToFill = ???
```

 - Multi Line to Fill (At least one)

```python
# You need to start writing
?????
```

 - Section to Fill

```python
#===========================Fill This===========================#
# 1. Explanation about what to do.
# !! Remarks to follow / take under consideration.
mX = ???

?????
#===============================================================#
```

In [None]:
# Configuration
# %matplotlib inline

seedNum = 512
np.random.seed(seedNum)
random.seed(seedNum)

# Matplotlib default color palette
lMatPltLibclr = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', '#8c564b', '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']
# sns.set_theme() #>! Apply SeaBorn theme

runInGoogleColab = 'google.colab' in str(get_ipython())

# Improve performance by benchmarking
torch.backends.cudnn.benchmark = True

# Reproducibility (Per PyTorch Version on the same device)
# torch.manual_seed(seedNum)
# torch.backends.cudnn.deterministic = True
# torch.backends.cudnn.benchmark     = False #<! Makes things slower

In [None]:
# Constants

FIG_SIZE_DEF    = (8, 8)
ELM_SIZE_DEF    = 50
CLASS_COLOR     = ('b', 'r')
EDGE_COLOR      = 'k'
MARKER_SIZE_DEF = 10
LINE_WIDTH_DEF  = 2

D_CLASSES_MNIST = {0: 'T-Shirt', 1: 'Trouser', 2: 'Pullover', 3: 'Dress', 4: 'Coat', 5: 'Sandal', 6: 'Shirt', 7: 'Sneaker', 8: 'Bag', 9: 'Boots'}
L_CLASSES_MNIST = ['T-Shirt', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Boots']

T_IMG_SIZE_MNIST = (28, 28)

DATA_FOLDER_PATH  = 'Data'
TENSOR_BOARD_BASE = 'TB'
WANDB_API_KEY     = 'WANDB_API_KEY'

In [None]:
# Download Auxiliary Modules for Google Colab
if runInGoogleColab:
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataManipulation.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DataVisualization.py
    !wget https://raw.githubusercontent.com/FixelAlgorithmsTeam/FixelCourses/master/AIProgram/2024_02/DeepLearningPyTorch.py

In [None]:
# Courses Packages

from DataVisualization import PlotConfusionMatrix, PlotLabelsHistogram, PlotMnistImages
from DeepLearningPyTorch import TrainModel

In [None]:
# General Auxiliary Functions


## The Vision Transformer (ViT)

fd


</br>

* <font color='brown'>(**#**)</font> Notes.

In [None]:
# Parameters

# Data
numSamplesTrain = 65_000
numSamplesVal   = 5_000

# Model
patchSize = 7
embedDim  = 96
hiddenDim = 160
numHeads  = 12 #<! `embedDim` must be divisible by `numHeads`
numLayers = 6
dropP     = 0.025 #<! Dropout Layer

# Training
batchSize = 256
numWork   = 2 #<! Number of workers
numEpochs = 80

# Visualization
numImg = 3

## Generate / Load Data

Load the [Fashion MNIST Data Set](https://github.com/zalandoresearch/fashion-mnist).  

The _Fashion MNIST Data Set_ is considerably more challenging than the original MNIST though it is still no match to Deep Learning models.

* <font color='brown'>(**#**)</font> The data set is available at [OpenML - Fashion MNIST](https://www.openml.org/search?type=data&id=40996).  
  Yet it is not separated into the original _test_ and _train_ sets.

In [None]:
# Load Data

mX, vY = fetch_openml('Fashion-MNIST', version = 1, return_X_y = True, as_frame = False, parser = 'auto')
# mX, vY = fetch_openml('mnist_784', version = 1, return_X_y = True, as_frame = False, parser = 'auto') #<! For debugging (Gets 99.1%)
vY = vY.astype(np.int_) #<! The labels are strings, convert to integer

print(f'The features data shape: {mX.shape}')
print(f'The labels data shape: {vY.shape}')
print(f'The unique values of the labels: {np.unique(vY)}')

* <font color='brown'>(**#**)</font> The images are grayscale with size `28x28`.

In [None]:
# Pre Process Data

mX = mX / 255.0
mX = mX.astype(np.float32) #<! Convert to `Float32` for PyTorch

### Plot the Data

In [None]:
# Plot the Data

hF = PlotMnistImages(mX, vY, numImg)
plt.show()

In [None]:
# Histogram of Labels

vHa = PlotLabelsHistogram(vY, lClass = L_CLASSES_MNIST)
plt.show()

In [None]:
# Train & Validation Split

numClass = len(np.unique(vY))

mXTrain, mXVal, vYTrain, vYVal = train_test_split(mX, vY, test_size = numSamplesVal, train_size = numSamplesTrain, shuffle = True, stratify = vY)

print(f'The training features data shape  : {mXTrain.shape}')
print(f'The training labels data shape    : {vYTrain.shape}')
print(f'The validation features data shape: {mXVal.shape}')
print(f'The validation labels data shape  : {vYVal.shape}')
print(f'The unique values of the labels   : {np.unique(vY)}')

In [None]:
# PyTorch DataSet

dsTrain  = torch.utils.data.TensorDataset(torch.tensor(np.reshape(mXTrain, (numSamplesTrain, 1, *T_IMG_SIZE_MNIST))), torch.tensor(vYTrain)) #<! -1 -> Infer
dsVal    = torch.utils.data.TensorDataset(torch.tensor(np.reshape(mXVal, (numSamplesVal, 1, *T_IMG_SIZE_MNIST))), torch.tensor(vYVal))

print(f'The training data set data shape  : {dsTrain.tensors[0].shape}')
print(f'The validation data set data shape: {dsVal.tensors[0].shape}')
print(f'The unique values of the labels   : {np.unique(dsTrain.tensors[1])}')

* <font color='brown'>(**#**)</font> The dataset is indexable (Subscriptable). It returns a tuple of the features and the label.

In [None]:
# Element of the Data Set

tX, valY = dsTrain[0]

print(f'The features shape: {mX.shape}')
print(f'The label value: {valY}')

### Data Loaders


In [None]:
# Data Loader

dlTrain = torch.utils.data.DataLoader(dsTrain, shuffle = True, batch_size = 1 * batchSize, num_workers = numWork, persistent_workers = True)
dlVal   = torch.utils.data.DataLoader(dsVal, shuffle = False, batch_size = 2 * batchSize, num_workers = numWork, persistent_workers = True)

* <font color='red'>(**?**)</font> Why is the size of the batch twice as big for the test dataset?

In [None]:
# Iterate on the Loader
# The first batch.
tX, vY = next(iter(dlTrain)) #<! PyTorch Tensors

print(f'The batch features dimensions: {tX.shape}')
print(f'The batch labels dimensions: {vY.shape}')

## ViT Building Blocks

### Image to Patches

In [None]:
class ImgToPatches(nn.Module):
    def __init__(self, tuImgSize: Tuple[int, int], patchSize: int, inChans: int = 1, embedDim: int = 768) -> None:
        super(ImgToPatches, self).__init__()
        
        self.imgHeight, self.imgWidth = tuImgSize
        self.patchSize = patchSize
        self.inChans = inChans
        self.embedDim = embedDim
        self.numPatches = (self.imgHeight // self.patchSize) * (self.imgWidth // self.patchSize)

        # Per patch, linear model
        self.projLayer = nn.Conv2d(inChans, embedDim, kernel_size = patchSize, stride = patchSize)

    def forward(self, tX: torch.Tensor) -> torch.Tensor:

        # Only for Debugging
        # batchSize, numChannel, imgHeight, imgWidth = tX.shape
        # assert (imgHeight == self.imgHeight and imgWidth == self.imgWidth), f'Input image size ({imgHeight}, {imgWidth}) does not match model ({self.imgHeight}, {self.imgWidth}).'
        # assert (imgHeight % self.patchSize == 0 and imgWidth % self.patchSize == 0), f'Image dimensions must be divisible by the patch size ({self.patchSize}, {self.patchSize}).'
        # assert (numChannel == self.inChans), f'Input image channels ({numChannel}) does not match the expected number of channels.'

        tX = self.projLayer(tX) #<! Shape: (B, embedDim, H // patchSize, W // patchSize)
        tX = tX.flatten(2)      #<! Shape: (B, embedDim, numPatches)
        tX = tX.transpose(1, 2) #<! Shape: (B, numPatches, embedDim)

        return tX

In [None]:
# Image to Patches Function
def ImgToPatches(tX: torch.Tensor, patchSize: int, /, *, flattenChannels: bool = True ) -> torch.Tensor:

    batchSize, numChannel, imgHeight, imgWidth = tX.shape

    tX = tX.reshape(batchSize, numChannel, imgHeight // patchSize, patchSize, imgWidth // patchSize, patchSize)
    tX = tX.permute(0, 2, 4, 1, 3, 5) #<! (B, H', W', C, p_H, p_W)
    tX = tX.flatten(1, 2)             #<! (B, H' * W', C, p_H, p_W)
    if flattenChannels:
        tX = tX.flatten(2, 4)         #<! (B, H' * W', C * p_H * p_W)

    return tX

In [None]:
# Plot Patches

tP = ImgToPatches(tX, 7, flattenChannels = False) #<! Shape: (B, numPatches, C, patchSize, patchSize)
tP = tP[:4]
print(f'The patches dimensions: {tP.shape}')

hF, vHa = plt.subplots(tP.shape[0], 1, figsize = (14, 3))
vHa     = vHa.flat
hF.suptitle(f'Batch Images as Sequences of Patches ({patchSize}, {patchSize})')
for ii, hA in enumerate(vHa):
    tImgGrid = torchvision.utils.make_grid(tP[ii], nrow = 64, normalize = True, pad_value = 0.75)
    tImgGrid = tImgGrid.permute(1, 2, 0)
    hA.imshow(tImgGrid)
    hA.axis("off")

### Transformer Encoder

Throughout the Transformer layers, the CLS token interacts with all the image patch embeddings via the self attention mechanism.  
This allows it to gather information and learn a global representation that encapsulates the context of the entire image.  
After processing through the _Transformer Encoder_, the final hidden state corresponding to the CLS token is typically passed through a linear layer to predict the image's class.

In [None]:
class TransformerEncoder(nn.Module):
    def __init__(self: Self, embedDim: int, hiddenDim: int, numHeads: int, /, *, dropoutProb: float = 0.0) -> None:
        super().__init__()

        self.layerNorm01 = nn.LayerNorm(embedDim)
        self.multiAttn = nn.MultiheadAttention(embedDim, numHeads)
        self.layerNorm02 = nn.LayerNorm(embedDim)
        self.ffNet = nn.Sequential(
            nn.Linear(embedDim, hiddenDim),
            nn.GELU(),
            nn.Dropout(dropoutProb),
            nn.Linear(hiddenDim, embedDim),
            nn.Dropout(dropoutProb),
        )

    def forward(self: Self, tX: torch.Tensor) -> torch.Tensor:
        tZ = self.layerNorm01(tX)
        tX = tX + self.multiAttn(tZ, tZ, tZ)[0] #<! Extract the Classification Token (CLS)
        tX = tX + self.ffNet(self.layerNorm02(tX))
        return tX

## Define the Model

The model is defined as a sequential model.

In [None]:
class VisionTransformer(nn.Module):
    def __init__(
        self: Self,
        patchSize: int,
        numPatches: int,
        numChannels: int,
        numClasses: int,
        /, *,
        embedDim: int = 256,
        hiddenDim: int = 512,
        numHeads: int = 8,
        numLayers: int = 6,
        dropoutProb: float = 0.0,
    ) -> None:
        """
        Simplified Vision Transformer (ViT) model.

        Input:
            patchSize   - Number of pixels that the patches have per dimension.
            numPatches  - Maximum number of patches an image can have.
            numChannels - Number of channels of the input image.
            numClasses  - Number of classes to predict.
            embedDim    - Dimensionality of the input feature vectors to the Transformer.
            hiddenDim   - Dimensionality of the hidden layer in the feed forward networks within the Transformer.
            numHeads    - Number of heads to use in the Multi Head Attention block.
            numLayers   - Number of layers to use in the Transformer.
            dropoutProb - Probability of dropout to apply in the feed forward network and on the input encoding.

        """
        super().__init__()

        self.patchSize = patchSize

        # Layers / Networks
        self.vitEmbeddder = nn.Linear(numChannels * (patchSize * patchSize), embedDim)
        self.vitEncoder = nn.Sequential(
            *(TransformerEncoder(embedDim, hiddenDim, numHeads, dropoutProb = dropoutProb) for _ in range(numLayers))
        )
        self.mlpHead = nn.Sequential(nn.LayerNorm(embedDim), nn.Linear(embedDim, numClasses))
        self.dropout = nn.Dropout(dropoutProb)

        # Parameters / Embeddings
        self.clsToken = nn.Parameter(torch.randn(1, 1, embedDim)) #<! Learnable CLS Token
        self.posEmbedding = nn.Parameter(torch.randn(1, 1 + numPatches, embedDim)) #<! Learnable Positional Encoding

    def forward(self, x):
        # Preprocess input
        x = ImgToPatches(x, self.patchSize)
        B, T, _ = x.shape
        x = self.vitEmbeddder(x)

        # Add CLS token and positional encoding
        cls_token = self.clsToken.repeat(B, 1, 1)
        x = torch.cat([cls_token, x], dim=1)
        x = x + self.posEmbedding[:, : T + 1]

        # Apply Transforrmer
        x = self.dropout(x)
        x = x.transpose(0, 1)
        x = self.vitEncoder(x)

        # Perform classification prediction
        cls = x[0]
        out = self.mlpHead(cls)
        return out

In [None]:
# Set the Model
oModel = VisionTransformer(
    patchSize,
    (T_IMG_SIZE_MNIST[0] // patchSize) * (T_IMG_SIZE_MNIST[1] // patchSize),
    1,
    numClass,
    embedDim    = embedDim,
    hiddenDim   = hiddenDim,
    numHeads    = numHeads,
    numLayers   = numLayers,
    dropoutProb = dropP,
)

In [None]:
# Model Summary
torchinfo.summary(oModel, tX.shape, col_names = ['kernel_size', 'input_size', 'output_size', 'num_params'], device = 'cpu') #<! Added `kernel_size`

In [None]:
# View Model

torchvista.trace_model(oModel, tX)

* <font color='red'>(**?**)</font> The Encoder output has shape of `(17, batchSize, embedDim)`. Explain the `17`.

<!-- The input sequence if length 17. The reason is the `patchSize = 7` and teh images is `(28, 28)` hence there are `4 * 4` patches and `CLS` token each of dimension 256. -->

## Train the Model

This section trains the model using different schedulers:

 - Updates the training function.
 - Updates the _epoch_ function to log information at mini batch level.
 - Create a class for a logger of TensorBoard.

In [None]:
# Run Device

runDevice = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') #<! The 1st CUDA device
oModel = oModel.to(runDevice) #<! Transfer model to device

In [None]:
# Loss and Score Function

hL = nn.CrossEntropyLoss()
hS = MulticlassAccuracy(num_classes = len(L_CLASSES_MNIST), average = 'micro')
hL = hL.to(runDevice) #<! Not required!
hS = hS.to(runDevice)

In [None]:
# Define Optimizer & Scheduler

oOpt = torch.optim.AdamW(oModel.parameters(), lr = 1e-3, betas = (0.9, 0.99), weight_decay = 1e-3) #<! Define optimizer
oSch = torch.optim.lr_scheduler.OneCycleLR(oOpt, max_lr = 9.5e-3, total_steps = numEpochs)

In [None]:
# Train the Model

oRunModel, lTrainLoss, lTrainScore, lValLoss, lValScore, lLearnRate = TrainModel(oModel, dlTrain, dlVal, oOpt, numEpochs, hL, hS, oSch = oSch)

In [None]:
# Plot Training Phase

hF, vHa = plt.subplots(nrows = 1, ncols = 3, figsize = (12, 5))
vHa = vHa.flat

hA = vHa[0]
hA.plot(lTrainLoss, lw = 2, label = 'Train')
hA.plot(lValLoss, lw = 2, label = 'Validation')
hA.set_title('Binary Cross Entropy Loss')
hA.set_xlabel('Epoch')
hA.set_ylabel('Loss')
hA.legend()

hA = vHa[1]
hA.plot(lTrainScore, lw = 2, label = 'Train')
hA.plot(lValScore, lw = 2, label = 'Validation')
hA.set_title('Accuracy Score')
hA.set_xlabel('Epoch')
hA.set_ylabel('Score')
hA.legend()

hA = vHa[2]
hA.plot(lLearnRate, lw = 2)
hA.set_title('Learn Rate Scheduler')
hA.set_xlabel('Epoch')
hA.set_ylabel('Learn Rate')

In [None]:
# Analysis

# Aggregate results for Train Set

lYPred = []
lY     = []

for ii, (tX, vY) in enumerate(dlTrain):
    # Move Data to Model's device
    tX = tX.to(runDevice) #<! Lazy
    vY = vY.to(runDevice) #<! Lazy
        
    with torch.inference_mode():
        mZ = oModel(tX) #<! Model output
        vYPred = torch.argmax(mZ, dim = 1)
    
    lYPred.append(vYPred.detach().cpu().numpy())
    lY.append(vY.detach().cpu().numpy())

vYPredTrain  = np.concat(lYPred, axis = 0)
vYTruthTrain = np.concat(lY, axis = 0)

In [None]:
# Analysis

# Aggregate results for Validation Set

lYPred = []
lY     = []

for ii, (tX, vY) in enumerate(dlVal):
    # Move Data to Model's device
    tX = tX.to(runDevice) #<! Lazy
    vY = vY.to(runDevice) #<! Lazy
        
    with torch.inference_mode():
        mZ = oModel(tX) #<! Model output
        vYPred = torch.argmax(mZ, dim = 1)
    
    lYPred.append(vYPred.detach().cpu().numpy())
    lY.append(vY.detach().cpu().numpy())

vYPredTest  = np.concat(lYPred, axis = 0)
vYTruthTest = np.concat(lY, axis = 0)

In [None]:
# Analysis
# Confusion Matrix

hF, vHa = plt.subplots(nrows = 1, ncols = 2, figsize = (14, 6))

hA, _ = PlotConfusionMatrix(vYTruthTrain, vYPredTrain, hA = vHa[0], lLabels = L_CLASSES_MNIST, xLabelRot = 45)
hA.set_title(f'Train Data, Accuracy {np.mean(vYTruthTrain == vYPredTrain): 0.2%}')

hA, _ = PlotConfusionMatrix(vYTruthTest, vYPredTest, hA = vHa[1], lLabels = L_CLASSES_MNIST, xLabelRot = 45)
hA.set_title(f'Test Data, Accuracy {np.mean(vYTruthTest == vYPredTest): 0.2%}');

* <font color='green'>(**@**)</font> Create a CNN model with ~400,000 parameters and compare results. Explain.