In [18]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

In [19]:
print(tf.__version__)

2.3.0


In [20]:
# load images to train a model
# For shorter training time, We'll use caltech101 instead of imagenet used in the paper
import pathlib

data_dir = pathlib.Path(r'C:\Users\K\tensorflow_datasets\caltech101')

batch_size = 32
img_height = 256
img_width = 256

train_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,
                                                       label_mode='categorical',
                                                       validation_split=0.2,
                                                       subset="training",
                                                       seed=123,
                                                       image_size=(img_height, img_width),
                                                       batch_size=batch_size)

val_ds = tf.keras.preprocessing.image_dataset_from_directory(data_dir,
                                                     label_mode='categorical',
                                                     validation_split=0.2,
                                                     subset="validation",
                                                     seed=123,
                                                     image_size=(img_height, img_width),
                                                     batch_size=batch_size)

Found 9144 files belonging to 102 classes.
Using 7316 files for training.
Found 9144 files belonging to 102 classes.
Using 1828 files for validation.


In [21]:
class Patch(tf.keras.layers.Layer):
    """coverts input images to patches"""
    def __init__(self, patch_size, **kwards):
        super(Patch, self).__init__(**kwards)
        self.patch_size = patch_size
    
    def call(self, inputs):
        patches = self.convert_to_patches(inputs, self.patch_size)
        return patches
    
    def convert_to_patches(self, images, patch_size):
        """convert batch of images to batch of flattened patches"""
        # shape of images : (batch_size, width, height, channels)
        # shape of output : (batch_size, no. of flattened patches, patch_size)
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(images=images, 
                                           sizes=[1, patch_size, patch_size, 1], 
                                           strides=[1, patch_size, patch_size, 1], 
                                           rates=[1, 1, 1, 1], 
                                           padding='VALID')
        flattened_size = tf.shape(patches)[-1]
        patches = tf.reshape(patches, shape=[batch_size, -1, flattened_size])
    
        return patches

In [22]:
for images, labels in train_ds.take(1):
    print(images.shape)
    print(type(images))
    patches = Patch(32)(images)
    print(patches.shape)
    print(type(patches))

(32, 256, 256, 3)
<class 'tensorflow.python.framework.ops.EagerTensor'>
(32, 64, 3072)
<class 'tensorflow.python.framework.ops.EagerTensor'>


In [55]:
class Projection(tf.keras.layers.Layer):
    """linear projection of flattened patches"""
    def __init__(self, d_model, **kwards):
        super(Projection, self).__init__(**kwards)
        self.d_model = d_model
        self.another = tf.keras.layers.Dense(units=d_model)
        self.project = tf.keras.layers.Dense(units=d_model)
        self.cls_token = self.add_weight(name='class token',
                                        shape=(1, 1, d_model),
                                        initializer=tf.initializers.RandomNormal(),
                                        trainable=True)
        
    def call(self, inputs):
        cls_token = tf.tile(self.cls_token, [tf.shape(inputs)[0], 1, 1])
        inputs = self.another(inputs)
        inputs = self.project(inputs)
        return tf.concat([inputs, cls_token], axis=1)

In [37]:
print(patches.shape)
patches = Projection(50)(patches)
print(patches.shape)

(32, 65, 50)
(32, 66, 50)


In [38]:
class Pos_embedding(tf.keras.layers.Layer):
    """add standard 1D positional embedding"""
    def __init__(self, **kwards):
        super(Pos_embedding, self).__init__(**kwards)
        
    def build(self, input_shape):
        self.pos_embedding = self.add_weight(name='pos_embedding',
                                            shape=(1, input_shape[1], input_shape[2]),
                                            initializer=tf.initializers.RandomNormal(),
                                            trainable=True)
        
    def call(self, inputs):
        return inputs + self.pos_embedding

In [39]:
print(patches.shape)
patches = Pos_embedding()(patches)
print(patches.shape)

(32, 66, 50)
(32, 66, 50)


In [40]:
class MLP(tf.keras.layers.Layer):
    """MLP layer in encoder of the transformer"""
    def __init__(self, d_model, mlp_dim, dropout_rate, **kwards):
        super(MLP, self).__init__(**kwards)
        self.net = tf.keras.Sequential([tf.keras.layers.Dense(mlp_dim, activation='relu'),
                                      tf.keras.layers.Dropout(dropout_rate),
                                      tf.keras.layers.Dense(d_model),
                                      tf.keras.layers.Dropout(dropout_rate)])
    def call(self, inputs):
        return self.net(inputs)

In [41]:
def scaled_dot_product_attention(q, k, v, mask):
    """Calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead)
    but it must be broadcastable for addition.

    Args:
        q: query shape == (..., seq_len_q, depth)
        k: key shape == (..., seq_len_k, depth)
        v: value shape == (..., seq_len_v, depth_v)
        mask: Float tensor with shape broadcastable
            to (..., seq_len_q, seq_len_k). Defaults to None.
    
    Returns:
        output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

In [42]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention,
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output

In [43]:
class EncoderBlock(tf.keras.layers.Layer):
    """Transformer encoder block."""
    def __init__(self, d_model,
                 mlp_dim, num_heads, dropout_rate, use_bias=False,
                 **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.mlp = MLP(d_model=d_model, mlp_dim=mlp_dim, dropout_rate=dropout_rate)
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)

    def call(self, X, mask=False):
        output = self.layernorm1(X)
        output = self.dropout1(output)
        output = self.attention(v=X, k=X, q=X, mask=None) + X
        
        output2 = self.layernorm2(output)
        output2 = self.dropout2(output2)
        output2 = self.mlp(output2) + output
        
        return output2

In [44]:
print(patches.shape)
patches = EncoderBlock(50, 100, 10, 0.1)(patches, None)
print(patches.shape)

(32, 66, 50)
(32, 66, 50)


In [74]:
class ViT(tf.keras.Model):
    """Vision Transformer model"""
    def __init__(self, d_model, mlp_dim,
                 num_heads, dropout_rate, num_layers, 
                 patch_size, num_classes, use_bias=False, **kwards):
        super(ViT, self).__init__(**kwards)
        self.d_model = d_model
        self.patch = Patch(patch_size)
        self.projection = Projection(d_model)
        self.pos_embedding = Pos_embedding()
        self.blocks = []
        for _ in range(num_layers):
            self.blocks.append(EncoderBlock(d_model, mlp_dim, num_heads, dropout_rate, use_bias))
        self.mlp_head = tf.keras.Sequential([tf.keras.layers.LayerNormalization(epsilon=1e-6),
                                            tf.keras.layers.Dense(num_classes)])
    
    def call(self, X):
        X = self.patch(X)
        X = self.projection(X)
        X = self.pos_embedding(X)
        for blk in self.blocks:
            X = blk(X)
        X = X[:, 0]
        X = self.mlp_head(X)
        return X


In [75]:
for images, labels in train_ds.take(1):
    print(images.shape)
    result = ViT(50, 100, 10, 0.1, 3, 32, 102)(images)
    print(result.shape)

(32, 256, 256, 3)
(32, 102)


In [76]:
model = ViT(50, 100, 10, 0.001, 3, 32, 102)
model.build((32,256,256,3))
model.summary()

Model: "vi_t_18"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
patch_27 (Patch)             multiple                  0         
_________________________________________________________________
projection_28 (Projection)   multiple                  156250    
_________________________________________________________________
pos_embedding_22 (Pos_embedd multiple                  3250      
_________________________________________________________________
encoder_block_57 (EncoderBlo multiple                  20550     
_________________________________________________________________
encoder_block_58 (EncoderBlo multiple                  20550     
_________________________________________________________________
encoder_block_59 (EncoderBlo multiple                  20550     
_________________________________________________________________
sequential_85 (Sequential)   (32, 102)                 5302

In [None]:
train_ds = train_ds.cache().prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

model = ViT(50, 100, 10, 0.001, 3, 32, 102)

def loss(model, x, y, training):
    # training=training is needed only if there are layers with different
    # behavior during training versus inference (e.g. Dropout).
    y_ = model(x, training=training)

    return tf.losses.CategoricalCrossentropy(from_logits=True)(y_true=y, y_pred=y_)

def grad(model, inputs, targets):
    with tf.GradientTape() as tape:
        loss_value = loss(model, inputs, targets, training=True)
    return loss_value, tape.gradient(loss_value, model.trainable_variables)

# Keep results for plotting
train_loss_results = []
train_accuracy_results = []

num_epochs = 3
optimizer = tf.keras.optimizers.SGD(learning_rate=0.01)

for epoch in range(num_epochs):
    epoch_loss =tf.losses.CategoricalCrossentropy(from_logits=True)
    epoch_accuracy = tf.keras.metrics.CategoricalAccuracy()

    # Training loop - using batches of 32
    for x, y in train_ds:
        # Optimize the model
        loss_value, grads = grad(model, x, y)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

    # Track progress
    epoch_accuracy.update_state(y, model(x, training=True))

    # End epoch
    train_accuracy_results.append(epoch_accuracy.result())

    print("Epoch {:03d}: Accuracy: {:.3%}".format(epoch, epoch_accuracy.result()))