<a href="https://colab.research.google.com/github/Matan-Vinkler/vit-paper-implementation/blob/main/vit_implementation_tiny.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Vision Transformer (Vit) - From-Scratch Implementation

This notebook contains a **from-scratch reimplementation** of the paper  
[*An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale*](https://arxiv.org/pdf/2010.11929) (Dosovitskiy et al., 2020).

**🎯 Project Goal**
- Read and understand a research paper,
- Translate the architecture into working **PyTorch code** without relying on prebuilt models,
- Validate correctness with unit tests, overfit-one-batch experiments, and CIFAR-10 training.

**🏗️ Components Implemented**
- Patchify / Unpatchify
- Patch projection + CLS token + learnable positional embeddings
- Multi-Head Self Attention (MSA)
- MLP feedforward block
- Transformer Encoder Block (Pre-LN, residuals, dropout)
- Vision Transformer backbone (stacked encoder blocks)
- Classification head (Linear or MLP)

**🧪 Validation**
- Gradient flow and shape checks
- CIFAR-10 training with AdamW, warmup+cosine LR, AMP

**📊 Results**
- On CIFAR-10 resized to 224×224: ~58–60% validation accuracy after 30 epochs with a small ViT.
- Performance is **consistent with the paper’s claim**: ViTs require large-scale data (e.g. ImageNet-21k, JFT) to truly outperform CNNs.


### Table of Content

>[Setup and Imports](#scrollTo=MCZsjq8YsD_7)

>[Define Model Architecture and Algorithms](#scrollTo=cZzKglKCrJvu)

>[Data Load and Preprocessing](#scrollTo=BpFyUIO8rT88)

>[Training Loop and Metrics](#scrollTo=QXzWp1VJr42L)

### Setup and Imports

Importing neccesary libraries:

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.cuda.amp import GradScaler, autocast

import torchvision
from torchvision import transforms

import matplotlib.pyplot as plt

import numpy as np
import random
import math
import time
from typing import Tuple

Setting random seed and checking for available CUDA device:

In [None]:
def set_seed(seed=1337):
    random.seed(seed); np.random.seed(seed); torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = False
    torch.backends.cudnn.benchmark = True

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
set_seed(1337)

### Define Model Architecture and Algorithms

Create out `patchify` and `unpatchify` function.

- `patchify` - split images into non-overlapping patches and flatten them.
- `unpatchify` - reconstruct images from flattened patches.

In [None]:
def patchify(x: torch.Tensor, patch_size: int) -> torch.Tensor:
    """
    Split images into non-overlapping patches and flatten them.

    Args:
        x: Tensor [B, C, H, W]
        patch_size: P, must divide H and W exactly.

    Returns:
        Tensor [B, N, P*P*C], where N = (H//P) * (W//P)
    """
    if x.dim() != 4:
        raise ValueError(f"Expected 4D tensor [B, C, H, W], got shape {tuple(x.shape)}")

    B, C, H, W = x.shape
    P = patch_size

    if (H % P) != 0 or (W % P) != 0:
        raise ValueError(f"patch_size={P} must divide H={H} and W={W} exactly.")

    h = H // P
    w = W // P

    x = x.reshape(B, C, h, P, w, P)
    x = x.permute(0, 2, 4, 3, 5, 1)
    x = x.reshape(B, h * w, P * P * C)

    return x

def unpatchify(patches: torch.Tensor, patch_size: int, img_size: Tuple[int, int], channels: int) -> torch.Tensor:
    """
    Reconstruct images from flattened patches.

    Args:
        patches: Tensor [B, N, P*P*C]
        patch_size: P used in patchify
        img_size: (H, W) of the original image
        channels: C

    Returns:
        Tensor [B, C, H, W]
    """
    if patches.dim() != 3:
        raise ValueError(f"Expected 3D tensor [B, N, P*P*C], got shape {tuple(patches.shape)}")
    B, N, flat = patches.shape
    P = patch_size
    H, W = img_size
    C = channels
    if flat != P * P * C:
        raise ValueError(f"Last dim {flat} != P*P*C = {P*P*C}")
    if (H % P) != 0 or (W % P) != 0:
        raise ValueError(f"patch_size={P} must divide H={H} and W={W} exactly.")
    h = H // P
    w = W // P
    if N != h * w:
        raise ValueError(f"Num patches N={N} != (H//P)*(W//P) = {h*w}")
    # [B, N, P*P*C] -> [B, h, w, P, P, C]
    x = patches.reshape(B, h, w, P, P, C)
    # -> [B, C, h, P, w, P]
    x = x.permute(0, 5, 1, 3, 2, 4)
    # -> [B, C, H, W]
    x = x.reshape(B, C, H, W)
    return x

Now create the linear projection module:

In [None]:
class PatchProjection(nn.Module):
    """
    Linear projection of flattened patches to D-dim embeddings.
    Input:  [B, C, H, W]
    Output: [B, N, D]
    """
    def __init__(self, img_size: int = 224, patch_size: int = 16,
                 in_chans: int = 3, embed_dim: int = 192):
        super().__init__()
        if img_size % patch_size != 0:
            raise ValueError("img_size must be divisible by patch_size.")
        self.img_size = img_size
        self.patch_size = patch_size
        self.in_chans = in_chans
        self.embed_dim = embed_dim
        self.num_patches = (img_size // patch_size) ** 2
        self.proj = nn.Linear(patch_size * patch_size * in_chans, embed_dim)

        nn.init.xavier_uniform_(self.proj.weight)
        if self.proj.bias is not None:
            nn.init.zeros_(self.proj.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, C, H, W = x.shape
        if C != self.in_chans:
            raise ValueError(f"in_chans={self.in_chans} but got C={C}")
        if H != self.img_size or W != self.img_size:
            raise ValueError(f"Expected H=W={self.img_size}, got ({H},{W})")
        patches = patchify(x, self.patch_size)   # [B, N, P*P*C]
        z = self.proj(patches)                   # [B, N, D]
        return z

This module inserts `[CLS]` token at the start of the embedding patches:

In [None]:
class AddCLSToken(nn.Module):
    """
    Adds a learnable [CLS] token to the beginning of the sequence.
    Input:  [B, N, D]
    Output: [B, N+1, D]
    """
    def __init__(self, embed_dim: int):
        super().__init__()
        self.embed_dim = embed_dim
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        nn.init.normal_(self.cls_token, std=0.02)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        B, N, D = z.shape
        if D != self.embed_dim:
            raise ValueError(f"Expected embed_dim={self.embed_dim}, got {D}")
        cls = self.cls_token.expand(B, -1, -1)   # [B, 1, D]
        return torch.cat((cls, z), dim=1)        # [B, N+1, D]

This module adds positional embedding for the patches:

In [None]:
class AddPositionalEmbedding(nn.Module):
    """
    Adds a learnable positional embedding (including index 0 for CLS).
    Input:  [B, N+1, D]
    Output: [B, N+1, D]
    """
    def __init__(self, num_patches: int, embed_dim: int):
        super().__init__()
        self.num_tokens = num_patches + 1     # account for CLS
        self.embed_dim = embed_dim
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_tokens, embed_dim))
        nn.init.normal_(self.pos_embed, std=0.02)

    def forward(self, z: torch.Tensor) -> torch.Tensor:
        B, T, D = z.shape
        if T != self.num_tokens or D != self.embed_dim:
            raise ValueError(f"Expected [B, {self.num_tokens}, {self.embed_dim}], got {tuple(z.shape)}")
        return z + self.pos_embed              # broadcast add

Combine `PatchProjection`, `AddCLSToken` and `AddPositionalEmbedding` into full tokenizer:

In [None]:
class ViTTokenizer(nn.Module):
    """
    Convenience wrapper: image -> patch projection -> add CLS -> add pos.
    Input:  [B, C, H, W]
    Output: [B, N+1, D]
    """
    def __init__(self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, embed_dim: int = 192):
        super().__init__()
        self.patch_proj = PatchProjection(img_size, patch_size, in_chans, embed_dim)
        self.add_cls = AddCLSToken(embed_dim)
        self.add_pos = AddPositionalEmbedding(self.patch_proj.num_patches, embed_dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.patch_proj(x)   # [B, N, D]
        z = self.add_cls(z)      # [B, N+1, D]
        z = self.add_pos(z)      # [B, N+1, D]
        return z

Define our multihead self-attention model:

In [None]:
class MultiHeadSelfAttention(nn.Module):
    """
    Multihead self-attention.
    Input:  [B, T, D]
    Output: [B, T, D]
    """

    def __init__(self, embed_dim: int, num_heads: int, attn_dropout_value: float = 0.0, proj_dropout_value: float = 0.0):
        super().__init__()

        if embed_dim % num_heads != 0:
           raise ValueError("embed_dim must be divisible by num_heads")

        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        self.w_q = nn.Linear(embed_dim, embed_dim)
        self.w_k = nn.Linear(embed_dim, embed_dim)
        self.w_v = nn.Linear(embed_dim, embed_dim)

        self.attn_drop = nn.Dropout(attn_dropout_value)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(proj_dropout_value)

        # Init (xavier init + zero bias)
        for m in [self.w_q, self.w_k, self.w_v, self.proj]:
            nn.init.xavier_uniform_(m.weight)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

    def forward(self, z: torch.Tensor, return_attn: bool = False):
        """
        Args:
            x: [B, T, D]
            return_attn: if True, also returns attention weights [B, h, T, T]
        Returns:
            out: [B, T, D] (and optionally attn)
        """
        B, T, D = z.shape

        if(D != self.embed_dim):
            raise ValueError(f"Expected embed_dim={self.embed_dim}, got {D}")

        # Apply linear projection
        q = self.w_q(z)   # [B, T, D]
        k = self.w_k(z)   # [B, T, D]
        v = self.w_v(z)   # [B, T, D]

        # Split into heads
        def split_heads(t):
            return t.view(B, T, self.num_heads, self.head_dim).permute(0, 2, 1, 3)
        q = split_heads(q)   # [B, h, T, head_dim]
        k = split_heads(k)   # [B, h, T, head_dim]
        v = split_heads(v)   # [B, h, T, head_dim]

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale    # [B, h, T, T]
        attn = scores.softmax(dim=-1)                                 # [B, h, T, T]
        attn = self.attn_drop(attn)

        # Apply attention to values
        z = torch.matmul(attn, v)                                     # [B, h, T, head_dim]
        z = z.permute(0, 2, 1, 3).reshape(B, T, D)                    # [B, T, D]
        z = self.proj(z)                                              # [B, T, D]
        z = self.proj_drop(z)

        if return_attn:
            return z, attn
        else:
            return z

Define out multi-layer perceptron:

In [None]:
class MLP(nn.Module):
    """
    Multi-layer perceptron: Linear -> GELU -> Dropout -> Linear -> Dropout
    Input:  [B, T, D]
    Output: [B, T, D]
    """

    def __init__(self, embed_dim: int, mlp_hidden_mult: int = 4, dropout_value: float = 0.0):
        super().__init__()

        hidden = embed_dim * mlp_hidden_mult
        self.w1 = nn.Linear(embed_dim, hidden)
        self.act = nn.GELU()
        self.drop1 = nn.Dropout(dropout_value)
        self.w2 = nn.Linear(hidden, embed_dim)
        self.drop2 = nn.Dropout(dropout_value)

        nn.init.xavier_uniform_(self.w1.weight)
        nn.init.zeros_(self.w1.bias)
        nn.init.xavier_uniform_(self.w2.weight)
        nn.init.zeros_(self.w2.bias)

    def forward(self, z: torch.Tensor):
        z = self.w1(z)
        z = self.act(z)
        z = self.drop1(z)
        z = self.w2(z)
        z = self.drop2(z)
        return z

Combine `MultiHeadSelfAttention` and `MLP` into `TransformerEncoderBlock`:

In [None]:
class TransformerEncoderBlock(nn.Module):
    """
    Transformer Encoder block:
    y = x + MSA(LN(x))
    z = y + MLP(LN(y))
    Input:  [B, T, D]
    Output: [B, T, D]
    """

    def __init__(self, embed_dim: int, num_heads: int, mlp_hidden_mult: int = 4, attn_dropout_value: float = 0.0, proj_dropout_value: float = 0.0, mlp_dropout_value: float = 1.0):
        super().__init__()

        self.embed_dim = embed_dim

        self.norm1 = nn.LayerNorm(embed_dim)
        self.msa = MultiHeadSelfAttention(embed_dim, num_heads, attn_dropout_value, proj_dropout_value)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = MLP(embed_dim, mlp_hidden_mult, mlp_dropout_value)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 3:
            raise ValueError(f"Expected 3D tensor [B, T, D], got shape {tuple(x.shape)}")
        B, T, D = x.shape
        if D != self.embed_dim:
            raise ValueError(f"Expected embed_dim={self.embed_dim}, got {D}")

        z = x + self.msa(self.norm1(x))
        z = z + self.mlp(self.norm2(z))
        return z

Create classification head - linear:

In [None]:
class ClassificationHeadLinear(nn.Module):
    """
    Classification linear head layer
    Input:  [B, D]
    Output: [B, num_classes]
    """

    def __init__(self, embed_dim: int, num_classes: int):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_classes = num_classes

        self.w1 = nn.Linear(embed_dim, num_classes)

        nn.init.xavier_uniform_(self.w1.weight)
        nn.init.zeros_(self.w1.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 2:
            raise ValueError(f"Expected 2D tensor [B, D], got shape {tuple(x.shape)}")

        B, D = x.shape
        if D != self.embed_dim:
            raise ValueError(f"Expected embed_dim={self.embed_dim}, got {D}")

        return self.w1(x)

Create classification head - MLP:

In [None]:
class ClassificationHeadMLP(nn.Module):
    """
    Classification MLP head layer (with 1 hidden layer)
    Input:  [B, D]
    Output: [B, num_classes]
    """

    def __init__(self, embed_dim: int, num_classes: int, mlp_hidden_mult: int = 4):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.hidden = embed_dim * mlp_hidden_mult

        self.w1 = nn.Linear(embed_dim, self.hidden)
        self.act = nn.Tanh()
        self.w2 = nn.Linear(self.hidden, num_classes)

        nn.init.xavier_uniform_(self.w1.weight)
        nn.init.zeros_(self.w1.bias)
        nn.init.xavier_uniform_(self.w2.weight)
        nn.init.zeros_(self.w2.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if x.dim() != 2:
            raise ValueError(f"Expected 2D tensor [B, D], got shape {tuple(x.shape)}")

        B, D = x.shape
        if D != self.embed_dim:
            raise ValueError(f"Expected embed_dim={self.embed_dim}, got {D}")

        return self.w2(self.act(self.w1(x)))

Combine everything for our Vision Transformer model:

In [None]:
class VisionTransformer(nn.Module):
    """
    Vision Transformer (ViT)
    Paper-faithful ViT:
      - Tokenizer: patch projection + CLS + learnable positional embeddings
      - Stack of L TransformerEncoderBlock (Pre-LN inside each block)
      - Final: y = LN(z_L^0)  (LayerNorm on CLS token only, Eq. 4)
      - Classification head on top of y (linear by default; MLP optional)

    Input:  [B, C, H, W]
    Output: [B, num_classes]
    """

    def __init__(self, img_size: int = 224, patch_size: int = 16, in_chans: int = 3, num_classes: int = 1000, embed_dim: int = 192, depth: int = 6, num_heads: int = 3, mlp_hidden_mult: int = 4, attn_dropout_value: float = 0.0, proj_dropout_value: float = 0.0, mlp_dropout_value: float = 1.0, head_type: str = "linear"):
        super().__init__()

        self.embed_dim = embed_dim
        self.num_classes = num_classes
        self.head_type = head_type

        self.tokenizer = ViTTokenizer(img_size, patch_size, in_chans, embed_dim)

        self.blocks = nn.ModuleList([
            TransformerEncoderBlock(embed_dim, num_heads, mlp_hidden_mult, attn_dropout_value, proj_dropout_value, mlp_dropout_value)
            for _ in range(depth)
        ])

        self.final_ln = nn.LayerNorm(embed_dim)

        if head_type == "linear":
            self.head = ClassificationHeadLinear(embed_dim, num_classes)
        elif head_type == "mlp":
            self.head = ClassificationHeadMLP(embed_dim, num_classes, mlp_hidden_mult)
        else:
            raise ValueError(f"Unkown head_type: {head_type!r}. Use 'linear' or 'mlp'.")

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self.tokenizer(x)
        for block in self.blocks:
            z = block(z)

        cls = z[:,0,:]
        y = self.final_ln(cls)

        logits = self.head(y)
        return logits

    @torch.no_grad()
    def extract_cls(self, x: torch.Tensor) -> torch.Tensor:
        """
        Returns the normalized CLS representation y = LN(z_L^0) without the head.
        """
        z = self.tokenizer(x)
        for block in self.blocks:
            z = block(z)
        cls = z[:,0,:]
        return self.final_ln(cls)

### Data Load and Preprocessing

Load **CIFAR-10** dataset, apply some image preprocessing on the data (using `torchvision.transforms`) and finally create a `Dataloader` object on them:

In [None]:
CIFAR10_MEAN = (0.4914, 0.4822, 0.4465)
CIFAR10_STD  = (0.2470, 0.2435, 0.2616)

def get_cifar10_dataloaders(batch_size = 128, num_workers = 2, img_size = 224):
    train_transform = transforms.Compose([
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
    ])

    val_transform = transforms.Compose([
        transforms.Resize(img_size),
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR10_MEAN, CIFAR10_STD)
    ])

    train_ds = torchvision.datasets.CIFAR10(root='./data/cifar10', train=True, download=True, transform=train_transform)
    val_ds = torchvision.datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=val_transform)

    train_dataloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers)
    val_dataloader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers)

    return train_dataloader, val_dataloader

train_dataloader, val_dataloader = get_cifar10_dataloaders()
print(f"Number of training batches: {len(train_dataloader)}")
print(f"Number of validation batches: {len(val_dataloader)}")

### Training Loop and Metrics

Create an instance of the ViT model for traininig:

In [None]:
num_classes = 10
vit_model = VisionTransformer(
    img_size=224,
    patch_size=16,
    in_chans=3,
    num_classes=num_classes,
    embed_dim=128,
    depth=6,
    num_heads=4,
    mlp_hidden_mult=4,
    attn_dropout_value=0.0,
    proj_dropout_value=0.0,
    mlp_dropout_value=0.1,
    head_type="linear"
).to(device)

print(f"Total number of parameters: {sum(p.numel() for p in vit_model.parameters())}")

Set up and optimizer (Adam with weight-decay) and adding `warmup` + `cosine` to ensure that the model won't explode at the beginning.

In [None]:
def build_warnup_cosine(optimizer, warup_steps, total_steps, min_lr=0.0):
    def lr_lambda(step):
        if step < warup_steps:
            return max(1e-8, float(step + 1) / float(max(1, warup_steps)))
        progress = float(step - warup_steps) / float(max(1, total_steps - warup_steps))
        cosine = 0.5 * (1.0 + math.cos(math.pi * progress))
        min_factor = min_lr / max(1e-8, optimizer.param_groups[0]["lr_initial"])
        return min_factor + (1.0 - min_factor) * cosine
    return LambdaLR(optimizer, lr_lambda)

epoches = 30
base_lr = 5e-4
weight_decay = 0.05
warmup_epochs = 1

optimizer = AdamW(vit_model.parameters(), lr=base_lr, weight_decay=weight_decay)
# Store initial lr for scheduler math
for pg in optimizer.param_groups:
    pg["lr_initial"] = pg["lr"]

total_steps = epoches * len(train_dataloader)
warmup_steps = warmup_epochs * len(train_dataloader)
scheduler = build_warnup_cosine(optimizer, warmup_steps, total_steps, min_lr=1e-5)

Set up accuracy metric function:

In [None]:
def accuracy(logits, targets, topk=(1,)):
    with torch.no_grad():
        maxk = max(topk)
        _, pred = logits.topk(maxk, dim=1)
        pred = pred.t()
        correct = pred.eq(targets.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0)
            res.append((correct_k / targets.size(0)).item())
        return res

Define training loop (one epoch only):

In [None]:
def train_one_epoch(model, data_loader, optimizer, scheduler, scaler, epoch, log_interval=50):
    model.train()
    t0 = time.time()
    run_loss, run_acc = 0.0, 0.0

    for step, (images, targets) in enumerate(data_loader):
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        optimizer.zero_grad()

        with autocast(enabled=torch.cuda.is_available()):
            logits = model(images)
            loss = F.cross_entropy(logits, targets, label_smoothing=0.1)

        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        loss_val = loss.item()
        acc_val = accuracy(logits, targets, topk=(1,))[0]

        run_loss += loss_val
        run_acc += acc_val

        if (step + 1) % log_interval == 0:
            avg_loss = run_loss / (step + 1)
            avg_acc = run_acc / (step + 1)
            lr_now = optimizer.param_groups[0]["lr"]
            print(f"Epoch {epoch} | Step {step+1}/{len(data_loader)} | lr={lr_now:.6f} | Loss={avg_loss:.4f} | Acc={avg_acc:.4f}")

    dt = time.time() - t0
    return run_loss / len(data_loader), run_acc / len(data_loader), dt

@torch.no_grad()
def evaluate(model, data_loader):
    model.eval()
    total_loss, total_acc = 0.0, 0.0

    for images, targets in data_loader:
        images, targets = images.to(device, non_blocking=True), targets.to(device, non_blocking=True)
        logits = model(images)
        loss = F.cross_entropy(logits, targets)

        total_loss += loss.item()
        total_acc += accuracy(logits, targets)[0]

    return total_loss / len(data_loader), total_acc / len(data_loader)

Define training loop (full):

In [None]:
def train_full_loop(model, train_dataloader, val_dataloader, optimizer, num_epochs):
    scaler = GradScaler(enabled=torch.cuda.is_available())

    train_loss_values, train_acc_values = [], []
    val_loss_values, val_acc_values = [], []

    best_val_acc = 0.0
    for epoch in range(1, num_epochs + 1):
        train_loss, train_acc, train_dt = train_one_epoch(model, train_dataloader, optimizer, scheduler, scaler, epoch)
        val_loss, val_acc = evaluate(model, val_dataloader)

        best_val_acc = max(best_val_acc, val_acc)

        print(f"[Epoch {epoch}] train: loss={train_loss:.4f}, acc={train_acc:.2f}% | val: loss={val_loss:.4f}, acc={val_acc:.2f}% | time={train_dt:.1f}s")

        train_loss_values.append(train_loss)
        train_acc_values.append(train_acc)
        val_loss_values.append(val_loss)
        val_acc_values.append(val_acc)

    print(f"Best val acc: {best_val_acc:.2f}%")

    history = {}
    history["train_loss"] = train_loss_values
    history["train_acc"] = train_acc_values
    history["val_loss"] = val_loss_values
    history["val_acc"] = val_acc_values

    return history

In [None]:
history = train_full_loop(vit_model, train_dataloader, val_dataloader, optimizer, epoches)