# Ablation studies
Feed the model with wrong/incomplete sentences.
Look at model's ability to spot errors, mispells, measures

In [None]:
import os
import sys
import random
import torch
import warnings
from tokenizers import Tokenizer

from tqdm import tqdm
from pathlib import Path
from torch.utils.data import DataLoader
from IPython.display import SVG, display

### 0. Add packages to python path

In [None]:
%load_ext autoreload
%autoreload 2
warnings.filterwarnings('ignore')

In [None]:
b_paths = [os.path.abspath(os.path.join('..', '..', '..')), os.path.abspath(os.path.join('..', '..')), os.path.abspath(os.path.join('..', '..', 'scripts'))]
for b_path in b_paths:
    if b_path not in sys.path:
        sys.path.append(b_path)

BASE_DIR = Path(os.getcwd()).parent.parent.parent.resolve()
%cd $BASE_DIR

## 1. Load model and test set

In [None]:
from models.scripts.generate_dataset import WordDatasetGenerator
from models.scripts.transformer.printable_models import ExplainableTransformer
from models.scripts.transformer.utils import preprocess_dataset, seed_all, strokes_to_svg, load_json_hypeparameters, pad_collate_fn
from models.scripts.utils import Levenshtein_Normalized_distance

In [None]:
VERSION = "en-fr12"
SEED = 2021
BATCH_SIZE = 256
EXPR_MODE="all"
seed_all(SEED)

In [None]:
TOKENIZER_FILE = os.path.join("word_sources","tokenizer-big_fr-normalized.json")
VOCAB = Tokenizer.from_file(TOKENIZER_FILE)

BOS_IDX = VOCAB.token_to_id('<bos>')
EOS_IDX = VOCAB.token_to_id('<eos>')
PAD_IDX = VOCAB.token_to_id('<pad>')

N_TOKENS = VOCAB.get_vocab_size() # len(VOCAB)
print(f"Number of Tokens: {N_TOKENS}\n")
print(sorted(VOCAB.get_vocab()))

In [None]:
d_gen = WordDatasetGenerator(vocab = VOCAB, fname="words_stroke_fr_full")
test = d_gen.generate_from_cache(mode="test")

test_set = DataLoader(preprocess_dataset(test, VOCAB,  os.path.join(d_gen.fname+"_fr", "test.pt"), total_len=d_gen.get_learning_set_length("test")), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_fn)

In [None]:
hp = load_json_hypeparameters(VERSION)
if "vocab" in hp:
    hp.pop("vocab")
model = ExplainableTransformer(name=VERSION, vocab=VOCAB, **hp)
model.count_parameters()
print(f"Conv trainable parameters: {sum(p.numel() for p in model.preencoder.parameters() if p.requires_grad):,}.")
print(f"Encoder trainable parameters: {sum(p.numel() for p in model.encoder.parameters() if p.requires_grad):,}.")
print(f"Decoder trainable parameters: {sum(p.numel() for p in model.decoder.parameters() if p.requires_grad):,}.")
model.load_best_version()
model.to(model.device)

## 2. Ablation studies

### Ablation functions

In [None]:
def find_sentence_ending_with_symbol(t_set, symbol, required_strokes=1):
    id_symbol = VOCAB.token_to_id(symbol)
    for x_batch, y_batch in t_set:   # iterate batches
        #randomize elements in batch
        idx = torch.randperm(x_batch.shape[0])
        x_batch = x_batch[idx]
        y_batch = y_batch[idx]
        for ind in range(y_batch.shape[0]):  #iterate over sentences
            x = x_batch[ind]
            y = y_batch[ind]
            try:
                where_y = y.tolist().index(id_symbol)
                where = -1
                t = x[where, :]
                eos_tensor = torch.zeros(x.size(-1)) + d_gen.eos_idx
                pad_tensor = torch.zeros(x.size(-1)) + d_gen.padding_value
                #find last element of the input (not padding or eos)
                while torch.all(t.eq(eos_tensor)) or torch.all(t.eq(pad_tensor)):
                    where -=1
                    t = x[where, :]
                # if last symbol requires more than 1 stroke, start replacing earlier
                where -= (required_strokes-1)
                where_x = x.size(0)+where
                return x, y, where_x, where_y
            except:
                continue

    return None, None, None, None

In [None]:
def replace_symbol(x, y, where_x, where_y):
    eos_tensor = torch.zeros(x.size(-1)) + d_gen.eos_idx
    pad_tensor = torch.zeros(x.size(-1)) + d_gen.padding_value
    # replace first stroke of the symbol with eos_tensor
    x[where_x, :] = eos_tensor
    y[where_y] = EOS_IDX
    # set to padding all successive ones
    for j in range(where_x+1, x_pred.size(1)):
        x[j, :] = pad_tensor
    for j in range(where_y+1, len(y)):
        y[j] = PAD_IDX
    return x, y

### Remove final dot

In [None]:
x_pred, y_pred, position_x, position_y = find_sentence_ending_with_symbol(test_set, symbol=".")

print("Original input:")
svg_str = strokes_to_svg(x_pred, {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))


gt = VOCAB.decode(y_pred.tolist()).replace(" ", '').replace("Ġ", " ")
gt_list = [i for i in y_pred.tolist() if i != 1]
gt_length = len(gt)-2

print("Ground truth => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ground Truth : {''.join(gt)} (len={gt_length})")

x_pred, y_pred = replace_symbol(x_pred, y_pred, position_x, position_y)

print("Ablated input:")
svg_str = strokes_to_svg(x_pred, {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))

prediction, (cross_att, dec_att, enc_att), token_ids = model.predict(x_pred.unsqueeze(0).to(model.device))
prediction = prediction.replace(" ", '').replace("Ġ", " ")

gt = VOCAB.decode(y_pred.tolist()).replace(" ", '').replace("Ġ", " ")
gt_list = [i for i in y_pred.tolist() if i != 1]
gt_length = len(gt)-2

print("Ablated Indices => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ablated new string: {''.join(gt)} (len={gt_length})")
print(f" -- New Prediction: {prediction} (len={len(prediction)-2})")

#print(f"Normalized Levenshtein distance is: {Levenshtein_Normalized_distance(gt, prediction)}")

In [None]:
def evaluate_ablation_metrics(symbol, req_strokes=1):
    TP = 0
    FP = 0
    FN = 0
    TN = 0

    for x_batch, y_batch in tqdm(test_set):
        for ind in range(y_batch.shape[0]):  #iterate over sentences
            x = x_batch[ind]
            y = y_batch[ind]
            #print(x.shape, y.shape)
            x_f, y_f, pos_x, pos_y = find_sentence_ending_with_symbol([(x.unsqueeze(0), y.unsqueeze(0))], symbol, req_strokes)
            if x_f is None:  # doesn't contain the symbol, it affects the precision
                predict, _, _ = model.predict(x.unsqueeze(0).to(model.device))
                predict = predict.replace(" ", '').replace("Ġ", " ")
                if predict.endswith(symbol):
                    FP += 1
                else:
                    TN += 1
            else:
                x, y = replace_symbol(x_f, y_f, pos_x, pos_y)
                predict, _, _ = model.predict(x.unsqueeze(0).to(model.device))
                predict = predict.replace(" ", '').replace("Ġ", " ")
                if predict.endswith(symbol):
                    TP += 1
                else:
                    FN += 1
    return TP, TN, FP, FN

In [None]:
SYMBOL="."
TP, TN, FP, FN = evaluate_ablation_metrics(SYMBOL)

In [None]:
print("True positives: ",TP)
print("True negatives: ",TN)
print("False positives: ",FP)
print("False negatives: ",FN)
print()
print("Accuracy: ", (TP+TN)/(TP+TN+FP+FN))
print("Precision: ", TP/(TP+FP))
print("Recall: ", TP/(TP+FN))

### Remove final question_mark

In [None]:
x_pred, y_pred, position_x, position_y = find_sentence_ending_with_symbol(test_set, symbol="?", required_strokes=2)

print("Original input:")
svg_str = strokes_to_svg(x_pred, {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))


gt = VOCAB.decode(y_pred.tolist()).replace(" ", '').replace("Ġ", " ")
gt_list = [i for i in y_pred.tolist() if i != 1]
gt_length = len(gt)-2

print("Ground truth => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ground Truth : {''.join(gt)} (len={gt_length})")

x_pred, y_pred = replace_symbol(x_pred, y_pred, position_x, position_y)

print("Ablated input:")
svg_str = strokes_to_svg(x_pred, {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))

prediction, _, _ = model.predict(x_pred.unsqueeze(0).to(model.device))
prediction = prediction.replace(" ", '').replace("Ġ", " ")

gt = VOCAB.decode(y_pred.tolist()).replace(" ", '').replace("Ġ", " ")
gt_list = [i for i in y_pred.tolist() if i != 1]
gt_length = len(gt)-2

print("Ablated Indices => ", gt_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ablated new string: {''.join(gt)} (len={gt_length})")
print(f" -- New Prediction: {prediction} (len={len(prediction)-2})")

#print(f"Normalized Levenshtein distance is: {Levenshtein_Normalized_distance(gt, prediction)}")

In [None]:
SYMBOL="?"
TP, TN, FP, FN = evaluate_ablation_metrics(SYMBOL, req_strokes=2)

In [None]:
print("True positives: ",TP)
print("True negatives: ",TN)
print("False positives: ",FP)
print("False negatives: ",FN)
print()
print("Accuracy: ", (TP+TN)/(TP+TN+FP+FN))
print("Precision: ", TP/(TP+FP))
print("Recall: ", TP/(TP+FN))

## Misspell

In [None]:
from models.scripts.generate_dataset import WordGenerator

misspelled_path = os.path.join(BASE_DIR, "word_sources", "misspell_test", "french_misspelled.txt")
words = WordGenerator().generate_from_file(misspelled_path, words_only=False)

d_gen = WordDatasetGenerator(vocab = VOCAB,
                             expr_mode=EXPR_MODE,
                             words=words,
                             train_split=0.0,
                             valid_split=0.0,
                             extended_dataset=False,
                             fname="misspelled_fr")
d_gen.generate()
test = d_gen.generate_from_cache(mode="test")

In [None]:
miss_en = DataLoader(preprocess_dataset(test, VOCAB,  os.path.join(d_gen.fname, "test.pt"), total_len=d_gen.get_learning_set_length("test")), batch_size=BATCH_SIZE, shuffle=False, collate_fn=pad_collate_fn)
x_pred, y_pred = next(iter(miss_en))

In [None]:
ind = random.choice(range(0, y_pred.shape[0]))
print("Index:", ind, "\n")

print("Ablated input:")
svg_str = strokes_to_svg(x_pred[ind], {'height':100, 'width':100}, d_gen.padding_value, BOS_IDX, EOS_IDX)
display(SVG(data = svg_str))


prediction, _ , _ = model.predict(x_pred[ind].unsqueeze(0).to(model.device))
prediction = prediction.replace(" ", '').replace("Ġ", " ")

abl = VOCAB.decode(y_pred[ind].tolist()).replace(" ", '').replace("Ġ", " ")
abl_list = [i for i in y_pred[ind].tolist() if i != 1]
abl_length = len(abl)-2

print("Ablated input indices => ", abl_list , '\n')

# Show ground truth and prediction along with the lengths of the words/glyphs
print(f"Ablated input: {''.join(abl)} (len={abl_length})")
with open(os.path.join(BASE_DIR, "word_sources", "misspell_test", "french_corrected.txt"), "r", encoding="utf-8") as gtf:
    for i, line in enumerate(gtf):
        if i == ind:
            break

gt = ' '+ WordGenerator().clean_sentence(line).replace("'","")
print(f"Corr sentence: {''.join(gt)} (len={len(gt)-1})")

print(f"- Prediction : {''.join(prediction)} (len={len(prediction)-2})")

print()
print(f"- Gt - Ablated LA is: {Levenshtein_Normalized_distance(gt, abl)}")
print(f"Gt -Prediction LA is: {Levenshtein_Normalized_distance(gt, prediction)}")
print(f"Ablated - Pred LA is: {Levenshtein_Normalized_distance(abl, prediction)}")

In [None]:
def compute_correction_metrics(t_set, corr_path):
    gt_abl_LD = 0
    gt_pred_LD = 0
    abl_pred_LD = 0
    with open(corr_path, "r", encoding="utf-8") as gtf:
        for b_x, b_y in t_set:
            b_x = b_x.to(model.device)
            b_y = b_y.to(model.device)
            assert len(b_x) == len(b_y), "Mismatch in test dimensions"
            for i, (x,y) in enumerate(zip(b_x, b_y), start=1):
                prediction, _ , _ = model.predict(x.unsqueeze(0).to(model.device))
                prediction = prediction.replace(" ", '').replace("Ġ", " ")
                ablation = VOCAB.decode(y.tolist()).replace(" ", '').replace("Ġ", " ")
                gt = ' '+ WordGenerator().clean_sentence(gtf.readline()).replace("'","")

                gt_abl = Levenshtein_Normalized_distance(gt, ablation)
                gt_pred = Levenshtein_Normalized_distance(gt, prediction)
                abl_pred = Levenshtein_Normalized_distance(ablation, prediction)

                gt_abl_LD = gt_abl_LD + ((gt_abl - gt_abl_LD) / i)
                gt_pred_LD = gt_pred_LD + ((gt_pred - gt_pred_LD) / i)
                abl_pred_LD = abl_pred_LD + ((abl_pred - abl_pred_LD) / i)

    return 1-gt_abl_LD, 1-gt_pred_LD, 1-abl_pred_LD

In [None]:
gt_abl_LA, gt_pred_LA, abl_pred_LA = compute_correction_metrics(miss_en, os.path.join(BASE_DIR, "word_sources", "misspell_test", "french_corrected.txt"))

In [None]:
import math
print(f"Similarity between ground truth and ablated input is:  {gt_abl_LA} ({round(gt_abl_LA, 4)*100}%).")
print(f"This means that we artificially added a degree of error in the db of {1-gt_abl_LA} ({round((1-gt_abl_LA), 4)*100}%).")
print()
print(f"Model's predictions when input is ablated are on average accurate at {gt_pred_LA} ({round(gt_pred_LA, 4)*100}%).")
print(f"So the total error of the model is {1-gt_pred_LA} ({round((1-gt_pred_LA), 4)*100}%).")
print()
print(f"Model's degree of similarity with the ablated input is {abl_pred_LA} ({round(abl_pred_LA, 4)*100}%).")
print()
s = ((1-gt_pred_LA) +  (1-gt_abl_LA) + (1- abl_pred_LA))/2
Area = math.sqrt(s*(s-1+gt_abl_LA)*(s-1+gt_pred_LA)*(s-1+abl_pred_LA))
only_pred_error = 2*Area/(1-gt_abl_LA)
abl_and_pred_error = math.sqrt((1-gt_pred_LA)**2-(only_pred_error**2))
abl_non_pred_error = (1-gt_abl_LA)-abl_and_pred_error

if gt_pred_LA > gt_abl_LA:
    print(f"This means that on average the model learned to recognize incorrect input, reducing the total amount of error to {1-gt_pred_LA}. {gt_pred_LA-gt_abl_LA} lower than the one that we artificially introduced.")

else:
    print(f"The model is actually adding more error than the one it is able to recognize and correct.")
print()
print(f"{only_pred_error} ({round(only_pred_error, 4)*100}%) is the additional recognition error introduced by the model.")
print(f"{abl_and_pred_error} ({round(abl_and_pred_error, 4)*100}%) is the error artificially introduced through ablation and that was resembled by the model.")
print(f"{abl_non_pred_error} ({round(abl_non_pred_error, 4)*100}%) is the error that was artificially introduced but it was corrected by the model.")
