# Assignment 1: MLP, CNN and Vision Transformers in PyTorch

## Kaggle setup (Colab)

1) Open Kaggle → **Click Your PFP (top-right)** → **Settings** → **Account** → **API** → **Create Legacy API Key**  
   This downloads `kaggle.json`.

2) In Colab, run the next code cell. It will:
- ask you to upload file... do so by clicking `Choose Files` and select `kaggle.json`
- place it in the correct folder
- set permissions
- download + unzip the dataset


In [None]:
!pip -q install kaggle

from google.colab import files
import os
from pathlib import Path

# Upload kaggle.json
uploaded = files.upload()
assert "kaggle.json" in uploaded, "Please upload kaggle.json"

# Move to the correct Kaggle folder
kaggle_dir = Path("/root/.kaggle")
kaggle_dir.mkdir(parents=True, exist_ok=True)
(Path("kaggle.json")).replace(kaggle_dir / "kaggle.json")

# Fix permissions (required by Kaggle)
!chmod 600 /root/.kaggle/kaggle.json

# Download + unzip dataset
!kaggle datasets download -d datamunge/sign-language-mnist -p /content/data --unzip

!ls -lah /content/data


## Load dataset

This dataset is provided as CSVs:
- `sign_mnist_train.csv`
- `sign_mnist_test.csv`

Each row contains:
- `label` (class id)
- 784 pixel columns (28×28 flattened)

We’ll split the Kaggle train CSV into train/val so we can compare models fairly.


In [None]:
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split

train_path = "/content/data/sign_mnist_train.csv"
test_path  = "/content/data/sign_mnist_test.csv"

train_df = pd.read_csv(train_path)
test_df  = pd.read_csv(test_path)

class SignLanguageMNIST(Dataset):
    def __init__(self, df):
        self.y = df["label"].astype(np.int64).values
        self.x = df.drop(columns=["label"]).astype(np.float32).values

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        x = self.x[idx] / 255.0
        x = torch.tensor(x, dtype=torch.float32).view(1, 28, 28)
        y = torch.tensor(self.y[idx], dtype=torch.long)
        return x, y

num_classes = 26
CLASS_NAMES = [chr(ord("A") + i) for i in range(26)]  # A..Z


full_train = SignLanguageMNIST(train_df)
test_ds    = SignLanguageMNIST(test_df)

# split train into train/val
val_frac = 0.15
val_size = int(len(full_train) * val_frac)
train_size = len(full_train) - val_size

g = torch.Generator().manual_seed(42)
train_ds, val_ds = random_split(full_train, [train_size, val_size], generator=g)

batch_size = 256
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,  num_workers=2, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


## Visualize Sign Language MNIST

We’ll:
- map numeric labels → letters (the dataset excludes **J** and **Z** because they require motion) :contentReference[oaicite:0]{index=0}
- show a small image grid
- show class distribution (counts per label)


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Label → Letter mapping (A..Z)
LETTERS = list("ABCDEFGHIJKLMNOPQRSTUVWXYZ")
assert num_classes == 26, f"For A–Z mapping, set num_classes=26 (got {num_classes})"

def label_to_letter(y: int) -> str:
    return LETTERS[int(y)]

# --- Show class counts (force all 26 bars, even if some are zero)
counts = train_df["label"].value_counts().sort_index()
all_counts = np.zeros(26, dtype=int)
for k, v in counts.items():
    all_counts[int(k)] = int(v)

plt.figure()
plt.bar(np.arange(26), all_counts)
plt.title("Train class distribution (A–Z index space)")
plt.xlabel("label id")
plt.ylabel("count")
plt.xticks(np.arange(26), LETTERS, rotation=0)
plt.show()

# --- Show a grid of examples
def show_grid(ds, n=25):
    n = int(n)
    idxs = np.random.choice(len(ds), size=n, replace=False)
    cols = int(np.sqrt(n))
    rows = int(np.ceil(n / cols))

    plt.figure(figsize=(cols*2.0, rows*2.0))
    for i, idx in enumerate(idxs, 1):
        x, y = ds[idx]
        img = x.squeeze(0).numpy()

        ax = plt.subplot(rows, cols, i)
        ax.imshow(img, cmap="gray")
        ax.set_title(f"{int(y)}:{label_to_letter(int(y))}")
        ax.axis("off")
    plt.tight_layout()
    plt.show()

show_grid(full_train, n=25)

## Training utilities

These functions work for MLP, CNN, and ViT.  
Your job is to implement the model classes only.


In [None]:
import torch
from torch import nn
from time import time

def acc(logits, y):
    return (logits.argmax(1) == y).float().mean().item()

def run_epoch(model, loader, optimizer=None):
    train = optimizer is not None
    model.train(train)

    criterion = nn.CrossEntropyLoss()
    total_loss, total_acc, total_n = 0.0, 0.0, 0

    for x, y in loader:
        x, y = x.to(device), y.to(device)

        if train:
            optimizer.zero_grad()

        logits = model(x)
        loss = criterion(logits, y)

        if train:
            loss.backward()
            optimizer.step()

        bs = x.size(0)
        total_loss += loss.item() * bs
        total_acc  += acc(logits, y) * bs
        total_n    += bs

    return total_loss / total_n, total_acc / total_n

@torch.no_grad()
def evaluate(model, loader):
    model.eval()
    return run_epoch(model, loader, optimizer=None)

def fit(model, train_loader, val_loader, epochs=5, lr=1e-3, wd=0.0):
    model = model.to(device)
    optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)

    for ep in range(1, epochs + 1):
        t0 = time()
        tr_loss, tr_acc = run_epoch(model, train_loader, optim)
        va_loss, va_acc = evaluate(model, val_loader)
        print(f"ep {ep:02d} | train {tr_loss:.4f}/{tr_acc:.4f} | val {va_loss:.4f}/{va_acc:.4f} | {time()-t0:.1f}s")


## Confusion matrix

After training a model, use this to see which letters it confuses most often.

In [None]:
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
import torch

@torch.no_grad()
def get_preds(model, loader):
    model.eval()
    ys, ps = [], []
    for x, y in loader:
        x = x.to(device, non_blocking=True)
        logits = model(x)
        pred = logits.argmax(dim=1).cpu().numpy()
        ys.append(y.numpy())
        ps.append(pred)
    return np.concatenate(ys), np.concatenate(ps)

def plot_confusion_matrix(model, loader, class_names, normalize=True, title="Confusion Matrix"):
    y_true, y_pred = get_preds(model, loader)

    K = len(class_names)
    cm = confusion_matrix(y_true, y_pred, labels=np.arange(K))  # <-- key fix

    if normalize:
        row_sums = cm.sum(axis=1, keepdims=True)
        cm = np.divide(
            cm.astype(np.float32),
            row_sums,
            out=np.zeros_like(cm, dtype=np.float32),
            where=(row_sums != 0)  # avoids NaNs for empty classes like J/Z
        )

    plt.figure(figsize=(8, 8))
    plt.imshow(cm, interpolation="nearest")
    plt.title(title)
    plt.colorbar()
    ticks = np.arange(K)
    plt.xticks(ticks, class_names, rotation=90)
    plt.yticks(ticks, class_names)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.show()


# Example usage after training:
# plot_confusion_matrix(model_cnn.to(device), val_loader, LETTERS, normalize=True, title="CNN (val) confusion matrix")


## Model 1 — Simple MLP (baseline)

**Goal:** Build a baseline that flattens the image and uses fully-connected layers.

Hints:
- Flatten 1×28×28 → 784
- 2–4 hidden layers
- Use `nn.ReLU()` or `nn.GELU()`
- Add `nn.Dropout(p)` if you want
- Final layer must output `num_classes`


In [None]:
import torch
from torch import nn

class SimpleMLP(nn.Module):
    def __init__(self, num_classes: int):
        super().__init__()
        # TODO: create layers (recommended: nn.Sequential)
        # self.net = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(28*28, 512),
        #     nn.ReLU(),
        #     ...
        #     nn.Linear(..., num_classes)
        # )
        #
        # Do not apply a Softmax activation in your model's final layer;
        # it is applied internally within the nn.CrossEntropyLoss function for numerical stability.

        # Commment this line after implementing your network
        self.net = None

    def forward(self, x):
        if self.net is None:
            raise NotImplementedError("Define self.net in __init__")
        return self.net(x)

model_simple_mlp = SimpleMLP(num_classes=num_classes)


### Run Training for MLP

In [None]:
fit(model_simple_mlp, train_loader, val_loader, epochs=10, lr=1e-3, wd=1e-4)

val_loss, val_acc = evaluate(model_simple_mlp.to(device), val_loader)
print(f"MLP val_loss: {val_loss:.4f}")
print(f"MLP val_acc:  {val_acc:.4f}")

### Plot Confusion Matrix

In [None]:
plot_confusion_matrix(model_simple_mlp.to(device), val_loader, LETTERS, normalize=True, title="MLP (val) confusion matrix")

## Model 2 — CNN

**Goal:** Build a small CNN that should outperform the MLP.

Hints (keep it simple):
- Use 2–4 conv blocks
- A conv block can be: `Conv2d -> ReLU -> (BatchNorm2d optional) -> MaxPool2d`
- Increase channels: 32 → 64 → 128
- Use `AdaptiveAvgPool2d((1,1))` before the classifier so you don't fight tensor sizes.

In [None]:
import torch
from torch import nn

class ConvBlock(nn.Module):
    """
    ConvBlock = a reusable CNN building block.

    REQUIREMENT (for the assignment):
    - Build this block using nn.Sequential
    - You must decide which layers to include and in what order.

    Suggested ingredients (pick what you want):
    - nn.Conv2d(...)
    - activation: nn.ReLU() or nn.GELU()
    - optional: nn.BatchNorm2d(...)
    - optional: nn.MaxPool2d(...)
    - optional: nn.Dropout2d(...)

    Hint: store the Sequential in self.block and call it in forward().
    """
    def __init__(self, in_ch: int, out_ch: int, use_bn: bool = True, use_pool: bool = True):
        super().__init__()
        # TODO: create a Python list called `layers`
        # TODO: append your Conv2d
        # TODO: append your activation
        # TODO: if use_bn: append BatchNorm2d
        # TODO: if use_pool: append MaxPool2d
        # TODO: wrap it into nn.Sequential and assign to self.block
        self.block = None

    def forward(self, x):
        if self.block is None:
            raise NotImplementedError("Build self.block (nn.Sequential) inside ConvBlock.__init__")
        return self.block(x)


class SimpleCNN(nn.Module):
    """
    SimpleCNN splits the network into:
    - feature_extractor: everything BEFORE flatten (outputs feature maps)
    - classifier: maps flattened features -> num_classes

    REQUIREMENT (for the assignment):
    - feature_extractor MUST be an nn.Sequential of multiple ConvBlocks (and optionally pooling / adaptive pooling)
    - classifier MUST be an nn.Sequential of Linear layers (and optional dropout / activation)

    Tips:
    - Use increasing channels (example idea: 1->32->64->128), but you choose.
    - If you add nn.AdaptiveAvgPool2d((1,1)) at the end of feature_extractor,
      flatten becomes easy because shape becomes (B, C, 1, 1) -> (B, C).
    - Do NOT apply Softmax at the end. nn.CrossEntropyLoss applies it internally.
    """
    def __init__(self, num_classes: int):
        super().__init__()
        # TODO: build self.feature_extractor (nn.Sequential)
        #   It should stack ConvBlocks like:
        #     ConvBlock(...), ConvBlock(...), ConvBlock(...), ...
        #   Optionally end with nn.AdaptiveAvgPool2d((1,1)).

        # TODO: build self.classifier (nn.Sequential)
        #   It should accept the flattened output of feature_extractor
        #   and end with a Linear(..., num_classes)

        self.feature_extractor = None
        self.classifier = None

    def forward(self, x):
        if self.feature_extractor is None or self.classifier is None:
            raise NotImplementedError("Define self.feature_extractor and self.classifier in __init__")
        x = self.feature_extractor(x)   # (B, C, H, W)
        x = torch.flatten(x, 1)         # (B, C*H*W) or (B, C) if using AdaptiveAvgPool2d((1,1))
        return self.classifier(x)       # (B, num_classes)

model_cnn = SimpleCNN(num_classes=num_classes)

### Run Training for CNN

In [None]:
fit(model_cnn, train_loader, val_loader, epochs=8, lr=2e-3, wd=1e-4)
print("CNN val:", evaluate(model_cnn.to(device), val_loader))

### Plot Confusion Matrix

In [None]:
plot_confusion_matrix(model_cnn.to(device), val_loader, LETTERS, normalize=True, title="CNN (val) confusion matrix")

## Compare models on validation

Run this after you have trained the both.

In [None]:
results = {}

results["mlp"] = evaluate(model_simple_mlp.to(device), val_loader)
results["cnn"] = evaluate(model_cnn.to(device), val_loader)

for name, (loss, a) in results.items():
    print(f"{name:>4} | val loss {loss:.4f} | val acc {a:.4f}")

## Inference Helpers

In [None]:
import numpy as np
import torch
from PIL import Image, ImageOps, ImageEnhance

def preprocess_to_sign_mnist(img, out_size=28, invert="auto", enhance_contrast=True):
    """
    Takes a PIL image / np array / path and returns a tensor shaped (1, 1, 28, 28)
    similar to Sign Language MNIST input.

    Steps (MNIST-ish):
    - convert to grayscale
    - resize while preserving aspect ratio
    - pad to square (white background)
    - resize to 28x28
    - optional invert (auto tries both and picks "more ink")
    - normalize to [0,1]
    """
    # --- load
    if isinstance(img, str):
        img = Image.open(img)
    elif isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    elif not isinstance(img, Image.Image):
        raise TypeError("img must be a path (str), PIL.Image, or numpy array")

    # --- grayscale
    img = img.convert("L")

    # --- optional contrast boost (helps phone photos)
    if enhance_contrast:
        img = ImageEnhance.Contrast(img).enhance(1.5)

    # --- resize keeping aspect ratio, then pad to square
    img = ImageOps.contain(img, (out_size, out_size))
    w, h = img.size
    pad_l = (out_size - w) // 2
    pad_t = (out_size - h) // 2
    pad_r = out_size - w - pad_l
    pad_b = out_size - h - pad_t
    img = ImageOps.expand(img, border=(pad_l, pad_t, pad_r, pad_b), fill=255)  # white background

    # ensure exact size
    img = img.resize((out_size, out_size), Image.BILINEAR)

    def to_tensor(pil_img):
        arr = np.array(pil_img).astype(np.float32)
        arr = arr / 255.0
        return torch.tensor(arr).view(1, 1, out_size, out_size)

    if invert == "never":
        x = to_tensor(img)
    elif invert == "always":
        x = to_tensor(ImageOps.invert(img))
    else:
        # auto: try both and choose the one that has "more ink" (darker pixels) in the center
        x1 = to_tensor(img)
        x2 = to_tensor(ImageOps.invert(img))
        # score: mean intensity (lower means more dark ink). pick lower mean.
        score1 = x1.mean().item()
        score2 = x2.mean().item()
        x = x1 if score1 < score2 else x2

    return x, img  # tensor + the final 28x28 PIL image (for display)


@torch.no_grad()
def predict_single_image(model, img, class_names=None, device=None, topk=3, invert="auto"):
    """
    model: your trained model
    img: path / PIL / np array
    class_names: list of class labels to display (length must match model outputs)
    device: "cuda" or "cpu" (defaults to whatever model is on)
    """
    model.eval()

    if device is None:
        device = next(model.parameters()).device

    x, img28 = preprocess_to_sign_mnist(img, out_size=28, invert=invert)
    x = x.to(device)

    logits = model(x)                  # (1, C)
    probs = torch.softmax(logits, dim=1).squeeze(0)  # (C,)

    topk = min(int(topk), probs.numel())
    vals, idxs = torch.topk(probs, k=topk)

    results = []
    for p, i in zip(vals.tolist(), idxs.tolist()):
        label = class_names[i] if class_names is not None else str(i)
        results.append((label, i, p))

    return results, img28


## Single-image inference (upload a photo)

Upload any hand-sign image (photo/screenshot).  
This cell will:
1) upload the image
2) convert it into a **28×28 grayscale MNIST-like** tensor
3) run your trained model on it
4) show the processed 28×28 image + top predictions

**Note:** `CLASS_NAMES` must match your model output size:
- if `num_classes = 26`: `CLASS_NAMES = ["A", "B", ..., "Z"]`
- if `num_classes = 25`: `CLASS_NAMES = ["A", ..., "Y"]`


In [None]:
from google.colab import files
import matplotlib.pyplot as plt

# --- Make sure these exist in your notebook already:
# - preprocess_to_sign_mnist(...)
# - predict_single_image(...)
# - a trained model, e.g. model_cnn or model_simple_mlp
# - device

# 1) Upload image
uploaded = files.upload()
img_path = next(iter(uploaded.keys()))  # first uploaded filename
print("Uploaded:", img_path)

# 2) Class names (adjust if your num_classes differs)
# If you used num_classes = 26:
CLASS_NAMES = [chr(ord("A") + i) for i in range(26)]

# 3) Pick a model to use (change this to whichever you trained)
model_for_inference = model_cnn  # or model_simple_mlp, model_vit, etc.

# 4) Predict
results, img28 = predict_single_image(
    model_for_inference.to(device),
    img_path,
    class_names=CLASS_NAMES,
    topk=5,
    invert="auto"
)

print("\nTop predictions:")
for label, idx, prob in results:
    print(f"- {label} (class {idx}): {prob:.3f}")

# 5) Show the processed 28x28 input that the model actually saw
plt.figure(figsize=(3,3))
plt.imshow(img28, cmap="gray")
plt.title("Processed 28×28 input")
plt.axis("off")
plt.show()

# Bonus (Optional)

## Tiny ViT for 28×28 grayscale (Step-by-step)

Idea:
1) Split image into patches (like “tokens” for vision)
2) Embed each patch into a vector of size `dim`
3) Add a `[CLS]` token + positional embeddings
4) Pass tokens through a Transformer Encoder
5) Use `[CLS]` output to classify into `num_classes`

Recommended beginner-friendly setup:
- `patch_size = 7` → (28/7 = 4) → **16 tokens** (smaller sequence, easier)
- `dim = 128`, `depth = 3`, `heads = 4`

(You *can* use patch_size=4 too, but that makes 49 tokens and can feel harder.)

## Step 1 — Patch Embedding

We want:
- Input:  x of shape (B, 1, 28, 28)
- Output: tokens of shape (B, N, dim)

Using Conv2d:
- `Conv2d(1, dim, kernel_size=patch, stride=patch)`
gives output (B, dim, H', W') where:
- H' = 28/patch
- W' = 28/patch
- N = H'*W'

Then we flatten:
- (B, dim, H', W') → (B, dim, N) → (B, N, dim)

In [None]:
import torch
from torch import nn

class PatchEmbed(nn.Module):
    def __init__(self, dim=128, patch_size=7):
        super().__init__()
        self.dim = dim
        self.patch_size = patch_size

        # TODO: conv that produces patch embeddings
        # self.proj = nn.Conv2d(1, dim, kernel_size=patch_size, stride=patch_size)
        self.proj = None

    def forward(self, x):
        if self.proj is None:
            raise NotImplementedError("TODO: define self.proj in __init__")

        # TODO:
        # x = self.proj(x)            # (B, dim, H', W')
        # x = x.flatten(2)            # (B, dim, N)
        # x = x.transpose(1, 2)       # (B, N, dim)
        # return x
        raise NotImplementedError("TODO: implement PatchEmbed.forward")


## Step 2 — CLS token + Positional embeddings

Transformers need positional info because attention alone doesn’t know order.

We’ll create:
- `cls_token`: (1, 1, dim) learnable
- `pos_embed`: (1, 1+N, dim) learnable

Forward:
- tokens: (B, N, dim)
- cls:    expand → (B, 1, dim)
- concat: (B, 1+N, dim)
- add pos: (B, 1+N, dim)

In [None]:
class AddClsPos(nn.Module):
    def __init__(self, num_patches: int, dim=128):
        super().__init__()
        self.num_patches = num_patches
        self.dim = dim

        # TODO: define cls token + pos embedding as nn.Parameter
        # self.cls_token = nn.Parameter(torch.zeros(1, 1, dim))
        # self.pos_embed = nn.Parameter(torch.zeros(1, 1 + num_patches, dim))
        self.cls_token = None
        self.pos_embed = None

    def forward(self, x):
        # x: (B, N, dim)
        if self.cls_token is None or self.pos_embed is None:
            raise NotImplementedError("TODO: define cls_token and pos_embed")

        # TODO:
        # B = x.size(0)
        # cls = self.cls_token.expand(B, -1, -1)   # (B, 1, dim)
        # x = torch.cat([cls, x], dim=1)           # (B, 1+N, dim)
        # x = x + self.pos_embed                   # (B, 1+N, dim)
        # return x
        raise NotImplementedError("TODO: implement AddClsPos.forward")


## Step 3 — Transformer Encoder

PyTorch gives you:
- `nn.TransformerEncoderLayer`
- `nn.TransformerEncoder`

Key detail: set `batch_first=True` so we can keep (B, S, E).

Input to encoder:
- (B, 1+N, dim)
Output:
- (B, 1+N, dim)

In [None]:
class TokenEncoder(nn.Module):
    def __init__(self, dim=128, depth=3, heads=4, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        # TODO: build encoder stack
        # ff_dim = int(dim * mlp_ratio)
        # layer = nn.TransformerEncoderLayer(
        #     d_model=dim, nhead=heads,
        #     dim_feedforward=ff_dim,
        #     dropout=dropout,
        #     activation="gelu",
        #     batch_first=True
        # )
        # self.encoder = nn.TransformerEncoder(layer, num_layers=depth)
        self.encoder = None

    def forward(self, x):
        if self.encoder is None:
            raise NotImplementedError("TODO: define self.encoder in __init__")
        return self.encoder(x)


## Step 4 — Full TinyViT

Pipeline:
1) PatchEmbed:   (B, 1, 28, 28) → (B, N, dim)
2) AddClsPos:    (B, N, dim)    → (B, 1+N, dim)
3) TokenEncoder: (B, 1+N, dim)  → (B, 1+N, dim)
4) Head on CLS:  use x[:,0]     → (B, dim) → (B, num_classes)

In [None]:
class TinyViT(nn.Module):
    def __init__(self, num_classes: int, dim=128, depth=3, heads=4, patch_size=7, dropout=0.1):
        super().__init__()
        assert 28 % patch_size == 0, "patch_size must divide 28 cleanly"
        num_patches = (28 // patch_size) * (28 // patch_size)

        self.patch = PatchEmbed(dim=dim, patch_size=patch_size)
        self.add_cp = AddClsPos(num_patches=num_patches, dim=dim)
        self.enc = TokenEncoder(dim=dim, depth=depth, heads=heads, dropout=dropout)

        # TODO: classifier head
        # self.head = nn.Linear(dim, num_classes)
        self.head = None

    def forward(self, x):
        if self.head is None:
            raise NotImplementedError("TODO: define self.head in __init__")

        x = self.patch(x)     # (B, N, dim)
        x = self.add_cp(x)    # (B, 1+N, dim)
        x = self.enc(x)       # (B, 1+N, dim)

        # TODO: take CLS token and classify
        # cls = x[:, 0]        # (B, dim)
        # return self.head(cls)
        raise NotImplementedError("TODO: finish TinyViT.forward")


## Debug: check shapes with one batch

Run this after you implement the TODOs in PatchEmbed / AddClsPos / TokenEncoder / TinyViT.
It helps you catch shape mistakes early.

In [None]:
def debug_vit_shapes(model, loader):
    model = model.to(device).eval()
    x, y = next(iter(loader))
    x = x.to(device)

    with torch.no_grad():
        out = model(x)

    print("Input:", tuple(x.shape))
    print("Output:", tuple(out.shape))
    assert out.ndim == 2, "Model output should be (B, num_classes)"
    assert out.shape[1] == num_classes, "Second dim must be num_classes"

# Example:
# model_vit = TinyViT(num_classes=num_classes, dim=128, depth=3, heads=4, patch_size=7)
# debug_vit_shapes(model_vit, train_loader)


## Train TinyViT (once implemented)

ViT often likes:
- lower LR (e.g. 3e-4)
- higher weight decay (e.g. 1e-2 to 5e-2)

Start small and only increase depth/dim if it’s stable.

In [None]:
model_vit = TinyViT(num_classes=num_classes, dim=128, depth=3, heads=4, patch_size=7, dropout=0.1)
fit(model_vit, train_loader, val_loader, epochs=12, lr=3e-4, wd=1e-2)
print("ViT val:", evaluate(model_vit.to(device), val_loader))

## Plot Confusion Matrix



In [None]:
plot_confusion_matrix(model_vit.to(device), val_loader, LETTERS, normalize=True, title="ViT confusion matrix")