https://keras.io/examples/vision/token_learner/

In [2]:
import keras 
from keras import layers
from keras import ops
from tensorflow import data as tf_data

from datetime import datetime
import matplotlib.pyplot as plt
import numpy as np
import math

2024-12-12 10:19:59.776334: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-12-12 10:19:59.824653: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733995199.863678    4507 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733995199.877452    4507 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-12 10:19:59.985586: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instr

## Hyperparameters

In [21]:
# DATA
BATCH_SIZE = 256
AUTO = tf_data.AUTOTUNE
INPUT_SHAPE = (32,32,3)
NUM_CLASSES = 10

# OPTIMIZER
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4

# TRAINING
EPOCHS = 10

# AUGMENTATION
IMAGE_SIZE = 48 # reize input images to this size
PATCH_SIZE = 6 # Size of the patches to be extrachted from the input images. 
NUM_PATCHES = (IMAGE_SIZE // PATCH_SIZE ) ** 2

# VIT ARCHITECTURE
LAYER_NORM_EPS = 1e-6
PROJECTION_DIM = 128
NUM_HEADS = 4
NUM_LAYERS = 4

MLP_UNITS = [
    PROJECTION_DIM * 2,
    PROJECTION_DIM

]

# TOKENLEARNER
NUM_TOKENS = 4

Load and prepare the CIFAR-10 dataset

In [4]:
# Load the CIFAR-10 dataset.
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
(x_train, y_train), (x_val, y_val) = (
    (x_train[:40000], y_train[:40000]),
    (x_train[40000:], y_train[40000:]),
)
print(f"Training samples: {len(x_train)}")
print(f"Validation samples: {len(x_val)}")
print(f"Testing samples: {len(x_test)}")

# Convert to tf.data.Dataset objects.
train_ds = tf_data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(BATCH_SIZE * 100).batch(BATCH_SIZE).prefetch(AUTO)

val_ds = tf_data.Dataset.from_tensor_slices((x_val, y_val))
val_ds = val_ds.batch(BATCH_SIZE).prefetch(AUTO)

test_ds = tf_data.Dataset.from_tensor_slices((x_test, y_test))
test_ds = test_ds.batch(BATCH_SIZE).prefetch(AUTO)

Training samples: 40000
Validation samples: 10000
Testing samples: 10000


I0000 00:00:1733995202.202031    4507 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13572 MB memory:  -> device: 0, name: NVIDIA GeForce RTX 4060 Ti, pci bus id: 0000:01:00.0, compute capability: 8.9


Data augmentation
- Rescaling
- REsizing
- Random cropping (fixed-sized or random sized)
- Random horizontal flipping

In [5]:
data_augmentation = keras.Sequential(
    [
        layers.Rescaling(1/255.0),
        layers.Resizing(INPUT_SHAPE[0] + 20, INPUT_SHAPE[0] + 20),
        layers.RandomCrop(IMAGE_SIZE, IMAGE_SIZE),
        layers.RandomFlip("horizontal")
    ],
    name="data_augmentation"
)

Positional embedding module
- multi-head self attention layers
- fully-connected feed forward networks (MLP) 

In [6]:
class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.position_embedding = layers.Embedding(
            input_dim = num_patches, output_dim=projection_dim
        )
    
    def call(self, patch):
        positions = ops.expand_dims(
            ops.arange(start = 0, stop=self.num_patches, step=1), axis=0
        )
        encoded = patch + self.position_embedding(positions)
        return encoded
    
    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches})
        return config   

MLP block for Transformer

In [7]:
def mlp(x, dropout_rate, hidden_units):
    # Iterate over the hidden units and
    # add Dense => Dropout.
    for units in hidden_units:
        x = layers.Dense(units, activation=ops.gelu)(x)
        x = layers.Dropout(dropout_rate)(x)
    return x

In [8]:
def token_learner(inputs, number_of_tokens=NUM_TOKENS):
    # Layer normalize the inpus.
    x = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(inputs)
    # Applying Conv2D => Reshape => Permute
    # The reshape and permute is done to help with the next steps of
    # multiplication and Global Average Pooling.

    attention_maps = keras.Sequential(
        [
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3,3),
                activation=ops.gelu,
                padding="same",
                use_bias=False
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3,3),
                activation=ops.gelu,
                padding="same",
                use_bias=False               
            ),
            layers.Conv2D(
                filters=number_of_tokens,
                kernel_size=(3, 3),
                activation=ops.gelu,
                padding="same",
                use_bias=False,
            ),
            # Reshape and Permute
            layers.Reshape((-1, number_of_tokens)),  # (B, H*W, num_of_tokens)
            layers.Permute((2, 1)),
        ]
    )(
        x
    )  # (B, num_of_tokens, H*W)

    # Reshape the input to align it with the output of the conv block.
    num_filters = inputs.shape[-1]
    inputs = layers.Reshape((1, -1, num_filters))(inputs)  # inputs == (B, 1, H*W, C)

    # Element-Wise multiplication of the attention maps and the inputs
    attended_inputs = (
        ops.expand_dims(attention_maps, axis=-1) * inputs
    )  # (B, num_tokens, H*W, C)

    # Global average pooling the element wise multiplication result.
    outputs = ops.mean(attended_inputs, axis=2)  # (B, num_tokens, C)
    return outputs

In [9]:
def transformer(encoded_patches):
    # Layer normalization 1.
    x1 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)

    # Multi Head Self Attention layer 1.
    attention_output = layers.MultiHeadAttention(
        num_heads=NUM_HEADS, key_dim=PROJECTION_DIM, dropout=0.1
    )(x1, x1)

    # Skip connection 1.
    x2 = layers.Add()([attention_output, encoded_patches])

    # Layer normalization 2.
    x3 = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(x2)

    # MLP layer 1.
    x4 = mlp(x3, hidden_units=MLP_UNITS, dropout_rate=0.1)

    # Skip connection 2.
    encoded_patches = layers.Add()([x4, x2])
    return encoded_patches

In [11]:
def create_vit_classifier(use_token_learner=True, token_learner_units=NUM_TOKENS):
    inputs = layers.Input(shape=INPUT_SHAPE)  # (B, H, W, C)

    # Augment data.
    augmented = data_augmentation(inputs)

    # Create patches and project the pathces.
    projected_patches = layers.Conv2D(
        filters=PROJECTION_DIM,
        kernel_size=(PATCH_SIZE, PATCH_SIZE),
        strides=(PATCH_SIZE, PATCH_SIZE),
        padding="VALID",
    )(augmented)
    _, h, w, c = projected_patches.shape
    projected_patches = layers.Reshape((h * w, c))(
        projected_patches
    )  # (B, number_patches, projection_dim)

    # Add positional embeddings to the projected patches.
    encoded_patches = PatchEncoder(
        num_patches=NUM_PATCHES, projection_dim=PROJECTION_DIM
    )(
        projected_patches
    )  # (B, number_patches, projection_dim)
    encoded_patches = layers.Dropout(0.1)(encoded_patches)

    # Iterate over the number of layers and stack up blocks of
    # Transformer.
    for i in range(NUM_LAYERS):
        # Add a Transformer block.
        encoded_patches = transformer(encoded_patches)

        # Add TokenLearner layer in the middle of the
        # architecture. The paper suggests that anywhere
        # between 1/2 or 3/4 will work well.
        if use_token_learner and i == NUM_LAYERS // 2:
            _, hh, c = encoded_patches.shape
            h = int(math.sqrt(hh))
            encoded_patches = layers.Reshape((h, h, c))(
                encoded_patches
            )  # (B, h, h, projection_dim)
            encoded_patches = token_learner(
                encoded_patches, token_learner_units
            )  # (B, num_tokens, c)

    # Layer normalization and Global average pooling.
    representation = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)(encoded_patches)
    representation = layers.GlobalAvgPool1D()(representation)

    # Classify outputs.
    outputs = layers.Dense(NUM_CLASSES, activation="softmax")(representation)

    # Create the Keras model.
    model = keras.Model(inputs=inputs, outputs=outputs)
    return model

In [22]:
def run_experiment(model):
    # Initialize the AdamW optimizer.
    optimizer = keras.optimizers.AdamW(
        learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
    )

    # Compile the model with the optimizer, loss function
    # and the metrics.
    model.compile(
        optimizer=optimizer,
        loss="sparse_categorical_crossentropy",
        metrics=[
            keras.metrics.SparseCategoricalAccuracy(name="accuracy"),
            keras.metrics.SparseTopKCategoricalAccuracy(5, name="top-5-accuracy"),
        ],
    )

    # Define callbacks
    checkpoint_filepath = "/tmp/checkpoint.weights.h5"
    checkpoint_callback = keras.callbacks.ModelCheckpoint(
        checkpoint_filepath,
        monitor="val_accuracy",
        save_best_only=True,
        save_weights_only=True,
    )

    # Train the model.
    _ = model.fit(
        train_ds,
        epochs=EPOCHS,
        validation_data=val_ds,
        callbacks=[checkpoint_callback],
    )

    model.load_weights(checkpoint_filepath)
    _, accuracy, top_5_accuracy = model.evaluate(test_ds)
    print(f"Test accuracy: {round(accuracy * 100, 2)}%")
    print(f"Test top 5 accuracy: {round(top_5_accuracy * 100, 2)}%")

In [23]:
vit_token_learner = create_vit_classifier()
run_experiment(vit_token_learner)

Epoch 1/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 92ms/step - accuracy: 0.1208 - loss: 2.3749 - top-5-accuracy: 0.5525 - val_accuracy: 0.2679 - val_loss: 1.9079 - val_top-5-accuracy: 0.8177
Epoch 2/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 90ms/step - accuracy: 0.2948 - loss: 1.8678 - top-5-accuracy: 0.8322 - val_accuracy: 0.3332 - val_loss: 1.7752 - val_top-5-accuracy: 0.8566
Epoch 3/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 93ms/step - accuracy: 0.3831 - loss: 1.6548 - top-5-accuracy: 0.8842 - val_accuracy: 0.4350 - val_loss: 1.5397 - val_top-5-accuracy: 0.9088
Epoch 4/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m15s[0m 92ms/step - accuracy: 0.4473 - loss: 1.5095 - top-5-accuracy: 0.9114 - val_accuracy: 0.4684 - val_loss: 1.4571 - val_top-5-accuracy: 0.9229
Epoch 5/10
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 92ms/step - accuracy: 0.4768 - loss: 1.4403 - top-

In [None]:
import tensorflow as tf

# Beispiel: Laden eines Bildes
def preprocess_image(image_path):
    img = tf.keras.preprocessing.image.load_img(image_path, target_size=(32, 32))
    img_array = tf.keras.preprocessing.image.img_to_array(img)
    img_array = tf.expand_dims(img_array, axis=0)  # Batch-Dimension hinzufügen
    #img_array = img_array / 255.0  # Normalisierung
    return img_array

# Beispiel: Pfad zu Ihrem Bild
image_path = "Download.jpeg"
image = preprocess_image(image_path)


In [42]:
# Eine Vorhersage für ein Bild
predictions = vit_token_learner.predict(image)

# Wahrscheinlichkeiten in Kategorien umwandeln
predicted_class = tf.argmax(predictions, axis=-1).numpy()[0]
print(f"Predicted Class: {predicted_class}")


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 25ms/step
Predicted Class: 5


In [43]:
# Mapping von Klassen-Indices zu Namen
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Beispiel: Vorhersage in eine Klasse umwandeln
predicted_class_name = cifar10_classes[predicted_class]
print(f"Predicted Class Name: {predicted_class_name}")


Predicted Class Name: dog


In [40]:
# Beispiel für ein Bild
predictions = vit_token_learner.predict(image)

# Wahrscheinlichkeiten in Kategorien umwandeln
predicted_class = tf.argmax(predictions, axis=-1).numpy()[0]

# Mapping von Indizes zu Klassen
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer",
    "dog", "frog", "horse", "ship", "truck"
]

# Umwandeln in Klassenname
predicted_class_name = cifar10_classes[predicted_class]
print(f"Predicted Class Index: {predicted_class}")
print(f"Predicted Class Name: {predicted_class_name}")


[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 27ms/step
Predicted Class Index: 1
Predicted Class Name: automobile
