##### Copyright 2018 The TensorFlow Authors.


In [None]:
!pip install einops

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting einops
  Downloading einops-0.6.1-py3-none-any.whl (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.6.1


This tutorial uses lots of imports, mostly for loading the dataset(s).

In [None]:
#@title
import concurrent.futures
import collections
import dataclasses
import hashlib
import itertools
import json
import math
import os
import pathlib
import random
import re
import string
import time
import urllib.request

import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import requests
import tqdm


import torch
import torch.nn.functional as F
import torch.utils.data as data
import torchvision.transforms as transforms
# import transformers
# import datasets

#### Flickr8k

In [None]:
import os
import pathlib
import collections
from urllib.request import urlretrieve
import zipfile
from torchvision.datasets.utils import download_and_extract_archive

def flickr8k(path='flickr8k'):
    path = pathlib.Path(path)

    if len(list(path.rglob('*'))) < 16197:
        download_and_extract_archive(
            url='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',
            download_root='.',
            extract_root=path,
        )
        download_and_extract_archive(
            url='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',
            download_root='.',
            extract_root=path,
        )

    captions = (path/"Flickr8k.token.txt").read_text().splitlines()
    captions = (line.split('\t') for line in captions)
    captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)

    cap_dict = collections.defaultdict(list)
    for fname, cap in captions:
        cap_dict[fname].append(cap)

    train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()
    train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]

    test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()
    test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]

    return train_captions, test_captions

#### Download the dataset

The Flickr8k is a good choice because it contains 5-captions per image, more data for a smaller download.

In [None]:
choose = 'flickr8k'

if choose == 'flickr8k':
  train_raw, test_raw = flickr8k()


Downloading https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8f49-4515a2a3339b?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230508%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230508T032652Z&X-Amz-Expires=300&X-Amz-Signature=97d1b10a77513c9447d12f2ffdd901c2129ff084141e3e660fe8d99222175b4d&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=124585957&response-content-disposition=attachment%3B%20filename%3DFlickr8k_Dataset.zip&response-content-type=application%2Foctet-stream to ./Flickr8k_Dataset.zip


100%|██████████| 1115419746/1115419746 [00:21<00:00, 52561305.18it/s]


Extracting ./Flickr8k_Dataset.zip to flickr8k
Downloading https://objects.githubusercontent.com/github-production-release-asset-2e65be/124585957/47f52b80-3501-11e9-8d2e-dd69a21a4362?X-Amz-Algorithm=AWS4-HMAC-SHA256&X-Amz-Credential=AKIAIWNJYAX4CSVEH53A%2F20230508%2Fus-east-1%2Fs3%2Faws4_request&X-Amz-Date=20230508T032721Z&X-Amz-Expires=300&X-Amz-Signature=273560cfe9cc953a177c8d8cc1dcfb49a184f777e122fb45d1271ea84baa6098&X-Amz-SignedHeaders=host&actor_id=0&key_id=0&repo_id=124585957&response-content-disposition=attachment%3B%20filename%3DFlickr8k_text.zip&response-content-type=application%2Foctet-stream to ./Flickr8k_text.zip


100%|██████████| 2340801/2340801 [00:00<00:00, 93828542.19it/s]

Extracting ./Flickr8k_text.zip to flickr8k





The loaders for both datasets above return `tf.data.Dataset`s containing `(image_path, captions)` pairs. The Flickr8k dataset contains 5 captions per image, while Conceptual Captions has 1:

In [None]:
for ex_path, ex_captions in train_raw[:1]:
  print(ex_path)
  print(ex_captions)

flickr8k/Flicker8k_Dataset/2513260012_03d33305cf.jpg
['A black dog is running after a white dog in the snow .', 'Black dog chasing brown dog through snow', 'Two dogs chase each other across the snowy ground .', 'Two dogs play together in the snow .', 'Two dogs running through a low lying body of water .']


### Image feature extractor

You will use an image model (pretrained on imagenet) to extract the features from each image. The model was trained as an image classifier, but setting `include_top=False` returns the model without the final classification layer, so you can use the last layer of feature-maps:  


In [None]:
import torch
import torchvision.models as models
from torchvision.transforms import Resize, Compose, ToTensor, Normalize

IMAGE_SHAPE = (224, 224, 3)

mobilenet = models.mobilenet_v3_small(pretrained=True)
mobilenet = mobilenet.features
mobilenet.eval()

normalize = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
preprocessing = Compose([Resize(IMAGE_SHAPE[:2]), ToTensor(), normalize])

Downloading: "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v3_small-047dcff4.pth
100%|██████████| 9.83M/9.83M [00:00<00:00, 152MB/s]


Here's a function to load an image and resize it for the model:

In [None]:
from PIL import Image
import torchvision.transforms as transforms

IMAGE_SHAPE = (224, 224, 3)

def load_image(image_path):
    img = Image.open(image_path)
    transform = transforms.Resize(IMAGE_SHAPE[:-1])
    img = transform(img)
    return img


The model returns a feature map for each image in the input batch:

In [None]:
ex_path = 'flickr8k/Flicker8k_Dataset/2513260012_03d33305cf.jpg'  # Replace this with the path to your image

test_img = load_image(ex_path)
test_img_batch = preprocessing(test_img).unsqueeze(0)

print(test_img_batch.shape)
print(mobilenet(test_img_batch).shape)

torch.Size([1, 3, 224, 224])
torch.Size([1, 576, 7, 7])


### Setup the text tokenizer/vectorizer

You will transform the text captions into integer sequences using the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer, with the following steps:

* Use [adapt](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization#adapt) to iterate over all captions, split the captions into words, and compute a vocabulary of the top words.
* Tokenize all captions by mapping each word to its index in the vocabulary. All output sequences will be padded to length 50.
* Create word-to-index and index-to-word mappings to display results.

In [None]:
import torch
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from torchtext.datasets import Flickr8k

# Load the Flickr8k dataset
train_dataset, val_dataset = Flickr8k(split=('train', 'val'))

# Define a tokenizer to split the captions into words
tokenizer = get_tokenizer('basic_english')

# Use a generator to iterate over all captions and tokenize them
def tokenize_caption(dataset):
    for caption in dataset.captions:
        yield tokenizer(caption)

# Build a vocabulary of the top words in the captions
vocab = build_vocab_from_iterator(tokenize_caption(train_dataset), min_freq=2)

# Define a TextVectorization layer to transform the captions into integer sequences
text_vectorizer = torch.nn.Sequential(
    torchtext.experimental.transforms.text.DefaultPreprocessor(),
    torchtext.experimental.transforms.text.TextSequential(tokenizer),
    torchtext.experimental.transforms.text.VocabTransform(vocab),
    torchtext.experimental.transforms.text.NumericReplaceTransform(
        vocabulary=vocab, unknown_token=0, special_tokens={}),
    torchtext.experimental.transforms.text.PadTransform(50, 0))

# Adapt the TextVectorization layer to the training set
text_vectorizer.adapt([example.caption for example in train_dataset])

# Create word-to-index and index-to-word mappings to display results
vocab_list = text_vectorizer.get_vocabulary()
word_to_index = {word: index for index, word in enumerate(vocab_list)}
index_to_word = {index: word for index, word in enumerate(vocab_list)}


AttributeError: ignored

In [None]:
# Use the top 5000 words for a vocabulary.
vocabulary_size = 5000
tokenizer = tf.keras.layers.TextVectorization(
    max_tokens=vocabulary_size,
    standardize=standardize,
    ragged=True)
# Learn the vocabulary from the caption data.

In [None]:
tokenizer.adapt(train_raw.map(lambda fp,txt: txt).unbatch().batch(1024))

In [None]:
tokenizer.get_vocabulary()[:10]

In [None]:
t = tokenizer([['a cat in a hat'], ['a robot dog']])
t

In [None]:
# Create mappings for words to indices and indices to words.
word_to_index = vocab.lookup_indices
index_to_word = vocab.lookup_token

# Example usage
example_words = ['[START]', 'a', 'cat']
example_indices = [word_to_index(word) for word in example_words]
print("Words to indices:", example_indices)

example_indices = [0, 1, 2]
example_words = [index_to_word(index) for index in example_indices]
print("Indices to words:", example_words)

In [None]:
w = index_to_word(t)
w.to_list()

In [None]:
tf.strings.reduce_join(w, separator=' ', axis=-1).numpy()

### Prepare the datasets

The `train_raw` and `test_raw` datasets contain 1:many `(image, captions)` pairs. 

This function will replicate the image so there are 1:1 images to captions:

In [None]:
def match_shapes(images, captions):
  caption_shape = einops.parse_shape(captions, 'b c')
  captions = einops.rearrange(captions, 'b c -> (b c)')
  images = einops.repeat(
      images, 'b ... -> (b c) ...',
      c = caption_shape['c'])
  return images, captions

In [None]:
for ex_paths, ex_captions in train_raw.batch(32).take(1):
  break

print('image paths:', ex_paths.shape)
print('captions:', ex_captions.shape)
print()

ex_paths, ex_captions = match_shapes(images=ex_paths, captions=ex_captions)

print('image_paths:', ex_paths.shape)
print('captions:', ex_captions.shape)


To be compatible with keras training the dataset should contain `(inputs, labels)` pairs. For text generation the tokens are both an input and the labels, shifted by one step. This function will convert an `(images, texts)` pair to an `((images, input_tokens), label_tokens)` pair:

In [None]:
def prepare_txt(imgs, txts):
  tokens = tokenizer(txts)

  input_tokens = tokens[..., :-1]
  label_tokens = tokens[..., 1:]
  return (imgs, input_tokens), label_tokens

This function adds operations to a dataset. The steps are:

1. Load the images (and ignore images that fail to load).
2. Replicate images to match the number of captions.
3. Shuffle and rebatch the `image, caption` pairs.
4. Tokenize the text, shift the tokens and add `label_tokens`.
5. Convert the text from a `RaggedTensor` representation to padded dense `Tensor` representation.

In [None]:
def prepare_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):
  # Load the images and make batches.
  ds = (ds
        .shuffle(10000)
        .map(lambda path, caption: (load_image(path), caption))
        .apply(tf.data.experimental.ignore_errors())
        .batch(batch_size))

  def to_tensor(inputs, labels):
    (images, in_tok), out_tok = inputs, labels
    return (images, in_tok.to_tensor()), out_tok.to_tensor()

  return (ds
          .map(match_shapes, tf.data.AUTOTUNE)
          .unbatch()
          .shuffle(shuffle_buffer)
          .batch(batch_size)
          .map(prepare_txt, tf.data.AUTOTUNE)
          .map(to_tensor, tf.data.AUTOTUNE)
          )

You could install the feature extractor in your model and train on the datasets like this:

In [None]:
train_ds = prepare_dataset(train_raw, tokenizer)
train_ds.element_spec

In [None]:
test_ds = prepare_dataset(test_raw, tokenizer)
test_ds.element_spec

 </section>


## Data ready for training



The dataset now returns `(input, label)` pairs suitable for training with keras. The `inputs` are `(images, input_tokens)` pairs. The `images` have been processed with the feature-extractor model. For each location in the `input_tokens` the model looks at the text so far and tries to predict the next which is lined up at the same location in the `labels`.

In [None]:
for (inputs, ex_labels) in train_ds.take(1):
  (ex_img, ex_in_tok) = inputs

print(ex_img.shape)
print(ex_in_tok.shape)
print(ex_labels.shape)

The input tokens and the labels are the same, just shifted by 1 step:

In [None]:
print(ex_in_tok[0].numpy())
print(ex_labels[0].numpy())

## A Transformer decoder model

The model will be implemented in three main parts: 

1. Input - The token embedding and positional encoding (`SeqEmbedding`).
1. Decoder - A stack of transformer decoder layers (`DecoderLayer`) where each contains:
   1. A causal self attention later (`CausalSelfAttention`), where each output location can attend to the output so far.
   1. A cross attention layer (`CrossAttention`) where each output location can attend to the input image.
   1. A feed forward network (`FeedForward`) layer which further processes each output location independently.
1. Output - A multiclass-classification over the output vocabulary.


### Input

The input text has already been split up into tokens and converted to sequences of IDs. 

Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.

The `SeqEmbedding` layer defined below:

- It looks up the embedding vector for each token.
- It looks up an embedding vector for each sequence location.
- It adds the two together.
- It uses `mask_zero=True` to initialize the keras-masks for the model.

Note: This implementation learns the position embeddings instead of using fixed embeddings like in the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer). Learning the embeddings is slightly less code, but doesn't generalize to longer sequences.

In [None]:
import torch
import torch.nn as nn

class SeqEmbedding(nn.Module):
    def __init__(self, vocab_size, max_length, depth):
        super().__init__()
        self.pos_embedding = nn.Embedding(num_embeddings=max_length, embedding_dim=depth)
        self.token_embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=depth, padding_idx=0)

    def forward(self, seq):
        seq = self.token_embedding(seq) # (batch, seq, depth)

        x = torch.arange(seq.size(1), device=seq.device)  # (seq)
        x = x.unsqueeze(0)  # (1, seq)
        x = self.pos_embedding(x)  # (1, seq, depth)

        return seq + x


### Decoder

The decoder is a standard Transformer-decoder, it contains a stack of `DecoderLayers` where each contains three sublayers: a `CausalSelfAttention`, a `CrossAttention`, and a`FeedForward`. The implementations are almost identical to the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer), refer to it for more details.

The `CausalSelfAttention` layer is below:

In [None]:
import torch
import torch.nn as nn
from torch.nn.modules.transformer import MultiheadAttention

class CausalSelfAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__()
        self.mha = MultiheadAttention(embed_dim, num_heads, **kwargs)
        self.layernorm = nn.LayerNorm(embed_dim)
  
    def forward(self, x):
        # (batch, seq, embed_dim) => (seq, batch, embed_dim)
        x = x.transpose(0, 1)
        
        attn_mask = torch.full((x.size(0), x.size(0)), -float('inf'), device=x.device)
        attn_mask = torch.triu(attn_mask, diagonal=1)

        attn_output, _ = self.mha(query=x, key=x, value=x, attn_mask=attn_mask)
        
        # (seq, batch, embed_dim) => (batch, seq, embed_dim)
        attn_output = attn_output.transpose(0, 1)
        
        x = x.transpose(0, 1) # Reverse the earlier transpose operation
        
        x = x + attn_output
        return self.layernorm(x)


The `CrossAttention` layer is below. Note the use of `return_attention_scores`.

In [None]:
import torch
import torch.nn as nn
from torch.nn.modules.transformer import MultiheadAttention

class CrossAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, **kwargs):
        super().__init__()
        self.mha = MultiheadAttention(embed_dim, num_heads, **kwargs)
        self.layernorm = nn.LayerNorm(embed_dim)
  
    def forward(self, x, y, **kwargs):
        # (batch, seq, embed_dim) => (seq, batch, embed_dim)
        x = x.transpose(0, 1)
        y = y.transpose(0, 1)

        attn_output, attention_scores = self.mha(query=x, key=y, value=y)

        self.last_attention_scores = attention_scores
        
        # (seq, batch, embed_dim) => (batch, seq, embed_dim)
        attn_output = attn_output.transpose(0, 1)
        
        x = x.transpose(0, 1) # Reverse the earlier transpose operation
        
        x = x + attn_output
        return self.layernorm(x)


The `FeedForward` layer is below. Remember that a `layers.Dense` layer is applied to the last axis of the input. The input will have a shape of `(batch, sequence, channels)`, so it automatically applies pointwise across the `batch` and `sequence` axes.  

In [None]:
import torch
import torch.nn as nn

class FeedForward(nn.Module):
    def __init__(self, units, dropout_rate=0.1):
        super().__init__()
        self.seq = nn.Sequential(
            nn.Linear(in_features=units, out_features=2*units),
            nn.ReLU(),
            nn.Linear(in_features=2*units, out_features=units),
            nn.Dropout(p=dropout_rate),
        )

        self.layernorm = nn.LayerNorm(units)
  
    def forward(self, x):
        x = x + self.seq(x)
        return self.layernorm(x)


Next arrange these three layers into a larger `DecoderLayer`. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of `out_seq` is `(batch, sequence, channels)`. The decoder layer also returns the `attention_scores` for later visualizations.

In [None]:
import torch
import torch.nn as nn

class DecoderLayer(nn.Module):
    def __init__(self, units, num_heads=1, dropout_rate=0.1):
        super().__init__()

        self.self_attention = CausalSelfAttention(embed_dim=units, num_heads=num_heads, dropout=dropout_rate)
        self.cross_attention = CrossAttention(embed_dim=units, num_heads=num_heads, dropout=dropout_rate)
        self.ff = FeedForward(units=units, dropout_rate=dropout_rate)

    def forward(self, inputs):
        in_seq, out_seq = inputs

        # Text input
        out_seq = self.self_attention(out_seq)

        out_seq = self.cross_attention(out_seq, in_seq)

        self.last_attention_scores = self.cross_attention.last_attention_scores

        out_seq = self.ff(out_seq)

        return out_seq


### Output

At minimum the output layer needs a `layers.Dense` layer to generate logit-predictions for each token at each location.

But there are a few other features you can add to make this work a little better:

1. **Handle bad tokens**: The model will be generating text. It should
   never generate a pad, unknown, or start token (`''`, `'[UNK]'`, 
   `'[START]'`). So set the bias for these to a large negative value.

   > Note: You'll need to ignore these tokens in the loss function as well. 

2. **Smart initialization**: The default initialization of a dense layer will
  give a model that initially predicts each token with almost uniform
  likelihood. The actual token distribution is far from uniform. The
  optimal value for the initial bias of the output layer is the log of the
  probability of each token. So include an `adapt` method to count the tokens
  and set the optimal initial bias. This reduces the initial loss from the
  entropy of the uniform distribution (`log(vocabulary_size)`) to the marginal
  entropy of the distribution (`-p*log(p)`).


In [None]:
import torch
import torch.nn as nn
import numpy as np
import collections
from tqdm import tqdm

class TokenOutput(nn.Module):
    def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), **kwargs):
        super().__init__()

        self.dense = nn.Linear(in_features=tokenizer.vocab_size, out_features=tokenizer.vocab_size, **kwargs)
        self.tokenizer = tokenizer
        self.banned_tokens = banned_tokens

        self.bias = None

    def adapt(self, ds):
        counts = collections.Counter()
        vocab_dict = {name: id
                      for id, name in enumerate(self.tokenizer.get_vocab().keys())}

        for tokens in tqdm(ds):
            counts.update(tokens.numpy().flatten())

        counts_arr = np.zeros(shape=(self.tokenizer.vocab_size,))
        counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values())

        counts_arr = counts_arr[:]
        for token in self.banned_tokens:
            counts_arr[vocab_dict[token]] = 0

        total = counts_arr.sum()
        p = counts_arr/total
        p[counts_arr==0] = 1.0
        log_p = np.log(p)  # log(1) == 0

        entropy = -(log_p*p).sum()

        print()
        print(f"Uniform entropy: {np.log(self.tokenizer.vocab_size):0.2f}")
        print(f"Marginal entropy: {entropy:0.2f}")

        self.bias = torch.tensor(log_p, dtype=torch.float32)
        self.bias[counts_arr==0] = -1e9

    def forward(self, x):
        x = self.dense(x)
        return x + self.bias


The smart initialization will significantly reduce the initial loss:

In [None]:
output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))
# This might run a little faster if the dataset didn't also have to load the image data.
output_layer.adapt(train_ds.map(lambda inputs, labels: labels))

### Build the model

To build the model, you need to combine several parts:

1. The image `feature_extractor` and the text `tokenizer` and.
1. The `seq_embedding` layer, to convert batches of token-IDs to 
   vectors `(batch, sequence, channels)`.
3. The stack of `DecoderLayers` layers that will process the text and image data.
4. The `output_layer` which returns a pointwise prediction of what the next word should be.

In [None]:
import torch
import torch.nn as nn

class Captioner(nn.Module):
    @classmethod
    def add_method(cls, fun):
        setattr(cls, fun.__name__, fun)
        return fun

    def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,
                 units=256, max_length=50, num_heads=1, dropout_rate=0.1):
        super().__init__()
        self.feature_extractor = feature_extractor
        self.tokenizer = tokenizer
        self.word_to_index = tokenizer.get_vocab()
        self.index_to_word = {v: k for k, v in tokenizer.get_vocab().items()}

        self.seq_embedding = SeqEmbedding(
            vocab_size=len(tokenizer.get_vocab()),
            depth=units,
            max_length=max_length)

        self.decoder_layers = nn.ModuleList([
            DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)
            for n in range(num_layers)])

        self.output_layer = output_layer


When you call the model, for training, it receives an `image, txt` pair. To make this function more usable, be flexible about the input:

* If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly
* If the text has dtype `tf.string` run it through the tokenizer.

After that running the model is only a few steps:

1. Flatten the extracted image features, so they can be input to the decoder layers.
2. Look up the token embeddings.
3. Run the stack of `DecoderLayer`s, on the image features and text embeddings.
4. Run the output layer to predict the next token at each position.


In [None]:
import einops

@Captioner.add_method
def forward(self, inputs):
    image, txt = inputs

    if image.shape[-1] == 3:
        # Apply the feature-extractor, if you get an RGB image.
        image = self.feature_extractor(image)

    # Flatten the feature map
    image = einops.rearrange(image, 'b h w c -> b (h w) c')

    if isinstance(txt, list) and isinstance(txt[0], str):
        # Apply the tokenizer if you get string inputs.
        txt = self.tokenizer(txt)

    txt = self.seq_embedding(txt)

    # Look at the image
    for dec_layer in self.decoder_layers:
        txt = dec_layer(inputs=(image, txt))

    txt = self.output_layer(txt)

    return txt


In [None]:
model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,
                  units=256, dropout_rate=0.5, num_layers=2, num_heads=2)

### Generate captions

Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.

Start by downloading a test image:

In [None]:
image_url = 'https://tensorflow.org/images/surf.jpg'
image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
image = load_image(image_path)

To caption an image with this model:

- Extract the `img_features`
- Initialize the list of output tokens with a `[START]` token.
- Pass `img_features` and `tokens` into the model.
  - It returns a list of logits.
  - Choose the next token based on those logits.  
  - Add it to the list of tokens, and continue the loop.
  - If it generates an `'[END]'` token, break out of the loop.

So add a "simple" method to do just that:

In [None]:
@Captioner.add_method
def simple_gen(self, image, temperature=1):
  initial = self.word_to_index([['[START]']]) # (batch, sequence)
  img_features = self.feature_extractor(image[tf.newaxis, ...])

  tokens = initial # (batch, sequence)
  for n in range(50):
    preds = self((img_features, tokens)).numpy()  # (batch, sequence, vocab)
    preds = preds[:,-1, :]  #(batch, vocab)
    if temperature==0:
        next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
    else:
        next = tf.random.categorical(preds/temperature, num_samples=1)  # (batch, 1)
    tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) 

    if next[0] == self.word_to_index('[END]'):
      break
  words = index_to_word(tokens[0, 1:-1])
  result = tf.strings.reduce_join(words, axis=-1, separator=' ')
  return result.numpy().decode()

Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:

In [None]:
for t in (0.0, 0.5, 1.0):
  result = model.simple_gen(image, temperature=t)
  print(result)

The temperature parameter allows you to interpolate between 3 modes:

1. Greedy decoding (`temperature=0.0`) - Chooses the most likely next token at each step.
2. Random sampling according to the logits (`temperature=1.0`).
3. Uniform random sampling (`temperature >> 1.0`). 

Since the model is untrained, and it used the frequency-based initialization, the "greedy" output (first) usually only contains the most common tokens: `['a', '.', '[END]']`.

## Train

To train the model you'll need several additional components:

- The Loss and metrics
- The Optimizer
- Optional Callbacks

### Losses and metrics

Here's an implementation of a masked loss and accuracy:

When calculating the mask for the loss, note the `loss < 1e8`. This term discards the artificial, impossibly high losses for the `banned_tokens`.

In [None]:
def masked_loss(labels, preds):  
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)

  mask = (labels != 0) & (loss < 1e8) 
  mask = tf.cast(mask, loss.dtype)

  loss = loss*mask
  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)
  return loss

def masked_acc(labels, preds):
  mask = tf.cast(labels!=0, tf.float32)
  preds = tf.argmax(preds, axis=-1)
  labels = tf.cast(labels, tf.int64)
  match = tf.cast(preds == labels, mask.dtype)
  acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)
  return acc

### Callbacks

For feedback during training setup a `keras.callbacks.Callback` to generate some captions for the surfer image at the end of each epoch.

In [None]:
class GenerateText(tf.keras.callbacks.Callback):
  def __init__(self):
    image_url = 'https://tensorflow.org/images/surf.jpg'
    image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)
    self.image = load_image(image_path)

  def on_epoch_end(self, epochs=None, logs=None):
    print()
    print()
    for t in (0.0, 0.5, 1.0):
      result = self.model.simple_gen(self.image, temperature=t)
      print(result)
    print()


It generates three output strings, like the earlier example, like before the first is "greedy", choosing the argmax of the logits at each step.

In [None]:
g = GenerateText()
g.model = model
g.on_epoch_end(0)

Also use `callbacks.EarlyStopping` to terminate training when the model starts to overfit.

In [None]:
callbacks = [
    GenerateText(),
    tf.keras.callbacks.EarlyStopping(
        patience=5, restore_best_weights=True)]

### Train

Configure and execute the training.

In [None]:
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
           loss=masked_loss,
           metrics=[masked_acc])

For more frequent reporting, use the `Dataset.repeat()` method, and set the `steps_per_epoch` and `validation_steps` arguments to `Model.fit`. 

With this setup on `Flickr8k` a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps.

In [None]:
history = model.fit(
    train_ds.repeat(),
    steps_per_epoch=100,
    validation_data=test_ds.repeat(),
    validation_steps=20,
    epochs=100,
    callbacks=callbacks)

Plot the loss and accuracy over the training run:

In [None]:
plt.plot(history.history['loss'], label='loss')
plt.plot(history.history['val_loss'], label='val_loss')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()

In [None]:
plt.plot(history.history['masked_acc'], label='accuracy')
plt.plot(history.history['val_masked_acc'], label='val_accuracy')
plt.ylim([0, max(plt.ylim())])
plt.xlabel('Epoch #')
plt.ylabel('CE/token')
plt.legend()

## Attention plots

Now, using the trained model,  run that `simple_gen` method on the image:

In [None]:
result = model.simple_gen(image, temperature=0.0)
result

Split the output back into tokens:

In [None]:
str_tokens = result.split()
str_tokens.append('[END]')

The `DecoderLayers` each cache the attention scores for their `CrossAttention` layer. The shape of each attention map is `(batch=1, heads, sequence, image)`:

In [None]:
attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]
[map.shape for map in attn_maps]

So stack the maps along the `batch` axis, then average over the `(batch, heads)` axes, while splitting the `image` axis back into `height, width`:


In [None]:
attention_maps = tf.concat(attn_maps, axis=0)
attention_maps = einops.reduce(
    attention_maps,
    'batch heads sequence (height width) -> sequence height width',
    height=7, width=7,
    reduction='mean')

Now you have a single attention map, for each sequence prediction. The values in each map should sum to `1.`

In [None]:
einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')

So here is where the model was focusing attention while generating each token of the output:

In [None]:
def plot_attention_maps(image, str_tokens, attention_map):
    fig = plt.figure(figsize=(16, 9))

    len_result = len(str_tokens)
    
    titles = []
    for i in range(len_result):
      map = attention_map[i]
      grid_size = max(int(np.ceil(len_result/2)), 2)
      ax = fig.add_subplot(3, grid_size, i+1)
      titles.append(ax.set_title(str_tokens[i]))
      img = ax.imshow(image)
      ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),
                clim=[0.0, np.max(map)])

    plt.tight_layout()

In [None]:
plot_attention_maps(image/255, str_tokens, attention_maps)

Now put that together into a more usable function:

In [None]:
@Captioner.add_method
def run_and_show_attention(self, image, temperature=0.0):
  result_txt = self.simple_gen(image, temperature)
  str_tokens = result_txt.split()
  str_tokens.append('[END]')

  attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]
  attention_maps = tf.concat(attention_maps, axis=0)
  attention_maps = einops.reduce(
      attention_maps,
      'batch heads sequence (height width) -> sequence height width',
      height=7, width=7,
      reduction='mean')
  
  plot_attention_maps(image/255, str_tokens, attention_maps)
  t = plt.suptitle(result_txt)
  t.set_y(1.05)


In [None]:
run_and_show_attention(model, image)

## Try it on your own images

For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)


In [None]:
image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)

run_and_show_attention(model, image)