In [1]:
!pip install k3im medmnist -qq --upgrade

import os
os.environ['KERAS_BACKEND'] = 'jax'

In [6]:
import keras
import medmnist
import numpy as np
import tensorflow as tf # For data processes only

In [7]:
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
num_classes = 11

def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )


# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

In [8]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")

In [9]:
batch_size = 32
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(trainloader, epochs=epochs, validation_data=validloader)
    score = model.evaluate(testloader)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

In [11]:
from k3im.cait_3d import CAiT3DModel # fixed jax ✅
model = CAiT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=128,
    depth=3,
    cls_depth=3,
    heads=8,
    mlp_dim=84,
    channels=1,
    dim_head=64,
)

In [12]:
model.summary()

In [13]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m29s[0m 596ms/step - accuracy: 0.4146 - loss: 1.8926 - val_accuracy: 0.9255 - val_loss: 0.2574
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 22ms/step - accuracy: 0.7658 - loss: 0.7132 - val_accuracy: 0.9068 - val_loss: 0.2550
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 268ms/step - accuracy: 0.7239 - loss: 1.0073
Test loss: 0.9272577166557312
Test accuracy: 0.7180327773094177


In [14]:
from k3im.cct_3d import CCT3DModel # jax ✅, tensorflow ✅, torch ✅
model = CCT3DModel(input_shape=(28, 28, 28, 1),
    num_heads=4,
    projection_dim=64,
    kernel_size=4,
    stride=4,
    padding=2,
    transformer_units=[16, 64],
    stochastic_depth_rate=0.6,
    transformer_layers=2,
    num_classes=num_classes,
    positional_emb=False,)

In [15]:
model.summary()

In [16]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m18s[0m 311ms/step - accuracy: 0.1076 - loss: 2.4914 - val_accuracy: 0.1801 - val_loss: 2.2091
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.2319 - loss: 2.0499 - val_accuracy: 0.4099 - val_loss: 1.4763
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 134ms/step - accuracy: 0.3311 - loss: 1.6405
Test loss: 1.6255520582199097
Test accuracy: 0.34590163826942444


In [17]:
from k3im.convmixer_3d import ConvMixer3DModel # jax ✅, # tensorflow ✅,  #torch fail
model = ConvMixer3DModel(image_size=28,
    num_frames=28,
    filters=32,
    depth=2,
    kernel_size=4,
    kernel_depth=3,
    patch_size=3,
    patch_depth=3,
    num_classes=10,
    num_channels=1)

In [18]:
model.summary()

In [19]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m8s[0m 151ms/step - accuracy: 0.2606 - loss: 1.7970 - val_accuracy: 0.0994 - val_loss: 2.1118
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 21ms/step - accuracy: 0.4494 - loss: 1.4487 - val_accuracy: 0.0994 - val_loss: 2.1147
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 66ms/step - accuracy: 0.0999 - loss: 2.0564
Test loss: 2.065363645553589
Test accuracy: 0.11311475187540054


In [20]:
from k3im.eanet3d import EANet3DModel # jax ✅, tensorflow ✅, torch ✅

In [21]:
model = EANet3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=64,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_coefficient=4,
    projection_dropout=0.0,
    attention_dropout=0,
)

In [22]:
model.summary()

In [23]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m12s[0m 248ms/step - accuracy: 0.3390 - loss: 2.0201 - val_accuracy: 0.9130 - val_loss: 0.5298
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 22ms/step - accuracy: 0.7829 - loss: 0.8192 - val_accuracy: 0.9441 - val_loss: 0.2836
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 133ms/step - accuracy: 0.7109 - loss: 0.9908
Test loss: 0.9404497146606445
Test accuracy: 0.7180327773094177


In [24]:
from k3im.fnet_3d import FNet3DModel # jax ✅ , tensorflow ✅, torch ✅
model = FNet3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=256,
    depth=5,
    hidden_units=256,
    dropout_rate=0.6,
    channels=1,
)

In [25]:
model.summary()

In [26]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m26s[0m 681ms/step - accuracy: 0.1565 - loss: 2.3726 - val_accuracy: 0.6149 - val_loss: 1.6513
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 15ms/step - accuracy: 0.6299 - loss: 1.4855 - val_accuracy: 0.7391 - val_loss: 0.7485
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 134ms/step - accuracy: 0.5232 - loss: 1.4321
Test loss: 1.3722052574157715
Test accuracy: 0.5278688669204712


In [27]:
from k3im.gmlp_3d import gMLP3DModel # jax ✅ , tensorflow ✅, torch ✅
model = gMLP3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)

In [28]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m23s[0m 552ms/step - accuracy: 0.2921 - loss: 2.1528 - val_accuracy: 0.7329 - val_loss: 0.9436
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 8ms/step - accuracy: 0.7036 - loss: 1.1000 - val_accuracy: 0.9193 - val_loss: 0.4642
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 134ms/step - accuracy: 0.6815 - loss: 1.1396
Test loss: 1.0670228004455566
Test accuracy: 0.685245931148529


In [29]:
from k3im.mlp_mixer_3d import MLPMixer3DModel # jax ✅, tensorflow ✅, torch ✅

model = MLPMixer3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)
model.summary()

In [30]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m24s[0m 655ms/step - accuracy: 0.3077 - loss: 2.1705 - val_accuracy: 0.8944 - val_loss: 0.3976
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.7351 - loss: 0.8639 - val_accuracy: 0.9317 - val_loss: 0.2207
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 66ms/step - accuracy: 0.7119 - loss: 1.0056
Test loss: 0.9401963353157043
Test accuracy: 0.7147541046142578


In [31]:
from k3im.simple_vit_3d import SimpleViT3DModel # jax ✅, tensorflow ✅, torch ✅

model = SimpleViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_head=64,
)

In [32]:
model.summary()

In [33]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 296ms/step - accuracy: 0.3410 - loss: 2.0619 - val_accuracy: 0.8571 - val_loss: 0.7134
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 9ms/step - accuracy: 0.7661 - loss: 0.9615 - val_accuracy: 0.9006 - val_loss: 0.4115
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 66ms/step - accuracy: 0.6486 - loss: 1.1380
Test loss: 1.0664023160934448
Test accuracy: 0.6672131419181824


In [34]:
from k3im.vit_3d import ViT3DModel # jax ✅, tensorflow ✅, torch ✅

In [35]:
model = ViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    pool='cls',
    channels=1,
    dim_head=64,
)

In [36]:
model.summary()

In [37]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 262ms/step - accuracy: 0.3574 - loss: 1.9998 - val_accuracy: 0.8696 - val_loss: 0.5885
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 11ms/step - accuracy: 0.7141 - loss: 0.9330 - val_accuracy: 0.9255 - val_loss: 0.3551
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 63ms/step - accuracy: 0.6813 - loss: 1.0153
Test loss: 0.963051438331604
Test accuracy: 0.6983606815338135


In [None]:
from k3im.video_eanet import VideoEANet # fixed jax ✅
model = VideoEANet(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=64,
    spatial_depth=2,
    temporal_depth=2,
    heads=4,
    mlp_dim=128,
    pool="cls",
    channels=1,
    dim_coefficient=4,
    projection_dropout=0.0,
    attention_dropout=0,
    emb_dropout=0.0,
)

In [None]:
model.summary()

In [None]:
train_model(model)

In [None]:
from k3im.vivit import ViViT

model = ViViT(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=11,
    dim=64,
    spatial_depth=1,
    temporal_depth=1,
    heads=4,
    mlp_dim=128,
    pool="cls",
    channels=1,
    dim_head=64,
    dropout=0.0,
    emb_dropout=0.0,
)

In [None]:
model.summary()

In [None]:
train_model(model)