In [None]:
!git clone https://github.com/anas-rz/k3im.git

In [1]:
import os
os.environ['KERAS_BACKEND'] = 'jax'

In [2]:
!pip install keras medmnist -qq --upgrade

/bin/bash: /home/anas/miniconda3/envs/keras_core/lib/libtinfo.so.6: no version information available (required by /bin/bash)
Keras version: 3.0.0


In [3]:
import sys
sys.path.append('/content/k3im')

In [4]:
import keras
import medmnist
import numpy as np
import tensorflow as tf # For data processes only



In [5]:
DATASET_NAME = "organmnist3d"
BATCH_SIZE = 32
AUTO = tf.data.AUTOTUNE
num_classes = 11

def download_and_prepare_dataset(data_info: dict):
    """Utility function to download the dataset.

    Arguments:
        data_info (dict): Dataset metadata.
    """
    data_path = keras.utils.get_file(origin=data_info["url"], md5_hash=data_info["MD5"])

    with np.load(data_path) as data:
        # Get videos
        train_videos = data["train_images"]
        valid_videos = data["val_images"]
        test_videos = data["test_images"]

        # Get labels
        train_labels = data["train_labels"].flatten()
        valid_labels = data["val_labels"].flatten()
        test_labels = data["test_labels"].flatten()

    return (
        (train_videos, train_labels),
        (valid_videos, valid_labels),
        (test_videos, test_labels),
    )


# Get the metadata of the dataset
info = medmnist.INFO[DATASET_NAME]

# Get the dataset
prepared_dataset = download_and_prepare_dataset(info)
(train_videos, train_labels) = prepared_dataset[0]
(valid_videos, valid_labels) = prepared_dataset[1]
(test_videos, test_labels) = prepared_dataset[2]

In [6]:
@tf.function
def preprocess(frames: tf.Tensor, label: tf.Tensor):
    """Preprocess the frames tensors and parse the labels."""
    # Preprocess images
    frames = tf.image.convert_image_dtype(
        frames[
            ..., tf.newaxis
        ],  # The new axis is to help for further processing with Conv3D layers
        tf.float32,
    )
    # Parse label
    label = tf.cast(label, tf.float32)
    return frames, label


def prepare_dataloader(
    videos: np.ndarray,
    labels: np.ndarray,
    loader_type: str = "train",
    batch_size: int = BATCH_SIZE,
):
    """Utility function to prepare the dataloader."""
    dataset = tf.data.Dataset.from_tensor_slices((videos, labels))

    if loader_type == "train":
        dataset = dataset.shuffle(BATCH_SIZE * 2)

    dataloader = (
        dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )
    return dataloader


trainloader = prepare_dataloader(train_videos, train_labels, "train")
validloader = prepare_dataloader(valid_videos, valid_labels, "valid")
testloader = prepare_dataloader(test_videos, test_labels, "test")

2023-12-05 14:51:19.335470: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:47] Overriding orig_value setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.


In [7]:
batch_size = 32
epochs = 2
def train_model(model):
    model.compile(loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), optimizer="adam", metrics=["accuracy"])
    model.fit(trainloader, epochs=epochs, validation_data=validloader)
    score = model.evaluate(testloader)
    print("Test loss:", score[0])
    print("Test accuracy:", score[1])

In [8]:
from k3im.cct_3d import CCT3DModel
model = CCT3DModel(input_shape=(28, 28, 28, 1),
    num_heads=4,
    projection_dim=64,
    kernel_size=4,
    stride=4,
    padding=2,
    transformer_units=[16, 64],
    stochastic_depth_rate=0.6,
    transformer_layers=2,
    num_classes=num_classes,
    positional_emb=False,)

CUDA backend failed to initialize: Found CUDA version 11070, but JAX was built against version 11080, which is newer. The copy of CUDA that is installed must be at least as new as the version against which JAX was built. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [9]:
model.summary()

In [10]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m10s[0m 229ms/step - accuracy: 0.1104 - loss: 2.5978 - val_accuracy: 0.0870 - val_loss: 2.2054
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 74ms/step - accuracy: 0.2255 - loss: 2.1498 - val_accuracy: 0.3665 - val_loss: 1.5470
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 36ms/step - accuracy: 0.3300 - loss: 1.7090
Test loss: 1.7021818161010742
Test accuracy: 0.32786884903907776


In [11]:
from k3im.convmixer_3d import ConvMixer3DModel
model = ConvMixer3DModel(image_size=28,
    num_frames=28,
    filters=32,
    depth=2,
    kernel_size=4,
    kernel_depth=3,
    patch_size=3,
    patch_depth=3,
    num_classes=10,
    num_channels=1)

In [12]:
model.summary()

In [13]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 114ms/step - accuracy: 0.2658 - loss: 1.7671 - val_accuracy: 0.0932 - val_loss: 2.0907
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 58ms/step - accuracy: 0.4928 - loss: 1.3977 - val_accuracy: 0.0932 - val_loss: 2.0937
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 23ms/step - accuracy: 0.1219 - loss: 2.0012
Test loss: 2.020193576812744
Test accuracy: 0.10819672048091888


In [14]:
from k3im.eanet3d import EANet3DModel

In [15]:
model = EANet3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=64,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_coefficient=4,
    projection_dropout=0.0,
    attention_dropout=0,
)

In [16]:
model.summary()

In [17]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 127ms/step - accuracy: 0.3240 - loss: 2.0638 - val_accuracy: 0.9130 - val_loss: 0.5357
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 28ms/step - accuracy: 0.7652 - loss: 0.8392 - val_accuracy: 0.9317 - val_loss: 0.2855
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 31ms/step - accuracy: 0.7249 - loss: 0.9507
Test loss: 0.9184524416923523
Test accuracy: 0.72295081615448


In [18]:
from k3im.gmlp_3d import gMLP3DModel
model = gMLP3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)

In [19]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 130ms/step - accuracy: 0.3039 - loss: 2.1179 - val_accuracy: 0.7826 - val_loss: 0.9232
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 14ms/step - accuracy: 0.7200 - loss: 1.0776 - val_accuracy: 0.8820 - val_loss: 0.4571
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.6478 - loss: 1.1318
Test loss: 1.0548360347747803
Test accuracy: 0.6770491600036621


In [20]:
from k3im.mlp_mixer_3d import MLPMixer3DModel

model = MLPMixer3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=4,
    hidden_units=32,
    dropout_rate=0.4,
    channels=1,
)
model.summary()

In [21]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 135ms/step - accuracy: 0.2796 - loss: 2.1830 - val_accuracy: 0.8820 - val_loss: 0.4691
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.7445 - loss: 0.8517 - val_accuracy: 0.9193 - val_loss: 0.2541
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step - accuracy: 0.6725 - loss: 1.0839
Test loss: 1.0160973072052002
Test accuracy: 0.6819671988487244


In [None]:
from k3im.simple_vit_3d import SimpleViT3DModel

model = SimpleViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    channels=1,
    dim_head=64,
)

In [None]:
model.summary()

In [None]:
train_model(model)

Epoch 1/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 124ms/step - accuracy: 0.7908 - loss: 0.6558 - val_accuracy: 0.9317 - val_loss: 0.2221
Epoch 2/2
[1m31/31[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 13ms/step - accuracy: 0.8924 - loss: 0.3603 - val_accuracy: 0.9317 - val_loss: 0.1682
[1m20/20[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 24ms/step - accuracy: 0.7327 - loss: 1.0193
Test loss: 0.923545241355896
Test accuracy: 0.7344262003898621


In [None]:
from k3im.vit_3d import ViT3DModel

In [None]:
model = ViT3DModel(
    image_size=28,
    image_patch_size=7,
    frames=28,
    frame_patch_size=7,
    num_classes=num_classes,
    dim=32,
    depth=2,
    heads=4,
    mlp_dim=32,
    pool='cls',
    channels=3,
    dim_head=64,
) ############ ERRRRRRRRRRRRRRRRRRRRRRRRRRR

In [None]:
model.summary()

In [None]:
train_model(model)