In [32]:
import jax
from jax import random
import jax.numpy as jnp
from jax import jit, value_and_grad
from einops import rearrange
from tqdm import tqdm

In [2]:
from torchvision.datasets import CIFAR10, Imagenette
from torchvision import transforms
from torch.utils.data import DataLoader

In [3]:
image_size = 64
patch_size = 4
num_patches = (image_size // patch_size) ** 2

num_layers = 4      # number of transfomer layers
hidden_dim = 192    # hidden dimension of each token
mlp_dim = 192*4     # hidden dimension in the MLP

num_classes = 10    # Imagenette number of classes
num_heads = 4       # attention heads
head_dim = hidden_dim//num_heads

In [4]:
key = random.PRNGKey(42)

In [5]:
# initialize vit parameters
vit_parameters = {
        'patch_embed': None,
        'positional_encoding': None,
        'layers': [],
        'final_layer_norm': None,
        'head': [],
        'cls_token': None
}

In [6]:
# for the class token, we just need a single vector of the same size as a token
cls_token = jnp.zeros((1,hidden_dim))
vit_parameters['cls_token'] = cls_token

In [7]:
# for the patch embedding, we need to consider each patch and 3 channels and project it into the hidden dimension
patch_embed_key, key = random.split(key)
patch_embed = random.normal(patch_embed_key, ((3 * patch_size * patch_size), hidden_dim)) * jnp.sqrt(2.0 / (hidden_dim))
vit_parameters['patch_embed'] = patch_embed

In [8]:
# the positional encoding is just a value we add to each patch in the image
pos_enc_key, key = random.split(key)
pos_enc = random.normal(pos_enc_key, (num_patches,  hidden_dim)) * 0.02
vit_parameters['positional_encoding'] = pos_enc

In [9]:
# The head will consider only the class token and project it into the number of classes
head_key, key = random.split(key)
head_params = random.normal(head_key, (hidden_dim, num_classes)) * jnp.sqrt(6.0 / hidden_dim)
head_bias = jnp.zeros(num_classes)
vit_parameters['head'] = (head_params, head_bias)

In [10]:
def initialize_mlp(hidden_dim, mlp_dim, key):
    w1_key, w2_key = random.split(key)

    # Xavier uniform limit for w1 and w2
    limit = jnp.sqrt(6.0 / (hidden_dim + mlp_dim))

    # Xavier uniform initialization for weights
    w1 = random.uniform(w1_key, (hidden_dim, mlp_dim), minval=-limit, maxval=limit)
    b1 = jnp.zeros(mlp_dim)

    w2 = random.uniform(w2_key, (mlp_dim, hidden_dim), minval=-limit, maxval=limit)
    b2 = jnp.zeros(hidden_dim)

    return w1, b1, w2, b2

In [11]:
def initialize_attention(hidden_dim, num_heads, key):
    q_key, k_key, v_key = random.split(key, 3)

    # Limit for Xavier uniform
    fan_in = hidden_dim
    fan_out = head_dim * num_heads
    limit = jnp.sqrt(6.0 / (fan_in + fan_out))

    # Random weights from uniform distribution
    q_w = random.uniform(q_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    q_b = jnp.zeros(fan_out)
    k_w = random.uniform(k_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    k_b = jnp.zeros(fan_out)
    v_w = random.uniform(v_key, (fan_in, fan_out), minval=-limit, maxval=limit)
    v_b = jnp.zeros(fan_out)

    return q_w, k_w, v_w, q_b, k_b, v_b

In [12]:
def initialize_layer_norm(hidden_dim):
    gamma = jnp.ones(hidden_dim)
    beta = jnp.zeros(hidden_dim)
    return gamma, beta

In [13]:
key, *layer_keys = random.split(key, num_layers+1)

for i in range(num_layers):
    mlp_params = initialize_mlp(hidden_dim, mlp_dim, layer_keys[i])
    attn_params = initialize_attention(hidden_dim, num_heads, layer_keys[i])
    ln1_params = initialize_layer_norm(hidden_dim)
    ln2_params = initialize_layer_norm(hidden_dim)
    vit_parameters['layers'].append((mlp_params, attn_params, ln1_params, ln2_params))



# we also have a final layer norm outside the loop
final_layer_norm_key, key = random.split(key)
final_layer_norm_params = initialize_layer_norm(hidden_dim)
vit_parameters['final_layer_norm'] = final_layer_norm_params



In [14]:
vit_parameters['layers']

[((Array([[-0.06414359,  0.02639334, -0.07137322, ..., -0.0058764 ,
            0.02678854,  0.00356656],
          [ 0.01247443, -0.07389854,  0.03269812, ..., -0.07085859,
           -0.07476959, -0.02115796],
          [ 0.03721791,  0.00462557, -0.06657642, ...,  0.06412043,
           -0.01957211, -0.07491516],
          ...,
          [ 0.02423331, -0.00456774, -0.06062945, ...,  0.03453164,
            0.06428294, -0.00506742],
          [ 0.0270672 , -0.02820023,  0.04839184, ...,  0.0513372 ,
           -0.07723889, -0.00386329],
          [ 0.03489704, -0.00888659,  0.00923372, ..., -0.0507099 ,
           -0.07051051,  0.03970635]], dtype=float32),
   Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0.

In [15]:
# First, some utility functions: a relu and a softmax

def relu(input):
    return jnp.maximum(0, input)


def softmax(x, axis=-1):
    x_max = jnp.max(x, axis=axis, keepdims=True)
    x_shifted = x - x_max
    exp_x = jnp.exp(x_shifted)
    return exp_x / jnp.sum(exp_x, axis=axis, keepdims=True)

In [16]:
def mlp(x, mlp_params):

    # unpack the parameters
    w1, b1, w2, b2 = mlp_params

    # out = (Relu(x*w1 + b1))*w2 + b2
    up_proj = relu(jnp.matmul(x, w1) + b1)
    down_proj = jnp.matmul(up_proj, w2) + b2

    return down_proj

In [17]:
def self_attention(x, attn_params):

    # unpack the parameters
    q_w, k_w, v_w, q_b, k_b, v_b = attn_params

    # n and d_k are the sequence length of the input and the hidden dimension
    n, d_k = x.shape

    # project the input into the query, key and value spaces
    q = jnp.matmul(x, q_w) + q_b
    k = jnp.matmul(x, k_w) + k_b
    v = jnp.matmul(x, v_w) + v_b


    # reshape to have heads
    # n, (num_heads head_dim) ->  (n, num_heads, headim) -> (num_heads, n, head_dim)
    q = q.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    k = k.reshape(n, num_heads, head_dim).swapaxes(0, 1)
    v = v.reshape(n, num_heads, head_dim).swapaxes(0, 1)

    # perform multi-head attention
    attention_weights_heads = jnp.matmul(q, jnp.swapaxes(k, -1, -2)) / jnp.sqrt(head_dim)
    attention_weights_heads = jax.nn.softmax(attention_weights_heads, axis=-1)

    # output projection (num_heads, n, head_dim)
    output = jnp.matmul(attention_weights_heads, v)

    # reshape back (n, num_heads * heam_dim)
    output = output.swapaxes(0,1).reshape(n, d_k)

    return output


In [18]:
def layer_norm(x, layernorm_params):
    # a simple layer norm
    gamma, beta = layernorm_params
    mean = jnp.mean(x, axis=-1, keepdims=True)
    var = jnp.var(x, axis=-1, keepdims=True)
    return gamma * (x - mean) / jnp.sqrt(var + 1e-6) + beta

In [19]:
def transformer_block(inp, block_params):

    # unpack the parameters
    mlp_params, attn_params, ln1_params, ln2_params = block_params

    # attention
    x = layer_norm(inp, ln1_params)
    x = self_attention(x, attn_params)
    skip = x + inp

    # mlp
    x = layer_norm(skip, ln2_params)
    x = mlp(x, mlp_params)
    x = x + skip

    return x

In [20]:
def transformer(patches, vit_parameters):

    # reshape image from c,h,w -> num_patches, patch_size*patch_size
    patches = rearrange (patches, 'c (h p1) (w p2) -> (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)

    # embed the patches
    patches = jnp.matmul(patches, vit_parameters['patch_embed'])

    # add positional encoding
    patches = patches + vit_parameters['positional_encoding']

    # append class token to sequence
    cls_token = vit_parameters['cls_token']
    patches = jnp.concatenate([cls_token, patches], axis=0)


    # forward through all transformer blocks
    for layer, block_params in enumerate(vit_parameters['layers']):
        patches = transformer_block(patches, block_params)

    # final layer norm
    patches = layer_norm(patches, vit_parameters['final_layer_norm'])

    # get the class token and apply the final head
    patches = patches[0, :]
    logits = jnp.matmul(patches, vit_parameters['head'][0]) + vit_parameters['head'][1]
    return logits

In [21]:
sample_image = random.normal(key, (3 ,image_size, image_size))
prediction = transformer(sample_image, vit_parameters)
print("Output shape:", prediction.shape) # should be (num_classes,)


Output shape: (10,)


In [22]:
def cross_entropy_loss(patches, vit_parameters, ground_truth):
    prediction = jax.vmap(transformer, in_axes=(0, None))(patches, vit_parameters)
    logs = jax.nn.log_softmax(prediction)
    l = -jnp.mean(jnp.sum(ground_truth * logs, axis=-1))
    return l

In [25]:

mean, std = [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]


train_dataset = Imagenette(
        root='data',
        size="160px",
        split='train',
        download=True,
        transform=transforms.Compose([transforms.Resize((image_size,image_size)),  transforms.ToTensor(), transforms.Normalize(mean, std)])
)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True)


test_dataset = Imagenette(
        root='data',
        size="160px",
        split='val',
        download=False,
        transform=transforms.Compose([transforms.Resize((image_size,image_size)), transforms.ToTensor(), transforms.Normalize(mean, std)])
)
test_loader = DataLoader(test_dataset, batch_size=256, shuffle=False)

100%|██████████| 99.0M/99.0M [01:52<00:00, 884kB/s] 


In [26]:
from tqdm import tqdm


def eval(vit_parameters):

    correct = 0

    for(img, target) in tqdm(test_loader, desc="Eval", unit="item"):

        img = jnp.asarray(img, dtype=jnp.float32)
        target = jnp.asarray(target)

        logits = jax.vmap(transformer, in_axes=(0, None))(img, vit_parameters)
        prediction = jnp.argmax(logits, axis=-1)
        correct += jnp.sum(prediction == target).item()


    acc = correct / len(test_dataset)

    return acc

accuracy = eval(vit_parameters)
print("Accuracy before training", accuracy)

Eval: 100%|██████████| 16/16 [00:31<00:00,  1.97s/item]

Accuracy before training 0.06777070063694267





In [29]:
bsize = 5
# fake labels and images
sample_images = random.normal(key, (bsize, 3 ,image_size, image_size))
sample_target = jnp.zeros((bsize, 10)).at[0, 1].set(1)
current_loss, grads = value_and_grad(cross_entropy_loss, argnums=1)(sample_images, vit_parameters, sample_target)

print("Current loss:", current_loss)
print("Gradients:", grads.keys())

Current loss: 0.6300481
Gradients: dict_keys(['cls_token', 'final_layer_norm', 'head', 'layers', 'patch_embed', 'positional_encoding'])


In [31]:
@jit
def train_step(patches, vit_parameters, target):
    # compute gradients
    current_loss, grads = value_and_grad(cross_entropy_loss, argnums=1)(
            patches,
            vit_parameters,
            target)

    # update parameters
    updated_params = jax.tree.map(lambda p, g: p - 0.01 * g, vit_parameters, grads)

    return current_loss, updated_params

In [33]:
num_epochs = 20


for epoch in range(num_epochs):

    progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{num_epochs}")
    #for (data, target) in tqdm(train_loader, desc=f'Train epoch {epoch}'):
    for i, (data, target) in progress_bar:

        # convert to numpy
        data = jnp.asarray(data)
        target = jnp.asarray(target)

        # reshape and get one hot fot loss
        target_one_hot = jax.nn.one_hot(target, num_classes)

        current_loss, vit_parameters = train_step(data, vit_parameters, target_one_hot)

        progress_bar.set_postfix({'loss': current_loss})


    eval_acc = eval(vit_parameters)
    print(f'Epoch: {epoch}, Eval acc: {eval_acc}')

Epoch 1/20: 100%|██████████| 37/37 [04:59<00:00,  8.08s/it, loss=1.9454371]
Eval: 100%|██████████| 16/16 [00:26<00:00,  1.64s/item]


Epoch: 0, Eval acc: 0.3019108280254777


Epoch 2/20: 100%|██████████| 37/37 [04:22<00:00,  7.10s/it, loss=1.9958316]
Eval: 100%|██████████| 16/16 [00:26<00:00,  1.66s/item]


Epoch: 1, Eval acc: 0.3284076433121019


Epoch 3/20: 100%|██████████| 37/37 [05:19<00:00,  8.62s/it, loss=1.8506495]
Eval: 100%|██████████| 16/16 [00:26<00:00,  1.67s/item]


Epoch: 2, Eval acc: 0.36203821656050955


Epoch 4/20: 100%|██████████| 37/37 [05:14<00:00,  8.50s/it, loss=1.9770259]
Eval: 100%|██████████| 16/16 [00:28<00:00,  1.77s/item]


Epoch: 3, Eval acc: 0.3653503184713376


Epoch 5/20:   3%|▎         | 1/37 [00:18<10:54, 18.18s/it, loss=1.9966648]


KeyboardInterrupt: 