# Module 3: Vision Transformers in Keras

This notebook completes all 5 tasks related to building a CNN + Vision Transformer hybrid model.

## Task 1: Load pre-trained CNN model

In [None]:

from tensorflow.keras.applications import ResNet50
from tensorflow.keras.models import Model

# Load ResNet50 as base CNN model
cnn_model = ResNet50(weights="imagenet", include_top=True)

# Print summary
cnn_model.summary()


## Task 2: Get the feature layer name

In [None]:

# Based on summary, choose final average pooling layer for feature extraction
feature_layer_name = "avg_pool"
print("Selected feature layer:", feature_layer_name)


## Task 3: Define hybrid model architecture

In [None]:

import tensorflow as tf
from tensorflow.keras import layers, models

def build_cnn_vit_hybrid(cnn_model, feature_layer_name, num_classes=10):
    # Extract features from CNN
    feature_extractor = Model(inputs=cnn_model.input, 
                              outputs=cnn_model.get_layer(feature_layer_name).output)
    
    inputs = layers.Input(shape=(224,224,3))
    x = feature_extractor(inputs, training=False)
    
    # Expand dims to feed into transformer
    x = layers.Reshape((1, x.shape[-1]))(x)
    
    # Transformer encoder
    transformer_block = layers.TransformerBlock(num_heads=4, key_dim=64, ff_dim=128)
    x = transformer_block(x)
    x = layers.GlobalAveragePooling1D()(x)
    
    outputs = layers.Dense(num_classes, activation="softmax")(x)
    
    model = models.Model(inputs, outputs, name="cnn_vit_hybrid")
    return model

# Build hybrid model
hybrid_model = build_cnn_vit_hybrid(cnn_model, feature_layer_name, num_classes=10)
hybrid_model.summary()


## Task 4: Compile hybrid model

In [None]:

hybrid_model.compile(
    optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
    loss="categorical_crossentropy",
    metrics=["accuracy"]
)


## Task 5: Define training configuration

In [None]:

# Training configuration (dummy example)
batch_size = 32
epochs = 5

callbacks = [
    tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=3, restore_best_weights=True)
]

print("Training config set:")
print("Batch size:", batch_size)
print("Epochs:", epochs)
print("Callbacks:", callbacks)
