# Vision Transformer

- skip_showdoc: true
- skip_exec: true


In [None]:
import jax
from jax import numpy as jnp, random as jrand, tree_util as jt
import optax
import numpy as np
import einops

# nn specific
import flax.linen as nn

# data specific
import jax_dataloader as jdl
import torchvision

# utils
import functools as ft
import matplotlib.pyplot as plt
from dataclasses import dataclass
from tqdm.auto import tqdm

In [None]:
transforms = torchvision.transforms

class ToNumpy:
    def __call__(self, x): # (H, W, C)
        return np.array(x) / 255.0
    
@dataclass
class Normalize:
    mean: list[float] 
    std: list[float]
    inplace: bool = False

    def __call__(self, x: np.ndarray):
        if x.ndim < 3:
            raise ValueError(
                f"Expected tensor to be a tensor image of size (..., C, H, W). "
                f"Got x.shape = {x.shape}"
            )
        if not self.inplace:
            x = x.copy()
        
        dtype = x.dtype
        mean = np.asarray(self.mean, dtype=dtype)
        std = np.asarray(self.std, dtype=dtype)
        if np.any(self.std) == 0:
            raise ValueError(f"std evaluated to zero after conversion to {dtype}, "
                             f"leading to division by zero.")
        if mean.ndim == 1:
            mean = einops.rearrange(mean, 'C -> 1 1 C')
        if std.ndim == 1:
            std = einops.rearrange(std, 'C -> 1 1 C')
        
        return (x - mean) / std

In [None]:
class PositionalEmbedding(nn.Module):
    dtype = jnp.float32

    @nn.compact
    def __call__(self, x):
        batch_size, seq_len, emb_dim = x.shape
        pos_emb_shape = (1, seq_len, emb_dim)
        pe = self.param('positional_embedding', 
                        nn.initializers.normal(stddev=0.02), pos_emb_shape)
        return x + pe

In [None]:
class PatchEmbedding(nn.Module):
    num_hiddens: int
    dtype = jnp.float32

    @nn.compact
    def __call__(self, x: jax.Array):
        x = einops.rearrange(
            x, "... (H PH) (W PW) C -> ... (H W) (PH PW C)",
            PH=patch_size, PW=patch_size
        )
        x = nn.Dense(self.num_hiddens, dtype=self.dtype)(x)
        return x

In [None]:
class MLP(nn.Module):
    hidden_dim: int
    dropout_rate: float
    dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train):
        out_dim = x.shape[-1]
        x = nn.Dense(self.hidden_dim, 
                     kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(out_dim, dtype=self.dtype,
                     kernel_init=nn.initializers.xavier_uniform())(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not train)(x)
        return x

In [None]:
class EncoderBlock(nn.Module):
    mlp_dim: int
    num_heads: int
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1
    dtype = jnp.float32

    @nn.compact
    def __call__(self, inputs, train):
        x = nn.LayerNorm()(inputs)
        x = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            dropout_rate=self.attention_dropout_rate,
            deterministic=not train,
            kernel_init=nn.initializers.xavier_uniform()
        )(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not train)(x)
        x = x + inputs

        # mlp
        y = nn.LayerNorm()(x)
        y = MLP(self.mlp_dim, self.dropout_rate)(y, train)
        return x + y


In [None]:
class ViT(nn.Module):
    num_classes: int
    num_layers: int
    hidden_dim: int
    num_heads: int
    mlp_dim: int
    dropout_rate: float
    attention_dropout_rate: float
    dtype = jnp.float32

    @nn.compact
    def __call__(self, x, train):
        B, H, W, C = x.shape
        
        x = PatchEmbedding(self.hidden_dim)(x)
        cls_token = self.param('cls_token', 
                               nn.initializers.normal(stddev=0.02), 
                               (1, 1, self.hidden_dim))
        cls_token = jnp.tile(cls_token, (B, 1, 1)) # (B, 1, hidden_dim)
        x = jnp.concatenate([cls_token, x], axis=1)
        
        x = PositionalEmbedding()(x)
        x = nn.Dropout(self.dropout_rate, deterministic=not train)(x)
        for _ in range(self.num_layers):
            x = EncoderBlock(self.mlp_dim, self.num_heads, 
                             self.dropout_rate, self.attention_dropout_rate)(x, train)
        x = nn.LayerNorm()(x)
        x = x[:, 0]
        x = nn.Dense(self.num_classes)(x)
        return x

In [None]:
@jax.value_and_grad
def compute_grad(
    params,
    model: nn.Module,
    batch: tuple[jnp.ndarray, jnp.ndarray],
    key: jrand.PRNGKey,
):
    img, label = batch
    logits = model.apply(params, img, rngs={'dropout': key}, train=True)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, label)
    return loss.mean()

@ft.partial(jax.jit, static_argnums=(1, 2))
def step(
    params,
    model: nn.Module,
    opt: optax.GradientTransformation,
    opt_state: optax.OptState,
    batch: tuple[jnp.ndarray, jnp.ndarray],
    key: jrand.PRNGKey,
):
    loss, grads = compute_grad(params, model, batch, key)
    updates, opt_state = opt.update(grads, opt_state, params)
    new_params = optax.apply_updates(params, updates)
    return new_params, opt_state, loss

def train(
    model: nn.Module,
    optimizer: optax.GradientTransformation,
    data_loader: jdl.DataLoader,
    epochs: int,
    rng_key: jrand.PRNGKey = jrand.PRNGKey(0),
):
    rng_key, init_key = jrand.split(rng_key)
    xs, ys = next(iter(data_loader))
    params = model.init(init_key, xs, train=False)
    opt_state = optimizer.init(params)
    losses, steps = [], 0

    for epoch in range(epochs):
        for batch in data_loader:
            rng_key, key = jrand.split(rng_key)
            params, opt_state, loss = step(
                params, model, optimizer, opt_state, batch, key
            )
            losses.append(loss)
            steps += 1

            if steps % 500 == 0:
                print(f"Epoch: {epoch}, Step: {steps}, Loss: {loss}")
    return params, losses

In [None]:
# Hyperparameters
lr = 3e-4
dropout_rate = 0.1
beta1 = 0.9
beta2 = 0.99
batch_size = 64 * 2 * 2
patch_size = 4
num_patches = 64
num_steps = 100000
image_size = (32, 32, 3)
embedding_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 4
height, width, channels = image_size
num_classes = 10

In [None]:
transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.Resize((height, width)),
        transforms.RandomHorizontalFlip(),
        ToNumpy(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
        transforms.Resize((height, width)),
        ToNumpy(),
        Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = torchvision.datasets.CIFAR10(
    "/tmp/CIFAR/",
    train=True,
    download=True,
    transform=transform_train,

)

test_dataset = torchvision.datasets.CIFAR10(
    "/tmp/CIFAR/",
    train=False,
    download=True,
    transform=transform_test,
)


Downloading https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz to /tmp/CIFAR/cifar-10-python.tar.gz


100%|██████████| 170498071/170498071 [00:02<00:00, 63361930.39it/s]


Extracting /tmp/CIFAR/cifar-10-python.tar.gz to /tmp/CIFAR/
Files already downloaded and verified


In [None]:
vit = ViT(
    num_classes=num_classes,
    num_layers=num_layers,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    mlp_dim=embedding_dim,
    dropout_rate=dropout_rate,
    attention_dropout_rate=dropout_rate,
)
schedule_fn = optax.warmup_cosine_decay_schedule(
    init_value=0.0, peak_value=lr, warmup_steps=500, decay_steps=10_000
)
opt = optax.chain(
    optax.clip_by_global_norm(1.0),
    optax.adamw(learning_rate=schedule_fn, b1=beta1, b2=beta2),
)
dl = jdl.DataLoader(train_dataset, 'pytorch', batch_size=batch_size, shuffle=True)

In [None]:
params, losses = train(vit, opt, dl, 500)

In [None]:
corrects = []

dl = jdl.DataLoader(test_dataset, 'pytorch', batch_size=batch_size * 4, shuffle=True)
for batch in dl:
    img, label = batch
    logits = vit.apply(params, img, rngs={'dropout': jrand.PRNGKey(0)}, train=False)
    preds = jnp.argmax(logits, axis=-1)
    corrects.append((preds == label))

print(f"Accuracy: {np.concatenate(corrects).mean()}")

Accuracy: 0.8149
