In [None]:
%%capture
# Installs

print("\n... PIP INSTALLS COMPLETE ...\n")

#!pip install -q --upgrade pip
#!pip install -q git+https://github.com/qubvel/efficientnet.git
#import efficientnet.keras as efn
print("\n... IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")
# Machine Learning and Data Science Imports
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_addons as tfa
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np

# Built In Imports
from kaggle_datasets import KaggleDatasets
from collections import Counter
from glob import glob
import imageio
import random
import math
import time
import io
import os
import gc
import re

import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import cv2

AUTO = tf.data.experimental.AUTOTUNE
def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)

print("\n\n... IMPORTS COMPLETE ...\n")
    
print("\n... SEEDING FOR DETERMINISTIC BEHAVIOUR ...\n")
seed_it_all()

import sys
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl")
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl/brain_automl")
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl/brain_automl/efficientnetv2")

# Google brain Imports

# EfficientNet Module Imports
import brain_automl
from brain_automl import efficientnetv2
from efficientnetv2 import effnetv2_model
from efficientnetv2 import effnetv2_configs

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16' if TPU else 'float32')
    tf.config.optimizer.set_jit(True)
except:
    TPU = None
    strategy = tf.distribute.get_strategy() 
N_REPLICAS = strategy.num_replicas_in_sync



In [None]:
class DataModule:
    base_dir = 'bmsdatasetlong'
    DATA_DIR = KaggleDatasets().get_gcs_path(base_dir)
    if not TPU:
        DATA_DIR = f"/kaggle/input/{base_dir}"
    TARGET_DTYPE = tf.bfloat16 if TPU else tf.float32
    TOKEN_LIST = ["<PAD>", "InChI=1S/", "<END>", "/c", "/h", "/m", "/t", "/b", "/s", "/i"] +\
             ['Si', 'Br', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', 'C', 'H', 'B', ] +\
             [str(i) for i in range(167,-1,-1)] +\
             ["\+", "\(", "\)", "\-", ",", "D", "T"]
    # Get rid of Backslashs
    for token_idx in range(len(TOKEN_LIST)):
        TOKEN_LIST[token_idx] = TOKEN_LIST[token_idx].strip('\\')
        
    # The start/end/pad tokens will be removed from the string when computing the Levenshtein distance
    # We want them as tf.constant's so they will operate properly within the @tf.function context
    VOCAB_LEN = len(TOKEN_LIST)
    START_TOKEN = tf.constant(TOKEN_LIST.index("InChI=1S/"), dtype=tf.uint8)
    END_TOKEN = tf.constant(TOKEN_LIST.index("<END>"), dtype=tf.uint8)
    PAD_TOKEN = tf.constant(TOKEN_LIST.index("<PAD>"), dtype=tf.uint8)

    # Prefixes and Their Respective Ordering/Format
    #      -- ORDERING --> {c}{h/None}{b/None}{t/None}{m/None}{s/None}{i/None}{h/None}{t/None}{m/None}
    PREFIX_ORDERING = "chbtmsihtm"
    print(f"\n... PREFIX ORDERING IS {PREFIX_ORDERING} ...")

    # Paths to Respective Image Directories
    TRAIN_DIR = os.path.join(DATA_DIR, "train_records")
    VAL_DIR = os.path.join(DATA_DIR, "val_records")
    TEST_DIR = os.path.join(DATA_DIR, "test_records")

    # Get the Full Paths to The Individual TFRecord Files
    TRAIN_TFREC_PATHS = sorted(
        tf.io.gfile.glob(os.path.join(TRAIN_DIR, "*.tfrec")), 
        key=lambda x: int(x.rsplit("_", 2)[1]))
    VAL_TFREC_PATHS = sorted(
        tf.io.gfile.glob(os.path.join(VAL_DIR, "*.tfrec")), 
        key=lambda x: int(x.rsplit("_", 2)[1]))
    TEST_TFREC_PATHS = sorted(
        tf.io.gfile.glob(os.path.join(TEST_DIR, "*.tfrec")), 
        key=lambda x: int(x.rsplit("_", 2)[1]))

   
    # Paths to relevant CSV files containing training and submission information
    TRAIN_CSV_PATH = os.path.join("/kaggle/input", "bms-molecular-translation", "train_labels.csv")
    SS_CSV_PATH    = os.path.join("/kaggle/input", "bms-molecular-translation", "sample_submission.csv")
    # When debug is true we use a smaller batch size and smaller model
    DEBUG=False 
    
    train_df = pd.read_csv(TRAIN_CSV_PATH)
    ss_df    = pd.read_csv(SS_CSV_PATH)

    # --- Distribution Information ---
    N_EX    = len(train_df)
    N_TEST  = len(ss_df)
    N_VAL   = 80000 # Fixed from dataset creation information
    N_TRAIN = N_EX-N_VAL
    
    BATCH_SIZE_DEBUG   = 2
    REPLICA_BATCH_SIZE = 32 # Lower Batch Size -> Increase Model Capacity.
    if DEBUG:
        REPLICA_BATCH_SIZE = BATCH_SIZE_DEBUG
    OVERALL_BATCH_SIZE = REPLICA_BATCH_SIZE*N_REPLICAS


    # --- Input Image Information ---
    IMAGE_SIZE = 384
    IMG_SHAPE = (192,384,3)

    # --- Autocalculate Training/Validation Information ---
    TRAIN_STEPS = (N_TRAIN // OVERALL_BATCH_SIZE) + 1 
    VAL_STEPS   = (N_VAL   // OVERALL_BATCH_SIZE) + 1
    
    # --- Modelling Information --
    FEEDFORWARD_DIM = 2560 # Increase to 256
    DECODER_DIM   = 640 # increase to 1024 later 
    
    INPUT_LEN = 277 # 140
    MAX_LEN = 140
    seed = 7
    
    mean_train = np.array([0.485, 0.456, 0.406])#np.array([0.9871, 0.9871, 0.9871])
    mean_train = np.expand_dims(np.expand_dims(mean_train, axis = 0), axis = 0)
    std_train = np.array([0.229, 0.224, 0.225])#np.array([0.0888,0.0888,0.0888])
    std_train = np.expand_dims(np.expand_dims(std_train, axis = 0), axis = 0)
    stats = (mean_train, std_train)

In [None]:
class TrainingConfig:
    NUM_EPOCHS = 6
    PRINT_EVERY = 100
    
 
    label_smoothing = 0.1
    weight_decay = 1e-6
    
    clip_grad = 10.
    
    TRAIN_STEPS = DataModule.TRAIN_STEPS
    TOTAL_STEPS = TRAIN_STEPS * NUM_EPOCHS
    WARM_STEPS = 0.1 # Increase to 0.25 Later.
    PEAK_STEPS = 0.2 # 0.3 is ramp up, 0.7 is slowly ramp down. # Decrease to 0.1 later
    
    # --------First Run -------------------(Half, 6 EPOCHS_)
    WARM_START_LR = 1e-6
    PEAK_START_LR = 5e-4 # Manual Scaling when loading and reloading models.
    FINAL_LR = 1e-5
    # ---------Second Run ----------------(Half, 6 EPOCHS)
    #WARM_START_LR = 1e-5
    #PEAK_START_LR = 3e-5
    #FINAL_LR = 1e-6
    # ---------Third Run -----------------(Full Length, 4 EPOCHS)
    #WARM_START_LR = 1e-7
    #PEAK_START_LR = 5e-6
    #FINAL_LR = 1e-7#  ABsolute Minimum LR before nothing learns.
    # ---------Fourth Run ----------------(Full Length, 4 EPOCHS)
    #WARM_START_LR = 1e-8
    #PEAK_START+LR = 1e-7
    #FINAL_LR = 1e-8
    GRAD_ACCUMULATION = 1
    

In [None]:
def dropout(image, CT = 10, CT_WIDTH = 0.01):
    # input - one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image with CT squares of side size SZ*DIM removed
    DIM = cfg.IMAGE_SIZE
    for k in range(CT):
        # CHOOSE RANDOM LOCATION
        x = random.randint(0, DIM - 1)
        y = random.randint(0, DIM- 1)

        # COMPUTE SQUARE 
        WIDTH = DIM * CT_WIDTH
        ya = int(max(0,y-WIDTH//2))
        yb = int(min(DIM,y+WIDTH//2))
        xa = int(max(0,x-WIDTH//2))
        xb = int(min(DIM, x+WIDTH//2))

        # DROPOUT IMAGE
        one = image[ya:yb,0:xa,:]
        two = tf.zeros([yb-ya,xb-xa, 3]) 
        three = image[ya:yb,xb:DIM,:]

        middle = tf.concat([one,two,three],axis=1)
        image = tf.concat([image[0:ya,:,:],middle,image[yb:DIM,:,:]],axis=0)

    # RESHAPE HACK SO TPU COMPILER KNOWS SHAPE OF OUTPUT TENSOR 
    image = tf.reshape(image,[DIM,DIM,3])
    return image
def transform(image):
    # input image - is one image of size [dim,dim,3] not a batch of [b,dim,dim,3]
    # output - image randomly rotated, sheared, zoomed, and shifted
    DIM = cfg.IMAGE_SIZE
    XDIM = cfg.IMAGE_SIZE%2 #fix for size 331
    shift = (0.0625 * DIM) * tf.random.uniform([2],dtype='float32', minval = -1, maxval = 1)
    
    rot = 360. * tf.random.uniform([1],dtype='float32', minval = -1, maxval = 1)
    h_zoom = 1.0 + tf.random.uniform([1],dtype='float32', minval = 0, maxval = 0.5)/10.
    w_zoom = 1.0 + tf.random.uniform([1],dtype='float32', minval = 0, maxval = 0.5)/10.
    
    # GET TRANSFORMATION MATRIX
    one = tf.constant([1],dtype='float32')
    zero = tf.constant([0],dtype='float32')
    
    m = tf.reshape( tf.concat([one/h_zoom, zero, zero, zero, one/w_zoom, zero, zero, zero, one],axis=0),[3,3])

    # LIST DESTINATION PIXEL INDICES
    x = tf.repeat( tf.range(DIM//2,-DIM//2,-1), DIM)
    y = tf.tile( tf.range(-DIM//2, DIM//2),[DIM] )
    z = tf.ones([DIM*DIM],dtype='int32')
    idx = tf.stack( [x,y,z] )
    
    # ROTATE DESTINATION PIXELS ONTO ORIGIN PIXELS
    idx2 = K.dot(m,tf.cast(idx,dtype='float32'))
    idx2 = K.cast(idx2,dtype='int32')
    idx2 = K.clip(idx2,-DIM//2+XDIM+1,DIM//2)
    
    # FIND ORIGIN PIXEL VALUES           
    idx3 = tf.stack( [DIM//2-idx2[0,], DIM//2-1+idx2[1,]] )
    trans = tf.transpose(idx3)
    d = tf.gather_nd(image, trans)
    image = tf.reshape(d,[DIM,DIM,3])
    image = tfa.image.rotate(image, rot, fill_mode = 'reflect')
    image = tfa.image.translate(image, shift)
    return image
@tf.function
def shift_scale_rotate(img):
    return transform(img)
def normalize(image):
    mean, std = DataModule.stats
    image = image - mean
    image = image / std
    return image



In [None]:
def cut_off(x, y):
    return x, y[:DataModule.MAX_LEN]
def load_image(one_sample, augment = True):
    
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
        'inchi': tf.io.FixedLenFeature(shape=[DataModule.INPUT_LEN], dtype=tf.int64, default_value=[0]*DataModule.INPUT_LEN)
    }
    features = tf.io.parse_single_example(one_sample, features=feature_dict)

    image = tf.io.decode_png(features['image'], channels = 3)
    GT = features['inchi']
    GT = tf.cast(GT, tf.uint8)
    image = tf.reshape(image, DataModule.IMG_SHAPE)
    image = tf.cast(image, tf.float32) / 255.0
    if augment == True:
        pass # nothing yet
        #if tf.random.uniform(()) < 0.5:
        #    image = tf.image.flip_left_right(image)    
        #if tf.random.uniform(()) < 0.5:	        
        #    image = tf.image.flip_up_down(image)	      
        #if tf.random.uniform(()) < 0.5:
        #    # Transpose
        #    image = tf.image.transpose(image)
        #if tf.random.uniform(()) < 0.5:
        #    image = tf.image.rot90(image) 
        #if tf.random.uniform(()) < 0.5:
        #    image = shift_scale_rotate(image)
    
        #if tf.random.uniform(()) < 0.5:
        #    image = tfa.image.gaussian_filter2d(image, filter_shape = (3, 7), sigma = (0.8, 1.4))
        
        #if tf.random.uniform(()) < 0.25:
        #    image = dropout(image)
    # Normalize
    #image = normalize(image)
    return tf.cast(image, DataModule.TARGET_DTYPE), GT

In [None]:
def load_dataset():
    train_dataset = tf.data.TFRecordDataset(DataModule.TRAIN_TFREC_PATHS,
     num_parallel_reads= AUTO
    )
    
    options = tf.data.Options()
    options.experimental_deterministic = False
    
    # Pseudo Labelled Data - 1616107
    # Turned off augmentation
    train_dataset = train_dataset.with_options(options)
    train_dataset = train_dataset.map(lambda x: load_image(x, augment = True), num_parallel_calls = AUTO, deterministic = False)
    train_dataset = train_dataset.map(lambda x, y: cut_off(x, y), num_parallel_calls = AUTO, deterministic = False)
    
    train_dataset = train_dataset.repeat()
    train_dataset = train_dataset.shuffle(128, seed = DataModule.seed) # High Memory Consumption I think. I think it's tradeoff memory for performance:<
    train_dataset = train_dataset.batch(DataModule.OVERALL_BATCH_SIZE,drop_remainder=True)
    train_dataset = train_dataset.prefetch(AUTO)
    
    val_dataset = tf.data.TFRecordDataset(DataModule.VAL_TFREC_PATHS,
        num_parallel_reads = AUTO
    )
    val_dataset = val_dataset.with_options(options)
    val_dataset = val_dataset.map(lambda x: load_image(x, augment = False),  num_parallel_calls = AUTO, deterministic = False)
    val_dataset = val_dataset.map(lambda x, y: cut_off(x, y), num_parallel_calls = AUTO, deterministic = False)
    
    val_dataset = val_dataset.batch(DataModule.OVERALL_BATCH_SIZE, drop_remainder = True)
    val_dataset = val_dataset.prefetch(AUTO)
    

    return train_dataset, val_dataset

In [None]:
class eff_net_v2_model(keras.Model):
    def __init__(self, model_name, enc_dim):
        super().__init__()
        self.eff_net_v2_model = effnetv2_model.EffNetV2Model(model_name=model_name, name = 'eff_net_v2_model')
        #self.proj = keras.layers.Conv2D(enc_dim, 1, use_bias = False, kernel_initializer = 'he_uniform', activation = 'relu') # Proj Layer, as the EffNetV2's don't have them and would benefit greatly from them.
        #self.bn = keras.layers.BatchNormalization()
    def call(self, x, training, features_only = True):
        features = self.eff_net_v2_model(x, training = training, features_only = features_only) # Scale features 0 
        #return self.bn(self.proj(features, training = training), training = training)
        return features

In [None]:
def get_efficientnetv2_backbone(model_name, weight_path, enc_dim, include_top=False, input_shape=(192,384,3), pooling=None, weights=None):
    # Load Complex Weights.
    model = eff_net_v2_model(model_name, enc_dim) 
    model.build((None, *DataModule.IMG_SHAPE))
    info = model.eff_net_v2_model._layers[-1]
    model.eff_net_v2_model._layers.pop(-2)
    model.eff_net_v2_model._layers += [info]
    if weight_path is not None:
        checkpoint = tf.train.Checkpoint(net = model)
        status = checkpoint.restore(tf.train.latest_checkpoint(weight_path))
        print("Loaded Encoder Weights")
    #status.assert_existing_objects_matched() # Just Check that everything worked.
    
    return model
    
class ModelConfig:
    # ENCODER_CONFIG
    enc_dim = 640 # Controls the Dimension of the Enc Dim
    enc_dec_drop = 0.0 # Dropout Between Encoder and Decoder
    transformer_drop = 0.2
    final_drop = 0.5
    num_encoder_layers = 4
    num_decoder_layers = 6 # 6 layers is the max(Official Transformer uses 6.) - Becomes Exponentially Slow.
    num_att_heads = 12 # Standard Transformer Config.
    # CHANGE TO EFFICIENTNET V2. - Medium or Large Model(XL is too many params)
    model_name = 'efficientnetv2-l' # EfficientNetV2 Small(Will scale to Large) - Mostly for model capacity, not for weights(NS is better.)
    base_dir = 'efficientnet-v2-weights'
    GCS_PATH = KaggleDatasets().get_gcs_path(base_dir) if TPU else '../input/efficientnet-v2-weights'
    print(GCS_PATH)
    weight_path = f'{GCS_PATH}/efficientnetv2-l-21k/efficientnetv2-l-21k/'
    BB_FN = get_efficientnetv2_backbone
    PREPROCESSING_FN = tf.keras.applications.efficientnet.preprocess_input
    FREEZE_ENCODER = False
    
    tmp_model = BB_FN(model_name, weight_path,enc_dim, include_top=False, input_shape=DataModule.IMG_SHAPE)
    IMG_EMB_DIM = tmp_model(tf.ones((DataModule.BATCH_SIZE_DEBUG, *DataModule.IMG_SHAPE)), features_only = True).shape[1:]
    IMG_EMB_DIM = (IMG_EMB_DIM[0]*IMG_EMB_DIM[1], IMG_EMB_DIM[2])
    
    BaseSavePath = './save/'
    PrevModelPath = None#'../input/halftrainedtransformer/save/'
    try:
        os.mkdir(BaseSavePath)
    except:
        pass

Encoder Model

In [None]:
print("\n... ENCODER MODEL CREATION STARTING ...\n")

# SAMPLE IMAGES
train_ds, val_ds = load_dataset()
val_ds = val_ds.unbatch().batch(DataModule.BATCH_SIZE_DEBUG)
SAMPLE_IMGS, SAMPLE_LBLS = next(iter(val_ds))
    
class CNNEncoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.image_embedding_dim = ModelConfig.IMG_EMB_DIM
        self.preprocessing_fn = ModelConfig.PREPROCESSING_FN
        self.model_name = ModelConfig.model_name
        self.weight_path = ModelConfig.weight_path
        self.enc_dim = ModelConfig.enc_dim
        self.backbone_fn = ModelConfig.BB_FN
        self.img_shape = DataModule.IMG_SHAPE
        self.encoder_backbone = self.backbone_fn(self.model_name, self.weight_path, self.enc_dim, include_top=False, weights=None, input_shape=self.img_shape)
        self.dropout = ModelConfig.enc_dec_drop
        self.spat_drop = keras.layers.SpatialDropout2D(rate = self.dropout)
        self.reshape = tf.keras.layers.Reshape(self.image_embedding_dim, name='image_embedding')
    def call(self, x, training):
        """ TODO
        
        Args:
            TODO        
        
        Returns:
            TODO
        """
        x = self.preprocessing_fn(x)
        x = self.encoder_backbone(x, training=training)
        x = self.spat_drop(x, training = training)
        x = self.reshape(x, training=training)
        return x
    
    
# Example enoder output
with tf.device('/CPU:0'):
    encoder = CNNEncoder()
    img_embedding_batch = encoder(tf.cast(SAMPLE_IMGS, tf.float32))
print(f'\n... Encoder Output Shape  :  (batch_size, embedding_length, embedding_depth)  :  {img_embedding_batch.shape} ...\n')

print("\n... ENCODER MODEL CREATION FINISHED ...\n")

# Vit Transformer Head Encoder:
- i.E Hybrid CNN -> Vit Model(
- EffNetV2 L -> 4 Layer Vit

In [None]:
class EncoderMultiHeadAttention(keras.layers.Layer):
    def __init__(self, encoder_dim, num_heads, drop_prob = 0.1):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.num_heads = num_heads
        self.drop_prob = drop_prob
        
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon = 1e-6)
        self.MAH = keras.layers.MultiHeadAttention(self.num_heads, self.encoder_dim // self.num_heads)
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon = 1e-6)
        self.Linear = keras.layers.Dense(self.encoder_dim)
    def __call__(self, x, training):
        norm = self.layer_norm1(x, training = training)
        mah = self.MAH(key = norm, value = norm, query = norm, training = training) + x
        
        norm2 = self.layer_norm2(mah, training = training)
        linear = self.Linear(norm2, training = training) + mah
        return linear
class TransformerEncoder(keras.Model):
    def __init__(self):
        super().__init__()
        self.num_layers = ModelConfig.num_encoder_layers
        self.num_heads = ModelConfig.num_att_heads
        self.enc_dim = ModelConfig.IMG_EMB_DIM[-1]# Features from the CNN encoder
        self.drop_prob = ModelConfig.enc_dec_drop
        self.length_input = ModelConfig.IMG_EMB_DIM[0] # Length of the input

        self.initializer = keras.initializers.glorot_uniform(seed = 42)
        self.pos_enc = tf.Variable(self.initializer(shape = (1, self.length_input, self.enc_dim)), trainable = True)
        self.encoders = [
            EncoderMultiHeadAttention(self.enc_dim, self.num_heads, drop_prob = self.drop_prob) for i in range(self.num_layers)
        ]
    def call(self, x, training):
        x = x + tf.cast(self.pos_enc, DataModule.TARGET_DTYPE) # Should Broadcast across Batch
        for encoder in self.encoders:
            x = encoder(x, training = training)
        return x
class Encoder(keras.Model):
    def __init__(self):
        super().__init__()
        self.CNN = CNNEncoder()
        self.transformer = TransformerEncoder()
    def call(self, x, training):
        cnn = self.CNN(x, training = training)
        return self.transformer(cnn, training = training)
with tf.device("CPU: 0"):
    encoder = Encoder()
    img_embedding_batch = encoder(tf.cast(SAMPLE_IMGS, tf.float32))
    print(img_embedding_batch.shape)

Transformer Decoder

In [None]:
class DecoderMultiHeadAttention(keras.layers.Layer):
    def __init__(self, encoder_features, decoder_features, num_heads, drop_prob = 0.1):
        super().__init__()
        self.encoder_features = encoder_features
        self.decoder_features = decoder_features
        self.drop_prob = drop_prob
        self.num_heads = num_heads
        
        self.MAH = keras.layers.MultiHeadAttention(self.num_heads, self.encoder_features // self.num_heads, value_dim = self.decoder_features // self.num_heads)
        self.Dropout = keras.layers.Dropout(rate = self.drop_prob)
        self.layernorm = keras.layers.LayerNormalization(epsilon = 1e-6)
    def __call__(self, encoder_features, decoder_features, attention_mask, training):
        query = decoder_features # (B, T, Dim) 
        key = encoder_features
        value = encoder_features # (B, S, Dim) 

        attended = self.MAH(query = query, key = key, value = value, attention_mask = attention_mask,  training = training)
        return self.layernorm(self.Dropout(attended, training = training) + decoder_features, training = training)

In [None]:
class TransformerDecoder(keras.layers.Layer):
    def __init__(self, encoder_features, decoder_features, num_heads, feedforward_dim, drop_prob = 0.1):
        super().__init__()
        self.encoder_features = encoder_features
        self.decoder_features = decoder_features
        self.feedforward_dim = feedforward_dim
        
        self.num_heads = num_heads
        self.drop_prob = drop_prob
        
        self.MAH1 = DecoderMultiHeadAttention(self.decoder_features, self.decoder_features, self.num_heads, drop_prob = self.drop_prob)
        self.MAH2 = DecoderMultiHeadAttention(self.encoder_features, self.decoder_features, self.num_heads, drop_prob = self.drop_prob)
        
        self.FFN = keras.Sequential([
            keras.layers.Dense(self.feedforward_dim, activation = 'relu'),
            keras.layers.Dense(self.decoder_features)
        ])
        self.Dropout = keras.layers.Dropout(rate = self.drop_prob)
        self.layernorm = keras.layers.LayerNormalization(epsilon = 1e-6)
    def __call__(self, encoder_features, decoder_features, attention_mask, padding_mask, training):
        attention1 = self.MAH1(decoder_features, decoder_features, attention_mask, training = training)
        
        attention2 = self.MAH2(encoder_features, attention1, padding_mask, training = training)
        
        FFN = self.layernorm(self.Dropout(self.FFN(attention2, training = training), training = training) + attention2, training = training)
        
        return FFN
class FullTransformerDecoder(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        # Extract Settings from Config
        self.encoder_dim = ModelConfig.IMG_EMB_DIM[-1]
        self.decoder_dim = DataModule.DECODER_DIM
        self.feedforward_dim = DataModule.FEEDFORWARD_DIM
        self.drop_prob = ModelConfig.transformer_drop
        self.num_layers = ModelConfig.num_decoder_layers
        self.num_heads = ModelConfig.num_att_heads
        
        self.DecoderLayer = [
            TransformerDecoder(self.encoder_dim, self.decoder_dim, self.num_heads, self.feedforward_dim, drop_prob = self.drop_prob) for i in range(self.num_layers)
        ]
    
    def __call__(self, encoder_features, decoder_features, attention_mask, padding_mask, training):
        cur_decoded = decoder_features
        for decoder in self.DecoderLayer:
            cur_decoded = decoder(encoder_features, cur_decoded, attention_mask, padding_mask, training = training)
        return cur_decoded
class Decoder(keras.Model):
    # Full Decoder Model
    def positional_embeddings(self, max_len, dim):
        pos_enc = np.zeros((1, max_len, dim), dtype = np.float32)
        for L in range(max_len):
            for i in range(0, dim + 2, 2):
                if i >= dim:
                    break
                pos_enc[:, L, i] = math.sin(L / 10000 ** (i / dim))
                if i + 1 >= dim:
                    break
                pos_enc[:, L, i + 1] = math.cos(L / 10000 ** ((i + 1) / dim))

        return tf.cast(tf.identity(pos_enc), DataModule.TARGET_DTYPE)
    
    def __init__(self):
        super().__init__()
        # Extract Features from Config 
        self.vocab_len = DataModule.VOCAB_LEN
        self.decoder_dim = DataModule.DECODER_DIM
        self.final_drop = ModelConfig.final_drop
        self.max_len = DataModule.MAX_LEN
        
        self.embedding = keras.layers.Embedding(self.vocab_len, self.decoder_dim)
        self.decoder_transformer = FullTransformerDecoder()
        
        self.Dropout = keras.layers.Dropout(rate = self.final_drop)
        self.Linear = keras.layers.Dense(self.vocab_len)
        self.layer = keras.layers.Layer(dtype = tf.float32)
        
        self.pos_enc = self.positional_embeddings(self.max_len, self.decoder_dim)
    
    def mask_pad(self, input_tokens, dtype):
        # From Padidng Tokens, generates a Padding Mask(To Mimic Inference Time)
        # Input_tokens: Tensor(B, L)
        pad_tok = DataModule.PAD_TOKEN
        is_pad = tf.equal(input_tokens, pad_tok) # (B, L)
        return tf.cast(1 - tf.cast(tf.expand_dims(is_pad, axis = -1), dtype = tf.uint8),dtype) # (B, L, 1)
    def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
        """Masks the upper half of the dot product matrix in self attention.

        This prevents flow of information from future tokens to current token.
        1's in the lower triangle, counting from the lower right corner.
        """
        i = tf.range(n_dest)[:, None]
        j = tf.range(n_src)
        m = i >= j - n_src + n_dest
        mask = tf.cast(m, dtype)
        mask = tf.reshape(mask, [1, n_dest, n_src])
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
        )
        return tf.tile(mask, mult)
    def call(self, encoder_features, decoder_token, training):
        # Extract Embeddings
        # Encoder features: Tensor(B, L, C)
        # Decoder Token: tensor(B, L, 1)      
        B, L = decoder_token.shape
        C = self.decoder_dim
        _, Enc_L, _ = encoder_features.shape
        
        attention_mask = self.causal_attention_mask(B, L, L, tf.uint8) # (B, L, L)
        padding_mask = self.mask_pad(decoder_token, tf.uint8) # (B, L, 1) 
        
        # Add to the Attention Mask and Also Add Pad Mask
        attention_mask = attention_mask * padding_mask
        decoder_embeddings = self.embedding(decoder_token, training = training) # (B, L, C)
        # Stretch the Padding Mask across encoder dims
        padding_mask = tf.repeat(padding_mask, Enc_L, axis = -1)
        # Stretch Positional Encodings Across the Batch
        pos_enc = tf.repeat(self.pos_enc, B, axis = 0)[:, :L, :]
        # Add Decoder Pos Enc
        decoder_embeddings = decoder_embeddings + pos_enc
        # Convert Masks to Bool
        
        attention_mask = tf.cast(attention_mask, tf.bool)
        padding_mask = tf.cast(padding_mask, tf.bool) 
    
        # Decode
        decoded = self.decoder_transformer(encoder_features, decoder_embeddings, attention_mask, padding_mask, training = training) # (B, L, C)
        # Get Element wise Predictions
        dropped = self.Dropout(decoded, training = training)
        return self.layer(self.Linear(dropped, training = training)) # (B, NumClasses)
# Create Sample Decoder
with tf.device("CPU:0"):
    decoder = Decoder()
    input_val = tf.ones((DataModule.BATCH_SIZE_DEBUG, DataModule.MAX_LEN), dtype = tf.uint8) * DataModule.PAD_TOKEN
    import time
    cur_time = time.time()
    pred_output = decoder(img_embedding_batch, input_val, training = False)
    print(pred_output.shape)
    print(time.time() - cur_time)

# Gradient Accumulation Adam

In [None]:
class GradAccAdam():
    # Just a Wrapper to Accumulate Gradients and Send them to Adam
    def __init__(self, encoder, decoder, learning_rate, grad_acc_steps):
        self.learning_rate = learning_rate
        self.grad_acc_steps = grad_acc_steps
        
        self.weight_decay = TrainingConfig.weight_decay
        self.optimizer = tfa.optimizers.AdamW(learning_rate = self.learning_rate, weight_decay = self.weight_decay)
        
        self.PrevModelPath = ModelConfig.PrevModelPath
        if self.PrevModelPath:
            self.opt_weights = np.load(f'{self.PrevModelPath}optimizer_last.npy', allow_pickle = True)
        
            trainable_weights = encoder.trainable_weights + decoder.trainable_weights
            
            zero_grads = [tf.zeros_like(w) for w in trainable_weights]
            @tf.function
            def f():
                self.optimizer.apply_gradients(zip(zero_grads, trainable_weights))
            strategy.run(f)
            self.optimizer.set_weights(self.opt_weights)
            print("Loaded Weights")
        
        self.gradients = None
        self.cur_grad_acc = 0
    def apply_gradients(self, gradients, variables):
        if self.gradients is None:
            self.gradients = [g / tf.constant(float(self.grad_acc_steps)) for g in gradients]
            self.cur_grad_acc += 1
        else:
            for i in range(len(gradients)):
                self.gradients[i] += gradients[i] / tf.constant(float(self.grad_acc_steps))
            self.cur_grad_acc += 1
        if self.cur_grad_acc == self.grad_acc_steps:
            self.optimizer.apply_gradients(zip(self.gradients, variables))
            self.gradients = None
            self.cur_grad_acc = 0

# Lr Scheduler


In [None]:
class ParamScheduler:
    def __init__(self, start, end, num_iter):
        self.start = start
        self.end = end
        self.num_iter = num_iter
        self.idx = -1
        
        
    def step(self):
        self.idx+=1
        return self.func(self.start, self.end, self.idx/self.num_iter)
    
    def reset(self):
        self.idx=-1
        
    def is_complete(self):
        return self.idx >= self.num_iter

class CosineScheduler(ParamScheduler):
    def func(self, start_val, end_val, pct):
        cos_out = np.cos(np.pi * pct) + 1
        return end_val + (start_val - end_val)/2 * cos_out
class ConstantScheduler(ParamScheduler):
    def __init__(self, init_lr, num_steps):
        self.init_lr = init_lr
        self.num_steps = num_steps
        self.steps = -1
    def step(self):
        self.steps += 1
        return self.init_lr
    def reset(self):
        self.steps = -1
    def is_complete(self):
        return self.steps >= self.num_steps
class OneCycleScheduler(keras.callbacks.Callback):
    
    def __init__(self, init_lr, max_lr, min_lr, warm_steps, peak_steps, total_steps):
        momentums=(0.95,0.85)
        start_div=25.
        pct_start=warm_steps
        pct_climax = peak_steps# Stay at the peak for 0.1 of training.
        verbose=True
        sched=CosineScheduler
        end_div=None
        self.pct_climax = pct_climax
        self.max_lr, self.momentums, self.start_div, self.pct_start, self.verbose, self.sched, self.end_div = max_lr, momentums, start_div, pct_start, verbose, sched, end_div
        if self.end_div is None:
            self.end_div = start_div * 1e4
        self.logs = {}
        self.min_lr = min_lr
        self.init_lr = init_lr
  
        self.start_lr = self.max_lr/self.start_div
        self.end_lr = self.max_lr/self.end_div 
        self.num_iter = int(total_steps * 1.2) # Pad the Steps a bit to make sure no overflow.
        self.num_iter_1 = int(self.pct_start*self.num_iter)
        self.num_iter_2 = int(self.pct_climax * self.num_iter)
        self.num_iter_3 = self.num_iter - self.num_iter_1 - self.num_iter_2
        
        self.lr_scheds = (self.sched(self.start_lr, self.max_lr, self.num_iter_1), ConstantScheduler(self.max_lr, self.num_iter_2), self.sched(self.max_lr, self.end_lr, self.num_iter_3))
        self.sched_idx = 0 
        
    def optimizer_params_step(self):
        next_lr = self.lr_scheds[self.sched_idx].step()
        next_lr = tf.maximum(next_lr, self.min_lr)
        next_lr = tf.cast(next_lr, tf.float32)
        # update optimizer params
        optimizer.optimizer.learning_rate.assign(next_lr)
        
    def step(self):
        self.optimizer_params_step()
        if self.lr_scheds[self.sched_idx].is_complete():
            self.sched_idx += 1

# Model Config

In [None]:
class Config():
    def __init__(self,):
        self.lr_config = {}
    
    def initialize_lr_config(self, total_steps, init_lr, warm_steps, max_lr, min_lr, peak_steps):
        self.lr_config = dict(
            total_steps=total_steps, 
            init_lr=init_lr, 
            warm_steps=warm_steps, 
            max_lr = max_lr, 
            min_lr=min_lr,
            peak_steps = peak_steps
        )
        
training_config = Config()

training_config.initialize_lr_config(total_steps=TrainingConfig.TOTAL_STEPS, 
                                     init_lr =TrainingConfig.WARM_START_LR, 
                                     warm_steps = TrainingConfig.WARM_STEPS,
                                     peak_steps = TrainingConfig.PEAK_STEPS,
                                     max_lr=TrainingConfig.PEAK_START_LR, 
                                     min_lr=TrainingConfig.FINAL_LR)

print(f"TRAINING LEARNING RATE CONFIG:\n\t--> {training_config.lr_config}\n")

In [None]:
def prepare_for_training():
    """ Declare required objects under TPU session scope and return ready for training
    
    Args:
        lr_config (dict): Keyword arguments mapped to desired values for lr schedule function
        encoder_config (dict): Keyword arguments mapped to desired values for encoder model instantiation
        decoder_config (dict): Keyword arguments mapped to desired values for decoder model instantiation    
        encoder_wts (str, optional): Path to pretrained model weights for encoder
        decoder_wts (str, optional): Path to pretrained model weights for decoder
        verbose (bool, optional): Whether or not to print model information and plot lr schedule
        
    Returns:
        loss_fn - TBD
        metrics - TBD
        optimizer - TBD
        lr_scheduler - TBD
        encoder - TBD
        decoder - TBD
        
    """
    

    # Everything must be declared within the scope when leveraging the TPU strategy
    #     - This will still function properly if scope is set to another type of accelerator
    with strategy.scope():
        
        print("\t--> CREATING LOSS FUNCTION ...")
        # Declare the loss object
        #     - Sparse categorical cross entropy loss is used as root loss
        label_smoothing = TrainingConfig.label_smoothing
        loss_object = tf.keras.losses.CategoricalCrossentropy(
            from_logits=True, reduction=tf.keras.losses.Reduction.NONE, label_smoothing = label_smoothing
        )
        
        def loss_fn(real, pred):
            # Convert to uint8
            # Real: Tensor(B, L, C) 
            # PRed: Tensor(B, L, C) 
            
            num_classes = DataModule.VOCAB_LEN
            mask = tf.math.not_equal(real, 0)
            loss_ = loss_object(tf.one_hot(real, num_classes), pred)
            loss_ *= tf.cast(mask, dtype=loss_.dtype)

            # https://www.tensorflow.org/tutorials/distribute/custom_training#define_the_loss_function
            loss_ = tf.nn.compute_average_loss(loss_, global_batch_size= DataModule.REPLICA_BATCH_SIZE)
            return loss_
        
        
        print("\t--> CREATING METRICS ...")
        metrics = {
            'train_loss': tf.keras.metrics.Mean(),
            'train_acc': tf.keras.metrics.SparseCategoricalAccuracy(),
            'val_loss': tf.keras.metrics.Mean(),
            'val_acc': tf.keras.metrics.SparseCategoricalAccuracy(),
            'val_lsd': tf.keras.metrics.Mean(), 
        }
        
        
        # Instiate an optimizer
        print("\t--> CREATING ENCODER MODEL ARCHITECTURE ...")
        encoder = Encoder()
        initialization_batch = encoder(
            tf.ones(((DataModule.REPLICA_BATCH_SIZE,)+DataModule.IMG_SHAPE), dtype=DataModule.TARGET_DTYPE), 
            training=False,
        )
        print("\t--> CREATING DECODER MODEL ARCHITECTURE...")
        decoder = Decoder()
        pred_output = decoder(
            initialization_batch,
            tf.identity(np.random.randint(0, DataModule.VOCAB_LEN, size = (DataModule.REPLICA_BATCH_SIZE, DataModule.MAX_LEN), dtype = np.uint8)), 
            training=False
        )
    
        print("\t--> CREATING OPTIMIZER ...")
        optimizer = GradAccAdam(encoder, decoder, learning_rate = TrainingConfig.WARM_START_LR, grad_acc_steps = TrainingConfig.GRAD_ACCUMULATION)
        print("\t--> CREATING LEARNING RATE SCHEDULER ...")
        # Declare the learning rate schedule (try this as actual lr schedule and list...)
        lr_scheduler = OneCycleScheduler(**training_config.lr_config)
        
        # Instantiate the encoder model 
        
        
        PrevModelPath = ModelConfig.PrevModelPath
        if PrevModelPath is not None:
            encoder.load_weights(f"{PrevModelPath}enc_last.h5")
            if ModelConfig.FREEZE_ENCODER:
                encoder.trainable = False
            print("Loaded Encoder")
        # Instantiate the decoder model
        
        
        if PrevModelPath is not None:
            decoder.load_weights(f"{PrevModelPath}dec_last.h5")
            print("Loaded Decoder")
       
  
    return loss_fn, metrics, optimizer, lr_scheduler, encoder, decoder
    
    
print("\n... GENERATING THE FOLLOWING:")
# Instantiate our required training components in the correct scope
loss_fn, metrics, optimizer, lr_scheduler, encoder, decoder = \
    prepare_for_training()

print("\n... TRAINING PREPERATION FINISHED ...\n")

In [None]:
print("-"*100) # Just for printing, Summed Loss Should be 14 with BS 64

loss = loss_fn(SAMPLE_LBLS[:, :], pred_output)
print(f"\n\n... AVERAGE LOSS ACROSS BATCH:\n\t--> {loss}")
print("-"*100) # Just for printing

# Custom Training

In [None]:
def train_step(_image_batch, _inchi_batch):
    """ Forward pass (calculate gradients)
    
    Args:
        _image_batch (): TBD
        _inchi_batch (): TBD
    
    Returns:
        tbd
    """
    
    batch_loss = tf.constant(0.0, tf.float32)   
    with tf.GradientTape() as tape:
        # image_batch_embedding has shape --> (REPLICA_BATCH_SIZE, IMG_EMB_DIM)
        image_batch_embedding = encoder(_image_batch, training=True)
        # Image_Batch_Embedding: Tensor(B, L, C) 
        input_vals = _inchi_batch[:, :-1] # (B, L - 1) 
        GT = _inchi_batch[:, 1:] # (B, L - 1) 
        # Predict
        pred = decoder(image_batch_embedding, input_vals, training = True) # (B, L - 1, C)
        # Compute Loss
        batch_loss = loss_fn(GT, pred) # (1)
        # Compute Accuracy
        metrics["train_acc"].update_state(GT, pred, 
                                              sample_weight=tf.where(tf.not_equal(GT, DataModule.PAD_TOKEN), 1.0, 0.0))
    # backpropagation using variables, gradients and loss
    #    - split this into two seperate optimizers/lrs/etc in the future
    #    - we use the batch loss accumulation to update gradients
    gradients = tape.gradient(batch_loss, encoder.trainable_variables + decoder.trainable_variables)
    # Clip Grads?
    gradients = [tf.clip_by_norm(g, TrainingConfig.clip_grad) for g in gradients]
    batch_loss = batch_loss/(DataModule.MAX_LEN-1)
    
    metrics["train_loss"].update_state(batch_loss)
    
    optimizer.apply_gradients(gradients, encoder.trainable_variables+decoder.trainable_variables)
@tf.function
def dist_train_step(iterator, num_steps):
    for i in tf.range(num_steps):
        strategy.run(train_step, args=next(iterator))
def val_step(_image_batch, _inchi_batch):
    """ Forward pass (calculate gradients)
    
    Args:
        image_batch (): TBD
        inchi_batch (): TBD
    
    Returns:
        tbd
    """
    
    # Initialize batch_loss
    batch_loss = tf.constant(0.0, tf.float32)   
    
    # image_batch_embedding has shape --> (REPLICA_BATCH_SIZE, IMG_EMB_DIM)
    image_batch_embedding = encoder(_image_batch, training=False)

    predictions_seq_batch = tf.ones((DataModule.REPLICA_BATCH_SIZE, 1), dtype=tf.uint8) # (B, L)

    # Teacher forcing - feeding the target as the next input
    for c_idx in range(1, DataModule.MAX_LEN):
        gt_batch = _inchi_batch[:, c_idx]
        
        # passing enc_output to the decoder
        prediction_batch = \
            decoder(image_batch_embedding, predictions_seq_batch, training=False) # (B, L)

        # Update Loss Accumulator
        pred = prediction_batch[:, -1]
        batch_loss += loss_fn(gt_batch, pred)
        
        # Update Accuracy Metric
        metrics["val_acc"].update_state(gt_batch, pred,
                                        sample_weight=tf.where(tf.not_equal(gt_batch, DataModule.PAD_TOKEN), 1.0, 0.0))

        # no teacher forcing, predicted char is next LSTMCell input
        decoder_input_batch = tf.expand_dims(tf.cast(tf.math.argmax(keras.activations.softmax(pred, axis = -1), axis=1, output_type=tf.int32), tf.uint8), axis=1)
        predictions_seq_batch = tf.concat([predictions_seq_batch, decoder_input_batch], axis=1)
        
    # Normalize loss across all characters    
    batch_loss = batch_loss/(DataModule.MAX_LEN-1)
    
    # Update Levenshtein Distance Metric & Loss Metric
    metrics["val_loss"].update_state(batch_loss)
    
    return predictions_seq_batch    

    
@tf.function
def dist_val_step(_val_image_batch, _val_inchi_batch):
    predictions_seq_batch_per_replica = strategy.run(val_step, args=(_val_image_batch, _val_inchi_batch))
    predictions_seq_batch_accum = strategy.gather(predictions_seq_batch_per_replica, axis=0)
    _val_inchi_batch_accum = strategy.gather(_val_inchi_batch, axis=0)
    
    return predictions_seq_batch_accum, _val_inchi_batch_accum

# Visualize LR

In [None]:
# Only Uncomment on Testing. This code ruins the scheduler
#import matplotlib.pyplot as plt
#import tqdm.notebook as tqdm
#lrs = []

#for i in tqdm.tqdm(range(TrainingConfig.NUM_EPOCHS * DataModule.TRAIN_STEPS)):
#    lrs += [optimizer.optimizer.learning_rate.numpy()]
#    lr_scheduler.step()
#plt.plot(lrs)

# Stat Logger

In [None]:
class StatLogger():
    def __init__(self, verbose_frequency=100):
        self.epochs = 0
        self.steps = 0
      
        self.best_loss = float('inf')
        self.best_lsd = float('inf')
        self.best_acc = 0

    def print_train(self, metrics):
      # Train Metrics are train_accuracy
      train_acc = round(metrics['train_acc'].result().numpy().item(), 3)
      train_loss = round(metrics['train_loss'].result().numpy().item(), 3)
      # Get Cur LR
      cur_lr = optimizer.optimizer.learning_rate.numpy().item()
      print(f'E: {self.epochs}, S: {self.steps}, BL: {self.best_loss}, BLSD: {self.best_lsd}, BA: {self.best_acc}, TA: {train_acc}, TL: {train_loss}, LR: {cur_lr}')
      self.steps += TrainingConfig.PRINT_EVERY

      # Reset Metrics
      metrics['train_acc'].reset_states()
      metrics['train_loss'].reset_states()

    def print_val(self, metrics):
      
      val_acc = round(metrics['val_acc'].result().numpy().item(), 3)
      val_loss = round(metrics['val_loss'].result().numpy().item(), 3)
      val_lsd = round(metrics['val_lsd'].result().numpy().item(), 3)
        
      basePath = ModelConfig.BaseSavePath
      if val_lsd < self.best_lsd:
        self.best_lsd = val_lsd
        encoder.save_weights(f"{basePath}enc_lsd.h5")
        decoder.save_weights(f"{basePath}dec_lsd.h5")
      if val_acc > self.best_acc:
        self.best_acc = val_acc
        encoder.save_weights(f"{basePath}enc_acc.h5")
        decoder.save_weights(f'{basePath}dec_acc.h5')
      if val_loss < self.best_loss:
        self.best_loss = val_loss
        encoder.save_weights(f'{basePath}enc_loss.h5')
        decoder.save_weights(f'{basePath}dec_loss.h5')
      
      print(f"E: {self.epochs}, BA: {self.best_acc}, BL: {self.best_loss}, BLSD: {self.best_lsd}, VA: {val_acc} VL:{val_loss} VLSD: {val_lsd}")
      self.epochs += 1
      self.steps = 0
      # Reset All metrics
      for metric in metrics:
        metrics[metric].reset_states()



In [None]:
# sparse tensors are required to compute the Levenshtein distance
def dense_to_sparse(dense):
    """ Convert a dense tensor to a sparse tensor 
    
    Args:
        dense (Tensor): TBD
        
    Returns:
        A sparse tensor    
    """
    indices = tf.where(tf.ones_like(dense))
    values = tf.reshape(dense, (DataModule.MAX_LEN* DataModule.OVERALL_BATCH_SIZE,))
    sparse = tf.SparseTensor(indices, values, dense.shape)
    return sparse

def levenshtein_distance(preds, lbls):
    """ Computes the Levenshtein distance between the predictions and labels 
    
    Args:
        preds (tensor): Batch of predictions
        lbls (tensor): Batch of labels
        
    Returns:
        The mean Levenshtein distance calculated across the batch
    """
    preds = tf.where(tf.not_equal(lbls, DataModule.END_TOKEN) & tf.not_equal(lbls, DataModule.PAD_TOKEN), preds, 0)
    lbls = tf.where(tf.not_equal(lbls, DataModule.END_TOKEN), lbls, 0)
    
    preds = tf.cast(preds, tf.uint8)
    lbls = tf.cast(lbls, tf.uint8)
    
    preds_sparse = dense_to_sparse(preds)
    lbls_sparse = dense_to_sparse(lbls)

    batch_distance = tf.edit_distance(preds_sparse, lbls_sparse, normalize=False)
    mean_distance = tf.math.reduce_mean(batch_distance)
    
    return mean_distance

# Training Loop

In [None]:
# Instantiate our tool for logging
stat_logger = StatLogger()
# Load Distributed dataset
print_every = TrainingConfig.PRINT_EVERY
for epoch in range(TrainingConfig.NUM_EPOCHS):
    print("epoch: ")
    
    train_ds, val_ds = load_dataset()
    train_dist_ds = iter(strategy.experimental_distribute_dataset(train_ds))
    val_dist_ds = iter(strategy.experimental_distribute_dataset(val_ds))

    for i in tqdm(range((TrainingConfig.TRAIN_STEPS // print_every) + 1)):
        dist_train_step(train_dist_ds, print_every)
        # -------Update LR 100 Times-----------(To make up for the print every steps)(Can't do it inside a tf.function)
        for j in range(print_every):
            lr_scheduler.step()
        stat_logger.print_train(metrics)

    print("\n... VALIDATION DATASET STATISTICS ... \n")
    for images, inchi in val_dist_ds:
        preds, lbls = dist_val_step(images, inchi)
        metrics["val_lsd"].update_state(levenshtein_distance(preds, lbls))
    stat_logger.print_val(metrics)

In [None]:
# Save Final Weights of the Model
basePath = ModelConfig.BaseSavePath
with strategy.scope():
    encoder.save_weights(f"{basePath}enc_last.h5") # Save Last Encoder Weights
    decoder.save_weights(f"{basePath}dec_last.h5") # Save Last Decoder Weights
    np.save(f"{basePath}optimizer_last.npy", optimizer.optimizer.get_weights())