In [20]:
from typing import Iterator

import jax
import jax.numpy as jnp
from flax import nnx
from datasets import load_dataset
from jaxtyping import Array, Float, Int, jaxtyped
from beartype import beartype

In [21]:
import lovely_jax
lovely_jax.monkey_patch()

In [22]:
def typed(fn):
    return jaxtyped(fn, typechecker=beartype)

In [23]:
from dataclasses import dataclass


@dataclass
class VITConfig:
    in_feature_shape = (32, 32, 3)
    out_features = 10
    patch_size = 4
    num_layers = 8
    num_heads = 8
    embed_dim = 256
    rngs: nnx.Rngs

In [24]:
class Residual(nnx.Module):
    def __init__(self, module: nnx.Module):
        self.norm = nnx.LayerNorm(
            num_features=config.embed_dim,
            rngs=config.rngs,
        )
        self.module = module
    
    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch ..."]) -> Float[Array, "batch ..."]:
        x = self.norm(x)
        return x + self.module(x)

In [25]:
class Patchify(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.conv = nnx.Conv(
            in_features=config.in_feature_shape[2],
            out_features=config.embed_dim,
            kernel_size=(config.patch_size, config.patch_size),
            strides=(config.patch_size, config.patch_size),
            rngs=config.rngs,
        )
    
    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch h w c"]) -> Float[Array, "batch patches emb"]:
        x = self.conv(x)
        x = x.reshape(x.shape[0], -1, self.config.embed_dim)
        return x

In [26]:
@typed
@nnx.jit()
def apply_rope(
    q: Float[Array, "batch n d"],
    k: Float[Array, "batch n d"],
) -> tuple[Float[Array, "batch n d"], Float[Array, "batch n d"]]:
    return q, k  #TODO: implement rope

In [27]:
class AttnBlock(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.qkv = nnx.Linear(
            in_features=config.embed_dim,
            out_features=config.embed_dim * 3,
            rngs=config.rngs,
        )

    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch patches emb"]) -> Float[Array, "batch patches emb"]:
        q, k, v = self.qkv(x).split(3, axis=-1)
        q, k = apply_rope(q, k)
        a = nnx.dot_product_attention(q, k, v)
        a = a.reshape(a.shape[0], -1, self.config.embed_dim)
        return a

In [28]:
class MLP(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.norm = nnx.LayerNorm(
            num_features=config.embed_dim,
            rngs=config.rngs,
        )
        self.linear1 = nnx.Linear(
            in_features=config.embed_dim,
            out_features=config.embed_dim * 4,
            rngs=config.rngs,
        )
        self.linear2 = nnx.Linear(
            in_features=config.embed_dim * 4,
            out_features=config.embed_dim,
            rngs=config.rngs,
        )

    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch patches emb"]) -> Float[Array, "batch patches emb"]:
        x = self.norm(x)
        x = self.linear1(x)
        x = self.gelu(x)
        x = self.linear2(x)
        return x

In [29]:
class EncoderBlock(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.mha = Residual(AttnBlock(config=config))
        self.mlp = Residual(MLP(config=config))

    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch patches emb"]) -> Float[Array, "batch patches emb"]:
        x = self.mha(x)
        x = self.mlp(x)
        return x

In [30]:
class Encoder(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.layers = nnx.Sequential(
            *[EncoderBlock(config=config) for _ in range(config.num_layers)]
        )

    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch patches emb"]) -> Float[Array, "batch patches emb"]:
        return x

In [31]:
class VIT(nnx.Module):
    def __init__(self, *, config: VITConfig):
        self.config = config
        self.patchify = Patchify(config=config)
        self.encoder = Encoder(config=config)
    
    @typed
    @nnx.jit()
    def __call__(self, x: Float[Array, "batch h w c"]) -> Float[Array, "batch patches emb"]:
        x = self.patchify(x)
        x = self.encoder(x)
        return x

In [32]:
def dataloader(
    X: Float[Array, "n h w c"], y: Int[Array, "n c"], batch_size: int = 64
) -> Iterator[tuple[Float[Array, "batch h w c"], Int[Array, "batch c"]]]:
    for i in range(0, len(X), batch_size):
        yield X[i : i + batch_size], y[i : i + batch_size]

In [33]:
dataset = load_dataset("cifar10")

X_images = jnp.array([dataset["train"][i]["img"] for i in range(1024)])
X_labels = jnp.array([dataset["train"][i]["label"] for i in range(1024)])

X_images = X_images / 255.0
X_images = X_images.astype(jnp.float32)

In [34]:
config = VITConfig(rngs=nnx.Rngs(0))
vit = VIT(config=config)