# vt1 evaluation

In [None]:
import os
import sys
import random
import torch

from pathlib import Path
from torch.utils.data import DataLoader
from IPython.display import SVG, display

### 0. Add packages to python path

In [None]:
b_paths = [os.path.abspath(os.path.join('..', '..', '..')), os.path.abspath(os.path.join('..', '..')), os.path.abspath(os.path.join('..', '..', 'scripts'))]
for b_path in b_paths:
    if b_path not in sys.path:
        sys.path.append(b_path)

BASE_DIR = Path(os.getcwd()).parent.parent.parent.resolve()
%cd $BASE_DIR

## 1. Load model and test set

In [None]:
from models.scripts.generate_dataset import WordDatasetGenerator
from models.scripts.transformer.Transformer import Transformer
from models.scripts.transformer.utils import preprocess_dataset, seed_all, build_vocab, strokes_to_svg, load_json_hypeparameters, pad_collate_fn, tensor_to_word
from models.scripts.utils import Levenshtein_Normalized_distance

In [None]:
VERSION = "vt1"
SEED = 2021
BATCH_SIZE = 256
seed_all(SEED)

In [None]:
hp = load_json_hypeparameters(VERSION)
VOCAB = dict(hp['vocab'])
VOCAB = build_vocab(a for a in VOCAB.keys() if a not in ['<pad>', '<bos>', '<eos>', '<unk>'])
hp['vocab']=VOCAB
N_TOKENS = len(VOCAB)
PAD_IDX = VOCAB['<pad>']
BOS_IDX = VOCAB['<bos>']
EOS_IDX = VOCAB['<eos>']

print(f"Number of Tokens: {N_TOKENS}\n")
print({VOCAB.itos[i]: i for i in range(N_TOKENS)}) # Token order

In [None]:
d_gen = WordDatasetGenerator(vocab = VOCAB, fname="words_stroke_100_155805")
test = d_gen.generate_from_cache(mode='test')

test_set = DataLoader(preprocess_dataset(test, VOCAB,  os.path.join(d_gen.fname, "test.pt"), total_len=d_gen.get_learning_set_length("test")), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_fn)

In [None]:
model = Transformer(name=VERSION, **hp)
model.load_best_version()
model.to(model.device)

## 2. Test on a single expression (0 and 1 required)

In [None]:
test_set_iter = iter(test_set)
x_pred, y_pred = next(test_set_iter)

x_pred = x_pred.to(model.device)
y_pred = y_pred.to(model.device)

In [None]:
ind = random.choice(range(0, y_pred.shape[0]))
print("Index:", ind, "\n")


svg_str = strokes_to_svg(x_pred[ind], {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))

eos_tensor = torch.zeros(x_pred[ind].size(-1)) + EOS_IDX

prediction, (cross_att, dec_att, enc_att), _ = model.predict(x_pred[ind].unsqueeze(0))

gt = tensor_to_word(y_pred[ind], VOCAB)
gt_list = [i for i in y_pred[ind].tolist() if i != 1]
gt_length = len([i for i in y_pred[ind] if i not in [PAD_IDX, BOS_IDX, EOS_IDX]])

print("Ground truth => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ground Truth: {''.join(gt)} (len={gt_length})")
print(f"- Prediction: {''.join(prediction)} (len={len(prediction)-2})")

print(f"Normalized Levenshtein distance is: {Levenshtein_Normalized_distance(a=''.join(gt).strip('<bos>').strip('<pad>').strip('<eos>'), b=''.join(prediction).strip('<bos>').strip('<eos>').strip('<pad>'))}")


In [None]:
model.trace_and_export(src=x_pred[ind].unsqueeze(0), trg=y_pred[ind].unsqueeze(0), version=f"{VERSION}_single_test")

## 3. Evaluate on test set (0 and 1 required)

### Compute average test set cross-entropy loss (XEL)

In [None]:
test_loss = model.evaluate_f(test_set)

print(f'Test Loss: {test_loss:.3f}')

In [None]:
### ATTENTION #### it would need to load in memory the whole test set!!!!

model.trace_and_export(src=test_set, trg=test_set, version=f"{VERSION}_test_set")

### Compute Normalized Levensthein accuracy, Character Error Rate and Word Error Rate

In [None]:
test_set_iter = iter(test_set)
metrics = model.evaluate_multiple(test_set_iter, ["Lev_acc", "CER", "WER"])
print(f"\nNormalized Levenshtein accuracy of test set is: {metrics['Lev_acc']}")
print(f"\nCharacter Error Rate of test set is: {metrics['CER']}")
print(f"\nWord Error Rate of test set is: {metrics['WER']}")

In [None]:
test_set_iter = iter(test_set)
wer = model.evaluate_WER(test_set_iter)

In [None]:
print(f"\nWord Error Rate of test set is: {wer}")

## 4. Visualization (0,1,2 required)

### Cross-attention visualization

In [None]:
print("Index:", ind)
ind_list = [i for i in prediction]

model.display_encoder_self_attention(x_pred[ind], x_pred[ind], enc_att)

In [None]:
model.display_decoder_self_attention(ind_list, ind_list, dec_att)

In [None]:
model.display_cross_attention(x_pred[ind], ind_list, cross_att)