In [1]:
import os
import time
import json
from glob import glob
from PIL import Image
import pickle
import tensorflow as tf

from libs.configs import cfgs
from data.dataset_pipeline import dataset_batch, read_from_pickle, tokenize, split_dataset, load_dataset
from libs.nets.model import CNNEencoder, RNNDecoder, BahdanauAttention

++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--++--
/home/alex/python_code/Image-Caption


In [2]:
train_image_path = os.path.join(cfgs.DATASET_PATH, 'train2017')
train_annotation_path = os.path.join(cfgs.DATASET_PATH, 'annotations', 'captions_train2017.json')

train_images, train_captions = load_dataset(train_image_path, train_annotation_path, num_examples=50000)
print(len(train_images), len(train_captions))

# get step per epoch
step_per_epoch = int(len(train_images) / cfgs.BATCH_SIZE)

118287 570281
50000 50000


In [3]:
train_sequence = tokenize(train_captions)
img_name_train, img_name_val, cap_train, cap_val = split_dataset(train_images, train_sequence,
                                                                 split_ratio=cfgs.SPLIT_RATIO)
print(len(img_name_train), len(img_name_val), len(cap_train), len(cap_val))

40000 10000 40000 10000


In [4]:
# get word_index and index word
word_index = read_from_pickle(cfgs.WORD_INDEX)
index_word = {index: word for word, index in word_index.items()}

vocab_size = len(word_index)
#
train_dataset = dataset_batch(img_name_train, cap_train, batch_size=cfgs.BATCH_SIZE)

example_image_batch, example_cap_batch = next(iter(train_dataset))

In [5]:
 # show shape
embedding_dim = 256
units = 512

encoder = CNNEencoder(embedding_dim=embedding_dim)
decoder = RNNDecoder(embedding_dim, units, vocab_size)

In [6]:
optimizer = tf.keras.optimizers.Adam(learning_rate=cfgs.LEARNING_RATE)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')


def loss_function(target, pred):
    mask = tf.math.logical_not(tf.math.equal(target, 0))
    loss_ = loss_object(target, pred)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_mean(loss_)

In [7]:
# checkpoint
ckpt = tf.train.Checkpoint(encoder=encoder, decoder=decoder, optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(ckpt, directory=cfgs.TRAINED_CKPT, max_to_keep=5)

# --------------------------train start with latest checkpoint----------------------------
start_epoch = 0
if ckpt_manager.latest_checkpoint:
    start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    ckpt.restore(ckpt_manager.latest_checkpoint)

In [9]:
# --------------------------------- train_step---------------------------------------------
@tf.function
def train_step(image_feature, target):
    """

    :param image_feature:
    :param target:
    :return:
    """

    loss = 0
    # decoder hidden state
    hidden_states = decoder.reset_state(batch_size=target.shape[0])
    # decoder input per step
    decoder_input = tf.expand_dims([word_index['<start>']] * target.shape[0], axis=1)

    with tf.GradientTape() as tape:
        # get encoder feature
        feature = encoder(image_feature)

        for i in range(1, target.shape[1]):
            predictions, hidden_states, _ = decoder(x=decoder_input, feature=feature, hidden=hidden_states)

            loss += loss_function(target[:, i], predictions)

            # teacher forcing the target word is passed as the next input to the decoder
            decoder_input = tf.expand_dims(target[:, i], axis=1)

    total_loss = (loss / int(target.shape[1]))
    trainable_variables = encoder.trainable_variables + encoder.trainable_variables

    gradients = tape.gradient(loss, trainable_variables)

    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss, total_loss

In [11]:
summary_writer = tf.summary.create_file_writer(cfgs.SUMMARY_PATH)
loss_plot = []
for epoch in range(start_epoch, cfgs.NUM_EPOCH):

    start_time = time.time()
    num_steps = 0
    total_loss = 0

    for (batch, (image_feature, image_caption)) in enumerate(train_dataset.take(step_per_epoch)):

        batch_loss, t_loss = train_step(image_feature, image_caption)
        total_loss += t_loss
        num_steps += 1

        if batch % cfgs.SHOW_TRAIN_INFO_INTE == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(
                epoch + 1, batch, batch_loss / int(image_caption.shape[1])))
    if epoch % 5 == 0:
        ckpt_manager.save()

    with summary_writer.as_default():
        tf.summary.scalar('loss', (total_loss / num_steps), step=epoch)

    print('Epoch {} Loss {:.6f}'.format(epoch + 1,
                                        total_loss / num_steps))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start_time))

Epoch 2 Batch 0 Loss 1.7998


KeyboardInterrupt: 