In [None]:
import numpy as np
import string
import glob
import os
import warnings
import tensorflow as tf
warnings.filterwarnings('ignore')
from tensorflow.keras.applications.inception_v3 import InceptionV3, preprocess_input
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, add, RepeatVector, Reshape, concatenate
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.text import Tokenizer 
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tqdm import tqdm
os.environ['TQDM_DISABLE'] = '1'

In [None]:
# Mapping image IDs to their descriptions

def load_doc(filename):
  with open(filename, 'r') as f:
    text = f.read()
  return text

filename='Flickr8k.token.txt'
doc=load_doc(filename)

desc = {}
for line in doc.strip().split('\n'):
  tokens = line.split('\t')
  img_id, img_desc = tokens[0], tokens[1]
  img_id = img_id.split('#')[0]
  img_desc = img_desc.lower().translate(str.maketrans('', '', string.punctuation))
  if img_id not in desc:
    desc[img_id] = []
  desc[img_id].append('startseq'+img_desc+'endseq')

print(f"Loaded {len(desc)} image descriptions.")

In [None]:
# Extracting features from images using InceptionV3 pre-trained model

model = InceptionV3(weights='imagenet')
model = Model(model.input, model.layers[-2].output)

features = {}
images = glob.glob('Flicker8k_Dataset/*.jpg')

for img_path in tqdm(images, disable=False):
  img_id = os.path.basename(img_path).split('.')[0]
  img = image.load_img(img_path, target_size=(299, 299))
  img = image.img_to_array(img)
  img = np.expand_dims(img, axis=0)
  img = preprocess_input(img)
  feature = model.predict(img, verbose=0)
  features[img_id] = feature

print(f"Extracted features for {len(features)} images.")

In [None]:
# Preparing the tokenizer and vocabulary size

descs = []
for key in desc:
  [descs.append(d) for d in desc[key]]

tokenizer = Tokenizer()
tokenizer.fit_on_texts(descs)
vocab_size = len(tokenizer.word_index) + 1
maxLen = max(len(d.split()) for d in descs)
print(f"Vocabulary Size: {vocab_size}")
print(f"Max Caption Length: {maxLen}")

In [None]:
# Function to create a TensorFlow dataset from the image descriptions and features

def create_tf_dataset(desc, features, tokenizer, maxLen, vocab_size, batch_size):
    def data_generator():
        while True:
            for key, value in desc.items():
                if key in features:
                    feature = features[key][0]
                    for d in value:
                        seq = tokenizer.texts_to_sequences([d])[0]
                        for i in range(1, len(seq)):
                            in_seq, out_seq = seq[:i], seq[i]
                            in_seq = pad_sequences([in_seq], maxlen=maxLen)[0]
                            out_seq_onehot = to_categorical(out_seq, num_classes=vocab_size)
                            yield (feature, in_seq), out_seq_onehot
    feature_dim = features[list(features.keys())[0]][0].shape[0]
    output_signature = (
        (
            tf.TensorSpec(shape=(feature_dim,), dtype=tf.float32),
            tf.TensorSpec(shape=(maxLen,), dtype=tf.int32)
        ),
        tf.TensorSpec(shape=(vocab_size,), dtype=tf.float32)
    )
    dataset = tf.data.Dataset.from_generator(data_generator, output_signature=output_signature)
    return dataset.batch(batch_size).prefetch(tf.data.AUTOTUNE)

In [None]:
# Architecture taken from the paper "Show and Tell: A Neural Image Caption Generator" -> https://arxiv.org/pdf/1411.4555.pdf

inputs1 = Input(shape=(2048,))
feature1 = Dense(256, activation='relu')(inputs1)
feature1_reshaped = Reshape((1, 256), input_shape=(256,))(feature1)

inputs2 = Input(shape=(maxLen,))
emb1 = Embedding(vocab_size, 256, mask_zero = False)(inputs2)
merged = concatenate([feature1_reshaped, emb1], axis=1)
emb2 = LSTM(256)(merged)
emb3 = Dropout(0.5)(emb2)
combined = add([emb3, feature1])
x = Dense(256, activation='relu')(combined)
x = Dropout(0.5)(x)
outputs = Dense(vocab_size, activation='softmax')(x)

model = Model(inputs=[inputs1, inputs2], outputs=outputs)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()

In [None]:
batch_size = 32
epochs = 30

def normalize_desc_keys(desc):
    normalized_desc = {}
    for key, value in desc.items():
        new_key = key.replace('.jpg', '')
        normalized_desc[new_key] = value
    return normalized_desc

desc_normalized = normalize_desc_keys(desc)

common_keys = set(desc_normalized.keys()).intersection(set(features.keys()))
filtered_desc = {key: desc_normalized[key] for key in common_keys}

def calculate_steps(desc, batch_size):
    total_seq = 0
    for key, value in desc.items():
        for d in value:
            seq_len = len(tokenizer.texts_to_sequences([d])[0])
            total_seq += max(1, seq_len - 1)
    return total_seq // batch_size

steps = calculate_steps(filtered_desc, batch_size)
print(f"Total Steps per Epoch: {steps}")
print(f"Number of images: {len(filtered_desc)}")
print(f"Total sequences: {steps * batch_size}")

callbacks = [
    ModelCheckpoint('model.h5', save_best_only=True, monitor='loss'),
    EarlyStopping(patience=5, monitor='loss', restore_best_weights=True),
    ReduceLROnPlateau(factor=0.5, patience=3, min_lr=1e-7)
]

dataset = create_tf_dataset(filtered_desc, features, tokenizer, maxLen, vocab_size, batch_size)

history = model.fit(
    dataset,
    epochs=epochs,
    steps_per_epoch=steps,
    callbacks=callbacks,
    verbose=1
)

In [None]:
def generate_desc(model, tokenizer, photo, max_length):
    in_text = 'startseq'
    for _ in range(max_length):
        sequence = tokenizer.texts_to_sequences([in_text])[0]
        sequence = pad_sequences([sequence], maxlen=max_length)
        yhat = model.predict([photo, sequence], verbose=0)
        yhat = np.argmax(yhat)
        word = None
        for w, index in tokenizer.word_index.items():
            if index == yhat:
                word = w
                break
        if word is None:
            break
        in_text += ' ' + word
        if word == 'endseq':
            break
    return in_text

keys = list(features.keys())
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
image_file_path = 'Flicker8k_Dataset/' + keys[19] + '.jpg'
img = mpimg.imread(image_file_path)
plt.imshow(img)
plt.axis('off')
plt.show()
photo = features[keys[19]]
print(generate_desc(model, tokenizer, photo, maxLen))


In [None]:
plt.figure(figsize=(20,8))
plt.plot(history.history['loss'])
plt.title('model loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()

In [None]:
def idx_to_word(integer,tokenizer):
    
    for word, index in tokenizer.word_index.items():
        if index==integer:
            return word
    return None

def predict_caption(model, image, tokenizer, max_length, features):
    
    feature = features[image]
    in_text = "startseq"
    for i in range(max_length):
        sequence = tokenizer.texts_to_sequences([in_text])[0]
        sequence = pad_sequences([sequence], max_length)

        y_pred = model.predict([feature,sequence])
        y_pred = np.argmax(y_pred)
        
        word = idx_to_word(y_pred, tokenizer)
        
        if word is None:
            break
            
        in_text+= " " + word
        
        if word == 'endseq':
            break
            
    return in_text 

In [None]:
samples = ['Flicker8k_Dataset/1000268201_693b08cb0e.jpg', 'Flicker8k_Dataset/1001773457_577c3a7d70.jpg', 'Flicker8k_Dataset/1002674143_1b742ab4b8.jpg']
for sample in samples:
    img = mpimg.imread(sample)
    plt.imshow(img)
    plt.axis('off')
    plt.show()
    
    img_id = os.path.basename(sample).split('.')[0]
    caption = predict_caption(model, img_id, tokenizer, maxLen, features)
    print(f"Predicted Caption: {caption}")
    


In [None]:
for sample in samples:
    img_id = os.path.basename(sample).split('.')[0]
    print(f"{img_id} feature norm: {np.linalg.norm(features[img_id])}")
