In [2]:
import numpy as np
import tensorflow.keras as keras
from cffi import model
from keras import layers
from keras.src.callbacks import early_stopping
from keras.src.metrics.accuracy_metrics import accuracy
from sympy import sequence, factor
from tensorflow.python.eager.profiler import start
from tensorflow.python.feature_column.utils import sequence_length_from_sparse_tensor
from tensorflow.python.keras.backend import learning_phase

In [3]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


In [4]:
weight_decay = 0.0001
batch_size = 128
num_epochs = 1
dropout_rate = 0.2
image_size = 64
patch_size = 8
num_patches = (image_size // patch_size) ** 2
embedding_dim = 256
num_blocks = 4

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {image_size ** 2}")
print(f"Patches per images: {num_patches}")
print(f"Elements per patch (3 channels): {(num_patches ** 2) * 2}")


Image size: 64 X 64 = 4096
Patch size: 8 X 8 = 4096
Patches per images: 64
Elements per patch (3 channels): 8192


In [5]:
def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)

    augmented = data_augmentation(inputs)

    patches = Patches(patch_size)(augmented)

    x = layers.Dense(units=embedding_dim)(patches)

    if positional_encoding:
        x = x + PositionalEmbedding(sequence_length=num_patches)(x)

    x = blocks(x)

    representation = layers.GlobalAveragePooling1D()(x)

    representation = layers.Dropout(rate=dropout_rate)(representation)

    logits = layers.Dense(num_classes)(representation)

    return keras.Model(inputs=inputs, outputs=logits)

In [6]:
def run_experiment(model):
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
                 keras.metrics.SparseTopKCategoricalAccuracy(5, name="Top_5")],
    )

    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor='val_loss', factor=0.5, patience=5
    )

    early_stopping = keras.callbacks.EarlyStopping(
        monitor='val_loss', patience=10, restore_best_weights=True
    )

    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
        verbose=0,
    )

    _, acc, top_5 = model.evaluate(x_test, y_test)

    print(f"Accuracy: {acc}")
    print(f"Top 5 accuracy: {top_5}")

    return history


In [7]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ]
)

data_augmentation.layers[0].adapt(x_train)

I0000 00:00:1743692574.427606   38390 gpu_device.cc:2019] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 9711 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 3060, pci bus id: 0000:01:00.0, compute capability: 8.6


In [8]:
class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, x):
        patches = keras.ops.image.extract_patches(x, self.patch_size)
        batch = keras.ops.shape(patches)[0]
        num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
        patch_dim = keras.ops.shape(patches)[3]
        out = keras.ops.reshape(patches, (batch, num_patches, patch_dim))

        return out

In [9]:
class PositionalEmbedding(keras.layers.Layer):
    def __init__(self, sequence_length, initalizer="glorot_uniform", **kwargs):
        super().__init__(**kwargs)
        if sequence_length is None:
            raise ValueError("Sequence length cannot be None")
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initalizer)

    def get_config(self):
        config = self.get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "initializer": keras.initializers.serialize(self.initializer),
            }
        )

        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embedding = self.add_weight(
            name="embedding",
            shape=[self.sequence_length, feature_size],
            initializer=self.initializer,
            trainable=True
        )

        super().build(input_shape)

    def call(self, inputs, start_index=0):
        shape = keras.ops.shape(inputs)
        feature_size = shape[-1]
        sequence_length = shape[-2]

        position_embedding = keras.ops.covert_to_tensor(self.position_embedding)
        position_embedding = keras.ops.slice(
            position_embedding,
            (start_index, 0),
            (self.sequence_length, feature_size)
        )

        return keras.ops.broadcast_to(position_embedding, shape)

    def compute_output_shape(self, input_shape):
        return input_shape


In [10]:
class MLPMixerLayers(layers.Layer):
    def __init__(self, num_classes, hidden_units, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate)
            ]
        )

        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=hidden_units),
                layers.Dropout(rate=dropout_rate)
            ]
        )

        self.normalize = layers.LayerNormalization(epsilon=1e-6)


    def build(self, input_shape):
        return super().build(input_shape)

    def call(self, inputs):
        x = self.normalize(inputs)

        x_channels = keras.ops.transpose(x, axes=(0, 2, 1))

        mlp1_outputs = self.mlp1(x_channels)
        mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))

        x = mlp1_outputs + inputs

        x_patches = self.normalize(x)

        mlp2_outputs = self.mlp2(x_patches)

        x = x + mlp2_outputs

        return x



In [11]:
mlpmixer_blocks = keras.Sequential(
    [
        MLPMixerLayers(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)
    ]
)

learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)


I0000 00:00:1743692671.262757   49115 cuda_dnn.cc:529] Loaded cuDNN version 90300


[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - Top_5: 0.9252 - accuracy: 0.4860 - loss: 1.4514
Accuracy: 0.4821000099182129
Top 5 accuracy: 0.9247000217437744
