# Vision Transformer (PyTorch Lightning)

## Image Preprocessing

In [1]:
import os
import sys
import logging
try:
    import flax
except ImportError:
    %pip install flax
    import flax
from flax import nnx
try:
    import jax
except ImportError:
    %pip install jax
    import jax
import jax.numpy as jnp
try:
    import optax
except ImportError:
    %pip install optax
    import optax
import tqdm
from sklearn.model_selection import train_test_split

In [2]:
# Embedding Class
class Embed(nnx.Module):

    # Embedding = Patch Embedding + Cls Token + Pos Embedding
    def __init__(self, config, rngs):
        self.patch_size = config["patch_size"]
        self.embed_dim = config["embed_dim"]
        self.image_height = config["image_size"]
        self.image_width = config["image_size"]
        self.batch_size = config["batch"]
        self.patch_count = (self.image_height * self.image_width) // self.patch_size ** 2
        self.rng = rngs

        self.proj_layer = nnx.Conv(
            in_features=3,
            out_features=self.embed_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=self.patch_size,
            rngs=self.rng
        )

        self.cls_token = nnx.Param(
            jax.random.normal(self.rng.params(), (1, 1, self.embed_dim))
        )
        
        self.pos_embedding = nnx.Param(
            jax.random.normal(self.rng.params(), (self.batch_size, self.patch_count + 1, self.embed_dim))
        )

    def __call__(self, x):
        x = self.proj_layer(x)
        x = jnp.reshape(x, (x.shape[0], (x.shape[1] * x.shape[2]), x.shape[3]))
        self.cls_tokens = jnp.tile(self.cls_token, [self.batch_size, 1, 1])
        x = jnp.concatenate([x, self.cls_tokens], axis=1)

        x = x + self.pos_embedding
        return x

# Attention Head
class AttentionHead(nnx.Module):
    def __init__(self, embed_dim, attention_head_size, bias=True):
        self.embed_dim= embed_dim
        self.attention_head_size = attention_head_size

        # Query, Key and Value Weight matrices
        self.q_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))
        self.k_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))
        self.v_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))

    def __call__(self, x):
        q_x, k_x, v_x = self.q_w(x), self.k_w(x), self.v_w(x)

        # Calculate QK^T/sqrt(dk)
        attn_out = jnp.matmul(q_x, jnp.matrix_transpose(k_x)) / jnp.sqrt(self.attention_head_size)
        # Apply softmax
        softmax_out = nnx.softmax(attn_out)
        # Obtain the attention value with value
        attn_value = jnp.matmul(softmax_out, v_x)

        return attn_value

# Multiheadattention
class MultiHeadAttention(nnx.Module):
    def __init__(self, config):
        self.embed_dim = config["embed_dim"]
        self.num_of_heads = config["num_of_heads"]

        self.attn_head_size = self.embed_dim // self.num_of_heads
        self.all_head_size = self.attn_head_size * self.num_of_heads

        self.heads = []

        for _ in range(self.num_of_heads):
            self.attn_head = AttentionHead(
                embed_dim=self.embed_dim,
                attention_head_size=self.attn_head_size
            )
            self.heads.append(self.attn_head)

        self.linear_proj = nnx.Linear(in_features=self.all_head_size, out_features=self.embed_dim, use_bias=True, rngs=nnx.Rngs(0))
        self.dropout = nnx.Dropout(0.3, rngs=nnx.Rngs(0))

    def __call__(self, x):
        attn_outputs = [head(x) for head in self.heads]

        concat_output = jnp.concatenate(attn_outputs, axis=-1)
        proj_output = self.linear_proj(concat_output)
        proj_output = self.dropout(x)

        return proj_output

# MLP class
class MLP(nnx.Module):
    def __init__(self, config):
        self.embed_dim = config["embed_dim"]
        self.intermediate_size = config["intermediate_size"]

        self.linear1 = nnx.Linear(in_features=self.embed_dim, out_features=self.intermediate_size, use_bias=True, rngs=nnx.Rngs(0))
        self.linear2= nnx.Linear(in_features=self.intermediate_size, out_features=self.embed_dim, use_bias=True, rngs=nnx.Rngs(0))
        self.dropout = nnx.Dropout(0.3, rngs=nnx.Rngs(0))

    def __call__(self, x):
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)

        return x

# Single block
class Block(nnx.Module):
    def __init__(self, config):
        self.config = config
        self.embed_dim = config["embed_dim"]
        self.norm = nnx.BatchNorm(num_features=self.embed_dim, rngs=nnx.Rngs(0))
        self.mha = MultiHeadAttention(self.config)
        self.mlp = MLP(self.config)

    def __call__(self, x):
        norm_out = self.norm(x)
        attn_out = self.mha(x)
        attn_out = x + attn_out
        norm_out = self.norm(attn_out)
        mlp_out = self.mlp(norm_out)
        block_out = mlp_out + norm_out

        return block_out

# Encoder
class Encoder(nnx.Module):
    def __init__(self, config):
        self.config = config
        self.num_of_blocks = config["num_of_blocks"]

        self.blocks = []

        for _ in range(self.num_of_blocks):
            block = Block(self.config)
            self.blocks.append(block)


    def __call__(self, x):
        all_attns = []

        for block in self.blocks:
            x = block(x)
        return x

# ViT Classifier
class ViT(nnx.Module):
    def __init__(self, config):
        self.embed_layer = Embed(config, nnx.Rngs(0))
        self.encoder = Encoder(config)
        self.classifier = nnx.Linear(in_features=config["embed_dim"], out_features=config["no_of_classes"], use_bias=True, rngs=nnx.Rngs(0))

    def __call__(self, x):
        x = self.embed_layer(x)
        x = self.encoder(x)
        x = self.classifier(x[:, 0, :])

        return x

### Loading dataset and Creating dataloaders

In [3]:
# Load the NPY files of CIFAR10
base_path = r"D:\cifar-10-python\cifar-10-batches-py"

train_val_images = jnp.load(os.path.join(base_path, "train_val_data.npy"))
train_val_labels = jnp.load(os.path.join(base_path, "train_val_labels.npy"))

train_val_images = train_val_images / 255.0

print(jnp.unique_counts(train_val_labels)[1].shape[0])

x_train, x_val, y_train, y_val = train_test_split(train_val_images, train_val_labels, test_size=0.8, random_state=42)
print(x_train.shape)

class Dataloader:
    def __init__(self, x, y, batch, shuffle):
        self.x = x
        self.y = y
        self.batch = batch
        self.shuffle = shuffle
        self.n_samples = self.x.shape[0]
        self.total_batches = self.n_samples // self.batch
        self.indices = jnp.arange(self.n_samples)
        self.key = jax.random.PRNGKey(0)
        if self.shuffle:
            jax.random.permutation(self.key, self.indices, axis=0)
    
    def __len__(self):
        return self.total_batches
    
    def __iter__(self):
        self.current_batch = 0
        if self.shuffle:
            jax.random.permutation(self.key, self.indices, axis=0)

        return self
    
    def __next__(self):
        if self.current_batch >= self.total_batches:
            raise StopIteration
        
        start_idx = self.current_batch * self.batch
        end_idx = min(start_idx + self.batch, self.n_samples)

        x_batch = self.x[start_idx:end_idx, :, :, :]
        y_batch = self.y[start_idx:end_idx]

        self.current_batch += 1

        return x_batch, y_batch
    
    def get_batch(self):
        pass

    def reset(self):
        self.current_batch = 0
        if self.shuffle:
            jax.random.permutation(self.key, self.indices, axis=0)


train_dataloader = Dataloader(
    x_train, y_train, batch=16, shuffle=False
)

val_dataloader = Dataloader(
    x_val, y_val, batch=16, shuffle=False
)

10
(10000, 32, 32, 3)


In [4]:
# Define the loss function
def loss_fn(model, images, labels):
    logits = model(images)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, logits

In [5]:
# Define Training and validation step
def train_step(model: nnx.Module, optim: nnx.Optimizer, images: jax.Array, labels: jax.Array):

    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    
    optim.update(model, grads)
    return loss

def eval_step(model: nnx.Module, images: jax.Array, labels: jax.Array, eval_metrics: nnx.MultiMetric):
    
    loss, logits = loss_fn(model, images, labels)
    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels
    )

train_step = nnx.jit(train_step)
eval_step = nnx.jit(eval_step)

In [6]:
eval_metrics = nnx.MultiMetric(
    loss = nnx.metrics.Average("loss"),
    accuracy = nnx.metrics.Accuracy(),
)

train_history = {
    "train_loss": [], 
}

val_history = {
    "val_loss": [],
    "val_accuracy": []
}

In [7]:
# Configure optimizer, loss, metrics

tqdm_bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"

lr_0 = 0.001
lr_f = 1e-5
momemtum = 0.8
epochs = 100
total_steps = x_train.shape[0] // 16

config_dict = {
    "patch_size": 4,
    "embed_dim": 128,
    "image_size": 32,
    "batch": 16,
    "no_of_classes": 10,
    "num_of_blocks": 6,
    "num_of_heads": 12,
    "intermediate_size": 4 * 128,
}


model = ViT(config_dict)

lr_schedule = optax.linear_schedule(lr_0, lr_f, epochs * total_steps)

optim = nnx.Optimizer(
    model, optax.adam(lr_schedule), wrt=nnx.Param
)

def train_one_epoch(epoch):
    model.train()

    with tqdm.tqdm(
        desc=f"[Train] epoch: {epoch} / {epochs}, ",
        total=total_steps,
        bar_format=tqdm_bar_format,
        leave=True
    ) as pbar:
        for images, labels in train_dataloader:
            loss = train_step(model, optim, images, labels)
            train_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)
    
def eval_model(epoch):
    model.eval()
    eval_metrics.reset()

    for val_images, val_labels in val_dataloader:
        loss = eval_step(model, val_images, val_labels, eval_metrics)
    
    for metric, value in eval_metrics.compute().items():
        val_history[f"val_{metric}"].append(value)

    print(f"[Val] epoch: {epoch + 1} / {epochs}")
    print(f"Loss: {val_history['val_loss'][-1]:0.4f}")
    print(f"Accuracy: {val_history['val_accuracy'][-1]:0.4f}")
        

In [10]:
len(train_dataloader)

625

In [None]:
for epoch in range(epochs):
    train_one_epoch(epoch)
    eval_model(epoch)

[Train] epoch: 0 / 100, [0/625] [00:00<?]

[Train] epoch: 0 / 100, [625/625], loss=2.15 [02:25<00:00]


[Val] epoch: 1 / 100
Loss: 2.1863
Accuracy: 0.1558


[Train] epoch: 1 / 100, [625/625], loss=2.03 [01:48<00:00]


[Val] epoch: 2 / 100
Loss: 2.1759
Accuracy: 0.1653


[Train] epoch: 2 / 100, [625/625], loss=2.11 [01:49<00:00]


[Val] epoch: 3 / 100
Loss: 2.1741
Accuracy: 0.1599


[Train] epoch: 3 / 100, [625/625], loss=2.09 [01:51<00:00]


[Val] epoch: 4 / 100
Loss: 2.1599
Accuracy: 0.1690


[Train] epoch: 4 / 100, [625/625], loss=2.07 [01:52<00:00]


[Val] epoch: 5 / 100
Loss: 2.1661
Accuracy: 0.1788


[Train] epoch: 5 / 100, [625/625], loss=2.15 [01:49<00:00]


[Val] epoch: 6 / 100
Loss: 2.1614
Accuracy: 0.1746


[Train] epoch: 6 / 100, [625/625], loss=2.1 [01:48<00:00] 


[Val] epoch: 7 / 100
Loss: 2.1549
Accuracy: 0.1822


[Train] epoch: 7 / 100, [307/625], loss=2.1 [00:45<00:39] Exception ignored in: <function _xla_gc_callback at 0x00000270BBA87380>
Traceback (most recent call last):
  File "c:\ProgramData\anaconda3\Lib\site-packages\jax\_src\lib\__init__.py", line 112, in _xla_gc_callback
KeyboardInterrupt: 
[Train] epoch: 7 / 100, [625/625], loss=2.07 [01:37<00:00]


[Val] epoch: 8 / 100
Loss: 2.1694
Accuracy: 0.1649


[Train] epoch: 8 / 100, [121/625], loss=2.23 [00:20<01:22]

[Train] epoch: 8 / 100, [168/625], loss=2.3 [00:28<01:07] 