# Vision Transformers: From Theory to Practice

This notebook explores **Vision Transformers (ViT)** through computational analysis and empirical experiments. We'll investigate:

1. **Computational Complexity** - How patch size affects FLOPs, parameters, and memory
2. **Ablation Studies** - The impact of architectural choices (positional embeddings, CLS token)
3. **CIFAR-10 Training** - Practical implementation and performance analysis

**Paper Reference**: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2021)

---

## Part 1: Computational Complexity Analysis

### Understanding ViT Computational Trade-offs

Vision Transformers process images as sequences of patches. The patch size directly impacts:
- **Sequence length**: Smaller patches → more tokens → quadratic attention cost
- **Information granularity**: Larger patches → coarser features
- **Model capacity**: Patch embedding parameters vs. transformer parameters

Let's analyze how different patch sizes affect computational complexity.

#### Imports

In [None]:
import os
import torch
import neural_stack

from neural_stack.utils import model_summary
from neural_stack.models.vision_transformer import VisionTransformer

#### Hyperparameters

In [2]:
IMAGE_SIZE = (32, 32)
NUM_CHANNELS = 3

NUM_LAYERS = 6
NUM_HEADS = 8
EMBED_DIM = 512
PATCH_SIZE = 4
MLP_RATIO = 4

#### Model Configuration

In [3]:
dummy_input = torch.randn((1, NUM_CHANNELS, *IMAGE_SIZE))

for patch_size in [4, 8, 16]:
    vit_model = VisionTransformer(
        img_size=IMAGE_SIZE,
        patch_size=patch_size,
        in_channels=NUM_CHANNELS,
        embed_dim=EMBED_DIM,
        num_heads=NUM_HEADS,
        num_layers=NUM_LAYERS,
        mlp_ratio=MLP_RATIO,
        dropout=0.1,
        num_classes=2
    )

    num_params, num_flops, num_acts, summary = model_summary(vit_model, dummy_input)
    num_patches = IMAGE_SIZE[0] * IMAGE_SIZE[1] // (patch_size ** 2)
    print(f"Patch Size {patch_size}x{patch_size}, Total Patches={num_patches}; #params={num_params}; #flops={num_flops}; #activations={num_acts}")
    print(summary)

    del vit_model

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 4x4, Total Patches=64; #params=18975234; #flops=1256365568; #activations=2232370
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 18.975M                | 1.256G     | 2.232M         |
|  patch_embedding                       |  58.88K                |  1.573M    |  32.768K       |
|   patch_embedding.pos_embedding        |   (1, 65, 512)         |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   25.088K              |   1.573M   |   32.768K      |
|    patch_embedding.proj.weight         |    (512, 3, 4, 4)      |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack      

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)
Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 8x8, Total Patches=16; #params=19024386; #flops=324738560; #activations=544306
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 19.024M                | 0.325G     | 0.544M         |
|  patch_embedding                       |  0.108M                |  1.573M    |  8.192K        |
|   patch_embedding.pos_embedding        |   (1, 17, 512)         |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   98.816K              |   1.573M   |   8.192K       |
|    patch_embedding.proj.weight         |    (512, 3, 8, 8)      |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack        

Unsupported operator aten::repeat encountered 1 time(s)
Unsupported operator aten::add encountered 13 time(s)
Unsupported operator aten::layer_norm encountered 13 time(s)
Unsupported operator aten::div encountered 6 time(s)
Unsupported operator aten::softmax encountered 6 time(s)
Unsupported operator aten::gelu encountered 6 time(s)


Patch Size 16x16, Total Patches=4; #params=19313154; #flops=96255488; #activations=156850
| module                                 | #parameters or shape   | #flops     | #activations   |
|:---------------------------------------|:-----------------------|:-----------|:---------------|
| model                                  | 19.313M                | 96.255M    | 0.157M         |
|  patch_embedding                       |  0.397M                |  1.573M    |  2.048K        |
|   patch_embedding.pos_embedding        |   (1, 5, 512)          |            |                |
|   patch_embedding.cls_token            |   (1, 1, 512)          |            |                |
|   patch_embedding.proj                 |   0.394M               |   1.573M   |   2.048K       |
|    patch_embedding.proj.weight         |    (512, 3, 16, 16)    |            |                |
|    patch_embedding.proj.bias           |    (512,)              |            |                |
|  transformer_stack        

### Complexity Report

The table below summarizes the results of the complexity analysis using the above measurements.

| Patch Size | Num Patches | Tokens (w/ CLS) | Attention FLOPs/Layer | Total Params |
|------------|-------------|-----------------|----------------------|--------------|
| 4×4        | 64          | 65              | 72.4 M                    | 18.9 M            |
| 8×8        | 16          | 17              | 18.1 M                    | 19.1 M            |
| 16×16      | 4           | 5               | 5.2 M                    | 19.3            |

The quadratic relationship between the patch size and the number of FLOPs per attention layer is clearly visible: doubling the patch size decreases the number of FLOPs per attention layer by a factor of 4, while the number of parameters remains constant (no parameters in the attention module).

While using larger patches results in a less computationally-intensive model, this comes at the cost of reduced expresivity and level of detail. Depending on the size of the objects of interest in the images, the patch size should be increased enough to capture the desired level of detail. On the other hand, attention compute scales quadratically with the image size as well, so using a larger input resolution has a costly effect on computational complexity, increasing the number of patches. As in everything, a balance must be found.

This can be seen in the original ViT paper, which uses 16x16 patches when training on ImageNet. Models trained on ImageNet commonly use a input resolution of 224x224. This results in 196 total patches. CIFAR10, on the other hand, contains images of size 32x32 -- 7x smaller than ImageNet. This should allow us the use a smaller patch size. In order to get an equivalent number of patches, we would have to set the patch smaller than ~2x2 -- for practical reasons, we shall use a patch size of 4x4.

## Part 2: Ablation Study

---

### Experimental Setup

We'll conduct controlled experiments to understand which components are essential for ViT performance:

**Ablation Experiments**:
1. **Baseline** - Full ViT with positional embeddings and CLS token
2. **No Positional Embedding** - Remove spatial information encoding
3. **No CLS Token** - Use mean pooling instead of learned class token

All models are trained on **CIFAR-10** with identical hyperparameters for fair comparison.

#### Imports

In [3]:
import plotly
import torchvision
import wandb

from tqdm import tqdm
from torchvision import transforms
from torchvision.datasets import CIFAR10

from neural_stack.utils import get_project_root

In [4]:
# Get project root directory
ROOT_DIR = get_project_root()
DATASET_PATH = f"{ROOT_DIR}/data/cifar10"
WANDB_PATH = f"{ROOT_DIR}/data/.wandb"

wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mandreifurdui[0m ([33mandreifurdui-team[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

#### Dataset Preparation

Load **CIFAR-10** dataset with standard data augmentation:
- Random horizontal flips
- Random cropping with slight scale/ratio variation
- Normalization with dataset statistics

In [5]:
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomResizedCrop((32, 32), scale=(0.8, 1.0), ratio=(0.9, 1.1)),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = CIFAR10(
    root=DATASET_PATH,
    train=True,
    transform=train_transform,
    download=True
)

test_dataset = CIFAR10(
    root=DATASET_PATH,
    train=False,
    transform=test_transform,
    download=True
)

train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True,
    num_workers=4,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=128,
    shuffle=False,
    num_workers=4
)

#### Data Visualization

Sample images from CIFAR-10 to understand the task complexity.

In [6]:
# Visualize some examples
NUM_IMAGES_VIZ = 9
CIFAR_images = torch.stack([test_dataset[idx][0] for idx in range(NUM_IMAGES_VIZ)], dim=0)
img_grid = torchvision.utils.make_grid(CIFAR_images, nrow=3, padding=2)
img_grid = img_grid.permute(1, 2, 0)

# plot the image grid with plotly
fig = plotly.graph_objects.Figure()
fig.add_trace(plotly.graph_objects.Image(z=img_grid.numpy(), zmin=[0,0,0,0], zmax=[1,1,1,1], colormodel="rgb"))
fig.update_layout(width=600, height=600, title="CIFAR-10 Sample Images")
fig.show()

#### Training Configuration

Hyperparameters for all experiments:
- **Architecture**: 4 layers, 8 heads, 256-dim embeddings
- **Patch size**: 8×8 (16 patches from 32×32 images)
- **Optimizer**: AdamW with learning rate 3e-4
- **Training**: 25 epochs with step learning rate decay

In [7]:
IMAGE_SIZE = (32, 32)
NUM_CHANNELS = 3

NUM_LAYERS = 4
NUM_HEADS = 8
EMBED_DIM = 256
PATCH_SIZE = 8
MLP_RATIO = 2

LR = 3e-4
NUM_EPOCHS = 25

device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")

#### Training & Validation Functions

Define training loop with W&B logging and validation metrics.

In [8]:
def train_epoch(model: torch.nn.Module, dataloader, criterion, optimizer, lr_scheduler, epoch, device, wandb_logger: wandb.Run, print_freq = 50):
    model.train()

    for idx, (img, target) in enumerate(dataloader):
        img = img.to(device)
        target = target.to(device)

        pred = model(img)
        loss = criterion(pred, target)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        accuracy = (pred.argmax(dim=1) == target).float().mean().item()

        wandb_logger.log({
            "train/loss": loss.item(),
            "train/accuracy": accuracy,
            "train/lr": optimizer.param_groups[0]['lr'],
        }, step=epoch * len(dataloader) + idx)
            
    lr_scheduler.step()

def validate(model: torch.nn.Module, dataloader, criterion, device):
    model.eval()

    loss_avg = 0.0
    accuracy_avg = 0.0

    for img, target in dataloader:
        with torch.no_grad():
            img = img.to(device)
            target = target.to(device)

            pred = model(img)
            loss = criterion(pred, target)

            accuracy = (pred.argmax(dim=1) == target).float().mean().item()

            loss_avg += loss.item()
            accuracy_avg += accuracy
    
    loss_avg /= len(dataloader)
    accuracy_avg /= len(dataloader)

    return loss_avg, accuracy_avg

In [11]:
def train(model, train_loader, test_loader, criterion, optimizer, lr_scheduler, device, wandb_logger, num_epochs):
    best_val_accuracy = 0.0

    for epoch in tqdm(range(num_epochs), total=num_epochs):
        train_epoch(
            model=model,
            dataloader=train_loader,
            criterion=criterion,
            optimizer=optimizer,
            lr_scheduler=lr_scheduler,
            epoch=epoch,
            device=device,
            wandb_logger=wandb_logger
        )

        val_loss, val_accuracy = validate(
            model=model,
            dataloader=test_loader,
            criterion=criterion,
            device=device
        )

        print(f"Epoch {epoch+1}/{num_epochs}: Val Loss={val_loss:.4f}, Val Acc={val_accuracy:.4f}")

        if val_accuracy > best_val_accuracy:
            best_val_accuracy = val_accuracy
            torch.save(model.state_dict(), os.path.join(wandb_logger.dir, "best.pth"))
            print(f"New best model saved with Val Acc={best_val_accuracy:.4f}")

        wandb_logger.log({
            "val/loss": val_loss,
            "val/accuracy": val_accuracy,
        }, step=(epoch + 1) * len(train_loader))
    
    wandb_logger.save(os.path.join(wandb_logger.dir, "best.pth"))

---

### Experiment 1: Baseline Model

Train the full ViT model with:
- ✓ Learned positional embeddings
- ✓ CLS token for classification

In [12]:
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    mlp_ratio=MLP_RATIO,
    dropout=0.1,
    num_classes=10
).to(device)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=12)
criterion = torch.nn.CrossEntropyLoss().to(device)

with wandb.init(project="neural-stack", dir=WANDB_PATH, name="baseline.v2",
     notes="Vision Transformer on CIFAR-10 baseline run",
     group="vision-transformer-cifar10-ablations",
     tags=["vision-transformer", "cifar10"],
     config={
          "num_layers": NUM_LAYERS,
          "num_heads": NUM_HEADS,
          "embed_dim": EMBED_DIM,
          "patch_size": PATCH_SIZE,
          "mlp_ratio": MLP_RATIO,
          "lr": LR,
          }) as wandb_logger:

     train(model, train_loader, test_loader, criterion, optimizer, lr_scheduler, device, wandb_logger, NUM_EPOCHS)
          

  4%|▍         | 1/25 [00:28<11:17, 28.25s/it]

Epoch 1/25: Val Loss=1.4131, Val Acc=0.4939
New best model saved with Val Acc=0.4939


  8%|▊         | 2/25 [00:58<11:18, 29.50s/it]

Epoch 2/25: Val Loss=1.3173, Val Acc=0.5134
New best model saved with Val Acc=0.5134


 12%|█▏        | 3/25 [01:29<11:03, 30.18s/it]

Epoch 3/25: Val Loss=1.2258, Val Acc=0.5580
New best model saved with Val Acc=0.5580


 16%|█▌        | 4/25 [02:01<10:44, 30.67s/it]

Epoch 4/25: Val Loss=1.1714, Val Acc=0.5723
New best model saved with Val Acc=0.5723


 20%|██        | 5/25 [02:32<10:18, 30.93s/it]

Epoch 5/25: Val Loss=1.1245, Val Acc=0.5934
New best model saved with Val Acc=0.5934


 24%|██▍       | 6/25 [03:04<09:56, 31.41s/it]

Epoch 6/25: Val Loss=1.1116, Val Acc=0.6015
New best model saved with Val Acc=0.6015


 28%|██▊       | 7/25 [03:36<09:26, 31.49s/it]

Epoch 7/25: Val Loss=1.0811, Val Acc=0.6163
New best model saved with Val Acc=0.6163


 32%|███▏      | 8/25 [04:08<08:55, 31.53s/it]

Epoch 8/25: Val Loss=1.0709, Val Acc=0.6189
New best model saved with Val Acc=0.6189


 36%|███▌      | 9/25 [04:36<08:09, 30.58s/it]

Epoch 9/25: Val Loss=1.0427, Val Acc=0.6318
New best model saved with Val Acc=0.6318


 40%|████      | 10/25 [05:08<07:46, 31.08s/it]

Epoch 10/25: Val Loss=1.0179, Val Acc=0.6396
New best model saved with Val Acc=0.6396


 44%|████▍     | 11/25 [05:41<07:20, 31.48s/it]

Epoch 11/25: Val Loss=1.0151, Val Acc=0.6404
New best model saved with Val Acc=0.6404


 48%|████▊     | 12/25 [06:13<06:51, 31.64s/it]

Epoch 12/25: Val Loss=0.9997, Val Acc=0.6493
New best model saved with Val Acc=0.6493


 52%|█████▏    | 13/25 [06:44<06:19, 31.66s/it]

Epoch 13/25: Val Loss=0.9269, Val Acc=0.6734
New best model saved with Val Acc=0.6734


 56%|█████▌    | 14/25 [07:15<05:45, 31.42s/it]

Epoch 14/25: Val Loss=0.9190, Val Acc=0.6725


 60%|██████    | 15/25 [07:46<05:12, 31.28s/it]

Epoch 15/25: Val Loss=0.9171, Val Acc=0.6745
New best model saved with Val Acc=0.6745


 64%|██████▍   | 16/25 [08:18<04:42, 31.38s/it]

Epoch 16/25: Val Loss=0.9157, Val Acc=0.6803
New best model saved with Val Acc=0.6803


 68%|██████▊   | 17/25 [08:49<04:09, 31.20s/it]

Epoch 17/25: Val Loss=0.9101, Val Acc=0.6801


 72%|███████▏  | 18/25 [09:20<03:38, 31.25s/it]

Epoch 18/25: Val Loss=0.9070, Val Acc=0.6844
New best model saved with Val Acc=0.6844


 76%|███████▌  | 19/25 [09:51<03:07, 31.23s/it]

Epoch 19/25: Val Loss=0.9061, Val Acc=0.6840


 80%|████████  | 20/25 [10:23<02:36, 31.32s/it]

Epoch 20/25: Val Loss=0.8997, Val Acc=0.6841


 84%|████████▍ | 21/25 [10:54<02:04, 31.22s/it]

Epoch 21/25: Val Loss=0.9059, Val Acc=0.6833


 88%|████████▊ | 22/25 [11:29<01:37, 32.53s/it]

Epoch 22/25: Val Loss=0.8983, Val Acc=0.6852
New best model saved with Val Acc=0.6852


 92%|█████████▏| 23/25 [12:05<01:07, 33.61s/it]

Epoch 23/25: Val Loss=0.8989, Val Acc=0.6912
New best model saved with Val Acc=0.6912


 96%|█████████▌| 24/25 [12:41<00:34, 34.16s/it]

Epoch 24/25: Val Loss=0.8994, Val Acc=0.6945
New best model saved with Val Acc=0.6945


100%|██████████| 25/25 [13:16<00:00, 31.85s/it]


Epoch 25/25: Val Loss=0.8970, Val Acc=0.6917


0,1
train/accuracy,▁▁▁▃▂▃▅▅▅▆▅▅▅▅▅▅▆▄▇▇▇▆▇▇▆▅▆▇█▆▇▇▇▇█▅▆▇▇▇
train/loss,█▆▅▅▆▄▅▄▅▄▄▄▄▄▄▅▃▄▄▄▃▃▃▃▃▂▃▂▂▂▂▂▂▂▂▃▂▁▂▁
train/lr,██████████████████████▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁
val/accuracy,▁▂▃▄▄▅▅▅▆▆▆▆▇▇▇█▇████████
val/loss,█▇▅▅▄▄▃▃▃▃▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/accuracy,0.8125
train/loss,0.54036
train/lr,0.0
val/accuracy,0.69165
val/loss,0.89699


---

### Experiment 2: No Positional Embedding

**Hypothesis**: Without positional embeddings, the model loses spatial awareness and treats patches as an unordered set.

This tests whether ViT can learn spatial relationships purely from data, or if architectural guidance is essential.

In [9]:
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    mlp_ratio=MLP_RATIO,
    dropout=0.1,
    num_classes=10,
    positional_embedding='none'
).to(device)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=12)
criterion = torch.nn.CrossEntropyLoss().to(device)

with wandb.init(project="neural-stack", dir=WANDB_PATH, name="no-pos-emb",
     notes="Vision Transformer on CIFAR-10 w/o positional embedding run",
     group="vision-transformer-cifar10-ablations",
     tags=["vision-transformer", "cifar10"],
     config={
          "num_layers": NUM_LAYERS,
          "num_heads": NUM_HEADS,
          "embed_dim": EMBED_DIM,
          "patch_size": PATCH_SIZE,
          "mlp_ratio": MLP_RATIO,
          "lr": LR,
          }) as wandb_logger:

     train(model, train_loader, test_loader, criterion, optimizer, lr_scheduler, device, wandb_logger, NUM_EPOCHS)
          

  4%|▍         | 1/25 [00:24<09:53, 24.75s/it]

Epoch 1/25: Val Loss=1.5560, Val Acc=0.4358


  8%|▊         | 2/25 [00:49<09:27, 24.68s/it]

Epoch 2/25: Val Loss=1.4534, Val Acc=0.4717


 12%|█▏        | 3/25 [01:14<09:11, 25.06s/it]

Epoch 3/25: Val Loss=1.4510, Val Acc=0.4831


 16%|█▌        | 4/25 [01:40<08:50, 25.27s/it]

Epoch 4/25: Val Loss=1.2977, Val Acc=0.5312


 20%|██        | 5/25 [02:05<08:25, 25.26s/it]

Epoch 5/25: Val Loss=1.2691, Val Acc=0.5453


 24%|██▍       | 6/25 [02:31<08:03, 25.47s/it]

Epoch 6/25: Val Loss=1.2421, Val Acc=0.5559


 28%|██▊       | 7/25 [02:57<07:38, 25.49s/it]

Epoch 7/25: Val Loss=1.2038, Val Acc=0.5692


 32%|███▏      | 8/25 [03:22<07:15, 25.59s/it]

Epoch 8/25: Val Loss=1.1960, Val Acc=0.5732


 36%|███▌      | 9/25 [03:48<06:50, 25.66s/it]

Epoch 9/25: Val Loss=1.1668, Val Acc=0.5831


 40%|████      | 10/25 [04:14<06:26, 25.74s/it]

Epoch 10/25: Val Loss=1.1418, Val Acc=0.5901


 44%|████▍     | 11/25 [04:40<05:59, 25.68s/it]

Epoch 11/25: Val Loss=1.1126, Val Acc=0.5969


 48%|████▊     | 12/25 [05:05<05:34, 25.70s/it]

Epoch 12/25: Val Loss=1.1126, Val Acc=0.5990


 52%|█████▏    | 13/25 [05:31<05:06, 25.53s/it]

Epoch 13/25: Val Loss=1.0434, Val Acc=0.6239


 56%|█████▌    | 14/25 [05:55<04:38, 25.31s/it]

Epoch 14/25: Val Loss=1.0319, Val Acc=0.6322


 60%|██████    | 15/25 [06:20<04:12, 25.21s/it]

Epoch 15/25: Val Loss=1.0329, Val Acc=0.6331


 64%|██████▍   | 16/25 [06:46<03:47, 25.27s/it]

Epoch 16/25: Val Loss=1.0273, Val Acc=0.6362


 68%|██████▊   | 17/25 [07:12<03:23, 25.43s/it]

Epoch 17/25: Val Loss=1.0290, Val Acc=0.6368


 72%|███████▏  | 18/25 [07:37<02:56, 25.27s/it]

Epoch 18/25: Val Loss=1.0247, Val Acc=0.6411


 76%|███████▌  | 19/25 [08:03<02:32, 25.50s/it]

Epoch 19/25: Val Loss=1.0169, Val Acc=0.6433


 80%|████████  | 20/25 [08:29<02:08, 25.70s/it]

Epoch 20/25: Val Loss=1.0124, Val Acc=0.6459


 84%|████████▍ | 21/25 [08:54<01:42, 25.58s/it]

Epoch 21/25: Val Loss=1.0239, Val Acc=0.6410


 88%|████████▊ | 22/25 [09:20<01:16, 25.57s/it]

Epoch 22/25: Val Loss=1.0220, Val Acc=0.6448


 92%|█████████▏| 23/25 [09:45<00:51, 25.57s/it]

Epoch 23/25: Val Loss=1.0166, Val Acc=0.6450


 96%|█████████▌| 24/25 [10:11<00:25, 25.64s/it]

Epoch 24/25: Val Loss=1.0211, Val Acc=0.6473


100%|██████████| 25/25 [10:37<00:00, 25.49s/it]

Epoch 25/25: Val Loss=1.0067, Val Acc=0.6517





0,1
train/accuracy,▁▁▃▃▃▃▃▃▄▃▄▄▅▃▄▄▄▃▅▅▃▅▅▅▆▆▅▆▅▆▆▇▇▇▆▇▇▇█▇
train/loss,▇▆█▆▆█▆▆▅▆▄▅▄▄▄▅▂▄▃▃▃▃▂▃▃▁▂▃▂▂▄▄▂▂▁▂▁▂▂▁
train/lr,█████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
val/accuracy,▁▂▃▄▅▅▅▅▆▆▆▆▇▇▇▇█████████
val/loss,█▇▇▅▄▄▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/accuracy,0.7125
train/loss,0.80988
train/lr,0.0
val/accuracy,0.6517
val/loss,1.0067


---

### Experiment 3: No CLS Token (Mean Pooling)

**Hypothesis**: The CLS token is a learned aggregation mechanism. Mean pooling provides a simpler alternative.

This tests whether a learnable aggregation token provides advantages over simple averaging.

In [10]:
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    mlp_ratio=MLP_RATIO,
    dropout=0.1,
    num_classes=10,
    use_cls_token=False
).to(device)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=12)
criterion = torch.nn.CrossEntropyLoss().to(device)

with wandb.init(project="neural-stack", dir=WANDB_PATH, name="no-cls-tok",
     notes="Vision Transformer on CIFAR-10 w/o CLS token (mean pooling) run",
     group="vision-transformer-cifar10-ablations",
     tags=["vision-transformer", "cifar10"],
     config={
          "num_layers": NUM_LAYERS,
          "num_heads": NUM_HEADS,
          "embed_dim": EMBED_DIM,
          "patch_size": PATCH_SIZE,
          "mlp_ratio": MLP_RATIO,
          "lr": LR,
          }) as wandb_logger:

     train(model, train_loader, test_loader, criterion, optimizer, lr_scheduler, device, wandb_logger, NUM_EPOCHS)
          

  4%|▍         | 1/25 [00:26<10:33, 26.39s/it]

Epoch 1/25: Val Loss=1.4229, Val Acc=0.4864


  8%|▊         | 2/25 [00:51<09:55, 25.90s/it]

Epoch 2/25: Val Loss=1.2819, Val Acc=0.5403


 12%|█▏        | 3/25 [01:18<09:32, 26.00s/it]

Epoch 3/25: Val Loss=1.2216, Val Acc=0.5586


 16%|█▌        | 4/25 [01:43<08:59, 25.67s/it]

Epoch 4/25: Val Loss=1.1592, Val Acc=0.5845


 20%|██        | 5/25 [02:09<08:36, 25.82s/it]

Epoch 5/25: Val Loss=1.1261, Val Acc=0.5938


 24%|██▍       | 6/25 [02:34<08:03, 25.45s/it]

Epoch 6/25: Val Loss=1.1045, Val Acc=0.6105


 28%|██▊       | 7/25 [02:59<07:38, 25.47s/it]

Epoch 7/25: Val Loss=1.0852, Val Acc=0.6102


 32%|███▏      | 8/25 [03:25<07:14, 25.58s/it]

Epoch 8/25: Val Loss=1.0587, Val Acc=0.6205


 36%|███▌      | 9/25 [03:51<06:52, 25.75s/it]

Epoch 9/25: Val Loss=1.0268, Val Acc=0.6362


 40%|████      | 10/25 [04:17<06:27, 25.85s/it]

Epoch 10/25: Val Loss=1.0452, Val Acc=0.6283


 44%|████▍     | 11/25 [04:41<05:54, 25.32s/it]

Epoch 11/25: Val Loss=0.9888, Val Acc=0.6479


 48%|████▊     | 12/25 [05:06<05:26, 25.15s/it]

Epoch 12/25: Val Loss=0.9805, Val Acc=0.6572


 52%|█████▏    | 13/25 [05:32<05:03, 25.29s/it]

Epoch 13/25: Val Loss=0.9178, Val Acc=0.6788


 56%|█████▌    | 14/25 [05:57<04:38, 25.29s/it]

Epoch 14/25: Val Loss=0.9130, Val Acc=0.6846


 60%|██████    | 15/25 [06:22<04:13, 25.31s/it]

Epoch 15/25: Val Loss=0.9038, Val Acc=0.6853


 64%|██████▍   | 16/25 [06:47<03:45, 25.10s/it]

Epoch 16/25: Val Loss=0.8977, Val Acc=0.6906


 68%|██████▊   | 17/25 [07:11<03:19, 24.91s/it]

Epoch 17/25: Val Loss=0.8961, Val Acc=0.6899


 72%|███████▏  | 18/25 [07:36<02:53, 24.85s/it]

Epoch 18/25: Val Loss=0.8945, Val Acc=0.6906


 76%|███████▌  | 19/25 [08:01<02:29, 24.94s/it]

Epoch 19/25: Val Loss=0.8874, Val Acc=0.6948


 80%|████████  | 20/25 [08:25<02:03, 24.76s/it]

Epoch 20/25: Val Loss=0.8892, Val Acc=0.6947


 84%|████████▍ | 21/25 [08:50<01:39, 24.79s/it]

Epoch 21/25: Val Loss=0.8956, Val Acc=0.6925


 88%|████████▊ | 22/25 [09:15<01:14, 24.82s/it]

Epoch 22/25: Val Loss=0.8861, Val Acc=0.6971


 92%|█████████▏| 23/25 [09:40<00:49, 24.93s/it]

Epoch 23/25: Val Loss=0.8955, Val Acc=0.6969


 96%|█████████▌| 24/25 [10:06<00:25, 25.03s/it]

Epoch 24/25: Val Loss=0.8903, Val Acc=0.6978


100%|██████████| 25/25 [10:31<00:00, 25.25s/it]

Epoch 25/25: Val Loss=0.8863, Val Acc=0.6984





0,1
train/accuracy,▁▃▄▅▄▆▅▅▅▅▆▅▆▆▆▆▇▆▅▇▆▇▆▇█▇▇███▇▇█▇▇█▇▇█▇
train/loss,▇█▇▆▆▄▄▅▅▄▄▄▄▃▄▂▃▄▂▁▂▂▁▂▃▂▂▂▂▂▃▂▂▁▁▁▂▂▁▁
train/lr,███████████████████▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁
val/accuracy,▁▃▃▄▅▅▅▅▆▆▆▇▇████████████
val/loss,█▆▅▅▄▄▄▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/accuracy,0.8625
train/loss,0.61381
train/lr,0.0
val/accuracy,0.69838
val/loss,0.88635


### Report & Analysis

The ablation results show the importance of using positional encoding in the ViT model. Model performance drops from 69% to 65% accuracy when positional embedding is not used. 

However, using mean pooling over all the tokens instead of the CLS token does not seem to impact model performance, resulting in similar training and testing metrics.

See the GitHub Issue for loss and metrics curves: https://github.com/andreifurdui/neural-stack/issues/6

Ablation Results Summary Table:
| Model Variant                 | Test Accuracy |
|-------------------------------|---------------|
| Baseline (with CLS & Pos Emb) | 69.5%         |
| No Positional Embedding       | 65.8%         |
| No CLS Token (Mean Pooling)   | 69.8%         |
| Random Guessing               | 10.0%         |

#### The Importance of Positional Embedding

Removing positional embedding essentially removes the model's compass: the patches are treated as a bag-of-tokens and all spatial relation between them being lost after the first Self-Attention layer. In effect, this completely removes the possibility of the ViT architecture to model the inductive biases inherent in a CNN architecture. Obviously, this degrades the model's performance quite significantly, but not completely: information can still be extracted from the contents of the patches, but not by their spatial arrangement. A larger patch size would most probably hurt less from the removal of the positional embedding information, as larger patches would be able to individually capture more information necessary for understanding the contents.

#### 4x4 vs. 16x16

The choice of patch size comes as a compromise between several factors:
- Input size
- Available Compute
- Level of detail

Larger input size would push the patch size up, in order to control the quadratic increase of Attention compute.
Level of detail would push the patch size down, in order to be able to capture more precise features and understand the image at a finer level of detail.

In general, one could safely assume that image size scales with level of detail. However, this is not always true: medical or satelite imagery for example are of very high resolutions, while the level of detail needed to properly extract information from these type of images is quite granular. Domain knowledge serves as the best guide for making this choice.

A hard cap is always given by the available compute, which would put a higher-bound on the computational resources and guide the choice of patch size.

#### CNN vs. ViT Spatial Understanding

CNNs spatial modelling is inherent in its architecture. Overlapping convolutional operations build up a hierarchical understanding of the image. On the other hand, a ViTs spatial modelling must be learned during training. There is no architectural 'help' the ViT can rely on (unless a hand crafted, static positional embedding is used). The transformer architecture must learn spatial information from the data distribution itself.

---

### Additional Experiments

---

### Experiment 4: 2D Positional Embedding

**Hypothesis**: Explicitly modelling the 2-dimensional spatial information through a 2D-aware positional embedding should aid the model in learning spatial relationship.

This tests whether intoducing inductive bias by structurally modelling the positional embedding to describe 2D spatial relationship aids the model in learning.

In [10]:
model = VisionTransformer(
    img_size=IMAGE_SIZE,
    patch_size=PATCH_SIZE,
    in_channels=NUM_CHANNELS,
    embed_dim=EMBED_DIM,
    num_heads=NUM_HEADS,
    num_layers=NUM_LAYERS,
    mlp_ratio=MLP_RATIO,
    dropout=0.1,
    num_classes=10,
    positional_embedding='learned-2d'
).to(device)

optimizer = torch.optim.AdamW(
    params=model.parameters(),
    lr=LR
)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=12)
criterion = torch.nn.CrossEntropyLoss().to(device)

with wandb.init(project="neural-stack", dir=WANDB_PATH, name="2d-pos-emb",
     notes="Vision Transformer on CIFAR-10 with 2D positional embedding run",
     group="vision-transformer-cifar10-ablations",
     tags=["vision-transformer", "cifar10"],
     config={
          "num_layers": NUM_LAYERS,
          "num_heads": NUM_HEADS,
          "embed_dim": EMBED_DIM,
          "patch_size": PATCH_SIZE,
          "mlp_ratio": MLP_RATIO,
          "lr": LR,
          }) as wandb_logger:

     train(model, train_loader, test_loader, criterion, optimizer, lr_scheduler, device, wandb_logger, NUM_EPOCHS)
          

  4%|▍         | 1/25 [00:25<10:13, 25.56s/it]

Epoch 1/25: Val Loss=1.3991, Val Acc=0.4986


  8%|▊         | 2/25 [00:50<09:45, 25.44s/it]

Epoch 2/25: Val Loss=1.2725, Val Acc=0.5442


 12%|█▏        | 3/25 [01:16<09:25, 25.71s/it]

Epoch 3/25: Val Loss=1.2282, Val Acc=0.5557


 16%|█▌        | 4/25 [01:42<08:56, 25.53s/it]

Epoch 4/25: Val Loss=1.1919, Val Acc=0.5679


 20%|██        | 5/25 [02:07<08:27, 25.36s/it]

Epoch 5/25: Val Loss=1.1251, Val Acc=0.5961


 24%|██▍       | 6/25 [02:32<08:03, 25.46s/it]

Epoch 6/25: Val Loss=1.1034, Val Acc=0.6067


 28%|██▊       | 7/25 [02:58<07:38, 25.45s/it]

Epoch 7/25: Val Loss=1.0650, Val Acc=0.6134


 32%|███▏      | 8/25 [03:23<07:10, 25.31s/it]

Epoch 8/25: Val Loss=1.0658, Val Acc=0.6163


 36%|███▌      | 9/25 [03:48<06:44, 25.26s/it]

Epoch 9/25: Val Loss=1.0120, Val Acc=0.6423


 40%|████      | 10/25 [04:12<06:14, 24.99s/it]

Epoch 10/25: Val Loss=0.9984, Val Acc=0.6422


 44%|████▍     | 11/25 [04:37<05:49, 24.93s/it]

Epoch 11/25: Val Loss=0.9960, Val Acc=0.6450


 48%|████▊     | 12/25 [05:02<05:25, 25.03s/it]

Epoch 12/25: Val Loss=0.9943, Val Acc=0.6401


 52%|█████▏    | 13/25 [05:27<04:58, 24.91s/it]

Epoch 13/25: Val Loss=0.9152, Val Acc=0.6755


 56%|█████▌    | 14/25 [05:51<04:32, 24.74s/it]

Epoch 14/25: Val Loss=0.9050, Val Acc=0.6796


 60%|██████    | 15/25 [06:16<04:08, 24.84s/it]

Epoch 15/25: Val Loss=0.8978, Val Acc=0.6838


 64%|██████▍   | 16/25 [06:42<03:44, 24.97s/it]

Epoch 16/25: Val Loss=0.8996, Val Acc=0.6842


 68%|██████▊   | 17/25 [07:07<03:19, 24.91s/it]

Epoch 17/25: Val Loss=0.8969, Val Acc=0.6853


 72%|███████▏  | 18/25 [07:31<02:53, 24.84s/it]

Epoch 18/25: Val Loss=0.8932, Val Acc=0.6861


 76%|███████▌  | 19/25 [07:56<02:28, 24.73s/it]

Epoch 19/25: Val Loss=0.8903, Val Acc=0.6887


 80%|████████  | 20/25 [08:21<02:04, 24.96s/it]

Epoch 20/25: Val Loss=0.8893, Val Acc=0.6898


 84%|████████▍ | 21/25 [08:47<01:41, 25.28s/it]

Epoch 21/25: Val Loss=0.8938, Val Acc=0.6897


 88%|████████▊ | 22/25 [09:12<01:15, 25.21s/it]

Epoch 22/25: Val Loss=0.8864, Val Acc=0.6931


 92%|█████████▏| 23/25 [09:37<00:50, 25.22s/it]

Epoch 23/25: Val Loss=0.8882, Val Acc=0.6927


 96%|█████████▌| 24/25 [10:02<00:25, 25.13s/it]

Epoch 24/25: Val Loss=0.8926, Val Acc=0.6943


100%|██████████| 25/25 [10:28<00:00, 25.14s/it]

Epoch 25/25: Val Loss=0.8861, Val Acc=0.6945





0,1
train/accuracy,▁▃▄▂▂▄▃▄▄▅▄▃▅▅▅▆▆▃▄▄▇▇▆▆▆▆▆▇▇▇▇▇▆██▇▆▇██
train/loss,█▅▆▅▅▅▄▆▅▄▃▃▅▃▃▃▂▄▃▂▂▁▂▁▂▂▂▂▁▂▂▂▂▁▂▂▁▂▁▂
train/lr,████████████████████▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁
val/accuracy,▁▃▃▃▄▅▅▅▆▆▆▆▇▇███████████
val/loss,█▆▆▅▄▄▃▃▃▃▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
train/accuracy,0.7125
train/loss,0.67633
train/lr,0.0
val/accuracy,0.69452
val/loss,0.88611


#### 2D Positional Embedding Results

Training the ViT with a 2D positional embedding achieves identical results under our configuration, similarly with the results achieved in the paper. This reinforces their conclusion that explicitly modelling the 2D spatial information within the Positional Embedding structure does not bring any improvements to the expressivity of the ViT architecture, which can easily learn 2D spatial dynamics using a 1D positional embedding.

Ablation Results Summary Table:
| Model Variant                 | Test Accuracy |
|-------------------------------|---------------|
| Baseline (with CLS & Pos Emb) | 69.5%         |
| 2D Positional Embedding       | 69.4%         |
| No Positional Embedding       | 65.8%         |
| No CLS Token (Mean Pooling)   | 69.8%         |
| Random Guessing               | 10.0%         |