In [None]:
%%capture
print("\n... PIP INSTALLS COMPLETE ...\n")

print("\n... IMPORTS STARTING ...\n")
print("\n\tVERSION INFORMATION")
# Machine Learning and Data Science Imports
import tensorflow as tf
import tensorflow.keras as keras
import tensorflow_addons as tfa
import pandas as pd; pd.options.mode.chained_assignment = None;
import numpy as np

# Built In Import
from kaggle_datasets import KaggleDatasets
from collections import Counter
from glob import glob
import random
import math
from tqdm.notebook import tqdm
import os

AUTO = tf.data.experimental.AUTOTUNE
def seed_it_all(seed=7):
    """ Attempt to be Reproducible """
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)
    np.random.seed(seed)
    tf.random.set_seed(seed)
seed_it_all()

import sys
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl")
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl/brain_automl")
sys.path.append("/kaggle/input/experimental-efficientnetv2/automl/brain_automl/efficientnetv2")

# Google brain Imports

# EfficientNet Module Imports
import brain_automl
from brain_automl import efficientnetv2
from efficientnetv2 import effnetv2_model
from efficientnetv2 import effnetv2_configs


In [None]:
#!conda install -y -c rdkit rdkit # For Normalization

In [None]:
try:
    # TPU detection. No parameters necessary if TPU_NAME environment variable is set. On Kaggle this is always the case.
    TPU = tf.distribute.cluster_resolver.TPUClusterResolver()  
    tf.config.experimental_connect_to_cluster(TPU)
    tf.tpu.experimental.initialize_tpu_system(TPU)
    strategy = tf.distribute.experimental.TPUStrategy(TPU)
    tf.keras.mixed_precision.set_global_policy('mixed_bfloat16' if TPU else 'float32')
    tf.config.optimizer.set_jit(True)
except:
    TPU = None
    strategy = tf.distribute.get_strategy() 
N_REPLICAS = strategy.num_replicas_in_sync

In [None]:
class DataModule:
    base_dir = 'bmsdatatest'
    DATA_DIR = KaggleDatasets().get_gcs_path(base_dir)
    if not TPU:
        DATA_DIR = f"/kaggle/input/{base_dir}"
    TARGET_DTYPE = tf.bfloat16 if TPU else tf.float32
    TOKEN_LIST = ["<PAD>", "InChI=1S/", "<END>", "/c", "/h", "/m", "/t", "/b", "/s", "/i"] +\
             ['Si', 'Br', 'Cl', 'F', 'I', 'N', 'O', 'P', 'S', 'C', 'H', 'B', ] +\
             [str(i) for i in range(167,-1,-1)] +\
             ["\+", "\(", "\)", "\-", ",", "D", "T"]
    # Get rid of Backslashs
    for token_idx in range(len(TOKEN_LIST)):
        TOKEN_LIST[token_idx] = TOKEN_LIST[token_idx].strip('\\')
        
    # The start/end/pad tokens will be removed from the string when computing the Levenshtein distance
    # We want them as tf.constant's so they will operate properly within the @tf.function context
    VOCAB_LEN = len(TOKEN_LIST)
    START_TOKEN = tf.constant(TOKEN_LIST.index("InChI=1S/"), dtype=tf.uint8)
    END_TOKEN = tf.constant(TOKEN_LIST.index("<END>"), dtype=tf.uint8)
    PAD_TOKEN = tf.constant(TOKEN_LIST.index("<PAD>"), dtype=tf.uint8)

    # Prefixes and Their Respective Ordering/Format
    #      -- ORDERING --> {c}{h/None}{b/None}{t/None}{m/None}{s/None}{i/None}{h/None}{t/None}{m/None}
    PREFIX_ORDERING = "chbtmsihtm"
    print(f"\n... PREFIX ORDERING IS {PREFIX_ORDERING} ...")

    # Paths to Respective Image Directories
    TEST_DIR = os.path.join(DATA_DIR, "test_records")

    # Get the Full Paths to The Individual TFRecord Files
    TEST_TFREC_PATHS = sorted(
        tf.io.gfile.glob(os.path.join(TEST_DIR, "*.tfrec")), 
        key=lambda x: int(x.rsplit("_", 2)[1]))

   
    # Paths to relevant CSV files containing training and submission information
    TRAIN_CSV_PATH = os.path.join("/kaggle/input", "bms-molecular-translation", "train_labels.csv")
    SS_CSV_PATH    = os.path.join("/kaggle/input", "bms-molecular-translation", "sample_submission.csv")
    # When debug is true we use a smaller batch size and smaller model
    DEBUG=False 
    
    train_df = pd.read_csv(TRAIN_CSV_PATH)
    ss_df    = pd.read_csv(SS_CSV_PATH)

    # --- Distribution Information ---
    N_EX    = len(train_df)
    N_TEST  = len(ss_df)
    N_VAL   = 100_000 # 80000 # Fixed from dataset creation information
    N_TRAIN = N_EX-N_VAL
    
    BATCH_SIZE_DEBUG   = 2
    REPLICA_BATCH_SIZE = 16# Larger BS at Infer.
    if DEBUG:
        REPLICA_BATCH_SIZE = BATCH_SIZE_DEBUG
    OVERALL_BATCH_SIZE = REPLICA_BATCH_SIZE*N_REPLICAS


    # --- Input Image Information ---
    IMAGE_SIZE = 384
    IMG_SHAPE = (192,384,3)
    
    # --- Autocalculate Training/Validation Information ---
    START_TF = 0
    END_TF = 21#len(TEST_TFREC_PATHS)
    COMPLETED_TF = 0
    NUM_TF = (END_TF - START_TF) # 16107 Examples in the last TFRec.
    ALL_TF = False # True on Second Half
    if ALL_TF:
        # Compute Number of TF's already completed
        COMPLETED_TF = START_TF
        TEST_STEPS = len(ss_df) - COMPLETED_TF * 40000
    else:
        TEST_STEPS = NUM_TF * 40000
    TEST_BATCHES = (TEST_STEPS // OVERALL_BATCH_SIZE) + 1
    REQUIRED_DATASET_PAD = OVERALL_BATCH_SIZE-TEST_STEPS%OVERALL_BATCH_SIZE # Pad to Avoid Dropping Examples.
    # --- Modelling Information --
    FEEDFORWARD_DIM = 2048 # Increase to 256
    DECODER_DIM   = 512 # increase to 1024 later 
    
    MAX_LEN = 300 # Max Length Allowed at Inferenced
    seed = 7
    
    mean_train = np.array([0.485, 0.456, 0.406])#np.array([0.9871, 0.9871, 0.9871])
    mean_train = np.expand_dims(np.expand_dims(mean_train, axis = 0), axis = 0)
    std_train = np.array([0.229, 0.224, 0.225])#np.array([0.0888,0.0888,0.0888])
    std_train = np.expand_dims(np.expand_dims(std_train, axis = 0), axis = 0)
    stats = (mean_train, std_train)

In [None]:
class TrainingConfig:
    NUM_EPOCHS = 16
    PRINT_EVERY = 100
    
 
    label_smoothing = 0.05
    weight_decay = 1e-6
    
    clip_grad = 20.
    
    TEST_STEPS = DataModule.TEST_STEPS
    TOTAL_STEPS = TEST_STEPS * NUM_EPOCHS
    WARM_STEPS = 0.1
    PEAK_STEPS = 0.2
    WARM_START_LR = 1e-6
    PEAK_START_LR = 5e-4
    FINAL_LR = 1e-5
    
    GRAD_ACCUMULATION = 1

In [None]:
def filter_ex(x, y):
    # Removes Examples that are too long.
    num_pad = tf.equal(y, DataModule.PAD_TOKEN)
    first_token = tf.argmax(num_pad)
    if first_token > DataModule.MAX_LEN:
        return False
    return True
def cut_off(x, y):
    return x, y[:DataModule.MAX_LEN]
def load_image(one_sample):
    
    feature_dict = {
        'image': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value=''),
        'image_id': tf.io.FixedLenFeature(shape=[], dtype=tf.string, default_value='')
    }
    features = tf.io.parse_single_example(one_sample, features=feature_dict)

    image = tf.io.decode_png(features['image'], channels = 3)
    image_id = features['image_id']
    image = tf.reshape(image, DataModule.IMG_SHAPE)
    image = tf.cast(image, tf.float32) / 255.0
    # Normalize
    return tf.cast(image, DataModule.TARGET_DTYPE), image_id

In [None]:
def load_dataset():
    # Loads the Test Dataset in
    options = tf.data.Options()
    options.experimental_deterministic = False
    
    extra_padding = DataModule.REQUIRED_DATASET_PAD
    
    
    START = DataModule.START_TF
    END = DataModule.END_TF
    test_dataset = tf.data.TFRecordDataset(DataModule.TEST_TFREC_PATHS[START:END], num_parallel_reads = AUTO)
    test_dataset = test_dataset.with_options(options)
    
    test_dataset = test_dataset.map(lambda x: load_image(x), num_parallel_calls = AUTO, deterministic = False)
    
    if extra_padding!=0:
        pad_dataset = tf.data.Dataset.from_tensor_slices((
            tf.zeros((extra_padding, *DataModule.IMG_SHAPE), dtype=DataModule.TARGET_DTYPE),       # Fake Images
            tf.constant(["000000000000",]*extra_padding, dtype=tf.string))   # Fake IDs
        )
        test_dataset = test_dataset.concatenate(pad_dataset)
    
    test_dataset = test_dataset.batch(DataModule.OVERALL_BATCH_SIZE)
    test_dataset = test_dataset.prefetch(AUTO)
    return test_dataset


Model Config

In [None]:
class eff_net_v2_model(keras.Model):
    def __init__(self, model_name, enc_dim):
        super().__init__()
        self.eff_net_v2_model = effnetv2_model.EffNetV2Model(model_name=model_name, name = 'eff_net_v2_model')
        #self.proj = keras.layers.Conv2D(enc_dim, 1, use_bias = False, kernel_initializer = 'he_uniform', activation = 'relu') # Proj Layer, as the EffNetV2's don't have them and would benefit greatly from them.
        #self.bn = keras.layers.BatchNormalization()
    def call(self, x, training, features_only = True):
        features = self.eff_net_v2_model(x, training = training, features_only = features_only) # Scale features 0 
        #return self.bn(self.proj(features, training = training), training = training)
        return features

In [None]:
def get_efficientnetv2_backbone(model_name, weight_path, enc_dim, include_top=False, input_shape=(192,384,3), pooling=None, weights=None):
    # Load Complex Weights.
    model = eff_net_v2_model(model_name, enc_dim) 
    model.build((None, *DataModule.IMG_SHAPE))
    info = model.eff_net_v2_model._layers[-1]
    model.eff_net_v2_model._layers.pop(-2)
    model.eff_net_v2_model._layers += [info]
    if weight_path is not None:
        checkpoint = tf.train.Checkpoint(net = model)
        status = checkpoint.restore(tf.train.latest_checkpoint(weight_path))
        print("Loaded Encoder Weights")
    #status.assert_existing_objects_matched() # Just Check that everything worked.
    
    return model
    
class ModelConfig:
    # ENCODER_CONFIG
    enc_dim = 640 # Controls the Dimension of the Enc Dim
    enc_dec_drop = 0.0 # Dropout Between Encoder and Decoder
    transformer_drop = 0.2
    final_drop = 0.5
    num_encoder_layers = 4
    num_decoder_layers = 6 # 6 layers is the max(Official Transformer uses 6.) - Becomes Exponentially Slow.
    num_att_heads = 12 # Standard Transformer Config.
    # CHANGE TO EFFICIENTNET V2. - Medium or Large Model(XL is too many params)
    model_name = 'efficientnetv2-l' # EfficientNetV2 Small(Will scale to Large) - Mostly for model capacity, not for weights(NS is better.)
    base_dir = 'efficientnet-v2-weights'
    GCS_PATH = KaggleDatasets().get_gcs_path(base_dir) if TPU else '../input/efficientnet-v2-weights'
    print(GCS_PATH)
    weight_path = f'{GCS_PATH}/efficientnetv2-l-21k/efficientnetv2-l-21k/'
    BB_FN = get_efficientnetv2_backbone
    PREPROCESSING_FN = tf.keras.applications.efficientnet.preprocess_input
    FREEZE_ENCODER = False
    
    tmp_model = BB_FN(model_name, weight_path,enc_dim, include_top=False, input_shape=DataModule.IMG_SHAPE)
    IMG_EMB_DIM = tmp_model(tf.ones((DataModule.BATCH_SIZE_DEBUG, *DataModule.IMG_SHAPE)), features_only = True).shape[1:]
    IMG_EMB_DIM = (IMG_EMB_DIM[0]*IMG_EMB_DIM[1], IMG_EMB_DIM[2])
    


In [None]:
print("\n... ENCODER MODEL CREATION STARTING ...\n")
    
class CNNEncoder(tf.keras.Model):
    def __init__(self):
        super().__init__()
        
        self.image_embedding_dim = ModelConfig.IMG_EMB_DIM
        self.preprocessing_fn = ModelConfig.PREPROCESSING_FN
        self.model_name = ModelConfig.model_name
        self.weight_path = ModelConfig.weight_path
        self.enc_dim = ModelConfig.enc_dim
        self.backbone_fn = ModelConfig.BB_FN
        self.img_shape = DataModule.IMG_SHAPE
        self.encoder_backbone = self.backbone_fn(self.model_name, self.weight_path, self.enc_dim, include_top=False, weights=None, input_shape=self.img_shape)
        self.dropout = ModelConfig.enc_dec_drop
        self.spat_drop = keras.layers.SpatialDropout2D(rate = self.dropout)
        self.reshape = tf.keras.layers.Reshape(self.image_embedding_dim, name='image_embedding')
    def call(self, x, training):
        """ TODO
        
        Args:
            TODO        
        
        Returns:
            TODO
        """
        x = self.preprocessing_fn(x)
        x = self.encoder_backbone(x, training=training)
        x = self.spat_drop(x, training = training)
        x = self.reshape(x, training=training)
        return x


In [None]:
class EncoderMultiHeadAttention(keras.layers.Layer):
    def __init__(self, encoder_dim, num_heads, drop_prob = 0.1):
        super().__init__()
        self.encoder_dim = encoder_dim
        self.num_heads = num_heads
        self.drop_prob = drop_prob
        
        self.layer_norm1 = keras.layers.LayerNormalization(epsilon = 1e-6)
        self.MAH = keras.layers.MultiHeadAttention(self.num_heads, self.encoder_dim // self.num_heads)
        self.layer_norm2 = keras.layers.LayerNormalization(epsilon = 1e-6)
        self.Linear = keras.layers.Dense(self.encoder_dim)
    def __call__(self, x, training):
        norm = self.layer_norm1(x, training = training)
        mah = self.MAH(key = norm, value = norm, query = norm, training = training) + x
        
        norm2 = self.layer_norm2(mah, training = training)
        linear = self.Linear(norm2, training = training) + mah
        return linear
class TransformerEncoder(keras.Model):
    def __init__(self):
        super().__init__()
        self.num_layers = ModelConfig.num_encoder_layers
        self.num_heads = ModelConfig.num_att_heads
        self.enc_dim = ModelConfig.IMG_EMB_DIM[-1]# Features from the CNN encoder
        self.drop_prob = ModelConfig.enc_dec_drop
        self.length_input = ModelConfig.IMG_EMB_DIM[0] # Length of the input

        self.initializer = keras.initializers.glorot_uniform(seed = 42)
        self.pos_enc = tf.Variable(self.initializer(shape = (1, self.length_input, self.enc_dim)), trainable = True)
        self.encoders = [
            EncoderMultiHeadAttention(self.enc_dim, self.num_heads, drop_prob = self.drop_prob) for i in range(self.num_layers)
        ]
    def call(self, x, training):
        x = x + tf.cast(self.pos_enc, DataModule.TARGET_DTYPE) # Should Broadcast across Batch
        for encoder in self.encoders:
            x = encoder(x, training = training)
        return x
class Encoder(keras.Model):
    def __init__(self):
        super().__init__()
        self.CNN = CNNEncoder()
        self.transformer = TransformerEncoder()
    def call(self, x, training):
        cnn = self.CNN(x, training = training)
        return self.transformer(cnn, training = training)
with tf.device("CPU: 0"):
    encoder = Encoder()


# Transformer Decoder

In [None]:
class MultiHeadAttention(keras.layers.Layer):
    def __init__(self, encoder_features, decoder_features, num_heads, drop_prob = 0.1):
        super().__init__()
        self.encoder_features = encoder_features
        self.decoder_features = decoder_features
        self.drop_prob = drop_prob
        self.num_heads = num_heads
        
        self.MAH = keras.layers.MultiHeadAttention(self.num_heads, self.encoder_features // self.num_heads, value_dim = self.decoder_features // self.num_heads)
        self.Dropout = keras.layers.Dropout(rate = self.drop_prob)
        self.layernorm = keras.layers.LayerNormalization(epsilon = 1e-6)
    def __call__(self, encoder_features, decoder_features, attention_mask, training):
        query = decoder_features # (B, T, Dim) 
        key = encoder_features
        value = encoder_features # (B, S, Dim) 

        attended = self.MAH(query = query, key = key, value = value, attention_mask = attention_mask,  training = training)
        return self.layernorm(self.Dropout(attended, training = training) + decoder_features, training = training)

In [None]:
class TransformerDecoder(keras.layers.Layer):
    def __init__(self, encoder_features, decoder_features, num_heads, feedforward_dim, drop_prob = 0.1):
        super().__init__()
        self.encoder_features = encoder_features
        self.decoder_features = decoder_features
        self.feedforward_dim = feedforward_dim
        
        self.num_heads = num_heads
        self.drop_prob = drop_prob
        
        self.MAH1 = MultiHeadAttention(self.decoder_features, self.decoder_features, self.num_heads, drop_prob = self.drop_prob)
        self.MAH2 = MultiHeadAttention(self.encoder_features, self.decoder_features, self.num_heads, drop_prob = self.drop_prob)
        
        self.FFN = keras.Sequential([
            keras.layers.Dense(self.feedforward_dim, activation = 'relu'),
            keras.layers.Dense(self.decoder_features)
        ])
        self.Dropout = keras.layers.Dropout(rate = self.drop_prob)
        self.layernorm = keras.layers.LayerNormalization(epsilon = 1e-6)
    def __call__(self, encoder_features, decoder_features, attention_mask, padding_mask, training):
        attention1 = self.MAH1(decoder_features, decoder_features, attention_mask, training = training)
        
        attention2 = self.MAH2(encoder_features, attention1, padding_mask, training = training)
        
        FFN = self.layernorm(self.Dropout(self.FFN(attention2, training = training), training = training) + attention2, training = training)
        
        return FFN
class FullTransformerDecoder(keras.layers.Layer):
    def __init__(self):
        super().__init__()
        # Extract Settings from Config
        self.encoder_dim = ModelConfig.IMG_EMB_DIM[-1]
        self.decoder_dim = DataModule.DECODER_DIM
        self.feedforward_dim = DataModule.FEEDFORWARD_DIM
        self.drop_prob = ModelConfig.transformer_drop
        self.num_layers = ModelConfig.num_decoder_layers
        self.num_heads = ModelConfig.num_att_heads
        
        self.DecoderLayer = [
            TransformerDecoder(self.encoder_dim, self.decoder_dim, self.num_heads, self.feedforward_dim, drop_prob = self.drop_prob) for i in range(self.num_layers)
        ]
    
    def __call__(self, encoder_features, decoder_features, attention_mask, padding_mask, training):
        cur_decoded = decoder_features
        for decoder in self.DecoderLayer:
            cur_decoded = decoder(encoder_features, cur_decoded, attention_mask, padding_mask, training = training)
        return cur_decoded
class Decoder(keras.Model):
    # Full Decoder Model
    def positional_embeddings(self, max_len, dim):
        pos_enc = np.zeros((1, max_len, dim), dtype = np.float32)
        for L in range(max_len):
            for i in range(0, dim + 2, 2):
                if i >= dim:
                    break
                pos_enc[:, L, i] = math.sin(L / 10000 ** (i / dim))
                if i + 1 >= dim:
                    break
                pos_enc[:, L, i + 1] = math.cos(L / 10000 ** ((i + 1) / dim))

        return tf.cast(tf.identity(pos_enc), DataModule.TARGET_DTYPE)
    
    def __init__(self):
        super().__init__()
        # Extract Features from Config 
        self.vocab_len = DataModule.VOCAB_LEN
        self.decoder_dim = DataModule.DECODER_DIM
        self.final_drop = ModelConfig.final_drop
        self.max_len = DataModule.MAX_LEN
        
        self.embedding = keras.layers.Embedding(self.vocab_len, self.decoder_dim)
        self.decoder_transformer = FullTransformerDecoder()
        
        self.Dropout = keras.layers.Dropout(rate = self.final_drop)
        self.Linear = keras.layers.Dense(self.vocab_len)
        self.layer = keras.layers.Layer(dtype = tf.float32)
        
        self.pos_enc = self.positional_embeddings(self.max_len, self.decoder_dim)
    
    def mask_pad(self, input_tokens, dtype):
        # From Padidng Tokens, generates a Padding Mask(To Mimic Inference Time)
        # Input_tokens: Tensor(B, L)
        pad_tok = DataModule.PAD_TOKEN
        is_pad = tf.equal(input_tokens, pad_tok) # (B, L)
        return tf.cast(1 - tf.cast(tf.expand_dims(is_pad, axis = -1), dtype = tf.uint8),dtype) # (B, L, 1)
    def causal_attention_mask(self, batch_size, n_dest, n_src, dtype):
        """Masks the upper half of the dot product matrix in self attention.

        This prevents flow of information from future tokens to current token.
        1's in the lower triangle, counting from the lower right corner.
        """
        i = tf.range(n_dest)[:, None]
        j = tf.range(n_src)
        m = i >= j - n_src + n_dest
        mask = tf.cast(m, dtype)
        mask = tf.reshape(mask, [1, n_dest, n_src])
        mult = tf.concat(
            [tf.expand_dims(batch_size, -1), tf.constant([1, 1], dtype=tf.int32)], 0
        )
        return tf.tile(mask, mult)
    def call(self, encoder_features, decoder_token, training):
        # Extract Embeddings
        # Encoder features: Tensor(B, L, C)
        # Decoder Token: tensor(B, L, 1)      
        B, L = decoder_token.shape
        C = self.decoder_dim
        _, Enc_L, _ = encoder_features.shape
        
        attention_mask = self.causal_attention_mask(B, L, L, tf.uint8) # (B, L, L)
        padding_mask = self.mask_pad(decoder_token, tf.uint8) # (B, L, 1) 
        
        # Add to the Attention Mask and Also Add Pad Mask
        attention_mask = attention_mask * padding_mask
        decoder_embeddings = self.embedding(decoder_token, training = training) # (B, L, C)
        # Stretch the Padding Mask across encoder dims
        padding_mask = tf.repeat(padding_mask, Enc_L, axis = -1)
        # Stretch Positional Encodings Across the Batch
        pos_enc = tf.repeat(self.pos_enc, B, axis = 0)[:, :L, :]
        # Add Decoder Pos Enc
        decoder_embeddings = decoder_embeddings + pos_enc
        # Convert Masks to Bool
        
        attention_mask = tf.cast(attention_mask, tf.bool)
        padding_mask = tf.cast(padding_mask, tf.bool) 
    
        # Decode
        decoded = self.decoder_transformer(encoder_features, decoder_embeddings, attention_mask, padding_mask, training = training) # (B, L, C)
        # Get Element wise Predictions
        dropped = self.Dropout(decoded, training = training)
        return self.layer(self.Linear(dropped, training = training)) # (B, NumClasses)


Pretrained Config

In [None]:
class PretrainedConfig:
    encoder_checkpoints = ['../input/cptrmodel/save/enc_lsd.h5']
    decoder_checkpoints = ['../input/cptrmodel/save/dec_lsd.h5']

In [None]:
def prepare_for_training(verbose=0):
    """ Declare required objects under TPU session scope and return ready for training
    
    Args:
        lr_config (dict): Keyword arguments mapped to desired values for lr schedule function
        encoder_config (dict): Keyword arguments mapped to desired values for encoder model instantiation
        decoder_config (dict): Keyword arguments mapped to desired values for decoder model instantiation    
        encoder_wts (str, optional): Path to pretrained model weights for encoder
        decoder_wts (str, optional): Path to pretrained model weights for decoder
        verbose (bool, optional): Whether or not to print model information and plot lr schedule
        
    Returns:
        loss_fn - TBD
        metrics - TBD
        optimizer - TBD
        lr_scheduler - TBD
        encoder - TBD
        decoder - TBD
        
    """
    

    # Everything must be declared within the scope when leveraging the TPU strategy
    #     - This will still function properly if scope is set to another type of accelerator
    with strategy.scope():
        
        encoder_checkpoints = PretrainedConfig.encoder_checkpoints
        decoder_checkpoints = PretrainedConfig.decoder_checkpoints
        
        assert len(encoder_checkpoints) == len(decoder_checkpoints)
        models = []
        for i in range(len(encoder_checkpoints) if len(encoder_checkpoints) != 0 else 1):
            
            # Instantiate the encoder model 
            print("\t--> CREATING ENCODER MODEL ARCHITECTURE ...")
            encoder = Encoder()
            initialization_batch = encoder(
                tf.ones(((DataModule.REPLICA_BATCH_SIZE,)+DataModule.IMG_SHAPE), dtype=DataModule.TARGET_DTYPE), 
                training=False,
            )
        

            # Instantiate the decoder model
            print("\t--> CREATING DECODER MODEL ARCHITECTURE...")
            decoder = Decoder()
            pred_output = decoder(
                initialization_batch,
                tf.identity(np.random.randint(0, DataModule.VOCAB_LEN, (DataModule.REPLICA_BATCH_SIZE, DataModule.MAX_LEN), dtype = np.uint8)),
                training=False,
            )
            models += [(encoder, decoder)]
        for i in range(len(encoder_checkpoints)):
            encoder_path = encoder_checkpoints[i]
            decoder_path = decoder_checkpoints[i]
            
            models[i][0].load_weights(encoder_path)
            models[i][1].load_weights(decoder_path)

  
    return models
    
    
print("\n... GENERATING THE FOLLOWING:")
# Instantiate our required training components in the correct scope
models = prepare_for_training(verbose=0)

print("\n... TRAINING PREPERATION FINISHED ...\n")

In [None]:
def batch_encoder(encoder, beam_size):
    # Encoder: tensor(B, L, C) 
    # Decoder: tensor(B, T)
    # Repeats the encoder values and decoder values beam_size times to make them match the tokens
    new_encoder = []
    for encoder_values in encoder:
        value = tf.repeat(encoder_values, beam_size, axis = 0)
        new_encoder += [value]
    return new_encoder # (BK, L, C)

def batch_decoder(decoder, beam_size):
    new_decoder = tf.repeat(decoder, beam_size, axis = 0)
    
    return new_decoder # (BK, T)
def repeat_logits(logits, beam_size):
    # Logits: 
    new_logits = tf.repeat(logits, beam_size, axis = 0)
    return new_logits
def test_step(images):
    # BATCHED BEAM SEARCH PREDICTIONS(semi Batched Beam Search) --------TRANSFORMER ENSEMBLE.
    
    B, H, W, C = images.shape
    def predict(models, encoded_images, predictions_seq_batch, training = False):
        pred = []
        
        for decoder_idx in range(len(models)):
            pred += [models[decoder_idx][1](encoded_images[decoder_idx], predictions_seq_batch, training = training)[:, -1]]
        # Average the predictions
        pred = tf.stack(pred, axis = -1)
        pred = tf.reduce_mean(pred, axis = -1)
        # Softmax
        pred = keras.activations.softmax(pred) # (B, C)
        return pred
    classes = DataModule.VOCAB_LEN
    BEAM_SIZE = 16 # No Free Will said that min Beam Size 16 is needed for any results.
    training = False
    MAX_LEN = DataModule.MAX_LEN
    # Initial Work Before: Encoding the Images
     # image_batch_embedding has shape --> (REPLICA_BATCH_SIZE, IMG_EMB_DIM)
    training = False
    encoded_images = []
    for encoder_idx in range(len(models)):
        model_pred = models[encoder_idx][0](images, training = training)
        _, L, C = model_pred.shape
        encoded_images += [model_pred]
    # For Simplicity: Let's assume a <START><END> never exists(it doesn't in this dataset)
    predictions_seq_batch = tf.ones((DataModule.REPLICA_BATCH_SIZE, 1), dtype=tf.uint8)
    
    preds = predict(models, encoded_images, predictions_seq_batch, training = training)
    # Top k
    top_k = tf.math.top_k(preds, k = BEAM_SIZE)
    indices = tf.cast(tf.reshape(top_k.indices, (-1, 1)), tf.uint8)
    cur_logits = tf.reshape(top_k.values, (-1, 1))
    
    encoded_images = batch_encoder(encoded_images, BEAM_SIZE)
    predictions_seq_batch = tf.ones((DataModule.REPLICA_BATCH_SIZE * BEAM_SIZE, 1), dtype=tf.uint8)
    predictions_seq_batch = tf.concat([predictions_seq_batch, indices], axis = -1)
    
    final_predictions = [tf.ones((MAX_LEN), dtype = tf.uint8) for i in range(B)]
    final_scores = [tf.ones((1), dtype = tf.float32) * 999. for i in range(B)]
    beams_completed = np.ones((B, BEAM_SIZE))
   

    def pad_tensor(value): 
        # pads an output sentence to max len
        L = value.shape[0]
        pad_len = MAX_LEN - L
        padded = tf.ones((pad_len), dtype = value.dtype) * tf.constant(DataModule.PAD_TOKEN, dtype = value.dtype)
    
        return tf.concat([value, padded], axis = 0)
    def batch_wise_index(values, indices):
        # Batch wise Indexing. Slow, but it works.
        B, L, C = values.shape
        _, L_prime = indices.shape
        all_batches = tf.zeros((1, C), dtype = values.dtype)
        for b in range(B):
            all_lengths = tf.zeros((1, C), dtype = values.dtype)
            index = indices[b]
            
            for i in range(index.shape[0]):
                idx = index[i]
                all_lengths = tf.concat([all_lengths, tf.expand_dims(values[b, idx], axis = 0)], axis = 0)
            all_lengths = all_lengths[1:]
            all_batches = tf.concat([all_batches, all_lengths], axis = 0)
        return all_batches[1:]
        
   # Teacher forcing - feeding the target as the next input
    for IDX in range(1, MAX_LEN - 1):
        if IDX % 10 == 0:
            print(IDX)
       
        # Decoded logits shouls already by (BK, T)
        pred = predict(models, encoded_images, predictions_seq_batch, training = training) # (BK, C) 

        # Repeat the Tokens K times
        repeated_tokens = batch_decoder(predictions_seq_batch, BEAM_SIZE) # (BK^2, T)
        repeated_logits = repeat_logits(cur_logits, BEAM_SIZE) # (BK^2)
        # Take the Top K Predictions
        top_k_1 = tf.math.top_k(pred, k = BEAM_SIZE) # (BK, K)
        indices = tf.cast(top_k_1.indices, tf.uint8) # (BK, K) - top K Classes
        logits = top_k_1.values # (BK,K)
        # Mask the logits
        mask = tf.reshape(beams_completed, (-1, 1)) # (BK)
        logits = logits * tf.cast(mask, logits.dtype) # (BK, K)
        
        indices = tf.reshape(indices, (-1, 1)) # (BK^2, 1)
        logits = tf.reshape(logits, (-1, BEAM_SIZE)) # (BK, K)
        # Append the Indices and Add the logits to get all predictions
        logits = logits + cur_logits
        repeated_logits = tf.reshape(logits, (B, -1)) # (B, K^2)
        # Append
        repeated_tokens = tf.concat([repeated_tokens, indices], axis = -1) # (BK^2, T + 1) 
        # Take the top K best Predictions
        top_k_2 = tf.math.top_k(repeated_logits, k = BEAM_SIZE) 
        extraction_indices = tf.reshape(top_k_2.indices, (B, BEAM_SIZE)) # (B, K)
        indices = tf.reshape(indices, (B, -1)) # (B, K ** 2)
        
        extracted_tokens = tf.gather(indices, extraction_indices, batch_dims = 1) # (B(K))
        extracted_tokens = tf.squeeze(tf.reshape(extracted_tokens, (-1, 1)))

        # Extracte the top K best Sentences
        sentence_scores =tf.squeeze(tf.reshape(top_k_2.values, (-1, 1)))# (BK)
        repeated_tokens = tf.reshape(repeated_tokens, (B, BEAM_SIZE ** 2, -1))  #(B, K ** 2, L)
        extracted_sentences = batch_wise_index(repeated_tokens, extraction_indices)
   
        
        # Reshape to separate out batch size and K 
        extracted_tokens = tf.reshape(extracted_tokens, (B, BEAM_SIZE)) # (B, K)
        sentence_scores = tf.reshape(sentence_scores, (B, BEAM_SIZE)) # (B, K)
        extracted_sentences = tf.reshape(extracted_sentences, (B, BEAM_SIZE, -1)) # (B, K, T + 1)
        # Check for equality to <END>
        new_mask = np.ones((B, BEAM_SIZE))
        for b in range(B):
            all_completed_sentences = tf.zeros((1, len(extracted_sentences[0, 0])), dtype = extracted_sentences[0, 0].dtype)
            all_completed_scores = tf.zeros((1, 1), dtype = sentence_scores[0, 0].dtype)
            
            for beam_idx in range(BEAM_SIZE):
                
                if extracted_tokens[b, beam_idx] == tf.constant(DataModule.END_TOKEN): # (B, K) 
                    all_completed_sentences = tf.concat([all_completed_sentences, tf.expand_dims(extracted_sentences[b, beam_idx], axis = 0)], axis = 0)
                    all_completed_scores = tf.concat([all_completed_scores, tf.expand_dims(tf.expand_dims(sentence_scores[b, beam_idx], axis = 0), axis = 0)], axis = 0)
                    new_mask[b, beam_idx] = 0
                else:
                    pass
            # ISSUE WITH NESTED LOOPS, ONLY HAPPENS WHEN LOOPS ARE NESTED.        
            all_completed_sentences = all_completed_sentences[1:]
            all_completed_scores = all_completed_scores[1:]
            # Find Completed sentences 
            completed_sentences = all_completed_sentences 
            completed_scores = all_completed_scores # (N)
            
            # Hash all of the scores of completed indices
            for i in range(len(completed_scores)):
            
                if tf.reduce_sum(tf.cast(tf.equal(completed_sentences[i], DataModule.END_TOKEN), tf.int32)) == 1 and tf.reduce_sum(tf.cast(tf.equal(completed_sentences[i, -1], DataModule.END_TOKEN), tf.int32)) == 1:
                    
                    padded_sent = pad_tensor(completed_sentences[i])
                    scores = completed_scores[i] / (len(completed_sentences[i]) - 1)
                    if scores < final_scores[b]:
                        final_scores[b] = scores
                        final_predictions[b] = padded_sent
                    
                else:
                    pass
            
        
        # Iterate through batches and remove completed sentences
        beams_completed = new_mask
        cur_logits = tf.reshape(sentence_scores, (-1, 1)) # (BK, 1)
        predictions_seq_batch = tf.reshape(extracted_sentences, (B * BEAM_SIZE, -1))
        
    # add all remaining to the cache
    predictions_seq_batch = tf.reshape(predictions_seq_batch, (B, BEAM_SIZE, -1))
    cur_logits = tf.reshape(cur_logits, (B, BEAM_SIZE))
   
    for b in range(predictions_seq_batch.shape[0]):
        for i in range(len(predictions_seq_batch[b])):
            padded_sent = pad_tensor(predictions_seq_batch[b, i])
            scores = cur_logits[b, i] / (MAX_LEN - 1)
        
            if scores < final_scores[b]:
                final_scores[b] = scores
                final_predictions[b] = padded_sent
    
    return final_predictions

@tf.function
def dist_test_step(images, ids):
    pred = strategy.run(test_step, args = (images, ))
    predictions = strategy.gather(pred, axis=0)
    pred_ids = strategy.gather(ids, axis=0)
    return predictions, pred_ids

In [None]:
def arr_2_inchi(arr):
    """ Basic integer array to inchi string conversion """
    inchi_str = ''
    for i in arr:
        c = DataModule.TOKEN_LIST[i]
        if c=="<END>":
            break
        inchi_str += c
    return inchi_str

def decode(sentences):
    all_sentences = []
    B, L = sentences.shape
    for i in range(B):
        sentence = sentences[i]
        inchi_string = ''
        for j in sentence:
            token = DataModule.TOKEN_LIST[j]
          
            if token == '<END>':
                break
            inchi_string += token
        all_sentences += [inchi_string]
    return all_sentences
def post_process(all_pred_arr, all_pred_ids):
    # List of Sentences
    max_len = DataModule.MAX_LEN
    
    image_id = [x[0].decode() for x in tqdm(all_pred_ids[1:-DataModule.REQUIRED_DATASET_PAD].numpy(), total=DataModule.TEST_STEPS)]
    inchi = [arr_2_inchi(pred_arr) for pred_arr in tqdm(all_pred_arr[1:-DataModule.REQUIRED_DATASET_PAD].numpy(), total=DataModule.TEST_STEPS)]
    pred_df = pd.DataFrame({
    "image_id": image_id, 
    "InChI": inchi
    })

   
    # create prediction DF
    pred_df = pred_df.sort_values(by="image_id").reset_index(drop=True)
    pred_df.to_csv("submission.csv", index=False)


In [None]:
test_ds = load_dataset()


In [None]:
# Inference Loop 
test_ds = load_dataset()
test_ds = iter(strategy.experimental_distribute_dataset(test_ds))
all_pred_sentences = tf.ones((1, DataModule.MAX_LEN), dtype = tf.uint8)
all_pred_ids = tf.zeros((1, 1), dtype = tf.string)
count = 0
cur_val = 10000
for images, ids in test_ds:
    sentences, ids = dist_test_step(images, ids)
    all_pred_sentences = tf.concat([all_pred_sentences, sentences], axis = 0)
    all_pred_ids = tf.concat([all_pred_ids, tf.expand_dims(ids, axis = -1)], axis = 0)
    count += DataModule.OVERALL_BATCH_SIZE
    if count > cur_val:
        print(cur_val)
        cur_val += 10000

In [None]:
post_process(all_pred_sentences, all_pred_ids)

# Normalize Predictions(Post Processing)

In [None]:
'''
%%writefile normalize_inchis.py

from tqdm import tqdm
from rdkit import Chem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')
from pathlib import Path

def normalize_inchi(inchi):
    try:
        mol = Chem.MolFromInchi(inchi)
    except:
        pass
    if mol is None:
        return inchi
    else:
        try: return Chem.MolToInchi(mol)
        except: return inchi
        
submission_name = '../input/bmsmt-ds-model/submission_3.06.csv'
norm_path = Path('submission_norm.csv')

# Do the job
N = norm_path.read_text().count('\n') if norm_path.exists() else 0
print(N, 'number of predictions already normalized')

r = open(submission_name, 'r')
write_mode = 'w' if N == 0 else 'a'
w = open(str(norm_path), write_mode, buffering=1)

for _ in range(N):
    r.readline()
line = r.readline()  # this line is the header or is where it died last time
w.write(line)

pbar = tqdm()
while True:
    line = r.readline()
    if not line:
        break  # done
    image_id = line.split(',')[0]
    inchi = ','.join(line[:-1].split(',')[1:]).replace('"','')
    inchi_norm = normalize_inchi(inchi)
    w.write(f'{image_id},"{inchi_norm}"\n')
    pbar.update(1)

r.close()
w.close()
'''

In [None]:
# !while [ 1 ]; do python normalize_inchis.py && break; done

# Post Process Submission.csv Files.

# 1 Hour Startup Time, 20 Minutes to actually infer lol.