In [1]:
import tensorflow as tf
from tensorflow.keras import layers

class SwinBlock(layers.Layer):
    def __init__(self, dim, num_heads, window_size, shift_size):
        super().__init__()
        self.window_size = window_size
        self.shift_size = shift_size
        self.num_heads = num_heads
        self.dim = dim
        
        self.norm1 = layers.LayerNormalization(epsilon=1e-6)
        self.conv1 = layers.Conv2D(dim*2, kernel_size=1)
        self.norm2 = layers.LayerNormalization(epsilon=1e-6)
        self.shift = layers.Conv2D(dim*2, kernel_size=window_size, strides=shift_size, groups=num_heads, use_bias=False)
        self.norm3 = layers.LayerNormalization(epsilon=1e-6)
        self.conv2 = layers.Conv2D(dim, kernel_size=1)
        
    def call(self, x):
        res = x
        x = self.norm1(x)
        x = tf.nn.gelu(x)
        x = self.conv1(x)
        x = self.norm2(x)
        x = tf.nn.gelu(x)
        x = self.shift(x)
        x = self.norm3(x)
        x = tf.nn.gelu(x)
        x = self.conv2(x)
        return x + res

class SwinTransformer(tf.keras.Model):
    def __init__(self, img_size, patch_size, in_channels, num_classes, embed_dim, depth, num_heads, window_size, shift_size):
        super().__init__()
        assert img_size % patch_size == 0, "image size must be divisible by the patch size"
        num_patches = (img_size // patch_size) ** 2
        patch_dim = in_channels * patch_size ** 2
        self.patch_size = patch_size
        
        self.patch_embed = layers.Conv2D(embed_dim, kernel_size=patch_size, strides=patch_size)
        self.pos_embed = tf.Variable(tf.zeros((1, num_patches, embed_dim)))
        self.blocks = []
        for i in range(depth):
            block = SwinBlock(embed_dim, num_heads, window_size, shift_size)
            self.blocks.append(block)
        self.norm = layers.LayerNormalization(epsilon=1e-6)
        self.avg_pool = layers.GlobalAveragePooling2D()
        self.fc = layers.Dense(num_classes)
        
    def call(self, x):
        x = self.patch_embed(x)
        x = tf.reshape(x, (-1, x.shape[1]*x.shape[2], x.shape[3]))
        x += self.pos_embed
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        x = self.avg_pool(x)
        x = self.fc(x)
        return x
