# Implementation of ViT

This one is meant to be simple so that I can play with parameters. Taken from [Dive Into Deep Learning chapter 11](https://d2l.ai/chapter_attention-mechanisms-and-transformers/vision-transformer.html)

In [1]:
# imports
import jax
from flax import linen as nn
from jax import numpy as jnp

key = jax.random.key(0)

2024-10-14 21:31:58.140432: W external/xla/xla/service/gpu/nvptx_compiler.cc:893] The NVIDIA driver's CUDA version is 12.5 which is older than the PTX compiler version 12.6.68. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [2]:
def check_shape(a, b):
    if isinstance(a, tuple):
        return a == b
    for i in a:
        if a[i] == b[i]:
            continue
        return False
    return True

In [43]:
class DotProductAttention(nn.Module):
    """Scaled dot product attention.
    https://d2l.ai/chapter_attention-mechanisms-and-transformers/
    attention-scoring-functions.html
    """
    dropout: float

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    @nn.compact
    def __call__(self, queries, keys, values, valid_lens=None,
                 training=False):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.swapaxes(1, 2)
        scores = (queries@(keys.swapaxes(2,3))) / jnp.sqrt(d)
        attention_weights = nn.softmax(scores, valid_lens)
        dropout_layer = nn.Dropout(self.dropout, deterministic=not training)
        return dropout_layer(attention_weights)@values, attention_weights

In [44]:
class MultiHeadAttention(nn.Module):
    """
    Borrowed from
    https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks
    /JAX/tutorial6/Transformers_and_MHAttention.html

    and 

    https://github.com/d2l-ai/d2l-en
    /blob/23d7a5aecceee57d1292c56e90cce307f183bb0a/d2l/jax.py
    """
    embed_dim: int
    num_heads: int
    dropout: float
    use_bias: bool = False

    def setup(self):
        self.weight_projection = nn.Dense(
            3 * self.embed_dim,
            # see https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros
        )
        self.offset_projection = nn.Dense(
            self.embed_dim,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros,
            use_bias=self.use_bias
        )
        self.attention = DotProductAttention(self.dropout)

    def __call__(self, x, mask=None, training=False):
        batch_size, sequence_length, embed_dim = x.shape
        if mask is not None:
            mask = expand_mask(mask)
        stacked_weights = self.weight_projection(x)

        # Seperate the weights from linear outputs
        # weights (QKV)
        stacked_weights = stacked_weights.reshape(batch_size, sequence_length, self.num_heads, -1)
        # transpose to [batch, head, sequence_length, dimensions]
        stacked_weights = stacked_weights.transpose(0, 2, 1, 3)
        q, k, v = jnp.array_split(stacked_weights, 3, axis=-1)

        # Determine the outputs
        values, attention = self.attention(q, k, v)
        # transpose to [batch, sequence_length, head, dimensions]
        values = values.transpose(0, 2, 1, 3)
        values = values.reshape(batch_size, sequence_length, embed_dim)
        o = self.offset_projection(values)

        return o, attention


# Check the implementation of MHA
this_key, key = jax.random.split(key)
x = jax.random.normal(this_key, (3, 16, 128))
MHA = MultiHeadAttention(embed_dim=128, num_heads=4, dropout=0.5)
this_key, key = jax.random.split(key)
params = MHA.init(this_key, x)['params']
out, attention = MHA.apply({'params': params}, x)

assert check_shape(out.shape, x.shape)
assert check_shape(attention.shape, (3, 4, 16, 16))

## Patch Embedding

Split the image into patches, then linearly project the flattened patches. AKA convolution.

In [45]:
class PatchEmbedding(nn.Module):
    img_size: int = 96
    patch_size: int = 16
    num_hiddens: int = 512

    def setup(self):
        def _make_tuple(x):
            if not isinstance(x, (list, tuple)):
                return (x, x)
            return x

        img_size, patch_size = _make_tuple(self.img_size), _make_tuple(self.patch_size)
        self.num_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])

        self.conv = nn.Conv(self.num_hiddens, kernel_size=patch_size,
                           strides=patch_size, padding='SAME')

    def __call__(self, X):
        X = self.conv(X)
        return X.reshape((X.shape[0], -1, X.shape[3]))


# Check the implementation of the patch embedding
img_size, patch_size, num_hiddens, batch_size = 96, 16, 512, 4
patch_emb = PatchEmbedding(img_size, patch_size, num_hiddens)
X = jnp.zeros((batch_size, img_size, img_size, 3))
this_key, key = jax.random.split(key)
output, _ = patch_emb.init_with_output(this_key, X)

assert check_shape((batch_size, (img_size//patch_size)**2, num_hiddens), output.shape)

## ViT Encoder Stage

Normalization occurs before the multi-head attention.

In [51]:
class ViTMLP(nn.Module):
    mlp_num_hiddens: int
    mlp_num_outputs: int
    dropout: float = 0.5

    @nn.compact
    def __call__(self, x, training=False):
        x = nn.Dense(self.mlp_num_hiddens)(x)
        x = nn.gelu(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        x = nn.Dense(self.mlp_num_outputs)(x)
        x = nn.Dropout(self.dropout, deterministic=not training)(x)
        return x


class ViTBlock(nn.Module):
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    dropout: float
    use_bias: bool = False

    def setup(self):
        self.attention = MultiHeadAttention(self.num_hiddens, self.num_heads,
                                            self.dropout, self.use_bias)
        self.mlp = ViTMLP(self.mlp_num_hiddens, self.num_hiddens, self.dropout)

    @nn.compact
    def __call__(self, x, valid_lens=None, training=False):
        x = x + self.attention(*([nn.LayerNorm()(x)]),
                               valid_lens)[0]
        return x + self.mlp(nn.LayerNorm()(x), training)


x = jnp.ones((2, 100, 24))
encoder_blk = ViTBlock(24, 48, 8, 0.5)
this_key, key = jax.random.split(key)
assert check_shape(encoder_blk.init_with_output(this_key, x)[0].shape, x.shape)

## Complete ViT implementation

In [53]:
class ViT(nn.Module):
    """
    Vision Transformer
    """
    image_size: int
    patch_size: int
    num_hiddens: int
    mlp_num_hiddens: int
    num_heads: int
    num_blks: int
    emb_dropout: float
    blk_dropout: float
    lr: float = 0.1
    use_bias: bool = False
    num_classes: int = 10
    training: bool = False

    def setup(self):
        self.patch_embedding = PatchEmbedding(self.img_size, self.patch_size, self.num_hiddens)
        self.cls_token = self.param('cls_token', nn.initializers.zeros, 1, 1, self.num_hiddens)

        num_steps = self.patch_embedding.num_patches + 1

        # Positional Embeddings
        self.pos_embedding = self.param('pos_embed', nn.initializers.normal(), 1, num_steps, self.num_hiddens)
        self.blks = [ViTBlock(self.num_hiddens, self.mlp_num_hiddens, self.num_heads, self.blk_dropout, self.use_bias)
                   for _ in range(self.num_blks)]
        self.head = nn.Sequential([nn.LayerNorm(), nn.Dense(self.num_classes)])

    @nn.compact
    def __call__(self, x):
        x = self.patch_embedding(x)
        x = jnp.concatenate((jnp.tile(self.cls_token, (x.shape[0], 1, 1)), x), 1)
        x = nn.Dropout(self.emb_dropout, deterministic=not self.training)(x + self.pos_embedding)
        for blk in self.blks:
            x = blk(x, training=self.training)
        return self.head(x[:, 0])

## JIT

## Train

In [None]:
img_size = 96
patch_size = img_size / 6
num_hiddens, mlp_num_hiddens, num_heads, num_blks = 512, 2048, 8, 2
emb_dropout, blk_dropout, lr = 0.1, 0.1, 0.1
model = ViT(img_size, patch_size, num_hiddens, mlp_num_hiddens,
            num_heads, num_blks, emb_dropout, blk_dropout, lr)

In [None]:
import torchvision
import torchvision.transforms as transforms

# PyTorch TensorBoard support
from torch.utils.tensorboard import SummaryWriter
from datetime import datetime


transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

# Create datasets for training & validation, download if necessary
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)

# Create data loaders for our datasets; shuffle for training, not for validation
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

# Class labels
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

## Inference