## Module

In [1]:
import cv2
import os
import gc
import time
import requests
import datetime
import numpy as np
import pandas as pd
import os.path as pth
from tqdm.auto import tqdm
tqdm.pandas()
import matplotlib.pyplot as plt

import rdkit
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit import DataStructs
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')  

from IPython.display import clear_output

from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow.python.client import device_lib
from tensorflow.keras.preprocessing import image

import tensorflow.keras as keras
import keras.backend as K
from keras.models import Model, Input, load_model
from keras.layers import Conv2D, Dense, MaxPooling2D, AveragePooling2D, GlobalAveragePooling2D
from keras.layers import Activation, BatchNormalization
from keras.layers import Concatenate
from keras.utils import to_categorical
from keras.callbacks import Callback
from keras.optimizers import SGD

from multiprocessing import Pool
from functools import partial

## Device

In [2]:
print(tf.__version__)
print(device_lib.list_local_devices())

2.4.1
[name: "/device:CPU:0"
device_type: "CPU"
memory_limit: 268435456
locality {
}
incarnation: 6925132695339830503
, name: "/device:GPU:0"
device_type: "GPU"
memory_limit: 10563764864
locality {
  bus_id: 1
  links {
  }
}
incarnation: 8933212339793434640
physical_device_desc: "device: 0, name: GeForce RTX 2080 Ti, pci bus id: 0000:3f:00.0, compute capability: 7.5"
, name: "/device:GPU:1"
device_type: "GPU"
memory_limit: 10563764864
locality {
  bus_id: 1
  links {
  }
}
incarnation: 15613777631761335604
physical_device_desc: "device: 1, name: GeForce RTX 2080 Ti, pci bus id: 0000:40:00.0, compute capability: 7.5"
, name: "/device:GPU:2"
device_type: "GPU"
memory_limit: 10563764864
locality {
  bus_id: 1
  links {
  }
}
incarnation: 517385157965408797
physical_device_desc: "device: 2, name: GeForce RTX 2080 Ti, pci bus id: 0000:41:00.0, compute capability: 7.5"
, name: "/device:GPU:3"
device_type: "GPU"
memory_limit: 10563764864
locality {
  bus_id: 1
  links {
  }
}
incarnation: 95

### Multi Device

In [3]:
#tf.compat.v1.disable_eager_execution()
mirrored_strategy = tf.distribute.MirroredStrategy()
#mirrored_strategy = tf.distribute.MirroredStrategy(devices = ["/gpu:4", "/gpu:5", "/gpu:6", "/gpu:7"])

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1', '/job:localhost/replica:0/task:0/device:GPU:2', '/job:localhost/replica:0/task:0/device:GPU:3', '/job:localhost/replica:0/task:0/device:GPU:4', '/job:localhost/replica:0/task:0/device:GPU:5', '/job:localhost/replica:0/task:0/device:GPU:6', '/job:localhost/replica:0/task:0/device:GPU:7')


## Constant

In [46]:
PATH = '../molecular_data/'
TRAIN_DIR = PATH + 'train'
TEST_DIR = PATH + 'test'

OUTPUT_DIR = './'
if not os.path.exists(OUTPUT_DIR):
    os.makedirs(OUTPUT_DIR)

class CFG : 
    debug = False
    size = 300
    data_ut = 2424186 #500000
    seed = 42
    batch_size = 256
    buffer_size = 100
    learning_rate = 1e-4
    base_channel = 8
    model_encoder_name = 'CustomDenseNet-121'
    model_base_path = pth.join('model', 'checkpoint')
    
    epochs = 30
    n_mht = 512
    n_layer = 4
    n_dff = 1024
    n_head = 8
    dropout = 0.1 # 0
    max_length = 100

## Label Load

In [5]:
train = pd.read_csv(PATH + 'train_labels.csv')
train.head()

Unnamed: 0,image_id,InChI
0,000011a64c74,InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12...
1,000019cc0cd2,InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(1...
2,0000252b6d2b,InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-1...
3,000026b49b7e,InChI=1S/C17H24N2O4S/c1-12(20)18-13(14-7-6-10-...
4,000026fc6c36,InChI=1S/C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7...


In [6]:
def get_path(img_name) : 
    return f"{TRAIN_DIR}/{img_name[0]}/{img_name[1]}/{img_name[2]}/{img_name}.png"

train['path'] = train['image_id'].apply(get_path)

In [7]:
train.head()

Unnamed: 0,image_id,InChI,path
0,000011a64c74,InChI=1S/C13H20OS/c1-9(2)8-15-13-6-5-10(3)7-12...,../molecular_data/train/0/0/0/000011a64c74.png
1,000019cc0cd2,InChI=1S/C21H30O4/c1-12(22)25-14-6-8-20(2)13(1...,../molecular_data/train/0/0/0/000019cc0cd2.png
2,0000252b6d2b,InChI=1S/C24H23N5O4/c1-14-13-15(7-8-17(14)28-1...,../molecular_data/train/0/0/0/0000252b6d2b.png
3,000026b49b7e,InChI=1S/C17H24N2O4S/c1-12(20)18-13(14-7-6-10-...,../molecular_data/train/0/0/0/000026b49b7e.png
4,000026fc6c36,InChI=1S/C10H19N3O2S/c1-15-10(14)12-8-4-6-13(7...,../molecular_data/train/0/0/0/000026fc6c36.png


In [8]:
len(train['InChI'].unique())

2424186

### Vectorize

In [9]:
all_captions = []
all_img_name_vector = []

for idx in tqdm(train.index) : 
    caption = train['InChI'].iloc[idx]
    img_name = train['path'].iloc[idx]
    all_captions.append(caption)
    all_img_name_vector.append(img_name)

  0%|          | 0/2424186 [00:00<?, ?it/s]

### Shuffle and Cut

In [10]:
train_captions, train_img_name_vector = shuffle(all_captions, all_img_name_vector, random_state = CFG.seed)
train_captions = train_captions[:CFG.data_ut]
train_img_name_vector = train_img_name_vector[:CFG.data_ut]
print(len(train_captions))

2424186


In [44]:
img_name_train, img_name_val, caption_train, caption_val = train_test_split(train_img_name_vector, train_captions, test_size = 0.2, random_state = CFG.seed)
len(img_name_train), len(img_name_val), len(caption_train), len(caption_val)

(1939348, 484838, 1939348, 484838)

In [11]:
is_sgd = False

## Model setting

### Encoder

In [21]:
encoder_model_base = CFG.model_encoder_name
encoder_model_name = 'Autoencoder_{}_trts_basech_{:03d}'.format(encoder_model_base, CFG.base_channel)

In [22]:
encoder_model_base_path = CFG.model_base_path
encoder_model_path = pth.join(encoder_model_base_path, encoder_model_name)

In [23]:
os.makedirs(encoder_model_path, exist_ok = True)
target_checkpoint_filename = sorted(os.listdir(encoder_model_path))[-1]
encoder_model_filename = pth.join(encoder_model_path, target_checkpoint_filename)

In [24]:
encoder_model_filename

'model/checkpoint/Autoencoder_CustomDenseNet-121_trts_basech_008/000015-0.001644-0.001633.hdf5'

In [26]:
decoder_model_name = 'trfrm_mht_{}_layer_{}_dff_{}_head_{}_DO_{}'.format(
        CFG.n_mht, CFG.n_layer, CFG.n_dff, CFG.n_head, CFG.dropout
    )

In [28]:
model_name = 'enc-tr_{}_dec-tr_{}_len-100-all-pseudolabel'.format(encoder_model_name, decoder_model_name)

### Max length

In [33]:
def calc_max_length(tensor):
    return max(len(t) for t in tensor)
    
# max_length = calc_max_length(all_captions)

In [35]:
tokenizer = tf.keras.preprocessing.text.Tokenizer(lower=False, char_level=True)
# temp_captions = all_origin_captions + [" ^#%()+-.0123456789=@ABCDEFGHIKLMNOPRSTVXYZ[\\]abcdefgilmnoprstuy$"]
# tokenizer.fit_on_texts(temp_captions)
all_token_list = [
    'c', 'C', '(', ')', '1', 'O', '=', '2', 'N', '<', '>', 'n', '[',
    ']', '3', '@', 'H', 'l', 'S', '-', 'F', '+', '4', 's', 'o', '#',
    'B', 'r', '.', '/', 'P', 'i', 'I', '5', '\\', 'e', 'A', 'a', 'g',
    '6', 'u', 't', 'T', 'M', 'b', 'K', 'Z', '8', 'd', '9', 'R', 'G',
    '7', 'L', 'V', 'h', 'W', 'p', 'm', 'E', 'Y', '0', 'U', 'f', 'D',
    'y', 'k', 'X', ' ', '^', '%', '$'
]
tokenizer.fit_on_texts(all_token_list)
top_k = len(tokenizer.word_index)
train_seqs = tokenizer.texts_to_sequences(train_captions)
cap_vector = tf.keras.preprocessing.sequence.pad_sequences(
    train_seqs, maxlen = CFG.max_length, padding = 'post'
)

In [38]:
train_captions[0]

'InChI=1S/C21H22N6O4/c1-10-19-14(27-21(22)24-10)7-13(26-20(19)28)11-5-16(29-2)17(30-3)6-12(11)15-8-23-9-18(25-15)31-4/h5-6,8-9,13H,7H2,1-4H3,(H,26,28)(H2,22,24,27)/t13-/m1/s1'

### Hyper Parameter

In [51]:
BATCH_SIZE = CFG.batch_size
BUFFER_SIZE = CFG.buffer_size
d_model = CFG.n_mht
num_layers = CFG.n_layer
dff = CFG.n_dff
num_heads = CFG.n_head
vocab_size = top_k # + 1
dropout_rate = CFG.dropout
EPOCHS = CFG.epochs
learning_rate = CFG.learning_rate

In [52]:
train_num_steps = int(np.ceil(len(img_name_train) / BATCH_SIZE))
val_num_steps = int(np.ceil(len(img_name_val) / BATCH_SIZE))

## Dataset

In [56]:
def map_func(img_path, caption) : 
    img = tf.io.read_file(img_path)
    img = tf.image.decode_png(img, channels = 1)
    img = tf.dtypes.cast(img, tf.float32)
    img = img / 255.0
    img = tf.image.resize(img, (CFG.size, CFG.size))
    return img, caption

def prep_func(img, caption) : 
    result_img = tf.keras.applications.inception_v3.preprocess_input(img)
    return result_img, caption

In [54]:
dataset_train = tf.data.Dataset.from_tensor_slices((img_name_train, caption_train))
dataset_train = dataset_train.map(map_func, num_parallel_calls = tf.data.experimental.AUTOTUNE)
dataset_train = dataset_train.shuffle(CFG.buffer_size).batch(CFG.batch_size)
dataset_train = dataset_train.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)

In [55]:
dataset_val = tf.data.Dataset.from_tensor_slices((img_name_val, caption_val))
dataset_val = dataset_val.map(map_func, num_parallel_calls = tf.data.experimental.AUTOTUNE)
dataset_val = dataset_val.batch(CFG.batch_size)
dataset_val = dataset_val.prefetch(buffer_size = tf.data.experimental.AUTOTUNE)

## Decoder

### Helper function

In [61]:
def get_angles(pos, i, d_model):
    angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
    return pos * angle_rates


def positional_encoding(position, d_model):
    angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)
    # apply sin to even indices in the array; 2i
    angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])
    # apply cos to odd indices in the array; 2i+1
    angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])
    pos_encoding = angle_rads[np.newaxis, ...]

    return tf.cast(pos_encoding, dtype=tf.float32)

@tf.function
def create_padding_mask(seq):
    seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

    # add extra dimensions to add the padding
    # to the attention logits.
    return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)

@tf.function
def create_look_ahead_mask(size):
    mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
    return mask  # (seq_len, seq_len)


def scaled_dot_product_attention(q, k, v, mask):
    matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

    # scale matmul_qk
    dk = tf.cast(tf.shape(k)[-1], tf.float32)
    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

    # add the mask to the scaled tensor.
    if mask is not None:
        scaled_attention_logits += (mask * -1e9)  

    # softmax is normalized on the last axis (seq_len_k) so that the scores
    # add up to 1.
    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

    output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

    return output, attention_weights

def point_wise_feed_forward_network(d_model, dff):
    return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
    ])

@tf.function
def create_masks(tar):
    look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
    dec_target_padding_mask = create_padding_mask(tar)
    combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

    return combined_mask

### Attention

In [62]:
class MultiHeadAttention(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model

        assert d_model % self.num_heads == 0

        self.depth = d_model // self.num_heads

        self.wq = tf.keras.layers.Dense(d_model)
        self.wk = tf.keras.layers.Dense(d_model)
        self.wv = tf.keras.layers.Dense(d_model)

        self.dense = tf.keras.layers.Dense(d_model)

    def split_heads(self, x, batch_size):
        """Split the last dimension into (num_heads, depth).
        Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
        """
        x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
        return tf.transpose(x, perm=[0, 2, 1, 3])

    def call(self, v, k, q, mask):
        batch_size = tf.shape(q)[0]

        q = self.wq(q)  # (batch_size, seq_len, d_model)
        k = self.wk(k)  # (batch_size, seq_len, d_model)
        v = self.wv(v)  # (batch_size, seq_len, d_model)

        q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
        k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
        v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

        # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
        # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
        scaled_attention, attention_weights = scaled_dot_product_attention(
            q, k, v, mask)

        scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

        concat_attention = tf.reshape(scaled_attention, 
                                      (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

        output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

        return output, attention_weights

### Decoder

In [65]:
class DecoderLayer(tf.keras.layers.Layer):
    def __init__(self, d_model, num_heads, dff, rate=0.1):
        super(DecoderLayer, self).__init__()

        self.mha1 = MultiHeadAttention(d_model, num_heads)
        self.mha2 = MultiHeadAttention(d_model, num_heads)

        self.ffn = point_wise_feed_forward_network(d_model, dff)

        self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
        self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)

        self.dropout1 = tf.keras.layers.Dropout(rate)
        self.dropout2 = tf.keras.layers.Dropout(rate)
        self.dropout3 = tf.keras.layers.Dropout(rate)


    def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
        # enc_output.shape == (batch_size, input_seq_len, d_model)

        attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
        attn1 = self.dropout1(attn1, training=training)
        out1 = self.layernorm1(attn1 + x)

        attn2, attn_weights_block2 = self.mha2(
            enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
        attn2 = self.dropout2(attn2, training=training)
        out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

        ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
        ffn_output = self.dropout3(ffn_output, training=training)
        out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

        return out3, attn_weights_block1, attn_weights_block2

class Decoder(tf.keras.layers.Layer):
    def __init__(self, num_layers, d_model, num_heads, dff, target_vocab_size,
               maximum_position_encoding, rate=0.1):
        super(Decoder, self).__init__()

        self.d_model = d_model
        self.num_layers = num_layers

        self.embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)
        self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

        self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                           for _ in range(num_layers)]
        self.dropout = tf.keras.layers.Dropout(rate)

    def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
        seq_len = tf.shape(x)[1]
        attention_weights = {}

        x = self.embedding(x)  # (batch_size, target_seq_len, d_model)
        x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
        x += self.pos_encoding[:, :seq_len, :]

        x = self.dropout(x, training=training)

        for i in range(self.num_layers):
            x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                                 look_ahead_mask, padding_mask)

            attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
            attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

        # x.shape == (batch_size, target_seq_len, d_model)
        return x, attention_weights


class CNN_Encoder(tf.keras.Model):
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        target_checkpoint_filename = sorted(os.listdir(encoder_model_path))[-1]
        image_autoencoder = load_model(pth.join(encoder_model_path, target_checkpoint_filename))
        image_features_extract_model = image_autoencoder.get_layer(encoder_model_base)
#         image_features_extract_model.trainable = False
        self.feature_extract_model = image_features_extract_model
        self.fc = tf.keras.layers.Dense(embedding_dim, activation='relu')
        
    def call(self, x):
        x = self.feature_extract_model(x)
        x = tf.keras.layers.Reshape((-1, x.shape[3]))(x)
        x = self.fc(x)
        return x


class ImageCaptioningTransformer(tf.keras.Model):
    def __init__(self, num_layers, d_model, num_heads, dff,
               target_vocab_size, pe_target, rate=0.1):
        super(ImageCaptioningTransformer, self).__init__()

        self.encoder = CNN_Encoder(d_model)

        self.decoder = Decoder(num_layers, d_model, num_heads, dff, 
                               target_vocab_size, pe_target, rate)

        self.final_layer = tf.keras.layers.Dense(target_vocab_size)

    def call(self, inp, tar, training, look_ahead_mask, dec_padding_mask):
        enc_output = self.encoder(inp)

        # dec_output.shape == (batch_size, tar_seq_len, d_model)
        dec_output, attention_weights = self.decoder(
            tar, enc_output, training, look_ahead_mask, dec_padding_mask)

        final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

        return final_output, attention_weights

## Train

### Model Creation

In [66]:
captioning_transformer = ImageCaptioningTransformer(
    num_layers, d_model, num_heads, dff,
    vocab_size, pe_target=100,
    rate=dropout_rate
)

### Helper Function

In [70]:
class CustomSchedule(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, d_model, warmup_steps=5):
        super(CustomSchedule, self).__init__()

        self.d_model = d_model
        self.d_model = tf.cast(self.d_model, tf.float32)

        self.warmup_steps = warmup_steps

    def __call__(self, step):
        arg1 = tf.math.rsqrt(step)
        arg2 = step * (self.warmup_steps ** -1.5)

        return tf.math.rsqrt(self.d_model) * tf.math.minimum(arg1, arg2)

In [72]:
learning_rate = 1e-4

if is_sgd == True:
    optimizer = tf.keras.optimizers.SGD(lr=learning_rate)
else:
    optimizer = tf.keras.optimizers.Adam(lr=learning_rate)

# learning_rate = CustomSchedule(d_model)
# optimizer = tf.keras.optimizers.Adam(
#     learning_rate, beta_1=0.9, beta_2=0.98, 
#     epsilon=1e-9
# )

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none'
)

@tf.function
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

### Checkpoint

In [74]:
checkpoint_path = pth.join('model', 'decoder_checkpoint', model_name)
os.makedirs(checkpoint_path, exist_ok=True)
ckpt = tf.train.Checkpoint(
    captioning_transformer = captioning_transformer, 
    optimizer = optimizer
)
ckpt_manager = tf.train.CheckpointManager(ckpt, checkpoint_path, max_to_keep=25)

In [75]:
start_epoch = 0
if ckpt_manager.latest_checkpoint:
    start_epoch = int(ckpt_manager.latest_checkpoint.split('-')[-1])
    ckpt.restore(ckpt_manager.latest_checkpoint)

## Train

In [76]:
def calculate_similarity(real, pred):
#     pred = np.array(list(map(np.array, pred)))
#     pred = np.moveaxis(pred, (0,1,2), (1,0,2))
    pred = np.argmax(pred, axis=-1)
#     print(real[:5], pred[:5])
    real = real.numpy()
    
    score_list = []
    for score_i, (each_pred, each_real) in enumerate(zip(pred, real)): 
        each_pred = ''.join([tokenizer.index_word.get(mol_i, '') for mol_i in each_pred])
        each_pred = each_pred.split('>')[0]
        m_pred = Chem.MolFromInchi(each_pred)
        if m_pred == None:
            score_list.append(0)
            continue
        each_real = ''.join([tokenizer.index_word.get(mol_i, '') for mol_i in each_real])
        each_real = each_real[1:-1]
        m_real = Chem.MolFromInchi(each_real)
        
        fp_pred = Chem.RDKFingerprint(m_pred)
        fp_real = Chem.RDKFingerprint(m_real)
        target_similarity = DataStructs.FingerprintSimilarity(fp_real,fp_pred)
        score_list.append(target_similarity)
        
    return score_list

In [81]:
# @tf.function(input_signature=train_step_signature)
@tf.function
def train_step(img_tensor, target, training=True):
    target_inp = target[:, :-1]
    target_real = target[:, 1:]
    
    combined_mask = create_masks(target_inp)
    
    with tf.GradientTape() as tape:
        predictions, _ = captioning_transformer(
            inp=img_tensor, tar=target_inp, training=training, 
            look_ahead_mask=combined_mask, dec_padding_mask=None
#             look_ahead_mask=None, dec_padding_mask=None
        )
        loss = loss_function(target_real, predictions)
#         total_loss = (loss / int(target_inp.shape[1]))
        if training == True:
            gradients = tape.gradient(loss, captioning_transformer.trainable_variables)    
            optimizer.apply_gradients(zip(gradients, captioning_transformer.trainable_variables))

    train_accuracy(target_real, predictions)
    return loss, predictions

In [82]:
loss_plot, val_loss_plot = [], []
sim_plot, val_sim_plot = [], []
lowest_val_loss = 1e12
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(
    name='train_accuracy')
gc.collect()

143

In [83]:
for epoch in range(start_epoch, EPOCHS):
    total_loss, total_val_loss, total_test_pred_26_loss = 0, 0, 0
    train_accuracy.reset_states()

    tqdm_dataset = tqdm(enumerate(dataset_train), total=train_num_steps, position=0, leave=True)
    total_sim = 0
    for (batch, (img_tensor, target)) in tqdm_dataset:
        valid_cap_mask = (target[:,0] == 10)
        img_tensor = img_tensor[valid_cap_mask]
        target = target[valid_cap_mask]

        batch_loss, pred_list = train_step(img_tensor, target, training=True)
        smilarlity_list = calculate_similarity(target, pred_list)
        smilarlity = np.mean(smilarlity_list)
        total_sim += smilarlity
        total_loss += batch_loss
        if batch % 50 == 0:
            tqdm_dataset.set_postfix({
                'Epoch': epoch + 1,
                'Batch': batch,
                'Loss': '{:06f}'.format(batch_loss.numpy() / int(target.shape[1])),
                'Similarlity': smilarlity,
                'Accuracy':train_accuracy.result().numpy(), 
            })
        if batch % 30 == 0:
            gc.collect()
    loss_plot.append(total_loss / (batch+1))
    sim_plot.append(total_sim / (batch+1))

    tqdm_dataset_val = tqdm(enumerate(dataset_val), total=val_num_steps, position=0, leave=True)
    total_val_sim = 0
    for (batch, (img_tensor, target)) in tqdm_dataset_val:
        valid_cap_mask = (target[:,0] == 10)
        img_tensor = img_tensor[valid_cap_mask]
        target = target[valid_cap_mask]

        batch_val_loss, pred_list = train_step(img_tensor, target, training=False)
        smilarlity_list = calculate_similarity(target, pred_list)
        smilarlity = np.mean(smilarlity_list)
        total_val_sim += smilarlity
        total_val_loss += batch_val_loss
        if batch % 50 == 0:
            tqdm_dataset_val.set_postfix({
                'Epoch': epoch + 1,
                'Batch': batch,
                'Val Loss': '{:06f}'.format(batch_val_loss.numpy() / int(target.shape[1])),
                'Var Similarlity': smilarlity,
                'Var Accuracy':train_accuracy.result().numpy(), 
            })
        if batch % 30 == 0:
            gc.collect()
    val_loss_plot.append(total_val_loss / (batch+1))
    val_sim_plot.append(total_val_sim / (batch+1))

    ckpt_manager.save()

    output.clear()

    plt.figure()
    plt.plot(loss_plot, label='loss')
    plt.plot(val_loss_plot, label='val_loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Loss Plot')
    plt.legend()
    plt.show()

    plt.figure()
    plt.plot(sim_plot, label='similarity')
    plt.plot(val_sim_plot, label='val_Similarity')
    plt.ylim(-0.1,1.1)
    plt.xlabel('Epochs')
    plt.ylabel('Similarity')
    plt.title('Similarity Plot')
    plt.legend()
    plt.show()

    print()
    # print ('Epoch {}, Loss {:.6f}, Similiarity {:.6f}'.format(
    #     epoch + 1, loss_plot[-1], sim_plot[-1]))    
    print ('Epoch {}, Loss {:.6f}, Val loss: {:.6f}, Similiarity {:.6f}, Val similiarity: {:.6f}'.format(
        epoch + 1, loss_plot[-1], val_loss_plot[-1], sim_plot[-1], val_sim_plot[-1]))

  0%|          | 0/7576 [00:00<?, ?it/s]

InvalidArgumentError: Index out of range using input dim 1; input has only 1 dims [Op:StridedSlice] name: strided_slice/