In [3]:
# needs Graphviz installed in OS
!pip install -q pydot

In [4]:
%env TF_CPP}_MIN_LOG_LEVEL=3

env: TF_CPP}_MIN_LOG_LEVEL=3


In [5]:
import os
import nltk
import pickle
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from tqdm.notebook import tqdm
from tensorflow.keras.models import Model
from nltk.translate.bleu_score import corpus_bleu
from tensorflow.keras.preprocessing.text import Tokenizer
from tensorflow.keras.utils import to_categorical, plot_model
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.layers import Input, Dense, LSTM, Embedding, Dropout, add

import warnings
warnings.filterwarnings('ignore')

# **Loading Dataset**

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
#Unzip dataset !unzip <path>
!unzip /content/drive/MyDrive/Colab\ Notebooks/archive.zip

# **Resources Extracting**

In [None]:
vgg_model = VGG16()

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels.h5
[1m553467096/553467096[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m25s[0m 0us/step


In [None]:
# Remove classification layers
model = Model(inputs = vgg_model.inputs, outputs = vgg_model.layers[-2].output)

In [None]:
features = {}

In [None]:
base_path = '/content'

In [None]:
path = os.path.join(base_path, 'Images')

In [None]:
%%time

for img in tqdm(os.listdir(path)):
  img_path = os.path.join(path, img)

  image = load_img(img_path, target_size = (224, 224))

  image = img_to_array(image)

  # Create Batch dimension
  image = image.reshape((1, image.shape[0], image.shape[1], image.shape[2]))

  image = preprocess_input(image)

  feature = model.predict(image, verbose = 0)

  image_id = img.split('.')[0]

  features[image_id] = feature

  0%|          | 0/8091 [00:00<?, ?it/s]

CPU times: user 10min 56s, sys: 15 s, total: 11min 11s
Wall time: 12min 12s


In [None]:
# Save features
pickle.dump(features, open(os.path.join(base_path, 'features.pkl'), 'wb'))

In [None]:
# Load features
with open(os.path.join(base_path, 'features.pkl'), 'rb') as f:
  features = pickle.load(f)

# **Data Preprocessing**

In [None]:
with open(os.path.join(base_path, 'captions.txt'), 'r') as f:
  next(f)
  captions_doc = f.read()

In [None]:
mapping = {}

In [None]:
for line in tqdm(captions_doc.split('\n')):
  tokens = line.split(',')

  if len(line) < 2:
    continue

  # Getting Image Id and caption
  image_id, caption = tokens[0], tokens[1:]

  image_id = image_id.split('.')[0]

  caption = ' '.join(caption)

  if image_id not in mapping:
    mapping[image_id] = []

  # Add caption in Image Id Dict
  mapping[image_id].append(caption)

  0%|          | 0/40456 [00:00<?, ?it/s]

In [None]:
def clean_caption(mapping):
  for key, captions in mapping.items():
    for i in range(len(captions)):
      caption = captions[i]

      caption = caption.lower()

      # Remove Special Characters
      caption = caption.replace('[^A-Za-z]', '')

      # Remove multiple whitespaces
      caption = caption.replace('\s+', ' ')

      caption = '[CLS]' + ' '.join([word for word in caption.split() if len(word) > 1]) + '[SEP]'

      captions[i] = caption

  return mapping

In [None]:
mapping = clean_caption(mapping)

In [None]:
all_captions = []

for key in mapping:
  for caption in mapping[key]:
    all_captions.append(caption)

In [None]:
tokenizer = Tokenizer()

tokenizer.fit_on_texts(all_captions)

In [None]:
vocab_size = len(tokenizer.word_index) + 1

max_length = max(len(caption.split()) for caption in all_captions)

In [None]:
# Split data in Train and Test
image_ids = list(mapping.keys())

split = int(len(image_ids) * .90)

train_data = image_ids[:split]
test_data = image_ids[split:]

In [None]:
# Prepare data to training
def data_generator(data_keys, mapping, features, tokenizer, max_length, vocab_size, batch_size):
  X1, X2, y = [], [], []

  n = 0

  while 1:
    for key in data_keys:
      n += 1

      captions = mapping[key]

      for caption in captions:
        # Transform Text in Tokens
        seq = tokenizer.texts_to_sequences([caption])[0]

        # Create Couple of Input / Output sequences
        for i in range(1, len(seq)):
          in_seq, out_seq = seq[:i], seq[i]

          # Fill Sequence with Pad Tokens
          in_seq = pad_sequences([in_seq], maxlen = max_length, padding = 'post')[0]

          out_seq = to_categorical([out_seq], num_classes = vocab_size)[0]

          X1.append(features[key][0])
          X2.append(in_seq)
          y.append(out_seq)

      if n == batch_size:

        X1, X2, y = np.array(X1), np.array(X2), np.array(y)

        # Return a Batch of data
        yield (X1, X2), y

        # Restore the lists for next batch
        X1, X2, y = [], [], []

        n = 0

# **Building the Model**

In [None]:
# Layer responsible for pixels
input_1 = Input(shape = (4096,))

drop_1 = Dropout(0.45)(input_1)

dense_1 = Dense(256, activation = 'relu')(drop_1)

In [None]:
# Layer responsible for texts
input_2 = Input(shape = (max_length,))

seq = Embedding(vocab_size, 256, mask_zero = True)(input_2)

drop_2 = Dropout(0.45)(seq)

lstm = LSTM(256)(drop_2)

decoder1 = add([dense_1, lstm])

decoder2 = Dense(256, activation = 'relu')(decoder1)

outputs = Dense(vocab_size, activation = 'softmax')(decoder2)

In [None]:
model = Model(inputs = [input_1, input_2], outputs = outputs)

In [None]:
model.compile(loss = 'categorical_crossentropy', optimizer = 'adam')

In [None]:
plot_model(model, show_shapes = True)

# **Training**

In [None]:
epochs = 35
batch_size = 32
steps = len(train_data) // batch_size

In [None]:
%%time

for i in range(epochs):
  generator = data_generator(train_data,
                             mapping,
                             features,
                             tokenizer,
                             max_length,
                             vocab_size,
                             batch_size)

  model.fit(generator, epochs = 1, steps_per_epoch = steps, verbose = 1)

In [None]:
model.save(os.path.join(base_path, 'model.h5'))

# **Evaluating Model**

In [None]:
def idx_to_word(integer, tokenizer):
  for word, index in tokenizer.word_index.items():
    if index == integer:
      return word

  return None

In [None]:
def predict_caption(model, image, tokenizer, max_length):
  in_text = '[CLS]'

  for i in range(max_length):
    # Transform Text in Tokens
    sequence = tokenizer.texts_to_sequences([in_text])[0]

    # Fill Sequence with Pad Tokens
    sequence = pad_sequences([sequence], maxlen = max_length)

    # Predict Next Word
    y_pred = model.predict([image, sequence], verbose = 0)

    # Get Index with Max Probability
    y_pred = np.argmax(y_pred)

    # Get Word
    word = idx_to_word(y_pred, tokenizer)

    if word is None:
      break

    # Append Word
    in_text += ' ' + word

    if word == '[SEP]':
      break

  return in_text

In [None]:
gt, predicted = [], []

%%time

for key in tqdm(test_data):
  captions = mapping[key]

  y_pred = predict_caption(model, features[key], tokenizer, max_length)

  gt_captions = [caption.split() for caption in captions]

  y_pred = y_pred.split()

  gt.append(gt_captions)

  predicted.append(y_pred)

In [None]:
# Score regarding only unigrams
print(f'BLEU-1: {corpus_bleu(gt, predicted, weights = (1.0, 0, 0, 0)):.4f}')

# Score regarding bi-grams
print(f'BLEU-2: {corpus_bleu(gt, predicted, weights = (0.5, 0.5, 0, 0)):.4f}')

# **Deploy**

In [None]:
def generate_caption(image_path, features, tokenizer, max_length):
  image = Image.open(image_path)

  y_pred = predict_caption(model, features[image_id], tokenizer, max_length)

  y_pred = y_pred.replace('[CLS]', '').replace('[SEP]', '').strip()

  return y_pred, image

In [None]:
image_name = '1001773457_577c3a7d70.jpg'
image_path = os.path.join(base_path, 'Images', image_name)
result, image = generate_caption(image_path, features, tokenizer, max_length)

print(result)
plt.imshow(image)