**Implementation of ViT Model on MNIST Dataset using JAX**

This code defines a Vision Transformer (ViT) model using JAX and Flax. It includes custom components such as a multi-head self-attention mechanism (SelfAttention) and feed-forward layers (dense_proj), organized into transformer blocks (TransformerBlock). The model processes input images by first dividing them into patches and embedding them into a higher-dimensional space (patch_embedding). The data then passes through multiple transformer blocks, each containing attention layers, feed-forward networks, layer normalization, and residual connections. The output is pooled and passed through a classification layer (ViTForClassification). The model is designed for image classification tasks, specifically applied to the MNIST dataset with built-in data augmentation techniques.

In [None]:
import os
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import optax
import flax.linen as nn
from flax.training import train_state
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from jax.random import PRNGKey

# Hyperparameters
exp_name = 'vit-with-10-epochs-jax'  # Experiment name
batch_size = 32
epochs = 10
lr = 1e-5
save_model_every = 0

# Configuration
config = {
    "patch_size": 7,
    "hidden_size": 64,
    "num_hidden_layers": 4,
    "num_attention_heads": 4,
    "intermediate_size": 4 * 64,
    "hidden_dropout_prob": 0.0,
    "attention_probs_dropout_prob": 0.0,
    "image_size": 28,
    "num_classes": 10,
    "num_channels": 1,
}

assert config["hidden_size"] % config["num_attention_heads"] == 0
assert config['intermediate_size'] == 4 * config['hidden_size']
assert config['image_size'] % config['patch_size'] == 0

# Data preparation
def prepare_data(batch_size):
    transform = transforms.Compose([
        transforms.RandomRotation(10),
        transforms.RandomHorizontalFlip(),
        transforms.RandomAffine(5),
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])

    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ])
    test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=test_transform)

    trainloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    testloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return trainloader, testloader

# Transformer block
class TransformerBlock(nn.Module):
    config: dict

    def setup(self):
        self.attention = nn.SelfAttention(
            num_heads=self.config['num_attention_heads'],
            qkv_features=self.config['hidden_size'],
            dropout_rate=self.config['attention_probs_dropout_prob']
        )
        self.dense_proj = nn.Sequential([
            nn.Dense(self.config['intermediate_size']),
            nn.gelu,
            nn.Dense(self.config['hidden_size']),
        ])
        self.layernorm1 = nn.LayerNorm()
        self.layernorm2 = nn.LayerNorm()
        self.dropout = nn.Dropout(self.config['hidden_dropout_prob'])

    def __call__(self, x, train):
        attn_output = self.attention(self.layernorm1(x))
        x = x + self.dropout(attn_output, deterministic=not train)
        proj_output = self.dense_proj(self.layernorm2(x))
        return x + self.dropout(proj_output, deterministic=not train)

# Vision Transformer (ViT)
class ViTForClassification(nn.Module):
    config: dict

    def setup(self):
        self.num_patches = (self.config['image_size'] // self.config['patch_size']) ** 2
        self.patch_embedding = nn.Dense(self.config['hidden_size'])
        self.transformer_blocks = [
            TransformerBlock(self.config) for _ in range(self.config['num_hidden_layers'])
        ]
        self.classifier = nn.Dense(self.config['num_classes'])
        self.dropout = nn.Dropout(self.config['hidden_dropout_prob'])

    def __call__(self, x, train=True):
        batch_size = x.shape[0]
        x = x.reshape(batch_size, self.num_patches, -1)
        x = self.patch_embedding(x)
        for block in self.transformer_blocks:
            x = block(x, train)
        x = x.mean(axis=1)  # Mean pooling
        logits = self.classifier(x)
        return logits

# Training and evaluation utilities
def create_train_state(rng, model, learning_rate):
    params = model.init(rng, jnp.ones((batch_size, config['image_size'], config['image_size'], config['num_channels'])))['params']
    tx = optax.adamw(learning_rate)
    return train_state.TrainState.create(apply_fn=model.apply, params=params, tx=tx)

def train_epoch(state, trainloader, loss_fn):
    train_loss = 0
    for batch in trainloader:
        images, labels = batch
        images = jnp.array(images.numpy())
        labels = jnp.array(labels.numpy())

        def loss_fn_wrapper(params):
            logits = state.apply_fn({'params': params}, images, train=True)
            return loss_fn(logits, labels)

        loss, grads = jax.value_and_grad(loss_fn_wrapper)(state.params)
        state = state.apply_gradients(grads=grads)
        train_loss += loss * images.shape[0]

    return state, train_loss / len(trainloader.dataset)

def evaluate(state, testloader, loss_fn):
    test_loss = 0
    correct = 0
    for batch in testloader:
        images, labels = batch
        images = jnp.array(images.numpy())
        labels = jnp.array(labels.numpy())
        logits = state.apply_fn({'params': state.params}, images, train=False)
        test_loss += loss_fn(logits, labels) * images.shape[0]
        correct += (jnp.argmax(logits, axis=1) == labels).sum()

    accuracy = correct / len(testloader.dataset)
    avg_loss = test_loss / len(testloader.dataset)
    return accuracy, avg_loss

# Loss function
def cross_entropy_loss(logits, labels):
    one_hot_labels = jax.nn.one_hot(labels, config['num_classes'])
    return -jnp.mean(jnp.sum(one_hot_labels * nn.log_softmax(logits), axis=-1))

# Plot metrics
def plot_metrics(train_losses, test_losses, accuracies):
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(12, 6))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Train Loss')
    plt.plot(epochs, test_losses, label='Test Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Train and Test Loss vs Epoch')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, accuracies, label='Accuracy', color='orange')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.title('Accuracy vs Epoch')
    plt.legend()

    plt.tight_layout()
    plt.show()

# Main function
def main():
    trainloader, testloader = prepare_data(batch_size=batch_size)
    model = ViTForClassification(config)
    rng = PRNGKey(0)
    state = create_train_state(rng, model, lr)

    train_losses, test_losses, accuracies = [], [], []

    for epoch in range(epochs):
        state, train_loss = train_epoch(state, trainloader, cross_entropy_loss)
        accuracy, test_loss = evaluate(state, testloader, cross_entropy_loss)

        train_losses.append(train_loss)
        test_losses.append(test_loss)
        accuracies.append(accuracy)

        print(f"Epoch {epoch + 1}: Train Loss = {train_loss:.4f}, Test Loss = {test_loss:.4f}, Accuracy = {accuracy:.4f}")

        if save_model_every > 0 and (epoch + 1) % save_model_every == 0:
            checkpoint_dir = f"./checkpoints/{exp_name}"
            os.makedirs(checkpoint_dir, exist_ok=True)
            checkpoint_path = os.path.join(checkpoint_dir, f"epoch_{epoch + 1}.ckpt")
            with open(checkpoint_path, 'wb') as f:
                f.write(flax.serialization.to_bytes(state))
            print(f"Checkpoint saved to {checkpoint_path}")

    plot_metrics(train_losses, test_losses, accuracies)

if __name__ == '__main__':
    main()

Epoch 1: Train Loss = 1.9852, Test Loss = 1.5666, Accuracy = 0.4714
Epoch 2: Train Loss = 1.4530, Test Loss = 1.2301, Accuracy = 0.5857
Epoch 3: Train Loss = 1.2106, Test Loss = 1.0352, Accuracy = 0.6551
Epoch 4: Train Loss = 1.0779, Test Loss = 0.9362, Accuracy = 0.6836
Epoch 5: Train Loss = 0.9892, Test Loss = 0.8592, Accuracy = 0.7137
Epoch 6: Train Loss = 0.9288, Test Loss = 0.8083, Accuracy = 0.7261
Epoch 7: Train Loss = 0.8792, Test Loss = 0.7696, Accuracy = 0.7412
Epoch 8: Train Loss = 0.8355, Test Loss = 0.7331, Accuracy = 0.7536
Epoch 9: Train Loss = 0.8005, Test Loss = 0.7008, Accuracy = 0.7664
