In [2]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import pickle
import os

from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.preprocessing import image as keras_image
from tensorflow.keras.models import Model


In [3]:
# Load tokenizer
with open('../data/tokenizer.pkl', 'rb') as f:
    tokenizer = pickle.load(f)

vocab_size = len(tokenizer.word_index) + 1
max_length = 50  # use value from training

# Hyperparameters (same as training)
embedding_dim = 256
units = 512



In [4]:
class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, features, hidden):
        hidden_with_time_axis = tf.expand_dims(hidden, 1)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))
        attention_weights = tf.nn.softmax(self.V(score), axis=1)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)
        return context_vector, attention_weights

class CNN_Encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        self.fc = tf.keras.layers.Dense(embedding_dim)

    def call(self, x):
        return self.fc(x)

class RNN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNN_Decoder, self).__init__()
        self.units = units
        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
        self.lstm = tf.keras.layers.LSTM(units,
                                         return_sequences=True,
                                         return_state=True,
                                         recurrent_initializer='glorot_uniform')
        self.fc1 = tf.keras.layers.Dense(units)
        self.fc2 = tf.keras.layers.Dense(vocab_size)
        self.attention = BahdanauAttention(units)

    def call(self, x, features, hidden):
        context_vector, attention_weights = self.attention(features, hidden)
        x = self.embedding(x)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)
        output, state, _ = self.lstm(x)
        x = self.fc1(output)
        x = tf.reshape(x, (-1, x.shape[2]))
        x = self.fc2(x)
        return x, state, attention_weights

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))


In [5]:
# Initialize model
encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, vocab_size)

# Load trained weights (if saved)
# encoder.load_weights('../checkpoints/best_model/encoder.h5')
# decoder.load_weights('../checkpoints/best_model/decoder.h5')


In [6]:
def preprocess_image(img_path):
    img = keras_image.load_img(img_path, target_size=(299, 299))
    img = keras_image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)
    return img

def extract_image_features(img_path):
    image_model = InceptionV3(include_top=False, weights='imagenet', pooling='avg')
    img_tensor = preprocess_image(img_path)
    features = image_model.predict(img_tensor)
    return features


In [7]:
def generate_caption(img_path):
    features = extract_image_features(img_path)
    features = encoder(features)

    hidden = decoder.reset_state(batch_size=1)
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)

    result = []
    attention_plot = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input, features, hidden)
        attention_plot.append(attention_weights.numpy().reshape(-1))

        predicted_id = tf.argmax(predictions[0]).numpy()
        predicted_word = tokenizer.index_word.get(predicted_id, '')

        if predicted_word == '<end>':
            break

        result.append(predicted_word)
        dec_input = tf.expand_dims([predicted_id], 0)

    return ' '.join(result), np.array(attention_plot)


In [8]:
def plot_attention(img_path, attention_plot, result):
    temp_img = np.array(Image.open(img_path))
    fig = plt.figure(figsize=(15, 15))

    len_result = len(result)
    for l in range(len_result):
        temp_att = attention_plot[l].reshape((1, 1))  # We don't have spatial attention here, it's a vector
        ax = fig.add_subplot(len_result // 2 + 1, 2, l + 1)
        ax.set_title(result[l])
        ax.imshow(temp_img)
        ax.axis('off')

    plt.tight_layout()
    plt.show()


In [None]:
test_image_path = '../data/val2017/000000565153.jpg'  # Change this path

caption, attention = generate_caption(test_image_path)

plot_attention(test_image_path, attention, caption.split())
