In [23]:
from typing import Iterator

import jax
import jax.numpy as jnp
from flax import nnx
from datasets import load_dataset
from jaxtyping import Array, Float, Int, jaxtyped
from beartype import beartype
from tqdm import tqdm
import optax

In [24]:
# nnx.jit = lambda fn: fn

In [25]:
import lovely_jax

lovely_jax.monkey_patch()

In [26]:
def typed(fn):
    return jaxtyped(fn, typechecker=beartype)

In [27]:
from dataclasses import dataclass


@dataclass
class VITConfig:
    in_feature_shape = (32, 32, 3)
    out_features = 10
    patch_size = 4
    num_layers = 8
    num_heads = 8
    embed_dim = 256
    rngs: nnx.Rngs

In [28]:
class Residual(nnx.Module):
    def __init__(self, module: nnx.Module):
        self.norm = nnx.LayerNorm(
            num_features=config.embed_dim,
            rngs=config.rngs,
        )
        self.module = module

    @typed
    @nnx.jit
    def __call__(self, x: Float[Array, "batch ..."]) -> Float[Array, "batch ..."]:
        x = self.norm(x)
        return x + self.module(x)

In [29]:
class Patchify(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.conv = nnx.Conv(
            in_features=config.in_feature_shape[2],
            out_features=config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            strides=(config.patch_size, config.patch_size),
            rngs=config.rngs,
        )

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch h w ch"]
    ) -> Float[Array, "batch patches emb"]:
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1, self.config.embed_dim)
        cls_token = jax.nn.initializers.truncated_normal(stddev=0.02)(
            jax.random.key(0),
            dtype=jnp.float32,
            shape=(x.shape[0], 1, self.config.embed_dim),
        )
        x = jnp.concatenate([cls_token, x], axis=1)
        return x

In [30]:
@typed
@nnx.jit
def apply_rope(
    q: Float[Array, "batch n d"],
    k: Float[Array, "batch n d"],
) -> tuple[Float[Array, "batch n d"], Float[Array, "batch n d"]]:
    return q, k  # TODO: implement rope

In [31]:
class AttnBlock(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.qkv = nnx.Linear(
            in_features=config.embed_dim,
            out_features=config.embed_dim * 3,
            rngs=config.rngs,
        )

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch patches emb"]
    ) -> Float[Array, "batch patches emb"]:
        q, k, v = self.qkv(x).split(3, axis=-1)
        q, k = apply_rope(q, k)
        a = nnx.dot_product_attention(q, k, v)
        a = a.reshape(a.shape[0], -1, self.config.embed_dim)
        return a

In [32]:
class MLP(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.norm = nnx.LayerNorm(
            num_features=config.embed_dim,
            rngs=config.rngs,
        )
        self.linear1 = nnx.Linear(
            in_features=config.embed_dim,
            out_features=config.embed_dim * 4,
            rngs=config.rngs,
        )
        self.linear2 = nnx.Linear(
            in_features=config.embed_dim * 4,
            out_features=config.embed_dim,
            rngs=config.rngs,
        )

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch patches emb"]
    ) -> Float[Array, "batch patches emb"]:
        x = self.norm(x)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x

In [33]:
class EncoderBlock(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.mha = Residual(AttnBlock(config=config))
        self.mlp = Residual(MLP(config=config))

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch patches emb"]
    ) -> Float[Array, "batch patches emb"]:
        x = self.mha(x)
        x = self.mlp(x)
        return x

In [34]:
class Encoder(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.layers = nnx.Sequential(
            *[EncoderBlock(config=config) for _ in range(config.num_layers)]
        )

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch patches emb"]
    ) -> Float[Array, "batch patches emb"]:
        return x

In [35]:
class VIT(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.patchify = Patchify(config=config)
        self.encoder = Encoder(config=config)

    @typed
    @nnx.jit
    def __call__(
        self, x: Float[Array, "batch h w ch"]
    ) -> Float[Array, "batch patches emb"]:
        x = self.patchify(x)
        x = self.encoder(x)
        return x

In [36]:
class VITClassifier(nnx.Module):
    def __init__(self, *, config: VITConfig, num_classes: int):
        self.config = config
        self.vit = VIT(config=config)
        self.linear_probe = nnx.Linear(
            in_features=config.embed_dim,
            out_features=num_classes,
            rngs=config.rngs,
        )

    @typed
    @nnx.jit
    def __call__(self, x: Float[Array, "batch h w ch"]) -> Float[Array, "batch c"]:
        x = self.vit(x)  # [batch, patches, emb]
        x = x[:, 0, :]  # [batch, emb]
        x = self.linear_probe(x)  # [batch, num_classes]
        return x

In [37]:
def dataloader(
    X: Float[Array, "n h w ch"], y: Int[Array, "n c"], batch_size: int = 64
) -> Iterator[tuple[Float[Array, "batch h w ch"], Int[Array, "batch c"]]]:
    for i in range(0, len(X), batch_size):
        yield X[i : i + batch_size], y[i : i + batch_size]

In [38]:
dataset = globals().get("dataset") or load_dataset("cifar10")

train_size = len(dataset["train"]) // 10
val_size = len(dataset["test"]) // 10

X_train = jnp.array([dataset["train"][i]["img"] for i in range(train_size)])
y_train = jnp.array([dataset["train"][i]["label"] for i in range(train_size)])
y_train = jax.nn.one_hot(y_train, num_classes=10).astype(jnp.int32)

X_val = jnp.array([dataset["test"][i]["img"] for i in range(val_size)])
y_val = jnp.array([dataset["test"][i]["label"] for i in range(val_size)])
y_val = jax.nn.one_hot(y_val, num_classes=10).astype(jnp.int32)

idx2cls = {i: cls for i, cls in enumerate(dataset["train"].features["label"].names)}
cls2idx = {cls: i for i, cls in idx2cls.items()}
num_classes = len(idx2cls)

X_train = X_train / 255.0
X_train = X_train.astype(jnp.float32)
X_val = X_val / 255.0
X_val = X_val.astype(jnp.float32)

In [39]:
config = VITConfig(rngs=nnx.Rngs(0))
model = VITClassifier(config=config, num_classes=num_classes)
optimizer = nnx.Optimizer(model, optax.adamw(learning_rate=3e-4))

In [40]:
def loss_fn(
    model: VITClassifier, X: Float[Array, "batch h w ch"], y: Int[Array, "batch c"]
) -> Float[Array, ""]:
    logits = model(X)
    loss = optax.softmax_cross_entropy(logits, y).mean()
    return loss, logits

In [41]:
def accuracy(
    logits: Float[Array, "batch c"], y: Int[Array, "batch c"]
) -> Float[Array, ""]:
    return (logits.argmax(axis=-1) == y.argmax(axis=-1)).mean()

In [42]:
@typed
@nnx.jit
def train_step(
    model: VITClassifier,
    X: Float[Array, "batch h w ch"],
    y: Int[Array, "batch c"],
    optimizer: nnx.Optimizer,
    metrics: nnx.Metric,
) -> None:
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, X, y)
    acc = accuracy(logits, y)
    metrics.update(loss=loss, accuracy=acc)
    optimizer.update(grads)

In [43]:
@typed
@nnx.jit
def val_step(
    model: VITClassifier,
    X: Float[Array, "batch h w ch"],
    y: Int[Array, "batch c"],
    metrics: nnx.Metric,
) -> None:
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, X, y)
    acc = accuracy(logits, y)
    metrics.update(loss=loss, accuracy=acc)

In [None]:
EPOCHS = 100

metrics = nnx.metrics.MultiMetric(
    accuracy=nnx.metrics.Average("accuracy"), loss=nnx.metrics.Average("loss")
)
pbar = tqdm(range(EPOCHS))
for epoch in pbar:
    metrics.reset()
    for X, y in dataloader(X_train, y_train, batch_size=64):
        # with jax.checking_leaks():
        train_step(model, X, y, optimizer, metrics)
    m = metrics.compute()
    train_loss = float(m["loss"])
    train_acc = float(m["accuracy"])

    metrics.reset()
    for X, y in dataloader(X_val, y_val, batch_size=64):
        val_step(model, X, y, metrics)
    m = metrics.compute()

    val_loss = float(m["loss"])
    val_acc = float(m["accuracy"])

    pbar.set_description(
        f"Epoch {epoch} train_loss: {train_loss:.4f} train_acc: {train_acc:.4f} val_loss: {val_loss:.4f} val_acc: {val_acc:.4f}"
    )