In [122]:
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
from transformers import AdamWeightDecay
from tensorflow.keras.preprocessing.sequence import pad_sequences
import re
from tqdm import tqdm
import textwrap
from difflib import Differ
from sklearn.metrics import precision_recall_fscore_support

In [123]:
with open('silver_test/Multatuli_MaxHavelaar_silver.txt', 'r') as s1:
    silver_data1 = s1.readlines()
    
with open('silver_test/Nescio_Titaantjes_silver.txt', 'r') as s2:
    silver_data2 = s2.readlines()
    
with open('silver_test/ConanDoyle_SherlockHolmesDeAgraSchat_silver.txt', 'r') as s3:
    silver_data3 = s3.readlines()
    
with open('new_gold/Multatuli_MaxHavelaar_gold.txt', 'r') as g1:
    gold_data1 = g1.readlines()
    
with open('new_gold/Nescio_Titaantjes_gold.txt', 'r') as g2:
    gold_data2 = g2.readlines()
    
with open('new_gold/ConanDoyle_SherlockHolmesDeAgraSchat_gold.txt', 'r') as g3:
    gold_data3 = g3.readlines()

In [124]:
def model_predict(test_set, model, tok):
    tokenized = tok(test_set, max_length=155, padding=True, truncation=True, return_tensors='tf')
    out = model.generate(**tokenized, max_length=250)

    pred = []
    for n in out:
        pred.append(tok.decode(n, text_target=True, skip_special_tokens=True))
    return pred

In [125]:
def split_sent(data, max_length):
    short_sent = []
    long_sent = []
    for n in data:
        n = n.split('|')
        if len(n[1]) <= max_length:
            short_sent.append(n[1])
        elif len(n[1]) > max_length:
            n[1] = re.sub(r'(\s)+(?=[^[]*?\])', '@@', n[1])
            n[1] = n[1].replace("] [", "]##[")
            lines = textwrap.wrap(n[1], max_length, break_long_words=False)
            long_sent.append(lines)

    new_data = []
    for s in long_sent:
        for s1 in s:
            s1 = s1.replace(']##[', '] [')
            s1 = re.sub(r'(@@)+(?=[^[]*?\])', ' ', s1)
            new_data.append(s1)

    for x in short_sent:
        new_data.append(x)
    return new_data

In [158]:
def create_data(data):
    source_text = []
    target_text = []
    for d in data:
        d = d.split('|')
        x = d[1]
        source = []
        target = []
        spel = re.findall(r'\[.*?\]', x)
        if spel:
            for s in spel:
                s = s.split()
                if s[1] == '@alt':
                    target.append(''.join(s[2:3]))
                    source.append(''.join(s[3:-1]))
                elif s[1] == '@mwu_alt':
                    target.append(''.join(s[2:3]))
                    source.append(''.join(s[3:-1]).replace('-', ''))
                elif s[1] == '@mwu':
                    target.append(''.join(s[2:-1]))
                    source.append(' '.join(s[2:-1]))
                elif s[1] == '@postag':
                    target.append(''.join(s[-2]))
                    source.append(''.join(s[-2]))
                elif s[1] == '@phantom':
                    target.append(''.join(s[2]))
                    source.append('')

        target2 = []
        for t in target:
            if t[0] == '~':
                t = t.split('~')
                target2.append(t[1])
            else:
                target2.append(t)

        sent = re.sub(r'\[.*?\]', 'EMPTY', x)
        word_c = 0
        src = []
        trg = []
        for word in sent.split():
            if word == 'EMPTY':
                src.append(source[word_c])
                trg.append(target2[word_c])
                word_c += 1
            else:
                src.append(word)
                trg.append(word)
        source_text.append(' '.join(src))
        target_text.append(' '.join(trg))
    return source_text, target_text

In [141]:
def align_sent(pred, gold):
    predv1 = []
    goldv1 = []

    d = Differ()
    diff = list(d.compare(pred, gold))
    for l, n in enumerate(diff):
        n = n.strip()
        if n[0] == '-' and len(n) > 1:
            predv1.append(n)
        elif n[0] == '+':
            goldv1.append(n)
        elif n[0] != '?':
            goldv1.append(n)
            predv1.append(n)
    
    for lp, p in enumerate(predv1):
        p = p.strip()
        next_p = predv1[(lp+1)%(len(predv1))]
        next_pp = predv1[(lp+2)%(len(predv1))]
        if p[0] == '-' and next_p[0] == '-' and next_pp[0] == '-':
            predv1.remove(p)
            predv1.remove(next_p)
            predv1.remove(next_pp)
            predv1.insert(lp, p + ' ' + next_p + ' ' + next_pp)
                
    for lp, p in enumerate(predv1):
        p = p.strip()
        next_p = predv1[(lp+1)%(len(predv1))]
        if p[0] == '-' and next_p[0] == '-':
            predv1.remove(p)
            predv1.remove(next_p)
            predv1.insert(lp, p + ' ' + next_p)
        
    for lg, g in enumerate(goldv1):
        g = g.strip()
        next_g = goldv1[(lg+1)%(len(goldv1))]
        next_gg = goldv1[(lg+2)%(len(goldv1))]
        if g[0] == '+' and next_g[0] == '+' and next_gg[0] == '+':
            goldv1.remove(g)
            goldv1.remove(next_g)
            goldv1.remove(next_gg)
            goldv1.insert(lg, g + ' ' + next_g + ' ' + next_gg)
            
    for lg, g in enumerate(goldv1):
        g = g.strip()
        next_g = goldv1[(lg+1)%(len(goldv1))]
        if g[0] == '+' and next_g[0] == '+':
            goldv1.remove(g)
            goldv1.remove(next_g)
            goldv1.insert(lg, g + ' ' + next_g)
            
    predv2 = []
    for pred1 in predv1:
        pred1 = pred1.replace('-', ' ').strip()
        predv2.append(pred1)
        
    goldv2 = []
    for gold1 in goldv1:
        gold1 = gold1.replace('+', ' ').strip()
        goldv2.append(gold1)
            
    return predv2, goldv2

In [153]:
def evaluate_err(raw, gold, pred):
    cor = 0
    changed = 0
    total = 0

    if len(gold) != len(pred):
        return 'Error: gold normalization contains a different numer of sentences(' + str(len(gold)) + ') compared to system output(' + str(len(pred)) + ')'
       
    for sentRaw, sentGold, sentPred in zip(raw, gold, pred):
        if len(sentGold) != len(sentPred):
            return 'Error: a sentence has a different length in you output, check the order of the sentences'
        for wordRaw, wordGold, wordPred in zip(sentRaw, sentGold, sentPred):
            if wordRaw != wordGold:
                changed += 1
            if wordGold == wordPred:
                cor += 1
            total += 1

    accuracy = float(cor) / total
    lai = float(total - changed) / total
    err = float(accuracy - lai) / (1-lai)
    return 'Baseline Accuracy: {:.2f}%\nAccuracy: {:.2f}%\nError Reduction Rate: {:.2f}%'.format((lai * 100), (accuracy * 100), (err * 100)) 

In [154]:
def evaluate_pre_rec(gold, pred):
    precision = []
    recall = []
    if len(gold) != len(pred):
        return 'Error: gold normalization contains a different numer of sentences(' + str(len(gold)) + ') compared to system output(' + str(len(pred)) + ')'
       
    for sentGold, sentPred in zip(gold, pred):
        if len(sentGold) == len(sentPred):
            if len(sentGold) > 0:
                pre_rec = precision_recall_fscore_support(sentGold, sentPred, average='macro', zero_division=True)
                precision.append(pre_rec[0])
                recall.append(pre_rec[1])
    
    avg_pre = round(sum(precision) / len(precision), 4)
    avg_rec = round(sum(recall) / len(recall), 4)
    return 'Avg Precision: {:.2f}%\nAvg Recall: {:.2f}%'.format((avg_pre * 100), (avg_rec * 100)) 

In [151]:
def evaluate_rulebased(gold_data, silver_data, verbose=False):
    source_gold, target_gold = create_data(gold_data)
    source_silver, target_silver = create_data(silver_data)
    
    pred = []
    gold = []
    raw = []
    for r, p, g in zip(source_gold, target_silver, target_gold):
        r1 = r.split()
        p1 = p.split() 
        g1 = g.split() 
        if len(p1) == len(g1) and len(r1) == len(g1):
            raw.append(r1)
            pred.append(p1)
            gold.append(g1)
        elif len(p1) != len(g1):
            p1, g1 = align_sent(p1, g1)
            if len(p1) == len(g1):
                raw.append(r1)
                pred.append(p1)
                gold.append(g1)
            else:
                if verbose:
                    print('ORIG:\t', len(r1), r1)
                    print('PRED:\t', len(p1), p1)
                    print('GOLD:\t', len(g1), g1)
                    print('\n')  
    print(evaluate_err(raw, gold, pred))
    print(evaluate_pre_rec(gold, pred))

In [152]:
def evaluate_T5(original, predictions, gold_set, verbose=False):
    test_pred = []
    test_gold = []
    test_orig = []
    for orig, pred, gold in zip(original, predictions, gold_set):
        pred = pred.replace('!', ' !').replace('?', ' ?')
        pred = re.sub(r'(?<=[a-zA-Z])([,])', ' ,', pred)
        gold = gold.replace(" 'm", "'m").replace(" 's", "'s")
        orig = orig.replace(" 'm", "'m").replace(" 's", "'s")
        
        pred = pred.split()
        gold = gold.split()
        orig = orig.split()
        
        if len(orig) != len(gold):
            orig, gold = align_sent(orig, gold)
    
        word_exc = ['No.', 'enz.', 'Mr.', 'S.', 'D.', 'A.', 'P.', 'I.', 'X.']
        predv1 = []
        for p in pred:
            p = p.strip()
            if p not in word_exc:
                if p[-1] == '.' and p[-2:-1] != '.':
                    p = p[:-1] + ' ' + '.'
                    predv1.append(p)
                elif p[-1] == '.' and p[-2:-1] == '.' and p[-3:-2] == '.':
                    p = p[:-3] + ' ' + '...'
                    predv1.append(p)
                elif p[-1] == '.' and p[-2:-1] == '.':
                    p = p[:-2] + ' ' + '..'
                    predv1.append(p)
                else:
                    predv1.append(p)
            else:
                predv1.append(p)
            
        predv2 = ' '.join(predv1).split()
    
        if len(predv2) == len(gold) and len(orig) == len(gold):
            test_pred.append(predv2)
            test_gold.append(gold)
            test_orig.append(orig)
        elif len(predv2) != len(gold):
            if len(gold) > 1 and len(predv2) > 1:
                predv2, gold = align_sent(predv2, gold)
                if len(predv2) == len(gold):
                    test_pred.append(predv2)
                    test_gold.append(gold)
                    test_orig.append(orig)
                else:
                    if verbose:
                        print(len(orig), orig)
                        print(len(predv2), predv2)
                        print(len(gold), gold)
                        print('\n')
            else: 
                if verbose:
                    print(len(orig), orig)
                    print(len(predv2), predv2)
                    print(len(gold), gold)
                    print('\n')
                    
    print('{}/{} ({:.2f}%) sentences are evaluated from prediction set'
          .format(len(test_pred), len(predictions), (len(test_pred) / len(predictions) * 100)))
    print(evaluate_err(test_orig, test_gold, test_pred))
    print(evaluate_pre_rec(test_gold, test_pred))

## Testing ByT5

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/byt5-small")
model = TFAutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")
model.load_weights('ByT5_small_weights.h5')

In [71]:
print('Multatuli MaxHavelaar (ByT5)')
gold1_split = split_sent(gold_data1, 150)
source_data1, target_data1 = create_data(gold1_split) 
predictions1 = model_predict(source_data1, model, tokenizer)
evaluate_T5(source_data1, predictions1, target_data1, verbose=False)

Multatuli MaxHavelaar (ByT5)
755/764 (98.82%) sentences are evaluated from prediction set
Baseline Accuracy: 96.02%
Accuracy: 98.67%
Error Reduction Rate: 66.50%
Avg Precision: 98.47%
Avg Recall: 98.47%


In [67]:
print('Nescio Titaantjes (ByT5)')
gold2_split = split_sent(gold_data2, 150)
source_data2, target_data2 = create_data(gold2_split) 
predictions2 = model_predict(source_data2, model, tokenizer)
evaluate_T5(source_data2, predictions2, target_data2, verbose=False)

Nescio Titaantjes (ByT5)
878/883 (99.43%) sentences are evaluated from prediction set
Baseline Accuracy: 95.87%
Accuracy: 99.11%
Error Reduction Rate: 78.38%
Avg Precision: 98.87%
Avg Recall: 98.87%


In [65]:
print('ConanDoyle Sherlock Holmes De Agra Schat (ByT5)')
gold3_split = split_sent(gold_data3, 150)
source_data3, target_data3 = create_data(gold3_split) 
predictions3 = model_predict(source_data3, model, tokenizer)
evaluate_T5(source_data3, predictions3, target_data3, verbose=False)

ConanDoyle Sherlock Holmes De Agra Schat (ByT5)
688/690 (99.71%) sentences are evaluated from prediction set
Baseline Accuracy: 94.29%
Accuracy: 98.35%
Error Reduction Rate: 71.08%
Avg Precision: 98.22%
Avg Recall: 98.24%


## Testing Flan-T5

In [None]:
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model = TFAutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model.load_weights('FlanT5_small_weights.h5')

In [24]:
print('Multatuli MaxHavelaar (ByT5)')
gold1_split = split_sent(gold_data1, 150)
source_data1, target_data1 = create_data(gold1_split) 
predictions1 = model_predict(source_data1, model, tokenizer)
evaluate_T5(source_data1, predictions1, target_data1, verbose=False)

Multatuli MaxHavelaar (ByT5)
762/764 (99.74%) sentences are evaluated from prediction set
Baseline Accuracy: 96.00%
Accuracy: 97.96%
Error Reduction Rate: 48.87%
Avg Precision: 97.91%
Avg Recall: 97.91%


In [22]:
print('Nescio Titaantjes (ByT5)')
gold2_split = split_sent(gold_data2, 150)
source_data2, target_data2 = create_data(gold2_split) 
predictions2 = model_predict(source_data2, model, tokenizer)
evaluate_T5(source_data2, predictions2, target_data2, verbose=False)

Nescio Titaantjes (ByT5)
872/883 (98.75%) sentences are evaluated from prediction set
Baseline Accuracy: 95.62%
Accuracy: 97.93%
Error Reduction Rate: 52.77%
Avg Precision: 98.04%
Avg Recall: 98.05%


In [23]:
print('ConanDoyle Sherlock Holmes De Agra Schat (ByT5)')
gold3_split = split_sent(gold_data3, 150)
source_data3, target_data3 = create_data(gold3_split) 
predictions3 = model_predict(source_data3, model, tokenizer)
evaluate_T5(source_data3, predictions3, target_data3, verbose=False)

ConanDoyle Sherlock Holmes De Agra Schat (ByT5)
686/690 (99.42%) sentences are evaluated from prediction set
Baseline Accuracy: 94.31%
Accuracy: 97.10%
Error Reduction Rate: 48.93%
Avg Precision: 96.87%
Avg Recall: 96.89%


## Testing Rule-Based

In [155]:
print('Multatuli MaxHavelaar (Rule-Based)')
evaluate_rulebased(gold_data1, silver_data1, verbose=False)

Multatuli MaxHavelaar (Rule-Based)
Baseline Accuracy: 96.04%
Accuracy: 98.81%
Error Reduction Rate: 69.87%
Avg Precision: 98.93%
Avg Recall: 98.92%


In [156]:
print('Nescio Titaantjes (Rule-Based)')
evaluate_rulebased(gold_data2, silver_data2, verbose=False)

Nescio Titaantjes (Rule-Based)
Baseline Accuracy: 96.30%
Accuracy: 98.68%
Error Reduction Rate: 64.45%
Avg Precision: 98.88%
Avg Recall: 98.87%


In [157]:
print('ConanDoyle Sherlock Holmes De Agra Schat (Rule-Based)')
evaluate_rulebased(gold_data3, silver_data3, verbose=False)

ConanDoyle Sherlock Holmes De Agra Schat (Rule-Based)
Baseline Accuracy: 94.83%
Accuracy: 98.38%
Error Reduction Rate: 68.60%
Avg Precision: 98.32%
Avg Recall: 98.35%
