# Image Caption  
I'll use pretrained coco dataset to train this image caption model  
the dataset is in [http://www.cs.toronto.edu/~vendrov/order/coco.zip](http://www.cs.toronto.edu/~vendrov/order/coco.zip)

In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt

%matplotlib inline

## Load pretrained vgg-19 image data   
The data is in 4096 dimension

In [2]:
path = './data/coco/images/10crop'
train_data = np.load( os.path.join(path, 'train.npy'))
val_data = np.load( os.path.join(path, 'val.npy'))
test_data = np.load( os.path.join(path, 'test.npy'))

In [3]:
print('train data shape: {}'.format(train_data.shape))
print('validation data shape: {}'.format(val_data.shape))
print('test data shape: {}'.format(test_data.shape))

train data shape: (113287, 4096)
validation data shape: (5000, 4096)
test data shape: (5000, 4096)


## Load the caption

In [4]:
def read_text(file_path):
    content = []
    max_length = 0
    with open(file_path, 'rb') as f:
        for line in f:
            content.append(line.strip())
            if len(line.strip().split()) > max_length:
                max_length = len(line.strip().split())
    return np.array(content), max_length

In [5]:
text_path = './data/coco'
train_cap, train_max_length = read_text(os.path.join(text_path, 'train.txt'))
val_cap, val_max_length = read_text(os.path.join(text_path, 'val.txt'))
test_cap, test_max_length = read_text(os.path.join(text_path, 'test.txt'))

In [6]:
print('train caption shape: {0}, max length: {1}'.format(train_cap.shape, train_max_length))
print('validation caption shape: {0}, max length: {1}'.format(val_cap.shape, val_max_length))
print('test caption shape: {0}, max length: {1}'.format(test_cap.shape, test_max_length))

train caption shape: (566435,), max length: 49
validation caption shape: (25000,), max length: 47
test caption shape: (25000,), max length: 43


In [7]:
train_cap[:10]

array([b'a woman wearing a net on her head cutting a cake',
       b'a woman cutting a large white sheet cake',
       b'a woman wearing a hair net cutting a large sheet cake',
       b'there is a woman that is cutting a white cake',
       b'a woman marking a cake with the back of a chefs knife',
       b'a young boy standing in front of a computer keyboard',
       b'a little boy wearing headphones and looking at a computer monitor',
       b'he is listening intently to the computer at school',
       b'a young boy stares up at the computer monitor',
       b'a young kid with head phones on using a computer'],
      dtype='|S246')

each image related to 5 captions

## Prepare caption

In [9]:
from collections import Counter
def get_vocab_int(text):
    text = text.lower()
    vocab = sorted(set(text.split()))
    vocab_counter = Counter(vocab)

    vocab = ['<PAD>','<EOS>','<UNK>','<GO>'] + vocab
    vocab_to_int = {word: index for index, word in enumerate(vocab)}
    int_to_vocab = {index: word for word, index in vocab_to_int.items()}
    return vocab_to_int, int_to_vocab, vocab_counter

In [70]:
# the input should be a list of words of a sentence
# and the output is the corresponding int value with the length as the max length of all the sentence
# the output starts with the <GO> and ends with <EOS>, with <PAD> to fill the empty places
def get_cap_id(sentence, vocab_to_int, max_length):
    output = []
    for word in sentence:
        output.append(vocab_to_int.get(word, vocab_to_int['<UNK>']))
    output.insert(0, vocab_to_int['<GO>'])
    output.append(vocab_to_int['<EOS>'])
    output = output + [vocab_to_int['<PAD>']] * (max_length - len(sentence))
    return np.array(output)

In [11]:
# get all the text to generate the dictionary
total_text = ''
for sentence in train_cap:
    total_text += sentence.decode('utf-8')
for sentence in val_cap:
    total_text += sentence.decode('utf-8')
for sentence in test_cap:
    total_text += sentence.decode('utf-8')

In [12]:
vocab_to_int, int_to_vocab, vocab_counter = get_vocab_int(total_text)

## Build the generator

In [95]:
# img_data: preprocessed image data, has the shape [None, 4096]
# cap_data: the caption of corresponding image, one image to 5 captions, has the shape [image_num*5, ...]

def data_generator(img_data, cap_data, vocab_to_int, max_length, batch_size = 32):
    partial_caps = []
    next_words = []
    images = []

    count = 0
    
    for img_index, image in enumerate(img_data):
        for cap_index in range(5):
            current_cap_list = cap_data[img_index + cap_index].decode('utf-8').split()
            for word_index in range( len(current_cap_list)-1 ):
                images.append(image)
                # for the max_length includes the <GO> and <EOS>, so when passing to get_cap_id, it should minus 2
                partial_caps.append(get_cap_id(current_cap_list[:word_index+1], vocab_to_int, max_length-2))
                
                next_word = np.zeros(len(vocab_to_int))
                next_word[vocab_to_int[current_cap_list[word_index+1]]] = 1
                next_words.append(next_word)
                count += 1
                
                if count >= batch_size:
                    yield [np.array(images), np.array(partial_caps)], np.array(next_words)
                    count = 0
                    partial_caps = []
                    next_words = []
                    images = []

## Build the model

In [34]:
import keras
from keras.layers import LSTM, Embedding, TimeDistributed, Dense, RepeatVector, concatenate, Input
from keras.models import Sequential, Model

### Some parameters
the embedding dimension of image and caption should be the same

In [97]:
img_input_dim = 4096
img_embed_dim = 4096
max_cap_length = max(train_max_length, val_max_length, test_max_length) + 2
vocab_size = len(vocab_to_int)
cap_embed_dim = 4096

### Image model

In [85]:
# image_input = Input(shape=(1, 4096), name='image_input')
img_input = Input(shape=(img_input_dim,), name='image_input')
img_dense = Dense(img_embed_dim, activation='relu')(img_input)
img_output = RepeatVector(max_cap_length)(img_dense)

### Caption model

In [86]:
# cap_input = Input(shape=(max_length,), dtype='int32', name='cap_input')
# embed = Embedding(output_dim=cap_embed_dim, input_dim=vocab_size, input_length=max_cap_length)(cap_input)
cap_input = Input(shape=(max_cap_length,), name='cap_input')
cap_embed = Embedding(vocab_size, 256)(cap_input)
cap_lstm = LSTM(256,return_sequences=True)(cap_embed)
cap_output = TimeDistributed(Dense(cap_embed_dim))(cap_lstm)

### concatenate the image and caption model

In [87]:
merge = concatenate([img_output, cap_output])
merge = LSTM(1000,return_sequences=False)(merge)
merge_output = Dense(vocab_size, activation='softmax')(merge)

model = Model(inputs=[img_input, cap_input], outputs=merge_output)

### compile

In [88]:
model.compile(loss='categorical_crossentropy', optimizer='rmsprop', metrics=['accuracy'])

In [89]:
model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
cap_input (InputLayer)           (None, 53)            0                                            
____________________________________________________________________________________________________
image_input (InputLayer)         (None, 4096)          0                                            
____________________________________________________________________________________________________
embedding_4 (Embedding)          (None, 53, 256)       23459328    cap_input[0][0]                  
____________________________________________________________________________________________________
dense_16 (Dense)                 (None, 4096)          16781312    image_input[0][0]                
___________________________________________________________________________________________

### Train the model

In [90]:
batch_size = 32
epoch = 1

In [96]:
model.fit_generator(data_generator(train_data, train_cap, vocab_to_int, max_cap_length, batch_size = batch_size), 
                    steps_per_epoch = train_cap.shape[0] // batch_size,
                    epochs = 1)

Epoch 1/1
    1/17701 [..............................] - ETA: 121389s - loss: 10.4502 - acc: 0.1875

KeyboardInterrupt: 

In [99]:
model.save('img_cap.h5')