In [5]:
!pip install k3im --upgrade



In [6]:
import os
os.environ['KERAS_BACKEND'] = 'jax'

In [3]:
import keras
import numpy as np

In [5]:
# Model / data parameters
num_classes = 10
input_shape = (28, 28, 1)
(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()
# Scale images to the [0, 1] range
x_train = x_train.astype("float32") / 255
x_test = x_test.astype("float32") / 255
# Make sure images have shape (28, 28, 1)
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)
print("x_train shape:", x_train.shape)
print(x_train.shape[0], "train samples")
print(x_test.shape[0], "test samples")


# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train shape: (60000, 28, 28, 1)
60000 train samples
10000 test samples


In [6]:
batch_size = 128
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.CategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_split=0.1)
    score = model.evaluate(x_test, y_test, verbose=0)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

In [7]:
from k3im.cait import CaiTModel # jax ✅, tensorflow ✅, torch ✅
model = CaiTModel(
    image_size=(28, 28),
    patch_size=(7, 7),
    num_classes=10,
    dim=32,
    depth=2,
    heads=8,
    mlp_dim=64,
    cls_depth=2,
    channels=1,
    dim_head=64,
)

In [8]:
model.summary()

In [9]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 52ms/step - accuracy: 0.5326 - loss: 1.3825 - val_accuracy: 0.9125 - val_loss: 0.2950
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 45ms/step - accuracy: 0.9129 - loss: 0.2858 - val_accuracy: 0.9593 - val_loss: 0.1494
Test loss: 0.16203546524047852
Test accuracy: 0.9521999955177307


In [10]:
from k3im.cct import CCT  # jax ✅, tensorflow ✅, torch ✅

model = CCT(
    input_shape=input_shape,
    num_heads=8,
    projection_dim=32,
    kernel_size=3,
    stride=3,
    padding=2,
    transformer_units=[16, 32],
    stochastic_depth_rate=0.6,
    transformer_layers=2,
    num_classes=num_classes,
    positional_emb=False,
)

In [11]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 33ms/step - accuracy: 0.6643 - loss: 1.0071 - val_accuracy: 0.9318 - val_loss: 0.2200
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 32ms/step - accuracy: 0.9292 - loss: 0.2308 - val_accuracy: 0.9532 - val_loss: 0.1575
Test loss: 0.1650298684835434
Test accuracy: 0.947700023651123


In [12]:
from k3im.convmixer import ConvMixer # jax ✅ # tf something not right # something aint right with torch


model = ConvMixer(
    image_size=28, filters=64, depth=8, kernel_size=3, patch_size=2, num_classes=10, num_channels=1
)

In [13]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 59ms/step - accuracy: 0.6054 - loss: 1.3247 - val_accuracy: 0.1113 - val_loss: 9747.7539
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 70ms/step - accuracy: 0.8339 - loss: 0.5841 - val_accuracy: 0.1113 - val_loss: 1693.4771
Test loss: 1714.0731201171875
Test accuracy: 0.10279999673366547


In [14]:
from k3im.cross_vit import CrossViT # jax ✅, tensorflow ✅, torch ✅
model = CrossViT(
    image_size=28,
    num_classes=10,
    sm_dim=32,
    lg_dim=42,
    channels=1,
    sm_patch_size=4,
    sm_enc_depth=1,
    sm_enc_heads=8,
    sm_enc_mlp_dim=48,
    sm_enc_dim_head=56,
    lg_patch_size=7,
    lg_enc_depth=2,
    lg_enc_heads=8,
    lg_enc_mlp_dim=84,
    lg_enc_dim_head=72,
    cross_attn_depth=2,
    cross_attn_heads=8,
    cross_attn_dim_head=64,
    depth=3,
    dropout=0.1,
    emb_dropout=0.1
)

In [15]:
model.summary()

In [16]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m76s[0m 179ms/step - accuracy: 0.5782 - loss: 1.2357 - val_accuracy: 0.9185 - val_loss: 0.2657
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m73s[0m 173ms/step - accuracy: 0.8943 - loss: 0.3341 - val_accuracy: 0.9500 - val_loss: 0.1754
Test loss: 0.18887896835803986
Test accuracy: 0.9437000155448914


In [17]:
from k3im.deepvit import DeepViT # jax ✅, tensorflow ✅
model = DeepViT(image_size=28,
    patch_size=7,
    num_classes=10,
    dim=64,
    depth=2,
    heads=8,
    mlp_dim=84,
    pool="cls",
    channels=1,
    dim_head=64,
    dropout=0.0,
    emb_dropout=0.0)

In [18]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 30ms/step - accuracy: 0.6203 - loss: 1.1394 - val_accuracy: 0.9172 - val_loss: 0.2594
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 30ms/step - accuracy: 0.9250 - loss: 0.2398 - val_accuracy: 0.9567 - val_loss: 0.1421
Test loss: 0.16465066373348236
Test accuracy: 0.9492999911308289


In [19]:
from k3im.eanet import EANet # jax ✅, tensorflow ✅, torch ✅
model = EANet(
    input_shape=input_shape,
    patch_size=7,
    embedding_dim=64,
    num_transformer_blocks=2,
    mlp_dim=32,
    num_heads=16,
    dim_coefficient=2,
    attention_dropout=0.5,
    projection_dropout=0.5,
    num_classes=10,
)

In [20]:
model.summary()

In [21]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.3873 - loss: 1.7322 - val_accuracy: 0.8285 - val_loss: 0.5884
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.7580 - loss: 0.7342 - val_accuracy: 0.8965 - val_loss: 0.3452
Test loss: 0.4003435969352722
Test accuracy: 0.8769000172615051


In [22]:
from k3im.fnet import FNetModel # jax ✅, tensorflow ✅, torch ✅
model = FNetModel(
    image_size=28,
    patch_size=7,
    embedding_dim=64,
    num_blocks=2,
    dropout_rate=0.4,
    num_classes=10,
    positional_encoding=False,
    num_channels=1,
)

In [23]:
model.summary()

In [24]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 16ms/step - accuracy: 0.3982 - loss: 1.7336 - val_accuracy: 0.8240 - val_loss: 0.5880
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 15ms/step - accuracy: 0.7478 - loss: 0.7953 - val_accuracy: 0.8858 - val_loss: 0.3823
Test loss: 0.43270713090896606
Test accuracy: 0.8708000183105469


In [25]:
from k3im.focalnet import focalnet_kid # jax ✅, tensorflow ✅, torch ✅
model = focalnet_kid(img_size=28, in_channels=1, num_classes=10)
model.summary()

In [26]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 23ms/step - accuracy: 0.5703 - loss: 1.2913 - val_accuracy: 0.9475 - val_loss: 0.1807
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 24ms/step - accuracy: 0.9437 - loss: 0.1907 - val_accuracy: 0.9697 - val_loss: 0.1060
Test loss: 0.1195862889289856
Test accuracy: 0.9613999724388123


In [27]:
from k3im.gmlp import gMLPModel # jax ✅, tensorflow ✅, torch ✅
model = gMLPModel(
    image_size=28,
    patch_size=7,
    embedding_dim=32,
    num_blocks=4,
    dropout_rate=0.5,
    num_classes=num_classes,
    positional_encoding=False,
    num_channels=1,
)

In [28]:
model.summary()

In [29]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.2239 - loss: 2.1543 - val_accuracy: 0.7037 - val_loss: 0.8542
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.5964 - loss: 1.1702 - val_accuracy: 0.8968 - val_loss: 0.3385
Test loss: 0.3988232910633087
Test accuracy: 0.8780999779701233


In [30]:
from k3im.mlp_mixer import MixerModel # jax ✅, tensorflow ✅, torch ✅
model = MixerModel(
    image_size=28,
    patch_size=7,
    embedding_dim=32,
    num_blocks=4,
    dropout_rate=0.5,
    num_classes=num_classes,
    positional_encoding=True,
    num_channels=1,
)

model.summary()

In [31]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 30ms/step - accuracy: 0.2378 - loss: 2.1815 - val_accuracy: 0.7837 - val_loss: 0.6431
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 31ms/step - accuracy: 0.6868 - loss: 0.9226 - val_accuracy: 0.8920 - val_loss: 0.3507
Test loss: 0.4144611358642578
Test accuracy: 0.8709999918937683


In [32]:
from k3im.simple_vit import SimpleViT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT(
    image_size=(28, 28),
    patch_size=(7, 7),
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=8,
    mlp_dim=65,
    channels=1,
    dim_head=32,
    pool="mean",
)

In [33]:
model.summary()

In [34]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 24ms/step - accuracy: 0.5580 - loss: 1.3136 - val_accuracy: 0.8928 - val_loss: 0.3520
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 24ms/step - accuracy: 0.8881 - loss: 0.3604 - val_accuracy: 0.9362 - val_loss: 0.2134
Test loss: 0.2578793466091156
Test accuracy: 0.9205999970436096


In [35]:
from k3im.simple_vit_with_fft import SimpleViTFFT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTFFT(image_size=28, patch_size=7, freq_patch_size=7, num_classes=num_classes, dim=32, depth=2,
                     heads=8, mlp_dim=64, channels=1,
                     dim_head = 16)
model.summary()

In [36]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 29ms/step - accuracy: 0.5238 - loss: 1.4078 - val_accuracy: 0.9082 - val_loss: 0.3177
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 29ms/step - accuracy: 0.9021 - loss: 0.3226 - val_accuracy: 0.9453 - val_loss: 0.1842
Test loss: 0.2127855271100998
Test accuracy: 0.9369999766349792


In [37]:
from k3im.simple_vit_with_register_tokens import SimpleViT_RT # jax ✅, tensorflow ✅, torch ✅
model = SimpleViT_RT(image_size=28,
    patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=64,
    num_register_tokens=4,
    channels=1,
    dim_head=64,)

In [38]:
model.summary()

In [39]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.5336 - loss: 1.3876 - val_accuracy: 0.8742 - val_loss: 0.4016
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.8742 - loss: 0.4034 - val_accuracy: 0.9272 - val_loss: 0.2538
Test loss: 0.2974727153778076
Test accuracy: 0.909600019454956


In [40]:
from k3im.swint import SwinTModel # jax ✅, tensorflow ✅, torch ✅

In [41]:
model = SwinTModel(
    img_size=28,
    patch_size=7,
    embed_dim=32,
    num_heads=4,
    window_size=4,
    num_mlp=4,
    qkv_bias=True,
    dropout_rate=0.2,
    shift_size=2,
    num_classes=num_classes,
    in_channels=1,
)

In [42]:
model.summary()

In [43]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.6364 - loss: 1.1143 - val_accuracy: 0.9118 - val_loss: 0.2969
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m11s[0m 26ms/step - accuracy: 0.8750 - loss: 0.3952 - val_accuracy: 0.9358 - val_loss: 0.2215
Test loss: 0.26952508091926575
Test accuracy: 0.9175999760627747


In [44]:
from k3im.token_learner import ViTokenLearner # jax check with jax ✅, tensorflow ✅, torch ✅
model = ViTokenLearner(image_size=28,
    patch_size=7,
    num_classes=10,
    dim=64,
    depth=4,
    heads=4,
    mlp_dim=32,
    token_learner_units=2,
    channels=1,
    dim_head=64,
                   dropout_rate=0.,
    pool="mean", use_token_learner=True)

In [45]:
model.summary()

In [46]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 56ms/step - accuracy: 0.3959 - loss: 2.0571 - val_accuracy: 0.7192 - val_loss: 1.7431
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 55ms/step - accuracy: 0.7304 - loss: 1.7321 - val_accuracy: 0.8182 - val_loss: 1.6422
Test loss: 1.654205560684204
Test accuracy: 0.8080999851226807


In [47]:
from k3im.vit import ViT # jax ✅, tensorflow ✅, torch ✅
model = ViT(
    image_size=(28, 28),
    patch_size=(7, 7),
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=8,
    mlp_dim=65,
    channels=1,
    dim_head=32,
    pool="mean",
)

In [48]:
model.summary()

In [49]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 36ms/step - accuracy: 0.5635 - loss: 1.3064 - val_accuracy: 0.8988 - val_loss: 0.3343
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 28ms/step - accuracy: 0.9013 - loss: 0.3197 - val_accuracy: 0.9478 - val_loss: 0.1735
Test loss: 0.18965917825698853
Test accuracy: 0.942799985408783


In [50]:
from k3im.vit_with_patch_dropout import SimpleViTPD # jax ✅, tensorflow ✅, torch ✅
model = SimpleViTPD(
    image_size=28,
    patch_size=7,
    num_classes=10,
    dim=32,
    depth=4,
    heads=8,
    mlp_dim=42,
    patch_dropout=0.25,
    channels=1,
    dim_head=16,
    pool="mean",
)

In [51]:
model.summary()

In [52]:
train_model(model)

Epoch 1/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m16s[0m 39ms/step - accuracy: 0.4712 - loss: 1.5234 - val_accuracy: 0.8593 - val_loss: 0.4529
Epoch 2/2
[1m422/422[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 39ms/step - accuracy: 0.8093 - loss: 0.5927 - val_accuracy: 0.9195 - val_loss: 0.2688
Test loss: 0.30448290705680847
Test accuracy: 0.9017000198364258
