# Vision Transformer (ViT) Implementation in TensorFlow
This notebook provides a modular and clear implementation of a Vision Transformer (ViT) adapted for time-series data.

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np
import os
from utils.metrics import calculate_metrics, print_metrics_summary
from utils.visualization import save_visualizations
from sklearn.preprocessing import LabelEncoder

## Patches Layer
This layer splits the input into non-overlapping patches.

In [None]:
class Patches(layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, x):
        batch_size = tf.shape(x)[0]
        patches = tf.reshape(x, [batch_size, -1, self.patch_size * x.shape[-1]])
        return patches

## Patch Encoder
This layer encodes each patch using a linear projection and adds positional embedding.

In [None]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = layers.Dense(projection_dim)
        self.position_embedding = layers.Embedding(input_dim=num_patches, output_dim=projection_dim)

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded

## Vision Transformer (ViT) Class
Main model class defining the Vision Transformer architecture, training, and evaluation methods.

In [None]:
class ViT:
    def __init__(self, input_shape=(23, 4), num_classes=7, model_dir="saved_models/vit"):
        self.input_shape = input_shape
        self.num_classes = num_classes
        self.model_dir = model_dir
        os.makedirs(self.model_dir, exist_ok=True)
        self.model = self._build_model()
        self.encoder = LabelEncoder()

    def _mlp(self, x, hidden_units, dropout_rate):
        for units in hidden_units:
            x = layers.Dense(units, activation=tf.nn.gelu)(x)
            x = layers.Dropout(dropout_rate)(x)
        return x

    def _build_model(self):
        inputs = layers.Input(shape=self.input_shape)
        patches = Patches(patch_size=4)(inputs)
        num_patches = (self.input_shape[0] // 4) * (self.input_shape[1] // 1)
        encoded_patches = PatchEncoder(num_patches, projection_dim=64)(patches)

        for _ in range(6):
            x1 = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
            attention_output = layers.MultiHeadAttention(num_heads=4, key_dim=64, dropout=0.1)(x1, x1)
            x2 = layers.Add()([attention_output, encoded_patches])
            x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
            x3 = self._mlp(x3, hidden_units=[128, 64], dropout_rate=0.1)
            encoded_patches = layers.Add()([x3, x2])

        representation = layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        representation = layers.Flatten()(representation)
        features = self._mlp(representation, hidden_units=[128, 64], dropout_rate=0.5)
        logits = layers.Dense(self.num_classes)(features)
        return Model(inputs=inputs, outputs=logits)

## Training, Evaluation and Model I/O

In [None]:
    def train(self, x_train, y_train, x_val, y_val, epochs=100, batch_size=64):
        y_train_enc = self.encoder.fit_transform(y_train)
        y_val_enc = self.encoder.transform(y_val)

        callbacks = [
            tf.keras.callbacks.ModelCheckpoint(
                os.path.join(self.model_dir, "best_model.h5"), monitor="val_accuracy", save_best_only=True
            ),
            tf.keras.callbacks.EarlyStopping(
                monitor="val_accuracy", patience=10, restore_best_weights=True
            )
        ]

        self.model.compile(
            optimizer=tf.keras.optimizers.AdamW(learning_rate=1e-4, weight_decay=1e-3),
            loss="sparse_categorical_crossentropy",
            metrics=["accuracy"]
        )

        history = self.model.fit(
            x=x_train,
            y=y_train_enc,
            validation_data=(x_val, y_val_enc),
            batch_size=batch_size,
            epochs=epochs,
            callbacks=callbacks
        )
        return history

    def evaluate(self, x_test, y_test):
        y_test_enc = self.encoder.transform(y_test)
        y_pred = np.argmax(self.model.predict(x_test), axis=1)
        metrics = calculate_metrics(y_test_enc, y_pred, "ViT")
        print_metrics_summary(metrics)
        save_visualizations(self.model, x_test, y_test_enc, y_pred, model_name="ViT")
        return metrics

    def save(self, model_name="vit_model"):
        save_path = os.path.join(self.model_dir, f"{model_name}.h5")
        self.model.save(save_path)
        print(f"Model saved to {save_path}")

    @classmethod
    def load(cls, model_path):
        model = tf.keras.models.load_model(model_path)
        vit = cls(input_shape=model.input_shape[1:], num_classes=model.output_shape[-1])
        vit.model = model
        return vit