In [1]:
%load_ext jupyter_black

import notebooks_path

notebooks_path.include_packages()
import jax
import jax.numpy as jnp
import jax.random as jr
import optax
import equinox as eqx
import functools
import tqdm

from vit import dataloader, util, model, train

In [2]:
batch_size = 8

In [3]:
train_dataloader = dataloader.get_train_dataloader(batch_size)
test_dataloder = dataloader.get_test_dataloader(batch_size)

Files already downloaded and verified


Files already downloaded and verified


In [4]:
image = next(iter(test_dataloder))

In [5]:
jax_image = jnp.array(image[0][0].numpy())
print(jax_image.shape)

2024-01-26 11:11:31.298036: W external/xla/xla/service/gpu/nvptx_compiler.cc:698] The NVIDIA driver's CUDA version is 12.1 which is older than the ptxas CUDA version (12.3.107). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


(3, 32, 32)


In [6]:
patches = util.img_to_patches(jax_image, patch_size=16, flatten_channel=False)

In [7]:
util.plot_patches(patches)









In [8]:
patches = util.img_to_patches(jax_image, patch_size=16, flatten_channel=False)

In [9]:
# Hyperparameters
lr = 0.0001
dropout_rate = 0.1
beta1 = 0.9
beta2 = 0.999
patch_size = 4
num_patches = 64
num_steps = 100000
image_size = (32, 32, 3)
embedding_dim = 512
hidden_dim = 256
num_heads = 8
num_layers = 6
height, width, channels = image_size
num_classes = 10

In [10]:
key = jr.PRNGKey(2003)

model_obj = model.VisionTransformer(
    embedding_dim=embedding_dim,
    channels=channels,
    hidden_dim=hidden_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    dropout_rate=dropout_rate,
    patch_size=patch_size,
    num_patches=num_patches,
    num_classes=num_classes,
    key=key,
)

optimizer = optax.adamw(
    learning_rate=lr,
    b1=beta1,
    b2=beta2,
)

state = optimizer.init(eqx.filter(model_obj, eqx.is_inexact_array))

model_obj, state, losses = train.train(
    model_obj, optimizer, state, train_dataloader, batch_size, num_steps, key=key
)

Step: 0/100000, Loss: 2.3389408588409424.


Step: 1000/100000, Loss: 1.543238639831543.


Step: 2000/100000, Loss: 1.8728426694869995.


Step: 3000/100000, Loss: 2.5119857788085938.


Step: 4000/100000, Loss: 1.6348769664764404.


Step: 5000/100000, Loss: 1.7819663286209106.


Step: 6000/100000, Loss: 1.6219452619552612.


Step: 7000/100000, Loss: 1.36859929561615.


Step: 8000/100000, Loss: 1.1760574579238892.


Step: 9000/100000, Loss: 1.4488664865493774.


Step: 10000/100000, Loss: 1.802775263786316.


Step: 11000/100000, Loss: 1.7637073993682861.


Step: 12000/100000, Loss: 0.7833076119422913.


Step: 13000/100000, Loss: 1.1811528205871582.


Step: 14000/100000, Loss: 1.1001811027526855.


Step: 15000/100000, Loss: 0.9761098623275757.


Step: 16000/100000, Loss: 0.8831843137741089.


Step: 17000/100000, Loss: 1.0121620893478394.


Step: 18000/100000, Loss: 1.0539034605026245.


Step: 19000/100000, Loss: 1.107746958732605.


Step: 20000/100000, Loss: 0.9343699216842651.


Step: 21000/100000, Loss: 1.1023443937301636.


Step: 22000/100000, Loss: 1.090574026107788.


Step: 23000/100000, Loss: 0.7163845300674438.


Step: 24000/100000, Loss: 0.4474385380744934.


Step: 25000/100000, Loss: 0.8346785306930542.


Step: 26000/100000, Loss: 1.422865867614746.


Step: 27000/100000, Loss: 0.8890442848205566.


Step: 28000/100000, Loss: 1.8160719871520996.


Step: 29000/100000, Loss: 1.2586593627929688.


Step: 30000/100000, Loss: 0.8509153127670288.


Step: 31000/100000, Loss: 1.455114483833313.


Step: 32000/100000, Loss: 1.3152594566345215.


Step: 33000/100000, Loss: 1.8295204639434814.


Step: 34000/100000, Loss: 0.7769181728363037.


Step: 35000/100000, Loss: 1.777035117149353.


Step: 36000/100000, Loss: 1.3431330919265747.


Step: 37000/100000, Loss: 0.7172893285751343.


Step: 38000/100000, Loss: 0.9415724277496338.


Step: 39000/100000, Loss: 0.6352332234382629.


Step: 40000/100000, Loss: 0.9800623655319214.


Step: 41000/100000, Loss: 1.166698694229126.


Step: 42000/100000, Loss: 0.6654778122901917.


Step: 43000/100000, Loss: 1.3018194437026978.


Step: 44000/100000, Loss: 0.6915130019187927.


Step: 45000/100000, Loss: 1.5542782545089722.


Step: 46000/100000, Loss: 1.0379728078842163.


Step: 47000/100000, Loss: 0.9970062971115112.


Step: 48000/100000, Loss: 1.2466797828674316.


Step: 49000/100000, Loss: 0.4997975528240204.


Step: 50000/100000, Loss: 0.5509274005889893.


Step: 51000/100000, Loss: 0.8984620571136475.


Step: 52000/100000, Loss: 1.6150636672973633.


Step: 53000/100000, Loss: 0.9482676982879639.


Step: 54000/100000, Loss: 1.465075969696045.


Step: 55000/100000, Loss: 1.289531946182251.


Step: 56000/100000, Loss: 0.7533168196678162.


Step: 57000/100000, Loss: 0.9306734800338745.


Step: 58000/100000, Loss: 0.7913195490837097.


Step: 59000/100000, Loss: 1.4608873128890991.


Step: 60000/100000, Loss: 0.9099599123001099.


Step: 61000/100000, Loss: 0.8783217072486877.


Step: 62000/100000, Loss: 0.90946364402771.


Step: 63000/100000, Loss: 0.8942668437957764.


Step: 64000/100000, Loss: 0.8570293188095093.


Step: 65000/100000, Loss: 1.262819766998291.


Step: 66000/100000, Loss: 0.8843784332275391.


Step: 67000/100000, Loss: 1.5559430122375488.


Step: 68000/100000, Loss: 0.700046181678772.


Step: 69000/100000, Loss: 0.8254873752593994.


Step: 70000/100000, Loss: 1.2056628465652466.


Step: 71000/100000, Loss: 0.7090159058570862.


Step: 72000/100000, Loss: 0.901639997959137.


Step: 73000/100000, Loss: 0.9501442313194275.


Step: 74000/100000, Loss: 1.6348774433135986.


Step: 75000/100000, Loss: 1.447645664215088.


Step: 76000/100000, Loss: 0.8848025798797607.


Step: 77000/100000, Loss: 0.9465682506561279.


Step: 78000/100000, Loss: 0.9715703725814819.


Step: 79000/100000, Loss: 0.44195184111595154.


Step: 80000/100000, Loss: 0.6943346261978149.


Step: 81000/100000, Loss: 0.4777899384498596.


Step: 82000/100000, Loss: 0.3369857668876648.


Step: 83000/100000, Loss: 1.2335765361785889.


Step: 84000/100000, Loss: 1.2700443267822266.


Step: 85000/100000, Loss: 0.24581551551818848.


Step: 86000/100000, Loss: 1.2456016540527344.


Step: 87000/100000, Loss: 0.6545699834823608.


Step: 88000/100000, Loss: 1.6082074642181396.


Step: 89000/100000, Loss: 0.5465951561927795.


Step: 90000/100000, Loss: 0.5897265672683716.


Step: 91000/100000, Loss: 0.6407677531242371.


Step: 92000/100000, Loss: 0.3814378082752228.


Step: 93000/100000, Loss: 0.45877283811569214.


Step: 94000/100000, Loss: 0.9836975932121277.


Step: 95000/100000, Loss: 0.40443313121795654.


Step: 96000/100000, Loss: 0.5107324123382568.


Step: 97000/100000, Loss: 0.8225865364074707.


Step: 98000/100000, Loss: 0.234283447265625.


Step: 99000/100000, Loss: 0.6071439385414124.


Step: 99999/100000, Loss: 0.4844452440738678.


In [11]:
accuracies = []

for batch in range(len(test_dataloder.dataset) // batch_size):
    images, labels = next(iter(test_dataloder))

    logits = jax.vmap(functools.partial(model_obj, enable_dropout=False))(
        images.numpy(), key=jax.random.split(key, num=batch_size)
    )

    predictions = jnp.argmax(logits, axis=-1)

    accuracy = jnp.mean(predictions == labels.numpy())

    accuracies.append(accuracy)

print(f"Accuracy: {jnp.sum(jnp.array(accuracies)) / len(accuracies) * 100}%")

Accuracy: 81.89000701904297%
