In [None]:
import os
import random
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from image_captioning.constants import *
from image_captioning.data_pipeline import input_dataset, utils
from image_captioning.model import text_vectorization, encoder, decoder, checkpoint_manager
from image_captioning.scripts.train import create_models

# Load data

In [None]:
all_captions, all_imgpaths, imgpath_to_caption = input_dataset.load_annotations()
train_featurepaths, train_captions, val_featurepaths, val_captions = input_dataset.split_dataset(all_imgpaths, imgpath_to_caption)

In [None]:
train_imgpaths = list(map(utils.featurepath_to_imgpath, train_featurepaths))
val_imgpaths = list(map(utils.featurepath_to_imgpath, val_featurepaths))

# Load model

In [None]:
config = utils.load_json_file(os.path.join(PROJECT_PATH, 'config.json'))
max_length = config['max_text_length']

attention_features_shape = 64

In [None]:
tokenizer = text_vectorization.load_text_vectorizer(TOKENIZER_PATH)
word_to_index = tf.keras.layers.StringLookup(mask_token="", vocabulary=tokenizer.get_vocabulary())
index_to_word = tf.keras.layers.StringLookup(mask_token="", vocabulary=tokenizer.get_vocabulary(), invert=True)

In [None]:
inceptionV3 = encoder.create_inception_v3()
cnn_encoder, rnn_decoder, optimizer = create_models(config['vocabulary_size'])
ckpt_manager = checkpoint_manager.create_checkpoint_manager(cnn_encoder, rnn_decoder, optimizer, restore_latest=True)

# Evaluate

In [None]:
def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden = rnn_decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(input_dataset.load_and_preprocess_image(image)[0], 0)
    img_tensor_val = inceptionV3(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0], -1, img_tensor_val.shape[3]))

    features = cnn_encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = rnn_decoder(dec_input, features, hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

In [None]:
def plot_attention(image, sentence, attention_weights):
    temp_image = input_dataset.load_image(image).numpy()
    ratio = temp_image.shape[0] / temp_image.shape[1]
    
    n_words = len(sentence)
    n_cols = 4
    n_rows = int(np.ceil(n_words / n_cols))
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(5 * n_cols, 5 * n_rows * ratio))
    axes_raveled = axes.ravel()
    
    for i in range(n_words):
        ax = axes_raveled[i]
        temp_att = np.resize(attention_weights[i], (8, 8))
        ax.set_title(sentence[i])
        img = ax.imshow(temp_image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())
    for ax in axes_raveled[n_words:]:
        ax.set_visible(False)

    fig.tight_layout()

In [None]:
# captions on the validation set
rid = np.random.randint(0, len(val_imgpaths))
random_image_path = val_imgpaths[rid]
real_caption = val_captions[rid]

result, attention_plot = evaluate(random_image_path)

print('Real Caption:', real_caption)
print('Prediction Caption:', ' '.join(result))
plot_attention(random_image_path, result, attention_plot)