## Import all the dependencies

In [None]:
import os
import json
import time
import warnings
from PIL import Image
import tensorflow as tf
import matplotlib.pyplot as plt
import tensorflow_datasets as tfds
from tensorflow.keras import Input
from datetime import datetime, timedelta
from tensorflow.keras.layers import (
    GRU, Add, Attention, Dense, Embedding,
    LayerNormalization, Reshape, StringLookup, TextVectorization,
)
# Reduce TensorFlow logging
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
warnings.filterwarnings("ignore", category=DeprecationWarning)

# Enable eager execution
tf.config.run_functions_eagerly(True)

print("=============Tensorflow Version: ", tf.version.VERSION, "=============")

## Read & Prepare Dataset

In [None]:
VOCAB_SIZE = 20000 
ATTENTION_DIM = 512
WORD_EMBEDDING_DIM = 128

# InceptionResNetV2 takes (299, 299, 3) image as inputs
# and return features in (8, 8, 1536) shape
FEATURE_EXTRACTOR = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(
    include_top=False, weights="imagenet"
)
IMG_HEIGHT = 299
IMG_WIDTH = 299
IMG_CHANNELS = 3
FEATURES_SHAPE = (8, 8, 1536)

GCS_DIR = "gs://asl-public/data/tensorflow_datasets/"
BUFFER_SIZE = 1000

## Preprocessing

In [None]:
def get_image_label(example):
    caption = example["captions"]["text"][0]  # only the first caption per image
    img = example["image"]
    img = tf.image.resize(img, (IMG_HEIGHT, IMG_WIDTH))
    img = img / 255
    return {"image_tensor": img, "caption": caption}

trainds = tfds.load("coco_captions", split="train", data_dir=GCS_DIR)

trainds = trainds.map(
    get_image_label, num_parallel_calls=tf.data.AUTOTUNE
).shuffle(BUFFER_SIZE)
trainds = trainds.prefetch(buffer_size=tf.data.AUTOTUNE)

def add_start_end_token(data):
    start = tf.convert_to_tensor("<start>")
    end = tf.convert_to_tensor("<end>")
    data["caption"] = tf.strings.join(
        [start, data["caption"], end], separator=" "
    )
    return data

trainds = trainds.map(add_start_end_token)

## Tokenize the captions

In [None]:
# Take up to 10 min
MAX_CAPTION_LEN = 64

def standardize(inputs):
    inputs = tf.strings.lower(inputs)
    return tf.strings.regex_replace(
        inputs, r"[!\"#$%&\(\)\*\+.,-/:;=?@\[\\\]^_`{|}~]?", ""
    )

tokenizer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    standardize=standardize,
    output_sequence_length=MAX_CAPTION_LEN,
)

tokenizer.adapt(trainds.map(lambda x: x["caption"]))

In [None]:
tokenizer(["<start> This is a sentence <end>"])

In [None]:
sample_captions = []
for d in trainds.take(5):
    sample_captions.append(d["caption"].numpy())

print("=================Sample Captions=================")
for sample_caption in sample_captions:
  print(sample_caption)
print("==================================")
print(tokenizer(sample_captions))

for wordid in tokenizer([sample_captions[0]])[0]:
    print(tokenizer.get_vocabulary()[wordid], end=" ")

In [None]:
# Lookup table: Word -> Index
word_to_index = StringLookup(
    mask_token="", vocabulary=tokenizer.get_vocabulary()
)

# Lookup table: Index -> Word
index_to_word = StringLookup(
    mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True
)

BATCH_SIZE = 32

def create_ds_fn(data):
    img_tensor = data["image_tensor"]
    caption = tokenizer(data["caption"])

    target = tf.roll(caption, -1, 0)
    zeros = tf.zeros([1], dtype=tf.int64)
    target = tf.concat((target[:-1], zeros), axis=-1)
    return (img_tensor, caption), target

batched_ds = (
    trainds.map(create_ds_fn)
    .batch(BATCH_SIZE, drop_remainder=True)
    .prefetch(buffer_size=tf.data.AUTOTUNE)
)

for (img, caption), label in batched_ds.take(2):
    print(f"Image shape: {img.shape}")
    print(f"Caption shape: {caption.shape}")
    print(f"Label shape: {label.shape}")
    print(caption[0])
    print(label[0])

## Image Encoder

In [None]:
#================== Encoder =====================
FEATURE_EXTRACTOR.trainable = False

image_input = Input(shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS))
image_features = FEATURE_EXTRACTOR(image_input)

x = Reshape((FEATURES_SHAPE[0] * FEATURES_SHAPE[1], FEATURES_SHAPE[2]))(
    image_features
)
encoder_output = Dense(ATTENTION_DIM, activation="relu")(x)

encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
encoder.summary()

## Caption Decoder

In [None]:
#================== Caption Decoder =====================
word_input = Input(shape=(MAX_CAPTION_LEN,), name="words")
embed_x = Embedding(VOCAB_SIZE, ATTENTION_DIM)(word_input)

decoder_gru = GRU(
    ATTENTION_DIM,
    return_sequences=True,
    return_state=True,
)
gru_output, gru_state = decoder_gru(embed_x)

decoder_attention = Attention()
context_vector = decoder_attention([gru_output, encoder_output])

addition = Add()([gru_output, context_vector])

layer_norm = LayerNormalization(axis=-1)
layer_norm_out = layer_norm(addition)

decoder_output_dense = Dense(VOCAB_SIZE
)
decoder_output = decoder_output_dense(layer_norm_out)

decoder = tf.keras.Model(
    inputs=[word_input, encoder_output], outputs=decoder_output
)
decoder.summary()

## Training Model

In [None]:
#================== Training Model =====================
image_caption_train_model = tf.keras.Model(
    inputs=[image_input, word_input], outputs=decoder_output
)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction="none"
)

# @tf.function
def loss_function(real, pred):
    loss_ = loss_object(real, pred)
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    mask = tf.cast(mask, dtype=tf.int32)
    sentence_len = tf.reduce_sum(mask)
    loss_ = loss_[:sentence_len]
    return tf.reduce_mean(loss_, 1)

image_caption_train_model.compile(
    optimizer="adam",
    loss=loss_function,
)

In [None]:
#================== Training Loop =====================
# Record start time
start_time = time.time()
history = image_caption_train_model.fit(batched_ds, epochs=1)

# Record end time and calculate duration
end_time = time.time()
total_time = end_time - start_time

# Convert to hours, minutes, seconds
duration = timedelta(seconds=total_time)
hours, remainder = divmod(duration.seconds, 3600)
minutes, seconds = divmod(remainder, 60)

print(f"\nTraining completed in: {hours} hours, {minutes} minutes, {seconds} seconds")

In [None]:
gru_state_input = Input(shape=(ATTENTION_DIM,), name="gru_state_input")

# Reuse trained GRU, but update it so that it can receive states.
gru_output, gru_state = decoder_gru(embed_x, initial_state=gru_state_input)

# Reuse other layers as well
context_vector = decoder_attention([gru_output, encoder_output])
addition_output = Add()([gru_output, context_vector])
layer_norm_output = layer_norm(addition_output)

decoder_output = decoder_output_dense(layer_norm_output)

# Define prediction Model with state input and output
decoder_pred_model = tf.keras.Model(
    inputs=[word_input, gru_state_input, encoder_output],
    outputs=[decoder_output, gru_state],
)

In [None]:
try:
    final_save_dir = './Epochs10'
    os.makedirs(final_save_dir, exist_ok=True)

    print("Saving model weights...")
    encoder.save_weights(os.path.join(final_save_dir, 'encoder.weights.h5'))
    decoder.save_weights(os.path.join(final_save_dir, 'decoder.weights.h5'))
    decoder_pred_model.save_weights(os.path.join(final_save_dir, 'decoder_pred.weights.h5'))
    image_caption_train_model.save_weights(os.path.join(final_save_dir, 'train_model.weights.h5'))

    print("Saving tokenizer vocabulary...")
    tokenizer_vocab = tokenizer.get_vocabulary()
    with open(os.path.join(final_save_dir, 'tokenizer_vocab.json'), 'w') as f:
        json.dump(tokenizer_vocab, f)

    print("Saving model configuration...")
    model_config = {
        'IMG_HEIGHT': IMG_HEIGHT,
        'IMG_WIDTH': IMG_WIDTH,
        'IMG_CHANNELS': IMG_CHANNELS,
        'ATTENTION_DIM': ATTENTION_DIM,
        'VOCAB_SIZE': VOCAB_SIZE,
        'MAX_CAPTION_LEN': MAX_CAPTION_LEN,
        'WORD_EMBEDDING_DIM': WORD_EMBEDDING_DIM,
        'FEATURES_SHAPE': FEATURES_SHAPE,
        'BUFFER_SIZE': BUFFER_SIZE,
        'BATCH_SIZE': BATCH_SIZE,
    }

    # Save training state
    training_state = {
        'last_epoch': len(history.history['loss']),  # Number of epochs trained
        'last_loss': history.history['loss'][-1]     # Last loss value
    }

    with open(os.path.join(final_save_dir, 'model_config.json'), 'w') as f:
        json.dump(model_config, f)

    with open(os.path.join(final_save_dir, 'training_state.json'), 'w') as f:
        json.dump(training_state, f)

    print("\nModel saving completed successfully!")
except Exception as e:
    print(f"Error in model saving: {str(e)}")
    raise e

## Test the model by loading saved weights and configuration files

In [None]:
# Load configuration files and create models
def load_models(model_dir='./Epochs10'):
    # Load configuration
    with open(os.path.join(model_dir, 'model_config.json'), 'r') as f:
        config = json.load(f)
    
    # Load vocabulary
    with open(os.path.join(model_dir, 'tokenizer_vocab.json'), 'r') as f:
        vocab = json.load(f)
    
    # Create tokenizer
    tokenizer = TextVectorization(
        max_tokens=config['VOCAB_SIZE'],
        output_sequence_length=config['MAX_CAPTION_LEN']
    )
    tokenizer.set_vocabulary(vocab)
    
    # Create feature extractor
    feature_extractor = tf.keras.applications.inception_resnet_v2.InceptionResNetV2(
        include_top=False, weights="imagenet"
    )
    feature_extractor.trainable = False
    
    # Create encoder
    image_input = Input(shape=(config['IMG_HEIGHT'], config['IMG_WIDTH'], config['IMG_CHANNELS']))
    image_features = feature_extractor(image_input)
    x = Reshape((8 * 8, 1536))(image_features)
    encoder_output = Dense(config['ATTENTION_DIM'], activation="relu")(x)
    encoder = tf.keras.Model(inputs=image_input, outputs=encoder_output)
    
    # Create decoder
    word_input = Input(shape=(None,), name="words")
    gru_state_input = Input(shape=(config['ATTENTION_DIM'],), name="gru_state_input")
    encoder_output_input = Input(shape=(None, config['ATTENTION_DIM']))
    
    embed_x = Embedding(config['VOCAB_SIZE'], config['ATTENTION_DIM'])(word_input)
    decoder_gru = GRU(
        config['ATTENTION_DIM'],
        return_sequences=True,
        return_state=True
    )
    decoder_attention = Attention()
    layer_norm = LayerNormalization(axis=-1)
    decoder_output_dense = Dense(config['VOCAB_SIZE'])
    
    gru_output, gru_state = decoder_gru(embed_x, initial_state=gru_state_input)
    context_vector = decoder_attention([gru_output, encoder_output_input])
    addition_output = Add()([gru_output, context_vector])
    layer_norm_output = layer_norm(addition_output)
    decoder_output = decoder_output_dense(layer_norm_output)
    
    decoder = tf.keras.Model(
        inputs=[word_input, gru_state_input, encoder_output_input],
        outputs=[decoder_output, gru_state]
    )
    
    # Load weights
    encoder.load_weights(os.path.join(model_dir, 'encoder.weights.h5'))
    decoder.load_weights(os.path.join(model_dir, 'decoder.weights.h5'))
    
    return encoder, decoder, tokenizer, config

In [None]:
# Define beam search prediction
def beam_search(image_features, decoder, tokenizer, config, beam_width=3):
    vocab = tokenizer.get_vocabulary()
    word_to_index = StringLookup(mask_token="", vocabulary=vocab)
    
    # Initialize with proper shape (1, ATTENTION_DIM)
    gru_state = tf.zeros((1, config['ATTENTION_DIM']))
    start_token = word_to_index("<start>")
    
    initial_sequences = [([start_token], 0.0, gru_state)]
    completed_sequences = []
    
    for _ in range(config['MAX_CAPTION_LEN']):
        candidates = []
        
        for seq, score, curr_state in initial_sequences:
            if seq[-1] == word_to_index("<end>"):
                completed_sequences.append((seq, score))
                continue
            
            # Ensure proper shapes for decoder input
            dec_input = tf.expand_dims([seq[-1]], 0)  # Shape: (1, 1)
            curr_state = tf.reshape(curr_state, (1, config['ATTENTION_DIM']))  # Ensure proper state shape
            
            predictions, new_state = decoder(
                [dec_input, curr_state, image_features]
            )
            
            logits = predictions[0, 0]
            top_k_logits, top_k_indices = tf.math.top_k(logits, k=beam_width)
            
            for i in range(beam_width):
                candidate_seq = seq + [top_k_indices[i].numpy()]
                candidate_score = score - tf.math.log(tf.nn.softmax(logits)[top_k_indices[i]])
                candidates.append((candidate_seq, candidate_score.numpy(), new_state))
        
        ordered = sorted(candidates, key=lambda x: x[1])
        initial_sequences = ordered[:beam_width]
    
    if completed_sequences:
        best_seq = sorted(completed_sequences, key=lambda x: x[1])[0][0]
    else:
        best_seq = initial_sequences[0][0]
    
    return [vocab[idx] for idx in best_seq]

In [None]:
# Define caption prediction and display function
def predict_and_display_caption(image_path, encoder, decoder, tokenizer, config):
    try:
        # Process image
        img = tf.image.decode_jpeg(tf.io.read_file(image_path), channels=config['IMG_CHANNELS'])
        img = tf.image.resize(img, (config['IMG_HEIGHT'], config['IMG_WIDTH']))
        img = tf.keras.applications.inception_resnet_v2.preprocess_input(img)
        features = encoder(tf.expand_dims(img, axis=0))
        
        # Generate caption using beam search
        caption_tokens = beam_search(features, decoder, tokenizer, config)
        caption = " ".join([word for word in caption_tokens[1:-1] if word not in ["<start>", "<end>"]])
        caption = caption + "."

        # Display image and caption
        plt.figure(figsize=(12, 8))
        image = Image.open(image_path)
        plt.imshow(image)
        plt.axis('off')
        plt.title(f"Generated Caption:\n{caption}", pad=20, wrap=True, fontsize=12)
        plt.tight_layout()
        plt.show()

    except Exception as e:
        print(f"Error in prediction: {str(e)}")
        raise e

In [None]:
# Test model
encoder, decoder, tokenizer, config = load_models()

# Test on an image
image_path = "images/dog1.jpeg"
predict_and_display_caption(image_path, encoder, decoder, tokenizer, config)