### External libraries imports

In [None]:
from pathlib import Path
from random import randint
import numpy as np
from tensorflow import Variable, argmax, int32
from tensorflow.keras.optimizers import RMSprop
from tensorflow.keras.losses import SparseCategoricalCrossentropy, BinaryCrossentropy
from tensorflow.keras.callbacks import ModelCheckpoint, LambdaCallback, TerminateOnNaN, TensorBoard, EarlyStopping

### Internal imports

In [None]:
from multilevel_diacritizer.model import MultiLevelDiacritizer
from multilevel_diacritizer.metrics import DiacritizationErrorRate, WordErrorRate
from multilevel_diacritizer.constants import (
DEFAULT_EMBEDDING_SIZE, DEFAULT_LSTM_SIZE, DEFAULT_DROPOUT_RATE, DEFAULT_WINDOW_SIZE, DEFAULT_SLIDING_STEP,
DEFAULT_BATCH_SIZE, DEFAULT_PARAMS_DIR, DEFAULT_MONITOR_METRIC, DEFAULT_EARLY_STOPPING_STEPS, DEFAULT_TRAIN_STEPS
)

### Dataset files

In [None]:
from pathlib import Path

#@title Dataset files
TRAIN_DATA_FILES = [str(p) for p in Path('data/tashkeela_train/').glob('tashkeela_train_*.txt')] #@param {type:"raw"}
VAL_DATA_FILES = [str(p) for p in Path('data/tashkeela_val/').glob('tashkeela_val_*.txt')] #@param {type:"raw"}
TEST_DATA_FILES = [str(p) for p in Path('data/tashkeela_test/').glob('tashkeela_test_*.txt')] #@param {type:"raw"}

### Construction of the model 

In [None]:
model = MultiLevelDiacritizer(window_size=DEFAULT_WINDOW_SIZE, lstm_size=DEFAULT_LSTM_SIZE,
                              dropout_rate=DEFAULT_DROPOUT_RATE, embedding_size=DEFAULT_EMBEDDING_SIZE)
model.summary(positions=[.45, .6, .75, 1.])

### Loading the training data

In [None]:
train_set = MultiLevelDiacritizer.get_processed_window_dataset(
            TRAIN_DATA_FILES, DEFAULT_BATCH_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_SLIDING_STEP
        )

### Loading the validation data

In [None]:
val_set = MultiLevelDiacritizer.get_processed_window_dataset(
    VAL_DATA_FILES, DEFAULT_BATCH_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_SLIDING_STEP
)

### Compiling the model and loading the weights

In [None]:
model.compile(RMSprop(0.001),
                      [SparseCategoricalCrossentropy(from_logits=True, name='primary_loss'),
                       SparseCategoricalCrossentropy(from_logits=True, name='secondary_loss'),
                       BinaryCrossentropy(from_logits=True, name='shadda_loss'),
                       BinaryCrossentropy(from_logits=True, name='sukoon_loss')])
model_path = DEFAULT_PARAMS_DIR / Path(
    f'{model.name}-E{DEFAULT_EMBEDDING_SIZE}L{DEFAULT_LSTM_SIZE}W{DEFAULT_WINDOW_SIZE}S{DEFAULT_SLIDING_STEP}.h5'
)
if model_path.exists():
    print('Loading model weights from %s ...' % str(model_path))
    model.load_weights(str(model_path), by_name=True, skip_mismatch=True)
else:
    print('Initializing random weights for the model %s ...' % model.name)

### Training the model

In [None]:
last_epoch_path = DEFAULT_PARAMS_DIR / Path('last_epoch.txt')

def write_epoch(epoch, logs):
    with last_epoch_path.open('w') as f:
        print(epoch, file=f)
        print(logs, file=f)

def get_initial_epoch():
    if last_epoch_path.exists():
        with last_epoch_path.open() as f:
            return int(f.readline())
    return 0

def get_diacritization_preview(val_set, sliding_step, model, limit):
    x, (pri, sec, sh, su) = next(iter(
            val_set['dataset'].skip(randint(1, val_set['size'] - 1)).take(1)
        ))
    x, pri, sec, sh, su = x[:limit], pri[:limit], sec[:limit], sh[:limit], su[:limit]
    predicted = model.predict_sentence_from_input_batch(x, sliding_step).numpy().decode('UTF-8')
    real = model.generate_real_sentence_from_batch(
        (x, [pri, sec, sh, su]),
        sliding_step
    )
    return predicted, real


model.fit(train_set['dataset'].repeat(), steps_per_epoch=train_set['size'], epochs=DEFAULT_TRAIN_STEPS,
          initial_epoch=get_initial_epoch(),
          validation_data=val_set['dataset'].repeat(), validation_steps=val_set['size'],
          callbacks=[ModelCheckpoint(str(model_path), save_best_only=True, save_weights_only=True,
                                     monitor=DEFAULT_MONITOR_METRIC), TerminateOnNaN(),
                     EarlyStopping(monitor=DEFAULT_MONITOR_METRIC, patience=DEFAULT_EARLY_STOPPING_STEPS, verbose=1),
                     LambdaCallback(
                         on_epoch_end=lambda epoch, logs: print(
                             'Predicted diacritization: %s\nReal diacritization: %s' %
                             get_diacritization_preview(val_set, DEFAULT_SLIDING_STEP, model, 100)
                         )
                     ), LambdaCallback(on_epoch_end=write_epoch), TensorBoard()
                     ]
          )

### Loading the testing data

In [None]:
test_set = MultiLevelDiacritizer.get_processed_window_dataset(
    TEST_DATA_FILES, DEFAULT_BATCH_SIZE, DEFAULT_WINDOW_SIZE, DEFAULT_SLIDING_STEP
)

### Testing the model

In [None]:
model = MultiLevelDiacritizer(window_size=DEFAULT_WINDOW_SIZE, lstm_size=DEFAULT_LSTM_SIZE,
                              dropout_rate=DEFAULT_DROPOUT_RATE, embedding_size=DEFAULT_EMBEDDING_SIZE)
model_path = DEFAULT_PARAMS_DIR / Path(
    f'{model.name}-E{DEFAULT_EMBEDDING_SIZE}L{DEFAULT_LSTM_SIZE}W{DEFAULT_WINDOW_SIZE}S{DEFAULT_SLIDING_STEP}.h5'
)
if model_path.exists():
    print('Loading model weights from %s ...' % str(model_path))
    model.load_weights(str(model_path), by_name=True, skip_mismatch=True)
else:
    print('Weights file for the selected model is not found in %s. The model weights are initialized randomly.' % str(model_path.parent))

der = Variable(0.0)
wer = Variable(0.0)
count = Variable(0.0)
print('Calculating DER and WER...')
for i, (x, diacs) in test_set['dataset'].enumerate(1):
    pri_pred, sec_pred, sh_pred, su_pred = model(x)
    pred_diacs = [
        MultiLevelDiacritizer.combine_windows(argmax(v, axis=2, output_type=int32), DEFAULT_SLIDING_STEP)
        for v in model(x)
    ]
    x = MultiLevelDiacritizer.combine_windows(x, DEFAULT_SLIDING_STEP)
    diacs = [MultiLevelDiacritizer.combine_windows(v, DEFAULT_SLIDING_STEP) for v in diacs]
    diacritics = MultiLevelDiacritizer.decode_encoded_diacritics(diacs)
    pred_diacritics = MultiLevelDiacritizer.decode_encoded_diacritics(pred_diacs)
    der.assign_add(1 - DiacritizationErrorRate.char_acc((diacritics, pred_diacritics, x)))
    wer.assign_add(1 - WordErrorRate.word_acc((diacritics, pred_diacritics, x)))
    count.assign_add(1)
    print('Batch %d/%d: DER = %f | WER = %f' % (i, test_set['size'], (der / count).numpy(), (wer / count).numpy()))