# Implementation and fine-tuning of GCViT: Global Context Vision Transformer for image classification.

**Author:** [Ashaduzzaman Sarker](https://github.com/ashaduzzaman-sarker/)
<br>
**Date created:** 01/07/2024
**Reference:**

 - [Global Context Vision Transformers](
https://doi.org/10.48550/arXiv.2206.09959)

 - [Keras](https://keras.io/examples/vision/image_classification_using_global_context_vision_transformer/)

## Introduction

- This notebook will implement the GCViT (Global Context Vision Transformer) paper presented at ICML 2023 by A Hatamizadeh et al. using multi-backend Keras 3.0.
- We will fine-tune the model on the Flower dataset for an image classification task, utilizing official ImageNet pre-trained weights.
- A key feature of this notebook is its compatibility with multiple backends: TensorFlow, PyTorch, and JAX, highlighting the true potential of multi-backend Keras.

## Setup

In [1]:
!pip install --upgrade keras_cv tensorflow
!pip install --upgrade keras

Collecting keras_cv
  Downloading keras_cv-0.9.0-py3-none-any.whl (650 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m650.7/650.7 kB[0m [31m4.9 MB/s[0m eta [36m0:00:00[0m
Collecting tensorflow
  Downloading tensorflow-2.16.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (590.6 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m590.6/590.6 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
Collecting keras-core (from keras_cv)
  Downloading keras_core-0.1.7-py3-none-any.whl (950 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m950.8/950.8 kB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
Collecting h5py>=3.10.0 (from tensorflow)
  Downloading h5py-3.11.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (5.3 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.3/5.3 MB[0m [31m37.5 MB/s[0m eta [36m0:00:00[0m
Collecting ml-dtypes~=0.3.1 (from tensorflow)
  Downloading ml_dtypes-0.3.2-cp310-cp

In [2]:
import keras
from keras_cv.layers import DropPath
from keras import ops
from keras import layers

import tensorflow as tf
import tensorflow_datasets as tfds

from skimage.data import chelsea
import matplotlib.pyplot as plt
import numpy as np

## Motivation

- **Note**: In this section, we will explore the background of GCViT and understand the rationale behind its proposal.
- **Transformers in NLP**: Recently, Transformers have become dominant in Natural Language Processing (NLP) tasks due to their self-attention mechanism, which captures both long and short-range information.
- **Vision Transformer (ViT)**: Inspired by this trend, Vision Transformer (ViT) proposed using image patches as tokens in a large architecture similar to the original Transformer's encoder.
- **ViT vs. CNN**: Despite the historical dominance of Convolutional Neural Networks (CNN) in computer vision, ViT-based models have demonstrated state-of-the-art (SOTA) or competitive performance in various computer vision tasks.
<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/vit_gif.gif)

<br>

- The quadratic [O(n^2)] computational complexity of self-attention and the lack of multi-scale information hinder ViT's suitability as a general-purpose architecture for computer vision tasks like segmentation and object detection, which require dense pixel-level predictions.
- The Swin Transformer addresses some of ViT's issues by introducing multi-resolution/hierarchical architectures where self-attention is computed in local windows, and cross-window connections like window shifting model interactions across regions.
- However, the limited receptive field of local windows in the Swin Transformer fails to capture long-range information, and cross-window connection schemes like window shifting only cover small neighborhoods near each window.
- Additionally, the Swin Transformer lacks inductive bias, which encourages translation invariance, a desirable property for general-purpose visual modeling, especially for dense prediction tasks like object detection and semantic segmentation.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/swin_vs_vit.JPG)

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/shifted_window.JPG)

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/swin_arch.JPG)

<br>

- To address above limitations, **Global Context (GC) ViT** network is proposed.

## Architecture

Let's have a quick overview of our key components:

1. **Stem/PatchEmbed**:
    - Processes images at the network’s beginning.
    - Creates patches/tokens and converts them into embeddings.

2. **Level**:
    - Repetitive building block that extracts features using different blocks.

3. **Global Token Generation/Feature Extraction**:
    - Generates global tokens/patches using Depthwise-CNN, SqueezeAndExcitation (Squeeze-Excitation), CNN, and MaxPooling.
    - Essentially acts as a feature extractor.

4. **Block**:
    - Repetitive module that applies attention to the features and projects them to a certain dimension.
        1. **Local-MSA**: Local Multi-Head Self-Attention.
        2. **Global-MSA**: Global Multi-Head Self-Attention.
        3. **MLP**: Linear layer that projects a vector to another dimension.

5. **Downsample/ReduceSize**:
    - Similar to the Global Token Generation module but uses CNN instead of MaxPooling to downsample, with additional Layer Normalization modules.

6. **Head**:
    - Responsible for the classification task.
        1. **Pooling**: Converts N x 2D features to N x 1D features.
        2. **Classifier**: Processes N x 1D features to make a decision about class.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/arch_annot.png)



### Unit Blocks

**1. SqueezeAndExcitation:**

- **Squeeze-Excitation (SE)**, also known as the Bottleneck module, functions as a form of channel attention.
- It consists of the following components:
    - **AvgPooling**: Averages the spatial dimensions of the input.
    - **Dense/FullyConnected (FC)/Linear**: Applies a fully connected layer to the pooled output.
    - **GELU**: Uses the Gaussian Error Linear Unit activation function.
    - **Sigmoid**: Applies the sigmoid activation function to produce the final output.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/se_annot.png)

<br>

**2. Fused-MBConv:**

- **Fused-MBConv**: Similar to the one used in EfficientNetV2.
    - Utilizes **Depthwise-Conv** for depthwise convolution.
    - **GELU** for activation.
    - **SqueezeAndExcitation** for channel attention.
    - **Conv** for regular convolution.
    - Includes a **residual connection** to retain input information.
- Note: No new modules are declared for this; the corresponding existing modules are applied directly.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/fmb_annot.png)

<br>

**3. ReduceSize:**

- **ReduceSize**: A CNN-based downsample module, referred to as the downsample module in the paper/figure.
    - **Fused-MBConv**: Extracts features.
    - **Strided Conv**: Simultaneously reduces spatial dimensions and increases channel-wise dimensions of the features.
    - **LayerNormalization**: Normalizes features.
- Noteworthy: SwinTransformer uses the PatchMerging module instead of ReduceSize, which employs fully-connected/dense/linear modules.
- According to the GCViT paper, the purpose of using ReduceSize is to introduce inductive bias through the CNN module.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/down_annot.png)

<br>

**4. MLP**

- **MLP (Multi-Layer Perceptron)**:
    - A feed-forward/fully-connected/linear module.
    - Projects input to an arbitrary dimension.



In [3]:
# Squeeze and excitation block
class SqueezeAndExcitation(layers.Layer):
  '''
  Args:
      output_dim: output features dimension, if 'None' use same dim as input
      expansion: expansion ratio
  '''

  def __init__(self, output_dim=None, expansion=0.25, **kwargs):
    super().__init__(**kwargs)
    self.expansion = expansion
    self.output_dim = output_dim

  def build(self, input_shape):
    inp = input_shape[-1]
    self.output_dim = self.output_dim or inp
    self.avg_pool = layers.GlobalAveragePooling2D(keepdims=True, name='avg_pool')
    self.fc = [
        layers.Dense(int(inp * self.expansion), use_bias=False, name='fc_0'),
        layers.Activation('gelu', name='fc_1'),
        layers.Dense(self.output_dim, use_bias=False, name='fc_2'),
        layers.Activation('sigmoid', name='fc_3'),
    ]
    super().build(input_shape)

  def call(self, inputs, **kwargs):
    x = self.avg_pool(inputs)
    for layer in self.fc:
      x = layer(x)
    return x * inputs

# Down-sampling block
class ReduceSize(layers.Layer):
  '''
  Args:
      keepdims: if False spatial dim is reduced and channel dim is increased
  '''

  def __init__(self, keepdims=False, **kwargs):
    super().__init__(**kwargs)
    self.keepdims = keepdims

  def build(self, input_shape):
    embed_dim = input_shape[-1]
    dim_out = embed_dim if self.keepdims else 2 * embed_dim
    self.pad_1 = layers.ZeroPadding2D(1, name='pad1')
    self.pad_2 = layers.ZeroPadding2D(1, name='pad2')
    self.conv = [
        layers.DepthwiseConv2D(
            kernel_size=3, strides=1, padding='valid', use_bias=False, name='conv_0'
        ),
        layers.Activation('gelu', name='conv_1'),
        SqueezeAndExcitation(name='conv_2'),
        layers.Conv2D(
            embed_dim,
            kernel_size=1,
            strides=1,
            padding='valid',
            use_bias=False,
            name='conv_3',
        ),
    ]
    self.reduction = layers.Conv2D(
        dim_out,
        kernel_size=3,
        strides=2,
        padding='valid',
        use_bias=False,
         name='reduction',
    )
    self.norm1 = layers.LayerNormalization(-1, 1e-05, name='norm1')
    self.norm2 = layers.LayerNormalization(-1, 1e-05, name='norm2')

  def call(self, inputs, **kwargs):
    x = self.norm1(inputs)
    xr = self.pad1(x)
    for layer in self.conv:
      xr = layer(xr)
    x = x + xr
    x = self.pad2(x)
    x = self.reduction(x)
    x = self.norm2(x)
    return x

# Multi-Layer Perceptron (MLP) block
class MLP(layers.Layer):
  '''
  Args:
      hidden_features=None,
      out_features=None,
      activation='gelu',
      dropout=0.0,
      **kwargs,
  '''
  def __init__(
      self,
      hidden_features=None,
      out_features=None,
      activation='gelu',
      dropout=0.0,
      **kwargs,
  ):
      super().__init__(**kwargs)
      self.hidden_features = hidden_features
      self.out_features = out_features
      self.activation = activation
      self.dropout = dropout

  def build(self, input_shape):
    self.in_features = input_shape[-1]
    self.hidden_features = self.hidden_features or self.in_features
    self.out_features = self.out_features or self.in_features
    self.fc1 = layers.Dense(self.hidden_features, name='fc1')
    self.act = layers.Activation(self.activation, name='act')
    self.fc2 = layers.Dense(self.out_features, name='fc2')
    self.drop1 = layers.Dropout(self.dropout, name='drop1')
    self.drop2 = layers.Dropout(self.dropout, name='drop2')

  def call(self, inputs, **kwargs):
    x = self.fc1(inputs)
    x = self.act(x)
    x = self.drop1(x)
    x = self.fc2(x)
    x = self.drop2(x)
    return x

### Stem

**Notes**: In the code, this module is referred to as PatchEmbed, but in the paper, it is called Stem.

- **PatchEmbed Module**:
    - **Padding**: The module first pads the input.
    - **Convolutions**: Uses convolutions to extract patches with embeddings.
    - **ReduceSize Module**: Utilizes this module to extract features with convolution, without reducing or increasing the spatial dimension.
    - **Overlapping Patches**: Unlike ViT or SwinTransformer, GCViT creates overlapping patches. This is indicated by `Conv2D(self.embed_dim, kernel_size=3, strides=2, name='proj')`. Non-overlapping patches would have used the same kernel_size and stride.
    - **Spatial Dimension Reduction**: This module reduces the spatial dimension of the input by 4x.

**Summary**:

image → padding → convolution → (feature extraction + downsample)

In [4]:
# Patch embedding block
class PatchEmbed(layers.Layer):
  '''
  Args:
      embed_dim: output features dimension
  '''

  def __init__(self, embed_dim, **kwargs):
    super().__init__(**kwargs)
    self.embed_dim = embed_dim

  def build(self, input_shape):
    self.pad = layers.ZeroPadding2D(1, name='pad')
    self.proj = layers.Conv2D(self.embed_dim, 3, 2, name='proj')
    self.conv_down = ReduceSize(keepdims=True, name='conv_down')

  def call(self, inputs, **kwargs):
    x = self.pad(inputs)
    x = self.proj(x)
    x = self.conv_down(x)
    return x

### Global Token Gen.

**Notes**: This is one of the two CNN modules used to impose inductive bias.

- **Global Token Gen./Feature Extraction**:
    - **Purpose**: In the level, this module is used to convert the input into global tokens for global-context-attention.
    - **Repetition**: According to the paper, this module should be repeated K times, where \( K = \log_2(\frac{H}{h}) \). Here, \( H \) and \( W \) are the height and width of the feature map, and \( h \) and \( w \) are the reduced dimensions.
    - **FeatureExtraction**: Similar to the ReduceSize module but with key differences:
        - **MaxPooling**: Used to reduce the spatial dimensions.
        - **No Channel Increase**: Does not increase the feature dimension (channel-wise).
        - **No LayerNormalization**: Does not use LayerNormalization.
    - **Global Tokens**: Shared across the entire image, using only one global window for all local tokens in an image, making computation efficient.
    - **Shape Transformation**:
        - For input feature map with shape (B, H, W, C), the output shape will be (B, h, w, C).
        - If global tokens are copied for a total of M local windows in an image, where \( M = \frac{H \times W}{h \times w} \) (num_window), the output shape will be (B * M, h, w, C).

**Summary**: This module resizes the image to fit the window, creating global tokens for efficient computation.

<br>

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/global_token_annot.png)



In [5]:
# Feature Extraction Block
class FeatureExtraction(layers.Layer):
  '''
  Args:
      Keepdims: bool argument for maintaining the resolution
  '''
  def __init__(self, keepdims=False, **kwargs):
    super().__init__(**kwargs)
    self.keepdims = keepdims

  def build(self, input_shape):
    embed_dim = input_shape[-1]
    self.pad1 = layers.ZeroPadding2D(1, name='pad1')
    self.pad2 = layers.ZeroPadding2D(1, name='pad2')
    self.conv = [
        layers.DepthwiseConv2D(3, 1, use_bias=False, name='conv_0'),
        layers.Activation('gelu', name='conv_1'),
        SqueezeAndExcitation(name='conv_2'),
        layers.Conv2D(embed_dim, 1, 1, use_bias=False, name='conv_3'),
    ]
    if not self.keepdims:
      self.pool = layers.MaxPool2D(3, 2, name='pool')
    super().build(input_shape)

  def call(self, inputs, **kwargs):
    x = inputs
    xr = self.pad1(x)
    for layer in self.conv:
      xr = layer(xr)
    x = x + xr
    if not self.keepdims:
      x = self.pool(self.pad2(x))
    return x

# Global query generator
class GlobalQueryGenerator(layers.Layer):
  '''
  Args:
    keepdims: to keep the dimension of FeatureExtraction Layer.
      For instance, repeating log(56/7) = 3 blocks, with input
      window dimension 56 and output window dimension 7 at down-sampling
      ratio 2.
  '''

  def __init__(self, keepdims=False, **kwargs):
    super().__init__(**kwargs)
    self.keepdims = keepdims

  def build(self, input_shape):
    self.to_q_global = [
        FeatureExtraction(keepdims, name=f'to_q_global_{i}')
        for i, keepdims in enumerate(self.keepdims)
    ]
    super().build(input_shape)

  def call(self, inputs, **kwargs):
    x = inputs
    for layer in self.to_q_global:
      x = layer(x)
    return x




### Attention

**Notes**: This is the core contribution of the paper.

- **WindowAttention Module**:
    - **Local and Global Attention**: Applies either local or global window attention depending on the `global_query` parameter.
    - **Query, Key, Value Creation**:
        1. Converts input features into query, key, and value for local attention.
        2. Converts input features into key and value for global attention.
        3. Global query is taken from the Global Token Gen.

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_msa.PNG)

  - **Computation Reduction**: Features (embed_dim) are divided among all heads of the Transformer to reduce computation.
       
  - **Global Token Process**:
      - Global tokens are copied for all local windows to increase efficiency.
      - Here, `B_//B` means the number of windows in an image.
  - **Attention Application**: Applies either local-window-self-attention or global-window-attention based on the `global_query` parameter.

![](https://raw.githubusercontent.com/awsaf49/gcvit-tf/main/image/lvg_arch.PNG)

  - **Relative Positional Embedding**: Adds relative-positional-embedding with the attention mask instead of the patch embedding.
        

**Explanation**:
- **Local Attention**: The query is local, limited to the local window (red square border), hence no access to long-range information.
- **Global Attention**: With the global query, it is not limited to local windows (blue square border), allowing access to long-range information.
- **Comparative Attention**:
    - **ViT**: Compares image-tokens with image-tokens.
    - **SwinTransformer**: Compares window-tokens with window-tokens.
    - **GCViT**: Compares image-tokens with window-tokens by resizing image-tokens to fit window-tokens using the Global Token Gen./FeatureExtraction CNN module.

**Comparison Table**:

| Model             | Query Tokens     | Key-Value Tokens | Attention Type       | Attention Coverage |
|-------------------|------------------|------------------|----------------------|--------------------|
| **ViT**           | image            | image            | self-attention       | global             |
| **SwinTransformer**| window           | window           | self-attention       | local              |
| **GCViT**         | resized-image    | window           | image-window attention | global             |

**Summary**:
- The Global Token Gen./FeatureExtraction CNN module resizes image-tokens to fit window-tokens, enabling GCViT to perform efficient global attention by comparing resized image-tokens with window-tokens.