### **Load data from github repo**

### **Imports**

In [1]:
from numpy import array
from collections import defaultdict
from pickle import load, dump
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.utils import to_categorical, plot_model
from keras.models import Model, load_model
from keras.layers import *
from keras.layers.merge import add, Concatenate
from keras.callbacks import ModelCheckpoint
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img, img_to_array
from keras.applications.vgg16 import preprocess_input
from nltk.translate.bleu_score import corpus_bleu
from sklearn.model_selection import train_test_split

import random
import numpy as np
import tensorflow as tf

### **Specify dataset and load additional data**

In [2]:
dataset = 'coco' # 'coco' or 'flickr8k'

In [None]:
if dataset == 'coco':
  from google.colab import auth
  auth.authenticate_user()
  !gsutil cp gs://tpu330/* ./
else:
  !rm -rf LitImageCaptions
  !git clone -q https://github.com/AndrewB330/LitImageCaptions
  !unzip -o -q LitImageCaptions/features.zip
  !cp LitImageCaptions/descriptions.txt descriptions.txt

Copying gs://tpu330/coco_descriptions.txt...
Copying gs://tpu330/coco_features.dat...
Copying gs://tpu330/coco_features_ids.txt...
Omitting prefix "gs://tpu330/downloads/". (Did you mean to do cp -r?)

Operation completed over 3 objects/1.3 GiB.                                      


### **Functions for reading photo features and lables**

In [3]:
def load_clean_descriptions(dataset, identifiers, tokenizer=None):
  c = open(f'{dataset}_descriptions.txt', 'r').read()
  descriptions = defaultdict(lambda: [])
  for line in c.split('\n'):
    tokens = line.split()
    if len(tokens) < 2 or len(tokens) > 16:
      continue
    id, desc = tokens[0], tokens[1:]
    if tokenizer is not None:
      desc = tokenizer.texts_to_sequences([' '.join(desc)])
      desc = tokenizer.sequences_to_texts(desc)[0].split()
      any = False
      for t in desc:
        if tokenizer.word_counts[t] <= 10:
          any = True
          break
      if any:
        continue
    desc = 'startseq ' + ' '.join(desc) + ' endseq'
    descriptions[id].append(desc)
  return [descriptions[id] for id in identifiers]

def create_tokenizer(descriptions):
  tokenizer = Tokenizer()
  tokenizer.fit_on_texts([d for dd in descriptions for d in dd])
  return tokenizer

def max_length_compute(descriptions):
  return max([len(d.split()) for dd in descriptions for d in dd])

def read_dataset(dataset, test_size=2000):
  c = open(f'{dataset}_identifiers.txt', 'r').read()
  identifiers = [l.split('.')[0] for l in c.split('\n')]
  f = np.load(open(f'{dataset}_features.dat', 'rb'))
  d = load_clean_descriptions(dataset, identifiers)
  d = load_clean_descriptions(dataset, identifiers, create_tokenizer(d))
  d = load_clean_descriptions(dataset, identifiers, create_tokenizer(d))
  assert(len(f) == len(d))
  return train_test_split(f, d, test_size=test_size, random_state=0)

### **Data generator**

In [4]:
def create_sequences_gen(tokenizer, max_length, descriptions, 
                         features, vocab_size, *args, 
                         batch_size=2048, infinite=False):
  X1, X2, y = [], [], []
  while True:
    for feature, d in zip(features.copy(), descriptions):
      for description in d:
        seq = tokenizer.texts_to_sequences([description])[0]
        for i in range(1, len(seq)):
          in_seq, out_seq = seq[:i], seq[i]
          in_seq = pad_sequences([in_seq], maxlen=max_length)[0]
          out_seq = to_categorical([out_seq], num_classes=vocab_size)[0]
          X1.append(feature.reshape(1, -1))
          X2.append(in_seq.reshape(1, -1))
          y.append(out_seq.reshape(1, -1))
          if len(X1) >= batch_size:
            yield [np.vstack(X1), np.vstack(X2)], np.vstack(y)
            X1, X2, y = [], [], []
    if not infinite:
      break
  if len(X1) > 0:
    yield [np.vstack(X1), np.vstack(X2)], np.vstack(y)

### **Read train and test data, initialize tokenizer**

In [5]:
train_features, test_features, \
  train_descriptions, test_descriptions = read_dataset(dataset)

print(f'Train dataset: {len(train_features)}')
print(f'Test dataset: {len(test_features)}')
tokenizer = create_tokenizer(train_descriptions)
vocab_size = len(tokenizer.word_index) + 1
print('Vocabulary Size: %d' % vocab_size)
max_length = max_length_compute(train_descriptions)
print('Description Length: %d' % max_length)

Train dataset: 79920
Test dataset: 2000
Vocabulary Size: 5364
Description Length: 19


# **Model definition**

In [6]:
def define_model(vocab_size, max_length):
	inputs1 = Input(shape=(4096,))
	fe1 = Dense(512, activation='relu')(inputs1)
	fe1 = Dropout(0.2)(fe1)
	# sequence model
	inputs2 = Input(shape=(max_length,))
	se1 = Embedding(vocab_size, 256, mask_zero=True)(inputs2)
	se2 = Dropout(0.5)(se1)
	se3 = LSTM(256)(se2)
	# decoder model
	decoder1 = Concatenate()([fe1, se3])
	decoder2 = Dense(512, activation='relu')(decoder1)
	decoder2 = Dropout(0.2)(decoder2)
	decoder2 = Dense(512, activation='relu')(decoder2)
	outputs = Dense(vocab_size, activation='softmax')(decoder2)
	# tie it together [image, seq] [word]
	model = Model(inputs=[inputs1, inputs2], outputs=outputs)
	model.compile(loss='categorical_crossentropy', optimizer='adam')
	plot_model(model, to_file='model.png', show_shapes=True)
	return model

In [7]:
model = define_model(vocab_size, max_length)
load_img('model.png')

version = 'v5'
filepath = version + '_model-val_loss{val_loss:.3f}.h5'
dump(tokenizer, open(version + '_tokenizer.pickle', 'wb'))
checkpoint = ModelCheckpoint(filepath, monitor='val_loss', 
                             verbose=1, save_best_only=True, mode='min')

# **Model training**

In [None]:
gen = create_sequences_gen(tokenizer, max_length, train_descriptions, train_features, vocab_size, infinite=True)
for j in range(20):
  gen_validation = create_sequences_gen(tokenizer, max_length, test_descriptions, test_features, vocab_size)
  model.fit(gen, epochs=1, verbose=1, steps_per_epoch=512, callbacks=[checkpoint], validation_data=gen_validation)

 82/512 [===>..........................] - ETA: 1:04 - loss: 5.8632