In [None]:
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False

In [None]:
if IN_COLAB:
    # Clone the repo.
    !git clone "https://github.com/ash0ts/keras-image-captioning-wandb.git"

    # Change the working directory to the repo root.
    %cd keras-image-captioning-wandb

    # Add the repo root to the Python path.
    import sys, os
    sys.path.append(os.getcwd())
    
    !pip install wandb
    !pip install PIL
    
    import wandb
    
    wandb.login()

In [None]:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

In [None]:
from model.CNNEncoder import CNN_Encoder
from model.RNNDecoder import RNN_Decoder
import tensorflow as tf
import time
from image_utils import load_image
import numpy as np
from config import WANDB_PROJECT, WANDB_ENTITY, max_length, vocabulary_size, attention_features_shape
import wandb
from utils import load_from_pickle
from text_utils import generate_text_artifacts
import os

In [None]:
default_config = {
        "EPOCHS": 20,
        "BATCH_SIZE": 64,
        "BUFFER_SIZE": 1000,
        "embedding_dim": 256,
        "units": 512,
        "features_shape": 2048,
        "attention_features_shape": 64
}

run = wandb.init(project=WANDB_PROJECT,
                 entity=WANDB_ENTITY, name="train-coco2014-attention-model", job_type="train", config=default_config)

In [None]:
EPOCHS = run.config.get("EPOCHS")
BATCH_SIZE = run.config.get("BATCH_SIZE")
BUFFER_SIZE = run.config.get("BUFFER_SIZE")
embedding_dim = run.config.get("embedding_dim")
units = run.config.get("units")
features_shape = run.config.get("features_shape")
attention_features_shape = run.config.get("attention_features_shape")

split_art = run.use_artifact("split:latest")
img_name_train_path = split_art.get_path("img_name_train.pkl").download()
cap_train_path = split_art.get_path("cap_train.pkl").download()

image_feat_art = run.use_artifact("inception_v3:latest")
image_feat_path = image_feat_art.download()
img_name_train = [os.path.join(image_feat_path, path)
                  for path in load_from_pickle(img_name_train_path)]
cap_train = load_from_pickle(cap_train_path)

# TODO: Very jank. We regenerate the tokenizer twice from these same data from a previous step which feels wrong. Hope this doesnt change.
img_cap_table = run.use_artifact(
    "image_caption_table:latest").get("img_cap_table")
train_captions = img_cap_table.get_column("caption")

In [None]:
caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)
# cap vecotr contains each sentence as max_length where the word position index is the vocab index
tokenizer, _, word_to_index, index_to_word = generate_text_artifacts(
    caption_dataset, max_length=max_length, vocabulary_size=vocabulary_size, return_vector=False, return_mapping=True)

num_steps = len(img_name_train) // BATCH_SIZE
# TODO: Load the numpy files from inception artifact

In [None]:
def map_func(img_name, cap):
    img_tensor = np.load(img_name.decode('utf-8')+'.npy')
    return img_tensor, cap

In [None]:
dataset = tf.data.Dataset.from_tensor_slices((img_name_train, cap_train))

# Use map to load the numpy files in parallel
dataset = dataset.map(lambda item1, item2: tf.numpy_function(
    map_func, [item1, item2], [tf.float32, tf.int64]),
    num_parallel_calls=tf.data.AUTOTUNE)

# Shuffle and batch
dataset = dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)

print(dataset)

In [None]:
encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, tokenizer.vocabulary_size())
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

In [None]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [None]:
checkpoint_path = "./checkpoints/train"
ckpt = tf.train.Checkpoint(encoder=encoder,
                           decoder=decoder,
                           optimizer=optimizer)
ckpt_manager = tf.train.CheckpointManager(
    ckpt, checkpoint_path, max_to_keep=5)

start_epoch = 0
if ckpt_manager.latest_checkpoint:
    start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    # restoring the latest checkpoint in checkpoint_path
    ckpt.restore(ckpt_manager.latest_checkpoint)

In [None]:
# adding this in a separate cell because if you run the training cell
# many times, the loss_plot array will be reset

@ tf.function
def train_step(img_tensor, target):
    loss = 0

    # initializing the hidden state for each batch
    # because the captions are not related from image to image
    hidden = decoder.reset_state(batch_size=target.shape[0])

    dec_input = tf.expand_dims(
        [word_to_index('<start>')] * target.shape[0], 1)

    with tf.GradientTape() as tape:
        features = encoder(img_tensor)

        for i in range(1, target.shape[1]):
            # passing the features through the decoder
            predictions, hidden, _ = decoder(dec_input, features, hidden)

            loss += loss_function(target[:, i], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(target[:, i], 1)

    total_loss = (loss / int(target.shape[1]))

    trainable_variables = encoder.trainable_variables + decoder.trainable_variables

    gradients = tape.gradient(loss, trainable_variables)

    optimizer.apply_gradients(zip(gradients, trainable_variables))

    return loss, total_loss

In [None]:
def grab_gradients(model):
    gradients = {}
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.Conv2D) or isinstance(layer, tf.keras.layers.Dense):
            gradients[layer.name] = layer.get_weights()[0]
    return gradients

In [None]:
loss_plot = []

In [None]:
for epoch in range(start_epoch, EPOCHS):
    start = time.time()
    total_loss = 0

    for (batch, (img_tensor, target)) in enumerate(dataset):
        print(batch)
        batch_loss, t_loss = train_step(img_tensor, target)
        run.log({"batch_loss": batch_loss})
        total_loss += t_loss

        if batch % 100 == 0:
            average_batch_loss = batch_loss.numpy()/int(target.shape[1])
            print(
                f'Epoch {epoch+1} Batch {batch} Loss {average_batch_loss:.4f}')
    # storing the epoch end loss value to plot later
    run.log({"epoch_loss": total_loss/num_steps})
    run.log({"encoder_epoch_gradient": grab_gradients(encoder)})
    run.log({"decoder_epoch_gradient": grab_gradients(decoder)})
    loss_plot.append(total_loss / num_steps)

    if epoch % 5 == 0:
        ckpt_manager.save()

    print(f'Epoch {epoch+1} Loss {total_loss/num_steps:.6f}')
    print(f'Time taken for 1 epoch {time.time()-start:.2f} sec\n')

In [None]:
encoder.save("encoder")
decoder.save("decoder")

In [None]:
encoder_model = wandb.Artifact(name="encoder", type="model")
decoder_model = wandb.Artifact(name="decoder", type="model")
checkpoints_art = wandb.Artifact(name="checkpoints", type="training")

encoder_model.add_dir("encoder")
decoder_model.add_dir("decoder")
checkpoints_art.add_dir("checkpoints")
# run.finish()

In [None]:
run.log_artifact(encoder_model)
run.log_artifact(decoder_model)
run.log_artifact(checkpoints_art)

In [None]:
run.finish()

## Evaluate

In [None]:
from dotenv import load_dotenv, find_dotenv
load_dotenv(find_dotenv())

In [None]:
import tensorflow as tf
from model.CNNEncoder import CNN_Encoder
from model.RNNDecoder import RNN_Decoder
from text_utils import generate_text_artifacts
from image_utils import load_image
from utils import load_from_pickle
import os
import numpy as np
from PIL import Image

In [None]:
from config import WANDB_PROJECT, WANDB_ENTITY, max_length, vocabulary_size, attention_features_shape
import wandb

In [None]:
run = wandb.init(project=WANDB_PROJECT,
                 entity=WANDB_ENTITY, name="evaluate-coco2014-attention-model", job_type="evaluate")

In [None]:
img_cap_table = run.use_artifact(
    "image_caption_table:latest").get("img_cap_table")
train_captions = img_cap_table.get_column("caption")
# TODO: Very jank. We regenerate the tokenizer twice from these same data from a previous step which feels wrong. Hope this doesnt change.

In [None]:
caption_dataset = tf.data.Dataset.from_tensor_slices(train_captions)
# cap vecotr contains each sentence as max_length where the word position index is the vocab index
tokenizer, _, word_to_index, index_to_word = generate_text_artifacts(
    caption_dataset, max_length=max_length, vocabulary_size=vocabulary_size, return_vector=False, return_mapping=True)

In [None]:
image_features_extract_model_path = run.use_artifact(
    "feature_extractor:latest").download()
image_features_extract_model = tf.keras.models.load_model(image_features_extract_model_path)

In [None]:
encoder_model_path = run.use_artifact("encoder:latest").download()
# encoder = tf.saved_model.load(encoder_model_path)

In [None]:
decoder_model_path = run.use_artifact("decoder:latest").download()
# decoder = tf.saved_model.load(decoder_model_path)

In [None]:
def evaluate(image):
    attention_plot = np.zeros((max_length, attention_features_shape))

    hidden = decoder.reset_state(batch_size=1)

    temp_input = tf.expand_dims(load_image(image)[0], 0)
    img_tensor_val = image_features_extract_model(temp_input)
    img_tensor_val = tf.reshape(img_tensor_val, (img_tensor_val.shape[0],
                                                 -1,
                                                 img_tensor_val.shape[3]))

    features = encoder(img_tensor_val)

    dec_input = tf.expand_dims([word_to_index('<start>')], 0)
    result = []

    for i in range(max_length):
        predictions, hidden, attention_weights = decoder(dec_input,
                                                         features,
                                                         hidden)

        attention_plot[i] = tf.reshape(attention_weights, (-1, )).numpy()

        predicted_id = tf.random.categorical(predictions, 1)[0][0].numpy()
        predicted_word = tf.compat.as_text(index_to_word(predicted_id).numpy())
        result.append(predicted_word)

        if predicted_word == '<end>':
            return result, attention_plot

        dec_input = tf.expand_dims([predicted_id], 0)

    attention_plot = attention_plot[:len(result), :]
    return result, attention_plot

In [None]:
def plot_attention(image, result, attention_plot):
    import matplotlib.pyplot as plt
    from PIL import Image
    plt.clf()
    temp_image = np.array(Image.open(image))

    fig = plt.figure(figsize=(10, 10))

    len_result = len(result)
    for i in range(len_result):
        temp_att = np.resize(attention_plot[i], (8, 8))
        grid_size = max(int(np.ceil(len_result/2)), 2)
        ax = fig.add_subplot(grid_size, grid_size, i+1)
        ax.set_title(result[i])
        img = ax.imshow(temp_image)
        ax.imshow(temp_att, cmap='gray', alpha=0.6, extent=img.get_extent())

    # plt.tight_layout()
    return plt

In [None]:
images_art = run.use_artifact("images:latest")
images_path = images_art.download()

In [None]:
split_art = run.use_artifact("split:latest")
img_name_val_path = split_art.get_path("img_name_val.pkl").download()
cap_val_path = split_art.get_path("cap_val.pkl").download()

In [None]:
img_name_val = [os.path.join(images_path, path)
                  for path in load_from_pickle(img_name_val_path)]
cap_val = load_from_pickle(cap_val_path)

In [None]:
from tqdm.notebook import tqdm

In [None]:
plot_dir = os.path.join(".", "attention_plots")
if not os.path.exists(plot_dir):
    os.makedirs(plot_dir)

In [None]:
eval_table = wandb.Table(columns = ["name", "image", "real_caption", "predicted_caption", "attention_plot"])
# captions on the validation set
for rid, image in tqdm(enumerate(img_name_val)):
    print(rid)
    real_caption = ' '.join([tf.compat.as_text(index_to_word(i).numpy())
                            for i in cap_val[rid] if i not in [0]])
    result, attention_plot = evaluate(image)

    print('Real Caption:', real_caption)
    predicted_caption = ' '.join(result)
    print('Prediction Caption:', predicted_caption)
    plt = plot_attention(image, result, attention_plot)
    plt_path = os.path.join(plot_dir, f'attention_{os.path.basename(image)}.png')
    plt.savefig(plt_path)
    eval_table.add_data(os.path.basename(image), wandb.Image(image), real_caption, predicted_caption, wandb.Image(plt_path))

In [None]:
eval_table_art = wandb.Artifact(name="eval_table", type="eval")
eval_table_art.add(eval_table, "eval_table")

In [None]:
run.log({"eval_table": eval_table})
run.log_artifact(eval_table_art)

In [None]:
run.finish()