# Module 3: Vision Transformers in Keras
---

In [None]:
# Import necessary libraries
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.models import load_model, Model
from tensorflow.keras.layers import (
    Input, Dense, Flatten, Dropout, LayerNormalization,
    MultiHeadAttention, GlobalAveragePooling1D, Reshape,
    Add, Embedding
)
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator

print(f"TensorFlow version: {tf.__version__}")

In [None]:
# Define paths and parameters
dataset_path = './images_dataSAT/'
IMG_SIZE = (64, 64)
BATCH_SIZE = 32
EPOCHS = 20
NUM_CLASSES = 2
LEARNING_RATE = 0.001

# ViT parameters
PATCH_SIZE = 8
NUM_PATCHES = (IMG_SIZE[0] // PATCH_SIZE) ** 2  # (64/8)^2 = 64 patches
PROJECTION_DIM = 64
NUM_HEADS = 4
TRANSFORMER_LAYERS = 4
MLP_DIM = 128

print(f"Number of patches: {NUM_PATCHES}")
print(f"Projection dimension: {PROJECTION_DIM}")

## Task 1: Load the pre-trained CNN model in the cnn_model variable using the load_model() function and print the model summary using the summary() method.

In [None]:
# Task 1: Load pre-trained CNN model and print summary
try:
    cnn_model = load_model('best_model.keras')
    print("Pre-trained CNN model loaded successfully!")
except:
    print("Pre-trained model file not found. Creating a new CNN model for demonstration.")
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
    
    cnn_model = Sequential([
        Conv2D(32, (3, 3), activation='relu', padding='same', input_shape=(64, 64, 3), name='conv2d_0'),
        BatchNormalization(name='bn_0'),
        tf.keras.layers.MaxPooling2D((2, 2), name='maxpool_0'),
        
        Conv2D(64, (3, 3), activation='relu', padding='same', name='conv2d_1'),
        BatchNormalization(name='bn_1'),
        tf.keras.layers.MaxPooling2D((2, 2), name='maxpool_1'),
        
        Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2d_2'),
        BatchNormalization(name='bn_2'),
        tf.keras.layers.MaxPooling2D((2, 2), name='maxpool_2'),
        
        Conv2D(256, (3, 3), activation='relu', padding='same', name='conv2d_3'),
        BatchNormalization(name='bn_3'),
        tf.keras.layers.MaxPooling2D((2, 2), name='maxpool_3'),
        
        Flatten(name='flatten'),
        Dense(512, activation='relu', name='dense_0'),
        Dropout(0.5, name='dropout_0'),
        Dense(256, activation='relu', name='dense_1'),
        Dropout(0.4, name='dropout_1'),
        Dense(128, activation='relu', name='dense_2'),
        Dropout(0.3, name='dropout_2'),
        Dense(64, activation='relu', name='dense_3'),
        Dropout(0.2, name='dropout_3'),
        Dense(1, activation='sigmoid', name='output')
    ])
    cnn_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Print model summary
print("\nCNN Model Summary:")
cnn_model.summary()

## Task 2: Based on model.summary(), get the name of the layer from the CNN model for feature extraction in the variable feature_layer_name.

In [None]:
# Task 2: Get the feature extraction layer name
# We want the last convolutional or pooling layer before the Flatten layer
# This layer outputs spatial feature maps that can be used as input to the ViT

print("All layer names in the CNN model:")
print("=" * 50)
for i, layer in enumerate(cnn_model.layers):
    print(f"Layer {i}: {layer.name} ({layer.__class__.__name__}) -> Output shape: {layer.output_shape}")

# Select the last convolutional/pooling layer for feature extraction
# Typically the last layer before Flatten that still has spatial dimensions
feature_layer_name = None
for layer in cnn_model.layers:
    if isinstance(layer, (tf.keras.layers.Conv2D, tf.keras.layers.MaxPooling2D, tf.keras.layers.BatchNormalization)):
        feature_layer_name = layer.name

print(f"\nSelected feature extraction layer: '{feature_layer_name}'")

# Create feature extraction model
feature_model = Model(
    inputs=cnn_model.input,
    outputs=cnn_model.get_layer(feature_layer_name).output
)

print(f"Feature model output shape: {feature_model.output_shape}")

In [None]:
# Define Transformer Encoder Block
def transformer_encoder(inputs, num_heads, projection_dim, mlp_dim, dropout_rate=0.1):
    """
    Transformer encoder block with multi-head self-attention and MLP.
    """
    # Layer Normalization 1
    x1 = LayerNormalization(epsilon=1e-6)(inputs)
    
    # Multi-Head Self-Attention
    attention_output = MultiHeadAttention(
        num_heads=num_heads,
        key_dim=projection_dim,
        dropout=dropout_rate
    )(x1, x1)
    attention_output = Dropout(dropout_rate)(attention_output)
    
    # Skip Connection 1
    x2 = Add()([attention_output, inputs])
    
    # Layer Normalization 2
    x3 = LayerNormalization(epsilon=1e-6)(x2)
    
    # MLP
    x3 = Dense(mlp_dim, activation='gelu')(x3)
    x3 = Dropout(dropout_rate)(x3)
    x3 = Dense(projection_dim)(x3)
    x3 = Dropout(dropout_rate)(x3)
    
    # Skip Connection 2
    output = Add()([x3, x2])
    
    return output

print("Transformer encoder block defined successfully.")

In [None]:
# Define the CNN-ViT Hybrid Model builder
def build_cnn_vit_hybrid(feature_model, num_patches, projection_dim, 
                          num_heads, transformer_layers, mlp_dim, 
                          num_classes, dropout_rate=0.1):
    """
    Build a hybrid CNN-ViT model.
    CNN extracts features, ViT processes them with self-attention.
    """
    # Input
    inputs = feature_model.input
    
    # CNN Feature Extraction
    cnn_features = feature_model(inputs)
    
    # Reshape CNN features to sequence of patches
    # e.g., (batch, 4, 4, 256) -> (batch, 16, 256)
    feature_shape = feature_model.output_shape
    h, w, c = feature_shape[1], feature_shape[2], feature_shape[3]
    seq_length = h * w
    
    x = Reshape((seq_length, c))(cnn_features)
    
    # Linear projection to projection_dim
    x = Dense(projection_dim)(x)
    
    # Add positional embedding
    positions = tf.range(start=0, limit=seq_length, delta=1)
    pos_embedding = Embedding(input_dim=seq_length, output_dim=projection_dim)(positions)
    x = x + pos_embedding
    
    # Transformer Encoder Blocks
    for _ in range(transformer_layers):
        x = transformer_encoder(x, num_heads, projection_dim, mlp_dim, dropout_rate)
    
    # Global Average Pooling
    x = LayerNormalization(epsilon=1e-6)(x)
    x = GlobalAveragePooling1D()(x)
    
    # Classification Head
    x = Dropout(dropout_rate)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(dropout_rate)(x)
    
    if num_classes == 2:
        outputs = Dense(1, activation='sigmoid')(x)
    else:
        outputs = Dense(num_classes, activation='softmax')(x)
    
    model = Model(inputs=inputs, outputs=outputs)
    return model

print("build_cnn_vit_hybrid function defined successfully.")

## Task 3: Define the model architecture in a variable named hybrid_model using the build_cnn_vit_hybrid function.

In [None]:
# Task 3: Build the hybrid CNN-ViT model
hybrid_model = build_cnn_vit_hybrid(
    feature_model=feature_model,
    num_patches=NUM_PATCHES,
    projection_dim=PROJECTION_DIM,
    num_heads=NUM_HEADS,
    transformer_layers=TRANSFORMER_LAYERS,
    mlp_dim=MLP_DIM,
    num_classes=NUM_CLASSES,
    dropout_rate=0.1
)

print("Hybrid CNN-ViT Model Architecture:")
hybrid_model.summary()

print(f"\nTotal parameters: {hybrid_model.count_params():,}")

## Task 4: Compile the model hybrid_model.

In [None]:
# Task 4: Compile the hybrid model
hybrid_model.compile(
    optimizer=Adam(learning_rate=LEARNING_RATE),
    loss='binary_crossentropy',
    metrics=['accuracy']
)

print("Hybrid model compiled successfully!")
print(f"  Optimizer: Adam (lr={LEARNING_RATE})")
print(f"  Loss: binary_crossentropy")
print(f"  Metrics: accuracy")

## Task 5: Define the training configuration of the hybrid_model.

In [None]:
# Task 5: Define training configuration

# Data generators
train_datagen = ImageDataGenerator(
    rescale=1.0/255.0,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest',
    validation_split=0.2
)

train_generator = train_datagen.flow_from_directory(
    dataset_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='training',
    shuffle=True,
    seed=42
)

validation_generator = train_datagen.flow_from_directory(
    dataset_path,
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='binary',
    subset='validation',
    shuffle=False,
    seed=42
)

# Callbacks
checkpoint_callback = ModelCheckpoint(
    filepath='best_hybrid_model.keras',
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    verbose=1
)

early_stopping = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True,
    verbose=1
)

print("Training Configuration:")
print(f"  Epochs: {EPOCHS}")
print(f"  Batch size: {BATCH_SIZE}")
print(f"  Training samples: {train_generator.samples}")
print(f"  Validation samples: {validation_generator.samples}")
print(f"  Steps per epoch: {train_generator.samples // BATCH_SIZE}")
print(f"  Validation steps: {validation_generator.samples // BATCH_SIZE}")
print(f"  Callbacks: ModelCheckpoint, EarlyStopping")

In [None]:
# Train the hybrid model
history = hybrid_model.fit(
    train_generator,
    steps_per_epoch=train_generator.samples // BATCH_SIZE,
    epochs=EPOCHS,
    validation_data=validation_generator,
    validation_steps=validation_generator.samples // BATCH_SIZE,
    callbacks=[checkpoint_callback, early_stopping],
    verbose=1
)

print("\nTraining completed!")

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

axes[0].plot(history.history['loss'], label='Training Loss', color='blue', linewidth=2)
axes[0].plot(history.history['val_loss'], label='Validation Loss', color='red', linewidth=2)
axes[0].set_title('Hybrid Model Loss', fontsize=14, fontweight='bold')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history.history['accuracy'], label='Training Accuracy', color='blue', linewidth=2)
axes[1].plot(history.history['val_accuracy'], label='Validation Accuracy', color='red', linewidth=2)
axes[1].set_title('Hybrid Model Accuracy', fontsize=14, fontweight='bold')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('CNN-ViT Hybrid Model Training History', fontsize=16, fontweight='bold')
plt.tight_layout()
plt.show()

print(f"Best Validation Accuracy: {max(history.history['val_accuracy']):.4f}")
print(f"Lowest Validation Loss: {min(history.history['val_loss']):.4f}")

---
## All 5 tasks completed successfully.