<a href="https://colab.research.google.com/github/BoraShruti/VisualTransformersMNIST/blob/main/visual_transformers_experiment.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Visual Transformer Experiments with WandB Integration

This notebook demonstrates how to integrate [Weights & Biases (WandB)](https://wandb.ai/) with a Vision Transformer (ViT) training pipeline. We will experiment with the following hyperparameters:

- **Patch Size:** Fixed at 4×4
- **Number of ViT Layers:** 4 to 8
- **Number of Attention Heads:** 2 to 4

For each configuration, the model will be trained and the test accuracy will be recorded and logged to WandB. Detailed explanations are provided in each section so that you can understand the purpose of each block of code.

In [None]:
# Install wandb if it is not already installed
!pip install wandb
!pip install --upgrade wandb




## Importing Libraries

In this cell we import all the necessary libraries including TensorFlow for building our model, WandB for experiment tracking, and other useful libraries for data processing and result analysis.

In [None]:
import wandb
from wandb.integration.keras import WandbCallback
import itertools
import pandas as pd

import tensorflow as tf
from tensorflow.keras import layers, models

# For reproducibility
import numpy as np
np.random.seed(42)
tf.random.set_seed(42)

## Dataset Preparation

In this section we load and preprocess the dataset. For demonstration purposes, we are using the MNIST dataset. In your project, replace this section with your actual data loading and preprocessing steps. The dataset is split into training, validation, and test sets.

In [None]:
# Example: Loading the MNIST dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.astype('float32') / 255.0
x_test = x_test.astype('float32') / 255.0

# Expand dims to add a channel dimension
x_train = np.expand_dims(x_train, -1)
x_test = np.expand_dims(x_test, -1)

# Create a validation set from the training data
val_split = 0.1
val_size = int(len(x_train) * val_split)
x_val = x_train[:val_size]
y_val = y_train[:val_size]
x_train = x_train[val_size:]
y_train = y_train[val_size:]

# Create TensorFlow datasets
batch_size = 64
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)).batch(batch_size)
val_dataset = tf.data.Dataset.from_tensor_slices((x_val, y_val)).batch(batch_size)
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)).batch(batch_size)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


## Model Definition

Below is the definition of a helper function `create_vit_model` that constructs a simple Vision Transformer (ViT) model. The function takes the patch size, number of transformer layers, and number of attention heads as parameters. You can customize the model architecture as needed.

In [None]:
def create_vit_model(patch_size, num_layers, num_heads, image_size=28, num_classes=10):
    """
    Creates a simple Vision Transformer (ViT) model.

    Parameters:
        patch_size (int): Size of each patch (patch_size x patch_size).
        num_layers (int): Number of transformer layers.
        num_heads (int): Number of attention heads in each transformer layer.
        image_size (int): Size of the input image (assumed to be square).
        num_classes (int): Number of classes for classification.

    Returns:
        model (tf.keras.Model): A compiled Vision Transformer model.
    """
    # Calculate the number of patches
    num_patches = (image_size // patch_size) ** 2

    # Input layer
    inputs = layers.Input(shape=(image_size, image_size, 1))

    # Create patches using a convolution layer
    patches = layers.Conv2D(filters=64, kernel_size=patch_size, strides=patch_size, padding='valid')(inputs)
    x = layers.Reshape((num_patches, 64))(patches)

    # Add positional embeddings
    positions = tf.range(start=0, limit=num_patches, delta=1)
    pos_embedding = layers.Embedding(input_dim=num_patches, output_dim=64)(positions)
    x = x + pos_embedding

    # Transformer encoder blocks
    for _ in range(num_layers):
        # Layer normalization
        x1 = layers.LayerNormalization(epsilon=1e-6)(x)
        # Multi-head attention
        attention_output = layers.MultiHeadAttention(num_heads=num_heads, key_dim=64)(x1, x1)
        x2 = layers.Add()([x, attention_output])

        # MLP block with residual connection
        x3 = layers.LayerNormalization(epsilon=1e-6)(x2)
        mlp_output = layers.Dense(128, activation='relu')(x3)
        mlp_output = layers.Dense(64)(mlp_output)
        x = layers.Add()([x2, mlp_output])

    # Classification head
    x = layers.LayerNormalization(epsilon=1e-6)(x)
    x = layers.Flatten()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model

## Experiment Setup with WandB

In this section we configure our hyperparameter experiments using WandB. The experiment will use the following settings:

- **Patch Size:** Fixed at 4×4
- **Number of ViT Layers:** 4 to 8
- **Number of Attention Heads:** 2 to 4

For each combination, a new WandB run is initiated, the model is created, compiled, and trained. After training, the model is evaluated on the test set, and the test accuracy is logged to WandB and stored locally for summary.

In [None]:
# Set the fixed patch size and define the range for the number of layers and heads
patch_size = 4  # 4x4 patches
layer_options = range(4, 9)  # Experiment with 4 to 8 layers
head_options = range(2, 5)   # Experiment with 2 to 4 heads

# Number of epochs for training
EPOCHS = 5  # Adjust the number of epochs as needed

# List to store experiment results
experiment_results = []

import time

for num_layers, num_heads in itertools.product(layer_options, head_options):
    # Initialize a new WandB run for the current experiment
    wandb.init(project="ViT_experiment",
               config={
                   "patch_size": f"{patch_size}x{patch_size}",
                   "num_layers": num_layers,
                   "num_heads": num_heads
               }, reinit=True)
    config = wandb.config

    print(f"Running experiment with patch_size: {config.patch_size}, num_layers: {config.num_layers}, num_heads: {config.num_heads}")

    # Create the Vision Transformer model with the current configuration
    model = create_vit_model(patch_size=patch_size,
                             num_layers=num_layers,
                             num_heads=num_heads,
                             image_size=28,
                             num_classes=10)

    # Compile the model with an appropriate optimizer and loss function
    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])

    # Train the model. The WandbCallback will automatically log training metrics to WandB
    history = model.fit(
      train_dataset,
      validation_data=val_dataset,
      epochs=EPOCHS,
      callbacks=[WandbCallback(save_graph=False,save_model=False,data_type="image")],  # Disable automatic model saving
      verbose=1
    )


    # Evaluate the trained model on the test set
    test_loss, test_acc = model.evaluate(test_dataset)
    print(f"Test Accuracy: {test_acc}")

    # Log the test accuracy to WandB
    wandb.log({"test_accuracy": test_acc})

    # Store the experiment results
    experiment_results.append({
        "patch_size": config.patch_size,
        "num_layers": config.num_layers,
        "num_heads": config.num_heads,
        "test_accuracy": test_acc
    })

    # Finish the current WandB run
    wandb.finish()

    # Optional: Pause briefly between experiments
    time.sleep(1)

# Convert the results into a DataFrame and display the summary
df_results = pd.DataFrame(experiment_results)
print("Summary of experiments:")
print(df_results)

Running experiment with patch_size: 4x4, num_layers: 4, num_heads: 2
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 270ms/step - accuracy: 0.8119 - loss: 0.6058



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m270s[0m 283ms/step - accuracy: 0.8120 - loss: 0.6055 - val_accuracy: 0.9295 - val_loss: 0.2304
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 269ms/step - accuracy: 0.9617 - loss: 0.1252



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m260s[0m 281ms/step - accuracy: 0.9618 - loss: 0.1252 - val_accuracy: 0.9418 - val_loss: 0.1932
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 272ms/step - accuracy: 0.9711 - loss: 0.0940



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m265s[0m 284ms/step - accuracy: 0.9711 - loss: 0.0940 - val_accuracy: 0.9558 - val_loss: 0.1500
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m14s[0m 90ms/step - accuracy: 0.9424 - loss: 0.1829
Test Accuracy: 0.953499972820282


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▄█
val_loss,█▅▁

0,1
accuracy,0.97304
best_epoch,2.0
best_val_loss,0.14998
epoch,2.0
loss,0.08656
test_accuracy,0.9535
val_accuracy,0.95583
val_loss,0.14998


Running experiment with patch_size: 4x4, num_layers: 4, num_heads: 3
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 348ms/step - accuracy: 0.8053 - loss: 0.6433



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m321s[0m 362ms/step - accuracy: 0.8054 - loss: 0.6429 - val_accuracy: 0.9137 - val_loss: 0.2806
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 350ms/step - accuracy: 0.9591 - loss: 0.1330



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m304s[0m 360ms/step - accuracy: 0.9591 - loss: 0.1329 - val_accuracy: 0.9520 - val_loss: 0.1529
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 353ms/step - accuracy: 0.9697 - loss: 0.0985



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m308s[0m 365ms/step - accuracy: 0.9697 - loss: 0.0985 - val_accuracy: 0.9510 - val_loss: 0.1588
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m19s[0m 119ms/step - accuracy: 0.9386 - loss: 0.1944
Test Accuracy: 0.9508000016212463


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁██
val_loss,█▁▁

0,1
accuracy,0.97252
best_epoch,1.0
best_val_loss,0.15289
epoch,2.0
loss,0.09054
test_accuracy,0.9508
val_accuracy,0.951
val_loss,0.15877


Running experiment with patch_size: 4x4, num_layers: 4, num_heads: 4
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 448ms/step - accuracy: 0.8058 - loss: 0.6149



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m409s[0m 465ms/step - accuracy: 0.8059 - loss: 0.6145 - val_accuracy: 0.9190 - val_loss: 0.2704
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 452ms/step - accuracy: 0.9590 - loss: 0.1336



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m451s[0m 476ms/step - accuracy: 0.9590 - loss: 0.1336 - val_accuracy: 0.9515 - val_loss: 0.1634
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 448ms/step - accuracy: 0.9703 - loss: 0.0978



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m391s[0m 463ms/step - accuracy: 0.9703 - loss: 0.0978 - val_accuracy: 0.9612 - val_loss: 0.1208
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m21s[0m 137ms/step - accuracy: 0.9533 - loss: 0.1349
Test Accuracy: 0.9625999927520752


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▆█
val_loss,█▃▁

0,1
accuracy,0.97317
best_epoch,2.0
best_val_loss,0.12084
epoch,2.0
loss,0.08874
test_accuracy,0.9626
val_accuracy,0.96117
val_loss,0.12084


Running experiment with patch_size: 4x4, num_layers: 5, num_heads: 2
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 336ms/step - accuracy: 0.7879 - loss: 0.7027



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m311s[0m 348ms/step - accuracy: 0.7880 - loss: 0.7023 - val_accuracy: 0.9270 - val_loss: 0.2334
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 341ms/step - accuracy: 0.9610 - loss: 0.1275



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m298s[0m 353ms/step - accuracy: 0.9610 - loss: 0.1275 - val_accuracy: 0.9587 - val_loss: 0.1235
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 344ms/step - accuracy: 0.9706 - loss: 0.0974



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m326s[0m 357ms/step - accuracy: 0.9706 - loss: 0.0974 - val_accuracy: 0.9618 - val_loss: 0.1189
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m17s[0m 109ms/step - accuracy: 0.9530 - loss: 0.1389
Test Accuracy: 0.963100016117096


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▇█
val_loss,█▁▁

0,1
accuracy,0.97252
best_epoch,2.0
best_val_loss,0.11892
epoch,2.0
loss,0.09066
test_accuracy,0.9631
val_accuracy,0.96183
val_loss,0.11892


Running experiment with patch_size: 4x4, num_layers: 5, num_heads: 3
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 441ms/step - accuracy: 0.7850 - loss: 0.7200



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m412s[0m 466ms/step - accuracy: 0.7851 - loss: 0.7195 - val_accuracy: 0.9452 - val_loss: 0.1770
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 444ms/step - accuracy: 0.9604 - loss: 0.1322



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m395s[0m 468ms/step - accuracy: 0.9604 - loss: 0.1322 - val_accuracy: 0.9537 - val_loss: 0.1484
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 442ms/step - accuracy: 0.9709 - loss: 0.0970



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m386s[0m 458ms/step - accuracy: 0.9709 - loss: 0.0970 - val_accuracy: 0.9600 - val_loss: 0.1165
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 142ms/step - accuracy: 0.9507 - loss: 0.1412
Test Accuracy: 0.9620000123977661


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▅█
val_loss,█▅▁

0,1
accuracy,0.97257
best_epoch,2.0
best_val_loss,0.11648
epoch,2.0
loss,0.09158
test_accuracy,0.962
val_accuracy,0.96
val_loss,0.11648


Running experiment with patch_size: 4x4, num_layers: 5, num_heads: 4
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 556ms/step - accuracy: 0.7920 - loss: 0.6798



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m508s[0m 582ms/step - accuracy: 0.7921 - loss: 0.6794 - val_accuracy: 0.9370 - val_loss: 0.2015
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 559ms/step - accuracy: 0.9587 - loss: 0.1335



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m492s[0m 583ms/step - accuracy: 0.9587 - loss: 0.1335 - val_accuracy: 0.9470 - val_loss: 0.1742
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 563ms/step - accuracy: 0.9680 - loss: 0.1027



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m491s[0m 582ms/step - accuracy: 0.9680 - loss: 0.1027 - val_accuracy: 0.9597 - val_loss: 0.1302
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m30s[0m 188ms/step - accuracy: 0.9479 - loss: 0.1606
Test Accuracy: 0.9592000246047974


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▄█
val_loss,█▅▁

0,1
accuracy,0.97093
best_epoch,2.0
best_val_loss,0.13021
epoch,2.0
loss,0.09324
test_accuracy,0.9592
val_accuracy,0.95967
val_loss,0.13021


Running experiment with patch_size: 4x4, num_layers: 6, num_heads: 2
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 417ms/step - accuracy: 0.8093 - loss: 0.6210



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m388s[0m 433ms/step - accuracy: 0.8094 - loss: 0.6206 - val_accuracy: 0.9308 - val_loss: 0.2219
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 423ms/step - accuracy: 0.9613 - loss: 0.1288



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m386s[0m 438ms/step - accuracy: 0.9613 - loss: 0.1288 - val_accuracy: 0.9465 - val_loss: 0.1712
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 424ms/step - accuracy: 0.9697 - loss: 0.1000



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m379s[0m 449ms/step - accuracy: 0.9697 - loss: 0.1000 - val_accuracy: 0.9633 - val_loss: 0.1113
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m22s[0m 143ms/step - accuracy: 0.9549 - loss: 0.1296
Test Accuracy: 0.965499997138977


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▄█
val_loss,█▅▁

0,1
accuracy,0.97193
best_epoch,2.0
best_val_loss,0.11128
epoch,2.0
loss,0.09157
test_accuracy,0.9655
val_accuracy,0.96333
val_loss,0.11128


Running experiment with patch_size: 4x4, num_layers: 6, num_heads: 3
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 529ms/step - accuracy: 0.8092 - loss: 0.6444



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m485s[0m 548ms/step - accuracy: 0.8094 - loss: 0.6440 - val_accuracy: 0.9372 - val_loss: 0.2029
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 534ms/step - accuracy: 0.9596 - loss: 0.1314



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m466s[0m 552ms/step - accuracy: 0.9596 - loss: 0.1314 - val_accuracy: 0.9588 - val_loss: 0.1317
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 537ms/step - accuracy: 0.9711 - loss: 0.0949



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m509s[0m 562ms/step - accuracy: 0.9711 - loss: 0.0948 - val_accuracy: 0.9603 - val_loss: 0.1245
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 173ms/step - accuracy: 0.9565 - loss: 0.1338
Test Accuracy: 0.9652000069618225


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁██
val_loss,█▂▁

0,1
accuracy,0.973
best_epoch,2.0
best_val_loss,0.12453
epoch,2.0
loss,0.0883
test_accuracy,0.9652
val_accuracy,0.96033
val_loss,0.12453


Running experiment with patch_size: 4x4, num_layers: 6, num_heads: 4
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 679ms/step - accuracy: 0.8024 - loss: 0.6563



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m640s[0m 732ms/step - accuracy: 0.8025 - loss: 0.6559 - val_accuracy: 0.9477 - val_loss: 0.1672
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 724ms/step - accuracy: 0.9612 - loss: 0.1241



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m656s[0m 773ms/step - accuracy: 0.9612 - loss: 0.1241 - val_accuracy: 0.9497 - val_loss: 0.1615
Epoch 3/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 722ms/step - accuracy: 0.9702 - loss: 0.0995



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m661s[0m 748ms/step - accuracy: 0.9702 - loss: 0.0995 - val_accuracy: 0.9498 - val_loss: 0.1623
[1m157/157[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m36s[0m 232ms/step - accuracy: 0.9422 - loss: 0.1738
Test Accuracy: 0.9514999985694885


0,1
accuracy,▁▇█
epoch,▁▅█
loss,█▂▁
test_accuracy,▁
val_accuracy,▁▇█
val_loss,█▁▂

0,1
accuracy,0.97207
best_epoch,1.0
best_val_loss,0.16153
epoch,2.0
loss,0.09006
test_accuracy,0.9515
val_accuracy,0.94983
val_loss,0.16226


Running experiment with patch_size: 4x4, num_layers: 7, num_heads: 2
Epoch 1/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 487ms/step - accuracy: 0.7938 - loss: 0.6689



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m459s[0m 513ms/step - accuracy: 0.7939 - loss: 0.6685 - val_accuracy: 0.9425 - val_loss: 0.1846
Epoch 2/3
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 483ms/step - accuracy: 0.9614 - loss: 0.1253



[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m428s[0m 507ms/step - accuracy: 0.9614 - loss: 0.1253 - val_accuracy: 0.9552 - val_loss: 0.1397
Epoch 3/3
[1m204/844[0m [32m━━━━[0m[37m━━━━━━━━━━━━━━━━[0m [1m5:08[0m 482ms/step - accuracy: 0.9641 - loss: 0.1081

In [None]:
!pip install wandb --upgrade




## Summary

Each experiment has been logged to WandB, and the test accuracies for the various configurations are printed at the end. You can inspect the detailed logs and visualizations on your WandB project page. This notebook provides a complete pipeline from data loading to model evaluation with extensive documentation to help you understand each step.