# Social Media Caption Generator

##### Objective:
- Generate a suggested caption for a social media post.
- take an image as input and use a transformer decoder to generate text.
- The point of this model is to explore other use cases for transformer models with nonstandard inputs

##### Inspired by:
- AN IMAGE IS WORTH 16X16 WORDS: TRANSFORMERS FOR IMAGE RECOGNITION AT SCALE
- Perceiver: General Perception with Iterative Attention

- Originally, model was planned as a visual transformer encoder (ViT) into a language transformer decoder
- In place of encoder, I implement the Perceiver architecture transformer and feed its outputs into a language transformer decoder.

In [None]:
from tensorflow.python.client import device_lib
import os
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "CPU"

# this is to check if the gpu is present
print(device_lib.list_local_devices())


In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Embedding, LayerNormalization, Input
from tensorflow.keras.layers import Layer
from tensorflow.keras import Sequential
from tensorflow_addons.activations import gelu
from tensorflow_addons.optimizers import LAMB, RectifiedAdam
import numpy as np
%matplotlib inline
import matplotlib.pyplot as plt
import time
import os
from tqdm import tqdm
import pickle
import imageio
from cv2 import resize
from imgaug import augmenters as iaa
from glob import glob

from math import pi

In [None]:
# Select appropriate policy for hardware
policy = tf.keras.mixed_precision.Policy("mixed_float16")#float32, float16
tf.keras.mixed_precision.set_global_policy(policy)

All data is preprocessed and stored on disk within the data preprocessing notebook
- The importance of the data preprocessing notebook to be seperate comes from memory limit issues, particularly VRAM

Structure of the data is as follows:
- InceptionV3 image features for each post within image_encodings folder
- Captions indexed by vocabulary defined in preprocessing within tokenized_captions folder
- both files related to a post share the same name (without extension) and filename lists are used to split and load data

In [None]:
f = open("./other/vocab.pkl", 'rb')
vocab = pickle.load(f)
f.close()
vocab_inverse = {idx: w for w, idx in vocab.items()}

In [None]:
pad_idx = 1
unk_idx = 3

In [None]:
IMAGE_PATH = "./images/"
filename_list = []
for img in tqdm(glob(IMAGE_PATH + "*.jpg")):
    filename_list.append(os.path.basename(img))

train_filenames = filename_list[:-10000]
val_filenames = filename_list[-10000:]
del filename_list

In [None]:
print(len(train_filenames))
print(len(val_filenames))

In [None]:
BATCH_SIZE = 16
N_EPOCHS = 12
MAX_LEN = 50

In [None]:
IMG_SIZE = 128
N_CHANNELS = 3
SEQUENCE_LENGTH = 50

In [None]:
NUM_FREQ_BANDS = 16
MAX_FREQ = 64
FREQ_BASE = 2
INPUT_AXIS = 2
DATA_CHANNELS = N_CHANNELS + (INPUT_AXIS * ((NUM_FREQ_BANDS * 2) + 1))

In [None]:
train_images = np.zeros((len(train_filenames), IMG_SIZE, IMG_SIZE, 3), dtype=np.float16)
train_captions = []
for i, filename in tqdm(enumerate(train_filenames)):
    train_images[i] = resize(imageio.imread("./images/" + filename[:-4] + ".jpg"), (IMG_SIZE, IMG_SIZE))/255
    train_captions.append(np.load("./tokenized_caption/" + filename[:-4] + ".npy"))

val_images = np.zeros((len(val_filenames), IMG_SIZE, IMG_SIZE, 3), dtype=np.float16)
val_captions = []
for i, filename in tqdm(enumerate(val_filenames)):
    val_images[i] = resize(imageio.imread("./images/" + filename[:-4] + ".jpg"), (IMG_SIZE, IMG_SIZE))/255
    val_captions.append(np.load("./tokenized_caption/" + filename[:-4] + ".npy"))

In [None]:
rand_aug = iaa.RandAugment(n=3, m=7)

#Credit for fourier_encode function from lucidrains/perceiver
def fourier_encode(x, max_freq, num_bands, base):
    x = tf.expand_dims(x, -1)
    x_orig = x
    scales = tf.experimental.numpy.logspace(start=0,
                                            stop=tf.math.log(tf.cast(max_freq/2, x.dtype))/tf.math.log(tf.cast(base, x.dtype)),
                                            num=num_bands,
                                            endpoint=True,
                                            base=base,
                                            dtype=x.dtype)
    scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
    
    x = x * scales * pi
    x = tf.concat([tf.math.sin(x), tf.math.cos(x)], axis=-1)
    x = tf.concat((x, x_orig), axis=-1)
    return x
    
def make_fourier_grid(num_freq_bands, max_freq, freq_base, dims):
        axis_pos = tf.map_fn(fn = lambda size: tf.linspace(start=-1, stop=1, num=size),
                             elems=tf.convert_to_tensor(dims, dtype=tf.int32), fn_output_signature=tf.float64)
        
        x_axis = axis_pos[0,:]
        y_axis = axis_pos[1,:]

        pos = tf.cast(tf.stack(tf.meshgrid(x_axis, y_axis), axis=-1), policy.variable_dtype)
        enc_pos = fourier_encode(pos, max_freq, num_freq_bands, freq_base)
        enc_pos_shape = tf.shape(enc_pos)
        enc_pos = tf.reshape(enc_pos, tf.concat([enc_pos_shape[0:2],enc_pos_shape[2,tf.newaxis]*enc_pos_shape[3]], axis=0))        
        return enc_pos

def batch_captions_to_matrix(batch_captions, max_len=50):
        # function takes a list of tokenized captions, pads and returns a matrix for the batch
        batch = np.full((len(batch_captions), max_len), pad_idx)
        for i, caption in enumerate(batch_captions):
            batch[i,:len(caption)] = caption[:max_len]
        return batch

class DataGenerator(tf.keras.utils.Sequence):
    # Generator uses list of filenames to load in the data
    def __init__(self, img_tensor, caption_list, batch_size=64, sequence_length=50, shuffle=True, image_size=300, n_channels=3,
                 num_frequency_bands=64, max_freq=200, freq_base=2, input_axis=2, augment=True):
        self.image_size = image_size
        self.n_channels = n_channels
        self.img_tensor = img_tensor
        self.caption_list = caption_list
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.shuffle = shuffle
        self.on_epoch_end()
        self.n = 0
        self.n_total = self.__len__()
        self.augment = augment
        self.fourier_grid = make_fourier_grid(num_frequency_bands, max_freq, freq_base, [image_size, image_size])
        self.data_channels = n_channels + (input_axis * ((num_frequency_bands * 2) + 1))
    
    def __len__(self):
        return int(np.floor(len(self.caption_list)/self.batch_size))

    def __getitem__(self, index):
        indices = self.indices[index*self.batch_size : (index + 1)*self.batch_size]
        
        batch_images = np.zeros((self.batch_size, self.image_size, self.image_size, self.n_channels))
        captions_tokenized = []
        for i, ind in enumerate(indices):
            batch_images[i] = self.img_tensor[ind]
            captions_tokenized.append(self.caption_list[ind])

        batch_captions = batch_captions_to_matrix(captions_tokenized)

        batch_images = batch_images*255
        if self.augment:
            batch_images = rand_aug(images=batch_images.astype(np.uint8))
        batch_images = (batch_images.astype(np.float64) / (255/2)) - 1
        
        enc_pos = tf.tile(self.fourier_grid, [self.batch_size, 1, 1])
        enc_pos = tf.reshape(enc_pos, tf.concat([[self.batch_size], tf.shape(self.fourier_grid)], axis=0))        
        data = tf.concat((batch_images, enc_pos), axis=-1)

        # concatenate and flatten
        # change data from batch, height, width, channel to batch, height/width, channel
        data_shape = tf.shape(data)
        data = tf.reshape(data, tf.concat([data_shape[0, tf.newaxis], data_shape[1,tf.newaxis]*data_shape[2], data_shape[3,tf.newaxis]], axis=0))
        
        return tf.cast(data, policy.variable_dtype), tf.cast(batch_captions, tf.int32)



    def __next__(self):
        if self.n >= self.n_total:
            self.n = 0
            if self.shuffle:
                self.on_epoch_end()
        result = self.__getitem__(self.n)
        self.n += 1
        return result

    def __iter__(self):
        return self
    
    def on_epoch_end(self):
        # shuffles
        self.indices = np.arange(len(self.caption_list))
        if self.shuffle == True:
            np.random.shuffle(self.indices)

In [None]:
train_data_generator = DataGenerator(
            train_images,
            train_captions,
            batch_size = BATCH_SIZE,
            sequence_length=50,
            image_size=IMG_SIZE,
            num_frequency_bands=NUM_FREQ_BANDS,
            max_freq=MAX_FREQ,
            freq_base=FREQ_BASE)

validation_data_generator = DataGenerator(
            val_images,
            val_captions,
            batch_size = BATCH_SIZE,
            sequence_length=50,
            image_size=IMG_SIZE,
            num_frequency_bands=NUM_FREQ_BANDS,
            max_freq=MAX_FREQ,
            freq_base=FREQ_BASE,
            augment=False)

In [None]:
t = time.time()
ex_img_sequences, ex_cap = train_data_generator.__getitem__(1)
e = time.time()
print("elapsed_time: " + str(e-t))
print("ex_img_encoding")
print(ex_img_sequences.shape)
print(ex_img_sequences)

del ex_img_sequences, ex_cap

In [None]:
with tf.device('/device:CPU:0'):
    train_dataset = tf.data.Dataset.from_generator(lambda: train_data_generator,
                                       output_signature=(
                                           tf.TensorSpec(shape=(BATCH_SIZE, IMG_SIZE * IMG_SIZE, DATA_CHANNELS), dtype=policy.variable_dtype),
                                           tf.TensorSpec(shape=(BATCH_SIZE, SEQUENCE_LENGTH), dtype=tf.int32)))
    train_dataset = train_dataset.prefetch(8)

    val_dataset = tf.data.Dataset.from_generator(lambda: validation_data_generator,
                                       output_signature=(
                                           tf.TensorSpec(shape=(BATCH_SIZE, IMG_SIZE * IMG_SIZE, DATA_CHANNELS), dtype=policy.variable_dtype),
                                           tf.TensorSpec(shape=(BATCH_SIZE, SEQUENCE_LENGTH), dtype=tf.int32)))
    val_dataset = val_dataset.prefetch(8)


From the observed test above, retrieving the data from memory this way should not cause a significant slowdown

In [None]:
@tf.function(jit_compile=True)
def attention(query, key, value, mask=None):        
        score = tf.matmul(query, key, transpose_b=True)
        dim_key = tf.cast(tf.shape(key)[-1], score.dtype)
        scaled_score = tf.math.divide(score, tf.math.sqrt(dim_key))
        if mask is not None:
            scaled_score += (mask * (scaled_score.dtype.min/32)) #to avoid numerical underflow/overflow
        weights = tf.nn.softmax(scaled_score, axis=-1) # (..., seq_len_q, seq_len_k)
        output = tf.matmul(weights, value) # (..., seq_len_q, depth_v)
        return output, weights

class MultiHeadSelfAttention(Layer):
    def __init__(self, embed_dim, num_heads=8, has_bias=True):#, is_causal=False):
        super(MultiHeadSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}")
        self.projection_dim = embed_dim // num_heads

        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)

        self.combine_heads = Dense(embed_dim)

    @tf.function(jit_compile=True)
    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    @tf.function
    def call(self, x, padding_mask=None):
        batch_size = tf.shape(x)[0]
        
        query = self.query_dense(x)  # (batch_size, seq_len, embed_dim)
        key =   self.key_dense(x)    # (batch_size, seq_len, embed_dim)
        value = self.value_dense(x)  # (batch_size, seq_len, embed_dim)
        
        query = self.separate_heads(query, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        key =   self.separate_heads(key, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        value = self.separate_heads(value, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        
        attention_out, _ = attention(query, key, value, padding_mask)
        
        attention_out = tf.transpose(attention_out, perm=[0, 2, 1, 3])  # (batch_size, seq_len, num_heads, projection_dim)
        concat_attention = tf.reshape(attention_out, (batch_size, -1, self.embed_dim))  # (batch_size, seq_len, embed_dim)
        
        output = self.combine_heads(concat_attention)  # (batch_size, seq_len, embed_dim)

        return output
        
class MultiHeadAttention(Layer):
    def __init__(self, embed_dim, num_heads=8, has_bias=True):#, is_causal=False):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        if embed_dim % num_heads != 0:
            raise ValueError(
                f"embedding dimension = {embed_dim} should be divisible by number of heads = {num_heads}")
        self.projection_dim = embed_dim // num_heads

        self.query_dense = Dense(embed_dim)
        self.key_dense = Dense(embed_dim)
        self.value_dense = Dense(embed_dim)
        
        self.combine_heads = Dense(embed_dim)

    @tf.function(jit_compile=True)
    def separate_heads(self, x, batch_size):
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.projection_dim))
        return tf.transpose(x, perm=[0, 2, 1, 3])
    
    @tf.function
    def call(self, q, kv, padding_mask=None):
        batch_size = tf.shape(q)[0]

        query = self.query_dense(q)  # (batch_size, seq_len, embed_dim)
        key =   self.key_dense(kv)    # (batch_size, seq_len, embed_dim)
        value = self.value_dense(kv)  # (batch_size, seq_len, embed_dim)
        
        query = self.separate_heads(query, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        key =   self.separate_heads(key, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        value = self.separate_heads(value, batch_size)  # (batch_size, num_heads, seq_len, projection_dim)
        
        attention_out, weights = attention(query, key, value, padding_mask)
        attention_out = tf.transpose(attention_out, perm=[0, 2, 1, 3])  # (batch_size, seq_len, num_heads, projection_dim)
        concat_attention = tf.reshape(attention_out, (batch_size, -1, self.embed_dim))  # (batch_size, seq_len, embed_dim)
        output = self.combine_heads(concat_attention)  # (batch_size, seq_len, embed_dim)
        return output, weights

class PerceiverTransformerBlock(Layer):
    def __init__(self,
                 embed_dim,
                 num_latents = 512,
                 cross_heads = 8,
                 latent_heads = 8,
                 transformers_per_attend = 1):
        super(PerceiverTransformerBlock, self).__init__()
        
        #cross attention block
        self.cross_attn_layernorm_x = LayerNormalization(epsilon=1e-3)
        self.cross_attn = MultiHeadAttention(embed_dim, cross_heads)
        self.cross_ffn = Sequential([LayerNormalization(epsilon=1e-3), Dense(embed_dim, activation=gelu), Dense(embed_dim)])

        #self attention block
        self.transformers_per_attend = transformers_per_attend
        self.self_attention_transformers = [Sequential([LayerNormalization(epsilon=1e-3),
                                                        MultiHeadSelfAttention(embed_dim, latent_heads)])
                                                        for i in range(transformers_per_attend)]
        self.self_attention_transformer_ffns = [Sequential([LayerNormalization(epsilon=1e-3),
                                                            Dense(embed_dim, activation=gelu),
                                                            Dense(embed_dim)]) for i in range(transformers_per_attend)]

    @tf.function
    def call(self, x, data):
        shortcut1, _ = self.cross_attn(self.cross_attn_layernorm_x(x), data) #input to attn is (q,k,v,mask)
        x = x + shortcut1
        shortcut2 = self.cross_ffn(x)
        x = x + shortcut2
        for i in range(self.transformers_per_attend):
            shortcut1 = self.self_attention_transformers[i](x)
            x = x + shortcut1
            
            shortcut2 = self.self_attention_transformer_ffns[i](x)
            x = x + shortcut2
        return x

class Perceiver(Layer):
    def __init__(self, num_layers, latent_dim, num_latents=512, cross_heads=8,
                 latent_heads=8, weight_tie=True, self_attn_transformers_per_attend=1):
        super(Perceiver, self).__init__()
        self.context_norm = LayerNormalization(epsilon=1e-3)
        self.num_layers = num_layers
        self.first_layer = PerceiverTransformerBlock(latent_dim,
                                                     num_latents,
                                                     cross_heads,
                                                     latent_heads,
                                                     self_attn_transformers_per_attend)
        self.weight_tie_layers = weight_tie
        if weight_tie is False:
            self.perceiver_layers = [PerceiverTransformerBlock(latent_dim,
                                                               num_latents,
                                                               cross_heads,
                                                               latent_heads,
                                                               self_attn_transformers_per_attend) for i in range(num_layers-1)]
        else:
            self.perceiver_layers = PerceiverTransformerBlock(latent_dim,
                                                              num_latents,
                                                              cross_heads,
                                                              latent_heads,
                                                              self_attn_transformers_per_attend)

        latents = tf.clip_by_value(tf.random.normal((num_latents, latent_dim), mean=0, stddev=0.02), -2, 2)
        self.latents = tf.Variable(initial_value=latents, trainable=True)

    @tf.function
    def call(self, data):
        batch_size = tf.shape(data)[0]
        
        x = tf.tile(self.latents, [batch_size, 1])
        x = tf.reshape(x, tf.concat([[batch_size], tf.shape(self.latents)], axis=0))
        data = self.context_norm(data)

        x = self.first_layer(x, data)
        if self.weight_tie_layers is True:
            for i in range(self.num_layers - 1):
                x = self.perceiver_layers(x, data)
        else:
            for i in range(self.num_layers - 1):
                x = self.perceiver_layers[i](x, data)
        return x

class DecoderTransformerBlock(Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(DecoderTransformerBlock, self).__init__()
        
        self.layernorm1 = LayerNormalization(epsilon=1e-6)#1e-6
        self.att1 = MultiHeadSelfAttention(embed_dim, num_heads) #(batch_size, seq_len, embed_dim)
        self.ffn1 = Sequential([LayerNormalization(epsilon=1e-6), Dense(ff_dim, activation=gelu), Dense(embed_dim)]) #(batch_size, seq_len, embed_dim)
        
        self.layernorm2 = LayerNormalization(epsilon=1e-6)#1e-6
        self.att2 = MultiHeadAttention(embed_dim, num_heads) #(batch_size, seq_len, embed_dim)
        self.ffn2 = Sequential([LayerNormalization(epsilon=1e-6), Dense(ff_dim, activation=gelu), Dense(embed_dim)])

    @tf.function(jit_compile=True)
    def call(self, inputs, ENCODER_OUTPUT, training, look_ahead_mask, padding_mask):
        x = self.layernorm1(inputs)
        attn1_output = self.att1(x, look_ahead_mask) #(batch_size, input_seq_len, embed_dim)
        x = x + attn1_output
        shortcut = self.ffn1(x)
        x = x + shortcut
        attn2_output, _ = self.att2(x, ENCODER_OUTPUT, None)
        x = x + attn2_output
        shortcut = self.ffn2(x)
        x = x + shortcut
        return x

class Decoder(Layer):
    def __init__(self, num_layers, embed_dim, num_heads, ff_dim, sequence_length, target_vocab_size, rate=0.1):
        super(Decoder, self).__init__()
        self.embed_dim = embed_dim
        self.sequence_length = sequence_length -1
        self.num_layers = num_layers
        self.embedding = Embedding(input_dim=target_vocab_size, output_dim=embed_dim)
        self.pos_emb = Embedding(input_dim=sequence_length, output_dim=embed_dim)
        self.dec_layers = [DecoderTransformerBlock(embed_dim, num_heads, ff_dim, rate) for _ in range(num_layers)]

    @tf.function
    def call(self, x, enc_output, training, look_ahead_mask, padding_mask):
        x = self.embedding(x)  # (batch_size, input_seq_len, d_model)
        
        positions = tf.range(start=0, limit=self.sequence_length, delta=1)
        positions = self.pos_emb(positions)
        x += positions
        for i in range(self.num_layers):
            x = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
        return x  # (batch_size, input_seq_len, d_model)
    
class Transformer(tf.keras.Model):
    def __init__(self, n_perceiver_layers,
                 n_dec_layers, embed_dim, dec_heads, dec_ff_dim, sequence_length, target_vocab_size,
                 num_latents=512, cross_heads=8, weight_tie_perceiver_layers=True,
                 latent_heads=8, dropout_rate=0.1, self_attn_transformers_per_attend=1):
        super(Transformer, self).__init__()

        self.embed_dim = embed_dim
        self.sequence_length = sequence_length

        self.encoder = Perceiver(n_perceiver_layers,
                                 embed_dim,
                                 num_latents,
                                 cross_heads,
                                 latent_heads,
                                 weight_tie_perceiver_layers,
                                 self_attn_transformers_per_attend)
        
        self.decoder = Decoder(n_dec_layers,
                               embed_dim,
                               dec_heads,
                               dec_ff_dim,
                               sequence_length,
                               target_vocab_size,
                               dropout_rate)
        self.final_layer = Dense(target_vocab_size, dtype="float32")
    
    @tf.function
    def call(self, image_sequence, target, training, look_ahead_mask, padding_mask):
        enc_output = self.encoder(image_sequence)  # (batch_size, inp_seq_len*d_model)
        decoder_output = self.decoder(target, enc_output, training, look_ahead_mask, padding_mask) # (batch_size, inp_seq_len, d_model)
        final_output = self.final_layer(decoder_output)  # (batch_size, tar_seq_len, target_vocab_size)
        return final_output

In [None]:
transformer_params = {
    "n_perceiver_layers": 2,
    "n_dec_layers": 2,
    "embed_dim": 512,
    "dec_heads": 8,
    "dec_ff_dim": 512,
    "sequence_length": 50,
    "target_vocab_size": len(vocab),
    "num_latents": 512,
    "cross_heads": 1,
    "weight_tie_perceiver_layers":True,
    "latent_heads": 4,
    "dropout_rate": 0,
    "self_attn_transformers_per_attend": 2
    }
Caption_generator = Transformer(**transformer_params)

In [None]:
#to build model before going into tf.function
temp1, temp2 = validation_data_generator.__getitem__(0)
_ = Caption_generator(temp1, temp2[:, :-1], False, None, None)

In [None]:
Caption_generator.summary()

In [None]:
@tf.function(jit_compile=True)
def create_padding_mask(batch):
    mask = tf.cast(tf.math.equal(batch, 0), policy.compute_dtype)
    return mask[:, tf.newaxis, tf.newaxis, :]

@tf.function(jit_compile=True)
def create_look_ahead_mask(batch):
    batch_size, seq_len = batch.shape
    mask = 1 - tf.linalg.band_part(tf.ones((seq_len, seq_len)), -1, 0)
    return tf.cast(mask, policy.compute_dtype)

@tf.function(jit_compile=True)
def create_masks(batch):
    padding_mask = create_padding_mask(batch)
    look_ahead_mask = create_look_ahead_mask(batch)
    combined_mask = tf.maximum(padding_mask, look_ahead_mask)
    return combined_mask, padding_mask

In [None]:
#learning_rate = 0.004 #DIVERGES
learning_rate = 0.0004 #GOOD
#learning_rate = 0.00004 #NAN (underflow)
optim = LAMB(learning_rate)
#optimizer = RectifiedAdam(learning_rate=1e-4, total_steps=10000, warmup_proportion=0.1, min_lr=1e-6)
optim = tf.keras.mixed_precision.LossScaleOptimizer(optim)

In [None]:
loss_obj = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction="none")
@tf.function
def loss_function(batch_captions, predictions):
    mask = tf.math.logical_not(tf.math.equal(batch_captions, pad_idx))
    loss_ = loss_obj(batch_captions, predictions)
    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask
    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

@tf.function(jit_compile=True)
def accuracy_function(batch_captions, predictions):
    accuracies = tf.equal(tf.cast(batch_captions, tf.int32), tf.cast(tf.argmax(predictions, axis=2), tf.int32))
    mask = tf.math.logical_not(tf.math.equal(batch_captions, pad_idx))
    accuracies = tf.math.logical_and(mask, accuracies)
    accuracies = tf.cast(accuracies, dtype=tf.float32)
    mask = tf.cast(mask, dtype=tf.float32)
    return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)

In [None]:
checkpoint_path = "./checkpoints/Perceiver/train"

ckpt = tf.train.Checkpoint(transformer=Caption_generator,
                           optimizer=optim)

ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=5)

# if a checkpoint exists, restore the latest checkpoint.
if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

In [None]:
train_loss = tf.keras.metrics.Mean(name="train_loss")
train_accuracy = tf.keras.metrics.Mean(name="train_accuracy")
validation_loss = tf.keras.metrics.Mean(name="validation_loss")
validation_accuracy = tf.keras.metrics.Mean(name="validation_accuracy")

In [None]:
@tf.function
def train_step(image_encodings_sequence, batch_captions):
    batch_captions_input = batch_captions[:, :-1]
    batch_captions_real = batch_captions[:, 1:]
    
    combined_mask, padding_mask = create_masks(batch_captions_input)
    
    with tf.GradientTape() as tape:
        predictions = Caption_generator(image_encodings_sequence, batch_captions_input,
                                   True, combined_mask, padding_mask)
        loss = loss_function(batch_captions_real, predictions)
        scaled_loss = optim.get_scaled_loss(loss)
    
    if not tf.math.is_nan(scaled_loss):
        scaled_gradients = tape.gradient(scaled_loss, Caption_generator.trainable_variables)
        gradients = optim.get_unscaled_gradients(scaled_gradients)
        optim.apply_gradients(zip(gradients, Caption_generator.trainable_variables))
    
    train_loss(loss)
    train_accuracy(accuracy_function(batch_captions_real, predictions))

In [None]:
@tf.function
def validation_step(image_encodings_sequence, batch_captions):
    batch_captions_input = batch_captions[:, :-1]
    batch_captions_real = batch_captions[:, 1:]
    
    combined_mask, padding_mask = create_masks(batch_captions_input)

    predictions = Caption_generator(image_encodings_sequence, batch_captions_input,
                               True, combined_mask, padding_mask)
    loss = loss_function(batch_captions_real, predictions)
    
    validation_loss(loss)
    validation_accuracy(accuracy_function(batch_captions_real, predictions))

In [None]:
#to build model before going into tf.function
temp1, temp2 = validation_data_generator.__getitem__(0)
lha, pm = create_masks(temp2[:, :-1])
_ = Caption_generator(temp1, temp2[:, :-1], False, None, None)

In [None]:
def train(epochs):
    train_iterator = iter(train_dataset)
    validation_iterator = iter(val_dataset)
    best_val_loss = 10
    for epoch in range(epochs):
        for batch_number in range(len(train_data_generator)):
            image_encoding_sequence, captions = next(train_iterator)
            train_step(image_encoding_sequence, captions)
            
            if batch_number % 100 == 0:
                val_imgs, val_captions = next(validation_iterator)
                validation_step(val_imgs, val_captions)
                
                print(f"Epoch {epoch} Batch [{batch_number}/{len(train_data_generator)}]: train Loss: {train_loss.result():.5f}, train accuracy: {train_accuracy.result():.5f}, val Loss: {validation_loss.result():.5f}, val accuracy: {validation_accuracy.result():.5f}")
                
                if validation_loss.result() < best_val_loss:
                    best_val_loss = validation_loss.result()
                    ckpt_manager.save()

In [None]:
#LAMB lr=0.0004
train(16)

In [None]:
#lr changed from 0.0004 to 0.00004
train(8)

# Splitting model into encoder and decoder
- I split the model into 2 parts so that images do not reprocess each time a new word is generated

In [None]:
#same as defined as above but inherits from tf.keras.Model instead of layer

class Perceiver_encoder(tf.keras.Model):
    def __init__(self, num_layers, latent_dim, num_latents=512, cross_heads=8,
                 latent_heads=8, weight_tie=True, self_attn_transformers_per_attend=1):
        super(Perceiver_encoder, self).__init__()
        self.encoder = Perceiver(num_layers, latent_dim, num_latents, cross_heads,
                                 latent_heads, weight_tie, self_attn_transformers_per_attend)
    @tf.function(experimental_follow_type_hints=True)
    def call(self, data):
        x = self.encoder(data)
        return x

#initialize_model
Standalone_Perceiver_encoder = Perceiver_encoder(num_layers=transformer_params["n_perceiver_layers"],
                              latent_dim=transformer_params["embed_dim"],
                              num_latents=transformer_params["num_latents"],
                              cross_heads=transformer_params["cross_heads"],
                              latent_heads=transformer_params["latent_heads"],
                              weight_tie=True,
                              self_attn_transformers_per_attend=transformer_params["self_attn_transformers_per_attend"])

#to build model before going into tf.function
temp1, temp2 = validation_data_generator.__getitem__(0)
temp1 = Standalone_Perceiver_encoder(temp1)
#copy_weights
Standalone_Perceiver_encoder.set_weights(Caption_generator.layers[0].get_weights())
Standalone_Perceiver_encoder.latents = Caption_generator.encoder.latents

In [None]:
print(Standalone_Perceiver_encoder.get_weights()[0].shape)
print(Caption_generator.layers[0].get_weights()[0].shape)

In [None]:
Standalone_Perceiver_encoder.save("final_model/Perceiver_encoder")

In [None]:
tf.saved_model.save(Standalone_Perceiver_encoder, "final_model/saved_model_rest/Perceiver_encoder")

In [None]:
#define decoder structure
class Caption_generator_decoder(tf.keras.Model):
    def __init__(self, n_dec_layers, embed_dim, dec_heads, dec_ff_dim, sequence_length, target_vocab_size, rate=0.1):
        super(Caption_generator_decoder, self).__init__()
        self.decoder = Decoder(n_dec_layers,
                               embed_dim,
                               dec_heads,
                               dec_ff_dim,
                               sequence_length,
                               target_vocab_size,
                               rate)
        self.final_layer = Dense(target_vocab_size, dtype="float32")
    
    @tf.function(experimental_follow_type_hints=True)
    def call(self, inputs):
        enc_output, target, training, look_ahead_mask, padding_mask = inputs
        decoder_output = self.decoder(target, enc_output, training, look_ahead_mask, padding_mask) # (batch_size, inp_seq_len, d_model)
        final_output = self.final_layer(decoder_output)  # (batch_size, tar_seq_len, target_vocab_size)
        return final_output
        
#initialize_model
Standalone_Transformer_decoder = Caption_generator_decoder(n_dec_layers=transformer_params["n_dec_layers"],
                                                embed_dim=transformer_params["embed_dim"],
                                                dec_heads=transformer_params["dec_heads"],
                                                dec_ff_dim=transformer_params["dec_ff_dim"],
                                                sequence_length=transformer_params["sequence_length"],
                                                target_vocab_size=transformer_params["target_vocab_size"],
                                                rate=transformer_params["dropout_rate"])


#build computation graph
lha, pm = create_masks(temp2[:,:-1])
_ = Standalone_Transformer_decoder([temp1, temp2[:, :-1], False, lha, pm])
#copy_weights
Standalone_Transformer_decoder.layers[0].set_weights(Caption_generator.layers[1].get_weights())
Standalone_Transformer_decoder.layers[1].set_weights(Caption_generator.layers[2].get_weights())

In [None]:
Standalone_Transformer_decoder.save("final_model/Transformer_decoder", signatures=Standalone_Transformer_decoder.call.get_concrete_function(
    [
        tf.TensorSpec(shape=[None, 512, 512],  dtype=tf.float32, name="enc_output"),
        tf.TensorSpec(shape=[None, 49],        dtype=tf.int32,   name="target"),
        tf.TensorSpec(shape=[None],            dtype=tf.bool,    name="training"),
        tf.TensorSpec(shape=[None, 1, 49, 49], dtype=tf.float16, name="look_ahead_mask"),
        tf.TensorSpec(shape=[None, 1, 49, 49], dtype=tf.float16, name="padding_mask")
    ]))

# Caption Generation

In [None]:
START = "#START#"
END = "#END#"
PAD = "#PAD#"
UNK = "#UNK#"
PUNCTUATION_EXCLAMATION = "PUNCTUATION_EXCLAMATION"
PUNCTUATION_PERIOD = "PUNCTUATION_PERIOD"
PUNCTUATION_QUESTION_MARK = "PUNCTUATION_QUESTION_MARK"
EXCLAMATION = "EXCLAMATION"
def indices_to_caption(caption):
    # Takes predicted tokens and returns a string.
    # Removes words which are repeated too many times
    predicted_caption = ""
    prev_word = -1
    same_word_count = 0
    for word_ind in caption:
        if word_ind != prev_word:
            prev_word = word_ind
            same_word_count = 0
        if word_ind == prev_word:
            same_word_count += 1
        if same_word_count > 2:
            continue
        if word_ind == vocab[START]:
            continue
        if word_ind == vocab[END]:
            return predicted_caption
        if word_ind == vocab[PAD]:
            continue
        if word_ind == vocab[UNK]:
            predicted_caption += "___ "
            continue
        if word_ind == vocab[EXCLAMATION]:
            predicted_caption += "! "
            continue
        predicted_caption += vocab_inverse[word_ind] + " "
    return predicted_caption

In [None]:
from tensorflow.python.ops.numpy_ops import np_config
np_config.enable_numpy_behavior()

In [None]:
def generate_caption(encoded_image):
    '''Generates a caption for input images. Images can either be batched with size BATCH_SIZE or can be a single image
       NOTE: single images must be in the shape [1, encoded_image_size]'''
    is_batch = True
    n_images= encoded_image.shape
    if n_images == 1:
        is_batch = False
        encoded_image = np.stack([encoded_image for _ in range(BATCH_SIZE)])

    empty_cap = np.full((BATCH_SIZE, MAX_LEN), pad_idx)
    empty_cap[:,0] = vocab[START]
    empty_cap = empty_cap[:,:-1]

    encoder_out = Standalone_Perceiver_encoder(encoded_image)
    for i in range(48):
        lha, pm = create_masks(empty_cap)

        preds = Standalone_Transformer_decoder([encoder_out, empty_cap, False, lha, pm]).numpy()
        # UNK tokens are masked out for caption generation
        preds[:,:,3] = -3e4
        
        values, indices = tf.math.top_k(preds[:,i], 5)
        values = tf.nn.softmax(values)
        chosen = tf.random.categorical(values, 1).numpy()
        empty_cap[:,i+1] = np.choose(chosen.T[0], indices.T)

    predicted_caption = ""
    predicted_captions = []
    for caption in empty_cap:
        if not is_batch:
            predicted_caption = indices_to_caption(caption)
            return predicted_caption
        else:
            predicted_caption = indices_to_caption(caption)
            predicted_captions.append(predicted_caption)
    return predicted_captions

In [None]:
val_im, _ = train_data_generator.__getitem__(0)
generate_caption(val_im)