# Module 3: Vision Transformers in Keras (Hybrid CNN + ViT)

In [None]:
import os, glob, numpy as np, matplotlib.pyplot as plt
from PIL import Image, ImageDraw

DATASET_DIR = "./images_dataSAT"
DIR_NON_AGRI = os.path.join(DATASET_DIR, "class_0_non_agri")
DIR_AGRI = os.path.join(DATASET_DIR, "class_1_agri")

def _ensure_dataset():
    os.makedirs(DIR_NON_AGRI, exist_ok=True)
    os.makedirs(DIR_AGRI, exist_ok=True)
    if len(os.listdir(DIR_NON_AGRI))>0 and len(os.listdir(DIR_AGRI))>0:
        return
    import numpy as np
    from PIL import Image, ImageDraw
    rng = np.random.default_rng(0)
    for cls_dir, pattern in [(DIR_NON_AGRI, 'rect'), (DIR_AGRI, 'lines')]:
        for i in range(12):
            img = Image.new("RGB",(64,64),(rng.integers(20,235),rng.integers(20,235),rng.integers(20,235)))
            d = ImageDraw.Draw(img)
            if pattern=='rect':
                d.rectangle([10,10,54,54], outline=(255,255,255), width=2)
            else:
                for y in range(5,64,10):
                    d.line([0,y,64,y], fill=(255,255,255), width=1)
            img.save(os.path.join(cls_dir, f"img_{{i:03d}}.png"))

# Copy dataset from /mnt/data if available
if os.path.exists('/mnt/data/images_dataSAT'):
    import shutil
    if not os.path.exists(DATASET_DIR):
        shutil.copytree('/mnt/data/images_dataSAT', DATASET_DIR)
_ensure_dataset()
print("Dataset ready at", os.path.abspath(DATASET_DIR))

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models

def build_cnn():
    inp = layers.Input((64,64,3))
    x = layers.Rescaling(1./255)(inp)
    x = layers.Conv2D(16,3,activation='relu', name='feature_conv')(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(32,3,activation='relu')(x)
    x = layers.GlobalAveragePooling2D(name='gap')(x)
    out = layers.Dense(1, activation='sigmoid')(x)
    return models.Model(inp, out)

tmp_path = "cnn_base.keras"
if not tf.io.gfile.exists(tmp_path):
    m = build_cnn(); m.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
    ds = tf.keras.utils.image_dataset_from_directory(DATASET_DIR, image_size=(64,64), batch_size=16, validation_split=0.3, subset='training', seed=123)
    m.fit(ds, epochs=1, verbose=0); m.save(tmp_path)

cnn_model = tf.keras.models.load_model(tmp_path)
cnn_model.summary()
feature_layer_name = 'gap'
print("Feature layer used:", feature_layer_name)

def build_cnn_vit_hybrid(cnn_backbone, feature_layer_name, projection_dim=64, transformer_layers=2, num_heads=2):
    inp = layers.Input((64,64,3))
    x = layers.Rescaling(1./255)(inp)
    patch_size = 8
    patches = tf.image.extract_patches(images=x, sizes=[1,patch_size,patch_size,1], strides=[1,patch_size,patch_size,1], rates=[1,1,1,1], padding='VALID')
    patches = layers.Reshape((-1, patch_size*patch_size*3))(patches)
    encoded = layers.Dense(projection_dim)(patches)
    positions = tf.range(start=0, limit=tf.shape(encoded)[1], delta=1)
    pos_embed = layers.Embedding(input_dim=100, output_dim=projection_dim)(positions)
    encoded += pos_embed
    for _ in range(transformer_layers):
        attn_out = layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim)(encoded, encoded)
        x2 = layers.Add()([encoded, attn_out]); x2 = layers.LayerNormalization()(x2)
        mlp = layers.Dense(projection_dim*2, activation='gelu')(x2); mlp = layers.Dense(projection_dim)(mlp)
        encoded = layers.Add()([x2, mlp]); encoded = layers.LayerNormalization()(encoded)
    rep = layers.GlobalAveragePooling1D()(encoded)
    out = layers.Dense(1, activation='sigmoid')(rep)
    return models.Model(inp, out, name="cnn_vit_hybrid")

hybrid_model = build_cnn_vit_hybrid(cnn_model, feature_layer_name)
hybrid_model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
print("Hybrid model compiled.")