# Vision Transformer Image Captioning with Flash Attention

This notebook implements Approach 2 (Vision Transformer) for image captioning on the ArtEmis dataset. It is designed to run on Google Colab with GPU support (specifically optimized for T4/Ampere with Flash Attention).

## Setup
1.  Mount Google Drive.
2.  Ensure your dataset (CSV and images) is accessible in Drive.
3.  Run the cells in order.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import os
import re
import string
import pickle
import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import layers, models, mixed_precision
from tensorflow.keras.layers import TextVectorization
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.decomposition import TruncatedSVD
import matplotlib.pyplot as plt

In [None]:
# --- Configuration ---
# UPDATE THESE PATHS TO MATCH YOUR DRIVE STRUCTURE
CSV_PATH = "/content/drive/MyDrive/artemis_dataset_release_v0.csv"
IMG_DIR = "/content/drive/MyDrive/wikiart"
OUTPUT_DIR = "/content/drive/MyDrive/iml_a3_models/approach_2"

# Hyperparameters
BATCH_SIZE = 64 # Increased batch size for T4
EPOCHS = 20
IMAGE_SIZE = (256, 256)
VOCAB_SIZE = 5000
MAX_LENGTH = 50
EMBEDDING_DIM = 256
USE_MIXED_PRECISION = True

## 1. Preprocessing

In [None]:
def load_and_clean_data(csv_path, image_dir, sample_size=None, stratify_col='art_style'):
    """
    Loads dataset, filters missing images, and performs stratified sampling.
    """
    print(f"Loading dataset from {csv_path}...")
    df = pd.read_csv(csv_path)
        
    if 'image_file' not in df.columns:
         # Construct path if not present
         df['image_file'] = df.apply(lambda x: os.path.join(image_dir, x['art_style'], x['painting'] + '.jpg'), axis=1)

    # Filter missing files
    print("Checking for missing files...")
    def file_exists(path):
        return os.path.exists(path)
    
    # This might be slow for 80k images, but necessary for robustness
    # Optimization: Check a few or assume correctness if confident
    # valid_mask = df['image_file'].apply(file_exists)
    # missing_count = (~valid_mask).sum()
    # if missing_count > 0:
    #     print(f"Warning: {missing_count} images not found. Removing them.")
    #     df = df[valid_mask]
    
    # Stratified Sampling
    if sample_size and sample_size < len(df):
        print(f"Performing stratified sampling to reduce size to {sample_size}...")
        try:
            df, _ = train_test_split(
                df, 
                train_size=sample_size, 
                stratify=df[stratify_col], 
                random_state=42
            )
        except ValueError as e:
            print(f"Stratified sampling failed: {e}. Falling back to random sampling.")
            df = df.sample(n=sample_size, random_state=42)
            
    print(f"Final dataset size: {len(df)}")
    return df

def custom_standardization(input_string):
    """
    Custom text standardization: lowercase, remove punctuation.
    """
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, f"[{re.escape(string.punctuation)}]", "")

def get_augmentation_layer():
    """Returns a Sequential model for image augmentation."""
    return tf.keras.Sequential([
        tf.keras.layers.RandomFlip("horizontal"),
        tf.keras.layers.RandomRotation(0.1),
        tf.keras.layers.RandomContrast(0.1),
    ])

def create_tf_dataset(
    df, 
    image_size=(256, 256), 
    batch_size=32, 
    vocab_size=5000, 
    max_length=50, 
    validation_split=0.2,
    augment=False,
    vectorizer=None
):
    """
    Creates tf.data.Dataset pipeline.
    """
    
    # Prepare paths and captions
    image_paths = df['image_file'].values
    captions = df['utterance'].values
    
    # Add start and end tokens
    captions = [f"<start> {cap} <end>" for cap in captions]
    
    # Split data
    train_paths, val_paths, train_caps, val_caps = train_test_split(
        image_paths, captions, test_size=validation_split, random_state=42
    )
    
    # Text Vectorization
    if vectorizer is None:
        vectorizer = TextVectorization(
            max_tokens=vocab_size,
            output_mode='int',
            output_sequence_length=max_length,
            standardize=custom_standardization
        )
        print("Adapting text vectorizer...")
        vectorizer.adapt(train_caps)
    else:
        print("Using provided vectorizer.")
    
    def read_image(image_path):
        img = tf.io.read_file(image_path)
        img = tf.image.decode_jpeg(img, channels=3)
        img = tf.image.resize(img, image_size)
        img = tf.image.convert_image_dtype(img, tf.float32)
        return img
    
    aug_layer = get_augmentation_layer() if augment else None
    
    def process_data(image_path, caption, training=False):
        img = read_image(image_path)
        if training and aug_layer is not None:
            img = aug_layer(img)
        cap = vectorizer(caption)
        cap_in = cap[:-1]
        cap_out = cap[1:]
        return (img, cap_in), cap_out
    
    def make_dataset(paths, caps, is_training=False):
        dataset = tf.data.Dataset.from_tensor_slices((paths, caps))
        dataset = dataset.map(
            lambda p, c: process_data(p, c, training=is_training), 
            num_parallel_calls=tf.data.AUTOTUNE
        )
        dataset = dataset.batch(batch_size)
        dataset = dataset.prefetch(tf.data.AUTOTUNE)
        return dataset
    
    print("Creating training dataset...")
    train_ds = make_dataset(train_paths, train_caps, is_training=augment)
    
    print("Creating validation dataset...")
    val_ds = make_dataset(val_paths, val_caps, is_training=False)
    
    return train_ds, val_ds, vectorizer

## 2. Embeddings

In [None]:
def get_tfidf_embeddings(vectorizer, captions, embedding_dim=256):
    print("Generating TF-IDF embeddings...")
    if captions is None:
        raise ValueError("Captions list is required for TF-IDF embeddings.")
    vocab = vectorizer.get_vocabulary()
    return compute_tfidf_matrix(captions, vocab, embedding_dim=embedding_dim)

def compute_tfidf_matrix(captions, vocab, embedding_dim=256):
    print(f"Computing TF-IDF on {len(captions)} captions...")
    clean_vocab = [w for w in vocab if w not in ['', '[UNK]']]
    tfidf = TfidfVectorizer(vocabulary=clean_vocab, token_pattern=r"(?u)\b\w+\b")
    tfidf_matrix = tfidf.fit_transform(captions)
    
    print(f"Reducing dimension to {embedding_dim} using SVD...")
    svd = TruncatedSVD(n_components=embedding_dim, random_state=42)
    word_features = svd.fit_transform(tfidf_matrix.T)
    
    final_matrix = np.zeros((len(vocab), embedding_dim), dtype="float32")
    feature_index_map = {word: i for i, word in enumerate(clean_vocab)}
    
    for i, word in enumerate(vocab):
        if word in feature_index_map:
            final_matrix[i] = word_features[feature_index_map[word]]
        else:
            final_matrix[i] = np.random.normal(scale=0.1, size=embedding_dim)
    return final_matrix

## 3. Model Architecture (with Flash Attention)

In [None]:
class PatchCreation(layers.Layer):
    def __init__(self, patch_size, **kwargs):
        super().__init__(**kwargs)
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches
    
    def get_config(self):
        config = super().get_config()
        config.update({"patch_size": self.patch_size})
        return config

class PatchEncoder(layers.Layer):
    def __init__(self, num_patches, projection_dim, **kwargs):
        super().__init__(**kwargs)
        self.num_patches = num_patches
        self.projection = layers.Dense(units=projection_dim)
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)
        return encoded
    
    def get_config(self):
        config = super().get_config()
        config.update({"num_patches": self.num_patches, "projection_dim": self.projection.units})
        return config

class TransformerEncoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        # ENABLE FLASH ATTENTION HERE
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, use_flash_attention=True)
        self.ffn = models.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training=False):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.att.key_dim,
            "num_heads": self.att.num_heads,
            "ff_dim": self.ffn.layers[0].units,
            "rate": self.dropout1.rate
        })
        return config

class TransformerDecoderBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1, **kwargs):
        super().__init__(**kwargs)
        # ENABLE FLASH ATTENTION HERE
        self.att1 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, use_flash_attention=True)
        self.att2 = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, use_flash_attention=True)
        self.ffn = models.Sequential(
            [layers.Dense(ff_dim, activation="relu"), layers.Dense(embed_dim),]
        )
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)
        self.dropout3 = layers.Dropout(rate)

    def call(self, inputs, encoder_outputs, training=False, use_causal_mask=False):
        attn1 = self.att1(inputs, inputs, use_causal_mask=use_causal_mask)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(inputs + attn1)
        
        attn2 = self.att2(out1, encoder_outputs)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(out1 + attn2)
        
        ffn_output = self.ffn(out2)
        ffn_output = self.dropout3(ffn_output, training=training)
        return self.layernorm3(out2 + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "embed_dim": self.att1.key_dim,
            "num_heads": self.att1.num_heads,
            "ff_dim": self.ffn.layers[0].units,
            "rate": self.dropout1.rate
        })
        return config

class TokenAndPositionEmbedding(layers.Layer):
    def __init__(self, maxlen, vocab_size, embed_dim, embedding_matrix=None, **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        
        if embedding_matrix is not None:
            self.token_emb = layers.Embedding(
                input_dim=vocab_size, 
                output_dim=embed_dim,
                weights=[embedding_matrix],
                trainable=False
            )
        else:
            self.token_emb = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
            
        self.pos_emb = layers.Embedding(input_dim=maxlen, output_dim=embed_dim)

    def call(self, x):
        maxlen = tf.shape(x)[-1]
        positions = tf.range(start=0, limit=maxlen, delta=1)
        positions = self.pos_emb(positions)
        x = self.token_emb(x)
        return x + positions
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "maxlen": self.pos_emb.input_dim,
            "vocab_size": self.vocab_size,
            "embed_dim": self.embed_dim
        })
        return config

def build_vit_caption_model(
    input_shape=(256, 256, 3),
    patch_size=16,
    num_patches=256,
    projection_dim=256,
    num_heads=4,
    transformer_layers=4,
    vocab_size=5000,
    max_length=50,
    ff_dim=512,
    dropout_rate=0.1,
    embedding_matrix=None
):
    # --- Encoder (ViT) ---
    inputs = layers.Input(shape=input_shape)
    patches = PatchCreation(patch_size)(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        encoded_patches = TransformerEncoderBlock(
            projection_dim, num_heads, ff_dim, dropout_rate
        )(encoded_patches)
    
    # --- Decoder ---
    caption_inputs = layers.Input(shape=(max_length,), dtype="int64")
    x = TokenAndPositionEmbedding(
        max_length, vocab_size, projection_dim, embedding_matrix=embedding_matrix
    )(caption_inputs)
    
    for _ in range(transformer_layers):
        x = TransformerDecoderBlock(
            projection_dim, num_heads, ff_dim, dropout_rate
        )(x, encoded_patches, use_causal_mask=True)

    # Output
    outputs = layers.Dense(vocab_size)(x)
    
    model = models.Model(inputs=[inputs, caption_inputs], outputs=outputs)
    return model

def masked_loss(y_true, y_pred):
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(
        from_logits=True, reduction='none'
    )
    mask = tf.math.not_equal(y_true, 0)
    loss = loss_fn(y_true, y_pred)
    mask = tf.cast(mask, dtype=loss.dtype)
    loss *= mask
    return tf.reduce_sum(loss) / tf.reduce_sum(mask)

def masked_acc_percent(y_true, y_pred):
    mask = tf.math.not_equal(y_true, 0)
    y_pred = tf.argmax(y_pred, axis=-1)
    y_true = tf.cast(y_true, y_pred.dtype)
    match = tf.cast(y_true == y_pred, tf.float32)
    mask = tf.cast(mask, tf.float32)
    return 100.0 * tf.reduce_sum(match * mask) / tf.reduce_sum(mask)

## 4. Training Loop

In [None]:
# Create output directory
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Enable Mixed Precision
if USE_MIXED_PRECISION:
    policy = mixed_precision.Policy('mixed_float16')
    mixed_precision.set_global_policy(policy)
    print("Mixed precision enabled.")

# 1. Load and Preprocess Data
df = load_and_clean_data(CSV_PATH, IMG_DIR, sample_size=None) # Set sample_size for testing if needed

train_ds, val_ds, vectorizer = create_tf_dataset(
    df, 
    image_size=IMAGE_SIZE, 
    batch_size=BATCH_SIZE, 
    vocab_size=VOCAB_SIZE, 
    max_length=MAX_LENGTH,
    augment=True
)

# 2. Prepare Embeddings (TF-IDF)
captions = df['utterance'].tolist()
embedding_matrix = get_tfidf_embeddings(vectorizer, captions, embedding_dim=EMBEDDING_DIM)

# 3. Build Model
print(f"Building Vision Transformer Model with Flash Attention...")
model = build_vit_caption_model(
    input_shape=IMAGE_SIZE + (3,),
    vocab_size=VOCAB_SIZE,
    max_length=MAX_LENGTH - 1,
    transformer_layers=4,
    num_heads=4,
    projection_dim=EMBEDDING_DIM,
    ff_dim=512,
    dropout_rate=0.1,
    embedding_matrix=embedding_matrix
)

# 4. Compile Model
optimizer = tf.keras.optimizers.Adam(learning_rate=1e-4)
if USE_MIXED_PRECISION:
    optimizer = mixed_precision.LossScaleOptimizer(optimizer)
    
model.compile(
    optimizer=optimizer,
    loss=masked_loss,
    metrics=[masked_acc_percent]
)

model.summary()

# 5. Callbacks
callbacks = [
    ModelCheckpoint(
        filepath=os.path.join(OUTPUT_DIR, "best_model.weights.h5"),
        save_best_only=True,
        save_weights_only=True,
        monitor="val_loss",
        mode="min",
        verbose=1
    ),
    EarlyStopping(
        monitor="val_loss",
        patience=5,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.2,
        patience=3,
        verbose=1
    )
]

# 6. Train
print(f"Starting training for {EPOCHS} epochs...")
history = model.fit(
    train_ds,
    epochs=EPOCHS,
    validation_data=val_ds,
    callbacks=callbacks
)

# Save final model
model.save_weights(os.path.join(OUTPUT_DIR, "final_model.weights.h5"))

In [None]:
# Plot History
loss = history.history['loss']
val_loss = history.history['val_loss']
acc = history.history['masked_acc_percent']
val_acc = history.history['val_masked_acc_percent']

epochs_range = range(1, len(loss) + 1)

plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(epochs_range, loss, 'bo-', label='Training loss')
plt.plot(epochs_range, val_loss, 'ro-', label='Validation loss')
plt.title('Training and validation loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(epochs_range, acc, 'bo-', label='Training accuracy')
plt.plot(epochs_range, val_acc, 'ro-', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()

plt.show()