# Swin Transformer for Time Series Classification
This notebook implements a Swin Transformer adapted for time series input.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import os
from sklearn.preprocessing import LabelEncoder

# If utils are defined externally, import them
# from utils.metrics import calculate_metrics, print_metrics_summary
# from utils.visualization import save_visualizations

In [None]:
class WindowAttention(layers.Layer):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.dim = dim
        self.num_heads = num_heads
        self.window_size = window_size
        self.scale = (dim // num_heads) ** -0.5

    def build(self, input_shape):
        self.qkv = layers.Dense(self.dim * 3)
        self.proj = layers.Dense(self.dim)

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x = tf.reshape(x, [B, H//self.window_size, self.window_size, W//self.window_size, self.window_size, C])
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
        x = tf.reshape(x, [-1, self.window_size*self.window_size, C])

        qkv = self.qkv(x)
        q, k, v = tf.split(qkv, 3, axis=-1)
        q = tf.reshape(q, [-1, self.window_size**2, self.num_heads, C//self.num_heads])
        q = tf.transpose(q, [0, 2, 1, 3])
        k = tf.reshape(k, [-1, self.window_size**2, self.num_heads, C//self.num_heads])
        k = tf.transpose(k, [0, 2, 1, 3])
        v = tf.reshape(v, [-1, self.window_size**2, self.num_heads, C//self.num_heads])
        v = tf.transpose(v, [0, 2, 1, 3])

        attn = tf.matmul(q, tf.transpose(k, [0, 1, 3, 2])) * self.scale
        attn = tf.nn.softmax(attn, axis=-1)
        x = tf.matmul(attn, v)
        x = tf.transpose(x, [0, 2, 1, 3])
        x = tf.reshape(x, [-1, self.window_size, self.window_size, C])

        x = tf.reshape(x, [B, H//self.window_size, W//self.window_size, self.window_size, self.window_size, C])
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
        x = tf.reshape(x, [B, H, W, C])

        return self.proj(x)

In [None]:
class SwinTransformerBlock(layers.Layer):
    def __init__(self, dim, num_heads, window_size, shift_size=0):
        super().__init__()
        self.dim = dim
        self.norm1 = layers.LayerNormalization(epsilon=1e-5)
        self.attn = WindowAttention(dim, num_heads, window_size)
        self.norm2 = layers.LayerNormalization(epsilon=1e-5)
        self.mlp = tf.keras.Sequential([
            layers.Dense(dim * 4, activation='gelu'),
            layers.Dense(dim)
        ])
        self.shift_size = shift_size
        self.window_size = window_size

    def call(self, x):
        if self.shift_size > 0:
            shifted_x = tf.roll(x, shift=[-self.shift_size, -self.shift_size], axis=[1, 2])
        else:
            shifted_x = x

        x = x + self.attn(self.norm1(shifted_x))
        x = x + self.mlp(self.norm2(x))
        return x

In [None]:
class PatchMerging(layers.Layer):
    def __init__(self):
        super().__init__()
        self.norm = layers.LayerNormalization(epsilon=1e-5)

    def call(self, x):
        B, H, W, C = tf.shape(x)[0], tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]
        x = tf.reshape(x, [B, H//2, 2, W//2, 2, C])
        x = tf.transpose(x, [0, 1, 3, 2, 4, 5])
        x = tf.reshape(x, [B, H//2, W//2, 4*C])
        x = self.norm(x)
        x = layers.Dense(2*C)(x)
        return x

In [None]:
class SwinTransformer:
    def __init__(self, input_shape=(23, 4), num_classes=7, model_dir="saved_models/swin"):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model_dir = model_dir
        os.makedirs(self.model_dir, exist_ok=True)
        self.model = self._build_model()
        self.encoder = LabelEncoder()

    def _build_model(self):
        inputs = layers.Input(shape=self.input_shape)
        x = layers.Reshape((self.input_shape[0], self.input_shape[1], 1))(inputs)
        x = layers.Conv2D(64, kernel_size=(4, 1), strides=(4, 1), padding='same')(x)
        x = layers.LayerNormalization(epsilon=1e-5)(x)

        for _ in range(2):
            x = SwinTransformerBlock(64, num_heads=4, window_size=7)(x)
            x = SwinTransformerBlock(64, num_heads=4, window_size=7, shift_size=3)(x)
            x = PatchMerging()(x)

        x = layers.GlobalAveragePooling2D()(x)
        x = layers.Dense(128, activation='gelu')(x)
        outputs = layers.Dense(self.num_classes, activation='softmax')(x)
        return Model(inputs=inputs, outputs=outputs)

## Training, Evaluation and Save Methods
Add your training and evaluation pipeline here depending on your data.