In [None]:
!pip install k3im medmnist -qq --upgrade

import sys
import os

os.environ['KERAS_BACKEND'] = 'jax'

In [4]:
import keras
import numpy as np

In [5]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [6]:
batch_size = 128
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.CategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    score = model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

In [7]:
from k3im.vit import ViT
model = ViT(
    image_size=(28, 28),
    patch_size=(7, 7),
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=8,
    mlp_dim=65,
    channels=1,
    dim_head=32,
    pool="mean",
)

In [8]:
model.summary()

In [9]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.5419 - loss: 1.3609 - val_accuracy: 0.9145 - val_loss: 0.2954
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.9110 - loss: 0.3006 - val_accuracy: 0.9457 - val_loss: 0.1844
Test loss: 0.22440513968467712
Test accuracy: 0.9340999722480774


In [10]:
from k3im.simple_vit import SimpleViT
model = SimpleViT(
    image_size=(28, 28),
    patch_size=(7, 7),
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=8,
    mlp_dim=65,
    channels=1,
    dim_head=32,
    pool="mean",
)

In [11]:
model.summary()

In [12]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.5103 - loss: 1.4526 - val_accuracy: 0.9015 - val_loss: 0.3366
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.8941 - loss: 0.3440 - val_accuracy: 0.9435 - val_loss: 0.1977
Test loss: 0.23583565652370453
Test accuracy: 0.9265999794006348


In [13]:
from k3im.cct import CCT

model = CCT(
    input_shape=input_shape,
    num_heads=8,
    projection_dim=32,
    kernel_size=3,
    stride=3,
    padding=2,
    transformer_units=[16, 32],
    stochastic_depth_rate=0.6,
    transformer_layers=2,
    num_classes=num_classes,
    positional_emb=False,
)

In [14]:
model.summary()

In [15]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.6436 - loss: 1.0642 - val_accuracy: 0.9290 - val_loss: 0.2381
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 13ms/step - accuracy: 0.9302 - loss: 0.2362 - val_accuracy: 0.9552 - val_loss: 0.1508
Test loss: 0.1577761322259903
Test accuracy: 0.9517999887466431


In [16]:
from k3im.convmixer import ConvMixer # Check convmixer


model = ConvMixer(
    image_size=28, filters=64, depth=8, kernel_size=3, patch_size=2, num_classes=10, num_channels=1
)
model.summary()

In [17]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 19ms/step - accuracy: 0.6886 - loss: 1.1870 - val_accuracy: 0.1045 - val_loss: 174.8289
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 19ms/step - accuracy: 0.9161 - loss: 0.3115 - val_accuracy: 0.1050 - val_loss: 90.8182
Test loss: 90.73978424072266
Test accuracy: 0.11349999904632568


In [18]:
from k3im.eanet import EANet
model = EANet(
    input_shape=input_shape,
    patch_size=7,
    embedding_dim=64,
    num_transformer_blocks=2,
    mlp_dim=32,
    num_heads=16,
    dim_coefficient=2,
    attention_dropout=0.5,
    projection_dropout=0.5,
    num_classes=10,
)

In [19]:
model.summary()

In [20]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.4249 - loss: 1.6590 - val_accuracy: 0.8365 - val_loss: 0.5266
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 9ms/step - accuracy: 0.7781 - loss: 0.6857 - val_accuracy: 0.9113 - val_loss: 0.3001
Test loss: 0.350763738155365
Test accuracy: 0.891700029373169


In [21]:
from k3im.gmlp import gMLPModel
model = gMLPModel(
    image_size=28,
    patch_size=7,
    embedding_dim=32,
    num_blocks=4,
    dropout_rate=0.5,
    num_classes=num_classes,
    positional_encoding=False,
    num_channels=1,
)

In [22]:
model.summary()

In [23]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.2462 - loss: 2.0940 - val_accuracy: 0.7595 - val_loss: 0.7761
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.6231 - loss: 1.1040 - val_accuracy: 0.9082 - val_loss: 0.3132
Test loss: 0.375752717256546
Test accuracy: 0.8859999775886536


In [24]:
from k3im.mlp_mixer import MixerModel
model = MixerModel(
    image_size=28,
    patch_size=7,
    embedding_dim=32,
    num_blocks=4,
    dropout_rate=0.5,
    num_classes=num_classes,
    positional_encoding=True,
    num_channels=1,
)

model.summary()

In [25]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 11ms/step - accuracy: 0.2454 - loss: 2.1532 - val_accuracy: 0.7982 - val_loss: 0.6211
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 11ms/step - accuracy: 0.7099 - loss: 0.8873 - val_accuracy: 0.8825 - val_loss: 0.3738
Test loss: 0.42113006114959717
Test accuracy: 0.8712000250816345


In [26]:
from k3im.simple_vit_with_fft import SimpleViTFFT
model = SimpleViTFFT(image_size=28, patch_size=7, freq_patch_size=7, num_classes=num_classes, dim=32, depth=2, 
                     heads=8, mlp_dim=64, channels=1, 
                     dim_head = 16)
model.summary()

In [27]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.5062 - loss: 1.4556 - val_accuracy: 0.8795 - val_loss: 0.4012
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 12ms/step - accuracy: 0.8776 - loss: 0.3968 - val_accuracy: 0.9382 - val_loss: 0.2112
Test loss: 0.24840669333934784
Test accuracy: 0.9253000020980835


In [28]:
from k3im.simple_vit_with_register_tokens import SimpleViT_RT
model = SimpleViT_RT(image_size=28,
    patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=64,
    num_register_tokens=4,
    channels=1,
    dim_head=64,)

In [29]:
model.summary()

In [30]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.5092 - loss: 1.4363 - val_accuracy: 0.8662 - val_loss: 0.4330
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m4s[0m 10ms/step - accuracy: 0.8685 - loss: 0.4252 - val_accuracy: 0.9273 - val_loss: 0.2530
Test loss: 0.2882605493068695
Test accuracy: 0.9120000004768372


In [31]:
from k3im.swint import SwinTModel ########## PROBLEM

In [32]:
model = SwinTModel(
    img_size=28,
    patch_size=7,
    embed_dim=32,
    num_heads=4,
    window_size=4,
    num_mlp=4,
    qkv_bias=True,
    dropout_rate=0.2,
    shift_size=2,
    num_classes=num_classes,
    in_channels=1,
)

In [33]:
model.summary()

In [None]:
train_model(model)