# en-en11 evaluation

In [None]:
import os
import sys
import random
import torch
import warnings
from tokenizers import Tokenizer

from pathlib import Path
from torch.utils.data import DataLoader
from IPython.display import SVG, display

### 0. Add packages to python path

In [None]:
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

In [None]:
b_paths = [os.path.abspath(os.path.join('..', '..', '..')), os.path.abspath(os.path.join('..', '..')), os.path.abspath(os.path.join('..', '..', 'scripts'))]
for b_path in b_paths:
    if b_path not in sys.path:
        sys.path.append(b_path)

BASE_DIR = Path(os.getcwd()).parent.parent.parent.resolve()
%cd $BASE_DIR

## 1. Load model and test set

In [None]:
from models.scripts.generate_dataset import WordDatasetGenerator
from models.scripts.transformer.PreEncoders import Conv1DTransformer
from models.scripts.transformer.utils import preprocess_dataset, seed_all, strokes_to_svg, load_json_hypeparameters, pad_collate_fn
from models.scripts.utils import Levenshtein_Normalized_distance

In [None]:
VERSION = "en-en11"
SEED = 2021
BATCH_SIZE = 256
seed_all(SEED)

In [None]:
TOKENIZER_FILE = os.path.join("word_sources","tokenizer-big_en-normalized.json")
VOCAB = Tokenizer.from_file(TOKENIZER_FILE)

BOS_IDX = VOCAB.token_to_id('<bos>')
EOS_IDX = VOCAB.token_to_id('<eos>')
PAD_IDX = VOCAB.token_to_id('<pad>')

N_TOKENS = VOCAB.get_vocab_size() # len(VOCAB)
print(f"Number of Tokens: {N_TOKENS}\n")
print(sorted(VOCAB.get_vocab()))

In [None]:
d_gen = WordDatasetGenerator(vocab = VOCAB, fname="words_stroke_en_full")
test = d_gen.generate_from_cache(mode="test")

test_set = DataLoader(preprocess_dataset(test, VOCAB,  os.path.join(d_gen.fname+"_en", "test.pt"), total_len=d_gen.get_learning_set_length("test")), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_fn)

In [None]:
hp = load_json_hypeparameters(VERSION)
if "vocab" in hp:
    hp.pop("vocab")
model = Conv1DTransformer(name=VERSION, vocab=VOCAB, **hp)
model.count_parameters()
print(f"Conv trainable parameters: {sum(p.numel() for p in model.preencoder.parameters() if p.requires_grad):,}.")
print(f"Encoder trainable parameters: {sum(p.numel() for p in model.encoder.parameters() if p.requires_grad):,}.")
print(f"Decoder trainable parameters: {sum(p.numel() for p in model.decoder.parameters() if p.requires_grad):,}.")
model.load_best_version()
model.to(model.device)

## 2. Test on a single expression (0 and 1 required)

In [None]:
test_set_iter = iter(test_set)
x_pred, y_pred = next(test_set_iter)

x_pred = x_pred.to(model.device)
y_pred = y_pred.to(model.device)

In [None]:
ind = random.choice(range(0, y_pred.shape[0]))
print("Index:", ind, "\n")


svg_str = strokes_to_svg(x_pred[ind], {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))

eos_tensor = torch.zeros(x_pred[ind].size(-1)) + d_gen.eos_idx

prediction, (cross_att, dec_att, enc_att), token_ids = model.predict(x_pred[ind].unsqueeze(0))
prediction = prediction.replace(" ", '').replace("Ġ", " ")

gt = VOCAB.decode(y_pred[ind].tolist()).replace(" ", '').replace("Ġ", " ")
gt_list = [i for i in y_pred[ind].tolist() if i != 1]
gt_length = len(gt)-2

print("Ground truth => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ground Truth: {''.join(gt)} (len={gt_length})")
print(f"- Prediction: {prediction} (len={len(prediction)-2})")

print(f"Normalized Levenshtein distance is: {Levenshtein_Normalized_distance(gt, prediction)}")


In [None]:
model.trace_and_export(src=x_pred[ind].unsqueeze(0), trg=y_pred[ind].unsqueeze(0), version=f"{VERSION}_single_test")

## 3. Evaluate on test set (0 and 1 required)

### Compute average test set cross-entropy loss (XEL)

In [None]:
test_loss = model.evaluate_f(test_set)

print(f'Test Loss: {test_loss:.3f}')

In [None]:
### ATTENTION #### it would need to load in memory the whole test set!!!!

model.trace_and_export(src=test_set, trg=test_set, version=f"{VERSION}_test_set")

### Compute Normalized Levensthein accuracy, Character Error Rate and Word Error Rate

In [None]:
test_set_iter = iter(test_set)
metrics = model.evaluate_multiple(test_set_iter, ["Lev_acc", "CER", "WER"])
print(f"\nNormalized Levenshtein accuracy of test set is: {metrics['Lev_acc']}")
print(f"\nCharacter Error Rate of test set is: {metrics['CER']}")
print(f"\nWord Error Rate of test set is: {metrics['WER']}")

## 4. Visualization (0,1,2 required)

### Cross-attention visualization

In [None]:
print("Index:", ind)
model.display_encoder_self_attention(x_pred[ind], x_pred[ind], enc_att)

In [None]:
ind_list = [VOCAB.id_to_token(i).replace(" ", '').replace("Ġ", " ") for i in token_ids]
model.display_decoder_self_attention(ind_list, ind_list, dec_att)

In [None]:
ind_list = [VOCAB.id_to_token(i).replace(" ", '').replace("Ġ", " ") for i in token_ids]
model.display_cross_attention(x_pred[ind], ind_list, cross_att)