In [None]:
%run ./load_src_module.py

In [None]:
import os
import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp
from tqdm import tqdm

import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid

In [None]:
from src.datasets import PersonalityCaptions, DatasetManager
from src.utils import MultiCheckpointManager
from src.evaluate import _seq_to_text

from src.image_encoders import InceptionResNetEncoder
from src.transformer import TransformerGenerator

In [None]:
pc = PersonalityCaptions("/Users/akshaykurmi/NEU/Image-Captioning/stylized-captions/data/personality_captions_data")

In [None]:
dm = DatasetManager(pc, 20)

In [None]:
encoder = InceptionResNetEncoder()
generator = TransformerGenerator(token_vocab_size=dm.tokenizer.vocab_size,
                                 style_vocab_size=dm.style_encoder.num_classes,
                                 model_dim=512, style_dim=64, pffn_dim=2048, z_dim=512,
                                 encoder_blocks=2, decoder_blocks=6, num_attention_heads=8, max_pe=64,
                                 dropout=0.1, stylize=True)
checkpoint_manager = MultiCheckpointManager("/Users/akshaykurmi/NEU/Image-Captioning/stylized-captions/results/run_1/checkpoints", {
    "generator": {"generator": generator}
})
checkpoint_manager.restore({"generator": 231})

In [None]:
dm.style_encoder.num_classes

In [None]:
ds = pc.load("test")

In [None]:
# fig=plt.figure(figsize=(27,27))
# columns, rows = 10, 10
# for i in range(columns * rows):
#     img = dm.load_image(ds[i+1500]["image_path"]).numpy().astype(np.int32)
#     fig.add_subplot(rows, columns, i + 1)
#     plt.imshow(img)
#     plt.xticks([])
#     plt.yticks([])
# plt.show()

In [None]:
i = 9444
image_path = ds[i]["image_path"]
image = dm.load_image(image_path)
plt.figure(figsize=(5,5))
plt.imshow(image.numpy().astype(np.int32))
plt.xticks([])
plt.yticks([])
plt.show()

In [None]:
ds[i]

In [None]:
encoder_output = encoder(tf.expand_dims(image, axis=0))
for s in range(dm.style_encoder.num_classes):
    style_label = dm.style_encoder.index_to_label[s]
    print(f"\033[1m===== {s} : {style_label} =====\033[0m")
    style = tf.constant(s, dtype=tf.int32, shape=(1,))
    sequences, sequences_logits = generator.beam_search(encoder_output, style, sequence_length=20,
                                                        beam_size=5, sos=dm.tokenizer.start_id,
                                                        eos=dm.tokenizer.end_id)
    for seq, logit in zip(sequences.numpy()[0], sequences_logits.numpy()[0]):
        print(f"{logit:0.5f} | {_seq_to_text(dm, seq)}")

    initial_sequence = tf.ones((1, 1), dtype=tf.int64) * dm.tokenizer.start_id
    sequences = generator.sample(encoder_output, initial_sequence, style,
                                 sequence_length=20, mode="stochastic", n_samples=3,
                                 training=False, sos=dm.tokenizer.start_id,
                                 eos=dm.tokenizer.end_id)[0]
    for seq in sequences:
        print(f"{_seq_to_text(dm, seq.numpy()[0])}")
    print()