## Vision Transformer Model

source code: https://drive.google.com/file/d/1ZU1u3NnPduGHmIwiu8nSBoj5gYRPSjEq/view?source=post_page-----13fc4ce253d7--------------------------------

Original Article: https://dohyeongkim.medium.com/image-to-latex-using-vision-transformer-13fc4ce253d7

In [1]:
import concurrent.futures
import collections
import dataclasses
import hashlib
import itertools
import json
import math
import os
import pathlib
import random
import re
import string
import time
import urllib.request

!pip install einops
import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from PIL import Image
import requests
import tqdm
import numpy as np

import tensorflow as tf

Defaulting to user installation because normal site-packages is not writeable


2024-12-05 22:40:43.847500: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-05 22:40:43.850740: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2024-12-05 22:40:43.857015: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1733434843.868738    3608 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1733434843.872273    3608 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-05 22:40:43.886385: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU ins

### Choose a dataset

This tutorial is set up to give a choice of datasets. Either [Flickr8k](https://www.ijcai.org/Proceedings/15/Papers/593.pdf) or a small slice of the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/) dataset. These two are downloaded and converted from scratch, but it wouldn't be hard to convert the tutorial to use the caption datasets available in [TensorFlow Datasets](https://www.tensorflow.org/datasets): [Coco Captions](https://www.tensorflow.org/datasets/catalog/coco_captions) and the full [Conceptual Captions](https://www.tensorflow.org/datasets/community_catalog/huggingface/conceptual_captions).


In [2]:
image_size = 160

data_root = 'latex_data'
vocab = open(os.path.join(data_root, "latex_vocab.txt")).readlines()
formulae = open(os.path.join(data_root, "formulas.norm.lst"), 'r').readlines()

char_to_idx = {x.split('\n')[0]:i for i, x in enumerate(vocab)}
char_to_idx['#UNK'] = len(char_to_idx)
char_to_idx['#START'] = len(char_to_idx)
char_to_idx['#END'] = len(char_to_idx)
idx_to_char = {y:x for x, y in char_to_idx.items()}

file_name = 'train.lst'
data_path = os.path.join(data_root, file_name)
file_list = open(data_path, 'r')

image_dir = os.path.join(data_root, 'images_processed')

set_list = []
missing = {}
for i, line in enumerate(file_list):
    form = formulae[int(line.split()[1])].strip().split()

    out_form = [char_to_idx['#START']]
    for c in form:
        try:
            out_form += [char_to_idx[c]]
        except:
            if c not in missing:
                print(c, " not found!")
                missing[c] = 1
            else:
                missing[c] += 1

            out_form += [char_to_idx['#UNK']]

    out_form += [char_to_idx['#END']]
    set_list.append([line.split()[0], out_form])

    image_file_name = line.split()[0]
    label = out_form

    if i == 5:
        break

In [3]:
class LatexDataset:
    def __init__(self, set='train', batch_size=32):
        self.data_root = 'latex_data'

        self.set = 'train'
        self.batch_size = batch_size

        self.train_dict = np.load(os.path.join(data_root, set + '_buckets.npy'), allow_pickle=True).tolist()

        self.data_length = np.sum([len(self.train_dict[x]) for x in self.train_dict.keys()])
        print("Length of %s data: " % set, self.data_length)

    def __len__(self):
        return self.data_length

    def __iter__(self):
        for keys in self.train_dict.keys():
            train_list = self.train_dict[keys]
            N_FILES = (len(train_list) // self.batch_size) * self.batch_size
            for batch_idx in range(0, N_FILES, self.batch_size):
                train_sublist = train_list[batch_idx:batch_idx + self.batch_size]
                imgs = []
                input_tokens = []
                label_tokens = []
                for x, y in train_sublist:
                    img = Image.open(os.path.join(self.data_root, 'images_processed/') + x)
                    img = img.resize((image_size, image_size))

                    img = np.asarray(img)[:,:,0][:,:,None] / 255.0

                    imgs.append(img)
                    input_tokens.append(y[:-1])
                    label_tokens.append(y[1:])

                imgs = np.asarray(imgs, dtype=np.float32).transpose(0, 1, 2, 3)
                lens = [len(x) for x in input_tokens]

                Y_input_tokens = np.zeros((self.batch_size, max(lens)), dtype=np.int32)
                for i, input_token in enumerate(input_tokens):
                    Y_input_tokens[i, :len(input_token)] = input_token

                Y_label_tokens = np.zeros((self.batch_size, max(lens)), dtype=np.int32)
                for i, label_token in enumerate(label_tokens):
                    Y_label_tokens[i, :len(label_token)] = label_token

                yield imgs, Y_input_tokens, Y_label_tokens

    __call__ = __iter__

In [4]:
train_ds_gen = LatexDataset(set='train', batch_size=32)
train_ds = tf.data.Dataset.from_generator(train_ds_gen, (tf.float32, tf.int32, tf.int32))

test_ds_gen = LatexDataset(set='test', batch_size=32)
test_ds = tf.data.Dataset.from_generator(test_ds_gen, (tf.float32, tf.int32, tf.int32))

Length of train data:  76511
Instructions for updating:
Use output_signature instead
Length of test data:  10355


W0000 00:00:1733434845.548611    3608 gpu_device.cc:2344] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [5]:
import matplotlib.pyplot as plt
from IPython.display import display, Math, Latex
from IPython.display import Image as ipythonImage
from io import StringIO
import IPython.display
import numpy as np

!pip install cv2
import cv2

displayPreds = lambda Y: display(Math(Y.split('#END')[0]))

properties = np.load(os.path.join(data_root, 'properties.npy'), allow_pickle=True).tolist()

vocab = open(os.path.join(data_root, "latex_vocab.txt")).readlines()

word_to_index = {x.split('\n')[0]:i for i, x in enumerate(vocab)}
word_to_index['#UNK'] = len(word_to_index)
word_to_index['#START'] = len(word_to_index)
word_to_index['#END'] = len(word_to_index)
index_to_word = {y:x for x, y in word_to_index.items()}

index_to_words = lambda Y: ' '.join(map(lambda x: properties['idx_to_char'][x], Y))

n = int(32)
plt.figure(figsize=(40, 40))

x = None
for idx, train_data in enumerate(train_ds):
    imgs, Y_input_tokens, Y_label_tokens = train_data

    print("imgs.shape: ", imgs.shape)
    print("Y_input_tokens.shape: ", Y_input_tokens.shape)
    print("Y_label_tokens.shape: ", Y_label_tokens.shape)

    sub_idx = -1

    img = imgs[sub_idx]
    Y_input_token = Y_input_tokens[sub_idx]
    Y_label_token = Y_label_tokens[sub_idx]
    print("Y_input_token: ", Y_input_token)
    print("Y_label_token: ", Y_label_token)

    #ax = plt.subplot(n, n, sub_idx + 1)
    #patch_img = tf.reshape(patch, (patch_size, patch_size, 1))
    plt.imshow(img.numpy(), cmap="gray")
    #plt.imshow(img.numpy())

    #print("img.numpy().shape: ", img.numpy().shape)

    #cv2.imshow(img.numpy())
    #plt.axis("off")
    #print("img.numpy().shape: ", img.numpy().shape)

    #preds_chars = index_to_words(Y[1:].numpy()).replace('$','')
    #preds_chars = preds_chars.split('#END')[0]
    #print("preds_chars: ", preds_chars)
    #print("")
    if idx == 0:
        break

Defaulting to user installation because normal site-packages is not writeable
[31mERROR: Could not find a version that satisfies the requirement cv2 (from versions: none)[0m[31m
[0m[31mERROR: No matching distribution found for cv2[0m[31m
[0m

ModuleNotFoundError: No module named 'cv2'

In [None]:
len(word_to_index)

### Image feature extractor

You will use an image model (pretrained on imagenet) to extract the features from each image. The model was trained as an image classifier, but setting `include_top=False` returns the model without the final classification layer, so you can use the last layer of feature-maps:  


In [None]:
'''
IMAGE_SHAPE = (image_size, image_size * 4, 1)
mobilenet = tf.keras.Sequential(
    [
        tf.keras.layers.InputLayer(input_shape=IMAGE_SHAPE),
        tf.keras.layers.Conv2D(filters=64, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=[2, 2]),
        tf.keras.layers.Conv2D(filters=128, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.MaxPool2D(pool_size=[2, 2], strides=[2, 2]),
        tf.keras.layers.Conv2D(filters=256, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.Conv2D(filters=256, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.MaxPool2D(pool_size=[1, 2], strides=[1, 2]),
        tf.keras.layers.Conv2D(filters=512, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.BatchNormalization(),
        tf.keras.layers.MaxPool2D(pool_size=[2, 1], strides=[2, 1]),
        tf.keras.layers.Conv2D(filters=512, kernel_size=[3, 3], padding='same', activation='relu', use_bias=False),
        tf.keras.layers.BatchNormalization()
    ]
)
'''

Here's a function to load an image and resize it for the model:

In [None]:
def load_image(image_path):
    img = tf.io.read_file(image_path)
    img = tf.io.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, IMAGE_SHAPE[:-1])
    return img

The model returns a feature map for each image in the input batch:

In [None]:
data_root = 'latex_data'
set = 'test'
test_dict = np.load(os.path.join(data_root, set + '_buckets.npy'), allow_pickle=True).tolist()
test_dict.keys()

In [None]:
ex_path = test_dict[(400, 160)][0][0]
image_dir = os.path.join(data_root, 'images_processed')
image_dir = os.path.join(image_dir, ex_path)

img = Image.open(image_dir).convert('YCbCr')
img = img.resize((image_size * 4, image_size))
img = np.asarray(img)[:,:,0][:,:,None]

In [None]:
#test_img_batch = load_image(ex_path)[tf.newaxis, :]
#print(mobilenet(np.expand_dims(img, 0)).shape)

In [None]:
image_size = 160
patch_size = 6
num_patches = (image_size // patch_size) ** 2
projection_dim = 512
num_heads = 4
transformer_units = [projection_dim * 2, projection_dim]
transformer_layers = 8
mlp_head_units = [2048, 1024]


def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = tf.keras.layers.Dense(units, activation=tf.keras.activations.gelu)(x)
        x = tf.keras.layers.Dropout(dropout_rate)(x)

    return x


class Patches(tf.keras.layers.Layer):
    def __init__(self, patch_size):
        super().__init__()
        self.patch_size = patch_size

    def call(self, images):
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(images=images, sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1], rates=[1, 1, 1, 1], padding="VALID")

        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])

        return patches


class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        super().__init__()
        self.num_patches = num_patches
        self.projection = tf.keras.layers.Dense(units=projection_dim)
        self.position_embedding = tf.keras.layers.Embedding(input_dim=num_patches, output_dim=projection_dim)

    def call(self, patch):
        positions = tf.range(start=0, limit=self.num_patches, delta=1)
        encoded = self.projection(patch) + self.position_embedding(positions)

        return encoded


plt.figure(figsize=(4, 4))
image = img
plt.imshow(image.astype("uint8"))
plt.axis("off")

print("image.shape: ", image.shape)
resized_image = tf.image.resize(tf.convert_to_tensor([image]), size=(image_size, image_size))
print("resized_image.shape: ", resized_image.shape)

patches = Patches(patch_size)(resized_image)
print(f"Image size: {image_size} X {image_size}")
print(f"Patch size: {patch_size} X {patch_size}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

n = int(np.sqrt(patches.shape[1]))
print("n: ", n)
'''
plt.figure(figsize=(4, 4))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(n, n, i + 1)
    patch_img = tf.reshape(patch, (patch_size, patch_size, 1))

    plt.imshow(patch_img.numpy().astype("uint8"))
    plt.axis("off")
'''

In [None]:
def create_vit_classifier():
    inputs = tf.keras.Input(shape=(image_size, image_size, 1))

    patches = Patches(patch_size)(inputs)
    encoded_patches = PatchEncoder(num_patches, projection_dim)(patches)

    for _ in range(transformer_layers):
        x1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
        attention_output = tf.keras.layers.MultiHeadAttention(num_heads=num_heads, key_dim=projection_dim, dropout=0.1)(x1, x1)
        x2 = tf.keras.layers.Add()([attention_output, encoded_patches])
        x3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)(x2)
        x3 = mlp(x3, hidden_units=transformer_units, dropout_rate=0.1)
        encoded_patches = tf.keras.layers.Add()([x3, x2])

    representation = tf.keras.layers.LayerNormalization(epsilon=1e-6)(encoded_patches)
    #representation = tf.keras.layers.Flatten()(representation)
    #representation = tf.keras.layers.Dropout(0.5)(representation)
    #features = mlp(representation, hidden_units=mlp_head_units, dropout_rate=0.5)
    model = tf.keras.Model(inputs=inputs, outputs=representation)

    return model

In [None]:
vit_classifier = create_vit_classifier()

In [None]:
ex_path = test_dict[(400, 160)][0][0]
image_dir = os.path.join(data_root, 'images_processed')
image_dir = os.path.join(image_dir, ex_path)

img = Image.open(image_dir).convert('YCbCr')
img = img.resize((image_size, image_size))
img = np.asarray(img)[:,:,0][:,:,None]

print(vit_classifier(np.expand_dims(img, 0)).shape)

### Prepare the datasets

In [None]:
'''
step = 0
for train_batch in train_ds:
    print("train_batch[0][0].shape: ", train_batch[0][0].shape)
    print("train_batch[1][0]: ", train_batch[1][0])
    img = train_batch[0][0]
    result = mobilenet(np.expand_dims(img, 0))
    print(result.shape)

    step += 1
    break
'''

In [None]:
step = 0
test_data = []
for test_batch in test_ds.random(seed=4).take(5):

    print("test_batch: ", test_batch)
    #print("test_batch[1][0]: ", test_batch[1][0])

    #img = test_batch[0][0]
    #result = mobilenet(np.expand_dims(img, 0))
    #print(result.shape)

    step += 1
    break

print("step: ", step)

## A Transformer decoder model

In [None]:
class SeqEmbedding(tf.keras.layers.Layer):
  def __init__(self, vocab_size, max_length, depth):
    super().__init__()

    self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)
    self.token_embedding = tf.keras.layers.Embedding(input_dim=vocab_size, output_dim=depth, mask_zero=True)
    self.add = tf.keras.layers.Add()

  def call(self, seq):
    seq = self.token_embedding(seq) # (batch, seq, depth)

    x = tf.range(tf.shape(seq)[1])  # (seq)
    x = x[tf.newaxis, :]  # (1, seq)
    x = self.pos_embedding(x)  # (1, seq, depth)

    return self.add([seq, x])

In [None]:
class CausalSelfAttention(tf.keras.layers.Layer):
  def __init__(self, **kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.add = tf.keras.layers.Add()
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    attn = self.mha(query=x, value=x, use_causal_mask=True)
    x = self.add([x, attn])

    return self.layernorm(x)


class CrossAttention(tf.keras.layers.Layer):
  def __init__(self,**kwargs):
    super().__init__()
    self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)
    self.add = tf.keras.layers.Add()
    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x, y, **kwargs):
    # x.shape:  TensorShape([1, 29, 256])
    # y.shape:  TensorShape([1, 1600, 512])
    #tf.print("x.shape: ", x.shape)
    #tf.print("y.shape: ", y.shape)

    attn, attention_scores = self.mha(query=x, value=y, return_attention_scores=True)
    self.last_attention_scores = attention_scores
    x = self.add([x, attn])

    return self.layernorm(x)


class FeedForward(tf.keras.layers.Layer):
  def __init__(self, units, dropout_rate=0.1):
    super().__init__()
    self.seq = tf.keras.Sequential([
        tf.keras.layers.Dense(units=2*units, activation='relu'),
        tf.keras.layers.Dense(units=units),
        tf.keras.layers.Dropout(rate=dropout_rate),
    ])

    self.layernorm = tf.keras.layers.LayerNormalization()

  def call(self, x):
    x = x + self.seq(x)
    return self.layernorm(x)


class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, units, num_heads=1, dropout_rate=0.1):
    super().__init__()

    self.self_attention = CausalSelfAttention(num_heads=num_heads, key_dim=units, dropout=dropout_rate)
    self.cross_attention = CrossAttention(num_heads=num_heads,key_dim=units, dropout=dropout_rate)
    self.ff = FeedForward(units=units, dropout_rate=dropout_rate)

  def call(self, inputs, training=False):
    in_seq, out_seq = inputs

    out_seq = self.self_attention(out_seq)
    out_seq = self.cross_attention(out_seq, in_seq)

    self.last_attention_scores = self.cross_attention.last_attention_scores
    out_seq = self.ff(out_seq)

    return out_seq

In [None]:
class TokenOutput(tf.keras.layers.Layer):
  def __init__(self, vocabulary_size, banned_tokens=('', '[UNK]', '[START]'), **kwargs):
    super().__init__()

    self.dense = tf.keras.layers.Dense(units=vocabulary_size, **kwargs)
    self.banned_tokens = banned_tokens

  def call(self, x):
    x = self.dense(x)
    return x

In [None]:
vocabulary_size = len(word_to_index)
output_layer = TokenOutput(vocabulary_size, banned_tokens=('', '#UNK', '#START'))

In [None]:
class Captioner(tf.keras.Model):
  def __init__(self, vocabulary_size, feature_extractor, output_layer, num_layers=1,
               units=256, max_length=200, num_heads=1, dropout_rate=0.1):
    super().__init__()
    self.feature_extractor = feature_extractor

    vocab = open(os.path.join(data_root, "latex_vocab.txt")).readlines()
    self.word_to_index = {x.split('\n')[0]:i for i, x in enumerate(vocab)}
    self.word_to_index['#UNK'] = len(self.word_to_index)
    self.word_to_index['#START'] = len(self.word_to_index)
    self.word_to_index['#END'] = len(self.word_to_index)
    self.index_to_word = {y:x for x, y in self.word_to_index.items()}

    self.seq_embedding = SeqEmbedding(vocab_size=vocabulary_size, depth=units, max_length=max_length)

    self.decoder_layers = [
        DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate) for n in range(num_layers)]

    self.output_layer = output_layer

  def call(self, image, txt):
    #image.shape 1:  TensorShape([1, 160, 640, 1])

    image = self.feature_extractor(image)
    #image.shape 2:  TensorShape([1, 20, 80, 512])

    #image = einops.rearrange(image, 'b h w c -> b (h w) c')
    #image.shape 3:  TensorShape([1, 1600, 512])

    txt = self.seq_embedding(txt)

    for dec_layer in self.decoder_layers:
      txt = dec_layer(inputs=(image, txt))

    txt = self.output_layer(txt)

    return txt

  def simple_gen(self, image, temperature=1):
    initial = self.word_to_index['#START'] # (batch, sequence)
    initial = tf.expand_dims(initial, 0)
    initial = tf.expand_dims(initial, 0)

    tokens = initial # (batch, sequence)

    for n in range(50):
      preds = self(image[tf.newaxis, ...], tokens).numpy()  # (batch, sequence, vocab)
      preds = preds[:,-1, :]  #(batch, vocab)
      if temperature == 0:
          next = tf.argmax(preds, axis=-1)[:, tf.newaxis]  # (batch, 1)
      else:
          next = tf.random.categorical(preds / temperature, num_samples=1)  # (batch, 1)

      next = tf.cast(next, tf.int32)

      tokens = tf.concat([tokens, next], axis=1) # (batch, sequence)

      if next[0] == self.word_to_index['#END']:
        break

    words = []
    for token in tokens[0, 1:-1]:
        word = index_to_word[token.numpy()]
        words.append(word)

    result = tf.strings.reduce_join(words, axis=-1, separator=' ')

    return result.numpy().decode()

model = Captioner(vocabulary_size, feature_extractor=create_vit_classifier(), output_layer=output_layer,
                  units=256, dropout_rate=0.5, num_layers=4, num_heads=8)

In [None]:
ex_path = test_dict[(400, 160)][0][0]
image_dir = os.path.join(data_root, 'images_processed')
image_dir = os.path.join(image_dir, ex_path)

image = Image.open(image_dir).convert('YCbCr')
image = image.resize((image_size * 4, image_size))
image = np.asarray(image)[:,:,0][:,:,None]

#for t in (0.0, 0.5, 1.0):
#  result = model.simple_gen(image, temperature=t)
#  print(result)

In [None]:
#for ds in train_ds:
#    logits = model(ds[0], ds[1])
#    print("logits.shape: ", logits.shape)
#    break

## Train

To train the model you'll need several additional components:

- The Loss and metrics
- The Optimizer
- Optional Callbacks

### Losses and metrics

Here's an implementation of a masked loss and accuracy:

When calculating the mask for the loss, note the `loss < 1e8`. This term discards the artificial, impossibly high losses for the `banned_tokens`.

In [None]:
def masked_loss(labels, preds):
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)

  mask = (labels != 0) & (loss < 1e8)
  mask = tf.cast(mask, loss.dtype)

  loss = loss * mask
  loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)

  return loss


def masked_acc(labels, preds):
  mask = tf.cast(labels != 0, tf.float32)
  preds = tf.argmax(preds, axis=-1)
  labels = tf.cast(labels, tf.int64)
  match = tf.cast(preds == labels, mask.dtype)
  acc = tf.reduce_sum(match * mask) / tf.reduce_sum(mask)

  return acc

### Callbacks

For feedback during training setup a `keras.callbacks.Callback` to generate some captions for the surfer image at the end of each epoch.

In [None]:
class GenerateText(tf.keras.callbacks.Callback):
  def __init__(self):
    ex_path = test_dict[(400, 160)][0][0]
    image_dir = os.path.join(data_root, 'images_processed')
    image_dir = os.path.join(image_dir, ex_path)

    self.image = Image.open(image_dir).convert('YCbCr')
    self.image = self.image.resize((image_size, image_size))
    self.image = np.asarray(self.image)[:,:,0][:,:,None]

  def on_epoch_end(self, epochs=None, logs=None):
    print()
    print()
    for t in (0.0, 0.5, 1.0):
      result = self.model.simple_gen(self.image, temperature=t)
      print(result)

    print()

It generates three output strings, like the earlier example, like before the first is "greedy", choosing the argmax of the logits at each step.

In [None]:
g = GenerateText()
g.model = model
g.on_epoch_end(0)

Also use `callbacks.EarlyStopping` to terminate training when the model starts to overfit.

In [None]:
callbacks = [GenerateText(), tf.keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True)]

In [None]:
learning_rate = tf.keras.optimizers.schedules.PolynomialDecay(1e-4, int(100000 / 32.0 * 1000), 1e-6)
optimizer = tf.keras.optimizers.AdamW(learning_rate=learning_rate, weight_decay=0.0001)

@tf.function
def train(ds):
    labels = ds[2]

    with tf.GradientTape() as tape:
        logits = model(ds[0], ds[1])

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, logits)

        mask = (labels != 0) & (loss < 1e8)
        mask = tf.cast(mask, loss.dtype)

        loss = loss * mask
        loss = tf.reduce_sum(loss) / tf.reduce_sum(mask)

    grads = tape.gradient(loss, model.trainable_variables)
    #(grads, _) = tf.clip_by_global_norm(grads, clip_norm=1.0)
    optimizer.apply_gradients(zip(grads, model.trainable_variables))

    return loss


@tf.function
def test(ds):
    labels = ds[2]

    with tf.GradientTape() as tape:
        preds = model(ds[0], ds[1])

        mask = tf.cast(labels != 0, tf.float32)
        preds = tf.argmax(preds, axis=-1)
        labels = tf.cast(labels, tf.int64)
        match = tf.cast(preds == labels, mask.dtype)
        acc = tf.reduce_sum(match * mask) / tf.reduce_sum(mask)

    return acc

### Train

In [None]:
writer = tf.summary.create_file_writer("tensorboard")

for epoch in range(0, 1000000):
    print("epoch: ", epoch)

    train_losses = []
    step = 0
    for ds_train in train_ds:
        loss = train(ds_train)
        train_losses.append(loss)
        step += 1

    model.save_weights('model/model_' + str(epoch))

    mean_loss_train = np.mean(train_losses)
    mean_perp_train = np.mean(list(map(lambda x: np.power(np.e, x), train_losses)))

    test_accuracies = []
    for ds_test in test_ds:
        accuracy = test(ds_test)
        test_accuracies.append(accuracy)

    mean_accuracy_test = np.mean(test_accuracies)

    print("Mean train loss:", mean_loss_train, ", Mean train perplexity:", mean_perp_train, "Mean test Accuracy:", mean_accuracy_test)
    with writer.as_default():
        tf.summary.scalar("mean_loss_train", mean_loss_train, step=epoch)
        tf.summary.scalar("mean_perp_train", mean_perp_train, step=epoch)
        tf.summary.scalar("mean_accuracy_test", mean_accuracy_test, step=epoch)
        writer.flush()

## Evaluate

In [None]:
import random

model.load_weights('model/model_464')

plt.figure(figsize=(40, 40))

data_root = 'latex_data'
set = 'test'
test_dict = np.load(os.path.join(data_root, set + '_buckets.npy'), allow_pickle=True).tolist()
data_length = np.sum([len(test_dict[x]) for x in test_dict.keys()])
print("Length of %s data: " % set, data_length)

key_list = list(test_dict.keys())
key = random.choice(key_list)
test_list = test_dict[key]
test_image_info = random.choice(test_list)

img = Image.open(os.path.join(data_root, 'images_processed/') + test_image_info[0])
img = img.resize((image_size * 4, image_size))
img = np.asarray(img)[:,:,0][:,:,None] / 255.0
plt.imshow(img, cmap="gray")

Y = np.array(test_image_info[1])

preds_chars = index_to_words(Y[1:]).replace('$','')
preds_chars = preds_chars.split('#END')[0]

print("Label: ")
displayPreds(preds_chars)

result = model.simple_gen(img, temperature=0.0)
print("Prediction: ")
displayPreds(result)

## Attention plots

Now, using the trained model,  run that `simple_gen` method on the image:

In [None]:
ex_path = test_dict[(400, 160)][0][0]
image_dir = os.path.join(data_root, 'images_processed')
image_dir = os.path.join(image_dir, ex_path)

image = Image.open(image_dir).convert('YCbCr')
image = image.resize((256, 256))
image = np.asarray(img)[:,:,0][:,:,None] / 255.0

In [None]:
result = model.simple_gen(image, temperature=0.0)
result

In [None]:
from IPython.display import display, Math, Latex
displayPreds = lambda Y: display(Math(Y))

step = 0

fig = plt.figure(figsize=(100, 100))

for ds in test_ds.shuffle(1000):
    images = ds[0]

    num_image = images.shape[0]
    for i in range(8):
      result = model.simple_gen(images[i].numpy(), temperature=0.0)
      displayPreds(result)

      grid_size = max(int(np.ceil(num_image / 9)), 9)
      ax = fig.add_subplot(6, grid_size, i + 1)
      img = ax.imshow(images[i].numpy(), cmap='gray')

    step += 1
    break
    #if step == 20:
    #    break

Split the output back into tokens:

In [None]:
str_tokens = result.split()
str_tokens.append('[END]')

The `DecoderLayers` each cache the attention scores for their `CrossAttention` layer. The shape of each attention map is `(batch=1, heads, sequence, image)`:

In [None]:
attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]
[map.shape for map in attn_maps]

So stack the maps along the `batch` axis, then average over the `(batch, heads)` axes, while splitting the `image` axis back into `height, width`:


In [None]:
attention_maps = tf.concat(attn_maps, axis=0)
attention_maps = einops.reduce(
    attention_maps,
    'batch heads sequence (height width) -> sequence height width',
    height=7, width=7,
    reduction='mean')

Now you have a single attention map, for each sequence prediction. The values in each map should sum to `1.`

In [None]:
einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')

So here is where the model was focusing attention while generating each token of the output:

In [None]:
def plot_attention_maps(image, str_tokens, attention_map):
    fig = plt.figure(figsize=(16, 9))

    len_result = len(str_tokens)

    titles = []
    for i in range(len_result):
      map = attention_map[i]
      grid_size = max(int(np.ceil(len_result/2)), 2)
      ax = fig.add_subplot(3, grid_size, i+1)
      titles.append(ax.set_title(str_tokens[i]))
      img = ax.imshow(image)
      ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),
                clim=[0.0, np.max(map)])

    plt.tight_layout()

In [None]:
plot_attention_maps(image / 255, str_tokens, attention_maps)

Now put that together into a more usable function:

In [None]:
@Captioner.add_method
def run_and_show_attention(self, image, temperature=0.0):
  result_txt = self.simple_gen(image, temperature)
  str_tokens = result_txt.split()
  str_tokens.append('[END]')

  attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]
  attention_maps = tf.concat(attention_maps, axis=0)
  attention_maps = einops.reduce(
      attention_maps,
      'batch heads sequence (height width) -> sequence height width',
      height=7, width=7,
      reduction='mean')

  plot_attention_maps(image/255, str_tokens, attention_maps)
  t = plt.suptitle(result_txt)
  t.set_y(1.05)

In [None]:
run_and_show_attention(model, image)

## Try it on your own images

For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)


In [None]:
image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'
image_path = tf.keras.utils.get_file(origin=image_url)
image = load_image(image_path)

run_and_show_attention(model, image)