In [18]:
import tensorflow as tf
from tensorflow.keras import layers
import math

class WindowAttention(layers.Layer):
    def __init__(self, dim, window_size, num_heads, qkv_bias=True, dropout_rate=0.0):
        super().__init__()
        self.dim = dim
        self.window_size = window_size
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = layers.Dense(dim * 3, use_bias=qkv_bias)
        self.attn_drop = layers.Dropout(dropout_rate)
        self.proj = layers.Dense(dim)
        self.proj_drop = layers.Dropout(dropout_rate)

    def call(self, x):
        B_, N, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        qkv = tf.transpose(tf.reshape(self.qkv(x), shape=[-1, N, 3, self.num_heads, C // self.num_heads]), perm=[2, 0, 3, 1, 4])
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q * self.scale
        attn = tf.matmul(q, tf.transpose(k, perm=[0, 1, 3, 2]))
        attn = tf.nn.softmax(attn, axis=-1)
        attn = self.attn_drop(attn)

        x = tf.transpose(tf.matmul(attn, v), perm=[0, 2, 1, 3])
        x = tf.reshape(x, shape=[-1, N, C])
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class SwinTransformerBlock(layers.Layer):
    def __init__(self, dim, num_heads, window_size=7, shift_size=0, mlp_ratio=4., qkv_bias=True, dropout_rate=0.0):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.shift_size = shift_size
        self.mlp_ratio = mlp_ratio

        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(dim, window_size, num_heads, qkv_bias, dropout_rate)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)
        self.mlp = tf.keras.Sequential([
            layers.Dense(int(dim * mlp_ratio)),
            layers.Activation('gelu'),
            layers.Dense(dim),
        ])

    def build(self, input_shape):
        self.input_resolution = input_shape[1]
        self.H = self.W = int(math.sqrt(self.input_resolution))

    def call(self, x):
        B, L, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        shortcut = x
        x = self.norm1(x)
        x = tf.reshape(x, shape=[-1, self.H, self.W, C])

        # Cyclic shift
        if self.shift_size > 0:
            shifted_x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
        else:
            shifted_x = x

        # Partition windows
        x_windows = self.window_partition(shifted_x, self.window_size)
        x_windows = tf.reshape(x_windows, shape=[-1, self.window_size * self.window_size, C])

        # W-MSA/SW-MSA
        attn_windows = self.attn(x_windows)

        # Merge windows
        attn_windows = tf.reshape(attn_windows, shape=[-1, self.window_size, self.window_size, C])
        shifted_x = self.window_reverse(attn_windows, self.window_size, self.H, self.W)

        # Reverse cyclic shift
        if self.shift_size > 0:
            x = tf.roll(shifted_x, shift=[self.shift_size, self.shift_size], axis=[1, 2])
        else:
            x = shifted_x

        x = tf.reshape(x, shape=[-1, self.H * self.W, C])

        # FFN
        x = shortcut + x
        x = x + self.mlp(self.norm2(x))

        return x

    def window_partition(self, x, window_size):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x = tf.reshape(x, shape=[B, H // window_size, window_size, W // window_size, window_size, C])
        windows = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
        windows = tf.reshape(windows, shape=[-1, window_size, window_size, C])
        return windows

    def window_reverse(self, windows, window_size, H, W):
        B = tf.shape(windows)[0] // (H * W // window_size // window_size)
        x = tf.reshape(windows, shape=[B, H // window_size, W // window_size, window_size, window_size, -1])
        x = tf.transpose(x, perm=[0, 1, 3, 2, 4, 5])
        x = tf.reshape(x, shape=[B, H, W, -1])
        return x

class PatchMerging(layers.Layer):
    def __init__(self, dim):
        super().__init__()
        self.dim = dim
        self.reduction = layers.Dense(2 * dim, use_bias=False)
        self.norm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x):
        H, W = self.H, self.W
        B, L, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2]
        assert L == H * W, "input feature has wrong size"
        x = tf.reshape(x, shape=[B, H, W, C])

        x0 = x[:, 0::2, 0::2, :]  # B H/2 W/2 C
        x1 = x[:, 1::2, 0::2, :]  # B H/2 W/2 C
        x2 = x[:, 0::2, 1::2, :]  # B H/2 W/2 C
        x3 = x[:, 1::2, 1::2, :]  # B H/2 W/2 C
        x = tf.concat([x0, x1, x2, x3], axis=-1)  # B H/2 W/2 4*C
        x = tf.reshape(x, shape=[B, -1, 4 * C])  # B H/2*W/2 4*C

        x = self.norm(x)
        x = self.reduction(x)

        return x

    def build(self, input_shape):
        _, self.H, self.W, _ = input_shape

def build_swin_transformer(input_shape, num_classes, num_layers, num_heads, window_size, mlp_dim, dropout_rate=0.0):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(mlp_dim, kernel_size=4, strides=4, padding='same')(inputs)
    x = layers.Reshape((-1, x.shape[-1]))(x)

    for i in range(num_layers):
        x = SwinTransformerBlock(dim=mlp_dim, 
                                 num_heads=num_heads, 
                                 window_size=window_size,
                                 shift_size=0 if (i % 2 == 0) else window_size // 2,
                                 mlp_ratio=4,
                                 qkv_bias=True,
                                 dropout_rate=dropout_rate)(x)

    x = layers.LayerNormalization(epsilon=1e-5)(x)
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation="softmax")(x)

    return tf.keras.Model(inputs, outputs)

In [19]:
def prepare_data():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    y_train, y_test = tf.keras.utils.to_categorical(y_train, 10), tf.keras.utils.to_categorical(y_test, 10)
    return (x_train, y_train), (x_test, y_test)

def train_and_evaluate():
    (x_train, y_train), (x_test, y_test) = prepare_data()
    
    model = build_swin_transformer(input_shape=(32, 32, 3), num_classes=10, num_layers=4, num_heads=2, window_size=4, mlp_dim=128)
    model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
    
    model.fit(x_train, y_train, epochs=10, validation_data=(x_test, y_test))
    
    _, accuracy = model.evaluate(x_test, y_test)
    print(f"Test accuracy: {accuracy:.2f}")

train_and_evaluate()


Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10
Test accuracy: 0.46
