# Preliminary steps

## Get thelibraries

In [None]:
!pip install transformers==4.27.4

In [None]:
import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import random
import re

import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import AutoTokenizer

## Get the tokeniser

In [None]:
roberta_tokenizer = AutoTokenizer.from_pretrained("pstroe/roberta-base-latin-cased")

# Prepare the training data

## Import the list of errors typical for your transcription model

In [None]:
df_error = pd.read_csv("data/errors_final.csv")

In [None]:
df_error.iloc[625]

## Aggregate it

In [None]:
df_error["is_na_correct"] = df_error["correct"].isnull()
df_error["is_na_error"] = df_error["error"].isnull()

In [None]:
def get_regexp(x):
    if x["context"] == "––":
        return x["correct"]
    if x["is_na_error"]:
        st = x["context"].index("[")
        en = x["context"].index("]", st)
        return str(x["context"][:st]) + str(x["correct"]) + str(x["context"][en + 1 :])
    if x["is_na_correct"]:
        st = x["context"].index("_")
        return x["context"][:st] + x["context"][st + 1 :]


def get_error_with_contecst(x):
    if x["context"] == "––":
        return x["error"]
    if x["is_na_error"]:
        st = x["context"].index("[")
        en = x["context"].index("]", st)
        return str(x["context"][:st]) + str(x["context"][en + 1 :])
    if x["is_na_correct"]:
        st = x["context"].index("_")
        return x["context"][:st] + x["error"] + x["context"][st + 1 :]


def get_count_error_symbol(x):
    if x["is_na_correct"]:
        return len(x["error"])
    return len(x["correct"])


df_error["regexp"] = df_error.apply(get_regexp, axis=1)
df_error["context_with_error"] = df_error.apply(get_error_with_contecst, axis=1)

In [None]:
df_error["regexp"] = df_error["regexp"].apply(lambda x: "\?" if x == "?" else x)

In [None]:
df_error["count_symbol"] = df_error.apply(
    lambda x: x["count"]
    * (len(x["error"]) if x["is_na_correct"] else len(x["correct"])),
    axis=1,
)

In [None]:
df_error

## Obtain information about the errors

### Stats calculation to evaluate the error_rate

In [None]:
text = " ".join(df["target_text"].tolist())

res_count_in = {}
for ex in tqdm(df_error["regexp"].unique(), total=len(df_error["regexp"].unique())):
    count_in = len(re.findall(repr(ex).strip("'"), text))
    res_count_in[ex] = count_in

In [None]:
res_count_error_in = {}
for ex in tqdm(df_error["regexp"].unique(), total=len(df_error["regexp"].unique())):
    count_in = df_error[df_error["regexp"] == ex]["count"].sum()
    res_count_error_in[ex] = count_in

In [None]:
df_error[df_error["regexp"] == " ceato"]

### Calculate the error_rate

In [None]:
df_error["error_rate"] = df_error.apply(
    lambda x: (
        min(res_count_error_in[x["regexp"]] / res_count_in[x["regexp"]], 0.3)
        if res_count_in[x["regexp"]] != 0
        else 0.3
    ),
    axis=1,
)

### Compiling the regexps

In [None]:
df_error["comp_regexp"] = df_error["regexp"].apply(
    lambda x: re.compile(repr(x).strip("'"))
)

In [None]:
df_error

### Get an error bank

In [None]:
error_bank = []
uniq_cont = df_error["regexp"].unique()
for ex in tqdm(uniq_cont, total=len(uniq_cont)):
    el = df_error[df_error["regexp"] == ex]
    prob = np.array(el["count"].tolist()) / el["count"].sum() * el.iloc[0]["error_rate"]
    error_bank.append(
        [ex, el["comp_regexp"].iloc[0], el["context_with_error"].tolist(), prob]
    )

## Get all the separate xmls and parse them

In [None]:
from pathlib import Path

pathlist = Path("data/PL").glob("*.xml")
pathlist = list(pathlist)
print(len(pathlist))

In [None]:
from bs4 import BeautifulSoup

res_new_str = (
    []
)  # a list created out of the cycle to aggregate the results for all files

path_t = tqdm(pathlist, total=len(pathlist))

for path in tqdm(path_t):  # progress bar for the fields
    with open(path, "r") as fp:
        soup = BeautifulSoup(fp, "xml")

    body = soup.find("body")
    if body is None:
        print(f"There is no <body> in file: {path}")
        continue

    paragraphs = body.find_all("p")  # get all the paragraphs
    if not paragraphs:
        print(f"There is no <p> in file: {path}")
        continue

    print(f"I found {len(paragraphs)} paragraphs in file: {path}")

    for tag in paragraphs:
        text = tag.text.strip()
        original_text = text  # save the text for comparaison
        text = text.replace("\n", " ")
        text = re.sub(r"[^a-zA-Z0-9 .:!]", "", text)
        text = re.sub(r"  +", " ", text)

        if not text:  # is there any text left ?
            print(f"Text filtered in: {path}")
            print(f"Original text: {original_text}")
            continue

        l_t = len(text)
        if l_t > 500:
            i = 2
            while l_t // i > 500:
                i += 1
            len_msg = l_t // i
            prev_pos = 0
            sep = "."
            while l_t - prev_pos > 500 and (prev_pos >= 0) and (l_t > prev_pos):
                cur_pos = text.find(sep, prev_pos + len_msg)
                if cur_pos == -1 or cur_pos - prev_pos > 700:
                    sep = " "
                    cur_pos = text.find(sep, prev_pos + len_msg)
                    sep = "."
                    if cur_pos == -1:
                        prev_pos = l_t
                        print("Text slicing error:", path)
                        continue
                res_new_str.append(text[prev_pos : cur_pos + 1])
                prev_pos = cur_pos + 1
            res_new_str.append(text[prev_pos:])
        else:
            res_new_str.append(text)

print("Number of lines after processing all files:", len(res_new_str))

In [None]:
print("Number of lines after processing the whole corpus:", len(res_new_str))

In [None]:
prev_pos, len(text)

## Plot texts length distribution over the corpus

In [None]:
import matplotlib.pyplot as plt

dd = [len(s) for s in res_new_str]
_ = plt.hist(dd, bins=100)

## Inject the errors

###  Function that replaces some pieces of text according to the error probability

In [None]:
def get_text_with_error(x):

    name_proc = current_process().name

    correct_text, error_bank = x
    incorrect_text = []
    for cor_txt in correct_text:
        err_txt = cor_txt  # create a copy of each correct line
        l_er = list(
            range(len(error_bank))
        )  # randomly shuffle the list of indexes of all elements from error_bank
        random.shuffle(l_er)
        for i in l_er:
            el = error_bank[i]
            coef = 3.0
            pr = list(el[3] * coef) + [1 - ((el[3] * coef).sum())]
            # check that the sum of probabilities is not greater than 1
            pr = [p if p >= 0 else 0 for p in pr]
            pr[-1] = 1 - sum(pr[:-1])

            for i, mt in enumerate(list(el[1].finditer(err_txt))):
                ind_max = np.random.multinomial(1, pr, size=1)
                if ind_max[0][-1] == 1:
                    continue
                exch = el[2][ind_max.argmax()]
                err_txt = err_txt[: mt.start()] + exch + err_txt[mt.end() :]

        incorrect_text.append(err_txt)
    return incorrect_text

In [None]:
with Pool(processes=10) as pool:
    res = pool.map(
        get_text_with_error,
        tqdm(
            [
                [res_new_str[i : i + 100], error_bank]
                for i in range(0, len(res_new_str), 100)
            ],
            total=len(res_new_str) // 100 + 1,
        ),
    )
    print(res[0])

In [None]:
res[0]

data_new = pd.DataFrame(
    list(zip(res_new_str, sum(res, []))), columns=["target", "input"]
)

In [None]:
data_new.head()

In [None]:
data_new["msg_len"] = data_new["target"].apply(lambda x: len(x))
data_new["msg_len"] = data_new["input"].apply(lambda x: len(x))

In [None]:
data_new["msg_len"].max()

In [None]:
# tokenize everything and save to new columns. calculate the length of the input_ids array
data_new["trg_target_token_len"] = data_new["target"].apply(
    lambda x: len(roberta_tokenizer(x)["input_ids"])
)
data_new["trg_input_token_len"] = data_new["input"].apply(
    lambda x: len(roberta_tokenizer(x)["input_ids"])
)

In [None]:
# only save strings shorter than 500 characters
data_new_len_500 = data_new[
    (data_new["trg_target_token_len"] > 500) | (data_new["trg_input_token_len"] > 500)
]

In [None]:
data_new = data_new[
    (data_new["trg_target_token_len"] <= 500) & (data_new["trg_input_token_len"] <= 500)
]

In [None]:
# delete consecutive punctuation marks
data_new = data_new[
    data_new["input"].apply(lambda x: re.search(r"[.:!]{3,}", x) is None)
]

In [None]:
data_new.shape

In [None]:
import pickle

# save everything to pickle
with open("tokenized_data.pkl", "wb") as f:
    pickle.dump(data_new, f)