In [1]:
import jax
from jax import lax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import flax

import optax
import einops

In [2]:
jax.devices()

[StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)]

In [3]:
from ViT import *
from transformer_attention import MSALayerConfig

In [4]:
B = 50
H = W = 16
C = 3
key1, key2 = random.split(random.PRNGKey(0), 2)
x = random.uniform(key1, (H, W, C))
x.shape


(16, 16, 3)

In [5]:
patch_size = 4
embedding_dim = 256
num_patches = H//patch_size
n_classes = 10
n_heads = 4
n_layer = 4
feedforward_dim = 1024

In [6]:
model = ViT(patch_size, embedding_dim, num_patches, n_classes, n_heads, n_layer, feedforward_dim)

In [7]:
params = model.init(key2, x)
params["params"].keys()

(16, 48)
(16, 256)
(17, 256)
(17, 256)
(17, 256)
(17, 256)
(10,)


frozen_dict_keys(['cls_token', 'position_encoding', 'patch_projection', 'transformer_encoder', 'norm', 'head'])

In [8]:
y = model.apply(params, x)
assert y.shape == (n_classes,)

(16, 48)
(16, 256)
(17, 256)
(17, 256)
(17, 256)
(17, 256)
(10,)


In [9]:
batched_apply = jax.vmap(model.apply, (None, 0))

@jax.jit
def batched_apply2(params, v_batched):
     return jax.vmap(model.apply,(None, 0))(params, v_batched)

In [10]:
batched_x = random.uniform(key1, (B, H, W, C))

In [11]:
batched_y = batched_apply(params, batched_x)
assert batched_y.shape == (B, n_classes)

(16, 48)
(16, 256)
(17, 256)
(17, 256)
(17, 256)
(17, 256)
(10,)
