# Vision Transformer (JAX + FLAX)

In [6]:
import os
import sys
import logging
from PIL import Image
from pathlib import Path
try:
    import flax
except ImportError:
    %pip install flax
    import flax
from flax import nnx
try:
    import jax
except ImportError:
    %pip install jax
    import jax
import jax.numpy as jnp
try:
    import optax
except ImportError:
    %pip install optax
    import optax
from jax.tree_util import tree_map
import tqdm
import torch
from torchvision.transforms import transforms
from torch.utils.data import Dataset, DataLoader, default_collate 
from sklearn.model_selection import train_test_split

## Model Definitions

In [7]:
# Embedding Class
class Embed(nnx.Module):

    # Embedding = Patch Embedding + Cls Token + Pos Embedding
    def __init__(self, config: dict, rngs: nnx.Rngs):
        self.patch_size = config["patch_size"]
        self.embed_dim = config["embed_dim"]
        self.image_height = config["image_size"]
        self.image_width = config["image_size"]
        self.batch_size = config["batch"]
        self.patch_count = (self.image_height * self.image_width) // self.patch_size ** 2
        self.rng = rngs

        self.proj_layer = nnx.Conv(
            in_features=3,
            out_features=self.embed_dim,
            kernel_size=(self.patch_size, self.patch_size),
            strides=self.patch_size,
            rngs=self.rng
        )

        self.cls_token = nnx.Param(
            jax.random.normal(self.rng.params(), (1, 1, self.embed_dim))
        )
        
        self.pos_embedding = nnx.Param(
            jax.random.normal(self.rng.params(), (self.batch_size, self.patch_count + 1, self.embed_dim))
        )

    def __call__(self, x: jnp.array):
        x = self.proj_layer(x)
        x = jnp.reshape(x, (x.shape[0], (x.shape[1] * x.shape[2]), x.shape[3]))
        self.cls_tokens = jnp.tile(self.cls_token, [self.batch_size, 1, 1])
        x = jnp.concatenate([x, self.cls_tokens], axis=1)

        x = x + self.pos_embedding
        return x

# Attention Head
class AttentionHead(nnx.Module):
    def __init__(self, embed_dim: int, attention_head_size: int, bias=True):
        self.embed_dim= embed_dim
        self.attention_head_size = attention_head_size

        # Query, Key and Value Weight matrices
        self.q_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))
        self.k_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))
        self.v_w = nnx.Linear(in_features=self.embed_dim, out_features=self.attention_head_size, use_bias=True, rngs=nnx.Rngs(0))

    def __call__(self, x: jnp.array):
        q_x, k_x, v_x = self.q_w(x), self.k_w(x), self.v_w(x)

        # Calculate QK^T/sqrt(dk)
        attn_out = jnp.matmul(q_x, jnp.matrix_transpose(k_x)) / jnp.sqrt(self.attention_head_size)
        # Apply softmax
        softmax_out = nnx.softmax(attn_out)
        # Obtain the attention value with value
        attn_value = jnp.matmul(softmax_out, v_x)

        return attn_value

# Multiheadattention
class MultiHeadAttention(nnx.Module):
    def __init__(self, config: dict):
        self.embed_dim = config["embed_dim"]
        self.num_of_heads = config["num_of_heads"]

        self.attn_head_size = self.embed_dim // self.num_of_heads
        self.all_head_size = self.attn_head_size * self.num_of_heads

        self.heads = nnx.List()

        for _ in range(self.num_of_heads):
            self.attn_head = AttentionHead(
                embed_dim=self.embed_dim,
                attention_head_size=self.attn_head_size
            )
            self.heads.append(self.attn_head)
        
        self.linear_proj = nnx.Linear(in_features=self.all_head_size, out_features=self.embed_dim, use_bias=True, rngs=nnx.Rngs(0))
        self.dropout = nnx.Dropout(0.3, rngs=nnx.Rngs(0))

    def __call__(self, x: jnp.array):
        attn_outputs = [head(x) for head in self.heads]

        concat_output = jnp.concatenate(attn_outputs, axis=-1)
        proj_output = self.linear_proj(concat_output)
        proj_output = self.dropout(x)

        return proj_output

# MLP class
class MLP(nnx.Module):
    def __init__(self, config: dict):
        self.embed_dim = config["embed_dim"]
        self.intermediate_size = config["intermediate_size"]

        self.linear1 = nnx.Linear(in_features=self.embed_dim, out_features=self.intermediate_size, use_bias=True, rngs=nnx.Rngs(0))
        self.linear2= nnx.Linear(in_features=self.intermediate_size, out_features=self.embed_dim, use_bias=True, rngs=nnx.Rngs(0))
        self.dropout = nnx.Dropout(0.3, rngs=nnx.Rngs(0))

    def __call__(self, x: jnp.array):
        x = self.linear1(x)
        x = nnx.relu(x)
        x = self.linear2(x)
        x = self.dropout(x)

        return x

# Single block
class Block(nnx.Module):
    def __init__(self, config: dict):
        self.config = config
        self.embed_dim = config["embed_dim"]
        self.norm = nnx.BatchNorm(num_features=self.embed_dim, rngs=nnx.Rngs(0))
        self.mha = MultiHeadAttention(self.config)
        self.mlp = MLP(self.config)

    def __call__(self, x: jnp.array):
        norm_out = self.norm(x)
        attn_out = self.mha(x)
        attn_out = x + attn_out
        norm_out = self.norm(attn_out)
        mlp_out = self.mlp(norm_out)
        block_out = mlp_out + norm_out

        return block_out

# Encoder
class Encoder(nnx.Module):
    def __init__(self, config: dict):
        self.config = config
        self.num_of_blocks = config["num_of_blocks"]

        self.blocks = nnx.List()

        for _ in range(self.num_of_blocks):
            block = Block(self.config)
            self.blocks.append(block)


    def __call__(self, x: jnp.array):
        all_attns = []

        for block in self.blocks:
            x = block(x)
        return x

# ViT Classifier
class ViT(nnx.Module):
    def __init__(self, config: dict):
        self.embed_layer = Embed(config, nnx.Rngs(0))
        self.encoder = Encoder(config)
        self.classifier = nnx.Linear(in_features=config["embed_dim"], out_features=config["no_of_classes"], use_bias=True, rngs=nnx.Rngs(0))

    def __call__(self, x: jnp.array):
        x = self.embed_layer(x)
        x = self.encoder(x)
        x = self.classifier(x[:, 0, :])

        return x

## PyTorch Dataset

In [8]:
class FolderDataset(Dataset):
    """
        A Custom dataset class inheriting the PyTorch Dataset Class.
        Receives the the master directory as input which contains the images of different classes in different folders.
    """
    def __init__(self, parent_dir: str, image_size: tuple, resize: bool, train: bool, train_size: float):
        self.parent_dir = parent_dir
        self.paths = list(Path(parent_dir).glob("*/*.jpg"))
        self.train = train
        self.train_size = train_size
        self.resize = resize
        self.image_size = image_size

        self.transform = transforms.Compose([
            transforms.PILToTensor(),
            transforms.Lambda(lambda x: x.permute(1, 2, 0))
        ])

        if self.train:
            self.paths = self.paths[:int(self.train_size * len(self.paths))]
        else:
            self.paths = self.paths[int(self.train_size * len(self.paths)):]

        self.classes, self.class_to_idx = self.get_class_class_idx()
    

    def get_class_class_idx(self):
        """
            Takes in the master directory that contains the subdirectories with the images.
        """
        classes = os.listdir(self.parent_dir)
        class_idx = {}

        for i, c in enumerate(classes):
            class_idx[c] = i
        
        return classes, class_idx

    def load_image(self, index: int):
        image_path = self.paths[index]
        image = Image.open(image_path)

        if self.resize:
            image = image.resize((self.image_size))
    
        return self.transform(image)

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, index: int):
        image = self.load_image(index)
        class_name = self.paths[index].parent.name
        class_idx = torch.tensor(self.class_to_idx[class_name], dtype=torch.int32)

        return image, class_idx



## Training functions

In [11]:
def loss_fn(model: nnx.Module, images: jnp.array, labels: jnp.array):
    logits = model(images)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, logits

@nnx.jit
def train_step(model: nnx.Module, images: jnp.array, labels: jnp.array, optim: nnx.Optimizer):
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    
    optim.update(model, grads)
    
    return loss

@nnx.jit
def eval_step(model: nnx.Module, images: jnp.array, labels: jnp.array, eval_metrics: nnx.MultiMetric):
    loss, logits = loss_fn(model, images, labels)
    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels
    )

def train_one_epoch(model: nnx.Module, train_dataloader: DataLoader, epoch: int, epochs: int, optim: nnx.Optimizer, train_history: dict):
    tqdm_bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"
    total_steps = len(train_dataloader)
    
    model.train()

    with tqdm.tqdm(
        desc=f"[Train] epoch: {epoch} / {epochs}, ",
        total=total_steps,
        bar_format=tqdm_bar_format,
        leave=True
    ) as pbar:
        for images, labels in train_dataloader:
            loss = train_step(model, images, labels, optim)
            train_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)

def eval_model(model: nnx.Module, val_dataloader: DataLoader, epoch: int, eval_metrics: nnx.MultiMetric, val_history: dict):
    model.eval()
    eval_metrics.reset() 

    for val_images, val_labels in val_dataloader:
        loss = eval_step(model, val_images, val_labels, eval_metrics)

    for metric, value in eval_metrics.compute().items():
        val_history[f"val_{metric}"].append(value)
    
    print(f"[Val] epoch: {epoch + 1} / {epochs}")
    print(f"Loss: {val_history['val_loss'][-1]: .4f}")
    print(f"Accuracy: {val_history['val_accuracy'][-1]: .4f}")


In [12]:
if __name__ == "__main__":
    
    # Custom collate function
    def numpy_collate(batch):
        return tree_map(jnp.asarray, default_collate(batch))

    # Create Dataset
    train_data = FolderDataset(
        parent_dir="C:/Users/Kumar/Desktop/csiro-biomass/dataset",
        train=True,
        train_size=0.8,
        resize=True,
        image_size=(224,224),
    )
    val_data = FolderDataset(
        parent_dir="C:/Users/Kumar/Desktop/csiro-biomass/dataset",
        train=False,
        train_size=0.8,
        resize=True,
        image_size=(224,224),
    )

    # Create Dataloaders
    train_dataloader = DataLoader(
        train_data,
        batch_size=16,
        shuffle=True,
        collate_fn=numpy_collate
    )

    val_dataloader = DataLoader(
        val_data,
        batch_size=16,
        shuffle=False,
        collate_fn=numpy_collate
    )

    # Misc Settings
    

    lr_0 = 0.001
    lr_f = 1e-5
    momemtum = 0.8
    epochs = 100
    total_steps = len(train_data)
    
    config_dict = {
        "patch_size": 16,
        "embed_dim": 128,
        "image_size": 224,
        "batch": 16,
        "no_of_classes": 10,
        "num_of_blocks": 6,
        "num_of_heads": 12,
        "intermediate_size": 4 * 128,
    }

    train_history = {
        "train_loss": [], 
    }

    val_history = {
        "val_loss": [],
        "val_accuracy": []
    }
    
    # Model Settings
    model = ViT(config_dict)

    lr_schedule = optax.linear_schedule(lr_0, lr_f, epochs * total_steps)

    optim = nnx.Optimizer(
        model, optax.adam(lr_schedule), wrt=nnx.Param
    )

    eval_metrics = nnx.MultiMetric(
        loss = nnx.metrics.Average("loss"),
        accuracy = nnx.metrics.Accuracy(),
    )
    
    # Training loop
    for epoch in range(epochs):
        train_one_epoch(model, train_dataloader, epoch, epochs, optim, train_history)
        eval_model(model, val_dataloader, epoch, eval_metrics, val_history)

[Train] epoch: 0 / 100, [12/18], loss=0.000819 [00:23<00:11]


KeyboardInterrupt: 

## Legacy code

In [None]:
# Define the loss function
def loss_fn(model: nnx.Module, images: jnp.array, labels: jnp.array):
    logits = model(images)
    loss = optax.losses.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
    return loss, logits

# Define Training and validation step
def train_step(model: nnx.Module, optim: nnx.Optimizer, images: jax.Array, labels: jax.Array):

    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, images, labels)
    
    optim.update(model, grads)
    return loss

def eval_step(model: nnx.Module, images: jax.Array, labels: jax.Array, eval_metrics: nnx.MultiMetric):
    
    loss, logits = loss_fn(model, images, labels)
    eval_metrics.update(
        loss=loss,
        logits=logits,
        labels=labels
    )

train_step = nnx.jit(train_step)
eval_step = nnx.jit(eval_step)

In [6]:
eval_metrics = nnx.MultiMetric(
    loss = nnx.metrics.Average("loss"),
    accuracy = nnx.metrics.Accuracy(),
)

train_history = {
    "train_loss": [], 
}

val_history = {
    "val_loss": [],
    "val_accuracy": []
}

In [None]:
# Configure optimizer, loss, metrics

tqdm_bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]"

lr_0 = 0.001
lr_f = 1e-5
momemtum = 0.8
epochs = 100
total_steps = x_train.shape[0] // 16

config_dict = {
    "patch_size": 4,
    "embed_dim": 128,
    "image_size": 32,
    "batch": 16,
    "no_of_classes": 10,
    "num_of_blocks": 6,
    "num_of_heads": 12,
    "intermediate_size": 4 * 128,
}
train_history = {
    "train_loss": [], 
}


model = ViT(config_dict)

lr_schedule = optax.linear_schedule(lr_0, lr_f, epochs * total_steps)

optim = nnx.Optimizer(
    model, optax.adam(lr_schedule), wrt=nnx.Param
)

def train_one_epoch(epoch: int):
    model.train()

    with tqdm.tqdm(
        desc=f"[Train] epoch: {epoch} / {epochs}, ",
        total=total_steps,
        bar_format=tqdm_bar_format,
        leave=True
    ) as pbar:
        for images, labels in next(train_dataloader):
            loss = train_step(model, optim, images, labels)
            train_history["train_loss"].append(loss.item())
            pbar.set_postfix({"loss": loss.item()})
            pbar.update(1)
    
def eval_model(epoch: int):
    model.eval()
    eval_metrics.reset()

    for val_images, val_labels in next(val_dataloader):
        loss = eval_step(model, val_images, val_labels, eval_metrics)
    
    for metric, value in eval_metrics.compute().items():
        val_history[f"val_{metric}"].append(value)

    print(f"[Val] epoch: {epoch + 1} / {epochs}")
    print(f"Loss: {val_history['val_loss'][-1]:0.4f}")
    print(f"Accuracy: {val_history['val_accuracy'][-1]:0.4f}")
        

In [10]:
len(train_dataloader)

625

In [None]:
for epoch in range(epochs):
    train_one_epoch(epoch)
    eval_model(epoch)