# Setting Up

## Installing requirements

In [None]:
!ls /kaggle/input/

In [None]:
!pip install tensorflow keras tensorboard scikit-learn ipywidgets

## Imports

In [None]:
from pathlib import Path
import os
import sys

import random
from PIL import Image 

from datetime import datetime

import numpy as np
import pandas as pd

import pickle

import matplotlib.pyplot as plt

import tensorflow as tf
import keras
from keras.applications.vgg19 import VGG19, preprocess_input
from keras.layers import Input, Dense, Dropout, Embedding, LSTM, add, Flatten
from keras.preprocessing.image import load_img, img_to_array
from keras.utils import to_categorical, plot_model
from keras import Model

from sklearn.model_selection import train_test_split

# notebook specific
from IPython.display import display
# from tqdm.notebook import tqdm # for some reason this doesn't work
from tqdm import tqdm

# Some Parameters

In [None]:
# constants
TRAIN_CNN: bool = False
NUM_OUTPUT_CAPTIONS: int = 1
IMAGE_INPUT_SHAPE: tuple[int, int, int] = (224, 224, 3) # (height, width, channel)

DATASET_NAME: str = "flickr30k"
FEATURES_NAME: str = "flickr-31k-features-all"

FILTER_NON_ALPHA_NUMERIC_STRINGS = True

EVALUATE_AFTER_TRAIN: bool = False
LOAD_PRETRAINED: bool = False # make it possible to continue training from a saved model
TAKE_FEATURES_FROM_INPUT: bool = True # load features from already saved file as infering this takes a long time
CONTINUE_FROM_WANDB_RUN: bool = False # if needed to load pretrained model from wandb
USE_MULTIPLE_GPUS: bool = True 

# Hyperparameters
N_EPOCHS: int = 35
BATCH_SIZE: int = 1024
DEBUG: bool = False
TRAIN_TEST_VAL_SPLIT = (70, 20, 10)
    
# paths
WORKING_DIR = Path("/kaggle/working")
INPUT_DIR = Path("/kaggle/input/")
DATASET_INPUT_DIR = INPUT_DIR.joinpath(DATASET_NAME)
TEMP_DIR = Path("/tmp")


In [None]:
import logging

# setting up the logger
logging.basicConfig(level=logging.DEBUG if DEBUG else logging.INFO, force=True) # a workaround
logger = logging.getLogger("-^-")

In [None]:
all_good = True

if not os.path.exists(WORKING_DIR):
    logger.error(f"The directory {WORKING_DIR} doesn't exist.")
    all_good = False

if not os.path.exists(DATASET_INPUT_DIR):
    logger.error(f"The directory {DATASET_INPUT_DIR} doesn't exist.")
    all_good = False

if not os.path.exists(TEMP_DIR):
    logger.error(f"The directory {TEMP_DIR} doesn't exist.")
    all_good = False

if all_good:
    logger.info(f"All the directories are valid.")

# Dataset Preparation

In [None]:
images_dir = DATASET_INPUT_DIR.joinpath("Images")
captions_file = DATASET_INPUT_DIR.joinpath("captions.txt")

if TAKE_FEATURES_FROM_INPUT:
    features_dir = INPUT_DIR.joinpath(FEATURES_NAME)
else:
    features_dir = WORKING_DIR.joinpath("features")
    
    features_dir.mkdir(exist_ok=True)

## Creating the vocabulary

In [None]:
captions_data = pd.read_csv(captions_file)
captions_data.astype(str)
captions_data.dropna(inplace=True)
logger.info(captions_data.head())

In [None]:
# filtering the columns having too few or too many words
# as having too few and too many can improperly skew whole training process
# having too many causes the whole network to be trained mostly on padding rather than the actual data

upper_limit = 30
lower_limit = 6

captions_data = captions_data.drop(captions_data.loc[captions_data["caption"].apply(lambda x: len(str(x).split()) > upper_limit or len(str(x).split()) < lower_limit)].index)

len(captions_data["image"].unique())

In [None]:
unfiltered_vocabulary = list(set(" ".join(captions_data["caption"].to_list()).lower().split()))

removed_items = []

if FILTER_NON_ALPHA_NUMERIC_STRINGS:
    vocabulary = list(filter(lambda x: len(x) >= 3 or x.isalpha() , unfiltered_vocabulary))
    removed_items += list(filter(lambda x: len(x) < 3 and not x.isalpha() , unfiltered_vocabulary))
else:
    vocabulary = unfiltered_vocabulary

filtered_captions_data = captions_data.copy()
filtered_captions_data["caption"] = filtered_captions_data["caption"].apply(
    lambda x: " ".join(list(filter(
        lambda y: y not in removed_items, x.lower().split()))
                      )
)

vocabulary_from_filtered_captions_data = list(set(" ".join(filtered_captions_data["caption"].to_list()).lower().split()))

for x in vocabulary_from_filtered_captions_data:
    if x not in vocabulary:
        logger.error(f"Word: '{x}' not in vocabulary")
        raise Exception("Found a word that is not in the vocabulary")

captions_data = filtered_captions_data

logger.info(f"Removed items: {removed_items}")

logger.info(f"Total unique words: {len(vocabulary)}")

In [None]:
class VocabHandler():
    """
    Handles vocabulary. Indices start from 1 since 0 is reserved for padding.
    """
    word_to_id_dict: dict[str, int] = {}
    id_to_word_dict: dict[int, str] = {}
    vocab_size = 0
        
    def __init__(self, vocabulary: list[str], start_word: str = "<start>", stop_word: str = "<stop>", count_padding_as_separate_word: bool = True, padding: str = "<padding>"):
        """
        ID 0 is used for padding
        
        Parameters
        ----------
        vocabulary: The list of words
        start_word: special word for denoting the start of generation
        stop_word: special word for denoting the stop of generation
        count_padding_as_separate_word: if false, there won't be a entry for ID 0
        padding: the special word place where id=0
        """
        
        self.start_word: str = start_word
        self.stop_word: str = stop_word
        
        if count_padding_as_separate_word:
            self.padding = padding
            self.word_to_id_dict[self.padding] = 0
            self.id_to_word_dict[0] = self.padding
        
        # adding start word in the vocabulary
        self.word_to_id_dict[self.start_word] = 1
        self.id_to_word_dict[1] = self.start_word
        
        last_index = 0
        for idx, word in enumerate(vocabulary):
            self.word_to_id_dict[word] = idx + 2
            self.id_to_word_dict[idx + 2] = word 
            last_index = idx + 2
            
        # adding start word in the vocabulary
        self.word_to_id_dict[self.stop_word] = last_index + 1
        self.id_to_word_dict[last_index + 1] = self.stop_word
        
        assert len(self.word_to_id_dict) == len(self.id_to_word_dict)
        
        self.vocab_size = len(self.word_to_id_dict)

    def id_of(self, word: str) -> int | None:
        return self.word_to_id_dict[word]

    def word_of(self, idx: int) -> str | None:
        return self.id_to_word_dict[idx]
    
    def text_to_sequence(self, text: str, max_length: int = 0, padding: bool = False, pad_with: int = 0) -> np.ndarray:
        
        words = text.split()
        
        if not padding or max_length < 1:
            if max_length < 1:
                logger.error(f"The provided maximum length {max_length} is invalid.")
            return np.array(list(map(lambda x: self.id_of(x), words)))
        
        len_words = len(words)
        
        padded_sequence = np.full(max_length, pad_with)
        padded_sequence[:len_words] = np.array(list(map(lambda x: self.id_of(x), words)))
        
        
        return padded_sequence
    
    def sequence_to_text(self, sequence: np.ndarray, padded: bool = False, padded_with: int = 0):
        if not padded:
            return " ".join(map(lambda x: self.word_of(x), sequence))
        
        return " ".join(filter(lambda y: y!="", map(lambda x: self.word_of(x) if x!= padded_with else "", sequence)))
    
    def save(self, file_location: Path):
        pickle.dump(self.word_to_id_dict, open(file_location.joinpath("word-to-id-dict.pkl"), 'wb'))
        pickle.dump(self.id_to_word_dict, open(file_location.joinpath("id-to-word-dict.pkl"), 'wb'))

    def load(self, file_location: Path):
        self.word_to_id_dict = pickle.load(open(file_location.joinpath("word-to-id-dict.pkl"), 'rb'))
        self.id_to_word_dict= pickle.load(open(file_location.joinpath("id-to-word-dict.pkl"), 'rb'))
    

In [None]:
default_vocab_handler = VocabHandler(vocabulary)


In [None]:
logger.info(f"Vocab Size: {default_vocab_handler.vocab_size}")
logger.info(f"Id of young is {default_vocab_handler.id_of('young')}")
logger.info(f"The word corresponding to id 12414 {default_vocab_handler.word_of(12414)}") 

logger.info(f"Id of {default_vocab_handler.stop_word} is {default_vocab_handler.id_of(default_vocab_handler.stop_word)}")

# saving the vocab handler for later use
VOCAB_HANDLER_SAVE_PATH = WORKING_DIR.joinpath("vocab-handler")
VOCAB_HANDLER_SAVE_PATH.mkdir(exist_ok=True)
default_vocab_handler.save(VOCAB_HANDLER_SAVE_PATH)

In [None]:
default_vocab_handler.load(VOCAB_HANDLER_SAVE_PATH)

logger.info(f"Vocab Size: {default_vocab_handler.vocab_size}")
logger.info(f"Id of young is {default_vocab_handler.id_of('young')}")
logger.info(f"The word corresponding to id 12414 {default_vocab_handler.word_of(12414)}") 

logger.info(f"Id of {default_vocab_handler.stop_word} is {default_vocab_handler.id_of(default_vocab_handler.stop_word)}")


In [None]:
# getting the maximum length of the words in a caption
# this is important for padding the input as to provide equal length text input
maximum_length = max(captions_data["caption"].apply(lambda caption: len(caption.split())))
logger.info(f"The maximum number of words is {maximum_length}")

absolute_max_length = maximum_length + 2 # including start and stop words


In [None]:
# Two young guys with shaggy hair look at their
sample_text = "<start> Two young guys with shaggy hair look at their <stop>".lower()
output_sequence = default_vocab_handler.text_to_sequence(sample_text, absolute_max_length, padding = True)
output_text = default_vocab_handler.sequence_to_text(output_sequence, padded=True)
logger.info(f"Input Text: {sample_text}")
logger.info(f"Output Sequence: {output_sequence}.")
logger.info(f"Output Text: {output_text}")


## Pre-infer images

In [None]:
def get_base_model():
    base_model = VGG19(
        include_top=True,
        weights="imagenet",
        input_tensor=None,
        input_shape=IMAGE_INPUT_SHAPE,
        pooling="max",
    )
    base_model = Model(inputs=base_model.inputs, outputs=base_model.layers[-2].output)
    base_model.training = False

    return base_model

In [None]:
from contextlib import nullcontext

if USE_MULTIPLE_GPUS:
    logging.info("Using multiple GPUs with mirrored strategy.")
    strategy = tf.distribute.MirroredStrategy()
    print('Number of devices: {}'.format(strategy.num_replicas_in_sync))

with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():
    base_model = get_base_model()
    
    FEATURE_SHAPE = base_model.output.shape
    logger.info(f"Feature shape: {FEATURE_SHAPE}")
    base_model.summary()

In [None]:
def extract_image_features_and_save(images, batch_size=BATCH_SIZE, save_to_file=False) -> dict | None:
    """
    Returns features if kept in RAM None otherwise
    """
    
    features = {}
    
    with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():
        for i in tqdm(range(0, len(images), batch_size)):
            img_arr = map(lambda x: img_to_array(load_img(images_dir.joinpath(x), target_size=IMAGE_INPUT_SHAPE)), images[i: i+batch_size])
            predictions = base_model.predict(preprocess_input(np.array(list(img_arr))), verbose=0)

            for image_name, prediction in zip(images[i: i+batch_size], predictions):
                features[image_name] = prediction

    if save_to_file:
        output_file = features_dir.joinpath('features.pkl')
        pickle.dump(features, open(output_file, 'wb'))

    return features

images = captions_data["image"].unique().tolist()

if TAKE_FEATURES_FROM_INPUT:
    with open(features_dir.joinpath('features.pkl'), 'rb') as f:
        features = pickle.load(f)
else:
    features = extract_image_features_and_save(images, save_to_file=False)

In [None]:
len(features)

## Data Generator

In [None]:
def data_generator(training_ids: list[str], vocab_handler: VocabHandler, max_length: int, batch_size: int):
    """
    Generate infinite stream of batches
    
    """
    
    vocab_size = vocab_handler.vocab_size
    
    while True:
        img_features_input, text_inputs, text_outputs = list(), list(), list()
        
        random.shuffle(training_ids)
        sample = 0
        for img_id in training_ids:
            feature = features[img_id]
        
            # get the captions corresponding to the image_id
            captions = captions_data.loc[captions_data["image"]==img_id]
    
            for caption in captions["caption"].tolist():
                words = caption.split()
                words.insert(0, vocab_handler.start_word)
                words.append(vocab_handler.stop_word)
                n_words = len(words)
    
                for i in range(1, n_words):
                    img_features_input.append(feature[0])
                    text_inputs.append(vocab_handler.text_to_sequence(" ".join(words[:i]), max_length, True))
                    text_outputs.append(to_categorical([vocab_handler.id_of(words[i])], num_classes=vocab_size)[0])
        
                    sample += 1
            
                    if sample == batch_size:
                        sample = 0
        
                        img_features_input, text_inputs, text_outputs = np.array(img_features_input), np.array(text_inputs), np.array(text_outputs)
        
                        
                        yield (img_features_input, text_inputs), text_outputs
            
                        img_features_input, text_inputs, text_outputs = list(), list(), list()
                        
        # Yield any remaining samples that didn't make a full batch
        if img_features_input:
            img_batch = np.array(img_features_input)
            text_batch = np.array(text_inputs)
            output_batch = np.array(text_outputs)
            
            yield (img_batch, text_batch), output_batch


In [None]:
n_samples = 2
_test_images = ["1000092795.jpg", "10002456.jpg", "1000268201.jpg", "1000344755.jpg", "1000366164.jpg", "1000523639.jpg", "1000919630.jpg", "10010052.jpg", "1001465944.jpg", "1001545525.jpg", "1001573224.jpg", "1001633352.jpg", "1001773457.jpg", "1001896054.jpg", "100197432.jpg", "100207720.jpg", "1002674143.jpg", "1003163366.jpg", "1003420127.jpg"]
random.shuffle(_test_images)
_data_generator = data_generator(_test_images,
                  default_vocab_handler, absolute_max_length, 2)

for _ in range(n_samples):
    (img_features, txt_input), text_outputs = _data_generator.__next__()

    for img_feature, txt_i, txt_o in zip(img_features, txt_input, text_outputs):        
        print(f"Input: {default_vocab_handler.sequence_to_text(txt_i, True)}")
        print(f"To Predict: {default_vocab_handler.word_of(np.argmax(txt_o))}")
        plt.imshow(img_feature.reshape(64, 64, 1))
        plt.show()

# Model definition

In [None]:
def get_image_captioning_model():
    
    feature_input_layer = Input(shape=(FEATURE_SHAPE[1],))

    # some trainable layers before merging
    dropout_1 = Dropout(0.4)(feature_input_layer)
    image_feature_output = Dense(256, activation='relu')(dropout_1)
    
    # text feature extraction
    text_input_layer = Input(shape=(absolute_max_length, ))
    embed = Embedding(default_vocab_handler.vocab_size, 256, mask_zero=True)(text_input_layer)
    dropout_2 = Dropout(0.4)(embed)
    text_feature_output = LSTM(256)(dropout_2)
    
    # decoding
    combine = add([image_feature_output, text_feature_output])
    dense_decoder = Dense(256, activation='relu')(combine)
    outputs = Dense(default_vocab_handler.vocab_size, activation='softmax')(dense_decoder)
    
    model = Model(inputs=[feature_input_layer, text_input_layer], outputs=outputs)

    return model


In [None]:

# Open a strategy scope if needed.
with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():

    if LOAD_PRETRAINED:
        logging.info("Loading the pretrained model.")
        model = tf.keras.models.load_model('/kaggle/input/1024_batch_size_35_epoch/keras/default/1/models.keras')
    else:
        model = get_image_captioning_model()
    
        # compiling
        model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

plot_model(model, to_file="model.png", show_shapes=True, show_layer_names=True)

# Model Training

## Setting Up Tensorboard

In [None]:
logdir = "/kaggle/working/logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

In [None]:
# doesn't work on kaggle

# # Load the TensorBoard notebook extension.
# %load_ext tensorboard

# %tensorboard --logdir /kaggle/working/logs/scalars

## Setting Up Wandb

In [None]:
import wandb
from wandb.integration.keras import WandbMetricsLogger, WandbModelCheckpoint, WandbCallback

wandb.login(key="bed61d44fc0bb5ea4949f930c43ce0f44dcc764a")

# Initialize a new W&B run
run = wandb.init(config={"bs": 12}, project="Image Caption Generator-Multi-GPU Pre-Infer Images")

metric_logger_callback = WandbMetricsLogger(log_freq="batch")
model_save_callback = WandbModelCheckpoint(filepath="models.keras", save_freq="epoch")
wandb.save(str(VOCAB_HANDLER_SAVE_PATH)+"/*", base_path=str(WORKING_DIR))

In [None]:
# reuse old model
if CONTINUE_FROM_WANDB_RUN:
  entity = "kunepal"
  project = "Image Caption Generator-Multi-GPU Pre-Infer Images"
  alias = "latest"  # semantic nickname or identifier for the model version
  model_artifact_name = "run_cq1wzywj_model"

  # Access and download model. Returns path to downloaded artifact

  downloaded_model_path = run.use_model(name=f"{entity}/{project}/{model_artifact_name}:{alias}")
  run_path = f"/{entity}/{project}/{model_artifact_name[4:-6]}"
  id_to_word_dict_save_path = wandb.restore("vocab-handler/id-to-word-dict.pkl", run_path=run_path)
  word_to_id_dict_save_path = wandb.restore("vocab-handler/word-to-id-dict.pkl", run_path=run_path)

  loading_path = Path(id_to_word_dict_save_path.name).parent
  logger.info(f"Loading from the path {loading_path}.")
    
  default_vocab_handler.load(loading_path)

### Callback to clear memory and restart keras backend at the end of each epoch
https://stackoverflow.com/questions/53683164/keras-occupies-an-indefinitely-increasing-amount-of-memory-for-each-epoch

In [None]:
import gc
from tensorflow.keras import backend as k
from tensorflow.keras.callbacks import Callback

class ClearMemory(Callback):
    def on_epoch_end(self, epoch, logs=None):
        gc.collect()

In [None]:
early_stopping_callback = keras.callbacks.EarlyStopping(monitor='loss', patience=3)

## Training

In [None]:
all_img_ids = captions_data["image"].unique().tolist()

logger.info(f"Total samples: {len(all_img_ids)}")

In [None]:
def calculate_total_samples(data):
    # Calculate the length of captions
    length_of_caption_df = captions_data.copy()
    length_of_caption_df["length_of_caption"] = captions_data["caption"].apply(lambda x: len(str(x).split()))
    
    # Filter the dataframe to only include images in the given data
    filtered_df = length_of_caption_df[length_of_caption_df["image"].isin(data)]
    
    # Sum the lengths directly
    total_samples = filtered_df["length_of_caption"].sum()
    
    return total_samples

calculate_total_samples(["1000092795.jpg", "1000366164.jpg"])

In [None]:
train, test_and_validation = train_test_split(all_img_ids, test_size=((TRAIN_TEST_VAL_SPLIT[1] + TRAIN_TEST_VAL_SPLIT[2])/ sum(TRAIN_TEST_VAL_SPLIT)))
test, validation = train_test_split(test_and_validation, test_size=TRAIN_TEST_VAL_SPLIT[2]/(TRAIN_TEST_VAL_SPLIT[1] + TRAIN_TEST_VAL_SPLIT[2]))

total_training_samples = calculate_total_samples(train)
total_validation_samples = calculate_total_samples(validation)

steps = total_training_samples // BATCH_SIZE
val_steps = total_validation_samples // BATCH_SIZE

logger.info(f"Total batches in an epoch: {steps}")
logger.info(f"Total batches in validation set: {val_steps}")

generator = data_generator(train, default_vocab_handler, absolute_max_length, BATCH_SIZE)
val_generator = data_generator(validation, default_vocab_handler, absolute_max_length, BATCH_SIZE)
test_generator = data_generator(test, default_vocab_handler, absolute_max_length, BATCH_SIZE)



In [None]:
with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():
    if CONTINUE_FROM_WANDB_RUN:
        model = tf.keras.models.load_model(downloaded_model_path)
        logger.info(f"Continuing from the previous training sample from wandb")
    history = model.fit(
        generator, 
        epochs=N_EPOCHS, 
        steps_per_epoch=steps, 
        verbose=1, validation_data=val_generator,
        callbacks=[tensorboard_callback, metric_logger_callback, model_save_callback, ClearMemory(), early_stopping_callback], 
        validation_freq=1, 
        validation_steps=val_steps)
    if EVALUATE_AFTER_TRAIN:
        model.evaluate(test_generator)

# Evaluation

## Saving and loading up the saved model to remove dependencies of multiple GPUS

In [None]:
# this is done because model trained using multiple-GPUS require more than one 
# sample in a batch to distribute it equally
# since this is a bit complicated to run inference on multiple image,text pair
# so, using this hack for now
model.save('temp_model.keras')

model = tf.keras.models.load_model('temp_model.keras')

## Testing and visualization on few images

In [None]:
def test_on_image_having_feature(image_id: str):

    _feature = features[image_id]
    
    text_input = "<start>"
    whole_text_output = ""
    for i in range(absolute_max_length):
        sequence_input = default_vocab_handler.text_to_sequence(text_input, absolute_max_length, True)

        model_input = [_feature.reshape((1, FEATURE_SHAPE[1])), sequence_input.reshape((1,absolute_max_length))]
        predictions = model.predict(model_input, verbose=0)
        sequence_output = np.argmax(predictions[0])
        
        text_output = default_vocab_handler.word_of(sequence_output)
        
        if text_output == default_vocab_handler.stop_word:
            break
        whole_text_output += " " + text_output
        text_input += " " + text_output
    
    return whole_text_output

In [None]:
n_images_to_test = 2

test_images = random.sample(test, n_images_to_test)
    
for image_id in test_images:    
    whole_text_output = test_on_image_having_feature(image_id)
    img_path = images_dir.joinpath(image_id)
    
    print(f"Generated: {whole_text_output}")
    print(f"Actual: {captions_data.loc[captions_data['image'] == image_id]['caption'].tolist()}")
    plt.imshow(load_img(img_path, target_size=IMAGE_INPUT_SHAPE))
    plt.show()
    

In [None]:
def test_on_new_image(image: Path | np.ndarray):
    if isinstance(image, Path):
        image = load_img(image, target_size=IMAGE_INPUT_SHAPE)
        image = img_to_array(image)

    reshaped_img = image.reshape(1, *IMAGE_INPUT_SHAPE)
    image_input = preprocess_input(reshaped_img)

    _feature = base_model.predict(image_input)[0]
    
    text_input = "<start> "
    whole_text_output = ""
    for i in range(absolute_max_length):
        sequence_input = default_vocab_handler.text_to_sequence(text_input, absolute_max_length, True)

        model_input = [_feature.reshape((1, FEATURE_SHAPE[1])), sequence_input.reshape((1,absolute_max_length))]
        
        sequence_output = np.argmax(model.predict(model_input, verbose=0)[0])
        
        text_output = default_vocab_handler.word_of(sequence_output)
        
        if text_output == default_vocab_handler.stop_word:
            break
        whole_text_output += " " + text_output
        text_input += " " + text_output
    
    return whole_text_output


In [None]:
n_images_to_test = 2

test_images = random.sample(test, n_images_to_test)
    
for image_id in test_images:    
    img_path = images_dir.joinpath(image_id)
    whole_text_output = test_on_new_image(img_path)
    
    print(f"Generated: {whole_text_output}")
    print(f"Actual: {captions_data.loc[captions_data['image'] == image_id]['caption'].tolist()}")
    plt.imshow(load_img(img_path, target_size=IMAGE_INPUT_SHAPE))
    plt.show()

## Testing BLEU Score

In [None]:
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm

actual = []
predicted = []
for image_id in tqdm(test[:3]):
    whole_text_output = test_on_image_having_feature(image_id)

    _actual = list(map(lambda x: x.split(), captions_data.loc[captions_data['image'] == image_id]['caption'].tolist()))
    _predicted = whole_text_output.split()
    
    actual.append(_actual)
    predicted.append(_predicted)

    logger.info(f"\n\nActual: {_actual} \n --- \nPredicted: {_predicted}")
    
score = corpus_bleu(actual, predicted)

logger.info(f"Score: {score}")

### Parallelized

In [None]:
with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():
    model = tf.keras.models.load_model('temp_model.keras')

In [None]:
def test_parallelized(images, batch_size):

    predicted: list[str] = []
    
    for i in tqdm(range(0, len(images), batch_size)):
        images_batch = images[i: i + batch_size]
        num_images_in_batch = len(images_batch)
        
        text_inputs = ["<start>" for _ in range(num_images_in_batch)]
        text_outputs = ["" for _ in range(num_images_in_batch)]

        sequence_inputs = np.zeros(shape=(num_images_in_batch, absolute_max_length))
        feature_inputs = np.zeros(shape=(num_images_in_batch, FEATURE_SHAPE[1]))

        for j in range(absolute_max_length):
            for k, image in enumerate(images_batch):
                sequence_inputs[k] = default_vocab_handler.text_to_sequence(text_inputs[k], absolute_max_length, True)
                feature_inputs[k] = features[image]

            model_inputs = [feature_inputs, sequence_inputs]
            
            sequence_outputs = list(map(lambda x: np.argmax(x), model.predict(model_inputs, verbose=0)))
            
            text_output = list(map(lambda x: default_vocab_handler.word_of(x), sequence_outputs))

            for l in range(num_images_in_batch):
                text_outputs[l] += " " + text_output[l]
                text_inputs[l] += " " + text_output[l]

        predicted.extend(text_outputs)

    return predicted


In [None]:

from nltk.translate.bleu_score import sentence_bleu
from tqdm import tqdm

batch_size = 32
with strategy.scope() if USE_MULTIPLE_GPUS else nullcontext():
    predicted = test_parallelized(test, batch_size)

actual = []
for i, image_id in tqdm(enumerate(test)):
    # since all the inference is run up to maxlength, the generated text
    # might contain many <stop> i.e. stop words
    # so discarding after encountering the first stop word
    first_index_of_end_sequence = predicted[i].find(default_vocab_handler.stop_word)
    predicted[i] = predicted[i][0: first_index_of_end_sequence].split()
    
    _actual = list(map(lambda x: x.split(), captions_data.loc[captions_data['image'] == image_id]['caption'].tolist()))
    actual.append(_actual)
    


In [None]:
score = corpus_bleu(actual, predicted)
logger.info(f"Bleu score: {score}")

In [None]:
test_generator.__next__()[]