# Study of Image classification with modern MLP Mixer model and CKA

**Author:** [Arturo Flores](https://www.linkedin.com/in/afloresalv/)<br>

## Introduction

This example implements three modern attention-free, multi-layer perceptron (MLP) based models for image
classification, demonstrated on the CIFAR-100 dataset:

1. The [MLP-Mixer](https://arxiv.org/abs/2105.01601) model, by Ilya Tolstikhin et al., based on two types of MLPs.

The purpose of the example is not to compare between these models, as they might perform differently on
different datasets with well-tuned hyperparameters. Rather, it is to show simple implementations of their
main building blocks.

This example requires TensorFlow 2.4 or higher, as well as
[TensorFlow Addons](https://www.tensorflow.org/addons/overview),
which can be installed using the following command:

```shell
pip install -U tensorflow-addons
update tensorflow core to 2.7.0
```

## Setup

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa
import matplotlib.pyplot as plt 
from CKA import linear_CKA, kernel_CKA
import seaborn as sns 

## Prepare the data
C1FAR 100 = 100 classes, each 600 images

In [None]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {x_train.shape} - y_train shape: {y_train.shape}")
print(f"x_test shape: {x_test.shape} - y_test shape: {y_test.shape}")
#plt.imshow(x_train[1])

## Configure the hyperparameters

In [1]:
weight_decay = 0.0001
batch_size = 128 # The paper also fine tunes this to 512
num_epochs = 20
dropout_rate = 0.2
image_size = 64  # We'll resize input images to this size. Square
patch_size = 8  # Size of the patches to be extracted from the input images. Square
num_patches = (image_size // patch_size) ** 2  # Size of the data array, or sequence length (S)
embedding_dim = 256  # Number of hidden units.
num_blocks = 4  # Number of Mixer Layers

print(f"Image size: {image_size} X {image_size} = {image_size ** 2}")
print(f"Patch size: {patch_size} X {patch_size} = {patch_size ** 2} ")
print(f"Patches per image: {num_patches}")
print(f"Elements per patch (3 channels): {(patch_size ** 2) * 3}")

Image size: 64 X 64 = 4096
Patch size: 8 X 8 = 64 
Patches per image: 64
Elements per patch (3 channels): 192


## Build a classification model

We implement a method that builds a classifier given the processing blocks. \
Positional Encoding = https://kazemnejad.com/blog/transformer_architecture_positional_encoding/  


In [None]:

def build_classifier(blocks, positional_encoding=False):
    inputs = layers.Input(shape=input_shape)
    # Augment data. ## (data_augmentation)
    augmented = data_augmentation(inputs)
    # Create patches. ## (patches_4)
    patches = Patches(patch_size, num_patches)(augmented)
    # Encode patches to generate a [batch_size, num_patches, embedding_dim] tensor. ## (dense_163)
    x = layers.Dense(units=embedding_dim)(patches)
    if positional_encoding:
        positions = tf.range(start=0, limit=num_patches, delta=1)
        position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=embedding_dim
        )(positions)
        x = x + position_embedding
    # Process x using the module blocks. ## (sequential_82)
    x = blocks(x)
    # Apply global average pooling to generate a [batch_size, embedding_dim] representation tensor. ## (Global Average Pooling)
    representation = layers.GlobalAveragePooling1D()(x)
    # Apply dropout.
    representation = layers.Dropout(rate=dropout_rate)(representation) ## (Dropout)
    # Compute logits outputs.
    logits = layers.Dense(num_classes)(representation) ## (dense_164) - output
    # Create the Keras model.
    return keras.Model(inputs=inputs, outputs=logits)


## Define an experiment

We implement a utility function to compile, train, and evaluate a given model. \
Adam Algorithm with Weight Decay: https://www.tensorflow.org/addons/api_docs/python/tfa/optimizers/AdamW \
Losses: https://keras.io/api/losses/ \
Reduce learning rate: https://keras.io/api/callbacks/reduce_lr_on_plateau/ \
Logits: https://www.youtube.com/watch?v=icQaFxKa_J0


In [None]:

def run_experiment(model):
    # Create Adam optimizer with weight decay. Regularization that penalizes the increase of weight - with a facto alpha - to correct the overfitting
    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay,
    )
    # Compile the model.
    model.compile(
        optimizer=optimizer,
        #Negative Log Likelihood = Categorical Cross Entropy
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="acc"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top5-acc"),
        ],
    )
    # Create a learning rate scheduler callback.
    reduce_lr = keras.callbacks.ReduceLROnPlateau(
        monitor="val_loss", factor=0.5, patience=5
    )
    # Create an early stopping regularization callback. 
    # It ends at a point that corresponds to a minimum of the L2-regularized objective
    early_stopping = tf.keras.callbacks.EarlyStopping(
        monitor="val_loss", patience=10, restore_best_weights=True
    )
    # Fit the model.
    history = model.fit(
        x=x_train,
        y=y_train,
        batch_size=batch_size,
        epochs=num_epochs,
        validation_split=0.1,
        callbacks=[early_stopping, reduce_lr],
    )

    _, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

    # Return history to plot learning curves.
    return history


## Use data augmentation
Their state is not set during training; it must be set before training, either by initializing them from a precomputed constant, or by "adapting" them on data.

In [None]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.Resizing(image_size, image_size),
        layers.RandomFlip("horizontal"),
        layers.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(x_train)


## Implement patch extraction as a layer
Atributes and heritage: https://pythones.net/funcion-super-en-python-bien-explicada-ejemplos-oop/ \
Extract Patches: https://www.tensorflow.org/api_docs/python/tf/image/extract_patches \
Reshape: https://www.tensorflow.org/api_docs/python/tf/reshape \
If one component of shape is the special value -1, the size of that dimension is computed so that the total size remains constant. \
Preprocessing data: https://www.tensorflow.org/guide/keras/preprocessing_layers

In [None]:

class Patches(layers.Layer):
    def __init__(self, patch_size, num_patches):
        super(Patches, self).__init__()
        self.patch_size = patch_size
        self.num_patches = num_patches

    def call(self, images):
        #Extract the shape dimension in the position 0 = columns
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            #Without overlapping, stride horizontally and vertically
            strides=[1, self.patch_size, self.patch_size, 1],
            #Rate: Dilation factor [1 1* 1* 1] controls the spacing between the kernel points.
            rates=[1, 1, 1, 1],
            #Patches contained in the images are considered, no zero padding
            padding="VALID",
        )
        #shape[-1], number of colummns, as well as shape[0]
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, self.num_patches, patch_dims])
        return patches


## The MLP-Mixer model

The MLP-Mixer is an architecture based exclusively on
multi-layer perceptrons (MLPs), that contains two types of MLP layers:

1. One applied independently to image patches, which mixes the per-location features.
2. The other applied across patches (along channels), which mixes spatial information.

This is similar to a [depthwise separable convolution based model](https://arxiv.org/pdf/1610.02357.pdf)
such as the Xception model, but with two chained dense transforms, no max pooling, and layer normalization
instead of batch normalization.

### Implement the MLP-Mixer module

In [None]:

class MLPMixerLayer(layers.Layer):
    def __init__(self, num_patches, hidden_units, dropout_rate, *args, **kwargs):
        super(MLPMixerLayer, self).__init__(*args, **kwargs)

        self.mlp1 = keras.Sequential(
            [
                layers.Dense(units=num_patches),
                tfa.layers.GELU(),
                layers.Dense(units=num_patches),
                layers.Dropout(rate=dropout_rate),
            ]
        )

        self.mlp2 = keras.Sequential(
            [
                layers.Dense(units=num_patches),
                tfa.layers.GELU(),
                layers.Dense(units=embedding_dim),
                layers.Dropout(rate=dropout_rate),
            ]
        )
        self.normalize = layers.LayerNormalization(epsilon=1e-6)

    def call(self, inputs):
        # Apply layer normalization.
        x = self.normalize(inputs)
        # Transpose inputs from [num_batches, num_patches, hidden_units] to [num_batches, hidden_units, num_patches].
        x_channels = tf.linalg.matrix_transpose(x)
        # Apply mlp1 on each channel independently.
        mlp1_outputs = self.mlp1(x_channels)
        # Transpose mlp1_outputs from [num_batches, hidden_dim, num_patches] to [num_batches, num_patches, hidden_units].
        mlp1_outputs = tf.linalg.matrix_transpose(mlp1_outputs)
        # Add skip connection.
        x = mlp1_outputs + inputs
        # Apply layer normalization.
        x_patches = self.normalize(x)
        # Apply mlp2 on each patch independtenly.
        mlp2_outputs = self.mlp2(x_patches)
        # Add skip connection.
        x = x + mlp2_outputs
        return x

### Build, train, and evaluate the MLP-Mixer model

Note that training the model with the current settings on a V100 GPUs
takes around 8 seconds per epoch.

In [None]:
mlpmixer_blocks = keras.Sequential(
    [MLPMixerLayer(num_patches, embedding_dim, dropout_rate) for _ in range(num_blocks)] # creates the number of block without a 
)

In [None]:
learning_rate = 0.005
mlpmixer_classifier = build_classifier(mlpmixer_blocks) # Returns the model

In [None]:
mlpmixer_classifier.summary()

In [None]:
history = run_experiment(mlpmixer_classifier)

The MLP-Mixer model tends to have much less number of parameters compared
to convolutional and transformer-based models, which leads to less training and
serving computational cost.

As mentioned in the [MLP-Mixer](https://arxiv.org/abs/2105.01601) paper,
when pre-trained on large datasets, or with modern regularization schemes,
the MLP-Mixer attains competitive scores to state-of-the-art models.
You can obtain better results by increasing the embedding dimensions,
increasing the number of mixer blocks, and training the model for longer.
You may also try to increase the size of the input images and use different patch sizes.

## Visualization of activations
Here are some important methods of each layer
- dir()
- type()
- ._name
- .get_input_shape_at(0)

```shell
1block
mlpmixer_classifier.layers[4].layers[0].mlp1.layers[0]
mlpmixer_classifier.layers[4].layers[0].mlp1.layers[2]
mlpmixer_classifier.layers[4].layers[0].mlp2.layers[0]
mlpmixer_classifier.layers[4].layers[0].mlp2.layers[2]
mlpmixer_classifier.layers[4].layers[1].mlp1.layers[0]
mlpmixer_classifier.layers[4].layers[1].mlp1.layers[0]
```

In [None]:
def Mixer_Layer_Outputs_base(layer_name,example):
    model = mlpmixer_classifier.layers[4].layers[1].mlp1
    layer_output=model.get_layer(layer_name).output
    intermediate_model=tf.keras.models.Model(inputs=model.input,outputs=layer_output)
    augmented = data_augmentation(x_train[example])
    b = Patches(patch_size, num_patches)(augmented)
    a = layers.Dense(units=embedding_dim)(b)
    inp = tf.reshape(a,[1,256,64])
    intermediate_prediction=intermediate_model.predict(inp)
    layactivation = intermediate_prediction.reshape((256,64))
    visualize_out(layactivation,layer_name,example)
    return layactivation
#test = Mixer_Layer_Outputs_base('dense_4',4)

In [None]:
def visualize_out(result,layer_name,example):
    fig, (ax1, ax2)= plt.subplots(1,2)
    ax1.imshow(x_train[example])
    ax1.set_title('Original_Figure, Class:' + str(y_train[example][0]))
    ax2.imshow(result)
    ax2.set_title('Activations of layer: '+ '"' + layer_name + '"')
    print(np.shape(result))
    print(np.shape(x_train[example]))
    return None

In [None]:
def Regularization(example):
    augmented = data_augmentation(x_train[example])
    b = Patches(patch_size, num_patches)(augmented)
    a = layers.Dense(units=embedding_dim)(b)
    inp = tf.reshape(a,[1,embedding_dim,num_patches])
    return inp

In [None]:
def Mixer_Activations(example):
    total_activations = list()
    model1 = mlpmixer_classifier.layers[4]
    example = Regularization(example)
    for i in range(num_blocks):
        shape=(embedding_dim,num_patches)
        model = model1.layers[i].mlp1
        int_total_activations = Mixer_Layer_Outputs(model,example,shape)
        total_activations.append(int_total_activations)
        shape=(num_patches,embedding_dim)
        model = model1.layers[i].mlp2
        int_total_activations = Mixer_Layer_Outputs(model,example,shape)
        total_activations.append(int_total_activations)
    return total_activations
#result = Mixer_Activations(2)

In [None]:
def Mixer_Layer_Outputs(model,example,shape):
    intermediate_model=tf.keras.models.Model(inputs=model.input,outputs=model.output)
    example = tf.reshape(example,[1,shape[0],shape[1]])
    intermediate_prediction =intermediate_model.predict(example)
    layactivation = intermediate_prediction.reshape((embedding_dim,num_patches))
    #visualize_out(layactivation,layer_name,example)
    return layactivation

In [None]:
def Heatmap(result,type,bl):
    dim = len(result)
    heatmap_kernel = np.zeros((dim,dim))
    heatmap_linear = np.zeros((dim,dim))
    axis_labels = list()
    for i in range(dim):
        for j in range(dim):
            heatmap_kernel[i][j] = kernel_CKA(result[i],result[j])
            heatmap_linear[i][j] = linear_CKA(result[i],result[j])  
        axis_labels_inter = str('L%iB%i'%((i//2)+1,(i%2)+1))
        axis_labels.append(axis_labels_inter)
    ax = plt.axes()
    ax.set_xlabel('L=# MixerLayer, B = #Block within the MixerLayer')
    ax.set_title('Similarity Measures - Index: CKA-'+ type)
    if type == 'kernel':
        sns.heatmap(heatmap_kernel, xticklabels=axis_labels, yticklabels=axis_labels, ax = ax, annot=bl) 
    elif type == 'linear':
        sns.heatmap(heatmap_linear, xticklabels=axis_labels, yticklabels=axis_labels, ax = ax, annot=bl)
    else:
        print('Error, there is no type of plot define for the selection (2nd argument)')
    
    return heatmap_kernel

## CKA Kernel

### Setup additional libraries

GitHub1: https://github.com/yuanli2333/CKA-Centered-Kernel-Alignment/blob/master/CKA.ipynb \
GitHub2: https://github.com/jayroxis/CKA-similarity/blob/main/CKA.ipynb \
Paper: https://arxiv.org/pdf/1905.00414.pdf

In [None]:
plt.imshow(x_train[2])

In [None]:
## Repu_First_Activity
result = Mixer_Activations(2)
plot_heatmap = Heatmap(result,'kernel',False)


In [None]:
# Reporte: Curva de aprendizaje
def diagnostico(histories):
    for i in range(len(histories)):
        # Graph loss
        pyplot.subplot(2,1,1)
        pyplot.title("Cross Entropy Loss")
        pyplot.plot(histories[i].history["loss"], color = 'blue', label = 'Entramiento')
        pyplot.plot(histories[i].history["val_loss"], color = 'orange', label = 'Testing')
        # Graph accuracy
        pyplot.subplot(2,1,2)
        pyplot.title('Classification accuracy')
        pyplot.plot(histories[i].history["accuracy"], color = 'blue', label = 'Entramiento')
        pyplot.plot(histories[i].history["val_accuracy"], color = 'blue', label = 'Testing')
    pyplot.show()

In [None]:
    mixer_json = mlpmixer_classifier.to_json()
    #Save the architecture of the model
    with open(name_file +".json",'w') as json_file:
        json_file.write(mixer_json)
    # Save the weights in a hdf5 file
    mlpmixer_classifier.save_weights( name_file +".h5")