# Implementing Vision Transformers (ViT) using JAX

### By Addison Kalanther, Anirudh Rengarajan, Jake Bringetto, and Naasir Farooqi

## Introduction

Vision Transformers (ViT) are a group of models that uses a pure transformer model to perform various vision tasks, such as image classification and object detection. In this homework, we are going to compare a classical CNN model to a ViT model and observe how ViT is able to obtain state-of-the-art results in image classification. This will be done using JAX, a machine learning framework from Google with increasing usage in research environments, and Equinox, an easy-to-use JAX library for implenting PyTorch-like neural networks.

The objective of this homework is to allow students to become familiar with Vision Transformers, as well as their advantages and disadvantages compared to the more common CNN models. In addition, students will become familiar with JAX, better preparing them for research and academia if they wish to pursue it.

## Vision Transformers

ViTs are a class of models using only attention to perform target tasks, without the need for convolutions, which are used to exploit spacial locality and translational invariance. Despite this, it is found that vision transformers outperform CNNs when given a large set of data. 

ViTs work by splitting an input image into fixed-size patches, linearly embedding them and then adding a positional encoding. This process is similar to how words are processed in a traditional transformer model. A diagram is added below to help visualize this process.
<br><br>
<img src="https://github.com/google-research/vision_transformer/blob/main/vit_figure.png?raw=true">

## Setup Code

In [None]:
%pip install einops equinox optax

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[K     |████████████████████████████████| 41 kB 137 kB/s 
[?25hCollecting equinox
  Downloading equinox-0.9.2-py3-none-any.whl (88 kB)
[K     |████████████████████████████████| 88 kB 2.9 MB/s 
[?25hCollecting optax
  Downloading optax-0.1.4-py3-none-any.whl (154 kB)
[K     |████████████████████████████████| 154 kB 40.8 MB/s 
[?25hCollecting jaxtyping>=0.2.5
  Downloading jaxtyping-0.2.8-py3-none-any.whl (17 kB)
Collecting typeguard>=2.13.3
  Downloading typeguard-2.13.3-py3-none-any.whl (17 kB)
Collecting chex>=0.1.5
  Downloading chex-0.1.5-py3-none-any.whl (85 kB)
[K     |████████████████████████████████| 85 kB 2.3 MB/s 
Installing collected packages: typeguard, jaxtyping, chex, optax, equinox, einops
  Attempting uninstall: typeguard
    Found existing installation: typeguard 2.7.1
    Uninstalling typeguard-2

In [None]:
### Import necessary libraries and prepare dataset ###

import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import einops

from PIL import Image
from torch import nn, Tensor, optim
from einops.layers.torch import Rearrange, Reduce

import equinox as eqx
import equinox.nn as enn

import jax
import jax.nn as jnn
import jax.numpy as jnp
import optax

from typing import Optional, Any

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

training_set = torchvision.datasets.CIFAR10('./', train=True, download=True, transform=transform)
training_dataloader = torch.utils.data.DataLoader(training_set, batch_size=128, shuffle=True, num_workers=2)

test_set = torchvision.datasets.CIFAR10('./', train=False, download=True, transform=transform)
test_dataloader = torch.utils.data.DataLoader(test_set, batch_size=128, shuffle=False, num_workers=2)

Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to ./cifar-10-python.tar.gz


  0%|          | 0/170498071 [00:00<?, ?it/s]

Extracting ./cifar-10-python.tar.gz to ./
Files already downloaded and verified


## Creating a ViT via PyTorch

Before implementing a ViT using JAX, students will become familiar with ViTs by implementing one via the familiar PyTorch framework and einops, a framework for easier tensor manipulation.

To compare performance and to be used as reference for later parts of this homework, we have implemented LeNet-5 in PyTorch for you. 

In [None]:
### LeNet via PyTorch ###

class LeNet(nn.Module):
  def __init__(self):
    super().__init__()

    self.conv1 = nn.Conv2d(3, 6, 5)
    self.pool = nn.MaxPool2d(2, 2)
    self.conv2 = nn.Conv2d(6, 16, 5)
    self.fc1 = nn.Linear(16 * 5 * 5, 120)
    self.fc2 = nn.Linear(120, 84)
    self.fc3 = nn.Linear(84, 10)

  def forward(self, x: Tensor) -> Tensor:
    x = self.pool(F.relu(self.conv1(x)))
    x = self.pool(F.relu(self.conv2(x)))
    x = x.view(-1, 16  * 5 * 5) #flatten
    x = F.relu(self.fc1(x))
    x = F.relu(self.fc2(x))
    x = F.softmax(self.fc3(x), dim=1)
    return x

def test_pt(model):
  correct = 0
  total = 0

  model.eval()
  for data in test_dataloader:
    inputs, labels = data
    inputs = inputs.to(device)
    labels = labels.to(device)

    outputs = model(inputs)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum().item()

  return correct / total

def train_pt(model, learning_rate, momentum, epochs, verbose=False):
  criterion = nn.CrossEntropyLoss()
  optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum)

  for epoch in range(epochs):
    for i, data in enumerate(training_dataloader, 0):
      inputs, labels = data
      inputs = inputs.to(device)
      labels = labels.to(device)

      optimizer.zero_grad()

      outputs = model(inputs)
      loss = criterion(outputs, labels)
      loss.backward()
      optimizer.step()

      if verbose and i % 20 == 0:
         print('Batch: {}, Loss: {:.3f}'.format(i, loss))

    acc = test_pt(model)
    print('Epoch: {}, Loss: {:.3f}, Test Accuracy: {:.3f}'.format(epoch, loss, acc))
  return acc

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [None]:
### TODO: Tune hyperparameters such that LeNet-5 achieves an accuracy of over 45% ###
learning_rate_lenet = 0.02 
momentum_lenet = 0.96
epochs_lenet = 5

lenet = LeNet().to(torch.device(device))
train_pt(lenet, learning_rate_lenet, momentum_lenet, epochs_lenet)
print('Final Accuracy: {:.3f}'.format(test_pt(lenet)))

Epoch: 0, Loss: 2.195, Test Accuracy: 0.242
Epoch: 1, Loss: 2.105, Test Accuracy: 0.337
Epoch: 2, Loss: 2.061, Test Accuracy: 0.409
Epoch: 3, Loss: 2.037, Test Accuracy: 0.439
Epoch: 4, Loss: 2.027, Test Accuracy: 0.455
Final Accuracy: 0.455


In [None]:
### TODO: Implement a ViT in PyTorch ###
class MultiHeadAttention(nn.Module):
  def __init__(self, emb_size=768, num_heads=8):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads

    self.qkv = nn.Linear(emb_size, emb_size * 3)
    self.fc1 = nn.Linear(emb_size, emb_size)

  def forward(self, x: Tensor, mask: Tensor = None) -> Tensor:
    qkv = einops.rearrange(self.qkv(x), 'b n (h d qkv) -> (qkv) b h n d', h = self.num_heads, qkv = 3)
    queries, keys, values = qkv[0], qkv[1], qkv[2]
    
    dot_product = torch.einsum('bhqd, bhkd -> bhqk', queries, keys)
    if mask is not None:
      fill_value = torch.finfo(torch.float32).min
      dot_product.mask_fill(~mask, fill_value)

    scaling = self.emb_size ** (1/2)
    att = F.softmax(dot_product, dim=-1) / scaling
    out = torch.einsum('bhal, bhlv -> bhav', att, values)
    out = einops.rearrange(out, "b h n d -> b n (h d)")
    out = self.fc1(out)
    return out


class Embedding(nn.Module):
  def __init__(self, patch_size=16, in_channels=3, emb_size=768, img_size=224):
    self.patch_size = patch_size
    super().__init__()

    # Patch Embedding
    self.patch_embedding = nn.Sequential(
        Rearrange('b c (h s1) (w s2) -> b (h w) (s1 s2 c)', s1=patch_size, s2=patch_size),
        nn.Linear(patch_size * patch_size * in_channels, emb_size)
    )

    # CLS Token
    self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))

    # Position Embedding
    self.positions = nn.Parameter(torch.randn((img_size // patch_size) ** 2 + 1, emb_size))

  def forward(self, x: Tensor) -> Tensor:
    # Patch embedding
    x = self.patch_embedding(x)
    # Prepend CLS token (Hint: Einops may be useful here)
    x = torch.cat([einops.repeat(self.cls_token, '() n e -> b n e', b=x.shape[0]), x], dim=1)
    # Add position embeddings
    x += self.positions

    return x

class Encoder(nn.Module):
  def __init__(self, emb_size=768, mlp_exp=4):
    super().__init__()

    # Layernorm
    self.norm = nn.LayerNorm(emb_size)

    # Multi-head Attention
    self.mha = MultiHeadAttention(emb_size)

    # MLP (Hint: Use Sequential Layers with a GELU)
    self.mlp = nn.Sequential(
        nn.Linear(emb_size, mlp_exp * emb_size),
        nn.GELU(), 
        nn.Linear(mlp_exp * emb_size, emb_size) 
    )
    

  def forward(self, x: Tensor) -> Tensor:
    # Residual block 1
    res = x
    # Apply Layernorm
    x = self.norm(x) 
    # Apply Multi-head attention
    x = self.mha(x)
    # Add residual
    x = res + x 

    # Residual block 2
    res = x 
    # Apply Layernorm
    x = self.norm(x) 
    # Apply MLP
    x = self.mlp(x)
    # Add residual
    x = res + x 

    return x

class ClassificationHead(nn.Sequential):
  def __init__(self, emb_size=768, n_classes = 10):
    super().__init__(
        Reduce('b n e -> b e', reduction='mean'),
        nn.LayerNorm(emb_size),
        nn.Linear(emb_size, n_classes)
    )

class ViT(nn.Sequential):
  def __init__(self, img_size=32, in_channels=3, n_classes=10, patch_size=16, emb_size=768, depth=1, enc_mlp_exp=4, **kwargs):
    super().__init__(
        Embedding(patch_size, in_channels, emb_size, img_size), 
        *[Encoder(emb_size, enc_mlp_exp) for _ in range(depth)],
        ClassificationHead(emb_size, n_classes)
    )


In [None]:
learning_rate_vit = 0.01
momentum_vit = 0.9 
epochs_vit = 5
encoder_depth = 5

vit = ViT(depth=encoder_depth).to(torch.device(device))
acc = train_pt(vit, learning_rate_vit, momentum_vit, epochs_vit, verbose=False)

torch.Size([1, 1, 768])
torch.Size([1, 1, 768])


KeyboardInterrupt: ignored

## Getting familiar with JAX by writing a CNN


Before implementing a ViT in JAX, we will get familiar with the framework by first implementing a CNN. To become familiar with JAX and use as a resource throughout this notebook, please look at this JAX 101 tutorial: https://jax.readthedocs.io/en/latest/jax-101/index.html
<br><br>
#### Quick TLDR
JAX is an efficient just-in-time compiler for linear algebra operations. It allows for NumPy code to run not only on the CPU but also the GPU and TPU as well. JAX can be used as a drop-in replacement for NumPy due to Python's implementation of duck-typing.
<br><br>
For practice, we will implement LeNet-5 in JAX via Equinox. 

In [None]:
### Implementing LeNet in JAX via Equinox ###

class LeNetEqx(eqx.Module):
  features: eqx.Module
  classifier: eqx.Module

  def __init__(self, key: Optional["jax.random.PRNGKey"] = None):
    super().__init__()

    if key is None:
      key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 5)

    self.features = enn.Sequential([
        enn.Conv2d(3, 6, kernel_size=5, key=keys[0]),
        enn.Lambda(jnn.relu),
        enn.MaxPool2d(kernel_size=2, stride=2),
        enn.Conv2d(6, 16, kernel_size=5, key=keys[1]),
        enn.Lambda(jnn.relu),
        enn.MaxPool2d(kernel_size=2, stride=2),
    ])

    self.classifier = enn.Sequential([
        enn.Linear(16 * 5 * 5, 120, key=keys[2]),
        enn.Lambda(jnn.relu),
        enn.Linear(120, 84, key=keys[3]),
        enn.Lambda(jnn.relu),
        enn.Linear(84, 10, key=keys[4]),
        
    ])

  def __call__(self, x: jax.Array) -> jax.Array:
    x = self.features(x)
    #x = x.view(-1, 16 * 5 * 5) #flatten
    x = jnp.ravel(x)
    x = self.classifier(x)
    x = jnn.softmax(x)
    return x

  def __iter__(self):
    for attr, value in self.__dict__.items():
      yield attr, value

In [None]:
### Building training and testing functions ###

def train_eqx(model, learning_rate, momentum, epochs, verbose=False):
  optimizer = optax.sgd(learning_rate, momentum=momentum)
  opt_state = optimizer.init(eqx.filter(model, eqx.is_array))

  get_acc = jax.jit(test_eqx)

  @eqx.filter_value_and_grad # allows for (original value, gradient of function) to be returned
  def compute_loss(model, x, y):
    pred_y = jax.vmap(model)(x)
    y = jnn.one_hot(y, 10) 
    softmax_ce = optax.softmax_cross_entropy(pred_y, y) 
    return jnp.mean(softmax_ce) 

  @eqx.filter_jit # allows for Just-in-Time compilation of Python function
  def make_step(model, x, y, opt_state): 
    loss, grads = compute_loss(model, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    model = eqx.apply_updates(model, updates)
    return model, opt_state, loss

  for epoch in range(epochs):
    for i, data in enumerate(training_dataloader, 0):
      inputs, labels = data

      inputs = jnp.array(inputs)
      labels = jnp.array(labels)

      model, opt_state, loss = make_step(model, inputs, labels, opt_state)

      if verbose and i % 20 == 0:
         print('Batch: {}, Loss: {:.3f}'.format(i, loss))

    acc = test_eqx(model)
    print('Epoch: {}, Loss: {:.3f}, Test Accuracy: {:.3f}'.format(epoch, loss, acc))

def test_eqx(model):
  correct = 0
  total = 0

  @eqx.filter_jit
  def get_correct_total(model, inputs, outputs):
    outputs = jax.vmap(model)(inputs)
    predicted = jnp.argmax(outputs, axis=1)
    total = labels.shape[0]
    correct = (predicted == labels).sum()
    return correct, total

  for data in test_dataloader:
    inputs, labels = data

    inputs = jnp.array(inputs)
    labels = jnp.array(labels)

    c, t = get_correct_total(model, inputs, labels)
    correct += c
    total += t

  return correct / total

In [None]:
learning_rate_ln_eqx = 0.05
momentum_ln_eqx = 0.9
epochs_ln_eqx = 5

seed = 123
rng = jax.random.PRNGKey(seed)

lenet_eqx = LeNetEqx(rng)
train_eqx(lenet_eqx, learning_rate_ln_eqx, momentum_ln_eqx, epochs_ln_eqx)

## Creating a ViT using JAX

Now that we have gotten familiar with JAX by writing a model using CNN layers, we will transition into creating a ViT. Where you can, try to utilize the just-in-time (jit) functionality of JAX to speed up training and prediction where possible.
<br><br>


In [None]:
### Helper classes (feel free to use or not use them) ###

class RearrangeEqx(eqx.Module):
    """
    Equinox Module to act as a Rearrange layer (from einops)
    """

    pattern: str
    kwargs: dict

    def __init__(self, pattern, **kwargs):
      self.pattern = pattern
      self.kwargs = kwargs

    def __call__(self, input, **kwargs):
        return einops.rearrange(input, self.pattern, **self.kwargs)

class ReduceEqx(eqx.Module):
    """
    Equinox Module to act as a Reduce layer (from einops)
    """

    pattern: str
    kwargs: dict

    def __init__(self, pattern, **kwargs):
      self.pattern = pattern
      self.kwargs = kwargs

    def __call__(self, input, **kwargs):
        return einops.reduce(input, self.pattern, **self.kwargs)

In [None]:
class MultiHeadAttentionEqx(eqx.Module):
  emb_size: int
  num_heads: int
  qkv: eqx.Module
  fc: eqx.Module

  def __init__(self, emb_size=768, num_heads=8, key: Optional["jax.random.PRNGKey"] = None):
    super().__init__()
    self.emb_size = emb_size
    self.num_heads = num_heads

    if key is None:
      key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 2)

    self.qkv = enn.Linear(emb_size, emb_size * 3, key=keys[0])
    self.fc = enn.Linear(emb_size, emb_size, key=keys[1])

  def __call__(self, x: jax.Array, mask: jax.Array = None, **kwargs) -> jax.Array:
    qkv = einops.rearrange(jax.vmap(self.qkv)(x), 'n (h d qkv) -> (qkv) h n d', h = self.num_heads, qkv = 3)
    queries, keys, values = qkv[0], qkv[1], qkv[2]

    dot_product = jnp.einsum('hqd, hkd -> hqk', queries, keys)
    if mask is not None:
      fill_value = jnp.finfo(jnp.float32).min
      dot_product = jnp.where(mask, fill_value, dot_product)

    scaling = self.emb_size ** (1/2)
    att = jnn.softmax(dot_product, axis=-1) / scaling
    out = jnp.einsum('hal, hlv -> hav', att, values)
    out = einops.rearrange(out, 'h n d -> n (h d)')
    out = jax.vmap(self.fc)(out)
    return out

class EmbeddingEqx(eqx.Module):
  patch_size: int
  rearrange: eqx.Module
  embedding_mlp: eqx.Module
  cls_token: jnp.ndarray
  positions: jnp.ndarray

  def __init__(self, patch_size=16, in_channels=3, emb_size=768, img_size=224, key: Optional["jax.random.PRNGKey"] = None):
    self.patch_size = patch_size
    super().__init__()

    if key is None:
      key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 3)

    # Patch Embedding
    self.rearrange = RearrangeEqx('c (h s1) (w s2) -> (h w) (s1 s2 c)', s1=patch_size, s2=patch_size)
    self.embedding_mlp = enn.Linear(patch_size * patch_size * in_channels, emb_size, key=keys[0])

    # CLS Token
    self.cls_token = jax.random.uniform(keys[1], shape=(1, emb_size))

    # Position Embedding
    self.positions = jax.random.uniform(keys[2], shape=((img_size // patch_size) ** 2 + 1, emb_size))

  def __call__(self, x: jax.Array, **kwargs) -> jax.Array:
    x = self.rearrange(x)
    x = jax.vmap(self.embedding_mlp)(x)
    x = jnp.concatenate([self.cls_token, x], axis=0)
    x += self.positions

    return x

class EncoderEqx(eqx.Module):
  mha: eqx.Module
  mlp: eqx.Module
  norm: eqx.Module

  def __init__(self, emb_size=768, mlp_exp=4, key: Optional["jax.random.PRNGKey"] = None):
    super().__init__()

    if key is None:
      key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 3)

    # LayerNorm
    self.norm = enn.LayerNorm(emb_size)

    # Multi-head Attention
    self.mha = MultiHeadAttentionEqx(emb_size, key=keys[0])

    # MLP
    self.mlp = enn.Sequential([
        enn.Linear(emb_size, mlp_exp * emb_size, key=keys[1]),
        enn.Lambda(jnn.gelu),
        enn.Linear(mlp_exp * emb_size, emb_size, key=keys[2])
    ])


  def __call__(self, x: jax.Array, **kwargs) -> jax.Array:
    # Residual block 1
    res = x
    # Apply Layernorm
    x = self.norm(x) 
    # Apply Multi-head attention
    x = self.mha(x)
    # Add residual
    x = res + x 


    # Residual block 2
    res = x 
    # Apply Layernorm
    x = self.norm(x) 
    # Apply MLP
    x = jax.vmap(self.mlp)(x)
    # Add residual
    x = res + x 

    return x

class ClassificationHeadEqx(enn.Sequential):
  def __init__(self, emb_size=768, n_classes=10, key: Optional["jax.random.PRNGKey"] = None):
    if key is None:
      key = jax.random.PRNGKey(0)

    super().__init__([
        ReduceEqx('n e -> e', reduction='mean'),
        enn.LayerNorm(emb_size),
        enn.Linear(emb_size, n_classes, key=key)
    ])

class ViTEqx(enn.Sequential):
  def __init__(self, img_size=32, in_channels=3, n_classes=10, patch_size=16, emb_size=768, depth=1, enc_mlp_exp=4, key: Optional["jax.random.PRNGKey"] = None, **kwargs):
    if key is None:
      key = jax.random.PRNGKey(0)
    keys = jax.random.split(key, 2 + depth)

    super().__init__(
        [EmbeddingEqx(patch_size, in_channels, emb_size, img_size, key=keys[0]), 
        *[EncoderEqx(emb_size, enc_mlp_exp, key=keys[i + 1]) for i in range(depth)],
        ClassificationHeadEqx(emb_size, n_classes, key=keys[depth + 1])]
    )

In [None]:
learning_rate_vit_eqx = 0.01 # [0.005, 0.01, 0.03, 0.05] BEST LR=0.01
momentum_vit_eqx = 0.9 # [0.8, 0.9, 0.95, 0.99] BEST Momentum=0.9
epochs_vit_eqx = 5
encoder_depth_eqx = 5

seed = 456 # edit the seed to change pseudo-random behavior
rng = jax.random.PRNGKey(seed)

vit_eqx = ViTEqx(key=rng)
acc = train_eqx(vit_eqx, learning_rate_vit_eqx, momentum_vit_eqx, epochs_vit_eqx, verbose=True)
print(acc)

Batch: 0, Loss: 2.432
Batch: 20, Loss: 2.289
Batch: 40, Loss: 2.109
Batch: 60, Loss: 1.911
Batch: 80, Loss: 2.162
Batch: 100, Loss: 2.005


KeyboardInterrupt: ignored