# **Bristol-Myers Squibb – Molecular Translation**

<h2 style="text-align: center; font-family: Verdana; font-size: 24px; font-style: normal; font-weight: bold; text-decoration: underline; text-transform: none; letter-spacing: 2px; color: navy; background-color: #ffffff;">Image Captioning - End-to-End Pipeline -<br><font color="green">Vision Transformer</font> + <font color="blue">Transformer</font></h2><br>


---


<h5 style="text-align: center; font-family: Verdana; font-size: 12px; font-style: normal; font-weight: bold; text-decoration: None; text-transform: none; letter-spacing: 1px; color: black; background-color: #ffffff;">CREATED BY: DARIEN SCHETTLER</h5><br>

Based on ➡️ [this notebook](https://www.kaggle.com/dschettler8845/bms-visiontransformer-transformer-vit) just added extra comments.

# TABLE OF CONTENTS
```

0. BACKGROUND INFORMATION
1. IMPORTS
2. SETUP
    2.1 ACCELERATOR DETECTION
    2.2 COMPETITION DATA ACCESS
    2.3 LEVERAGING MIXED PRECISION
    2.4 LEVERAGING XLA OPTIMIZATIONS
    2.5 BASIC DATA DEFINITIONS & INITIALIZATIONS
    2.6 INITIAL DATAFRAME INSTANTIATION
    2.7 USER INPUT VARIABLES
3. HELPER FUNCTION & CLASSES
    3.1 GENERAL HELPER FUNCTIONS
4. PREPARE THE DATASET
    4.1 READ TFRECORD FILES- CREATE THE RAW DATASET(S)
    4.2 WHAT TO DO IF YOU DON'T KNOW THE FEATURE DESCRIPTIONS OF THE DATASET?
    4.3 PARSE THE RAW DATASET(S)
    4.4 WORKING WITH TF.DATA DATASET OBJECTS
5. MODEL PREPERATION
    5.1 UNDERSTANDING THE MODELS - ENCODER
    5.2 UNDERSTANDING THE MODELS - DECODER
    5.3 CREATE A LEARNING RATE SCHEDULER
    5.4 WRAP THE CONFIGURATION DETAILS IN A CLASS OBJECT FOR EASY ACCESS
    5.5 HOW TPU IMPACTS MODELS, METRICS, AND OPTIMIZERS
    5.6 LOSS CLASSES AND REDUCTION
    5.7 DISTRIBUTE THE DATASETS ACROSS REPLICAS
    5.8 DISTRIBUTED COMPUTATION & OPTIMIZING LOOPS
6. MODEL TRAINING
    6.1 INDIVIDUAL TRAIN STEP
    6.2 INDIVIDUAL VAL STEP
    6.3 INITIALIZE LOGGER
    6.4 CUSTOM TRAIN LOOP
    6.5 JUST-IN-CASE SAVE
    6.6 VIEW PREDICTIONS & DISTRIBUTION OF LEVENSHTEIN DISTANCE FOR VAL DATASET
7. INFER ON TEST DATA
    7.1 INDIVIDUAL TEST STEP (AND DISTRIBUTED)
    7.2 RAW INFERENCE LOOP
    7.3 TEST PRED POST-PROCESSING
    7.4 SAVE SUBMISSION.CSV
```



# 0. BACKGROUND INFORMATION    


<br><b style="text-decoration: underline; font-family: Verdana; text-transform: uppercase;">PRIMARY TASK DESCRIPTION</b>


**Given an image, our goal is to generate a caption. In this case, that image is of a single molecule and the description/caption is the InChI string for that molecule.**

---

<br>

<b style="text-decoration: underline; font-family: Verdana; text-transform: uppercase;">SECONDARY TASK DESCRIPTION</b>

In this notebook, we will go through, step by step, training models with TPUs in a custom way. The following steps will be covered:
* Use **`tf.data.Dataset`** as input pipeline
* Perform a custom training loop
* Correctly define loss function
* Gradient accumulation with TPUs<br>

<br>

<b style="text-decoration: underline; font-family: Verdana; text-transform: uppercase;">MORE DETAIL ON IMAGE CAPTIONING</b>


<b><sub><a href="https://machinelearningmastery.com/develop-a-deep-learning-caption-generation-model-in-python/">Description From a Tutorial I Used As Reference</a></sub></b>

>Caption generation is a challenging artificial intelligence problem where a textual description must be generated for a given photograph.
>
>It requires both methods from computer vision to understand the content of the image and a language model from the field of natural language processing to turn the understanding of the image into words in the right order. Recently, deep learning methods have achieved state-of-the-art results on examples of this problem.
>
>Deep learning methods have demonstrated state-of-the-art results on caption generation problems. What is most impressive about these methods is a single end-to-end model can be defined to predict a caption, given a photo, instead of requiring sophisticated data preparation or a pipeline of specifically designed models.


# 1. IMPORTS 

In [None]:
# @title Installing necessary packages
# Installs
print("\n... PIP/APT INSTALLS STARTING ...\n")
# Pips
!pip install -q --upgrade pip
!pip install -q pydot
!pip install -q pydotplus
!pip install tensorflow-addons
!pip install levenshtein
!pip install kaggledatasets
# Apt-get
!apt-get install -q graphviz
print("\n... PIP/APT INSTALLS COMPLETE ...\n")

In [None]:
# @title Imports

print("\n... IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")
# Machine Learning and Data Science Imports
import tensorflow as tf; print(f"\t\t– TENSORFLOW VERSION: {tf.__version__}");
import tensorflow_addons as tfa; print(f"\t\t– TENSORFLOW ADDONS VERSION: {tfa.__version__}");
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np; print(f"\t\t– NUMPY VERSION: {np.__version__}");

# Library used to easily calculate Levenshtein Distance
import Levenshtein

# Built In Imports
from kaggle_datasets import KaggleDatasets
from collections import Counter
from datetime import datetime
from glob import glob
import warnings
import requests
import imageio
import IPython
import urllib
import zipfile
import pickle
import random
import shutil
import string
import math
import time
import gzip
import ast
import sys
import io
import os
import gc
import re

# Visualization Imports
from matplotlib.colors import ListedColormap
import matplotlib.patches as patches
import plotly.graph_objects as go
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm; tqdm.pandas();
import plotly.express as px
import seaborn as sns
from PIL import Image
import matplotlib; print(f"\t\t– MATPLOTLIB VERSION: {matplotlib.__version__}");
import plotly
import PIL
import cv2


def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

    
print("\n\n... IMPORTS COMPLETE ...\n")
    
print("\n... SEEDING FOR DETERMINISTIC BEHAVIOUR ...\n")
seed_it_all()


# 2. SETUP

## 2.1 ACCELERATOR DETECTION

---

In order to use **`TPU`**, we use **`TPUClusterResolver`** for the initialization which is necessary to connect to the remote cluster and initialize cloud TPUs. Let's go over two important points

1. When using TPU on Kaggle, you don't need to specify arguments for **`TPUClusterResolver`**
2. However, on **G**oogle **C**ompute **E**ngine (**GCE**), you will need to do the following:

<br>

```python
# The name you gave to the TPU to use
TPU_WORKER = 'my-tpu-name'

# or you can also specify the grpc path directly
# TPU_WORKER = 'grpc://xxx.xxx.xxx.xxx:8470'

# The zone you chose when you created the TPU to use on GCP.
ZONE = 'us-east1-b'

# The name of the GCP project where you created the TPU to use on GCP.
PROJECT = 'my-tpu-project'

tpu = tf.distribute.cluster_resolver.TPUClusterResolver(tpu=TPU_WORKER, zone=ZONE, project=PROJECT)
```

<div class="alert alert-block alert-danger" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">🛑 &nbsp; WARNING:</b><br><br>- Although the Tensorflow documentation says it is the <b>project name</b> that should be provided for the argument <b><code>`project`</code></b>, it is actually the <b>Project ID</b>, that you should provide. This can be found on the GCP project dashboard page.<br>
</div>
<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCES:</b><br><br>
    - <a href="https://www.tensorflow.org/guide/tpu#tpu_initialization"><b>Guide - Use TPUs</b></a><br>
    - <a href="https://www.tensorflow.org/api_docs/python/tf/distribute/cluster_resolver/TPUClusterResolver"><b>Doc - TPUClusterResolver</b></a><br>

</div>
<br>

In [None]:
print(f"\n... ACCELERATOR SETUP STARTING ...\n")

# Detect hardware and return appropriate distribution strategy
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable 
    # is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
except ValueError:
    TPU = None

if TPU:
    print(f"\n... RUNNING ON TPU - {TPU.master()}...")
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
else:
    print(f"\n... RUNNING ON CPU/GPU ...")
    # Yield the default distribution strategy in Tensorflow
    #   --> Works on CPU and single GPU.
    strategy = tf.distribute.get_strategy() 

# What Is a Replica?
#    --> A single Cloud TPU device consists of FOUR chips, each of which has 
#        TWO TPU cores. 
#    --> Therefore, for efficient utilization of Cloud TPU, a program should 
#        make use of each of the EIGHT (4x2) cores. 
#    --> Each replica is essentially a copy of the training graph that is run
#        on each core and 
#        trains a mini-batch containing 1/8th of the overall batch size
N_REPLICAS = strategy.num_replicas_in_sync
    
print(f"... # OF REPLICAS: {N_REPLICAS} ...\n")

print(f"\n... ACCELERATOR SETUP COMPLETED ...\n")

## 2.2 COMPETITION DATA ACCESS
---

TPUs read data must be read directly from **G**oogle **C**loud **S**torage **(GCS)**. Kaggle provides a utility library – **`KaggleDatasets`** – which has a utility function **`.get_gcs_path`** that will allow us to access the location of our input datasets within **GCS**.<br><br>

<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📌 &nbsp; TIPS:</b><br><br>- If you have multiple datasets attached to the notebook, you should pass the name of a specific dataset to the <b><code>`get_gcs_path()`</code></b> function. <i>In our case, the name of the dataset is the name of the directory the dataset is mounted within.</i><br><br>
</div>

In [None]:
print("\n... DATA ACCESS SETUP STARTED ...\n")

if TPU:
    # Google Cloud Dataset path to training and validation images
    DATA_DIR = KaggleDatasets().get_gcs_path('bms-train-tfrecords-half-length')
    TEST_DATA_DIR = KaggleDatasets().get_gcs_path('bms-test-dataset-192x384')
#     DATA_DIR = "/kaggle/input/bms-train-tfrecords-half-length"
#     TEST_DATA_DIR = "/kaggle/input/bms-test-dataset-192x384"
else:
    # Local path to training and validation images
    DATA_DIR = "/kaggle/input/bms-train-tfrecords-half-length"
    TEST_DATA_DIR = "/kaggle/input/bms-test-dataset-192x384"
    
print(f"\n... DATA DIRECTORY PATH IS:\n\t--> {DATA_DIR}")
print(f"... TEST DATA DIRECTORY PATH IS:\n\t--> {TEST_DATA_DIR}")

print(f"\n... IMMEDIATE CONTENTS OF DATA DIRECTORY IS:")
for file in tf.io.gfile.glob(os.path.join(DATA_DIR, "*")): 
  print(f"\t--> {file}")

print(f"... IMMEDIATE CONTENTS OF TESTT DATA DIRECTORY IS:")
for file in tf.io.gfile.glob(os.path.join(TEST_DATA_DIR, "*")): 
  print(f"\t--> {file}")

    
print("\n\n... DATA ACCESS SETUP COMPLETED ...\n")

## 2.3 LEVERAGING MIXED PRECISION</h3>

---

Mixed precision is the use of both **`16-bit`** and **`32-bit`** floating-point types in a model during training to make it run faster and use less memory. By keeping certain parts of the model in the **`32-bit`** types for numeric stability, the model will have a lower step time and train equally as well in terms of the evaluation metrics such as accuracy. 

Today, most models use the **`float32`** dtype, which takes **`32`** bits of memory. However, there are two lower-precision dtypes, **`float16`** and **`bfloat16`**, each which take **`16`** bits of memory instead. Modern accelerators can run operations faster in the **`16-bit`** dtypes, as they have specialized hardware to run **`16-bit`** computations and **`16-bit`** dtypes can be read from memory faster.<br><br>

**NVIDIA GPUs** can run operations in **`float16`** faster than in **`float32`**<br>
**TPUs** can run operations **`bfloat16`** faster than in **`float32`**<br><br>

Therefore, these lower-precision dtypes should be used whenever possible on those devices. However, variables and a few computations should still be in **`float32`** for numeric reasons so that the model trains to the same quality. 

The Keras mixed precision API allows you to use a mix of either **`float16`** or **`bfloat16`** with **`float32`**, to get the performance benefits from **`float16/bfloat16`** and the numeric stability benefits from **`float32`**.<br><br>

<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; DEFINITION:</b><br><br>- The term <b>"numeric stability"</b> refers to how a model's quality is affected by the use of a lower-precision dtype instead of a higher precision dtype. We say an operation is "numerically unstable" in float16 or bfloat16 if running it in one of those dtypes causes the model to have worse evaluation accuracy or other metrics compared to running the operation in float32.<br>
</div>
<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>    - <a href="https://www.tensorflow.org/guide/mixed_precision"><b>TF Mixed Precision Overview</b></a><br>
</div>

In [None]:
print(f"\n... MIXED PRECISION SETUP STARTING ...\n")
print("\n... SET TF TO OPERATE IN MIXED PRECISION – `bfloat16` – IF ON TPU ...")

# Set Mixed Precision Global Policy
#     ---> To use mixed precision in Keras, you need to create a 
#          `tf.keras.mixed_precision.Policy` typically referred to as a dtype policy. 
#     ---> Dtype policies specify the dtypes layers will run in
tf.keras.mixed_precision.set_global_policy('mixed_bfloat16' if TPU else 'float32')

# target data type, bfloat16 when using TPU to improve throughput
TARGET_DTYPE = tf.bfloat16 if TPU else tf.float32
print(f"\t--> THE TARGET DTYPE HAS BEEN SET TO {TARGET_DTYPE} ...")

# The policy specifies two important aspects of a layer: 
#     1. The dtype the layer's computations are done in
#     2. The dtype of a layer's variables. 
print(f"\n... TWO IMPORTANT ASPECTS OF THE GLOBAL MIXED PRECISION POLICY:")
print(f'\t--> COMPUTE DTYPE  : {tf.keras.mixed_precision.global_policy().compute_dtype}')
print(f'\t--> VARIABLE DTYPE : {tf.keras.mixed_precision.global_policy().variable_dtype}')

print(f"\n\n... MIXED PRECISION SETUP COMPLETED ...\n")

## 2.4 LEVERAGING XLA OPTIMIZATIONS


---


**XLA** (Accelerated Linear Algebra) is a domain-specific compiler for linear algebra that can accelerate TensorFlow models with potentially no source code changes. **The results are improvements in speed and memory usage**.

<br>

When a TensorFlow program is run, all of the operations are executed individually by the TensorFlow executor. Each TensorFlow operation has a precompiled GPU/TPU kernel implementation that the executor dispatches to.

XLA provides us with an alternative mode of running models: it compiles the TensorFlow graph into a sequence of computation kernels generated specifically for the given model. Because these kernels are unique to the model, they can exploit model-specific information for optimization.<br><br>

<div class="alert alert-block alert-danger" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">🛑 &nbsp; WARNING:</b><br><br>- XLA can not currently compile functions where dimensions are not inferrable: that is, if it's not possible to infer the dimensions of all tensors without running the entire computation<br>
</div>
<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📌 &nbsp; NOTE:</b><br><br>- XLA compilation is only applied to code that is compiled into a graph (in <b>TF2</b> that's only a code inside <b><code>tf.function</code></b>).<br>- The <b><code>jit_compile</code></b> API has must-compile semantics, i.e. either the entire function is compiled with XLA, or an <b><code>errors.InvalidArgumentError</code></b> exception is thrown)
</div>
<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>    - <a href="https://www.tensorflow.org/xla"><b>XLA: Optimizing Compiler for Machine Learning</b></a><br>
</div>

In [None]:
print(f"\n... XLA OPTIMIZATIONS STARTING ...\n")

print(f"\n... CONFIGURE JIT (JUST IN TIME) COMPILATION ...\n")
# Enable XLA optmizations (10% speedup when using @tf.function calls)
tf.config.optimizer.set_jit(True)

print(f"\n... XLA OPTIMIZATIONS COMPLETED ...\n")

## 2.5 BASIC DATA DEFINITIONS & INITIALIZATIONS

---


In [None]:
print("\n... BASIC DATA SETUP STARTING ...\n")

# All the possible tokens in our InChI 'language'
TOKEN_LIST = ["<PAD>", "InChI=1S/", "<END>", "/c", "/h", "/m", "/t", "/b", "/s", "/i"] +\
             ['Si', 'Br', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', 'C', 'H', 'B', ] +\
             [str(i) for i in range(167,-1,-1)] +\
             ["\+", "\(", "\)", "\-", ",", "D", "T"]
print(f"\n... TOKEN LIST:")
for i, tok in enumerate(TOKEN_LIST): print(f"\t--> INTEGER-IDX = {i:<3}  –––  STRING = {tok}")

# The start/end/pad tokens will be removed from the string when computing the Levenshtein distance
# We want them as tf.constant's so they will operate properly within the @tf.function context
START_TOKEN = tf.constant(TOKEN_LIST.index("InChI=1S/"), dtype=tf.uint8)
END_TOKEN = tf.constant(TOKEN_LIST.index("<END>"), dtype=tf.uint8)
PAD_TOKEN = tf.constant(TOKEN_LIST.index("<PAD>"), dtype=tf.uint8)

# Prefixes and Their Respective Ordering/Format
#      -- ORDERING --> {c}{h/None}{b/None}{t/None}{m/None}{s/None}{i/None}{h/None}{t/None}{m/None}
PREFIX_ORDERING = "chbtmsihtm"
print(f"\n... PREFIX ORDERING IS {PREFIX_ORDERING} ...")

# Paths to Respective Image Directories
TRAIN_DIR = os.path.join(DATA_DIR, "train_records")
VAL_DIR = os.path.join(DATA_DIR, "val_records")
TEST_DIR = os.path.join(TEST_DATA_DIR, "test_records")

# Get the Full Paths to The Individual TFRecord Files
TRAIN_TFREC_PATHS = sorted(
    tf.io.gfile.glob(os.path.join(TRAIN_DIR, "*.tfrec")), 
    key=lambda x: int(x.rsplit("_", 2)[1]))
VAL_TFREC_PATHS = sorted(
    tf.io.gfile.glob(os.path.join(VAL_DIR, "*.tfrec")), 
    key=lambda x: int(x.rsplit("_", 2)[1]))
TEST_TFREC_PATHS = sorted(
    tf.io.gfile.glob(os.path.join(TEST_DIR, "*.tfrec")), 
    key=lambda x: int(x.rsplit("_", 2)[1]))

print(f"\n... TFRECORD INFORMATION:")
for SPLIT, TFREC_PATHS in zip(["TRAIN", "VAL", "TEST"], [TRAIN_TFREC_PATHS, 
                                                        VAL_TFREC_PATHS, 
                                                        TEST_TFREC_PATHS]):
    print(f"\t--> {len(TFREC_PATHS):<3} {SPLIT:<5} TFRECORDS")

# Paths to relevant CSV files containing training and submission information
TRAIN_CSV_PATH = os.path.join("/kaggle/input", "bms-csvs-w-extra-metadata", "train_labels_w_extra.csv")
SS_CSV_PATH    = os.path.join("/kaggle/input", "bms-csvs-w-extra-metadata", "sample_submission_w_extra.csv")
print(f"\n... PATHS TO CSVS:")
print(f"\t--> TRAIN CSV: {TRAIN_CSV_PATH}")
print(f"\t--> SS CSV   : {SS_CSV_PATH}")

# When debug is true we use a smaller batch size and smaller model
DEBUG=False

print("\n\n... BASIC DATA SETUP COMPLETED ...\n")

## 2.6 INITIAL DATAFRAME INSTANTIATION

---


In [None]:
print("\n... INITIAL DATAFRAME INSTANTIATION STARTING ...\n")

# Load the train and submission dataframes
train_df = pd.read_csv(TRAIN_CSV_PATH)
ss_df    = pd.read_csv(SS_CSV_PATH)

# --- Distribution Information ---
N_EX    = len(train_df)
N_TEST  = len(ss_df)
N_VAL   = 80_000 # Fixed from dataset creation information
N_TRAIN = N_EX-N_VAL

# --- Batching Information ---
DEBUG=False
BATCH_SIZE_DEBUG   = 2
REPLICA_BATCH_SIZE = 128 # Could probably be 128

if DEBUG:
    REPLICA_BATCH_SIZE = BATCH_SIZE_DEBUG
OVERALL_BATCH_SIZE = REPLICA_BATCH_SIZE*N_REPLICAS


# --- Input Image Information ---
IMG_SHAPE = (192,384,3)

# --- Autocalculate Training/Validation/Testing Information ---
TRAIN_STEPS = N_TRAIN  // OVERALL_BATCH_SIZE
VAL_STEPS   = N_VAL    // OVERALL_BATCH_SIZE
TEST_STEPS  = int(np.ceil(N_TEST/OVERALL_BATCH_SIZE))

# This is for padding our test dataset so we only have whole batches
REQUIRED_DATASET_PAD = OVERALL_BATCH_SIZE-N_TEST%OVERALL_BATCH_SIZE

# --- Modelling Information ---
ATTN_EMB_DIM  = 192
N_RNN_UNITS   = 512

print(f"\n... # OF TRAIN+VAL EXAMPLES  : {N_EX:<7} ...")
print(f"... # OF TRAIN EXAMPLES      : {N_TRAIN:<7} ...")
print(f"... # OF VALIDATION EXAMPLES : {N_VAL:<7} ...")
print(f"... # OF TEST EXAMPLES       : {N_TEST:<7} ...\n")

print(f"\n... REPLICA BATCH SIZE    : {REPLICA_BATCH_SIZE} ...")
print(f"... OVERALL BATCH SIZE    : {OVERALL_BATCH_SIZE} ...\n")

print(f"\n... IMAGE SHAPE           : {IMG_SHAPE} ...\n")

print(f"\n... TRAIN STEPS PER EPOCH : {TRAIN_STEPS:<5} ...")
print(f"... VAL STEPS PER EPOCH   : {VAL_STEPS:<5} ...")
print(f"... TEST STEPS PER EPOCH  : {TEST_STEPS:<5} ...\n")

print("\n... TRAIN DATAFRAME ...\n")
display(train_df.head(3))

print("\n... SUBMISSION DATAFRAME ...\n")
display(ss_df.head(3))

print("\n... INITIAL DATAFRAME INSTANTIATION COMPLETED...\n")

## 2.7 USER INPUT VARIABLES

---


In [None]:
print("\n... SPECIAL VARIABLE SETUP STARTING ...\n")


# Whether to start training using previously checkpointed model
LOAD_MODEL        = False
ENCODER_CKPT_PATH = ""
TRANSFORMER_CKPT_PATH = ""

if LOAD_MODEL:
    if TRANSFORMER_CKPT_PATH != "":
        print(f"... TRANSFORMER MODEL TRAINING WILL RESUME FROM PREVIOUS CHECKPOINT:\n\t-->{TRANSFORMER_CKPT_PATH}\n")
    elif ENCODER_CKPT_PATH != "":
        print(f"\n... ENCODER MODEL TRAINING WILL RESUME FROM PREVIOUS CHECKPOINT:\n\t-->{ENCODER_CKPT_PATH}\n")    
    else:
        print(f"\n... MODEL TRAINING WILL START FROM SCRATCH ...\n")
else:
    print(f"\n... MODEL TRAINING WILL START FROM SCRATCH ...\n")

    
print("\n... SPECIAL VARIABLE SETUP COMPLETED ...\n")

# 3.  HELPER FUNCTION & CLASSESS

## 3.1 GENERAL HELPER FUNCTIONS

---

In [None]:
def flatten_l_o_l(nested_list):
    """ Function to flatten a list of lists """
    return [item for sublist in nested_list for item in sublist]


def tf_load_image(path, img_size=(192,384,3), invert=False):
    """ 
    Function to load an image with desired size and shape .
    
    Args:
        path (tf.string): Path to the image to be loaded
        img_size (tuple, optional): Size to reshape image to (required for TPU)
        invert (bool, optional): Whether or not to invert the background/foreground
    
    Returns:
        img: tf.Constant image ready for training/inference
    """
    img = decode_img(tf.io.read_file(path), img_size, n_channels=3, invert=invert)        
    return img
    
    
def decode_image(image_data, resize_to=(192,384,3)):
    """ 
    Function to decode the tf.string containing image information 
    
    Args:
        image_data (tf.string): String containing encoded image data from tf.Example
        resize_to (tuple, optional): Size that we will reshape the tensor to (required for TPU)
    
    Returns:
        Tensor containing the resized single-channel image in the appropriate dtype
    """
    image = tf.image.decode_png(image_data, channels=3)
    image = tf.reshape(image, resize_to)
    return tf.cast(image, TARGET_DTYPE)
    
    
# sparse tensors are required to compute the Levenshtein distance
def dense_to_sparse(dense):
    """
    Function to convert a dense tensor to a sparse tensor 
    
    Args:
        dense (Tensor): A dense tensor
        
    Returns:
        sparse (Tensor): A sparse tensor     
    """
    indices = tf.where(tf.ones_like(dense))
    values = tf.reshape(dense, (MAX_LEN*OVERALL_BATCH_SIZE,))
    sparse = tf.SparseTensor(indices, values, dense.shape)
    return sparse

def get_levenshtein_distance(preds, lbls):
    """ 
    Function to computes the Levenshtein distance between the predictions and labels. 
    
    Args:
        preds (tensor): Batch of predictions
        lbls (tensor): Batch of labels
        
    Returns:
        mean_distance (int): The mean Levenshtein distance calculated across the batch
    """
    preds = tf.where(tf.not_equal(lbls, END_TOKEN) & tf.not_equal(lbls, PAD_TOKEN), preds, 0)
    lbls = tf.where(tf.not_equal(lbls, END_TOKEN), lbls, 0)

    preds_sparse = dense_to_sparse(preds)
    lbls_sparse = dense_to_sparse(lbls)

    batch_distance = tf.edit_distance(preds_sparse, lbls_sparse, normalize=False)
    mean_distance = tf.math.reduce_mean(batch_distance)
    
    return mean_distance

# 4. PREPARE THE DATASET  

In this section we prepare the **`tf.data.Datasets`** we will use for training and validation

In [None]:
print("\n\n... STARTING PREPARING VARIABLES FOR DATASET ...\n")

tok_2_int = {c.strip("\\"):i for i,c in enumerate(TOKEN_LIST)}
int_2_tok = {v:k for k,v in tok_2_int.items()}

# Max Length Was previously determined using-
#     >>> MAX_LEN = train_df.InChI.progress_apply(lambda x: len(re.findall("|".join(TOKEN_LIST), x))).max()+1
MAX_LEN = ((train_df.inchi_token_len.max()+1)//2) # //2 yields 138... which is half of max length (speeds up training)
VOCAB_LEN = len(int_2_tok)

print(f"\t--> TOKEN TO INTEGER MAP     : {tok_2_int}")
print(f"\t--> INTEGER TO TOKEN MAP     : {int_2_tok}")
print(f"\t--> MAX # OF TOKENS IN INCHI : {MAX_LEN}")
print(f"\t--> LENGTH OF VOCAB          : {VOCAB_LEN}")

print(f"\n\n\t--> CONVERTED INCHI STRINGS  :")
for i, row in train_df.iloc[:N_VAL].sample(3).iterrows():
    print(f"\n\t\t--> EXAMPLE #{i} FROM THE VALIDATION DATASET")
    print("\t\t\t--> RAW INCHI : ", row["InChI"])

print("\n\n... PREPARING VARIABLES FOR DATASET COMPLETED ...\n")

## 4.1 READ TFRECORD FILES - CREATE THE RAW DATASET(S)

---

Here we will leverage **`tf.data.TFRecordDataset`** to read the TFRecord files.
* The simplest way is to specify a list of filenames (paths) of TFRecord files.
* It is a subclass of **`tf.data.Dataset`**.

This newly created raw dataset contains **`tf.train.Example`** messages, and when iterated over it, we get scalar string tensors.

In [None]:
print("\n... CREATE TFRECORD RAW DATASETS STARTING ...\n")

# Create tf.data.Dataset from filepaths for conversion later
raw_train_ds = tf.data.TFRecordDataset(TRAIN_TFREC_PATHS, num_parallel_reads=None)
raw_val_ds = tf.data.TFRecordDataset(VAL_TFREC_PATHS, num_parallel_reads=None)
raw_test_ds = tf.data.TFRecordDataset(TEST_TFREC_PATHS, num_parallel_reads=None)

# raw_test_ds = tf.data.TFRecordDataset(TEST_TFREC_PATHS, num_parallel_reads=None)

print(f"\n... THE RAW TF.DATA.TFRECORDDATASET OBJECT:\n\t--> {raw_train_ds}\n")

print("\n... CREATE TFRECORD RAW DATASETS COMPLETED ...\n")

## 4.2 WHAT TO DO IF YOU DON'T KNOW THE FEATURE DESCRIPTIONS OF THE DATASET?

---

If you are the author who created the TFRecord files, you definitely know how to define the feature description to parse the raw dataset.

Otherwise, you can use like

```python
example = tf.train.Example()
example.ParseFromString(serialized_example.numpy())
```

to check the information. You will get something like

```python
features {
    feature {
        key: "class"
        value {
            int64_list {
                value: 57
            }
        }
    }
    feature {
        key: "id"
        value {
            bytes_list {
                value: "338ab7bac"
            }
        }
    }
    feature {
        key: "image"
        value {
            bytes_list {
                value: ...
            }
        }
    }
    ...
}
```

This should give you enough information to define the feature description.

In [None]:
print("\n... RAW TFRECORD INVESTIGATION TO DETERMINE FEATURE DESCRIPTIONS STARTED ...\n")

print("\n... EXAMPLE OF TRUNCATED RAW TFRECORD/TFEXAMPLE FROM TRAINING DATASET TO SHOW HOW TO FIND FEATURE DESCRIPTIONS:\n")

# Example
for raw in raw_train_ds.take(1):
    example = tf.train.Example()
    example.ParseFromString(raw.numpy())
    for i, (k,v) in enumerate(example.features.feature.items()):
        print(f"\tFEATURE #{i+1}")
        print(f"\t\t--> KEY = {k}")
        if k!="image":
            try:
                print(f"\t\t\t--> TRUNCATED-VALUE = {v.int64_list.value[:15]} ...\n")
            except:
                print(f"\t\t\t--> TRUNCATED-VALUE = {v.bytes_list.value[0][:25]} ...\n")
        else:
            print(f"\t\t\t--> TRUNCATED-VALUE = {str(v.bytes_list.value[0][:25])} ...\n")         

print("\n... RAW TFRECORD INVESTIGATION TO DETERMINE FEATURE DESCRIPTIONS COMPLETED ...\n")

## 4.3 PARSE THE RAW DATASET(S)4.3 PARSE THE RAW DATASET(S)

---


The general recipe to parse the string tensors in the raw dataset looks something like this:

<br>

**STEP 1.**  Create a description of the features. For example:

```python
feature_description = {    
    'feature0': tf.io.FixedLenFeature([], tf.int64),
    'feature1': tf.io.FixedLenFeature([], tf.string),
    'feature2': tf.io.FixedLenFeature([], tf.float32),
    ...
}
```

<br>

**STEP 2.**  Define a parsing function by using `tf.io.parse_single_example` and the defined feature description.
```python
def _parse_function(example):
    """
    Args:
        example: A string tensor representing a `tf.train.Example`.
    """

    # Parse `example`.
    parsed_example = tf.io.parse_single_example(example, feature_description)

    return parsed_example
```

<br>

**STEP 3.**  Map the raw dataset by `_parse_function`.
```python
dataset = raw_dataset.map(_parse_function)
```

<br>

---

<br>

**In the following cell, we apply the above recipe to our BMS tfrecord dataset.**

<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📌 &nbsp; NOTE:</b><br><br>- The parsed images are <code><b>`tf.string`</b></code>, which are then decoded with <code><b>`tf.image.decode_png`</b></code> which is an alias for <code><b>`tf.io.decode_png`</b></code><br>- The InChI strings and Image IDs will just be left as byte string tensors.
</div>
<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>
    - <a href="https://www.tensorflow.org/tutorials/load_data/tfrecord"><b>Tutorial - TFRecord and tf.Example</b></a><br>
    - <a href="https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset"><b>TFRecordDataset Documentation</b></a><br>
    - <a href="https://www.tensorflow.org/api_docs/python/tf/io/decode_png"><b>Decoding PNGs Documentation</b></a><br>
</div>


In [None]:
def decode(serialized_example, is_test=False, tokenized_inchi=True):
    """ 
    Function to parses a set of features and label from the given `serialized_example`.
    It is used as a map function for `dataset.map`

    Args:
        serialized_example (tf.Example): A serialized example containing the
            following features:
                – 'image'
                – 'image_id'
                – 'inchi'
        is_test (bool, optional): Whether to allow for the InChI feature
        drop_id (bool, optional): Whether or not to drop the ID feature
        
    Returns:
        A decoded tf.data.Dataset object representing the tfrecord dataset
    """
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=[],
                                       dtype=tf.string, 
                                       default_value=''),
    }
    
    if not is_test:
        if tokenized_inchi:
            feature_dict["inchi"] = tf.io.FixedLenFeature(shape=[MAX_LEN], 
                                                        dtype=tf.int64, 
                                                        default_value=[0]*MAX_LEN)
        else:
            feature_dict["inchi"] = tf.io.FixedLenFeature(shape=[], 
                                                        dtype=tf.string, 
                                                        default_value='')
    else:
        feature_dict["image_id"] = tf.io.FixedLenFeature(shape=[], 
                                                        dtype=tf.string, 
                                                        default_value='')
    
    # Define a parser
    features = tf.io.parse_single_example(serialized_example, features=feature_dict)
    
    # Decode the tf.string
    image = decode_image(features['image'], resize_to=IMG_SHAPE)
    
    # Figure out the correct information to return
    if is_test:
        image_id = features["image_id"] 
        return image, image_id
    else:
        if tokenized_inchi:
            target = tf.cast(features["inchi"], tf.uint8)
        else:
            target = features["inchi"]
        return image, target

In [None]:
print("\n... DECODING RAW TFRECORD DATASETS STARTING ...\n")

# Decode the tfrecords completely –– decode is our `_parse_function` 
train_ds = raw_train_ds.map(lambda x: decode(x, is_test=False))
val_ds = raw_val_ds.map(lambda x: decode(x, is_test=False))
test_ds = raw_test_ds.map(lambda x: decode(x, is_test=True))

print(f"\n... THE DECODED TF.DATA.TFRECORDDATASET OBJECT:" \
      f"\n\t--> ((image), (image_id - optional), (inchi))" \
      f"\n\t--> {train_ds}\n")

print("\n... 2 EXAMPLES OF IMAGES AND LABELS AFTER DECODING ...")
for i, (img, inchi) in enumerate(train_ds.take(2)):
    print(f"\nIMAGE SHAPE : {img.shape}")
    print(f"IMAGE INCHI : {[int_2_tok[x] for x in inchi.numpy()]}\n")
    plt.figure(figsize=(10,10))
    plt.imshow(img.numpy().astype(np.int64), cmap="gray")
    plt.title(f"{''.join([int_2_tok[x] for x in inchi.numpy() if x!=0][:50])} ... [truncated]")
    plt.show()

print("\n... DECODING RAW TFRECORD DATASETS COMPLETED ...\n")

## 4.4 WORKING WITH `TF.DATA.DATASET` OBJECTS

---

With the above parsing methods defined, we can define how to load the dataset with more options and further apply shuffling, bacthing, etc. In particular the following methods and attributes are of special interest to us:
* Use **`num_parallel_reads`** in **`tf.data.TFRecordDataset`** to read files in parallel.
* Set **`tf.data.Options.experimental_deterministic=False`** and use it to get a new dataset that ignores the order of elements.
* Use **`num_parallel_calls`** in **`tf.data.Dataset.map()`** method to have parallel processing.
* Use **`tf.data.Dataset.prefetch()`** to allow later batches to be prepared while the current batch is being processed.
* Use **`tf.data.AUTOTUNE`** to automatically determine parallelization argument values

The parallel processing and prefetching are particular important when working with TPU:
* This is because a TPU can process batches very quickly
* The dataset pipeline should be able to provide data for TPU efficiently, otherwise the TPU will be idle.

**In the cell below we will create the functions and configuration template which will later be used to create our respective datasets**

<br>
<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>
    - <a href="https://www.tensorflow.org/guide/data"><b>Guide - tf.data: Build TensorFlow Input Pipelines</b></a><br>
    - <a href="https://www.tensorflow.org/guide/data_performance"><b>Guide - Better Performance With the tf.data API</b></a><br>
    - <a href="https://www.tensorflow.org/api_docs/python/tf/data/Dataset"><b>tf.data.Dataset Documentation</b></a><br>
</div>

In [None]:
def load_dataset(filenames, is_test=False, ordered=False, tokenized_inchi=True):
    """
    Function to read the dataset from TFRecords.
    For optimal performance, reading from multiple files at once and 
    disregarding data order (if `ordered=False`).
        - If pulling InChI from TFRecords than order does not matter since we  
          will be shuffling the data anyway (for training dataset).
          
    Args:
        filenames (list of strings): List of paths to that point to the 
                                    respective TFRecord files
        is_test (bool, optional): Whether or not to include the image ID or 
                                label in the returned dataset
        ordered (bool, optional): Whether to ensured ordered results or 
                                maximize parallelization
        tokenized_inchi (bool, optional): Whether our dataset includes the 
                                        tokenized inchi or we will be creating 
                                        it from the caption numpy array
        
    Returns:
        dataset: Decoded tf.data.Dataset object
    """

    options = tf.data.Options()
    if not ordered:
        # disable order, increase speed
        options.experimental_deterministic = False
        N_PARALLEL=tf.data.AUTOTUNE
    else:
        N_PARALLEL=None
        
    # If not-ordered, this will read in by automatically interleaving multiple tfrecord files.
    dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=N_PARALLEL)
    
    # If not-ordered, this will ensure that we use data as soon as it 
    # streams in, rather than in its original order.
    dataset = dataset.with_options(options) 
    
    # parse and return a dataset w/ the appropriate configuration
    dataset = dataset.map(
        lambda x: decode(x, is_test, tokenized_inchi),
        num_parallel_calls=N_PARALLEL,
    )
    
    return dataset

def get_dataset(filenames, batch_size, is_test=False, shuffle_buffer_size=1, 
                repeat_dataset=True, preserve_file_order=False, 
                drop_remainder=True, tokenized_inchi=True,
                external_inchi_dataset=None, test_padding=0):
    """ 
    Function to get a tf.data.Dataset with the appropriate configuration.
    
    Args:
        filenames (list of strings): List of paths to that point to the respective TFRecord files
        batch_size (int): Batch size to be used during batching the dataset
        is_test (bool, optional): Whether or not to include the image ID or label in the returned dataset
        shuffle_buffer_size (int, optional): Number of elements from which the new dataset will be sampled
        repeat_dataset (bool, optional): Whether the dataset is to be repeated 
        preserve_file_order (bool, optional): Whether to ensured ordered results or maximize parallelization
        drop_remainder (bool, optional): Whether the last batch should be dropped 
                                         in case its size is smaller than desired
        tokenized_inchi (bool, optional): Whether our dataset includes the 
                                          tokenized inchi or we will be creating 
                                          it from the caption numpy array
        external_inchi_dataset (obj or None): None if no external inchi dataset is uploaded 
        test_padding (int, optional): Amount required to pad dataset to have only full batches
        
    Returns:
        dataset: new tf.data.Dataset object
    """
    # Load the dataset
    dataset = load_dataset(filenames, is_test, preserve_file_order, tokenized_inchi)
    
    if test_padding!=0:
        pad_dataset = tf.data.Dataset.from_tensor_slices((
            tf.zeros((test_padding, *IMG_SHAPE), dtype=TARGET_DTYPE),       # Fake Images
            tf.constant(["000000000000",]*test_padding, dtype=tf.string))   # Fake IDs
        )
        dataset = dataset.concatenate(pad_dataset)
    
    # If we are training than we will want to repeat the dataset. 
    # We will determine the number of steps (or updates) later for 1 training epoch.
    if repeat_dataset:
        dataset = dataset.repeat()
    
    # If we need to add on manually the inchi
    if external_inchi_dataset is not None:
        # Zip the datasets and tile the 1 channel image to 3 channels & drop the old inchi value
        dataset = tf.data.Dataset.zip((dataset, external_inchi_dataset))
        dataset = dataset.map(lambda x,y: (tf.tile(tf.expand_dims(x[0], -1), tf.constant([1,1,3], tf.int32)), y))
                              
    # Shuffling
    if shuffle_buffer_size!=1:
        dataset = dataset.shuffle(shuffle_buffer_size)
    
    # Batching
    dataset = dataset.batch(batch_size, drop_remainder=drop_remainder)
    
    # prefetch next batch while training (autotune prefetch buffer size)
    dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
    
    return dataset

In [None]:
# Template Configuration
DS_TEMPLATE_CONFIG = dict(
    filenames=[],
    batch_size=1,
    is_test=False, 
    shuffle_buffer_size=1, 
    repeat_dataset=True, 
    preserve_file_order=False, 
    drop_remainder=True,
    tokenized_inchi=True,
    external_inchi_dataset=None,
    test_padding=0
)

# Individual Respective Configurations
TRAIN_DS_CONFIG = DS_TEMPLATE_CONFIG.copy()
TRAIN_DS_CONFIG.update(dict(
    filenames=TRAIN_TFREC_PATHS,
    batch_size=OVERALL_BATCH_SIZE,
    shuffle_buffer_size=OVERALL_BATCH_SIZE*6,
))

VAL_DS_CONFIG = DS_TEMPLATE_CONFIG.copy()
VAL_DS_CONFIG.update(dict(
    filenames=VAL_TFREC_PATHS,
    batch_size=OVERALL_BATCH_SIZE,
))

TEST_DS_CONFIG = DS_TEMPLATE_CONFIG.copy()
TEST_DS_CONFIG.update(dict(
    filenames=TEST_TFREC_PATHS,
    batch_size=OVERALL_BATCH_SIZE,
    is_test=True,
    repeat_dataset=False,
    drop_remainder=True,
    test_padding=REQUIRED_DATASET_PAD,
))

####### ####### ####### ####### ####### ####### ####### #######

train_ds = get_dataset(**TRAIN_DS_CONFIG)
val_ds = get_dataset(**VAL_DS_CONFIG)
test_ds = get_dataset(**TEST_DS_CONFIG)

for SPLIT, CONFIG in zip(["TRAINING", "VALIDATION", "TESTING"], [TRAIN_DS_CONFIG, 
                                                                VAL_DS_CONFIG, 
                                                                TEST_DS_CONFIG]): 
    print(f"\n... {SPLIT} CONFIGURATION:")
    for k,v in CONFIG.items():
        if k=="filenames":
            print(f"\t--> {k:<23}: {[path.split('/', 4)[-1] for path in v[:2]]+['...']}")
        else:
            print(f"\t--> {k:<23}: {v}")

print(f"\n\n... TRAINING DATASET   : {train_ds} ...")
print(f"... VALIDATION DATASET : {val_ds} ...")
print(f"... TESTING DATASET    : {test_ds}    ...\n")

print("\n\n ... SOME VALIDATION EXAMPLES ... \n\n")
for x,y in val_ds.take(1):
    for i in range(2):
        plt.figure(figsize=(12,12))
        plt.imshow(x[i].numpy().astype(np.int64))
        plt.title(f"IMAGE INCHI : {''.join([int_2_tok[z] for z in y[i].numpy() if z not in [0,1,2]])}\n")
        plt.show()
        
print("\n\n ... SOME TESTING EXAMPLES ... \n\n")
for x,y in test_ds.take(1):
    for i in range(2):
        plt.figure(figsize=(12,12))
        plt.imshow(x[i].numpy().astype(np.int64))
        plt.title(f"{y[i].numpy().decode()}")
        plt.show()

# 5.  MODEL PREPERATION   


In this section we prepare the models for training. More information on Vision Transformers can we found on the [Google-search Vision Transformer implementation](https://github.com/google-research/vision_transformer).

<br>

<center><img src="https://github.com/google-research/vision_transformer/raw/main/vit_figure.png" width=50%></center>

<br>

## 5.1 UNDERSTANDING THE MODELS - ViT


### 5.1.1 ViT - Implement Patch Creation as a Layer

In [None]:
class PatchCreator(tf.keras.layers.Layer):
    ''' Creates Patches for input images '''
    def __init__(self, patch_size):
        '''
        Args:
            patch_size (int): Size of the patches to be extracted from the 
                              input images
        
        Returns:
            None, this is initialization
        '''
        super(PatchCreator, self).__init__()
        self.patch_size = patch_size

    def call(self, images):
        ''' Calling function for patch creation class  '''
        batch_size = tf.shape(images)[0]
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        patch_dims = patches.shape[-1]
        patches = tf.reshape(patches, [batch_size, -1, patch_dims])
        return patches


In [None]:
# Grab a demo image and label and define an arbitrary patch_size
demo_img, demo_lbl = next(iter(train_ds.unbatch().batch(1)))
demo_patch_size=16

# Instantiate the PatchCreator layer and call on the demo image
with tf.device('/CPU:0'):
    patch_creator = PatchCreator(demo_patch_size)
    patches = patch_creator(demo_img)

print(f"Image size: {IMG_SHAPE}")
print(f"Patch size: {(demo_patch_size, demo_patch_size)}")
print(f"Patches shape: {patches.shape}")
print(f"Patches per image: {patches.shape[1]}")
print(f"Elements per patch: {patches.shape[-1]}")

In [None]:
# PATCH LAYER FUNCTIONALITY

# 1. Plot the original image
print("\n\n... ORIGINAL IMAGE ...\n")
plt.figure(figsize=(18,9))
plt.imshow(demo_img[0].numpy().astype("float32")/255.)
plt.axis('off')
plt.tight_layout()
plt.show()

# 2. Plot the patches in the same shape/order as the original image
print("\n\n\n... IMAGE PATCHES ...\n")
plt.figure(figsize=(18,9))
for i, patch in enumerate(patches[0]):
    ax = plt.subplot(int(np.ceil(IMG_SHAPE[0]/demo_patch_size)), 
                     int(np.ceil(IMG_SHAPE[1]/demo_patch_size)), i + 1)
    patch_img = tf.reshape(patch, (demo_patch_size, demo_patch_size, 3))
    plt.imshow(patch_img.numpy().astype("float32")/255.)
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()

### 5.1.2 ViT - Implement Patch Encoder as a Layer


---

The **`PatchEncoder`** layer will linearly transform a patch by projecting it into a vector of size **`projection_dim`**. In addition, it adds a learnable position embedding to the projected vector.

In [None]:
class PatchEncoder(tf.keras.layers.Layer):
    def __init__(self, num_patches, projection_dim):
        '''
        Implements the PatchEncoder Block

        Args:
            num_patches (int): Number of input image patches
            projection_dim (int): Dimension of the output projection vector
        
        Returns:
            None; this is initialization
        '''
        super(PatchEncoder, self).__init__()
        
        self.num_patches = num_patches
        self.projection_dim = projection_dim
        
        self.dense_projection = tf.keras.layers.Dense(units=self.projection_dim)
        self.positions = tf.reshape(tf.range(start=0, limit=self.num_patches, delta=1), (self.num_patches,))
        self.position_embedding = tf.keras.layers.Embedding(input_dim=self.num_patches, 
                                                            output_dim=self.projection_dim)

    def call(self, patch):
        ''' Returns the encoded patches of image '''
        encoded = self.dense_projection(patch) + self.position_embedding(self.positions)
        return encoded
    

In [None]:
# Define an arbitrary projection_dim
demo_projection_dim = 128

# Instantiate the PatcheEncoder layer and call on the demo image
with tf.device('/CPU:0'):
    patch_encoder = PatchEncoder(num_patches=patches.shape[1], projection_dim=demo_projection_dim)
    encoded_patches = patch_encoder(patches)

print(f"Number of Patches: {patches.shape[1]}")    
print(f"Patch size: {(demo_patch_size, demo_patch_size)}")
print(f"Encoded patches shape: {encoded_patches.shape}")

# PATCH LAYER FUNCTIONALITY

# 1. Plot the original image
print("\n\n... ORIGINAL IMAGE ...\n")
plt.figure(figsize=(18,9))
plt.imshow(demo_img[0].numpy().astype("float32")/255.)
plt.axis('off')
plt.tight_layout()
plt.show()

# 2. Plot the patches in the same shape/order as the original image
print("\n\n\n... IMAGE PATCHES ...\n")
plt.figure(figsize=(18,9))
for i, patch in enumerate(patches[0]):
    w,h = int(np.ceil(IMG_SHAPE[0]/demo_patch_size)), int(np.ceil(IMG_SHAPE[1]/demo_patch_size))
    ax = plt.subplot(w,h,i+1)
    patch_img = tf.reshape(patch, (demo_patch_size, demo_patch_size, 3))
    plt.imshow(patch_img.numpy().astype("float32")/255.)
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()

with tf.device('/CPU:0'):
    enc_patches_rescaled = patch_creator(tf.expand_dims(tf.image.resize(tf.reshape(tf.reduce_mean(encoded_patches[0], axis=-1), (w, h, 1)), (IMG_SHAPE[0], IMG_SHAPE[1])), axis=0))

# 3. Plot the image embedding patches
print("\n\n\n... IMAGE EMBEDDING AS PATCH VISUALIZATION ...\n")
plt.figure(figsize=(18,9))
for i, enc_patch in enumerate(enc_patches_rescaled[0]):
    ax = plt.subplot(w,h,i+1)
    enc_patch_img = tf.reshape(enc_patch, (demo_patch_size, demo_patch_size))
    plt.imshow(enc_patch_img.numpy().astype("float32")/255., cmap="jet")
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()


### 5.1.3 Build the ViT Model

---

The ViT model consists of multiple Transformer blocks, which use the layers.MultiHeadAttention layer as a self-attention mechanism applied to the sequence of patches. The Transformer blocks produce a **[batch_size, num_patches, projection_dim]** tensor, which is processed via an classifier head with softmax to produce the final class probabilities output.

Unlike the technique described in the paper, which prepends a learnable embedding to the sequence of encoded patches to serve as the image representation, all the outputs of the final Transformer block are reshaped with **`tf.keras.layers.Flatten()`** and used as the image representation input to the classifier head. Note that the **`tf.keras.layers.GlobalAveragePooling1D`** layer could also be used instead to aggregate the outputs of the Transformer block, especially when the number of patches and the projection dimensions are large.

In [None]:
class ViTEncoder(tf.keras.Model):
    ''' Creates the stack of encoders '''
    def __init__(self, patch_size=16, projection_dim=256, n_transformer_layers=8,
                n_heads=4, dropout=0.1, img_shape=IMG_SHAPE):
        """
        Initilizes the varibales needed to create the model
        
        Args:
            patch_size (int): Size of the patches to be extracted from the input images
            projection_dim (int): Dimension of the output projection vector
            n_transformer_layers (int): Number of transformer layers in the architecture
            n_heads (int): Number of attention head to be used
            dropout (float): Dropout value to be used in MLP dropout layers
            img_shape (int): Shape of the input image

        Returns:
            None, this is a initialization     
        """
        super(ViTEncoder, self).__init__()
        
        # Layer Arguments
        self.patch_size = patch_size
        self.n_patches = tf.cast(tf.round((img_shape[0]/self.patch_size)*(img_shape[1]/self.patch_size)), tf.int32)
        self.img_shape = img_shape
        self.projection_dim = projection_dim
        self.n_transformer_layers = n_transformer_layers
        self.mlp_intermediate_units = [self.projection_dim*2, self.projection_dim,]
        self.n_heads = n_heads
        self.dropout = dropout
        
        # Layers
        self.patch_creator = PatchCreator(self.patch_size)
        self.patch_encoder = PatchEncoder(self.n_patches, self.projection_dim)
        self.ln_1_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.ln_2_layer = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.add_1_layer = tf.keras.layers.Add()
        self.add_2_layer = tf.keras.layers.Add()
        self.mha_layer = tf.keras.layers.MultiHeadAttention(num_heads=self.n_heads, 
                                                            key_dim=self.projection_dim, 
                                                            dropout=self.dropout)
        self.intermediate_mlp_dense_layers   = [tf.keras.layers.Dense(transformer_units, activation=tf.nn.gelu) \
                                                for transformer_units in self.mlp_intermediate_units]
        self.intermediate_mlp_dropout_layers = [tf.keras.layers.Dropout(self.dropout) \
                                                for transformer_units in self.mlp_intermediate_units]                
        
    def call(self, x, training):
        """Creating the Stack of Encoders in transformer
        
        Args:
            x (array) : Input token embedding
            training (bool): Whether to train or not       
        
        Returns:
            Stack of encoder architecture
        """
        patches = self.patch_creator(x, training=training)
        encoded = self.patch_encoder(patches, training=training)
        
        for _ in range(self.n_transformer_layers):
            
            # Layer Norm 1
            x1 = self.ln_1_layer(encoded, training=training)
            
            # Create a multi-head attention layer.
            attention_output = self.mha_layer(x1, x1, training=training)
            
            # Skip Connection 1
            x2 = self.add_1_layer([attention_output, encoded], training=training)
            
            # Layer Norm 2
            x3 = self.ln_2_layer(x2, training=training)
            
            # Intermediate MLP
            for i in range(len(self.mlp_intermediate_units)):
                x3 = self.intermediate_mlp_dense_layers[i](x3, training=training)
                x3 = self.intermediate_mlp_dropout_layers[i](x3, training=training)
            
            # Skip Connection 2
            encoded = self.add_2_layer([x3, x2], training=training)
            
        return encoded

In [None]:
with tf.device('/CPU:0'):
    ViT = ViTEncoder(patch_size=demo_patch_size, projection_dim=demo_projection_dim)
    demo_encoder_output= ViT(tf.ones((1, *IMG_SHAPE)))
    IMG_SEQ_LEN, IMG_EMB_DEPTH = demo_encoder_output.shape[1], demo_encoder_output.shape[2]
    
print(f"Encoder Output Shape: {demo_encoder_output.shape}")    
print(f"Output 'Sequence' Length: {IMG_SEQ_LEN}")
print(f"Output 'Sequence' Feature Depth: {IMG_EMB_DEPTH}")

# PATCH LAYER FUNCTIONALITY

# 1. Plot the original image
print("\n\n... ORIGINAL IMAGE ...\n")
plt.figure(figsize=(18,9))
plt.imshow(demo_img[0].numpy().astype("float32")/255.)
plt.axis('off')
plt.tight_layout()
plt.show()

# 2. Plot the patches in the same shape/order as the original image
print("\n\n\n... IMAGE PATCHES ...\n")
plt.figure(figsize=(18,9))
for i, patch in enumerate(patches[0]):
    w,h = int(np.ceil(IMG_SHAPE[0]/demo_patch_size)), int(np.ceil(IMG_SHAPE[1]/demo_patch_size))
    ax = plt.subplot(w,h,i+1)
    patch_img = tf.reshape(patch, (demo_patch_size, demo_patch_size, 3))
    plt.imshow(patch_img.numpy().astype("float32")/255.)
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()

with tf.device('/CPU:0'):
    enc_patches_rescaled = patch_creator(tf.expand_dims(tf.image.resize(tf.reshape(tf.reduce_mean(encoded_patches[0], axis=-1), (w, h, 1)), (IMG_SHAPE[0], IMG_SHAPE[1])), axis=0))

# 3. Plot the image embedding patches
print("\n\n\n... IMAGE EMBEDDING AS PATCH VISUALIZATION ...\n")
plt.figure(figsize=(18,9))
for i, enc_patch in enumerate(enc_patches_rescaled[0]):
    ax = plt.subplot(w,h,i+1)
    enc_patch_img = tf.reshape(enc_patch, (demo_patch_size, demo_patch_size))
    plt.imshow(enc_patch_img.numpy().astype("float32")/255., cmap="jet")
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()

with tf.device('/CPU:0'):
    tmp_img = tf.image.resize(tf.reshape(tf.reduce_mean(demo_encoder_output[0], axis=-1), (w, h, 1)), (IMG_SHAPE[0], IMG_SHAPE[1]))
    tmp_img = tf.cast(255*((tmp_img-tf.math.reduce_min(tmp_img))/(tf.math.reduce_max(tmp_img)-tf.math.reduce_min(tmp_img))), dtype=tf.uint8)
    enc_output_patches_rescaled = patch_creator(tf.expand_dims(tmp_img, axis=0))

# 4. Plot the encoder output as patches
print("\n\n\n... ENCODER OUTPUT AS PATCH VISUALIZATION ...\n")
plt.figure(figsize=(18,9))
for i, enc_out_patch in enumerate(enc_output_patches_rescaled[0]):
    ax = plt.subplot(w,h,i+1)
    enc_out_patch_img = tf.reshape(enc_out_patch, (demo_patch_size, demo_patch_size))
    plt.imshow(enc_out_patch_img.numpy().astype("float32")/255., cmap="jet")
    plt.axis("off")
plt.tight_layout()
plt.subplots_adjust(wspace=0.03, hspace=0.06)
plt.show()

print("\n\n\n\n\tSUMMARY\n")
ViT.summary()

## 5.2 UNDERSTANDING THE MODELS - TRANSFORMER

### 5.2.0 TRANSFORMER - HYPERPARAMETERS

In [None]:
D_MODEL = IMG_EMB_DEPTH
N_PE_POS = 72
D_FF = 1024

print(f"\n... THE INPUT 'SEQUENCE LENGTH'                  IS {IMG_SEQ_LEN}  (output of image encoder - shape flattened) ...")
print(f"... THE INPUT 'EMBEDDING DEPTH'                  IS {IMG_EMB_DEPTH}  (output of image encoder - # of channels) ...")
print(f"... THE NUMBER OF POSITIONAL ENCODING POSITIONS  IS {N_PE_POS}   (arbitray) ...\n")

### 5.2.1 TRANSFORMER - POSITIAL ENCODING

---

Since this model doesn't contain any recurrence or convolution, positional encoding is added to give the model some information about the relative position of the words in the sentence. 

The positional encoding vector is added to the embedding vector. 
* Embeddings represent a token in a **`d-dimensional`** space where tokens (encoded vectors) with similar meaning (feature representation) will be closer to each other. 

But the embeddings do not encode the relative position of words in a sentence (or in our case the localization of features as encoded by our **efficientnetv2 encoder model**).
* So after adding the positional encoding, words (feature representations) will be closer to each other based on the ***similarity of their meaning and their position in the sentence (feature vector)***, in the **`d-dimensional`** space.

See the notebook on **[positional encoding](https://www.tensorflow.org/tutorials/text/transformer#positional_encoding)** to learn more about it. The formula for calculating the positional encoding is as follows:

---

$$\Large{PE_{(pos, 2i)} = sin(pos / 10000^{2i / d_{model}})} $$
$$\Large{PE_{(pos, 2i+1)} = cos(pos / 10000^{2i / d_{model}})} $$

---

In [None]:
def get_angles(pos, i, d_model):
    ''' Function to calculate the angle between feature representations ''' 
    angle_rates = tf.constant(1, TARGET_DTYPE) / tf.math.pow(tf.constant(10000, TARGET_DTYPE), 
                                                            (tf.constant(2, dtype=TARGET_DTYPE) * tf.cast((i//2), 
                                                            TARGET_DTYPE))/d_model)
    return pos * angle_rates

def do_interleave(arr_a, arr_b):
    ''' Function to perform interleaving '''
    a_arr_tf_column = tf.range(arr_a.shape[1])*2 # [0 2 4 ...]
    b_arr_tf_column = tf.range(arr_b.shape[1])*2+1 # [1 3 5 ...]
    column_indices = tf.argsort(tf.concat([a_arr_tf_column,b_arr_tf_column],axis=-1))
    column, row = tf.meshgrid(column_indices,tf.range(arr_a.shape[0]))
    combine_indices = tf.stack([row,column],axis=-1)
    combine_value = tf.concat([arr_a,arr_b],axis=1)
    return tf.gather_nd(combine_value,combine_indices)

def positional_encoding_1d(position, d_model):
    ''' Function to calculate the positional encodings for 1-D data '''
    angle_rads = get_angles(tf.cast(tf.range(position)[:, tf.newaxis], TARGET_DTYPE),
                            tf.cast(tf.range(d_model)[tf.newaxis, :], TARGET_DTYPE),
                            d_model)
    
    # apply sin to even indices in the array; 2i
    sin_angle_rads = tf.math.sin(angle_rads[:, ::2])
    cos_angle_rads = tf.math.cos(angle_rads[:, 1::2])
    angle_rads = do_interleave(sin_angle_rads, cos_angle_rads)
    pos_encoding = angle_rads[tf.newaxis, ...]
    return pos_encoding

def np_positional_encoding_2d(row, col, d_model):
    assert d_model % 2 == 0
    row_pos = np.repeat(np.arange(row),col)[:,np.newaxis]
    col_pos = np.repeat(np.expand_dims(np.arange(col),0),row,axis=0).reshape(-1,1)

    angle_rads_row = get_angles(row_pos,np.arange(d_model//2)[np.newaxis,:],d_model//2).numpy()
    angle_rads_col = get_angles(col_pos,np.arange(d_model//2)[np.newaxis,:],d_model//2).numpy()
    
    angle_rads_row[:, 0::2] = np.sin(angle_rads_row[:, 0::2])
    angle_rads_row[:, 1::2] = np.cos(angle_rads_row[:, 1::2])
    angle_rads_col[:, 0::2] = np.sin(angle_rads_col[:, 0::2])
    angle_rads_col[:, 1::2] = np.cos(angle_rads_col[:, 1::2])
    pos_encoding = np.concatenate([angle_rads_row,angle_rads_col],axis=1)[np.newaxis, ...]
    return tf.cast(pos_encoding, dtype=TARGET_DTYPE)

def positional_encoding_2d(row, col, d_model):
    row_pos = tf.repeat(tf.range(row), col)[:, tf.newaxis]
    col_pos = tf.reshape(tf.repeat(tf.expand_dims(tf.range(col),0), row, axis=0), (-1, 1))

    angle_rads_row = get_angles(tf.cast(row_pos, tf.float32), tf.range(d_model//2)[tf.newaxis,:], d_model//2)
    angle_rads_col = get_angles(tf.cast(col_pos, tf.float32), tf.range(d_model//2)[tf.newaxis,:], d_model//2)

    sin_angle_rads_row = tf.math.sin(angle_rads_row[:, ::2])
    cos_angle_rads_row = tf.math.cos(angle_rads_row[:, 1::2])
    angle_rads_row = do_interleave(sin_angle_rads_row, cos_angle_rads_row)

    sin_angle_rads_col = tf.math.sin(angle_rads_col[:, ::2])
    cos_angle_rads_col = tf.math.cos(angle_rads_col[:, 1::2])
    angle_rads_col = do_interleave(sin_angle_rads_col, cos_angle_rads_col)
    
    pos_encoding = tf.concat([angle_rads_row,angle_rads_col],axis=1)[tf.newaxis, ...]
    return pos_encoding

pos_encoding = positional_encoding_1d(256, 512)

print(pos_encoding.shape)

plt.figure(figsize=(6,4))
plt.pcolormesh(tf.cast(pos_encoding[0], tf.float32), cmap='RdBu')
plt.xlim((0, 512))
plt.ylim((0, 256))
plt.xlabel('Depth', fontweight="bold")
plt.ylabel('Position', fontweight="bold")
plt.title("Visualization of Positional Encoding", fontweight="bold")
plt.colorbar()
plt.show()

### 5.2.2 TRANSFORMER - MASKING

---


Mask all the pad tokens in the batch of sequence. It ensures that the model does not treat padding as the input. 

The mask indicates where pad value **`0`** is present: 
* it outputs a **`1`** at those locations
* it outputs a **`0`** otherwise.

---

The **look-ahead mask** is used to mask the future tokens in a sequence. 

In other words, the mask indicates which entries should not be used.
* This means that to predict the third token, only the first and second tokens will be used. 
* Similarly to predict the fourth token, only the first, second and the third tokens will be used and so on.

In [None]:
def create_padding_mask(seq):
    ''' 
    Function to add extra dimensions to add the padding to the attention logits.
       - (batch_size, 1, 1, seq_len)
    '''
    seq = tf.cast(tf.math.equal(seq, 0), TARGET_DTYPE)
    return seq[:, tf.newaxis, tf.newaxis, :]

def create_look_ahead_mask(size):
    ''' Function to create a look ahead mask '''
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    # (seq_len, seq_len)
    return tf.cast(mask, TARGET_DTYPE)

def create_mask(inp, tar):
    '''
    Function to combine the look-ahead and padding masks to be used in the 1st attention block in the decoder.
    It is used to pad and mask future tokens in the input received by the decoder.
    '''
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)
    return tf.cast(combined_mask, TARGET_DTYPE)

x = tf.constant([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
print(f"x --\n{x}\n\n\nPADDING MASK\n{create_padding_mask(x)}\n")
print("")

x = tf.random.uniform((1, 5))
print(f"x.shape[1] -- {x.shape[1]}")
print(f"\n\nLOOK-AHEAD MASK\n{create_look_ahead_mask(x.shape[1])}\n")

### 5.2.3 TRANSFORMER - SCALED DOT-PRODUCT ATTENTION



---

Scaled dot-product attention is an attention mechanism where the dot products are scaled down by $\sqrt{d_k}$. 

---

<center><img src="https://www.tensorflow.org/images/tutorials/transformer/scaled_attention.png" width="500" alt="scaled_dot_product_attention"></center>

---

The attention function used by the transformer takes three inputs: 
* **`Q` (query)**
* **`K` (key)**
* **`V` (value)**
---

The equation used to calculate the attention weights is:

$$\Large{Attention(Q, K, V) = softmax_k(\frac{QK^T}{\sqrt{d_k}}) V} $$

---

The dot-product attention is scaled by a factor of square root of the depth. 
* This is done because for large values of depth, the dot product grows large in magnitude pushing the softmax function where it has small gradients resulting in a very hard softmax. 

For example, consider that **`Q`** and **`K`** have a mean of **`0`** and variance of **`1`**. 
* Their matrix multiplication will have a mean of **`0`** and variance of **`dk`**. 
* Hence, ***square root of `dk`*** is used for scaling (and not any other number) because the matmul of **`Q`** and **`K`** should have a mean of **`0`** and variance of **`1`**, and you get a gentler softmax.

The mask is multiplied with **`-1e9`** (close to negative infinity). 
* This is done because the mask is summed with the scaled matrix multiplication of **`Q`** and **`K`** and is applied immediately before a softmax. 
* The goal is to zero out these cells, and large negative inputs to softmax are near zero in the output.

---

As the softmax normalization is done on **`K`**, its values decide the amount of importance given to **`Q`**

The output represents the multiplication of the attention weights and the **`V` (value) vector.**
* This ensures that the words you want to focus on are kept as-is and the irrelevant words are flushed out.

In [None]:
def scaled_dot_product_attention(q, k, v, mask):
    """
    Function to calculate the attention weights.
    q, k, v must have matching leading dimensions.
    k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
    The mask has different shapes depending on its type(padding or look ahead) 
    but it must be broadcastable for addition.

    Args:
        q: query shape == (..., seq_len_q, depth)
        k: key shape == (..., seq_len_k, depth)
        v: value shape == (..., seq_len_v, depth_v)
        mask: Float tensor with shape broadcastable 
            to (..., seq_len_q, seq_len_k). Defaults to None.

    Returns:
        output, attention_weights
    """

    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)
    dk = tf.cast(tf.shape(k)[-1], TARGET_DTYPE)

    # Calculate scaled attention logits
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # Add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # Softmax is normalized on the last axis (seq_len_k) so that the scores add up to 1.
    #   - shape --> (..., seq_len_q, seq_len_k)
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  

    #   - shape --> (..., seq_len_q, depth_v)
    output = tf.matmul(attention_weights, v)
    
    return output, attention_weights


def print_out(q, k, v):
    ''' Function to print the output and attention weights from above class '''
    temp_out, temp_attn = scaled_dot_product_attention(q, k, v, None)
    print(f'Attention weights are:\n\t-->{temp_attn}')
    print(f'\nOutput is:\n\t-->{temp_out}')

# Set print options
np.set_printoptions(suppress=True)

In [None]:
# Demo inputs
temp_k = tf.constant([[10,0,0],
                      [0,10,0],
                      [0,0,10],
                      [0,0,10]], dtype=TARGET_DTYPE)  # (4, 3)

temp_v = tf.constant([[   1,0],
                      [  10,0],
                      [ 100,5],
                      [1000,6]], dtype=TARGET_DTYPE)  # (4, 2)

print(f"\n-----------------------\n\nTEMP K:\n\n{temp_k}\n")
print(f"\n-----------------------\n\nTEMP V:\n\n{temp_v}\n")

# This `query` aligns with the second `key`, so the second `value` is returned.
temp_q = tf.constant([[0, 10, 0]], dtype=TARGET_DTYPE)  # (1, 3)
print(f"\n-----------------------\n\nTEMP Q:\n\n{temp_q} \n")
print_out(temp_q, temp_k, temp_v)

# This query aligns with a repeated key (third and fourth), 
# so all associated values get averaged.
temp_q = tf.constant([[0, 0, 10]], dtype=TARGET_DTYPE)  # (1, 3)
print(f"\n-----------------------\n\nTEMP Q:\n\n{temp_q} \n")
print_out(temp_q, temp_k, temp_v)

# This query aligns equally with the first and second key, 
# so their values get averaged.
temp_q = tf.constant([[10, 10, 0]], dtype=TARGET_DTYPE)  # (1, 3)
print(f"\n-----------------------\n\nTEMP Q:\n\n{temp_q} \n")
print_out(temp_q, temp_k, temp_v)

temp_q = tf.constant([[0, 0, 10], [0, 10, 0], [10, 10, 0]], dtype=TARGET_DTYPE)  # (3, 3)
print(f"\n-----------------------\n\nTEMP Q:\n\n{temp_q} \n")
print_out(temp_q, temp_k, temp_v)

### 5.2.4 TRANSFORMER - MULTI-HEAD ATTENTION

---

This is an implementation of multi-headed attention based on [**"Attention is all you Need"**](https://arxiv.org/abs/1706.03762). 
* If **`query`**, **`key`**, **`value`** are the same, then this is **self-attention**. 
* Each timestep in query attends to the corresponding sequence in key, and returns a fixed-width vector.

<br>

This layer (the MHA layer) first projects **`query`**, **`key`** and **`value`**. 
* These are (effectively) a list of tensors of length num_attention_heads, where the corresponding shapes are: 
    * **`[batch_size, 1, key_dim]`**, **`[batch_size, 1, key_dim]`**, **`[batch_size, 1, value_dim]`**

Then, the **`query`** and **`key`** tensors are **dot-producted and scaled** (see previous section). These values are softmaxed to obtain attention probabilities. The tensors are then interpolated by these probabilities, then concatenated back to a single tensor.

Finally, the result tensor with the last dimension as value_dim can take an linear projection and return.

---

<br>

<center><img src="https://www.tensorflow.org/images/tutorials/transformer/multi_head_attention.png" width="500" alt="multi-head attention"></center>

---

<br>


**Multi-head attention consists of four parts:**
*    Linear layers and split into heads.
*    Scaled dot-product attention.
*    Concatenation of heads.
*    Final linear layer.

---

Each multi-head attention block gets three inputs;
* **`Q` (query)**
* **`K` (key)**
* **`V` (value)**

These are put through linear (**`Dense`**) layers and split up into multiple heads. 

The **`scaled_dot_product_attention`** defined above is applied to each head (broadcasted for efficiency). 
* An appropriate mask must be used in the attention step.  
* The attention output for each head is then concatenated (using **`tf.transpose`**, and **`tf.reshape`**) and put through a final **`Dense`** layer

Instead of one single attention head, **`Q`**, **`K`**, and **`V`** are split into multiple heads because it allows the model to jointly attend to information at different positions from different representational spaces. 

After the split each head has a reduced dimensionality, so the total computation cost is the same as a single head attention with full dimensionality.

---

Let's create a **`MultiHeadAttention`** layer to try out. 
* At each location in the sequence, **`y`**, the **`MultiHeadAttention`** runs all **`8`** attention heads across all other locations in the sequence, returning a new vector of the same length at each location.

In [None]:
class MultiHeadAttention(tf.keras.layers.Layer):
    ''' MultiHead Attention Layer Component of the Transformer '''
    def __init__(self, d_model, num_heads):
        '''
        Args:
            d_model (int): Depth of the d-dimensional space used for positional encoding
            num_heads (int): The number of heads to use in the multi-head-attention block

        Returns:
            None, this is initialization    
        '''
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % self.num_heads == 0
        
        self.depth = d_model // self.num_heads
        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)
        
        self.dense = tf.keras.layers.Dense(d_model)
        
    def split_heads(self, x, batch_size):
        """ 
        Function to split the last dimension into (num_heads, depth).
        Then we transpose the result such that the shape is 
                - (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])
        
    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]
    
        # (batch_size, seq_len, d_model)
        q = self.wq(q)  
        # (batch_size, seq_len, d_model)
        k = self.wk(k)  
        # (batch_size, seq_len, d_model)
        v = self.wv(v)  

        # (batch_size, num_heads, seq_len_q, depth)
        q = self.split_heads(q, batch_size)  
        # (batch_size, num_heads, seq_len_k, depth)
        k = self.split_heads(k, batch_size)  
        # (batch_size, num_heads, seq_len_v, depth)
        v = self.split_heads(v, batch_size)  
        
        # scaled_attention.shape – (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape – (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(q, k, v, mask)
    
        # (batch_size, seq_len_q, num_heads, depth)
        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  

        # (batch_size, seq_len_q, d_model)
        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  

        # (batch_size, seq_len_q, d_model)
        output = self.dense(concat_attention)  
            
        return output, attention_weights

# CUSTOM
temp_mha = MultiHeadAttention(d_model=512, num_heads=8)
y = tf.random.uniform((1, 60, 512), dtype=TARGET_DTYPE)  # (batch_size, encoder_sequence, d_model)
out, attn = temp_mha(y, y, y, mask=None)
print(f"Custom MHA Layer:\n\t-->{out[0,:2]}\n\t-->{(out.shape, attn.shape)}\n")

# TF NATIVE
temp_mha = tf.keras.layers.MultiHeadAttention(8, 512)
out, attn = temp_mha(y, y, y, attention_mask=None, return_attention_scores=True)
print(f"TF MHA Layer:\n\t-->{out[0,:2]}\n\t-->{(out.shape, attn.shape)}\n")

del temp_mha

### 5.2.5 TRANSFORMER - POINT-WISE FEED FORWARD NEURAL NETWORK

---

Point wise feed forward network consists of two fully-connected layers with a ReLU activation in-between.

In [None]:
def point_wise_feed_forward_network(d_model, dff):
    '''
    Args:
        d_model (int): Depth of the d-dimensional space used for positional encoding
        dff (int): Number of units to use in the feed-forward neural network  
    
    Returns:
        Feedforward neural network 
    '''
    return tf.keras.Sequential([
        
        # INNER LAYER
        #   – (batch_size, seq_len, dff)
        tf.keras.layers.Dense(dff, activation='relu'),  
        
        # OUTPUT 
        #   – (batch_size, seq_len, d_model)
        tf.keras.layers.Dense(d_model)  
    ])

sample_ffn = point_wise_feed_forward_network(512, D_FF)
print("\nFFN INPUT & OUTPUT SHAPE: " \
      f"{sample_ffn(tf.random.uniform((64, 50, 512), dtype=TARGET_DTYPE)).shape}" \
      "\n\nFFN SUMMARY:")
print(sample_ffn.summary())

del sample_ffn

### 5.2.6 TRANSFORMER - ENCODER-DECODER NETWORK ARCHITECTURE OVERVIEW
---

The transformer model follows the same general pattern as a standard [sequence to sequence with attention model](nmt_with_attention.ipynb). 

* The input sequennce ***(image embedding sequence in our case)*** is passed through **`N` encoder layers** that generates an output for each word/token in the sequence.
* The **decoder** attends on the encoder's output and its own input (self-attention) to predict the next word/token. 

---

<center><img src="https://www.tensorflow.org/images/tutorials/transformer/transformer.png" width="600" alt="transformer"></center>

---

<br>


### 5.2.7 TRANSFORMER - ENCODER

---

**Each transformer encoder layer consists of sublayers:**

1. **Multi-Head AAttention (with padding mask)**
2. **Point-Wise Feed Forward Neural Networks**

Each of these sublayers has a **residual connection** around it followed by a **layer normalization**. 
* Residual connections help in avoiding the vanishing gradient problem in deep networks.

The output of each sublayer is **`LayerNorm(x + Sublayer(x))`**. 
* The normalization is done on the **`d_model`** (last) axis. 
* There are **`N encoder` layers** in the ***transformer***.

In [None]:
class TransformerEncoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        """ 
        Encoder Layer Component Of Transformer Encoder Block 
        
        Args:
            d_model (int): Depth of the d-dimensional space used for positional encoding
            num_heads (int): The number of heads to use in the multi-head-attention block
            dff (int): Number of units to use in the feed-forward neural network
            drop_out_rate (float): Percentage of nodes to drop in a given layer
        
        Returns:
            None; This is an intiailization
        """
        super(TransformerEncoderLayer, self).__init__()

        self.mha = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model,)
        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, training, mask=None):
        """ 
        Call function for our encoder layer
        
        Args:
            x (array): Input token embeddinng
            training (bool): Whether or not to apply certain operations (i.e. disable/enable dropout)
            mask (tensor): None if no masks are used in the MultiHead Attention layer
            
        Returns:
            The encoded input sequence
               - shape --> (batch_size, input_seq_len, d_model)
        """
        
        # returns --> (batch_size, input_seq_len, d_model)
        attn_output, _ = self.mha(x, x, x, mask, return_attention_scores=True) 

        # Potentially unncessary by passing dropout1 to tf.keras.layers.MultiHeadAttention (if using tf MHA)
        attn_output = self.dropout1(attn_output, training=training)
        
        # Residual connection followed by layer normalization
        #   – returns --> (batch_size, input_seq_len, d_model)
        out1 = self.layernorm1(x + attn_output, training=training)  
        
        # Point-wise Feed Forward Step
        #   – returns --> (batch_size, input_seq_len, d_model)
        ffn_output = self.ffn(out1, training=training)  
        ffn_output = self.dropout2(ffn_output, training=training)
        

        # Residual connection followed by layer normalization
        #   – returns --> (batch_size, input_seq_len, d_model)
        out2 = self.layernorm2(out1 + ffn_output, training=training)  
        
        return out2

sample_encoder_layer = TransformerEncoderLayer(D_MODEL, 8, D_FF)
sample_encoder_layer_output = sample_encoder_layer(demo_encoder_output, training=False, mask=None)
del sample_encoder_layer

# (batch_size, input_seq_len, d_model)
sample_encoder_layer_output.shape  

### 5.2.8 TRANSFORMER - DECODER LAYER COMPONENT

---

**Each transformer decoder layer consists of sublayers:**

1. **Masked Multi-Head Attention (with look ahead mask and padding mask)**
2. **Multi-Head Attention (with padding mask)** 
    * **`V`** (value) and **`K`** (key) receive the ***encoder output*** as inputs. 
    * **`Q`** (query) receives the ***output from the masked multi-head attention sublayer.***
3. **Point-Wise Feed Forward Networks**

Each of these sublayers has a **residual connection** around it followed by a **layer normalization**
* The output of each sublayer is **`LayerNorm(x + Sublayer(x))`**
* The normalization is done on the **`d_model`** (last) axis.

There are **`N` decoder layers** in the ***transformer***

As **`Q`** receives the output from decoder's first attention block, and **`K`** receives the encoder output, the attention weights represent the importance given to the decoder's input based on the encoder's output. 
* In other words, the decoder predicts the next word/token by looking at the encoder output and self-attending to its own output. 
* See the demonstration above in the scaled dot product attention section.

In [None]:
class TransformerDecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, dropout_rate=0.1):
        """ 
        Decoder Layer Component Of Transformer Block 
        
        Args:
            d_model (int): Depth of the d-dimensional space used for positional encoding of image embedding
            num_heads (int): Number of heads to use in the multi-head-attention block
            dff (int): Number of units to use in the feed-forward neural network
            dropout_rate (float): Percentage of nodes to drop in a given layer
        
        Returns:
            None; This is an intiailization
        """
        super(TransformerDecoderLayer, self).__init__()

        # WE COULD USE A CUSTOM DEFINED MHA MODEL BUT WE WILL USE TFA INSTEAD
        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)
        #
        # # Multi Head Attention Layers
        # self.mha1 = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model,)
        # self.mha2 = tf.keras.layers.MultiHeadAttention(num_heads, key_dim=d_model,)

        # Feed Forward NN
        self.ffn = point_wise_feed_forward_network(d_model, dff)
    
        # Layer Normalization Layers
        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        
        # Dropout Layers
        self.dropout1 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout2 = tf.keras.layers.Dropout(dropout_rate)
        self.dropout3 = tf.keras.layers.Dropout(dropout_rate)
    
    # enc_output.shape == (batch_size, input_seq_len, d_model)
    def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
        """ 
        Call function for our encoder layer
        
        Args:
            x (array): token embeddinng (batch_size, output_seq_len, embedding_dim)
            enc_output (array): token embeddinng (batch_size, output_seq_len, embedding_dim)
            training (bool): Whether or not to apply certain operations (i.e. disable/enable dropout)
            look_ahead_mask (array): None if no look ahead mask is used
            padding_mask (array): None if no padding mask is used 
            
        Returns:
            The encoded input sequence
               - shape --> (batch_size, input_seq_len, d_model)
        """
        
        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)
        attn1 = self.dropout1(attn1, training=training)

        # Residual connection followed by layer normalization
        #   – (batch_size, target_seq_len, d_model)
        out1 = self.layernorm1(attn1 + x, training=training)
    
        # Merging connection between encoder and decoder (MHA)
        #   – (batch_size, target_seq_len, d_model)
        attn2, attn_weights_block2 = self.mha2(enc_output, enc_output, out1, padding_mask) 
        attn2 = self.dropout2(attn2, training=training)
        
        # Residual connection followed by layer normalization
        #   – (batch_size, target_seq_len, d_model)
        out2 = self.layernorm2(attn2 + out1, training=training)  
        
        # (batch_size, target_seq_len, d_model)
        ffn_output = self.ffn(out2, training=training)  
        ffn_output = self.dropout3(ffn_output, training=training)

        # Residual connection followed by layer normalization
        #   – (batch_size, target_seq_len, d_model)
        out3 = self.layernorm3(ffn_output + out2, training=training)  
        
        return out3, attn_weights_block1, attn_weights_block2

sample_decoder_layer = TransformerDecoderLayer(D_MODEL, 8, D_FF)
sample_decoder_layer_output, _, _ = sample_decoder_layer(tf.random.uniform((BATCH_SIZE_DEBUG, MAX_LEN, D_MODEL), dtype=TARGET_DTYPE), sample_encoder_layer_output, False, None, None)
del sample_decoder_layer

sample_decoder_layer_output.shape  # (batch_size, target_seq_len, d_model)

### 5.2.9 TRANSFORMER - ENCODER COMPONENT

---

The **`TransformerEncoder`** consists of:
1.   Input Embedding
2.   Positional Encoding
3.   **`N`** encoder layers

<br>

The input is put through an embedding which is summed with the positional encoding. 
* The output of this summation is the input to the encoder layers. The output of the encoder is the input to the decoder.

In [None]:
class TransformerEncoder(tf.keras.layers.Layer):
    ''' Encoder Component of the Transformer Block '''
    def __init__(self, num_layers, d_model, num_heads, dff,
                 maximum_position_encoding, dropout_rate=0.1):
        '''
        Args:
            num_layers (int): Number of encoder layers to be used
            d_model (int): Depth of the d-dimensional space used for positional encoding of image embedding
            num_heads (int): Number of heads to use in the multi-head-attention block
            dff (int): Number of units to use in the feed-forward neural network
            maximum_positional_encoding (int): Maximum limit for positional encoding
            dropout_rate (float): Percentage of nodes to drop in a given layer
        
        Returns:
            None; This is an intiailization
        '''
        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers
        self.pos_encoding = positional_encoding_1d(maximum_position_encoding, self.d_model)
        self.enc_layers = [TransformerEncoderLayer(d_model, num_heads, dff, dropout_rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
        
    def call(self, x, training, mask=None):
        """
        Sequence of Operations:
            1.  Embed the input data as a fixed length vector
            2.  Scale the fixed length vector by the square root of the 
                    input/output dimensionality
            3.  Introduce the position encoding into the data
            4.  Perform some amount of dropout
            5.  Pass our preprocessed input data into a stack of encoding layers
                    along with the input mask
        """

        # adding embedding and position encoding.
        #   – (batch_size, input_seq_len, d_model)
        x += self.pos_encoding
        x = self.dropout(x, training=training)
        
        for i in range(self.num_layers):
            x = self.enc_layers[i](x, training, mask)
    
        #   – (batch_size, input_seq_len, d_model)
        return x  

sample_encoder = TransformerEncoder(num_layers=2, 
                         d_model=D_MODEL, 
                         num_heads=8, 
                         dff=D_FF,
                         maximum_position_encoding=IMG_SEQ_LEN)

sample_encoder_output = sample_encoder(demo_encoder_output, training=False, mask=None)
del sample_encoder

print(sample_encoder_output.shape)  # (batch_size, input_seq_len, d_model)

### 5.2.10 TRANSFORMER - DECODER COMPONENT

---

1.  Output Embedding
2.   Positional Encoding
3.   **`N`** decoder layers

<br>

The target is put through an **embedding** which is **summed with the positional encoding.**
* The output of this summation is the input to the decoder layers. 
* The output of the decoder is the input to the final linear layer.

In [None]:
class TransformerDecoder(tf.keras.layers.Layer):
    ''' Decoder Layer Component for the Transformer '''
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, 
                 maximum_position_encoding, dropout_rate=0.1):
        '''
        Args:
            num_layers (int): Number of decoder layers to be used
            d_model (int): Depth of the d-dimensional space used for positional encoding of image embedding
            num_heads (int): Number of heads to use in the multi-head-attention block
            dff (int): Number of units to use in the feed-forward neural network
            maximum_positional_encoding (int): Maximum limit for positional encoding
            dropout_rate (float): Percentage of nodes to drop in a given layer
        
        Returns:
            None; This is an intiailization
        '''
        super(TransformerDecoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers
        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding_1d(maximum_position_encoding, d_model)
        self.dec_layers = [TransformerDecoderLayer(d_model, num_heads, dff, dropout_rate) for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(dropout_rate)
    
    def call(self, x, enc_output, training, look_ahead_mask=None, padding_mask=None):
        ''' Calling function for the decoder '''
        seq_len = tf.shape(x)[1]
        attention_weights = {}
        
        # adding embedding and position encoding.
        #   – (batch_size, target_seq_len, d_model)
        x = self.embedding(x)  
        x *= tf.math.sqrt(tf.cast(self.d_model, TARGET_DTYPE))
        x += self.pos_encoding[:, :seq_len, :]
        
        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training, look_ahead_mask, padding_mask)
            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2
    
        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights


sample_decoder = TransformerDecoder(num_layers=2, d_model=D_MODEL, num_heads=8, 
                         dff=D_FF, target_vocab_size=VOCAB_LEN,
                         maximum_position_encoding=MAX_LEN)
temp_input = tf.random.uniform((1, MAX_LEN), dtype=tf.int64, minval=0, maxval=VOCAB_LEN)
output, attn = sample_decoder(temp_input, 
                              enc_output=demo_encoder_output, 
                              training=False,
                              look_ahead_mask=None, 
                              padding_mask=None)
del sample_decoder
output.shape, attn['decoder_layer2_block2'].shape

### 5.2.11 TRANSFORMER - PUT IT ALL TOGETHER

---


Our Transformer consists of the **transformer encoder**, **transformer decoder** and a **final linear layer**. 
* The input to the encoder is the output of our image encoder (i.e. output of EfficientNetV2)
* The output of the decoder is the input to the linear layer and its output is returned.

In [None]:
class Transformer(tf.keras.Model):
    """ Final Transformer Architecture Block """
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size, 
                 pe_input, pe_target, dropout_rate=0.1):
        '''
        Args:
            num_layers (int): Number of decoder layers to be used
            d_model (int): Depth of the d-dimensional space used for positional encoding of image embedding
            num_heads (int): Number of heads to use in the multi-head-attention block
            dff (int): Number of units to use in the feed-forward neural network
            target_vocab_size (int): Vocabulary size of the target output
            pe_input (int): Input positional encodings
            pe_target (int): Target positional encoding
            dropout_rate (float): Percentage of nodes to drop in a given layer
        
        Returns:
            None; This is an intiailization
        '''
        
        super(Transformer, self).__init__()
        
        self.t_encoder = TransformerEncoder(num_layers, d_model, num_heads, 
                                            dff, pe_input, dropout_rate)
        self.t_decoder = TransformerDecoder(num_layers, d_model, num_heads, dff, 
                                            target_vocab_size, pe_target, dropout_rate)
        self.t_final_layer = tf.keras.layers.Dense(target_vocab_size)
    
    def call(self, t_inp, t_tar, training, enc_padding_mask=None, 
             look_ahead_mask=None, dec_padding_mask=None):
        '''
        Args:
            t_inp: Input sequences 
            t_tar: Target sequences 
            training (bool): Whether training is to be done
            enc_padding_mask (array or None): Padding masks from the encoder
            look_ahead_mask: None if no look ahead masks to be used
            dec_padding_mask (array or None): Padding masks from the decoder

        Returns:
            final_output and attention_weights of the Transformer model
        '''
        # (batch_size, inp_seq_len, d_model)
        enc_output = self.t_encoder(t_inp, training, enc_padding_mask)  
        
        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.t_decoder(t_tar, enc_output, 
                                                       training, look_ahead_mask, 
                                                       dec_padding_mask)

        # (batch_size, tar_seq_len, target_vocab_size)
        final_output = self.t_final_layer(dec_output)  
    
        return final_output, attention_weights


# sample_transformer = Transformer(num_layers=2, d_model=D_MODEL, 
#                                  num_heads=8, dff=1024,  
#                                  target_vocab_size=VOCAB_LEN, 
#                                  pe_input=IMG_SEQ_LEN, pe_target=MAX_LEN)
# fn_out, _ = sample_transformer(demo_encoder_output, SAMPLE_LBLS, training=False, 
#                                enc_padding_mask=None, 
#                                look_ahead_mask=None,
#                                dec_padding_mask=None)

# # (batch_size, tar_seq_len, target_vocab_size)
# print(fn_out.shape)

# del sample_transformer

## 5.3 CREATE A LEARNING RATE SCHEDULER
---

We utiliize the learning rate scheduler from the "Attention Is All You Need" paper with some minor tweaks.

In [None]:
print("\n... LEARNING RATE SCHEDULE CREATION STARTING ...\n")

# Part of the Training Configuration
EPOCHS = 30
TOTAL_STEPS = TRAIN_STEPS*EPOCHS

# Learning Rate Scheduler Configuration
WARM_STEPS = (TRAIN_STEPS-1)*4 # Suuuuuper long ramp-up

class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    ''' Creates a custom Learning Rate Scheduler '''
    def __init__(self, d_model, warmup_steps=4000):
        '''
        Args:
            d_model (int): Depth of the d-dimensional space used for positional encoding of image embedding
            warmup_steps (int): Number of steps for which the learning rate will ramp up to the desired peak learning rate value
        '''
        super(CustomSchedule, self).__init__()

        self.d_model = tf.cast(d_model, tf.float32)
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        ''' Returns learning rate at different steps '''
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model*1.75) * tf.math.minimum(arg1, arg2)
    
temp_learning_rate_schedule = CustomSchedule(D_MODEL, WARM_STEPS)
plt.plot(temp_learning_rate_schedule(tf.range(TRAIN_STEPS*EPOCHS, dtype=tf.float32)))
plt.ylabel("Learning Rate")
plt.xlabel("Train Step")
plt.show()
    
# def lr_schedule_fn(step, total_steps, warm_lr_start, warm_steps, peak_lr_start, lr_final, n_epochs):
#     """ Function to generate the learning rate for a given step based on parameters
    
#     Args:
#         step (int): The current step for which to calculate the respective learning rate
#         total_steps (int): The total number of steps for the entire training regime
#         warm_lr_start (float): The starting learning rate prior to warmup
#         warm_steps (int): The number of steps for which the learning rate will ramp up
#             to the desired peak learning rate value (more steps will result in less
#             dramatic changes to existing weights... better for pretrained models)
#         peark_lr_start (float): The starting learning rate after warmup (peak value)
#         lr_final (float): The final learning rate to step down to by the end of training
#         n_epochs (int): The total number of epochs for the training regime
    
#     Returns:
#         The learning rate (float) to be used for a given step
#     """
    
#     # exponential warmup
#     if step < warm_steps:
#         warmup_factor = (step / warm_steps) ** 2
#         lr_rate = warm_lr_start + (peak_lr_start - warm_lr_start) * warmup_factor    
    
#     # staircase decay
#     else:
#         power = (step - warm_steps) // ((total_steps - warm_steps) / (n_epochs + 1))
#         decay_factor =  ((peak_lr_start / lr_final) ** (1 / n_epochs)) ** power
#         lr_rate = peak_lr_start / decay_factor
        
#     return round(lr_rate, 8)


# def plot_lr_schedule(lr_schedule, name=""):
#     """ Plot the learning rate schedule over the course of training
    
#     Args:
#         lr_schedule (list of floats): The values to use for the LR over the
#             course of training
#         name (str, optional): A name for the LR schedule
    
#     Returns:
#         None; A plot of the how the learning rate changes over time will be displayed
    
#     """
#     schedule_info = f'start: {lr_schedule[0]:.6f}, max: {max(lr_schedule):.6f}, final: {lr_schedule[-1]:.6f}'
#     plt.figure(figsize=(18,6))
#     plt.plot(lr_schedule)
#     plt.title(f"Step Learning Rate Schedule {name+', ' if name else name}{schedule_info}", size=16, fontweight="bold")
#     plt.grid()
#     plt.show()
    
# class LRS():
#     """ LEARNING RATE SCHEDULER OBJECT"""
#     def __init__(self, optimizer, lr_schedule):
#         self.opt = optimizer
#         self.lr_schedule = lr_schedule
        
#         # assign initial learning rate
#         self.lr = lr_schedule[0]
#         self.opt.learning_rate.assign(self.lr)
        
#     def step(self, step):
#         self.lr = self.lr_schedule[step]
#         # assign learning rate to optimizer
#         self.opt.learning_rate.assign(self.lr)
        
#     def get_counter(self):
#         return self.c
    
#     def get_lr(self):
#         return self.lr

# # Create the Schedule and Plot
# lr_schedule = [
#     lr_schedule_fn(step, TOTAL_STEPS, WARM_START_LR, WARM_STEPS, PEAK_START_LR, FINAL_LR, EPOCHS) \
#     for step in range(TOTAL_STEPS)
# ]
# plot_lr_schedule(lr_schedule)

print("\n... LEARNING RATE SCHEDULE CREATION FINISHED ...\n")

## 5.4 WRAP THE CONFIGURATION DETAILS IN A CLASS OBJECT FOR EASY ACCESS

---


In [None]:
# Hyperparameters For ViT
VIT_PATCH_SIZE=16
VIT_PROJECTION_DIM=128
VIT_N_TRANSFORMER_LAYERS=8
VIT_N_HEADS=4
VIT_DROPOUT=0.1

# Hyperparameters For Transformer
N_LAYERS = 4
D_MODEL = IMG_EMB_DEPTH
D_FF = 256
N_HEADS = 4
DROPOUT_RATE = 0.1
PE_INPUT =  IMG_SEQ_LEN
PE_OUTPUT = MAX_LEN
TARGET_V_SIZE = VOCAB_LEN

class Config():
    ''' Class to initialize the Encoder, Decoder and Learning Rate configurations '''
    def __init__(self,):
        self.encoder_config = {}
        self.transformer_config = {}
        self.lr_config = {}
        
    def initialize_encoder_config(self, patch_size, projection_dim, 
                                n_transformer_layers, n_heads, dropout, img_shape):
        self.encoder_config = dict(
            patch_size=patch_size, 
            projection_dim=projection_dim, 
            n_transformer_layers=n_transformer_layers, 
            n_heads=n_heads, 
            dropout=dropout, 
            img_shape=img_shape)
        
    def initialize_transformer_config(self, vocab_len, n_transformer_layers, 
                                      transformer_d_dff, transformer_n_heads, 
                                      encoder_out_seq_len, encoder_out_depth, 
                                      dropout_rate=0.1):
        self.transformer_config = dict(num_layers=n_transformer_layers, 
                                       d_model=encoder_out_depth, 
                                       num_heads=transformer_n_heads, 
                                       dff=transformer_d_dff,
                                       target_vocab_size=vocab_len, 
                                       pe_input=encoder_out_seq_len, 
                                       pe_target=MAX_LEN, dropout_rate=0.1)
        
    def initialize_lr_config(self, warm_steps, n_epochs):
        self.lr_config = dict(
            warm_steps=warm_steps, 
            n_epochs=n_epochs,
        )
        
training_config = Config()
training_config.initialize_transformer_config(vocab_len=VOCAB_LEN,  
                                              n_transformer_layers=N_LAYERS,
                                              transformer_d_dff=D_FF,
                                              transformer_n_heads=N_HEADS,
                                              encoder_out_seq_len=IMG_SEQ_LEN, 
                                              encoder_out_depth=IMG_EMB_DEPTH)

training_config.initialize_encoder_config(patch_size=VIT_PATCH_SIZE, 
                                          projection_dim=VIT_PROJECTION_DIM, 
                                          n_transformer_layers=VIT_N_TRANSFORMER_LAYERS, 
                                          n_heads=VIT_N_HEADS, 
                                          dropout=VIT_DROPOUT, 
                                          img_shape=IMG_SHAPE)

training_config.initialize_lr_config(warm_steps=WARM_STEPS, n_epochs=EPOCHS,)

print(f"\nTRAINING ENCODER CONFIG:\n\t--> {training_config.encoder_config}\n")
print(f"\nTRAINING TRANSFORMER CONFIG:\n\t--> {training_config.transformer_config}\n")
print(f"TRAINING LEARNING RATE CONFIG:\n\t--> {training_config.lr_config}\n")

## 5.5 HOW TPU IMPACTS MODELS, METRICS, AND OPTIMIZERS

In order to use TPU, or [**tensorflow distribute strategy**](https://www.tensorflow.org/api_docs/python/tf/distribute) in general, certain objects will have to be created inside the **strategy's scope**

---

Here is the rule of thumb:

---

* Anything that creates variables that will be used in a distributed way must be created inside **`strategy.scope()`**.
* This includes, but is not limited to:
  - model creation
  - optimizer
  - metrics
  - sometimes, checkpoint restore
  - any custom code that creates distributed variables
* Once a variable is created inside a strategy's scope, it captures the strategy's information, and **you can use it outside the strategy's scope.**
* Unless using a high level API like **`model.fit()`**, defining something within the strategy's scope **WILL NOT automatically distribute the computation**. This will be discussed more in the section on training further down.

---

Inside the scope, everything is defined in the same way it would be outside the distribution strategy. There is, however, a particularity about the loss function which we will discuss further down as well.

**In the next cell, we instantiate the learning rate function, the loss object, and the model(s) inside the scope**

<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>
    - <a href="https://www.tensorflow.org/api_docs/python/tf/distribute/experimental/TPUStrategy#scope"><b>TPUStrategy - Scope</b></a><br>
    - <a href="https://colab.research.google.com/github/tensorflow/tpu/blob/master/tools/colab/custom_training.ipynb#scrollTo=s_suB7CZNw5W"><b>Tutorial - Custom Training With TPUs</b></a><br>
</div>

In [None]:
print("\n... TRAINING PREPERATION STARTING ...\n")

def prepare_for_training(lr_config, encoder_config, transformer_config, encoder_wts=None, transformer_wts=None, verbose=0):
    """ 
    Declare required objects under TPU session scope and return ready for training
    
    Args:
        lr_config (dict): Keyword arguments mapped to desired values for lr schedule function
        encoder_config (dict): Keyword arguments mapped to desired values for encoder model instantiation    
        transformer_config (dict): Keyword arguments mapped to desired values for transformer model instantiation    
        encoder_wts (str, optional): Path to pretrained model weights for encoder
        transformer_wts (str, optional): Path to pretrained model weights for encoder
        verbose (bool, optional): Whether or not to print model information and plot lr schedule
        
    Returns:
        loss_fn - Loss function used during training
        metrics - Loss and Accuracy metrics for training and validation data
        optimizer - Optimizer used for training
        lr_scheduler - Learning rate scheduler  
        encoder - Encoder model
        decoder - Decoder model
    """

    # Everything must be declared within the scope when leveraging the TPU strategy
    #     - This will still function properly if scope is set to another type of accelerator
    with strategy.scope():
        
        print("\t--> CREATING LOSS FUNCTION ...")
        # Declare the loss object
        #     - Sparse categorical cross entropy loss is used as root loss
        loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE
        )
        
        def loss_fn(real, pred):
            # Convert to uint8
            mask = tf.math.not_equal(real, 0)
            loss_ = loss_object(real, pred)
            loss_ *= tf.cast(mask, dtype=loss_.dtype)

            # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
            loss_ = tf.nn.compute_average_loss(loss_, global_batch_size=REPLICA_BATCH_SIZE)
            return loss_
        
        
        # def loss_fn(real, pred):
        #     per_example_loss = loss_object(real, pred)
        #     return tf.nn.compute_average_loss(per_example_loss, global_batch_size=OVERALL_BATCH_SIZE)
        
        # Declare the metrics
        #    - Loss (train only) and sparse categorical accuracy will be used
        print("\t--> CREATING METRICS ...")
        metrics = {
            'batch_loss':tf.keras.metrics.Mean(),
            'train_loss': tf.keras.metrics.Mean(),
            'train_acc': tf.keras.metrics.SparseCategoricalAccuracy(),
            'val_loss': tf.keras.metrics.Mean(),
            'val_acc': tf.keras.metrics.SparseCategoricalAccuracy(),
            'val_lsd': tf.keras.metrics.Mean(), 
        }
        
        
        print("\t--> CREATING LEARNING RATE SCHEDULER ...")
        # Declare the learning rate schedule (try this as actual lr schedule and list...)
        lr_scheduler = CustomSchedule(transformer_config["d_model"], lr_config["warm_steps"])
        
        print("\t--> CREATING OPTIMIZER ...")
        # Instiate an optimizer
        optimizer = tf.keras.optimizers.Adam(lr_scheduler)
        
        # Instantiate the encoder model 
        print("\t--> CREATING ENCODER MODEL ARCHITECTURE ...")
        encoder = ViTEncoder(**encoder_config)
        initialization_batch = encoder(
            tf.ones(((REPLICA_BATCH_SIZE,)+encoder_config["img_shape"]), dtype=TARGET_DTYPE), 
            training=False,
        )
                
        # Instantiate the decoder model
        print("\t--> CREATING TRANSFORMER MODEL ARCHITECTURE...")
        transformer = Transformer(**transformer_config)
        transformer(initialization_batch, tf.random.uniform((REPLICA_BATCH_SIZE, 1)), training=False)
        
        if encoder_wts is not None:
            print("\t--> LOADING ENCODER MODEL WEIGHTS ...")
            encoder.load_weights(encoder_wts)

        
        if transformer_wts is not None:
            print("\t--> LOADING TRANSFORMER MODEL WEIGHTS (WILL OVERWRITE ENCODER)...")
            transformer.load_weights(transformer_wts)
        
    # Show the model architectures and plot the learning rate
    if verbose:
        print("\n\n... ENCODER MODEL SUMMARY...\n")
        print(encoder.summary())

        print("\n\n... TRANSFORMER MODEL SUMMARY...\n")
        print(transformer.summary())

        # print("\n\n... LR SCHEDULE PLOT...\n")
        # plot_lr_schedule(lr_schedule)
  
    return loss_fn, metrics, optimizer, lr_scheduler, encoder, transformer
    
    
print("\n... GENERATING THE FOLLOWING:")
# Instantiate our required training components in the correct scope
loss_fn, metrics, optimizer, lr_scheduler, encoder, transformer = \
    prepare_for_training(lr_config=training_config.lr_config,
                         encoder_config=training_config.encoder_config,
                         transformer_config=training_config.transformer_config,
                         encoder_wts=(ENCODER_CKPT_PATH if ENCODER_CKPT_PATH!="" else None),
                         transformer_wts=(TRANSFORMER_CKPT_PATH if TRANSFORMER_CKPT_PATH!="" else None),
                         verbose=1,)

print("\n... TRAINING PREPERATION FINISHED ...\n")

## 5.6 LOSS CLASSES AND REDUCTION

In order to accurately calculate loss when leveraging a TPU, we have to accumulate the losses that will be calculated across the individual replicas. Knowing this we are limited to using a **`reduction`** value of **`SUM`** or **`NONE`** as the default value and some of the other options will not work with TPU.

---

During training, when a batch is [**distributed to the replicas**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run), each replica receives a part of the batch and
calculates the loss values separately. We **SHOULD NOT** calculate the average of the per-example losses on the (partial) batch the replica recieves. 

**The intuition behind this is as follows:**

---
* The gradients calculated on each replica will be synced across the replicas
    * Therefore, they are summed before the optimizer applies the gradients to update the model's parameters
* If we use the averaged per examples loss to compute the graident on each replica, the final graident applied by the optimizer will correspond to the sum of these averaged per-examples losses for respective replicas.
    * This is incorrect. The optimizer should apply the gradient obtained from the averaged per-examples loss **over the whole distributed batch**
    * It's worth noting that each replica may infact receive different number of examples. 
    * Therefore it is impossible, in general, to obtain the averaged per example loss over the whole distributed batch from by simply dividing it by the number of replicas.

---

**Therefore, we can see that for each replica, we calculate the sum of per examples losses divided by the batch size of the whole distributed batch, which will give the optimizer the correct gradients to apply.**

**EDIT**
* **In this notebook, we have the option to use [*gradient accumulation*](https://arxiv.org/pdf/1710.02368)**
* In ***gradient accumulation***, each replica receives several batches before the optimizer applies the graidents
    * we divide the sum of per examples losses by the update size (i.e. the number of examples used for one parameter update) rather than by the size of a single distributed batch.

**In the following cell we will demonstrate, using dummy values and pretending we are distributing them, how to deal with the accumulation of the loss values across replicas.**

## 5.7 DISTRIBUTE THE DATASETS ACROSS REPLICAS

With an input pipeline written using the [**tf.data.Dataset**](https://www.tensorflow.org/api_docs/python/tf/data/Dataset) API, we can use [**strategy.experimental_distribute_dataset**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#experimental_distribute_dataset) to turn it into a ***distributed dataset***, which produces **`per-replica`** values (which are objects of type [**PerReplica**](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/distribute/values.py#L361)) when iterating over it. 

For example, 

```python
    ds = (... something that is a `tf.data.Dataset` ...)
    dist_ds = strategy.experimental_distribute_dataset(ds)
```

**`dist_ds`** will now be distributed across all replicas.

---

The distributed datasets (when working with TPU) contain objects of type [**tensorflow.python.distribute.values.PerReplica**](https://github.com/tensorflow/tensorflow/blob/v2.3.0/tensorflow/python/distribute/values.py#L361), which is a subclass of [**tf.distribute.DistributedValues**](https://www.tensorflow.org/api_docs/python/tf/distribute/DistributedValues) that is the base class for representing distributed values.

When iterating over the dataset we will still get a tuple containing two values. However, the tuple now contains **`PerReplica`** objects wheras before that tuple contained tensors representing the image and the label/id respectively.

## 5.8 DISTRIBUTED COMPUTATION & OPTIMIZING LOOPS

For each distributed batch (which contains **`PerReplica`** objects as discussed previously) produced by a distributed dataset, we use [**`strategy.run`**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run) to perform a distributed computation on different TPU replicas, each processes a part of the batch.


---

To understand how **`strategy.run`** will execute across the replicas, we can look at an example:

```python
    @tf.function
    def dist_step(dist_batch):
        strategy.run(replica_fn, args=dist_batch)
        
    for dist_batch in dist_ds:
        dist_step(dist_batch)
```

Here **`replica_fn`** is a function that is going to be run on each replica, and it should work with **tensors**, not with **`PerReplica`** objects.
* You define the operations (for example, forward pass, compute loss values and gradients, etc.) to peform just like witout using TPU. 

---

When working with **`TPU`**, either [**`strategy.run`**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run) has to be called inside [**`tf.function`**](https://www.tensorflow.org/api_docs/python/tf/function) or the replica function has to be annotated with [**`tf.function`**](https://www.tensorflow.org/api_docs/python/tf/function). 

For example:

```python
    @tf.function
    def replica_fn(batch):
        
        model(batch)
        ...
        
    for dist_batch in dist_ds:
        strategy.run(replica_fn, args=dist_batch)
```

The above code snippet is a high level concept, and **`replica_fn`** doesn't necessary receive a single argument. 
* In our case, the original dataset yields tuples of tensors
* A distributed batch is also a tuple of **`PerReplica`** objects and the **`replica_fn`** is actually receiving the unpacked version of a tuple of tensors as arguments.

---

If a dataset yield a single tensor, you can do things like 

```python
    @tf.function
    def replica_fn(batch):
        
        tensor0 (, ... tensorN) = batch
        model(tensor0, ... tensorN)

    strategy.run(replica_fn, args=(dist_batch,))
```

where **`replica_fn`** expects a single tensor as arugment. Even if a dataset yields tuples of tensors, the above code still works, but **`replica_fn`** expects a single tuple of tensors as argument.

---

We also have to discuss how to collect the returned values from [**`strategy.run`**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run).

The results of [**`strategy.run`**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run) are also 
distributed values, just like the distributed batches it takes as inputs. 
* For each return value, we can use [strategy.experimental_local_results](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#experimental_local_results) to obtain a tuple of tensors from all replicas, and we can use [**`tf.concat`**](https://www.tensorflow.org/api_docs/python/tf/concat) to aggregate them into a single tensor.
* We will use this method to collect the labels and model predictions

---

We will need to iterate over the dataset to perform inference/train on the whole (distributed) dataset. When leveraging a TPU this is a non-trivial task. An example of iterating over a distributed dataset is:

```python
    for dist_batch in dist_ds:
        dist_step(dist_batch)
```

Every step in the loop, which calls [**`strategy.run`**](https://www.tensorflow.org/api_docs/python/tf/distribute/Strategy#run), will have a communication between the local VM (in our case, the Kaggle VM) and the remote TPU worker(s). 

**This is obviously not ideal.**

However, you can iterate the distributed dataset inside a `tf.function` as shown by:

``` python
    @tf.function
    def dist_run_on_dataset(dist_ds):
    
        for dist_batch in dist_ds:
            dist_step(dist_batch)
            
    dist_process_dataset(dist_ds)
```

This way, all the operations conducted on the dataset are compiled into a graph which is sent to the remote TPU worker(s) for execution. This will vastly reduce the running time and limit the time TPUs will sit idle waiting for data from the local VM. See [**TPU: extreme optimizations**](https://www.kaggle.com/c/flower-classification-with-tpus/discussion/135443) for a good benchmark by [**Martin Görner**](https://www.kaggle.com/mgornergoogle).

In this notebook, we use a fixed number of training steps, so we can also use

```python    
    @tf.function
    def dist_process_dataset(dist_ds_iter):
    
        for _ in tf.range(n_stes):
            dist_step(next(dist_ds_iter))
            
    dist_ds_iter = iter(dist_ds)
    dist_process_dataset(dist_ds_iter)
```

---

**With the above discussions, we are ready to define the routines used for training, validation and prediction. Let's get started!**

<div class="alert alert-block alert-info" style="margin: 2em; line-height: 1.7em; font-family: Verdana;">
    <b style="font-size: 16px;">📖 &nbsp; REFERENCE:</b><br><br>
    - <a href="https://www.tensorflow.org/tutorials/distribute/custom_training#using_iterators"><b>Tutorial - Using Iterators</b></a><br>
    - <a href="https://www.tensorflow.org/tutorials/distribute/custom_training#iterating_inside_a_tffunction"><b>Tutorial - Iterating Inside a <code>tf.function</code></b></a><br>
    - <a href="https://www.kaggle.com/c/flower-classification-with-tpus/discussion/135443"><b>Kaggle Discussion - TPU: Extreme Optimizations</b></a><br>
    - <a href="https://www.kaggle.com/mgornergoogle/custom-training-loop-with-100-flowers-on-tpu#Optimized-custom-training-loop"><b>Kaggle Notebook - Custom Training Loop With 100+ Flowers on TPU</b></a><br>
</div>


# 6. Model_training
---


In this section we will define the training and validation routines as well as the final custom training loop that will execute everything we have worked on up until this point.

## 6.1 INDIVIDUAL TRAIN STEP

---

INFORMATION

In [None]:
def train_step(_image_batch, _inchi_batch):
    """ 
    Forward pass of the training which calculates the gradients.
    
    Args:
        _image_batch: Distributed batches of image dataset
        _inchi_batch: Distributed batches of inchi dataset
    
    Returns:
        None; Calculates the gradients and loss metric
    """
    _inchi_batch_input  = _inchi_batch[:, :-1]
    _inchi_batch_target = _inchi_batch[:, 1:]
    combined_mask = create_mask(_inchi_batch_input, _inchi_batch_target)

    with tf.GradientTape() as tape:
        _image_embedding = encoder(_image_batch, training=True)
        prediction_batch, _ = transformer(_image_embedding, _inchi_batch_input, training=True, look_ahead_mask=combined_mask)
        
        # Update Loss Accumulator
        batch_loss = loss_fn(_inchi_batch_target, prediction_batch)/(MAX_LEN-1)

        # Update Accuracy Metric
        metrics["train_acc"].update_state(_inchi_batch_target, prediction_batch, 
                                          sample_weight=tf.where(tf.not_equal(_inchi_batch_target, PAD_TOKEN), 1.0, 0.0))


    # backpropagation using variables, gradients and loss
    #    - split this into two seperate optimizers/lrs/etc in the future
    #    - we use the batch loss accumulation to update gradients
    gradients = tape.gradient(batch_loss, encoder.trainable_variables + transformer.trainable_variables)
    gradients, _ = tf.clip_by_global_norm(gradients, 10.0)
    optimizer.apply_gradients(zip(gradients, encoder.trainable_variables+transformer.trainable_variables))
    
    metrics["batch_loss"].update_state(batch_loss)
    metrics["train_loss"].update_state(batch_loss)

@tf.function
def dist_train_step(_image_batch, _inchi_batch):
    strategy.run(train_step, args=(_image_batch, _inchi_batch))

## 6.2 INDIVIDUAL VAL STEP

---

INFORMATION

In [None]:
def val_step(_image_batch, _inchi_batch):
    """ 
    Forward pass of the validation step
    
    Args:
        image_batch: Distributed batches of image dataset
        inchi_batch: Distributed batches of inchi dataset
    
    Returns:
        predictions_seq_batch: Predictions on the validation batch
    """
    
    # Initialize batch_loss
    batch_loss = tf.constant(0.0, TARGET_DTYPE)       
    transformer_pred_batch = tf.ones((REPLICA_BATCH_SIZE, 1), dtype=tf.uint8)
    
    # Get image embedding (once)
    _image_embedding = encoder(_image_batch, training=False)
    
    # Teacher forcing - feeding the target as the next input
    for c_idx in range(1, MAX_LEN):
        gt_batch_id = _inchi_batch[:, c_idx]
        combined_mask = create_mask(_inchi_batch, transformer_pred_batch)
        
        # predictions.shape == (batch_size, seq_len, vocab_size)
        prediction_batch, attention_weights = transformer(_image_embedding, transformer_pred_batch, training=False, look_ahead_mask=combined_mask)
        predicted_batch_id = prediction_batch[:, -1:, :]
        
        # Update Loss Accumulator
        batch_loss += loss_fn(gt_batch_id, predicted_batch_id[:, -1])
    
        # Update Accuracy Metric
        metrics["val_acc"].update_state(gt_batch_id, predicted_batch_id[:, -1],
                                        sample_weight=tf.where(tf.not_equal(gt_batch_id, PAD_TOKEN), 1.0, 0.0))

        # no teacher forcing, predicted char is next transformer input
        transformer_pred_batch = tf.concat([transformer_pred_batch, tf.cast(tf.argmax(predicted_batch_id, axis=-1), tf.uint8)], axis=-1)
        
    # Update Loss Metric
    metrics["val_loss"].update_state(batch_loss)
    return transformer_pred_batch    

    
@tf.function
def dist_val_step(_val_dist_ds):
    _val_image_batch, _val_inchi_batch = next(_val_dist_ds)
    predictions_seq_batch_per_replica = strategy.run(val_step, args=(_val_image_batch, _val_inchi_batch))
    predictions_seq_batch_accum = strategy.gather(predictions_seq_batch_per_replica, axis=0)
    _val_inchi_batch_accum = strategy.gather(_val_inchi_batch, axis=0)
    return predictions_seq_batch_accum, _val_inchi_batch_accum

## 6.3 INITIALIZE LOGGER

---

INFORMATION

In [None]:
class StatLogger():
    ''' This class initializes the Logger '''
    def __init__(self, verbose_frequency=100, print_style="tight"):
        self.train_loss = []
        self.train_acc = []
        self.val_loss = []
        self.val_acc = []
        self.val_lsd = []
        self.step = []
        self.epoch = []
        self.lr = []
        
        self.current_step = 0
        self.epoch_start_time = 0
        self.batch_start_time = 0
        self.verbose_frequency = verbose_frequency
        self.print_style = print_style
        
    def print_last_val(self, current_time):
        if self.print_style=="tight":
            print(f"| VAL DATA |  STEP {VAL_STEPS:>4}/{VAL_STEPS} |  " \
                  f"ACC: {str(self.val_acc[-1]*100)[:5]:<5} – " \
                  f"LOSS: {str(self.val_loss[-1])[:5]:<5} – " \
                  f"LSD: {str(self.val_lsd[-1]):<3} |")
        else:
            print(f'\n\n{"-"*100}\n{"-"*100}\n' \
                  f'{"-"*25:<25}{"VALIDATION ACCURACY : "+str(self.val_acc[-1]*100): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"VALIDATION LOSS     : "+str(self.val_loss[-1]): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"VALIDATION LSD      : "+str(self.val_lsd[-1]): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n{"-"*100}\n\n')
    
    def print_current_train(self, step, train_acc, train_loss, batch_loss, current_time, current_lr):
        if self.print_style=="tight":
            print(f"| TRAIN DATA |  STEP {self.current_step:>4}/{TRAIN_STEPS} | " \
                  f"ACC: {str(train_acc*100)[:5]:<5} – " \
                  f"LOSS: {str(train_loss)[:5]:<5} – " \
                  f"LR: {current_lr:.2e} " \
                  f"|   | TIME |  EPOCH: {str(round((current_time-self.epoch_start_time)/3600,1))+'h':<5} – " \
                  f"SUBSET: {str(round((current_time-self.batch_start_time)*self.verbose_frequency,1))+'s':<6} – " \
                  f"BATCH: {str(round(current_time-self.batch_start_time,1))+'s':<5} |")
        else:
            print(f'\n\n{"-"*100}\n{"-"*100}\n' \
                  f'{"-"*25:<25}{"CURRENT STEP : "+str(step)+" OF "+str(TRAIN_STEPS): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"CURRENT TRAIN ACCURACY : "+str(train_acc*100): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"CURRENT TRAIN LOSS     : "+str(train_loss): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"LAST BATCH LOSS        : "+str(batch_loss): ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n{"-"*100}\n' \
                  f'{"-"*25:<25}{"EPOCH ELAPSED TIME  : "+str(round(current_time-self.epoch_start_time,1))+" SECONDS": ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"LAST SET OF BATCHES TOOK  : ~"+str(round((current_time-self.batch_start_time)*self.verbose_frequency,1))+" SECONDS": ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n' \
                  f'{"-"*25:<25}{"LAST SINGLE BATCH TOOK  : "+str(round(current_time-self.batch_start_time,1))+" SECONDS": ^50}{"-"*25:>25}\n' \
                  f'{"-"*100}\n{"-"*100}\n\n')

## 6.4 CUSTOM TRAIN LOOP

---

INFORMATION

In [None]:
# Instantiate our tool for logging
stat_logger = StatLogger()
    
for epoch in range(1,EPOCHS+1):
    print(f'\n\n{"="*100}\n{"="*25:<25}{"EPOCH #"+str(epoch): ^50}{"="*25:>25}\n{"="*100}\n')
    
    stat_logger.current_step=0
    stat_logger.epoch_start_time = time.time() # to compute epoch duration
    
    # create distributed versions of dataset to run on TPU with 8 computation units
    train_dist_ds = strategy.experimental_distribute_dataset(train_ds)
    val_dist_ds = iter(strategy.experimental_distribute_dataset(val_ds))
    
    for image_batch, inchi_batch in train_dist_ds:
                
        # Update current step
        stat_logger.batch_start_time = time.time()
        
        # Update the current step
        stat_logger.current_step += 1
        
        # Calculate training step
        dist_train_step(image_batch, inchi_batch)
        
        # end of epoch validation step
        if stat_logger.current_step == TRAIN_STEPS and epoch%2==0:
            print("\n... VALIDATION DATASET STATISTICS ... \n")
            for _ in range(VAL_STEPS):
                preds, lbls = dist_val_step(val_dist_ds)
                metrics["val_lsd"].update_state(get_levenshtein_distance(preds, lbls))
                
            # Record this epochs statistics
            stat_logger.train_loss.append(metrics["train_loss"].result().numpy())
            stat_logger.train_acc.append(metrics["train_acc"].result().numpy())
            stat_logger.val_loss.append(metrics["val_loss"].result().numpy())
            stat_logger.val_acc.append(metrics["val_acc"].result().numpy())
            stat_logger.val_lsd.append(metrics["val_lsd"].result().numpy())
            stat_logger.step.append(stat_logger.current_step)
            stat_logger.epoch.append(epoch)
            stat_logger.lr.append(lr_scheduler(tf.cast(stat_logger.current_step+TRAIN_STEPS*(epoch-1), tf.float32)))
            
            # Reset the validation metrics as one epoch should not effect the next
            metrics["val_lsd"].reset_states()
            metrics["val_acc"].reset_states()
            metrics["val_loss"].reset_states()
            metrics["train_acc"].reset_states()
            metrics["train_loss"].reset_states()
            metrics["batch_loss"].reset_states()
            
            # Print validation scores
            stat_logger.print_last_val(current_time=time.time())
        
        # verbose logging step
        if stat_logger.current_step % stat_logger.verbose_frequency == 0:    
            stat_logger.print_current_train(
                stat_logger.current_step,
                metrics["train_acc"].result().numpy(), 
                metrics["train_loss"].result().numpy(), 
                metrics["batch_loss"].result().numpy(), 
                current_time=time.time(),
                current_lr=lr_scheduler(tf.cast(stat_logger.current_step+TRAIN_STEPS*(epoch-1), tf.float32))
            )
            metrics["train_acc"].reset_states()
            metrics["train_loss"].reset_states()
            metrics["batch_loss"].reset_states()

        # stop training when NaN loss is detected
        if stat_logger.current_step == TRAIN_STEPS:
            break
            
        # update learning rate
        # lr_scheduler.step(stat_logger.current_step+((epoch-1)*TRAIN_STEPS))
        
    # Save every other epoch (starting with first epoch)
    # Save after last epoch too...
    # if epoch%2==1 or epoch==EPOCHS:
    # save weights
    print("\n...SAVING MODELS TO DISK ... \n")
    transformer.save_weights(f'./transformer_epoch_{epoch}.h5')
    encoder.save_weights(f'./encoder_epoch_{epoch}.h5')

## 6.5 JUST-IN-CASE SAVE

---

INFORMATION

In [None]:
# My thing crashed so I loaded the weights from the last stable epoch to continue

# transformer.save_weights(f'./transformer_epoch_safety_save.h5')
# encoder.save_weights(f'./encoder_epoch_safety_save.h5')


# 7.  INFER ON TEST DATA   

In this section we will use our trained model to generate the predictions we will use to submit to the competition

## 7.1 INDIVIDUAL TEST STEP (AND DISTRIBUTED)

---

INFORMATION

In [None]:
def test_step(_image_batch):
    """ 
    Forward pass of the testing step
    
    Args:
        image_batch: Distributed batches of image dataset
        inchi_batch: Distributed batches of inchi dataset
    
    Returns:
        predictions_seq_batch: Predictions on the test batch
    """
    
    transformer_pred_batch = tf.ones((REPLICA_BATCH_SIZE, 1), dtype=tf.uint8)
    
    # Get image embedding (once)
    _image_embedding = encoder(_image_batch, training=False)
    
    # Teacher forcing - feeding the target as the next input
    for c_idx in range(1, MAX_LEN):
        
        combined_mask = create_mask(None, transformer_pred_batch)
        
        # predictions.shape == (batch_size, seq_len, vocab_size)
        prediction_batch, attention_weights = transformer(_image_embedding, transformer_pred_batch, training=False, look_ahead_mask=combined_mask)
        predicted_batch_id = prediction_batch[:, -1:, :]
        
        # no teacher forcing, predicted char is next transformer input
        transformer_pred_batch = tf.concat([transformer_pred_batch, tf.cast(tf.argmax(predicted_batch_id, axis=-1), tf.uint8)], axis=-1)
        
    return transformer_pred_batch 

    
@tf.function
def distributed_test_step(_img_batch, _img_ids):
    per_replica_seqs = strategy.run(test_step, args=(_img_batch,))
    predictions = strategy.gather(per_replica_seqs, axis=0)
    pred_ids = strategy.gather(_img_ids, axis=0)
    return predictions, pred_ids


@tf.function
def distributed_test_step_v2(_test_dist_ds):
    _test_image_batch, _test_id_batch = next(_test_dist_ds)
    predictions_seq_batch_per_replica = strategy.run(test_step, args=(_test_image_batch,))
    predictions_seq_batch_accum = strategy.gather(predictions_seq_batch_per_replica, axis=0)
    _test_id_batch_accum = strategy.gather(_test_id_batch, axis=0)
    return predictions_seq_batch_accum, _test_id_batch_accum

## 7.2 RAW INFERENCE LOOP

---

INFORMATION

In [None]:
# To Store The Preds
all_pred_arr = tf.zeros((1, MAX_LEN), dtype=tf.uint8)
all_pred_ids = tf.zeros((1, 1), dtype=tf.string)

# Create an iterator
dist_test_ds = iter(strategy.experimental_distribute_dataset(test_ds))
for i in tqdm(range(TEST_STEPS), total=TEST_STEPS): 
    img_batch, id_batch = next(dist_test_ds)
    preds, pred_ids = distributed_test_step(img_batch, id_batch)
    all_pred_arr = tf.concat([all_pred_arr, preds], axis=0)
    all_pred_ids = tf.concat([all_pred_ids, tf.expand_dims(pred_ids, axis=-1)], axis=0)

## 7.3 TEST PRED POST-PROCESSING

---

INFORMATION

In [None]:
def arr_2_inchi(arr):
    ''' Function to convert array to inchi '''
    inchi_str = ''
    for i in arr:
        c = int_2_tok.get(i)
        if c=="<END>":
            break
        inchi_str += c
    return inchi_str

pred_df = pd.DataFrame({
    "image_id":[x[0].decode() for x in tqdm(all_pred_ids[1:-REQUIRED_DATASET_PAD].numpy(), total=N_TEST)], 
    "InChI":[arr_2_inchi(pred_arr) for pred_arr in tqdm(all_pred_arr[1:-REQUIRED_DATASET_PAD].numpy(), total=N_TEST)]
})

pred_df = pred_df.sort_values(by="image_id").reset_index(drop=True)
pred_df

## 7.4 SAVE SUBMISSION.CSV

In [None]:
pred_df.to_csv("submission.csv", index=False)