# Imports

In [None]:
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
import numpy as np
from bpemb import BPEmb
import time
import datetime
import os
from model.create_mask import *
from training.printing import *

# 1. Constants

## 1.1. Paths

In [None]:
DATASET_PATH = "/run/media/ishrak/Ishrak/IUT/Thesis/dataset/tfrecords/"
TRAIN_DATASET_PATH = os.path.join(DATASET_PATH, "train")
TEST_DATASET_PATH = os.path.join(DATASET_PATH, "test")
MODEL_DIR = "/run/media/ishrak/Ishrak/IUT/Thesis/model_dir"
CKPT_PATH = os.path.join(MODEL_DIR, "ckpt")
LOG_PATH = os.path.join(MODEL_DIR, "logs")
SCORE_PATH = os.path.join(LOG_PATH, "score.csv")
LOG_TEXT_PATH = os.path.join(LOG_PATH, "log.txt")
TENSORBOARD_PATH = os.path.join(LOG_PATH, "tensorboard_logs")

## 1.2. Constant values

In [None]:
BUFFER_SIZE = 2 ** 13
BATCH_SIZE = 32
EPOCHS = 1000
NUM_LAYERS = 4
D_MODEL = 128
DFF = 512
NUM_HEADS = 8
SUMMARY_LENGTH = 16
TEXT_LENGTH = 512
START_TOKEN = 1
END_TOKEN = 2
VOCAB_SIZE = 10000
ENCODER_VOCAB_SIZE = VOCAB_SIZE
DECODER_VOCAB_SIZE = VOCAB_SIZE
VOCAB_DIM = 100
CKPT_TO_KEEP = 50

# 2. Loading Dataset
The dataset consists of articles scraped from Prothom Alo news site. The dataset contains titles, contents and tags of many article.


In [None]:
train_tfrecord_files = [
    os.path.join(TRAIN_DATASET_PATH, file_name)
    for file_name in os.listdir(TRAIN_DATASET_PATH)
]
test_tfrecord_files = [
    os.path.join(TEST_DATASET_PATH, file_name)
    for file_name in os.listdir(TEST_DATASET_PATH)
]

In [None]:
from data_manipulation.create_tfrecord_dataset import create_tfrecord_dataset

In [None]:
train_dataset = create_tfrecord_dataset(
    tfrecord_files=train_tfrecord_files,
    batch_size=BATCH_SIZE,
    cache_buffer_size=BUFFER_SIZE,
    prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
    input_feature_length=TEXT_LENGTH,
    output_feature_length=SUMMARY_LENGTH
)

In [None]:
test_dataset = create_tfrecord_dataset(
    tfrecord_files=test_tfrecord_files,
    batch_size=BATCH_SIZE,
    cache_buffer_size=BUFFER_SIZE,
    prefetch_buffer_size=tf.data.experimental.AUTOTUNE,
    input_feature_length=TEXT_LENGTH,
    output_feature_length=SUMMARY_LENGTH
)

# 2. Model

## 2.1. Model Architecture

In [None]:
from model.transformer import Transformer
transformer = Transformer(
    NUM_LAYERS, 
    D_MODEL, 
    NUM_HEADS, 
    DFF,
    ENCODER_VOCAB_SIZE, 
    DECODER_VOCAB_SIZE, 
    pe_input=ENCODER_VOCAB_SIZE, 
    pe_target=DECODER_VOCAB_SIZE,
)

## 2.2. Adam optimizer
Used adam optimizer with custom learning rate scheduling.

In [None]:
from training.custom_scheduler import CustomSchedule
learning_rate = CustomSchedule(D_MODEL)
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, epsilon=1e-9)

## 2.3. Checkpoints

In [None]:
ckpt = tf.train.Checkpoint(transformer=transformer, optimizer=optimizer)

ckpt_manager = tf.train.CheckpointManager(ckpt, CKPT_PATH, max_to_keep=CKPT_TO_KEEP)

if ckpt_manager.latest_checkpoint:
    ckpt.restore(ckpt_manager.latest_checkpoint)
    print ('Latest checkpoint restored!!')

# 3. Training

## 3.1. Defining losses and other metrics 

In [None]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction='none')

In [None]:
def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_sum(loss_)/tf.reduce_sum(mask)

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
test_loss = tf.keras.metrics.Mean(name='train_loss')

### Training step

In [None]:
@tf.function
def train_step(inp, tar):
    tar_inp = tar[:, :-1]
    tar_real = tar[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

    with tf.GradientTape() as tape:
        predictions, _ = transformer(
            inp, 
            tar_inp, 
            True, 
            enc_padding_mask, 
            combined_mask, 
            dec_padding_mask
        )
        loss = loss_function(tar_real, predictions)

    gradients = tape.gradient(loss, transformer.trainable_variables)    
    optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

    train_loss(loss)

In [None]:
@tf.function
def test_step(model, input, target):
    tar_inp = target[:, :-1]
    tar_real = target[:, 1:]

    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)
    predictions, _ = transformer(
        input, 
        tar_inp, 
        False, 
        enc_padding_mask, 
        combined_mask, 
        dec_padding_mask
    )
    loss = loss_object(tar_real, predictions)

    test_loss(loss)

## 3.2. Inference function
Predicting one word at a time at the decoder and appending it to the output; then taking the complete sequence as an input to the decoder and repeating until maxlen or stop keyword appears

In [None]:
def evaluate(input_document):
    input_document = [bpemb_bn.encode_ids(input_document)]
    input_document = pad_sequences(input_document, maxlen=TEXT_LENGTH, padding='post', truncating='post')

    encoder_input = tf.expand_dims(input_document[0], 0)

    decoder_input = [START_TOKEN]
    output = tf.expand_dims(decoder_input, 0)
    
    for i in range(SUMMARY_LENGTH):
        enc_padding_mask, combined_mask, dec_padding_mask = create_masks(encoder_input, output)

        predictions, attention_weights = transformer(
            encoder_input, 
            output,
            False,
            enc_padding_mask,
            combined_mask,
            dec_padding_mask
        )

        predictions = predictions[: ,-1:, :]
        predicted_id = tf.cast(tf.argmax(predictions, axis=-1), tf.int32)

        if predicted_id == END_TOKEN:
            return tf.squeeze(output, axis=0), attention_weights

        output = tf.concat([output, predicted_id], axis=-1)

    return tf.squeeze(output, axis=0), attention_weights


In [None]:
def summarize(input_document):
    # not considering attention weights for now, can be used to plot attention heatmaps in the future
    summarized = evaluate(input_document=input_document)[0].numpy()
    summarized = np.expand_dims(summarized[1:], 0)  # not printing <go> token
    return bpemb_bn.decode_ids(summarized[0])  # since there is just one translated document

### Test inference values

In [None]:
bpemb_bn = BPEmb(lang = "bn", vs = VOCAB_SIZE, dim = VOCAB_DIM)

In [None]:
TEST_SUMMARY = "ট্রাকের ধাক্কায় সড়কে ছিটকে পড়ে প্রবাসীর মৃত্যু"
TEST_CONTENT = "যশোরের বাঘারপাড়া উপজেলায় ট্রাকের ধাক্কায় এক প্রবাসীর মৃত্যু হয়েছে। তাঁর নাম আবু সাঈদ (৪০)। রোববার রাত আটটার দিকে উপজেলার খাজুরায় বাঘারপাড়া-কালীগঞ্জ সড়কে আরিফ ব্রিকসের সামনে এ দুর্ঘটনা ঘটে।"\
" নিহত আবু সাঈদ বাঘারপাড়া উপজেলার বন্দবিলা ইউনিয়নের দাঁতপুর গ্রামের সদর আলী দফাদারের ছেলে। তিনি মালয়েশিয়াপ্রবাসী ছিলেন। প্রত্যক্ষদর্শীর বরাত দিয়ে পুলিশ জানায়, রোববার রাতে খাজুরা বাজার থেকে বাঘারপাড়া-কালীগঞ্জ সড়ক দিয়ে"\
" বাইসাইকেলে করে বাড়ি ফিরছিলেন আবু সাঈদ। রাত আটটার দিকে তিনি আরিফ ব্রিকসের সামনে পৌঁছান। এ সময় পেছন দিক থেকে আসা বাঘারপাড়াগামী একটি দ্রুতগামী ট্রাক সাইকেলটিকে ধাক্কা দেয়। বাইসাইকেল থেকে সড়কের ওপর ছিটকে পড়ে"\
" সেখানেই সাঈদের মৃত্যু হয়। আবু সাঈদের ভাই রেজাউল দফাদার বলেন, আবু সাঈদ মালয়েশিয়ায় চাকরি করতেন। সম্প্রতি ছুটিতে তিনি দেশে এসেছিলেন।"\
" বাঘারপাড়া থানা ভারপ্রাপ্ত কর্মকর্তা (ওসি) সৈয়দ আল মামুন বলেন, ট্রাকের ধাক্কায় আবু সাঈদ নামের এক বাইসাইকেলচালক নিহত হয়েছেন।"

In [None]:
summary = summarize(TEST_CONTENT)
summary

## 3.3. Tensorboard Setup

In [None]:
current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
train_log_dir = os.path.join(TENSORBOARD_PATH,  current_time + '/train')
test_log_dir = os.path.join(TENSORBOARD_PATH, current_time + '/test')
train_summary_writer = tf.summary.create_file_writer(train_log_dir)
test_summary_writer = tf.summary.create_file_writer(test_log_dir)

## 3.4. Training

In [None]:
start = time.time()
print_header(LOG_TEXT_PATH)
step = 1

for epoch in range(1, EPOCHS+1):
    train_loss.reset_states()
    test_loss.reset_states()
    
    # training loop
    for (batch, (inp, tar)) in enumerate(train_dataset):
        train_step(inp, tar)

        if batch % 100 == 0:
            loss = train_loss.result()
            summary = summarize(TEST_CONTENT)
            print_info(epoch, batch, loss, summary, LOG_TEXT_PATH, start)
            save_score(SCORE_PATH, epoch, batch, loss)

            with train_summary_writer.as_default():
                tf.summary.scalar("per_100_batch_loss", loss, step = step)

    with train_summary_writer.as_default():
        tf.summary.scalar("loss", loss, step = epoch)

    # testing loop
    for (inp, tar) in test_dataset:
        test_step(inp, tar)

    with train_summary_writer.as_default():
        tf.summary.scalar("loss", test_loss.result(), step = epoch)
        step += 1
        

    ckpt_save_path = ckpt_manager.save()