In [1]:
!apt-get install aria2 > /dev/null

In [2]:
!aria2c -x 16 -s 16 https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip > /dev/null 
!aria2c -x 16 -s 16 https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip > /dev/null

In [3]:
!unzip /content/Flickr8k_Dataset.zip > /dev/null
!unzip /content/Flickr8k_text.zip > /dev/null
!rm *.zip > /dev/null

In [4]:
import os
import re
import gc
import numpy as np
import collections
from PIL import Image
from textwrap import wrap
import matplotlib.pyplot as plt
from wordcloud import WordCloud
from tqdm import tqdm_notebook as tqdm

from keras.utils import to_categorical
from keras.preprocessing.text import Tokenizer
from keras_preprocessing.sequence import pad_sequences
from keras_preprocessing.image import load_img, img_to_array

import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.utils import plot_model
from tensorflow.keras.applications.xception import Xception
from tensorflow.keras.layers import Input, Dropout, Dense, Embedding, LSTM, add

In [5]:
def clean_description(desc, stopwords):

  cleaned = desc.lower()
  cleaned = re.sub('[^a-z]',' ',cleaned)
  tokens = cleaned.split(' ')
  cleaned = ' '.join([w for w in tokens if w not in stopwords and len(w)>1])
  
  return cleaned

In [6]:
def get_vocabulary(dictionary):
  vocab = set()

  for desc_list in dictionary.values():
    for desc in desc_list:
      words = desc.split(' ')
      for word in words:
        vocab.add(word)

  return vocab

In [7]:
with open('/content/Flickr8k.token.txt', 'r') as f:
  all_desc = f.read().split('\n')

In [None]:
# Some sample data
all_desc[:5]

In [9]:
stopwords = ['is', 'an', 'a', 'the', 'was']

In [10]:
all_dict = dict()

for desc in all_desc:
  if len(desc) < 1:
    continue
  file_name, file_desc = desc.split('\t')[0].split('.')[0], desc.split('\t')[1]
  
  if file_name not in all_dict.keys():
    all_dict[file_name] = []

  cleaned_desc = clean_description(file_desc, stopwords)
  cleaned_desc = 'startseq ' + cleaned_desc + ' endseq'

  all_dict[file_name].append(cleaned_desc)

In [11]:
vocab = get_vocabulary(all_dict)

In [None]:
print('Total images:', len(all_dict))
print('Total vocabulary without stopwords:', len(vocab))

In [None]:
!wget https://logos.flamingtext.com/Name-Logos/Bri-design-china-name.png > /dev/null

In [14]:
mask = np.array(Image.open('Bri-design-china-name.png'))

wordcloud = WordCloud(width = 500, height = 400, 
                  background_color ='black', 
                  min_font_size = 2,
                  mask=mask, random_state=0).generate(' '.join(vocab)) 

In [None]:
plt.figure(figsize = (10, 10), facecolor = 'k', edgecolor = 'k' ) 
plt.imshow(wordcloud) 
plt.axis("off") 
plt.tight_layout(pad = 0)

plt.show()

In [16]:
all_sent_list = [item.strip('startseq').strip('endseq').strip(' ') for sublist in list(all_dict.values()) for item in sublist]

In [17]:
all_sent_len = [len(sent) for sent in all_sent_list]

In [None]:
plt.hist([len(sentence.split()) for sentence in all_sent_list])
plt.xlabel('Number of words')
plt.ylabel('Number of sentenes')
plt.title('Count of number of words in sentences')
plt.show()

In [None]:
avg_sent_len = int(np.mean([len(sentence.split()) for sentence in all_sent_list]))
avg_sent_len

In [20]:
words = [w for a in all_sent_list for w in a.split(' ')]

In [None]:
counts = collections.Counter(words)
most_common = counts.most_common()
most_common[:15]

In [22]:
keys = [tupl[0] for tupl in most_common][:15]
values = [tupl[1] for tupl in most_common][:15]

In [None]:
plt.figure(figsize=(8, 8))
plt.bar(keys, values)
plt.xlabel('Words')
plt.ylabel('Frequency')
plt.title('Frequency of most common words')
plt.show()

In [None]:
lengths = set()
for cap_list in all_dict.values():
  lengths.add(len(cap_list))

print('Number of captions for each image: ', lengths)

In [None]:
fig = plt.figure()
fig.suptitle('Sample images', fontsize=20)

zoom = 3
w, h = fig.get_size_inches()
fig.set_size_inches(w * zoom, h * zoom)
fig.tight_layout()

for i in range(1, 26):
  ax = fig.add_subplot(5, 5, i)
  ax.imshow(plt.imread('Flicker8k_Dataset/'+list(all_dict.keys())[i-1]+'.jpg'))
  plt.axis('off')

plt.show()

In [None]:
fig = plt.figure()
fig.suptitle('Sample images with one of the captions', fontsize=20)

zoom = 2
w, h = fig.get_size_inches()
fig.set_size_inches(w * zoom, h * zoom)

for i in range(1, 5):
  ax = fig.add_subplot(2, 2, i)
  ax.imshow(plt.imread('Flicker8k_Dataset/'+list(all_dict.keys())[i-1]+'.jpg'))
  title = ax.set_title('\n'.join(wrap(all_dict.get(list(all_dict.keys())[i-1])[0], 60)))
  fig.tight_layout(h_pad=2)
  title.set_y(1.05)
  plt.axis('off')
  fig.subplots_adjust(top=0.85, hspace=0.3)

plt.show()

In [None]:
xcep = Xception(include_top=False, pooling='avg')

In [None]:
xcep.summary()

In [None]:
predictions = dict()

for dirpath, dirname, files in os.walk('Flicker8k_Dataset'):
  for filename in tqdm(files):
    img_path = os.path.join(dirpath, filename)
    if os.path.isfile(img_path):
      img = Image.open(img_path)
      img = img.resize((299,299))
      img = np.expand_dims(img, axis=0)
      img = img/127.5
      img = img - 1.0

      predictions[filename.split('.')[0]] = xcep.predict(img)

In [None]:
print('Number of extracted features:', len(predictions.get(list(predictions.keys())[0])[0]))

In [None]:
predictions.get(list(predictions.keys())[0])

In [32]:
def create_list(dictionary):
  final_list = []

  for desc_list in dictionary.values():
    for desc in desc_list:
      final_list.append(desc)

  return final_list

In [33]:
def fit_tokenizer(dictionary):
  desc_list = create_list(dictionary)
  tokenizer = Tokenizer()
  tokenizer.fit_on_texts(desc_list)
  return tokenizer

In [34]:
def convert_to_input(tokens, pos, im_name, max_len, vocab_len, tokenizer, img_predictions):

  inp = tokens[:pos]
  out = tokens[pos]
  inp = pad_sequences(sequences=[inp], maxlen=max_len)[0]
  out = to_categorical(y=[out], num_classes=vocab_len, dtype='bool')[0]
  
  return img_predictions.get(im_name)[0], inp, out

In [35]:
def convert_all_to_input(dictionary, max_len, vocab_len, tokenizer, img_predictions):
  
  X_1 = list()
  X_2 = list()
  y = list()

  for im_name, descriptions in tqdm(dictionary.items()):
    if im_name in img_predictions.keys():
      for desc in descriptions:
          tokens = tokenizer.texts_to_sequences([desc])[0]
          for i in range(1, len(tokens)):
              _X_1, _X_2, _y = convert_to_input(tokens, i, im_name, max_len, vocab_len, tokenizer, img_predictions)
              X_1.append(_X_1)
              X_2.append(_X_2)
              y.append(_y)
  return np.array(X_1), np.array(X_2), np.array(y)

In [36]:
tokenizer = fit_tokenizer(all_dict)

In [None]:
vocab_len = len(tokenizer.index_word) + 1
vocab_len

In [None]:
max_len = len(max(create_list(all_dict)))
max_len

In [None]:
cnn_len = predictions[list(predictions.keys())[0]].shape[1]
cnn_len

In [None]:
X_1, X_2, y = convert_all_to_input(all_dict, max_len, vocab_len, tokenizer, predictions)

In [41]:
def shuffle_arrays(arrays, set_seed=-1):
    """Shuffles arrays in-place, in the same order, along axis=0

    Parameters:
    -----------
    arrays : List of NumPy arrays.
    set_seed : Seed value if int >= 0, else seed is random.
    """
    assert all(len(arr) == len(arrays[0]) for arr in arrays)
    seed = np.random.randint(0, 2**(32 - 1) - 1) if set_seed < 0 else set_seed

    for arr in arrays:
        rstate = np.random.RandomState(seed)
        rstate.shuffle(arr)

In [42]:
shuffle_arrays([X_1, X_2, y])

In [43]:
def create_model(cnn_len, max_len, vocab_len):
  cnn_in = Input(shape=(cnn_len,))
  cnn_x = Dropout(0.5)(cnn_in)
  cnn_out = Dense(units=256, activation='relu')(cnn_x)

  lstm_in = Input(shape=(max_len,))
  lstm_x = Embedding(vocab_len, 256, mask_zero=True)(lstm_in)
  lstm_x = Dropout(0.5)(lstm_x)
  lstm_out = LSTM(256)(lstm_x)

  combined = add([cnn_out, lstm_out])
  combined_x = Dense(units=256, activation='relu')(combined)
  output = Dense(units=vocab_len, activation='softmax')(combined_x)

  model = Model(inputs=[cnn_in, lstm_in], outputs=output)

  return model

In [44]:
model = create_model(cnn_len, max_len, vocab_len)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-3), 
            loss=tf.keras.losses.categorical_crossentropy, 
            metrics=[tf.keras.metrics.categorical_accuracy])

In [None]:
model.summary()

In [None]:
# Train 4096 rows at a time to avoid memory overflow.

batch_size=4096

for i in tqdm(range(0, len(X_1), batch_size)):
  model.fit(x=[X_1[i:i+batch_size], X_2[i:i+batch_size]], y=y[i:i+batch_size], epochs=20, batch_size=256)

### Inference

In [47]:
def word_for_id(integer, tokenizer):

	for word, index in tokenizer.word_index.items():
		if index == integer:
			return word
	
	return None

In [48]:
def generate_desc(model, tokenizer, photo, max_len):
	
	start_text = 'startseq'

	for i in range(max_len):

		tokens = tokenizer.texts_to_sequences([start_text])[0]

		tokens = pad_sequences([tokens], maxlen=max_len)

		pred = model.predict([photo, tokens], verbose=0)

		pred = np.argmax(pred)

		word = word_for_id(pred, tokenizer)

		if word is None:
			break

		start_text += ' ' + word

		if word == 'endseq':
			break

	return start_text

In [None]:
!wget https://yt3.ggpht.com/9Eiy_C42-Cqiv2-PkhQ0DuElWcZEW6KOivNLDriocMBRe297UMcXwbOwSZedGATnb8IBnkCE0A=s900-c-k-c0x00ffffff-no-rj > /dev/null

In [None]:


img = Image.open('Bri-design-china-name.png')
img2 = img.copy()
img = img.resize((299,299))
img = np.expand_dims(img, axis=0)
img = img/127.5
img = img - 1.0
pred = xcep.predict(img)

plt.figure(figsize=(10, 10))
plt.imshow(img2)
plt.axis('off')
plt.show()

caption = generate_desc(model, tokenizer, pred, vocab_len)
caption = caption.strip('startseq').strip('endseq')
print(caption)