In [126]:
import tensorflow as tf
import pickle
import numpy as np
from tensorflow.keras.optimizers.schedules import ExponentialDecay
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.utils import Progbar
import random
from sklearn.model_selection import train_test_split

In [127]:

# Load Preprocessed Data

with open("features_inception_512.pkl", "rb") as f:
    image_features = pickle.load(f)  # 512D PCA-reduced features

with open("captions.pkl", "rb") as f:
    captions = pickle.load(f)
with open("word2idx.pkl", "rb") as f:
    word2idx = pickle.load(f)
with open("idx2word.pkl", "rb") as f:
    idx2word = pickle.load(f)


# **Split Data into Training and Validation (80% Train, 20% Validation)**

image_ids = list(captions.keys())
train_ids, val_ids = train_test_split(image_ids, test_size=0.2, random_state=42)

train_captions = {img_id: captions[img_id] for img_id in train_ids}
val_captions = {img_id: captions[img_id] for img_id in val_ids}


# **Define Model Hyperparameters**

embedding_dim = 256
units = 512  
vocab_size = len(word2idx) + 1
batch_size = 64
num_epochs = 25  # Set high, early stopping will stop it automatically
patience = 5  # Stop if no improvement after 5 epochs

# **Track Best Validation Loss**
best_val_loss = float("inf")
early_stop_counter = 0  # Tracks epochs without improvement

In [128]:
# Define CNN Encoder (with L2 Regularization)

class CNN_Encoder(tf.keras.Model):
    def __init__(self, embed_dim):
        super(CNN_Encoder, self).__init__()
        
        #  L2 Regularization (lambda=0.01) to prevent overfitting
        self.fc = tf.keras.layers.Dense(embed_dim, activation="relu", kernel_regularizer=tf.keras.regularizers.l2(0.01))

    def call(self, x):
        x = self.fc(x)
        return tf.expand_dims(x, axis=1)  # Ensure correct shape


# **Define RNN Decoder (with Dropout)

class RNN_Decoder(tf.keras.Model):
    def __init__(self, vocab_size, embed_dim, units):
        super(RNN_Decoder, self).__init__()
        self.embedding = tf.keras.layers.Embedding(vocab_size, embed_dim)
        
        #  Added dropout (30%) to LSTM to prevent overfitting
        self.lstm = tf.keras.layers.LSTM(units, return_sequences=True, return_state=True, dropout=0.3)
        
         #  Apply softmax at the final Dense layer to avoid the warning
        self.fc = tf.keras.layers.Dense(vocab_size, activation="softmax")

        self.attention = tf.keras.layers.AdditiveAttention()

    def call(self, x, hidden, features):
        hidden = tf.expand_dims(hidden, axis=1)  

        # Apply attention
        context_vector, _ = self.attention([hidden, features], return_attention_scores=True)

        x = self.embedding(x)  # Shape: (batch_size, sequence_length, embed_dim)
        context_vector = tf.repeat(context_vector, repeats=x.shape[1], axis=1)  # (batch_size, 30, 512)
        x = tf.concat([context_vector, x], axis=-1)  # (batch_size, 30, 512 + 256)

        output, state_h, state_c = self.lstm(x)  # Now includes Dropout
        
        x = self.fc(output)
        return x, state_h, state_c, _ 
    



# **Initialize Model and Optimizer**

encoder = CNN_Encoder(512)  # Image feature size = 512
decoder = RNN_Decoder(vocab_size, embedding_dim, units)

lr_schedule = ExponentialDecay(initial_learning_rate=0.001, decay_steps=1000, decay_rate=0.96, staircase=True)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

# Checkpoint manager for saving progress
checkpoint = tf.train.Checkpoint(optimizer=optimizer, encoder=encoder, decoder=decoder)
checkpoint_manager = tf.train.CheckpointManager(checkpoint, "./checkpoints", max_to_keep=5)

# **Initialize Accuracy Metrics**
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
val_accuracy = tf.keras.metrics.SparseCategoricalAccuracy()


In [129]:
def train_step(image_feature, caption_input, caption_target, training=True):
    loss = 0
    with tf.GradientTape() as tape:
        hidden = tf.zeros((caption_input.shape[0], units))  # Ensure batch size consistency
        
        #  Ensure Proper Input Shape to Encoder
        image_feature = tf.reshape(image_feature, (1, 512))  # Fix issue
        image_feature = encoder(image_feature)  # Extract image feature (Now Shape: (1, 1, 512))
        
        #  Corrected unpacking (Now expecting 4 values)
        output, state_h, state_c, _ = decoder(caption_input, hidden, image_feature)

        # Compute loss
        loss = tf.keras.losses.sparse_categorical_crossentropy(caption_target, output, from_logits=True)
        loss = tf.reduce_mean(loss)

        #  Update accuracy metric
        if training:
            train_accuracy.update_state(caption_target, output)
        else:
            val_accuracy.update_state(caption_target, output)

    #  Apply Gradients during training
    if training:
        gradients = tape.gradient(loss, decoder.trainable_variables + encoder.trainable_variables)
        optimizer.apply_gradients(zip(gradients, decoder.trainable_variables + encoder.trainable_variables))
    
    return loss


In [130]:
# **Training Loop with Validation & Early Stopping**
for epoch in range(num_epochs):
    print(f"\n Epoch {epoch+1}/{num_epochs}")  #  Print epoch number

    # **Training Phase**
    total_loss = 0
    batch_count = len(train_captions)
    progress_bar = Progbar(batch_count, stateful_metrics=["loss", "accuracy"])

    train_accuracy.reset_state()

    for i, (image_id, caption_words) in enumerate(train_captions.items()):
        image_feature = image_features.get(image_id)
        if image_feature is None:
            continue

        caption_sequence = [word2idx.get(word, word2idx["<unk>"]) for word in caption_words]
        caption_input = [word2idx['<start>']] + caption_sequence[:-1]
        caption_target = caption_sequence

        caption_input = pad_sequences([caption_input], maxlen=30, padding='post')
        caption_target = pad_sequences([caption_target], maxlen=30, padding='post')

        caption_input = tf.convert_to_tensor(caption_input, dtype=tf.int32)
        caption_target = tf.convert_to_tensor(caption_target, dtype=tf.int32)

        loss = train_step(image_feature, caption_input, caption_target, training=True)
        total_loss += loss

        if i % 10 == 0:  # Update progress bar every 10 steps
            progress_bar.update(i + 1, values=[("loss", loss.numpy()), ("accuracy", train_accuracy.result().numpy())])

    train_loss = total_loss.numpy() / batch_count  # Compute epoch-level loss
    train_acc = train_accuracy.result().numpy()  # Compute epoch-level accuracy

    # **Validation Phase**
    val_loss = 0
    val_accuracy.reset_state()

    for image_id, caption_words in val_captions.items():
        image_feature = image_features.get(image_id)
        if image_feature is None:
            continue

        caption_sequence = [word2idx.get(word, word2idx["<unk>"]) for word in caption_words]
        caption_input = [word2idx['<start>']] + caption_sequence[:-1]
        caption_target = caption_sequence

        caption_input = pad_sequences([caption_input], maxlen=30, padding='post')
        caption_target = pad_sequences([caption_target], maxlen=30, padding='post')

        caption_input = tf.convert_to_tensor(caption_input, dtype=tf.int32)
        caption_target = tf.convert_to_tensor(caption_target, dtype=tf.int32)

        loss = train_step(image_feature, caption_input, caption_target, training=False)
        val_loss += loss.numpy()

    val_loss /= len(val_captions)  # Compute epoch-level validation loss
    val_acc = val_accuracy.result().numpy()  # Compute epoch-level validation accuracy

    # **Display Epoch Summary**
    print(f"\n Epoch {epoch+1}, Loss: {train_loss:.4f}, Accuracy: {train_acc:.4f}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")

    checkpoint_manager.save()

    # **Early Stopping Condition**
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        early_stop_counter = 0  
    else:
        early_stop_counter += 1
        if early_stop_counter >= patience:
            print("\n Early stopping triggered.")
            break

print("\n Training complete. Model saved.")



 Epoch 1/25
[1m6471/6473[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 231ms/step - loss: 2.1656e-06 - accuracy: 0.9989
 Epoch 1, Loss: 0.0099, Accuracy: 0.9989, Val Loss: 0.0000, Val Accuracy: 1.0000

 Epoch 2/25
[1m6471/6473[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 242ms/step - loss: 2.2252e-07 - accuracy: 1.0000
 Epoch 2, Loss: 0.0001, Accuracy: 1.0000, Val Loss: 0.0000, Val Accuracy: 1.0000

 Epoch 3/25
[1m6471/6473[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 247ms/step - loss: 2.3842e-08 - accuracy: 1.0000
 Epoch 3, Loss: 0.0000, Accuracy: 1.0000, Val Loss: 0.0000, Val Accuracy: 1.0000

 Epoch 4/25
[1m6471/6473[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 238ms/step - loss: 1.4305e-07 - accuracy: 0.9997
 Epoch 4, Loss: 0.0016, Accuracy: 0.9997, Val Loss: 0.0000, Val Accuracy: 1.0000

 Epoch 5/25
[1m6471/6473[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m0s[0m 240ms/step - loss: 1.1126e-07 - accuracy: 1.0000
 Epoch 5, Loss: 0.0001, Ac