### [IN PROGRESS]

Sources used in this notebook:\
    1. [This google colab that implements a simpler version](https://colab.research.google.com/github/dzlab/notebooks/blob/master/_notebooks/2022-02-27-Swin_Transfomer.ipynb#scrollTo=-UnAWaONhf9J)\
    2. [The model outlined in the original research paper](https://arxiv.org/pdf/2103.14030.pdf)\
    3. [To load the weights of the pretrained model](https://tfhub.dev/sayakpaul/swin_s3_tiny_224/1)

In [3]:
import os
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
import tensorflow.keras.layers as tfl

from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
    Add, Dense, Dropout, Embedding, GlobalAveragePooling1D, Input, Layer, LayerNormalization, MultiHeadAttention,
    Softmax
)
from tensorflow.keras.initializers import TruncatedNormal

In [4]:
notebook_path = os.path.dirname(os.path.abspath('brain_MRI_classification.ipynb'))
datasets_combined = os.path.join(notebook_path, 'brainMRI_data')

train_directory = os.path.join(datasets_combined, 'Training')
test_directory = os.path.join(datasets_combined, 'Testing')

In [5]:
BATCH_SIZE = 64
IMG_SIZE = (224, 224)

# using 'int' to use sparse_categorical_crossentropy for loss
train_dataset = image_dataset_from_directory(train_directory,
                                             batch_size = BATCH_SIZE,
                                             image_size = IMG_SIZE,
                                             shuffle = True,
                                             validation_split = 0.2,
                                             subset = 'training',
                                             seed = 42,
                                             label_mode='int')

validation_dataset = image_dataset_from_directory(train_directory,
                                                  batch_size = BATCH_SIZE,
                                                  image_size = IMG_SIZE,
                                                  shuffle = True,
                                                  validation_split = 0.2,
                                                  subset = 'validation',
                                                  seed = 42,
                                                  label_mode='int')

test_dataset = image_dataset_from_directory(test_directory,
                                            shuffle = False,
                                            image_size = IMG_SIZE,
                                            label_mode='int')

Found 2870 files belonging to 4 classes.
Using 2296 files for training.
Found 2870 files belonging to 4 classes.
Using 574 files for validation.
Found 394 files belonging to 4 classes.


In [6]:
class PatchPartition(Layer):
    def __init__(self, window_size = 4, channels = 3):
        super(PatchPartition, self).__init__()
        self.window_size = window_size
        
    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images = images,
            # 4x4 patches with stride of 4 for non-overlapping patches
            sizes = [1, self.window_size, self.window_size, 1],
            strides = [1, self.window_size, self.window_size, 1],
            rates = [1, 1, 1, 1],
            padding = 'VALID',
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches

In [7]:
class LinearEmbedding(Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super(LinearEmbedding, self).__init__(**kwargs)
        self.num_patches = num_patches
        self.projection = Dense(projection_dim)
        self.position_embedding = Embedding(input_dim = self.num_patches, output_dim = projection_dim)
        
    def call(self, patch):
        patches_embed = self.projection(patch)
        positions = tf.range(start = 0, limit = self.num_patches, delta = 1)
        positions_embed = self.position_embedding(positions)
        encoded = patches_embed + positions_embed
        return encoded

In [8]:
class PatchMerging(Layer):
    def __init__(self, input_resolution, channels):
        super(PatchMerging, self).__init__()
        self.input_resolution = input_resolution
        self.channels = channels
        self.linear_trans = Dense(2 * channels, use_bias = False)
        
    def call(self, x):
        height, width = self.input_resolution
        _, _, C = x.get_shape().as_list()
        x = tf.reshape(x, shape = (-1, height, width, C))
        x0 = x[:, 0::2, 0::2, :]
        x1 = x[:, 1::2, 0::2, :]
        x2 = x[:, 0::2, 1::2, :]
        x3 = x[:, 1::2, 1::2, :]
        x = tf.concat((x0, x1, x2, x3), axis = -1)
        x = tf.reshape(x, shape = (-1, (height // 2) * (width // 2), 4 * C))
        return self.linear_trans(x)

In [9]:
# Multilayer Perceptron (MLP) -- 'feedforward NN'
class MLP(Layer):
    def __init__(self, hidden_features, out_features, dropout_rate = 0.1):
        super(MLP, self).__init__()
        self.dense1 = Dense(hidden_features, activation = tf.nn.gelu)
        self.dense2 = Dense(out_features)
        self.dropout = Dropout(dropout_rate)
        
    def call(self, x):
        x = self.dense1(x)
        x = self.dropout(x)
        x = self.dense2(x)
        y = self.dropout(x)
        return y

In [10]:
class WindowAttention(Layer):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
        super().__init__()
        self.dim = dim
        self.window_size = window_size  # Wh, Ww
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        # define a parameter table of relative position bias
        initializer = TruncatedNormal(mean=0., stddev=.02)
        # position table shape is: (2*Wh-1 * 2*Ww-1, nH)
        table_shape = ((2*self.window_size[0]-1) * (2*self.window_size[1]-1), num_heads)
        self.relative_position_bias_table = tf.Variable(initializer(shape=table_shape))

        # get pair-wise relative position index for each token inside the window
        coords_h = tf.range(self.window_size[0])
        coords_w = tf.range(self.window_size[1])
        coords = tf.stack(tf.meshgrid(coords_h, coords_w))  # 2, Wh, Ww
        coords_flatten = tf.reshape(coords, [2, -1])  # 2, Wh*Ww
        relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]  # 2, Wh*Ww, Wh*Ww
        relative_coords = tf.transpose(relative_coords, perm=[1,2,0]) # Wh*Ww, Wh*Ww, 2
        relative_coords = relative_coords + [self.window_size[0] - 1, self.window_size[1] - 1]  # shift to start from 0
        relative_coords = relative_coords * [2*self.window_size[1] - 1, 1]
        self.relative_position_index = tf.math.reduce_sum(relative_coords,-1)  # Wh*Ww, Wh*Ww

        self.qkv = Dense(dim * 3, use_bias=qkv_bias, kernel_initializer=initializer)
        self.attn_drop = Dropout(attn_drop)
        self.proj = Dense(dim, kernel_initializer=initializer)
        self.proj_drop = Dropout(proj_drop)
        self.softmax = Softmax(axis=-1)

    def call(self, x, mask=None):
        _, L, N, C = x.shape
        qkv = tf.transpose(tf.reshape(self.qkv(x), [-1, N, 3, self.num_heads, C // self.num_heads]), perm=[2, 0, 3, 1, 4]) # [3, B_, num_head, Ww*Wh, C//num_head]
        q, k, v = tf.unstack(qkv)  # make torchscript happy (cannot use tensor as tuple)
        q = q * self.scale
        attn = tf.einsum('...ij,...kj->...ik', q, k)
        relative_position_bias = tf.reshape(tf.gather(self.relative_position_bias_table, tf.reshape(self.relative_position_index, [-1])),
            [self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1])  # Wh*Ww,Wh*Ww,nH
        relative_position_bias = tf.transpose(relative_position_bias, perm=[2, 0, 1])  # nH, Wh*Ww, Wh*Ww
        attn = attn + relative_position_bias

        if mask is not None:
            nW = mask.shape[0] # every window has different mask [nW, N, N]
            attn = tf.reshape(attn, [-1 // nW, nW, self.num_heads, N, N]) + mask[:, None, :, :] # add mask: make each component -inf or just leave it
            attn = tf.reshape(attn, [-1, self.num_heads, N, N])
            attn = self.softmax(attn)
        else:
            attn = self.softmax(attn)

        attn = self.attn_drop(attn)

        x = tf.reshape(tf.transpose(attn @ v, perm=[0, 2, 1, 3]), [-1, L, N, C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

In [11]:
def window_partition(x, window_size):
    _, H, W, C = x.shape
    num_patch_y = H // window_size
    num_patch_x = W // window_size
    x = tf.reshape(x, [-1, num_patch_y, window_size, num_patch_x, window_size, C])
    x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
    windows = tf.reshape(x, [-1, num_patch_x * num_patch_y, window_size, window_size, C])
    return windows

In [12]:
def window_reverse(windows, window_size, H, W):
    C = windows.shape[-1]
    B = int(windows.shape[1] / (H * W / window_size / window_size))
    x = tf.reshape(windows, [B, H // window_size, W // window_size, window_size, window_size, C])
    x = tf.reshape(tf.transpose(x, perm=[0, 1, 3, 2, 4, 5]), [-1, H, W, C])
    return x

In [13]:
class DropPath(Layer):
    def __init__(self, prob):
        super().__init__()
        self.drop_prob = prob

    def call(self, x, training=None):
        if self.drop_prob == 0. or not training:
            return x
        keep_prob = 1 - self.drop_prob
        shape = (x.shape[0],) + (1,) * (x.ndim - 1)
        random_tensor = tf.random.uniform(shape=shape)
        random_tensor = tf.where(random_tensor < keep_prob, 1, 0)
        output = x / keep_prob * random_tensor
        return output

In [14]:
class SwinTransformerBlock(Layer):

    def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
                 mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.):
        super().__init__()
        self.dim = dim
        self.input_resolution = input_resolution
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio
        if min(self.input_resolution) <= self.window_size:
            # if window size is larger than input resolution, we don't partition windows
            self.shift_size = 0
            self.window_size = min(self.input_resolution)
        assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"

        self.norm1 = LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(
            dim, window_size=(self.window_size, self.window_size), num_heads=num_heads,
            qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else tf.identity
        self.norm2 = LayerNormalization(epsilon=1e-5)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = MLP(mlp_hidden_dim, dim, dropout_rate=drop)

        if self.shift_size > 0:
            # calculate attention mask for SW-MSA
            H, W = self.input_resolution
            img_mask = np.zeros([1, H, W, 1])  # 1 H W 1
            h_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            w_slices = (slice(0, -self.window_size),
                        slice(-self.window_size, -self.shift_size),
                        slice(-self.shift_size, None))
            cnt = 0
            for h in h_slices:
                for w in w_slices:
                    img_mask[:, h, w, :] = cnt
                    cnt += 1
            img_mask = tf.constant(img_mask)
            mask_windows = window_partition(img_mask, self.window_size)  # nW, window_size, window_size, 1
            mask_windows = tf.reshape(mask_windows, [-1, self.window_size * self.window_size])
            attn_mask = mask_windows[:, None, :] - mask_windows[:, :, None]
            self.attn_mask = tf.where(attn_mask==0, -100., 0.)
        else:
            self.attn_mask = None

    def call(self, x):
        H, W = self.input_resolution
        B, L, C = x.shape
        assert L == H * W, "input feature has wrong size"

        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x, [-1, H, W, C])

        # cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=(1, 2))
        else:
            shifted_x = x

        # partition windows
        x_windows = window_partition(shifted_x, self.window_size)  # nW*B, window_size, window_size, C
        x_windows = tf.reshape(x_windows, [-1, x_windows.shape[1], self.window_size * self.window_size, C])  # nW*B, window_size*window_size, C

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows, mask=self.attn_mask)  # nW*B, window_size*window_size, C

        # merge windows
        attn_windows = tf.reshape(attn_windows, [-1, x_windows.shape[1], self.window_size, self.window_size, C])
        shifted_x = window_reverse(attn_windows, self.window_size, H, W)  # B H' W' C

        # reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x, shift=[self.shift_size, self.shift_size], axis=(1, 2))
        else:
            x = shifted_x
        x = tf.reshape(x, [-1, H * W, C])

        # FFN
        x = shortcut + self.drop_path(x)
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

In [15]:
def create_SwinTransformer(num_classes, input_shape=(224, 224, 3), window_size=7, embed_dim=96, num_heads=3):
    num_patch_x = input_shape[0] // 4
    num_patch_y = input_shape[1] // 4
    inputs = Input(shape=input_shape)
    
    # Patch Partition
    patches = PatchPartition(window_size = 4)(inputs)
    
    ## Stage 1
    # Linear Embedding
    patches_embed = LinearEmbedding(num_patches = num_patch_x * num_patch_y, projection_dim = embed_dim)(patches)

    # Swin Transformer block (first)
    out_stage_1 = SwinTransformerBlock(
        dim = embed_dim,
        input_resolution = (num_patch_x, num_patch_y),
        num_heads = num_heads,
        window_size = window_size,
        shift_size = 0
    )(patches_embed)
    # Swin Transformer block (second)
    out_stage_1 = SwinTransformerBlock(
        dim = embed_dim,
        input_resolution = (num_patch_x, num_patch_y),
        num_heads = num_heads,
        window_size = window_size,
        shift_size = 1
    )(out_stage_1)
    
    ## Stage 2
    # Patch Merging
    pm_stage_2 = PatchMerging((num_patch_x, num_patch_y), channels=embed_dim)(out_stage_1)
    
    factor = [2 ** (i + 1) for i in range(3)]
    
    # Swin Transformer block (first)
    out_stage_2 = SwinTransformerBlock(
        dim = factor[0] * embed_dim,
        input_resolution = (num_patch_x // factor[0], num_patch_y // factor[0]),
        num_heads = factor[0] * num_heads,
        window_size = window_size,
        shift_size = 0
    )(pm_stage_2)
    # Swin Transformer block (second)
    out_stage_2 = SwinTransformerBlock(
        dim = factor[0] * embed_dim,
        input_resolution = (num_patch_x // factor[0], num_patch_y // factor[0]),
        num_heads = factor[0] * num_heads,
        window_size = window_size,
        shift_size = 1
    )(out_stage_2)
    
    ## Stage 3
    # Patch Merging
    pm_stage_3 = PatchMerging((num_patch_x // factor[0], num_patch_y // factor[0]), channels = factor[0] * embed_dim)(out_stage_2)
    
    out_stage_3 = pm_stage_3
    for _ in range(3):
        # Swin Transformer block (1)
        out_stage_3 = SwinTransformerBlock(
            dim = factor[1] * embed_dim,
            input_resolution = (num_patch_x // factor[1], num_patch_y // factor[1]),
            num_heads = factor[1] * num_heads,
            window_size = window_size,
            shift_size = 0
        )(out_stage_3)
        # Swin Transformer block (2)
        out_stage_3 = SwinTransformerBlock(
            dim = factor[1] * embed_dim,
            input_resolution = (num_patch_x // factor[1], num_patch_y // factor[1]),
            num_heads = factor[1] * num_heads,
            window_size = window_size,
            shift_size = 1
        )(out_stage_3)
    
    ## Stage 4
    # Patch Merging
    pm_stage_4 = PatchMerging((num_patch_x // factor[1], num_patch_y // factor[1]), channels = factor[1] * embed_dim)(out_stage_3)
     
    # Swin Transformer block (first)
    out_stage_4 = SwinTransformerBlock(
        dim = factor[2] * embed_dim,
        input_resolution = (num_patch_x // factor[2], num_patch_y // factor[2]),
        num_heads = factor[2] * num_heads,
        window_size = window_size,
        shift_size = 0
    )(pm_stage_4)
    # Swin Transformer block (second)
    out_stage_4 = SwinTransformerBlock(
        dim = factor[2] * embed_dim,
        input_resolution = (num_patch_x // factor[2], num_patch_y // factor[2]),
        num_heads = factor[2] * num_heads,
        window_size = window_size,
        shift_size = 1
    )(out_stage_4)
    
    # pooling
    representation = GlobalAveragePooling1D()(out_stage_4)
    # logits
    output = Dense(num_classes, activation="softmax")(representation)
    # Create model
    model = Model(inputs=inputs, outputs=output)
    return model

In [16]:
model = create_SwinTransformer(4)



In [17]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 224, 224, 3)]     0         
                                                                 
 patch_partition (PatchPart  (None, None, 48)          0         
 ition)                                                          
                                                                 
 linear_embedding (LinearEm  (None, 3136, 96)          305760    
 bedding)                                                        
                                                                 
 swin_transformer_block (Sw  (None, 3136, 96)          112347    
 inTransformerBlock)                                             
                                                                 
 swin_transformer_block_1 (  (None, 3136, 96)          112347    
 SwinTransformerBlock)                                       

In [18]:
for layer in model.layers:
    print(layer.__class__.__name__)

InputLayer
PatchPartition
LinearEmbedding
SwinTransformerBlock
SwinTransformerBlock
PatchMerging
SwinTransformerBlock
SwinTransformerBlock
PatchMerging
SwinTransformerBlock
SwinTransformerBlock
SwinTransformerBlock
SwinTransformerBlock
SwinTransformerBlock
SwinTransformerBlock
PatchMerging
SwinTransformerBlock
SwinTransformerBlock
GlobalAveragePooling1D
Dense


###### We should ideally load the weights of the pretrained model to ensure high accuracy/performance.

In [21]:
pretrained_model = tf.keras.models.load_model(notebook_path + '/swin_tiny', compile = False)

In [24]:
# Extract the weights
pretrained_weights = pretrained_model.get_weights()

In [25]:
# Set the initial weights to the pre-trained weights
model.set_weights(pretrained_weights)

ValueError: You called `set_weights(weights)` on layer "model" with a weight list of length 173, but the layer was expecting 164 weights. Provided weights: [array([[[[ 2.49835104e-02,  2.71712639e-03, -5.77...

We come across an issue with a mismatch in dimensions for the weights, where the pretrained weights have dimension 173, but our current model expects length of 164.

In [217]:
# for var in model.trainable_variables:
#     print(var.name)

In [218]:
# for var in pretrained_model.trainable_variables:
#     print(var.name)

In [216]:
model_vars = model.trainable_variables
pretrained_model_vars = pretrained_model.trainable_variables
print(f"{'model' :<55}{'pretrained_model' :>56} \n")

##
short_model_var = '/'.join(model_vars[2].name.split('/')[-3:])
short_pmodel_var = '/'.join(pretrained_model_vars[2].name.split('/')[-3:])
print('{:<55}{:>56}'.format(short_model_var, short_pmodel_var))
print('{:>111}\n'.format('/'.join(pretrained_model_vars[3].name.split('/')[-3:])))

##   
print('\n{:>111}'.format('/'.join(pretrained_model_vars[30].name.split('/')[-3:])))
print('{:>111}'.format('/'.join(pretrained_model_vars[31].name.split('/')[-3:])))
short_model_var = '/'.join(model_vars[29].name.split('/')[-3:])
short_pmodel_var = '/'.join(pretrained_model_vars[32].name.split('/')[-3:])
print('{:<55}{:>56}\n'.format(short_model_var, short_pmodel_var))

##
print('\n{:>111}'.format('/'.join(pretrained_model_vars[140].name.split('/')[-3:])))
print('{:>111}'.format('/'.join(pretrained_model_vars[141].name.split('/')[-3:])))
short_model_var = '/'.join(model_vars[135].name.split('/')[-3:])
short_pmodel_var = '/'.join(pretrained_model_vars[142].name.split('/')[-3:])
print('{:<55}{:>56}\n'.format(short_model_var, short_pmodel_var))

##
for i in range(169, 171):
    short_model_var = '/'.join(model_vars[i-7].name.split('/')[-3:])
    short_pmodel_var = '/'.join(pretrained_model_vars[i].name.split('/')[-3:])
    print('{:<50}{:>61}'.format(short_model_var, short_pmodel_var))
print('{:>111}'.format('/'.join(pretrained_model_vars[171].name.split('/')[-3:])))
print('{:>111}'.format('/'.join(pretrained_model_vars[172].name.split('/')[-3:])))

model                                                                                          pretrained_model 

linear_embedding/embedding/embeddings:0                                             layer_normalization/gamma:0
                                                                                     layer_normalization/beta:0


                                                                    patch_merging/layer_normalization_5/gamma:0
                                                                     patch_merging/layer_normalization_5/beta:0
patch_merging/dense_9/kernel:0                                                   patch_merging/dense_4/kernel:0


                                                                 patch_merging_2/layer_normalization_23/gamma:0
                                                                  patch_merging_2/layer_normalization_23/beta:0
patch_merging_2/dense_43/kernel:0                                             patch_merging_2/dens

The differences I can spot:\
    - Our model has a linear embedding layer with weights of shape (3136, 96), whereas the pretrained model has the weights and biases for a normalization layer instead, each tensors of shape (96).\
    - For 'Patch Merging', the pretrained model has weights and biases for a normalization layer in addition to the dense layer, wheras our model only has the dense layer. Since there are 3 patch merging layers, the pretrained model would have 3(2) = 6 of the patch merging normalization layers in addition to the dense layer.\
    - Instead of the weights and biases for the dense layer at the end in our model, the pretrained has weights and biases for normalization layer and classification layer.

So 1 + 6 + (2(2)-2) = 9 accounts for all the additional weights in the pretrained model. We would have to make changes in our original model or change the format of the pretrained weights in order for the tensors to have the same shape, so that we can load the weights onto our model. It would make a significant difference in our classification.

We can alternatively get the model from installing an additional library, like Torch(pytorch), or installing tfswin, the keras implementation of the pytorch model. Either of these may be more efficient. Though, doing the above was helpful in terms of learning about what the model is comprised of, its archictecture, what each step is doing, and what we should expect of our inputs/outputs.