In [1]:
from transformers import AutoTokenizer, TFAutoModelForSeq2SeqLM
from transformers import AdamWeightDecay
import tensorflow as tf
from tensorflow.keras.preprocessing.sequence import pad_sequences
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import Counter
import re
import textwrap
import warnings
warnings.filterwarnings("ignore")

In [2]:
def read_data_gold():
    with open('Gold/Nescio_Titaantjes_gold.txt') as g1:
        data1_gold = g1.readlines()

    with open('Gold/ConanDoyle_SherlockHolmesDeAgraSchat_gold.txt') as g2:
        data2_gold = g2.readlines()

    data_gold = []
    for d1_gold in data1_gold:
        data_gold.append(d1_gold)
    for d2_gold in data2_gold:
        data_gold.append(d2_gold)
    return data_gold

In [3]:
def read_data_silver():
    with open('Silver/concept_data/Couperus_ElineVere.txt') as f1:
        data1 = f1.readlines()

    with open('Silver/concept_data/Hugo_DeEllendigen.txt') as f2:
        data2 = f2.readlines()

    with open('Silver/concept_data/Nescio_DeUitvreter.txt') as f3:
        data3 = f3.readlines()

    with open('Silver/concept_data/Nescio_Dichtertje.txt') as f4:
        data4 = f4.readlines()

    with open('Silver/concept_data/Tolstoy_AnnaKarenina.txt') as f5:
        data5 = f5.readlines()

    with open('Silver/concept_data/Multatuli_MaxHavelaar.txt') as f6:
        data6 = f6.readlines()

    with open('Silver/concept_data/Verne_ReisOmDeWereld.txt') as f7:
        data7 = f7.readlines()

    data_silver = []
    for d1 in data1:
        data_silver.append(d1)
    for d2 in data2:
        data_silver.append(d2)
    for d3 in data3:
        data_silver.append(d3)
    for d4 in data4:
        data_silver.append(d4)
    for d5 in data5:
        data_silver.append(d5)
    for d6 in data6:
        data_silver.append(d6)
    for d7 in data7:
        data_silver.append(d7)
    return data_silver

In [4]:
def create_data(data):
    sentence = []
    spel_old = []
    spel_new = []

    for x in data:
        spelling = re.findall(r'\[.*?\]', x)
        sent = re.sub(r'\[.*?\]', 'EMPTY', x)
        sentence.append(sent)
        for s in spelling:
            s = s.split()
            spel_old.append(s[3])
            spel_new.append(s[2])

    spelling_new = []
    for n in spel_new:
        n = n.split('~')
        if len(n) > 1:
            spelling_new.append(n[1])
        else:
            spelling_new.append(n[0])

    source_text = []
    target_text = []
    old_sent = []
    new_sent = []
    c = 0
    for sent in sentence:
        sent = sent.split()
        for word in sent:
            if word == 'EMPTY':
                old_sent.append(spel_old[c])
                new_sent.append(spelling_new[c])
                c += 1
            else:
                old_sent.append(word)
                new_sent.append(word)
        source_text.append(' '.join(old_sent))
        target_text.append(' '.join(new_sent))
        old_sent = []
        new_sent = []
    return source_text, target_text

In [5]:
def split_sent(data, max_length):
    short_sent = []
    long_sent = []
    for n in data:
        n = n.split('|')
        if len(n[1]) <= max_length:
            short_sent.append(n[1])
        elif len(n[1]) > max_length:
            n[1] = re.sub(r'(\s)+(?=[^[]*?\])', 'EMPTY', n[1])
            lines = textwrap.wrap(n[1], max_length, break_long_words=False)
            long_sent.append(lines)
    
    new_data = []
    for s in long_sent:
        for s1 in s:
            s1 = re.sub(r'(EMPTY)+(?=[^[]*?\])', ' ', s1)
            new_data.append(s1)
            
    for x in short_sent:
        new_data.append(x)
    return new_data

In [6]:
def spell_mistakes(source, target):
    spell_mis = []
    for x, y in zip(source, target):
        x = x.rstrip()
        y = y.rstrip()
        x = x.split()
        y = y.split()
        for x1, y1 in zip(x, y):
            if x1 != y1:
                spell_mis.append(x1 + ' ' + y1)
    return spell_mis

In [7]:
data_gold = read_data_gold()
data_silver = read_data_silver()

In [8]:
s_data = split_sent(data_silver, 150)
g_data = split_sent(data_gold, 150)

In [9]:
dev, test = train_test_split(g_data, test_size=0.3, random_state=10)

In [10]:
source_train, target_train = create_data(s_data)
source_dev, target_dev = create_data(dev)
source_test, target_test = create_data(test)

In [11]:
train_spell_mis = spell_mistakes(source_train, target_train)
train_spell_mis = Counter(train_spell_mis).most_common()

test_spell_mis = spell_mistakes(source_test, target_test)
test_spell_mis = Counter(test_spell_mis).most_common()

train_mis = []
for x in train_spell_mis:
    x = x[0].split()
    train_mis.append(x[0])

unseen_mistakes = []
for mis1 in test_spell_mis:
    m1 = mis1[0].split()
    if m1[0] not in train_mis:
        unseen_mistakes.append(m1)

In [12]:
print('train size: {}'.format(len(s_data)))
print('dev size: {}'.format(len(dev)))
print('test size: {}'.format(len(test)))

train size: 5802
dev size: 1126
test size: 483


In [19]:
tokenizer1 = AutoTokenizer.from_pretrained("google/byt5-small")
model1 = TFAutoModelForSeq2SeqLM.from_pretrained("google/byt5-small")
model1.load_weights('byt5_weights.h5')

tokenizer2 = AutoTokenizer.from_pretrained("google/flan-t5-small")
model2 = TFAutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
model2.load_weights('byt5_weights1.h5')

All model checkpoint layers were used when initializing TFT5ForConditionalGeneration.

All the layers of TFT5ForConditionalGeneration were initialized from the model checkpoint at google/byt5-small.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFT5ForConditionalGeneration for predictions without further training.
Some layers from the model checkpoint at google/flan-t5-small were not used when initializing TFT5ForConditionalGeneration: ['shared/embeddings:0']
- This IS expected if you are initializing TFT5ForConditionalGeneration from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFT5ForConditionalGeneration from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification

In [21]:
def model_predict(test_set, model, tok):
    tokenized = tok(test_set, max_length=155, padding=True, return_tensors='tf')
    out = model.generate(**tokenized, max_length=155)

    pred = []
    for n in out:
        pred.append(tok.decode(n, text_target=True, skip_special_tokens=True))

    return pred

In [22]:
print('old spelled words found in the test set,\nwhich are not in the train set:\n')
for mis in unseen_mistakes:
    print('Orig: {}\nPred: {}\nGold: {}\n'.format(mis[0], ' '.join(model_predict(mis[0], model1, tokenizer1)), mis[1]))

old spelled words found in the test set,
which are not in the train set:

Orig: Neen
Pred: Neen
Gold: Nee

Orig: te
Pred: te
Gold: tezamen

Orig: Indischen
Pred: Indischen
Gold: Indische

Orig: Oostersche
Pred: Oosterse
Gold: Oosterse

Orig: genoegelijk
Pred: genoegelijk
Gold: genoeglijk

Orig: ongelukkigen
Pred: ongelukkigen
Gold: ongelukkige

Orig: uwen
Pred: uwen
Gold: uw

Orig: noodigde
Pred: nodigde
Gold: nodigde

Orig: aantoonen
Pred: aantonen
Gold: aantonen

Orig: teekeningetjes
Pred: tekeningetjes
Gold: tekeningetjes

Orig: luchtkasteelen
Pred: luchtkastelen
Gold: luchtkastelen

Orig: vroolijken
Pred: vrolijken
Gold: vrolijke

Orig: effe
Pred: fe
Gold: effen

Orig: onderhuidsche
Pred: onderhuidse
Gold: onderhuidse

Orig: rood-steenen
Pred: rood-stenen
Gold: rood-stenen

Orig: vieren
Pred: vieren
Gold: vier

Orig: er
Pred: er
Gold: ervoor

Orig: wijs
Pred: wijs
Gold: wijze

Orig: bedaard
Pred: bedaard
Gold: bedaarde

Orig: recommandeeren
Pred: recommanderen
Gold: recommanderen



In [23]:
print('old spelled Words found in the test set,\nwhich are not in the train set:\n')
for mis in unseen_mistakes:
    print('Orig: {}\nPred: {}\nGold: {}\n'.format(mis[0], ' '.join(model_predict(mis[0], model2, tokenizer2)), mis[1]))

old spelled Words found in the test set,
which are not in the train set:

Orig: Neen
Pred: Neen
Gold: Nee

Orig: te
Pred: te
Gold: tezamen

Orig: Indischen
Pred: Indien
Gold: Indische

Orig: Oostersche
Pred: Oosterse
Gold: Oosterse

Orig: genoegelijk
Pred: genoegelijk
Gold: genoeglijk

Orig: ongelukkigen
Pred: ongelukkigen
Gold: ongelukkige

Orig: uwen
Pred: uwen
Gold: uw

Orig: noodigde
Pred: nodigde
Gold: nodigde

Orig: aantoonen
Pred: aantoonen
Gold: aantonen

Orig: teekeningetjes
Pred: teekeningetjes
Gold: tekeningetjes

Orig: luchtkasteelen
Pred: luchtkasteelen
Gold: luchtkastelen

Orig: vroolijken
Pred: vroeg
Gold: vrolijke

Orig: effe
Pred: Ve
Gold: effen

Orig: onderhuidsche
Pred: onderhuidsche
Gold: onderhuidse

Orig: rood-steenen
Pred: rood-steenen
Gold: rood-stenen

Orig: vieren
Pred: vier
Gold: vier

Orig: er
Pred: er
Gold: ervoor

Orig: wijs
Pred: wijs
Gold: wijze

Orig: bedaard
Pred: bedaard
Gold: bedaarde

Orig: recommandeeren
Pred: reuseeren
Gold: recommanderen

Orig: P

In [24]:
def check_occ_old(word):
    for train in train_spell_mis:
        t = train[0].split()
        if word in t[0]:
            print('Found in train: {}'.format(t[0]))
        
    for test in test_spell_mis:
        t1 = test[0].split()
        if word in t1[0]:
            print('Found in test: {}'.format(t1[0]))

In [27]:
check_occ_old('luchtkasteelen')

Found in test: luchtkasteelen
