<a href="https://colab.research.google.com/github/ashaduzzaman-sarker/Computer-Vision-Projects/blob/main/Image_classification_with_modern_MLP(MLP_Mixer%2C_FNet%2C_and_gMLP)_models_for_CIFAR_100_using_Keras.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image classification with modern MLP(MLP-Mixer, FNet, and gMLP) models for CIFAR-100 using Keras

**Author:** [Ashaduzzaman Piash](https://github.com/ashaduzzaman-sarker/)
<br>
**Date created:** 2024/06/17

## Introduction

### Overview
This project demonstrates the implementation of three modern attention-free, multi-layer perceptron (MLP) based models for image classification, using the CIFAR-100 dataset:

- **[MLP-Mixer Model](
https://doi.org/10.48550/arXiv.2105.01601)**: Developed by Ilya Tolstikhin et al., this model utilizes two types of MLPs to process images.
- **[FNet Model](
https://doi.org/10.48550/arXiv.2105.03824
)**: Proposed by James Lee-Thorp et al., it relies on the unparameterized Fourier Transform for feature extraction.
- **[gMLP Model](
https://doi.org/10.48550/arXiv.2105.08050
)**: Created by Hanxiao Liu et al., this model incorporates gating mechanisms within MLPs.

### Purpose

The goal of this example is not to compare the performance of these models, as their effectiveness can vary across different datasets and with optimized hyperparameters. Instead, the focus is on providing simple implementations of their core components.

### Key Points

- **MLP-Mixer**: Combines two types of MLPs for spatial and channel-wise mixing of image data.
- **FNet**: Uses Fourier Transforms instead of attention mechanisms for efficient feature extraction.
- **gMLP**: Incorporates gating mechanisms to improve the expressiveness of standard MLPs.

### Implementation Focus

- Demonstrates the main building blocks of each model.
- Uses the CIFAR-100 dataset for image classification tasks.
- Provides insights into attention-free neural network architectures for image processing.

## Setup

In [1]:
# Update Keras 3
!pip install --upgrade keras



In [2]:
import numpy as np
import keras
from keras import layers

## Prepare the data

In [3]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")

x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


## Configure the hyperparameters

In [4]:
weight_decay = 0.0001
batch_size = 128
num_epochs = 10  # Recommended num_epochs = 50
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size.
patch_size = 8  # Size of the patches to be extracted from the input images.
num_patches = (image_size // patch_size) ** 2  # Size of the data array.
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of blocks.

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")

Image size: 64 X 64 = 4096
Patch size: 8 X 8 = 64 
Patches per image: 64
Elements per patch (3 channels): 192


## Build a classification model

In [5]:
def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data.
    augmented = data_augmentation(inputs)
    # Create patches.
    patches = Patches(patch_size)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor.
    x = layers.Dense(units=embedding_dim)(patches)
    if positional_encoding:
        x = x + PositionEmbedding(sequence_length=num_patches)(x)
    # Process x using the module blocks.
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor.
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation)
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)

## Define an experiment

In [6]:
def run_experiment(model):
    # Create Adam optimizer with weight decay.
    optimizer = keras.optimizers.AdamW(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )
    # Compile the model.
    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )
    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5
    )
    # Create an early stopping callback.
    early_stopping = keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
        verbose=0,
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history

## Use data augmentation

In [7]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)

## Implement patch extraction as a layer

In [8]:
class Patches(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, x):
        patches = keras.ops.image.extract_patches(x, self.patch_size)
        batch_size = keras.ops.shape(patches)[0]
        num_patches = keras.ops.shape(patches)[1] * keras.ops.shape(patches)[2]
        patch_dim = keras.ops.shape(patches)[3]
        out = keras.ops.reshape(patches, (batch_size, num_patches, patch_dim))
        return out

## Implement position embedding as a layer

In [9]:
class PositionEmbedding(keras.layers.Layer):
    def __init__(
        self,
        sequence_length,
        initializer="glorot_uniform",
        **kwargs,
    ):
        super().__init__(**kwargs)
        if sequence_length is None:
            raise ValueError("`sequence_length` must be an Integer, received `None`.")
        self.sequence_length = int(sequence_length)
        self.initializer = keras.initializers.get(initializer)

    def get_config(self):
        config = super().get_config()
        config.update(
            {
                "sequence_length": self.sequence_length,
                "initializer": keras.initializers.serialize(self.initializer),
            }
        )
        return config

    def build(self, input_shape):
        feature_size = input_shape[-1]
        self.position_embeddings = self.add_weight(
            name="embeddings",
            shape=[self.sequence_length, feature_size],
            initializer=self.initializer,
            trainable=True,
        )

        super().build(input_shape)

    def call(self, inputs, start_index=0):
        shape = keras.ops.shape(inputs)
        feature_length = shape[-1]
        sequence_length = shape[-2]
        # trim to match the length of the input sequence, which might be less
        # than the sequence_length of the layer.
        position_embeddings = keras.ops.convert_to_tensor(self.position_embeddings)
        position_embeddings = keras.ops.slice(
            position_embeddings,
            (start_index, 0),
            (sequence_length, feature_length),
        )
        return keras.ops.broadcast_to(position_embeddings, shape)

    def compute_output_shape(self, input_shape):
        return input_shape

## The MLP-Mixer Model

The MLP-Mixer is a novel neural network architecture that relies exclusively on multi-layer perceptrons (MLPs). It uses two types of MLP layers to process image data:

1. **Patch-wise MLPs**: Applied independently to image patches, these layers mix the per-location features.
2. **Channel-wise MLPs**: Applied across patches along the channels, these layers mix spatial information.

This approach is analogous to depthwise separable convolutions used in models like Xception, but with some key differences:

- **Chained Dense Transforms**: Instead of convolutions, the MLP-Mixer chains two dense (fully connected) transformations.
- **No Max Pooling**: The model does not use max pooling for down-sampling.
- **Layer Normalization**: Replaces batch normalization with layer normalization to stabilize and accelerate training.

### Key Points

- **Patch-wise MLPs**: Process each patch of the image independently to mix features within the patch.
- **Channel-wise MLPs**: Mix features across different patches along the channels to capture spatial relationships.
- **Chained Dense Layers**: Utilizes dense layers in place of convolutions to perform feature mixing.
- **Layer Normalization**: Ensures stability and faster convergence during training.

![](https://camo.githubusercontent.com/20dc92bd1da191765477f203d02535239ffdaad6d12b4fc0a5ddd67b3d3f290e/68747470733a2f2f73746f726167652e676f6f676c65617069732e636f6d2f70726f746f6e782d636c6f75642d73746f726167652f43617074757265332e504e47)

### Implement the MLP-Mixer module

In [10]:
class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches, activation="gelu"),
                layers.Dense(units=hidden_units),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def build(self, input_shape):
        return super().build(input_shape)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = keras.ops.transpose(x, axes=(0, 2, 1))
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = keras.ops.transpose(mlp1_outputs, axes=(0, 2, 1))
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

### Build, train, and evaluate the MLP-Mixer model

In [13]:
mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks)
history = run_experiment(mlpmixer_classifier)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 5ms/step - acc: 0.3523 - loss: 2.6181 - top5-acc: 0.6757
Test accuracy: 35.57%
Test top 5 accuracy: 67.04%


## The FNet model

The FNet is a neural network architecture inspired by the Transformer block but replaces the self-attention mechanism with a parameter-free 2D Fourier transformation layer. Here's how it works:

- **Fourier Transform**: Instead of using self-attention to capture dependencies in the input data, FNet applies the Fourier Transform.
  - **1D Fourier Transform Along Patches**: This transform is applied along the sequence of image patches.
  - **1D Fourier Transform Along Channels**: This transform is applied along the feature channels of the image.

### Key Points

- **Transformer-like Structure**: FNet retains a similar structure to the Transformer block but swaps out the self-attention layer for a Fourier Transform layer.
- **Parameter-free**: The Fourier Transform layers do not have learnable parameters, simplifying the model.
- **Two-dimensional Fourier Transform**: Consists of two 1D Fourier Transforms, one along the patches and one along the channels.

![](https://miro.medium.com/v2/resize:fit:1010/1*7ZfynrPBS6jNIu4U49TMCA.png)

### Implement the FNet module

In [23]:
class FNetLayer(layers.Layer):
    def __init__(self, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.ffn = keras.Sequential(
            [
                layers.Dense(units=embedding_dim, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
                layers.Dense(units=embedding_dim),
            ]
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply fourier transformations.
        real_part = inputs
        im_part = keras.ops.zeros_like(inputs)
        x = keras.ops.fft2((real_part, im_part))[0]
        # Add skip connection.
        x = x + inputs
        # Apply layer normalization.
        x = self.normalize1(x)
        # Apply Feedfowrad network.
        x_ffn = self.ffn(x)
        # Add skip connection.
        x = x + x_ffn
        # Apply layer normalization.
        return self.normalize2(x)

### Build, train, and evaluate the FNet model


In [24]:
fnet_blocks = keras.Sequential(
    [FNetLayer(embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.001
fnet_classifier = build_classifier(fnet_blocks, positional_encoding=True)
history = run_experiment(fnet_classifier)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2s[0m 6ms/step - acc: 0.3834 - loss: 2.3661 - top5-acc: 0.7068
Test accuracy: 37.87%
Test top 5 accuracy: 70.16%


## The gMLP Model

The gMLP (Gated MLP) is a neural network architecture that introduces a Spatial Gating Unit (SGU) to enable interactions across image patches along the spatial dimension (channels). Here's how it works:

1. **Spatial Transformation**: The input is transformed by applying a linear projection across patches along the channels.
2. **Element-wise Multiplication**: The original input is multiplied element-wise with its spatial transformation.

### Key Points

- **Spatial Gating Unit (SGU)**: This unit facilitates interactions across patches by transforming the input spatially and then gating it.
- **Linear Projection**: A linear transformation is applied to the input across the patches along the channel dimension.
- **Element-wise Multiplication**: The transformed input and the original input are combined through element-wise multiplication, enabling the model to learn cross-patch relationships.

![](https://production-media.paperswithcode.com/methods/641e1c00-a87b-40ce-a0ab-af50ac6aa318.png
)

### Implement the gMLP module

In [27]:
class gMLPLayer(layers.Layer):
    def __init__(self, num_patches, embedding_dim, dropout_rate, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.channel_projection1 = keras.Sequential(
            [
                layers.Dense(units=embedding_dim * 2, activation="gelu"),
                layers.Dropout(rate=dropout_rate),
            ]
        )

        self.channel_projection2 = layers.Dense(units=embedding_dim)

        self.spatial_projection = layers.Dense(
            units=num_patches, bias_initializer="Ones"
        )

        self.normalize1 = layers.LayerNormalization(epsilon=1e-6)
        self.normalize2 = layers.LayerNormalization(epsilon=1e-6)

    def spatial_gating_unit(self, x):
        # Split x along the channel dimensions.
        # Tensors u and v will in the shape of [batch_size, num_patchs, embedding_dim].
        u, v = keras.ops.split(x, indices_or_sections=2, axis=2)
        # Apply layer normalization.
        v = self.normalize2(v)
        # Apply spatial projection.
        v_channels = keras.ops.transpose(v, axes=(0, 2, 1))
        v_projected = self.spatial_projection(v_channels)
        v_projected = keras.ops.transpose(v_projected, axes=(0, 2, 1))
        # Apply element-wise multiplication.
        return u * v_projected

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize1(inputs)
        # Apply the first channel projection. x_projected shape: [batch_size, num_patches, embedding_dim * 2].
        x_projected = self.channel_projection1(x)
        # Apply the spatial gating unit. x_spatial shape: [batch_size, num_patches, embedding_dim].
        x_spatial = self.spatial_gating_unit(x_projected)
        # Apply the second channel projection. x_projected shape: [batch_size, num_patches, embedding_dim].
        x_projected = self.channel_projection2(x_spatial)
        # Add skip connection.
        return x + x_projected

### Build, train, and evaluate the gMLP model

In [28]:
gmlp_blocks = keras.Sequential(
    [gMLPLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)]
)
learning_rate = 0.003
gmlp_classifier = build_classifier(gmlp_blocks)
history = run_experiment(gmlp_classifier)

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m3s[0m 9ms/step - acc: 0.3886 - loss: 2.3994 - top5-acc: 0.6950
Test accuracy: 38.79%
Test top 5 accuracy: 69.64%
