In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import json
import tensorflow as tf
import numpy as np
from models import *
import pickle

gpu = tf.config.experimental.list_physical_devices('GPU')[0]
tf.config.experimental.set_memory_growth(gpu, True)

In [3]:
with open('annotations/captions_train2014.json', 'r') as f:
    train_instances = json.loads(f.read())

with open('annotations/captions_val2014.json', 'r') as f:
    val_instances = json.loads(f.read())

train_filenames = {
    image['id']: image['file_name']+'.npy'
    for image in train_instances['images']
}

with open('annotations/captions_val2014.json') as f:
    val_images = json.loads(f.read())['images']

val_filenames = {
    image['id']: image['file_name']+'.npy'
    for image in val_instances['images']
}

In [4]:
train_captions = {
    annotation['image_id']: '<start> ' + annotation['caption'] + ' <end>'
    for annotation in train_instances['annotations']
}
    
val_captions = {
    annotation['image_id']: '<start> ' + annotation['caption'] + ' <end>'
    for annotation in val_instances['annotations']
}

In [5]:
tokenizer = tf.keras.preprocessing.text.Tokenizer(
    num_words = 5000,
    oov_token = '<unk>',
    filters='!"#$%&()*+.,-/:;=?@[\]^_`{|}~ '
)
tokenizer.fit_on_texts(train_captions.values())

In [6]:
tokenizer.word_index['<pad>'] = 0
tokenizer.index_word[0] = '<pad>'

pkl = pickle.dumps(tokenizer)
with open('tokenizer.pickle', 'wb') as f:
    f.write(pkl)

In [7]:
train_seqs = tokenizer.texts_to_sequences(train_captions.values())
train_seqs = {k: seq for k, seq in zip(train_captions.keys(), train_seqs)}

val_seqs = tokenizer.texts_to_sequences(val_captions.values())
val_seqs = {k: seq for k, seq in zip(val_captions.keys(), val_seqs)}

In [8]:
max_len = max(len(x) for x in train_seqs.values())

In [9]:
def map_func(img_id, dataset='train'):
    filenames = {
        'train': train_filenames,
        'val': val_filenames
    }[dataset]
    captions = {
        'train': train_seqs,
        'val': val_seqs
    }[dataset]
    img_id = img_id.numpy()
    img = np.load('image_embeddings/' + filenames[img_id])
    caption = captions[img_id]
    return img, np.array(caption, dtype=np.int32)

In [10]:
from functools import partial
SHUFFLE_BUFFER_SIZE = 1000
BATCH_SIZE = 256

train_ds = tf.data.Dataset.from_tensor_slices(list(train_filenames.keys()))
val_ds = tf.data.Dataset.from_tensor_slices(list(val_filenames.keys()))

train_ds = train_ds.map(lambda x: tf.py_function(map_func, [x], [tf.float16, tf.int32]), num_parallel_calls=tf.data.AUTOTUNE)
val_ds = val_ds.map(lambda x: tf.py_function(partial(map_func, dataset='val'), [x], [tf.float16, tf.int32]), num_parallel_calls=tf.data.AUTOTUNE)

train_ds = train_ds.shuffle(SHUFFLE_BUFFER_SIZE).padded_batch(BATCH_SIZE, padded_shapes=([64, 2048], [None]))
val_ds = val_ds.padded_batch(BATCH_SIZE, padded_shapes=([64, 2048], [None]))

def shift_tokens(x, y):
    # return [x, y[:, :-1]], y[1:]
    return (x, y[:, :-1]), y[:, 1:]

train_ds = train_ds.map(shift_tokens, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)
val_ds = val_ds.map(shift_tokens, num_parallel_calls=tf.data.AUTOTUNE).prefetch(tf.data.AUTOTUNE)

In [11]:
loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')
def loss(y_true, y_pred):
    mask = tf.math.logical_not(tf.math.equal(y_true, 0))
    loss_ = loss_fn(y_true, y_pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)

In [12]:
# model = ImageCaptioning(attention_hidden_units=1024, hidden_state_size=512, embedding_dim=512, vocab_size=5000, input_shape=(64, 2048))
model = ImageCaptioning(attention_hidden_units=128, hidden_state_size=128, embedding_dim=1024, vocab_size=5000, input_shape=(64, 2048))
# initial_context = tf.zeros(shape=(64, 512))

# for x, y in train_ds:
#     print('x:')
#     for row in x[1]:
#         print([f'{col}' for col in row])
#     print('target: ')
#     for row in y:
#         print([f'{col}' for col in row])
#     print(y)
#     pred = model(x)
#     print(pred)
#     break
model.compile(loss=loss, optimizer='adam')

In [13]:
ckpt = tf.keras.callbacks.ModelCheckpoint(
    'models/weights.{epoch:02d}-{val_loss:.4f}.hdf5', monitor='val_loss', verbose=0, save_best_only=False, mode='min'
)

model.fit(train_ds, validation_data=val_ds, epochs=1000, callbacks=[ckpt])

Epoch 1/1000
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000

KeyboardInterrupt: 

In [20]:
list(val_filenames.keys())[0]

391895

In [22]:
val_seqs[391895]

KeyError: 391895