# Q7 — Vision Transformers in Keras (Hybrid)
Loads (demo) pre-trained CNN, picks feature layer, builds CNN+ViT hybrid, compiles and shows training config.

In [None]:
from tensorflow.keras import layers, models
from tensorflow.keras.models import load_model
import numpy as np

# Build and save a demo CNN then reload to simulate load_model()
demo = models.Sequential([
    layers.Input(shape=(64,64,3)),
    layers.Conv2D(16,3,activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(32,3,activation='relu', name='feat_layer'),
    layers.GlobalAveragePooling2D(),
    layers.Dense(2, activation='softmax')
])
demo.save('demo_cnn.h5')

# 1) Load pre-trained CNN
cnn_model = load_model('demo_cnn.h5')
print('Loaded cnn_model summary:')
cnn_model.summary()

# 2) Choose a feature layer name
feature_layer_name = 'feat_layer'
print('feature_layer_name:', feature_layer_name)

# 3) Define build_cnn_vit_hybrid
def build_cnn_vit_hybrid(cnn, feature_layer_name, num_classes=2, proj_dim=32, num_heads=2, depth=1):
    cnn.trainable = False
    feat_extractor = models.Model(inputs=cnn.input, outputs=cnn.get_layer(feature_layer_name).output)
    inputs = cnn.input
    x = feat_extractor(inputs)  # (B,H,W,C)
    x = layers.Reshape((-1, int(x.shape[-1])))(x)  # tokens
    x = layers.Dense(proj_dim)(x)
    for _ in range(depth):
        attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=proj_dim)(x, x)
        x = layers.Add()([x, attn])
        x = layers.LayerNormalization()(x)
        mlp = layers.Dense(proj_dim*2, activation='gelu')(x)
        mlp = layers.Dense(proj_dim)(mlp)
        x = layers.Add()([x, mlp])
    x = layers.GlobalAveragePooling1D()(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)
    return models.Model(inputs, outputs, name='cnn_vit_hybrid')

hybrid_model = build_cnn_vit_hybrid(cnn_model, feature_layer_name, num_classes=2)
hybrid_model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
print('Hybrid compiled. Training config: epochs=2, batch_size=8 (example)')