<a href="https://colab.research.google.com/github/1ucky40nc3/ml4me/blob/main/text/grammatical-error-correction/simple-recipe-for-multilingual-gec/run_gec_fine-tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# MIT License

Copyright (c) 2023 Louis Wendler

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

# Set up Notebook

In [None]:
!nvidia-smi

In [None]:
# @title Install Dependencies

!pip install pip install git+https://github.com/huggingface/transformers
!pip install datasets
!pip install evaluate
!pip install accelerate
!pip install sentencepiece
!pip install rouge_score
!pip install haikunator

In [None]:
#@title Clone the `huggingface/transformers` Repository

!git clone https://github.com/huggingface/transformers.git

# Prepare data

In [None]:
# @title Download the FALKO-MERLIN GEC Corpus

!gdown http://www.sfs.uni-tuebingen.de/~adriane/download/wnut2018/data.tar.gz
!tar -xf data.tar.gz

In [None]:
from typing import List

import os
import re
import pandas as pd


data_files = {
    "train": [
        ("/content/data/fm-train.src", "/content/data/fm-train.trg"),
        # ("/content/data/wiki.1M.src", "/content/data/wiki.1M.trg")
    ],
    "test": [
        ("/content/data/fm-test.src", "/content/data/fm-test.trg")
    ],
    "eval": [
        ("/content/data/fm-dev.src", "/content/data/fm-dev.trg")
    ]
}


def read_lines(data_file: str) -> List[str]:
    with open(data_file, "r", encoding="utf-8") as f:
        return f.readlines()


def delete_char_at_index(string: str, index: int) -> str:
    string = list(string)
    del string[index]
    return "".join(string)


def clean_whitespace_before(string: str, regex: str = r"\s[,;\.\:\!\?\)\]\}\/]") -> str:
    pattern = re.compile(regex)
    match_object = pattern.search(string)
    while match_object is not None:
        start, stop = match_object.span()
        string = delete_char_at_index(string, start)
        match_object = pattern.search(string, pos=stop)

    return string


def clean_whitespace_after(string: str, regex: str = r"[\(\[\{\/]\s") -> str:
    pattern = re.compile(regex)
    match_object = pattern.search(string)
    while match_object is not None:
        start, stop = match_object.span()
        string = delete_char_at_index(string, start + 1)
        match_object = pattern.search(string, pos=stop)

    return string


def clean_whitespace_inside_quotes(string: str) -> str:
    single_quotes = r"'[^']*'"
    double_quotes = r'"[^"]*"'
    tokens = '(?:' + '|'.join([single_quotes, double_quotes]) + ')'
    pattern = re.compile(tokens, re.DOTALL)

    match_object = pattern.search(string)
    while match_object is not None:
        start, stop = match_object.span()
        string_slice = string[start:stop]
        cleaned_slice = clean_whitespace_after(string_slice, r"[\"\']\s")
        cleaned_slice = clean_whitespace_before(cleaned_slice, r"\s[\"\']")
        string = string.replace(string_slice, cleaned_slice, 1)
        match_object = pattern.search(string, pos=stop)

    return string


def clean(string: str) -> str:
    string = string.replace('„', '"')
    string = clean_whitespace_after(string)
    string = clean_whitespace_before(string)
    string = clean_whitespace_inside_quotes(string)
    return string


for split, src_trg_pairs in data_files.items():
    data = {"source": [], "target": []}
    for src_file, trg_file in src_trg_pairs:
        src_examples = read_lines(src_file)
        trg_examples = read_lines(trg_file)

        src_examples = [clean(e) for e in src_examples]
        trg_examples = [clean(e) for e in trg_examples]

        data["source"].extend(src_examples)
        data["target"].extend(trg_examples)


    df = pd.DataFrame(data)
    df.to_csv(os.path.join("/content/data", f"fm_{split}.csv"))

# Train!

In [None]:
#@title Create the local `runs` directory

BASE_DIR = "/content/runs"
!mkdir -p $BASE_DIR

In [None]:
#@title Implement util functions

import os
from haikunator import Haikunator


def generate_name():
    haikunator = Haikunator()
    return haikunator.haikunate()

In [None]:
#@title Start a Training Run

# Initialize constants
RUN_NAME = generate_name()
OUTPUT_DIR = os.path.join(BASE_DIR, RUN_NAME)
LOGGING_DIR = os.path.join(OUTPUT_DIR, 'logs')
!mkdir -p $LOGGING_DIR

# Load the local tensorboard
%reload_ext tensorboard
%tensorboard --logdir $LOGGING_DIR

# Start the training script
%cd /content/transformers/examples/pytorch/summarization
!python run_summarization.py \
    --model_name_or_path google/mt5-small \
    --do_train \
    --do_eval \
    --lang "de" \
    --train_file "/content/data/fm_train.csv" \
    --validation_file "/content/data/fm_eval.csv" \
    --test_file "/content/data/fm_test.csv" \
    --text_column "source" \
    --summary_column "target" \
    --source_prefix "correct errors: " \
    --max_source_length 256 \
    --max_target_length 256 \
    --per_device_train_batch_size=4 \
    --auto_find_batch_size \
    --predict_with_generate \
    --output_dir $OUTPUT_DIR \
    --overwrite_output_dir \
    --logging_dir $LOGGING_DIR \
    --report_to tensorboard \
    --run_name $RUN_NAME