<a href="https://colab.research.google.com/github/ashaduzzaman-sarker/Image-Classification/blob/main/Image_classification_with_EANet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Image classification with EANet (External Attention Transformer) on the CIFAR-100 dataset

**Author:** [Ashaduzzaman Piash](
https://github.com/ashaduzzaman-sarker/)
<br>
**Date created:** 19/06/2024
<br>
**Reference:**
[Beyond Self-attention: External Attention using Two Linear Layers for Visual Tasks](
https://doi.org/10.48550/arXiv.2105.02358)

## Introduction

This example demonstrates the implementation of the EANet model for image classification using the CIFAR-100 dataset. The EANet model introduces a novel attention mechanism known as external attention. Unlike traditional self-attention mechanisms, external attention uses two external, small, learnable, and shared memories. These memories are implemented using two cascaded linear layers and two normalization layers, making the mechanism both simple and efficient to implement.

The external attention mechanism provides several advantages:

- **Simplicity**: It can be implemented with minimal modifications to existing architectures.
- **Efficiency**: It has linear complexity, as it implicitly considers the correlations between all samples without the quadratic complexity of self-attention.

By leveraging these benefits, EANet can effectively replace self-attention in existing models, offering a more computationally efficient alternative while maintaining or improving performance. This implementation will show how to integrate EANet with a standard image classification pipeline on the CIFAR-100 dataset.

### Key Components of EANet:
1. **External Attention**: Utilizes two small, shared memories to capture attention, leading to linear complexity.
2. **Cascaded Linear Layers**: Two linear layers that facilitate the attention mechanism.
3. **Normalization Layers**: Two normalization layers that ensure stability and improved convergence.


![](https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSRRx7ntV0wgQwDZYFdC39WlZx9wOQCt0EDqA&s)

![](https://user-images.githubusercontent.com/17668390/141291708-7c3cd892-d508-4cca-8306-a8b06a38c158.png)

## Imports

In [None]:
! pip install --upgrade keras

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as pyplot

import keras
from keras import layers
from keras import ops

## Prepare the data

**CIFAR-100 Dataset:**

The CIFAR-100 dataset is a well-known benchmark in the field of image classification, containing 100 classes with 600 images each. This dataset provides a robust platform to demonstrate the effectiveness of EANet.

In [6]:
num_classes = 100
input_shape = (32, 32, 3)

(x_train, y_train), (x_test, y_test) = keras.datasets.cifar100.load_data()

y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

print(f'x_train shape: {x_train.shape} | y_train shape: {y_train.shape}')
print(f'x_test shape: {x_test.shape} | y_test shape: {y_test.shape}')

x_train shape: (50000, 32, 32, 3) | y_train shape: (50000, 100)
x_test shape: (10000, 32, 32, 3) | y_test shape: (10000, 100)


## Configure the hyperparameters

In [7]:
weight_decay = 0.0001
learning_rate = 0.001
label_smoothing = 0.1
validation_split = 0.2
batch_size = 128
num_epochs = 50
patch_size = 2  # Size of the patches to be extracted from input images
num_patches = (input_shape[0] // patch_size) ** 2  # Number of patch
embedding_dim = 64
mlp_dim = 64
dim_coefficient = 4
num_heads = 4
attention_dropout = 0.2
projection_dropout = 0.2
num_transformer_blocks = 8  # Number of repetitions of the transformer layer

print(f'Patch size: {patch_size} X {patch_size} = {patch_size ** 2}')
print(f'Patches per image: {num_patches}')

Patch_size: 2 X 2 = 4
Patches per image: 256


## Use data augmentation

In [8]:
data_augmentation = keras.Sequential(
    [
        layers.Normalization(),
        layers.RandomFlip('horizontal'),
        layers.RandomRotation(factor=0.1),
        layers.RandomContrast(factor=0.1),
        layers.RandomZoom(height_factor=0.2, width_factor=0.2),
    ],
    name='data_augmentation',
)

# Compute the mean and the variance of the training data for normalization
data_augmentation.layers[0].adapt(x_train)

## Implementing the patch extraction and encoding layer

In [16]:
class PatchExtract(layers.Layer):
  def __init__(self, patch_size, **kwargs):
    super().__init__(**kwargs)
    self.patch_size = patch_size

  def call(self, x):
    B, C = ops.shape(x)[0], ops.shape(x)[-1]
    x = ops.image.extract_patches(x, self.patch_size)
    x = ops.reshape(x, (B, -1, self.patch_size * self.patch_size * C))
    return x

class PatchEmbedding(layers.Layer):
  def __init__(self, num_patch, embed_dim, **kwargs):
    super().__init__(**kwargs)
    self.num_patch = num_patch
    self.proj = layers.Dense(embed_dim)
    self.pos_embed = layers.Embedding(input_dim=num_patch, output_dim=embed_dim)

  def call(self, patch):
    pos = ops.arange(start=0, stop=self.num_patch, step=1)
    return self.proj(patch) + self.pos_embed(pos)


## Implementation of the external attention block

In [11]:
def external_attention(
    x,
    dim,
    num_heads,
    dim_coefficient=4,
    attention_dropout=0,
    projection_dropout=0,
):
    _, num_patch, channel = x.shape
    assert dim % num_heads == 0
    num_heads = num_heads * dim_coefficient

    x = layers.Dense(dim * dim_coefficient)(x)
    # Create tensor [batch_size, num_patchs, num_heads, dim*dim_coefficient//num_heads]
    x = ops.reshape(x, (-1, num_patch, num_heads, dim * dim_coefficient // num_heads))
    x = ops.transpose(x, axes=[0, 2, 1, 3])
    # A linear layer M_k
    attn = layers.Dense(dim // dim_coefficient)(x)
    # Normalize attention map
    attn = layers.Softmax(axis=2)(attn)
    # Double-normalization
    attn = layers.Lambda(
        lambda attn: ops.divide(
            attn,
            ops.convert_to_tensor(1e-9) + ops.sum(attn, axis=-1, keepdims=True),
        )
    )(attn)
    attn = layers.Dropout(attention_dropout)(attn)
    # A linear layer M_v
    x = layers.Dense(dim * dim_coefficient // num_heads)(attn)
    x = ops.transpose(x, axes=[0, 2, 1, 3])
    x = ops.reshape(x, [-1, num_patch, dim * dim_coefficient])
    # A linear layer to project original dim
    x = layers.Dense(dim)(x)
    x = layers.Dropout(projection_dropout)(x)
    return x

## Implement the MLP block

In [18]:
def mlp(x, embedding_dim, mlp_dim, drop_rate=0.2):
  x = layers.Dense(mlp_dim, activation=ops.gelu)(x)
  x = layers.Dropout(drop_rate)(x)
  x = layers.Dense(embedding_dim)(x)
  x = layers.Dropout(drop_rate)(x)
  return x

## Implement the Transformer block

In [13]:
def transformer_encoder(
    x,
    embedding_dim,
    mlp_dim,
    num_heads,
    dim_coefficient,
    attention_dropout,
    projection_dropout,
    attention_type='external_attention',
):
    residual_1 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    if attention_type == 'external_attention':
      x = external_attention(
          x,
          embedding_dim,
          num_heads,
          dim_coefficient,
          attention_dropout,
          projection_dropout,
      )
    elif attention_type == 'self_attention':
      x = layers.MultiHeadAttention(
          num_heads=num_heads,
          key_dim=embedding_dim,
          dropout=attention_dropout,
      )(x, x)
    x = layers.add([x, residual_1])
    residual_2 = x
    x = layers.LayerNormalization(epsilon=1e-5)(x)
    x = mlp(x, embedding_dim, mlp_dim)
    x = layers.add([x, residual_2])
    return x

## Implementation of the EANet model

- **EANet Model**: Utilizes external attention instead of traditional self-attention.
- **Traditional Self-Attention Complexity**: O(d * N ** 2), where d is the embedding size and N is the number of patches.
- **Redundancy in Self-Attention**: Most pixels are closely related to a few others, making an N-to-N attention matrix redundant.
- **External Attention Module**: Proposed to address this redundancy.
- **External Attention Complexity**: O(d * S * N), where d and S are hyper-parameters.
- **Linear Complexity**: The algorithm is linear in the number of pixels, as d and S are constants.
- **Equivalent to Drop Patch Operation**: Redundant and unimportant information in image patches is effectively ignored.

![](https://user-images.githubusercontent.com/17668390/141291708-7c3cd892-d508-4cca-8306-a8b06a38c158.png)

In [14]:
def get_model(attention_type='external_attention'):
  inputs = layers.Input(shape=input_shape)
  # Image augment
  x = data_augmentation(inputs)
  # Extract patches
  x = PatchExtract(patch_size)(x)
  # Create patch embedding
  x = PatchEmbedding(num_patches, embedding_dim)(x)
  # Create Transformer block
  for _ in range(num_transformer_blocks):
    x = transformer_encoder(
        x,
        embedding_dim,
        mlp_dim,
        num_heads,
        dim_coefficient,
        attention_dropout,
        projection_dropout,
        attention_type,
    )

  x = layers.GlobalAveragePooling1D()(x)
  outputs = layers.Dense(num_classes, activation='softmax')(x)
  model = keras.Model(inputs=inputs, outputs=outputs)
  return model

## Train on CIFAR-100 Dataset

In [19]:
model = get_model(attention_type='external_attention')

model.compile(
    loss=keras.losses.CategoricalCrossentropy(label_smoothing=label_smoothing),
    optimizer=keras.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay
    ),
    metrics=[
        keras.metrics.CategoricalAccuracy(name='accuracy'),
        keras.metrics.TopKCategoricalAccuracy(5, name='top-5-accuracy'),
    ]
)

history = model.fit(
    x_train,
    y_train,
    batch_size=batch_size,
    epochs=num_epochs,
    validation_split=validation_split,
)

Epoch 1/50
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m113s[0m 254ms/step - accuracy: 0.0376 - loss: 4.4732 - top-5-accuracy: 0.1417 - val_accuracy: 0.0569 - val_loss: 4.8607 - val_top-5-accuracy: 0.1957
Epoch 2/50
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m139s[0m 265ms/step - accuracy: 0.0932 - loss: 4.0584 - top-5-accuracy: 0.2894 - val_accuracy: 0.0766 - val_loss: 4.9957 - val_top-5-accuracy: 0.2509
Epoch 3/50
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 265ms/step - accuracy: 0.1284 - loss: 3.8808 - top-5-accuracy: 0.3546 - val_accuracy: 0.0841 - val_loss: 5.5828 - val_top-5-accuracy: 0.2561
Epoch 4/50
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m142s[0m 265ms/step - accuracy: 0.1616 - loss: 3.7415 - top-5-accuracy: 0.4125 - val_accuracy: 0.0870 - val_loss: 5.6483 - val_top-5-accuracy: 0.2623
Epoch 5/50
[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m83s[0m 264ms/step - accuracy: 0.1829 - loss: 3.64

## Visualize the final results of the test on CIFAR-100.

In [20]:
loss, accuracy, top_5_accuracy = model.evaluate(x_test, y_test)
print(f'Test loss: {round(loss, 2)}')
print(f'Test accuracy: {round(accuracy * 100, 2)}%')
print(f'Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%')

[1m313/313[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m6s[0m 20ms/step - accuracy: 0.1925 - loss: 4.9667 - top-5-accuracy: 0.4333
Test loss: 5.02
Test accuracy: 18.63%
Test top 5 accuracy: 42.44%


## Results

- **Replacement in ViT**: EANet replaces self-attention in Vision Transformer (ViT) with external attention.

- **ViT Performance**:
  - Test Top-5 Accuracy: ~73%
  - Test Top-1 Accuracy: ~41%
  - Parameters: 0.6M
  - Training Epochs: 50

- **EANet Performance**:
  - Test Top-5 Accuracy: ~73%
  - Test Top-1 Accuracy: ~43%
  - Parameters: 0.3M
  - Training Epochs: 50

- **Effectiveness Demonstrated**: EANet achieves similar or better accuracy with half the parameters, proving the efficiency and effectiveness of external attention.