In [4]:
import tensorflow as tf
from keras import layers
from swin_transformer import SwinTransformer, PatchExtract, PatchEmbedding, PatchMerging

In [11]:
# configuration = SwinConfig(
#     image_size=84,
#     patch_size=3,
#     num_channels=4,
#     embed_dim=96,
#     depths=[2, 3, 2],
#     num_heads=[3, 3, 6],
#     window_size=7,
#     mlp_ratio=4.0,
#     drop_path_rate=0.1,
# )

input_shape = (84, 84, 4)

embed_dim = 96
patch_size = (3, 3)
num_heads = [3, 3, 6]
window_size = 7
num_mlp = 384
qkv_bias=True
dropout_rate=0.1
shift_size = 1

num_patch_x = input_shape[0] // patch_size[0]
num_patch_y = input_shape[1] // patch_size[1]

In [13]:
input = layers.Input(input_shape)
x = tf.keras.layers.Rescaling(scale=1.0 / 255)(input)
x = PatchExtract(patch_size)(x)
x = PatchEmbedding(num_patch_x * num_patch_y, embed_dim)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads[0],
    window_size=window_size,
    shift_size=0,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads[1],
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = SwinTransformer(
    dim=embed_dim,
    num_patch=(num_patch_x, num_patch_y),
    num_heads=num_heads[2],
    window_size=window_size,
    shift_size=shift_size,
    num_mlp=num_mlp,
    qkv_bias=qkv_bias,
    dropout_rate=dropout_rate,
)(x)
x = PatchMerging((num_patch_x, num_patch_y), embed_dim=embed_dim)(x)
x = layers.GlobalAveragePooling1D()(x)
output = layers.Dense(6, activation="linear")(x)
model = tf.keras.Model(input, output)

In [15]:
model.summary()

Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 84, 84, 4)]       0         
                                                                 
 rescaling_3 (Rescaling)     (None, 84, 84, 4)         0         
                                                                 
 patch_extract_3 (PatchExtra  (None, 784, 36)          0         
 ct)                                                             
                                                                 
 patch_embedding_1 (PatchEmb  (None, 784, 96)          78816     
 edding)                                                         
                                                                 
 swin_transformer_1 (SwinTra  (None, 784, 96)          114748    
 nsformer)                                                       
                                                             