In [14]:
import torch
import torch.nn as nn
import torch.nn.functional as F 
from torchvision import datasets, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
%matplotlib inline

In [38]:
# -----------------------------
# 1. Load MSNIST dataset
# -----------------------------

# This will stream the data, you don't have to download the full file
# mnist_train = load_dataset("ylecun/mnist", split="train")

# mnist_test = load_dataset("ylecun/mnist", split="test")


##### Look into the normalisation #####
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])

mnist_train = datasets.MNIST("./data", train=True, download=False, transform=transform)

mnist_test = datasets.MNIST("./data", train=False, download=False, transform=transform)

In [47]:
img, label = mnist_train[0]
print(img.shape)    # e.g., torch.Size([1, 28, 28])
print(type(label))

torch.Size([1, 28, 28])
<class 'int'>


In [None]:
def patch(img, patch_size=7):
    # img shape: (1, 28, 28)
    patches = img.unfold(1, patch_size, patch_size).unfold(2, patch_size, patch_size)
    # shape: (1, 4, 4, 7, 7)
    patches = patches.contiguous().view(1, -1, patch_size, patch_size)
    # shape: (1, 16, 7, 7)
    return patches.squeeze(0)  # (16, 7, 7)



In [33]:
img, label = mnist_train[0]  # img: (1, 28, 28)
patches = patch(img, patch_size=7)
print(patches.shape)  # Should print: torch.Size([16, 7, 7])

torch.Size([16, 7, 7])


In [26]:
all_patches = [patch(img) for img, _ in mnist_train]  # list of (16,7,7)
all_patches = torch.stack(all_patches)  # (N, 16, 7, 7)
all_labels = torch.tensor([label for _, label in mnist_train])  # shape: (60000,)


torch.Size([60000, 16, 7, 7])


In [27]:
print(all_patches.shape)  # Should print: torch.Size([60000, 16, 7, 7])
print(all_labels.shape)  # Should print: torch.Size([60000])


torch.Size([60000, 16, 7, 7])
torch.Size([60000])


In [39]:
flat_patches = all_patches.view(60000, 16, -1)  # shape: (60000, 16, 49)
print(flat_patches.shape)  # Should print: torch.Size([60000, 16, 49])

torch.Size([60000, 16, 49])


In [None]:
patch_size = 7
embed_dim = 64
patch_dim = patch_size * patch_size

class PatchEmbed(nn.Module):
    """A simple linear projection of patches to embeddings."""
    def __init__(self, patch_dim, embed_dim):
        super().__init__()
        self.proj = nn.Linear(patch_dim, embed_dim)

    def forward(self, x):
        # x: (B, num_patches, patch_dim)
        return self.proj(x)
    
flat_patch_embed = PatchEmbed(patch_dim, embed_dim)(flat_patches)  # shape: (60000, 16, 64)
print(flat_patch_embed.shape)  # Should print: torch.Size([60000, 16

torch.Size([60000, 16, 64])


In [None]:
# Create CLS token and concatenate it to the patch embeddings

B, N, D = flat_patch_embed.shape

# Create a learnable CLS token
class_token = nn.Parameter(torch.zeros(1, 1, D))  # (1, 1, 64)
nn.init.trunc_normal_(class_token, std=0.02)  # paper uses truncated normal for init

# Expand CLS token for the batch
cls_tokens = class_token.expand(B, -1, -1)  # (60000, 1, 64)

# Concatenate to the front of the patch embeddings
vit_input = torch.cat([cls_tokens, flat_patch_embed], dim=1)  # (60000, 17, 64)

In [None]:
# Add positional encoding to the patch embeddings and CLS token

# Create positional embeddings (learnable)
pos_embed = nn.Parameter(torch.zeros(1, N + 1, D))  # (1, 17, 64)
nn.init.trunc_normal_(pos_embed, std=0.02)

# Add positional encoding to vit_input
vit_input = vit_input + pos_embed  # (60000, 17, 64)

In [None]:
class ViTEncoder(nn.Module):
    """Visual Transformer encoder for the MNIST dataset."""
    def __init__(self, embed_dim, num_heads, num_layers):
        super().__init__()
        )