<a href="https://colab.research.google.com/github/Sundragon1993/pyramid_vision_transformer/blob/main/%5BSO%5DPyramid_Vision_Transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U tensorflow-addons

Collecting tensorflow-addons
[?25l  Downloading https://files.pythonhosted.org/packages/74/e3/56d2fe76f0bb7c88ed9b2a6a557e25e83e252aec08f13de34369cd850a0b/tensorflow_addons-0.12.1-cp37-cp37m-manylinux2010_x86_64.whl (703kB)
[K     |▌                               | 10kB 23.7MB/s eta 0:00:01[K     |█                               | 20kB 17.8MB/s eta 0:00:01[K     |█▍                              | 30kB 15.0MB/s eta 0:00:01[K     |█▉                              | 40kB 13.9MB/s eta 0:00:01[K     |██▎                             | 51kB 7.5MB/s eta 0:00:01[K     |██▉                             | 61kB 7.2MB/s eta 0:00:01[K     |███▎                            | 71kB 8.1MB/s eta 0:00:01[K     |███▊                            | 81kB 8.8MB/s eta 0:00:01[K     |████▏                           | 92kB 9.1MB/s eta 0:00:01[K     |████▋                           | 102kB 7.5MB/s eta 0:00:01[K     |█████▏                          | 112kB 7.5MB/s eta 0:00:01[K     |█████▋     

# Data loading and preprocessing

In [None]:
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import tensorflow_addons as tfa

In [None]:
num_classes = 100
input_shape = (32,32,3)
(xtrain,ytrain), (xtest,ytest) = keras.datasets.cifar100.load_data()

print(f"x_train shape: {xtrain.shape} - y_train shape: {ytrain.shape}")
print(f"x_test shape: {xtest.shape} - y_test shape: {ytest.shape}")

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
x_train shape: (50000, 32, 32, 3) - y_train shape: (50000, 1)
x_test shape: (10000, 32, 32, 3) - y_test shape: (10000, 1)


In [None]:
patch_size_1 = 4 
patch_size_2 = 8 
patch_size_3 = 16 
patch_size_4 = 32 


learning_rate = 0.003
weight_decay = 0.1
batch_size = 512 
num_epochs = 100 
image_size = 72  
projection_dim = 64
num_heads = 5 
transformer_units = [
    projection_dim * 2,
    projection_dim,
]  # Size of the transformer layers
transformer_layers = 10 
mlp_head_units = [3072, 768]  # Size of the dense layers of the final classifier


In [None]:
data_augmentation = keras.Sequential(
    [
        layers.experimental.preprocessing.Normalization(),
        layers.experimental.preprocessing.Resizing(image_size, image_size),
        layers.experimental.preprocessing.RandomFlip("horizontal"),
        layers.experimental.preprocessing.RandomRotation(factor=0.02),
        layers.experimental.preprocessing.RandomZoom(
            height_factor=0.2, width_factor=0.2
        ),
    ],
    name="data_augmentation",
)
# Compute the mean and the variance of the training data for normalization.
data_augmentation.layers[0].adapt(xtrain)


# Training pipeline

### Patch layer

In [None]:
class Patch(layers.Layer):
  def __init__(self, patch_size):
    super(Patch, self).__init__()
    self.patch_size = patch_size
  
  def call(self, images):
    batch_size = tf.shape(images)[0]
    patches = tf.image.extract_patches(
        images = images,
        sizes = [1, self.patch_size, self.patch_size, 1],
        strides = [1, self.patch_size, self.patch_size, 1],
        rates = [1,1,1,1],
        padding = 'VALID'
    )
    patch_dims = patches.shape[-1]
    patches = tf.reshape(patches, [batch_size, -1, patch_dims]) 
    return patches



### Patch Encoder Layer

In [None]:
class PatchEncoder(layers.Layer):
  def __init__(self, num_patches, projection_dim):
    super(PatchEncoder, self).__init__()
    self.num_patches = num_patches
    self.projection = layers.Dense(units=projection_dim)
    self.position_embedding = layers.Embedding(
        input_dim=num_patches, output_dim=projection_dim
    )

  def call(self, patch):
    positions = tf.range(start=0, limit=self.num_patches, delta=1)
    encoded = self.projection(patch) + self.position_embedding(positions)
    return encoded
  
  

In [None]:
def mlp(x, hidden_units, dropout_rate):
  for units in hidden_units:
    x = layers.Dense(units, activation=tf.nn.gelu)(x)
    x = layers.Dropout(dropout_rate)(x)
  return x

### Transformer Layer

In [None]:
class transformer(layers.Layer):
  
  def __init__(self, num_heads, transformer_layers, patch_size, projection_dim):
    super(transformer, self).__init__()
    self.num_heads = num_heads
    self.transformer_layers = transformer_layers
    self.patch_size = patch_size
    self.projection_dim = projection_dim
  
  def call(self, encoded_patches):
    for _ in range(self.transformer_layers):
        x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = layers.MultiHeadAttention(
            num_heads = self.num_heads, key_dim = self.projection_dim, dropout = 0.1 
        )(x1,x1)
        x2 = attention_output + encoded_patches
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, transformer_units, 0.2)
        encoded_patches = x3 + x2

    representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    resize_img = layers.Reshape([image_size // self.patch_size, image_size // self.patch_size, self.projection_dim])(representation)
    
    return resize_img 

    


# Model definition

In [None]:
def PVT():

  # Inputs
  input = layers.Input(shape=input_shape)
  augment = data_augmentation(input)

  # Stage 1
  patches_1 = Patch(patch_size_1)(augment)
  patches_1 = PatchEncoder(num_patches=(image_size // patch_size_1) ** 2, projection_dim=64)(patches_1)
  input_2 = transformer(1, transformer_layers, patch_size_1, 64)(patches_1) #Output 1

  # Stage 2
  patches_2 = Patch(patch_size_2)(input_2)
  patches_2 = PatchEncoder(num_patches=(image_size // patch_size_2) ** 2, projection_dim=128)(patches_2)
  input_3 = transformer(2, transformer_layers, patch_size_2, 128)(patches_2) #Output 2

  # # Stage 3
  # patches_3 = Patch(patch_size_3)(input_3)
  # patches_3 = PatchEncoder(num_patches=(image_size // patch_size_3) ** 2, projection_dim=320)(patches_3)
  # input_4 = transformer(5, transformer_layers, patch_size_3, 320)(patches_3) #Output 3

  # # Stage 4
  # patches_4 = Patch(patch_size_4)(input_4)
  # patches_4 = PatchEncoder(num_patches=(image_size // patch_size_4) ** 2, projection_dim=512)(patches_4)
  # input_5 = transformer(8, transformer_layers, patch_size_4, 512)(patches_4) #Output 4

  representation = layers.Flatten()(input_3)
  representation = layers.Dropout(0.5)(representation)
  # Classify outputs.
  logits = layers.Dense(num_classes)(representation)
  # Create the Keras model.  
  model = keras.Model(inputs=input, outputs=logits)

  return model

In [None]:
pvt = PVT()
pvt.summary()

ValueError: ignored

# Training

In [None]:
def run_experiment(model):

    optimizer = tfa.optimizers.AdamW(
        learning_rate=learning_rate, weight_decay=weight_decay, beta_1=0.9, beta_2=0.999
    )

    model.compile(
        optimizer=optimizer,
        loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )


    history = model.fit(
        x=xtrain,
        y=ytrain,
        batch_size=batch_size,
        epochs=5,
        validation_split=0.1
    )

    model.save('model-5.h5')

    return history

pvt = PVT()
history = run_experiment(pvt)


Epoch 1/5


ValueError: ignored