In [None]:
# Uncomment and run if you haven't install this
# !pip install tensorflow

## **IMPORTING LIBRARIES**

In [None]:
import os
import pickle
import numpy as np
import tensorflow as tf
import time

from tqdm.notebook import tqdm
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, add

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
BASE_DIR = '/content/drive/MyDrive/Flickr8k'
WORKING_DIR = '/content/drive/MyDrive/Flickr8k/working'

## **IMAGE FEATURE EXTRACTION**

In [None]:
model = InceptionV3()
model = Model(
    inputs=model.inputs,
    outputs=model.get_layer('mixed10').output
)
print(model.summary())

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/inception_v3/inception_v3_weights_tf_dim_ordering_tf_kernels.h5
[1m96112376/96112376[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m5s[0m 0us/step


None


In [None]:
# extracting features from image (SPATIAL FEATURES for ATTENTION)

features = {}
directory = os.path.join(BASE_DIR, 'Images')

for img_name in tqdm(os.listdir(directory)):
    # load image
    img_path = os.path.join(directory, img_name)
    image = load_img(img_path, target_size=(299, 299))

    # convert image to array
    image = img_to_array(image)

    # expand dims to match model input
    image = np.expand_dims(image, axis=0)

    # preprocess for InceptionV3
    image = preprocess_input(image)

    # extract features
    feature = model.predict(image, verbose=0)
    # feature shape: (1, 8, 8, 2048)

    # reshape to (64, 2048) for attention
    feature = feature.reshape((64, 2048))

    # get image ID
    image_id = img_name.split('.')[0]

    # store feature
    features[image_id] = feature

  0%|          | 0/8115 [00:00<?, ?it/s]

In [None]:
#store features in pickle file
os.makedirs(WORKING_DIR, exist_ok=True)
pickle.dump(features, open(os.path.join(WORKING_DIR, 'features.pkl'), 'wb'))

In [None]:
any_key = list(features.keys())[0]
print(features[any_key].shape)

(64, 2048)


In [None]:
# load features from pickle
with open(os.path.join(WORKING_DIR, 'features.pkl'), 'rb') as f:
  features = pickle.load(f)

## **CAPTION CLEANING**

### Load the captions file

In [None]:
# path to captions file
caption_file = os.path.join(BASE_DIR, 'Flickr8k.token.txt')

# load captions
with open(caption_file, 'r') as f:
    captions = f.read()

### Create image → captions mapping

In [None]:
mapping = {}

for line in captions.split('\n'):
    if len(line) < 1:
        continue

    # split line into image_id and caption
    image_caption = line.split('\t')
    image_id = image_caption[0].split('#')[0].split('.')[0]
    caption = image_caption[1]

    # create list if image not already in mapping
    if image_id not in mapping:
        mapping[image_id] = []

    # store caption
    mapping[image_id].append(caption)


### Clean captions

In [None]:
import string

def clean_captions(mapping):
    for image_id, captions in mapping.items():
        for i in range(len(captions)):
            caption = captions[i]

            # 1. convert to lowercase
            caption = caption.lower()

            # 2. remove punctuation
            caption = caption.translate(
                str.maketrans('', '', string.punctuation)
            )

            # 3. remove words containing numbers
            caption = ' '.join(
                w for w in caption.split() if w.isalpha()
            )

            # 4. remove dataset artifacts: 'start' and 'end'
            caption = ' '.join(
                w for w in caption.split()
                if w not in ['start', 'end']
            )

            # 5. remove unwanted single characters (keep grammar word 'a')
            caption = ' '.join(
                w for w in caption.split()
                if len(w) > 1 or w == 'a'
            )

            # 6. add start and end tokens
            caption = '<start> ' + caption + ' <end>'

            captions[i] = caption


In [None]:
clean_captions(mapping)

### Remove captions for images without features

In [None]:
# 1. Get the list of IDs we have features for
image_ids = set(features.keys())

# 2. Create a brand new empty dictionary
new_mapping = {}

# 3. Loop through the old mapping and only copy what we need
for img, caps in mapping.items():
    if img in image_ids:
        new_mapping[img] = caps

# 4. Replace the old mapping with the cleaned one
mapping = new_mapping

### Save cleaned captions

In [None]:
# Define where the file will live
file_path = os.path.join(WORKING_DIR, 'captions.pkl')

# store captions in pickle file
with open(file_path, 'wb') as f:
    pickle.dump(mapping, f)



In [None]:
# load captions in pickle file
with open(file_path, 'rb') as f:
    mapping = pickle.load(f)

In [None]:
len(mapping)


8091

In [None]:
# Sanity Check
import random

# randomly select an image to inspect its cleaned captions
img_id = random.choice(list(mapping.keys()))
print("Image ID:", img_id)

# print all captions associated with the selected image
for cap in mapping[img_id]:
    print(cap)

Image ID: 186890601_8a6b0f1769
<start> a man in a maroon bathing suit swings on a rope on a lake <end>
<start> a man in red shorts swinging on a rope over a lake <end>
<start> a man is swinging on a rope above the water <end>
<start> a man is swinging on a rope over water <end>
<start> this man is swinging on the rope swing out over the blue water <end>


### **TOKENIZATION**

In [None]:
# sanity check: ensure every caption has <start> and <end> tokens before tokenization
for caps in mapping.values():
    for c in caps:
        assert c.startswith('<start> ') and c.endswith(' <end>')

In [None]:
# collect all cleaned captions into a single list for tokenizer fitting
all_captions = []
for caps in mapping.values():
    all_captions.extend(caps)

# initialize tokenizer and build vocabulary from captions
tokenizer = Tokenizer(filters='')
tokenizer.fit_on_texts(all_captions)


# Save tokenizer for caption generation
with open(os.path.join(WORKING_DIR, 'tokenizer.pkl'), 'wb') as f:
    pickle.dump(tokenizer, f)

print("Tokenizer saved successfully")

# vocabulary size (+1 to account for padding index)
vocab_size = len(tokenizer.word_index) + 1

Tokenizer saved successfully


In [None]:
# get maximum length of the caption available
max_length = max(len(caption.split()) for caption in all_captions)
max_length

37

### **TRAIN / VALIDATION SPLIT**

In [None]:
def load_image_list(path):
    with open(path, 'r') as f:
        return set(
            img.strip().split('.')[0]   # REMOVE .jpg
            for img in f.read().strip().split('\n')
        )

train_images = load_image_list(os.path.join(BASE_DIR, 'Flickr_8k.trainImages.txt'))
val_images   = load_image_list(os.path.join(BASE_DIR, 'Flickr_8k.devImages.txt'))

print("Train list size:", len(train_images))
print("Val list size:", len(val_images))


Train list size: 6000
Val list size: 1000


In [None]:
train_captions = {}
val_captions = {}

for image_id, caps in mapping.items():
    if image_id in train_images:
        train_captions[image_id] = caps
    elif image_id in val_images:
        val_captions[image_id] = caps

print("Training images:", len(train_captions))
print("Validation images:", len(val_captions))


Training images: 6000
Validation images: 1000


## **Building the Model**

In [None]:
class CNN_Encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        # shape after fc == (batch_size, 64, embedding_dim)
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x

In [None]:
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        # features shape: (batch_size, 64, embedding_dim)
        # hidden shape: (batch_size, hidden_size)

        # Expand hidden state to match features time axis
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # Calculate attention score
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        # Calculate attention weights (Softmax over 64 regions)
        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        # Create context vector (weighted sum of features)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [None]:
class RNN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNN_Decoder, self).__init__()
        self.units = units

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)

        # Layer 1: Takes image context + word embedding
        self.lstm1 = tf.keras.layers.LSTM(self.units,
                                         return_sequences=True,
                                         return_state=True,
                                         recurrent_initializer='glorot_uniform')

        # Layer 2: Takes output of Layer 1
        self.lstm2 = tf.keras.layers.LSTM(self.units,
                                         return_sequences=True,
                                         return_state=True,
                                         recurrent_initializer='glorot_uniform')

        self.fc1 = tf.keras.layers.Dense(self.units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)

        self.attention = BahdanauAttention(self.units)

    def call(self, x, features, hidden):
        # 1. Attention: Get context vector
        context_vector, attention_weights = self.attention(features, hidden)

        # 2. Embedding: Convert word index to vector
        x = self.embedding(x)

        # 3. Concatenate: Merge context vector with word embedding
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # 4. LSTM Layer 1
        output1, state1_h, state1_c = self.lstm1(x)

        # 5. LSTM Layer 2 (Stacked)
        # We pass the output of the first layer into the second
        output2, state2_h, state2_c = self.lstm2(output1)

        # 6. Output Generation
        x = self.fc1(output2)
        x = tf.reshape(x, (-1, x.shape[2]))
        x = self.fc2(x)

        # We return state2_h as the hidden state for the next time step's attention
        return x, state2_h, attention_weights

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))

In [None]:
# Set Hyperparameters
embedding_dim = 256
units = 512
vocab_size = len(tokenizer.word_index) + 1

# Instantiate the building blocks
encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, vocab_size)

# Setup Loss
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

## **TRAINING THE MODEL**

### Prepare image feature arrays for training & validation


In [None]:
def create_training_data(captions_dict, tokenizer, max_length):
    image_ids = []
    caption_seqs = []

    for image_id, captions in captions_dict.items():
        for caption in captions:
            seq = tokenizer.texts_to_sequences([caption])[0]
            seq = pad_sequences([seq], maxlen=max_length, padding='post')[0]

            image_ids.append(image_id)
            caption_seqs.append(seq)

    return image_ids, np.array(caption_seqs)


In [None]:
train_img_ids, y_train_seq = create_training_data(
    train_captions, tokenizer, max_length
)

### **LOSS FUNCTION**

In [None]:
def loss_function(real, pred):
    # Mask padding tokens (0)
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)


### One training step (core logic)

In [None]:
@tf.function
def train_step(img_tensor, target):
    batch_size = tf.shape(target)[0]
    hidden = decoder.reset_state(batch_size)

    start_token = tokenizer.word_index['<start>']
    dec_input = tf.fill([batch_size, 1], start_token)

    loss = 0.0

    with tf.GradientTape() as tape:
        features = encoder(img_tensor)

        for t in tf.range(1, tf.shape(target)[1]):
            predictions, hidden, _ = decoder(dec_input, features, hidden)
            loss += loss_function(target[:, t], predictions)
            dec_input = tf.expand_dims(target[:, t], 1)

        total_loss = loss / tf.cast(tf.shape(target)[1] - 1, tf.float32)

    variables = encoder.trainable_variables + decoder.trainable_variables
    gradients = tape.gradient(total_loss, variables)
    optimizer.apply_gradients(zip(gradients, variables))

    return total_loss


### Model Warm Up

In [None]:
# Warm-up call (very important)

dummy_img = tf.zeros((1, 64, 2048))
dummy_seq = tf.zeros((1, 1), dtype=tf.int32)

hidden = decoder.reset_state(1)
_ = decoder(dummy_seq, encoder(dummy_img), hidden)

print("✅ Models built successfully")


✅ Models built successfully


In [None]:
print("Encoder trainable vars:", len(encoder.trainable_variables))
print("Decoder trainable vars:", len(decoder.trainable_variables))


Encoder trainable vars: 2
Decoder trainable vars: 17


### Training loop

In [None]:
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)

In [None]:
EPOCHS = 20
batch_size = 64

for epoch in range(EPOCHS):
    total_loss = 0.0
    num_batches = 0
    start = time.time()

    indices = np.random.permutation(len(y_train_seq))
    train_img_ids = [train_img_ids[i] for i in indices]
    y_train_seq = y_train_seq[indices]

    for i in range(0, len(y_train_seq), batch_size):
        batch_img_ids = train_img_ids[i:i+batch_size]
        target_batch = y_train_seq[i:i+batch_size]

        # 🚨 LOAD FEATURES PER BATCH (SAFE)
        img_batch = np.array([features[img_id] for img_id in batch_img_ids])

        batch_loss = train_step(img_batch, target_batch)
        total_loss += batch_loss
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Epoch {epoch+1}/{EPOCHS} | Loss: {avg_loss:.4f} | Time: {time.time()-start:.2f}s")


Epoch 1/20 | Loss: 1.3277 | Time: 146.88s
Epoch 2/20 | Loss: 1.3008 | Time: 146.60s
Epoch 3/20 | Loss: 1.2737 | Time: 146.62s
Epoch 4/20 | Loss: 1.2436 | Time: 148.41s
Epoch 5/20 | Loss: 1.2128 | Time: 146.97s
Epoch 6/20 | Loss: 1.1843 | Time: 148.41s
Epoch 7/20 | Loss: 1.1596 | Time: 147.02s
Epoch 8/20 | Loss: 1.1385 | Time: 148.52s
Epoch 9/20 | Loss: 1.1200 | Time: 147.09s
Epoch 10/20 | Loss: 1.1043 | Time: 148.60s
Epoch 11/20 | Loss: 1.0897 | Time: 147.42s
Epoch 12/20 | Loss: 1.0765 | Time: 149.18s
Epoch 13/20 | Loss: 1.0641 | Time: 147.40s
Epoch 14/20 | Loss: 1.0527 | Time: 148.51s
Epoch 15/20 | Loss: 1.0416 | Time: 147.57s
Epoch 16/20 | Loss: 1.0314 | Time: 147.44s
Epoch 17/20 | Loss: 1.0210 | Time: 147.13s
Epoch 18/20 | Loss: 1.0117 | Time: 147.46s
Epoch 19/20 | Loss: 1.0027 | Time: 147.71s
Epoch 20/20 | Loss: 0.9935 | Time: 147.70s


### Save Weights

In [None]:
encoder.save_weights(os.path.join(WORKING_DIR, 'encoder.weights.h5'))
decoder.save_weights(os.path.join(WORKING_DIR, 'decoder.weights.h5'))

# **Image Caption Generation**

### Rebuild model architecture

In [None]:
encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, vocab_size)

### Warm-up call

In [None]:
dummy_img = tf.zeros((1, 64, 2048))
dummy_seq = tf.zeros((1, 1), dtype=tf.int32)
hidden = decoder.reset_state(1)

_ = decoder(dummy_seq, encoder(dummy_img), hidden)
print("Models initialized for inference")

Models initialized for inference


### Load weights

In [None]:
encoder.load_weights(os.path.join(WORKING_DIR, 'encoder.weights.h5'))
decoder.load_weights(os.path.join(WORKING_DIR, 'decoder.weights.h5'))

print("Weights loaded successfully")


Weights loaded successfully


### Rebuild CNN for feature extraction

In [None]:
cnn_model = InceptionV3()
cnn_model = Model(
    inputs=cnn_model.inputs,
    outputs=cnn_model.get_layer('mixed10').output
)

### Extract features from a single image

In [None]:
def extract_features(image_path):
    image = load_img(image_path, target_size=(299, 299))
    image = img_to_array(image)
    image = preprocess_input(image)
    image = image.reshape((1, 299, 299, 3))

    feature = cnn_model.predict(image, verbose=0)
    feature = feature.reshape((1, 64, 2048))
    return feature

### Index → word mapping

In [None]:
index_to_word = {v: k for k, v in tokenizer.word_index.items()}

### **CAPTION GENERATION (INFERENCE)**

In [None]:
def generate_caption(image_path):
    # 1. Extract image features
    feature = extract_features(image_path)
    feature = encoder(feature)

    # 2. Initialize decoder state
    hidden = decoder.reset_state(1)

    start_token = tokenizer.word_index['<start>']
    end_token   = tokenizer.word_index['<end>']

    dec_input = tf.expand_dims([start_token], 1)

    caption = []
    used_words = set()

    for t in range(max_length):

        predictions, hidden, _ = decoder(dec_input, feature, hidden)

        # Convert logits to numpy
        logits = predictions[0].numpy()

        # Never allow <start> again
        logits[start_token] = -1e9

        # Repetition penalty
        for w in used_words:
            logits[w] -= 1e9

        # Greedy selection
        predicted_id = int(np.argmax(logits))
        word = index_to_word.get(predicted_id, None)

        # Stop conditions
        if word is None or predicted_id == end_token:
            break

        caption.append(word)
        used_words.add(predicted_id)

        # Hard repetition explosion guard
        if len(caption) >= 3 and caption[-1] == caption[-2] == caption[-3]:
            break

        # Prepare next input
        dec_input = tf.expand_dims([predicted_id], 1)

    return ' '.join(caption)


In [None]:
def generate_caption_beam(image_path, beam_width=3, length_penalty=0.7):
    # Extract image features
    feature = extract_features(image_path)
    feature = encoder(feature)

    start_token = tokenizer.word_index['<start>']
    end_token   = tokenizer.word_index['<end>']

    # Each beam: (sequence, hidden_state, score)
    beams = [([start_token], decoder.reset_state(1), 0.0)]

    completed = []

    for _ in range(max_length):
        new_beams = []

        for seq, hidden, score in beams:
            if seq[-1] == end_token:
                completed.append((seq, score))
                continue

            dec_input = tf.expand_dims([seq[-1]], 1)
            predictions, new_hidden, _ = decoder(dec_input, feature, hidden)

            logits = predictions[0].numpy()

            # Never generate <start> again
            logits[start_token] = -1e9

            probs = tf.nn.softmax(logits).numpy()
            top_ids = np.argsort(probs)[-beam_width:]

            for idx in top_ids:
                new_seq = seq + [idx]
                new_score = score + np.log(probs[idx] + 1e-9)
                new_beams.append((new_seq, new_hidden, new_score))

        # Keep best beams
        beams = sorted(
            new_beams,
            key=lambda x: x[2] / (len(x[0]) ** length_penalty),
            reverse=True
        )[:beam_width]

        if not beams:
            break

    completed.extend([(seq, score) for seq, _, score in beams])

    best_seq = max(
        completed,
        key=lambda x: x[1] / (len(x[0]) ** length_penalty)
    )[0]

    # Convert to words
    caption = []
    for idx in best_seq:
        if idx == start_token or idx == end_token:
            continue
        caption.append(index_to_word.get(idx, ''))

    return ' '.join(caption)


## **EVALUATION**

In [None]:
def evaluate_bleu(val_captions, val_image_ids):
    references = []
    hypotheses = []

    smoothie = SmoothingFunction().method4

    for img_id in val_image_ids:
        # Ground-truth captions
        ref_caps = [
            cap.replace('<start>', '').replace('<end>', '').split()
            for cap in val_captions[img_id]
        ]

        image_path = os.path.join(BASE_DIR, 'Images', img_id + '.jpg')
        pred_caption = generate_caption(image_path).split()

        references.append(ref_caps)
        hypotheses.append(pred_caption)

    bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0), smoothing_function=smoothie)
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0), smoothing_function=smoothie)
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0), smoothing_function=smoothie)
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothie)

    return bleu1, bleu2, bleu3, bleu4


In [None]:
bleu1, bleu2, bleu3, bleu4 = evaluate_bleu(val_captions, sorted(val_captions.keys()))

print(f"BLEU-1: {bleu1:.4f}")
print(f"BLEU-2: {bleu2:.4f}")
print(f"BLEU-3: {bleu3:.4f}")
print(f"BLEU-4: {bleu4:.4f}")

BLEU-1: 0.5406
BLEU-2: 0.3678
BLEU-3: 0.2383
BLEU-4: 0.1371


### Testing the caption generation

In [None]:
image_path = '/content/drive/MyDrive/Flickr8k/Images/111497985_38e9f88856.jpg'

print("GREEDY:")
print(generate_caption(image_path))

print("\nBEAM SEARCH:")
print(generate_caption_beam(image_path))

GREEDY:
a man in the rocks

BEAM SEARCH:
a man in a rock
