Following https://huggingface.co/docs/transformers/en/training#train-with-pytorch-trainer and https://huggingface.co/docs/transformers/v4.39.3/en/tasks/summarization

In [None]:
!pip install -U transformers datasets evaluate rouge_score transformers[torch] pyspellchecker spacy

In [None]:
!python -m spacy download en_core_web_sm

In [1]:
from datasets import load_dataset, Dataset, DatasetDict
from transformers import AutoTokenizer, TrainingArguments, Trainer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
import numpy as np
import evaluate
import torch

# GEC task preprocessing

In [2]:
# Mount drive only when using google colab
# from google.colab import drive
# drive.mount('/content/drive')

In [3]:
import pandas as pd

# code taken from https://www.cl.cam.ac.uk/research/nl/bea2019st/data/corr_from_m2.py
def m2_to_df(m2, id):
    # Do not apply edits with these error types
    skip = {"noop", "UNK", "Um"}
    ori_sentences = []
    corrected_sentences = []
    for sent in m2:
        sent = sent.split("\n")
        ori_sent = sent[0].split()[1:] # Ignore "S "
        cor_sent = ori_sent.copy()
        edits = sent[1:]
        offset = 0
        for edit in edits:
            edit = edit.split("|||")
            if edit[1] in skip: continue # Ignore certain edits
            coder = int(edit[-1])
            if coder != id: continue # Ignore other coders
            span = edit[0].split()[1:] # Ignore "A "
            start = int(span[0])
            end = int(span[1])
            cor = edit[2].split()
            cor_sent[start+offset:end+offset] = cor
            offset = offset-(end-start)+len(cor)
        ori_sentences.append(" ".join(ori_sent))
        corrected_sentences.append(" ".join(cor_sent))
    df = pd.DataFrame(list(zip(ori_sentences, corrected_sentences)),columns =['original', 'corrected'])
    return df

In [4]:
with open('wi+locness/m2/ABC.train.gold.bea19.m2') as f:
    m2_train = f.read().strip().split("\n\n")
    print("Initial num of sentences:", len(m2_train))
    train_df = m2_to_df(m2_train, 0).drop_duplicates().reset_index(drop=True)
    print("After removing dup num of sentences:", len(train_df))

with open('wi+locness/m2/ABCN.dev.gold.bea19.m2') as f:
    m2_dev = f.read().strip().split("\n\n")
    dev_df = m2_to_df(m2_dev, 0)

Initial num of sentences: 34308
After removing dup num of sentences: 33493


In [5]:
dataset = DatasetDict({
    "train": Dataset.from_pandas(train_df),
    "test": Dataset.from_pandas(dev_df)
})
print(dataset)
dataset['train'][0]

DatasetDict({
    train: Dataset({
        features: ['original', 'corrected'],
        num_rows: 33493
    })
    test: Dataset({
        features: ['original', 'corrected'],
        num_rows: 4384
    })
})


{'original': 'My town is a medium size city with eighty thousand inhabitants .',
 'corrected': 'My town is a medium - sized city with eighty thousand inhabitants .'}

In [6]:
from transformers import AutoTokenizer, DataCollatorForSeq2Seq

checkpoint = "google-t5/t5-small"
# checkpoint = "saved_models/fined_tuned_gec_model/checkpoint-4500"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [7]:
from spellchecker import SpellChecker

spell = SpellChecker(distance=1)
prefix = "Correct the grammar: "

def spellcheck_doc(doc):
    tokens = doc.split(" ")
    return " ".join([spell.correction(token) or token if token.isalpha() and token.islower() else token for token in tokens])

def preprocess_function(examples, spellcheck=True):
    inputs = [prefix + (spellcheck_doc(doc) if spellcheck else doc) for doc in examples["original"]]
    model_inputs = tokenizer(inputs, max_length=128, truncation=True)

    labels = tokenizer(text_target=examples["corrected"], max_length=128, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

spellcheck_doc("This is an exemple sentence . My uncle 's hause is in Paris .")

"This is an example sentence . My uncle 's house is in Paris ."

In [8]:
tokenized_dataset = dataset.map(lambda x: preprocess_function(x, False), batched=True)
tokenized_spellchecked_dataset = dataset.map(lambda x: preprocess_function(x, True), batched=True)

Map:   0%|          | 0/33493 [00:00<?, ? examples/s]

Map:   0%|          | 0/4384 [00:00<?, ? examples/s]

Map:   0%|          | 0/33493 [00:00<?, ? examples/s]

Map:   0%|          | 0/4384 [00:00<?, ? examples/s]

# Evaluation metrics

In [9]:
import evaluate

google_bleu = evaluate.load("google_bleu")

In [31]:
import numpy as np

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    result = google_bleu.compute(predictions=decoded_preds, references=decoded_labels)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

In [24]:
# Generate output text file for a prediction
import spacy

# Load English tokenizer, tagger, parser and NER
nlp = spacy.load("en_core_web_sm")

def write_output_file(filename: str, predictions: list[str]):
    with open(filename, mode="w", encoding="utf-8") as file:
        for prediction in predictions:
            tokens = nlp(prediction)
            file.write(" ".join([i.text for i in tokens]))
            file.write("\n")

# Training/Fine tuning GEC

In [11]:
# checkpoint = "saved_models/fined_tuned_gec_model/checkpoint-xxx"
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

In [32]:
training_args = Seq2SeqTrainingArguments(
    output_dir="saved_models/fined_tuned_nospell",
    evaluation_strategy="epoch",
    learning_rate=2e-5, # can change
    per_device_train_batch_size=64, # can change depending on how powerful the gpu is
    per_device_eval_batch_size=64,
    weight_decay=0.01, # should we change?
    save_total_limit=1, 
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False, # can change to true and save the model in hub
    generation_max_length=128
)

trainer = Seq2SeqTrainer(
    model=AutoModelForSeq2SeqLM.from_pretrained(checkpoint),
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

training_args_spell = Seq2SeqTrainingArguments(
    output_dir="saved_models/fined_tuned_spell",
    evaluation_strategy="epoch",
    learning_rate=2e-5, # can change
    per_device_train_batch_size=64, # can change depending on how powerful the gpu is
    per_device_eval_batch_size=64,
    weight_decay=0.01, # should we change?
    save_total_limit=1, 
    num_train_epochs=10,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False, # can change to true and save the model in hub
    generation_max_length=128
)

trainer_spell = Seq2SeqTrainer(
    model=AutoModelForSeq2SeqLM.from_pretrained(checkpoint),
    args=training_args_spell,
    train_dataset=tokenized_spellchecked_dataset["train"],
    eval_dataset=tokenized_spellchecked_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [33]:
# Evaluate before fine tuning
outputs = trainer.predict(tokenized_dataset['test'])
outputs

PredictionOutput(predictions=array([[    0,  9510,  3380, ...,     0,     0,     0],
       [    0, 32099,     3, ...,     0,     0,     0],
       [    0,  9510,    60, ...,     0,     0,     0],
       ...,
       [    0,  9510,    60, ...,     0,     0,     0],
       [    0, 32099,    19, ...,     0,     0,     0],
       [    0,  9510,  3380, ...,     0,     0,     0]]), label_ids=array([[   94,     3,    31, ...,  -100,  -100,  -100],
       [  366,    27,    47, ...,  -100,  -100,  -100],
       [   27,   133,   114, ...,  -100,  -100,  -100],
       ...,
       [  282,    21,     8, ...,  -100,  -100,  -100],
       [   37,  1053,    21, ...,  -100,  -100,  -100],
       [   37,     3, 25360, ...,  -100,  -100,  -100]]), metrics={'test_loss': 0.805658757686615, 'test_google_bleu': 0.356, 'test_gen_len': 33.8688, 'test_runtime': 80.321, 'test_samples_per_second': 54.581, 'test_steps_per_second': 0.859})

In [38]:
# write_output_file("eval_nofinetune_nospell.txt", results['eval_predictions'])
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("eval_nofinetune_nospell.txt", lines)

In [39]:
outputs = trainer_spell.predict(tokenized_spellchecked_dataset['test'])
outputs

PredictionOutput(predictions=array([[    0,  9510,  3380, ...,     0,     0,     0],
       [    0, 32099,     3, ...,     0,     0,     0],
       [    0,  9510,    60, ...,     0,     0,     0],
       ...,
       [    0,  9510,    60, ...,     0,     0,     0],
       [    0, 32099,    19, ...,     0,     0,     0],
       [    0,  9510,  3380, ...,     0,     0,     0]]), label_ids=array([[   94,     3,    31, ...,  -100,  -100,  -100],
       [  366,    27,    47, ...,  -100,  -100,  -100],
       [   27,   133,   114, ...,  -100,  -100,  -100],
       ...,
       [  282,    21,     8, ...,  -100,  -100,  -100],
       [   37,  1053,    21, ...,  -100,  -100,  -100],
       [   37,     3, 25360, ...,  -100,  -100,  -100]]), metrics={'test_loss': 0.8107499480247498, 'test_google_bleu': 0.3575, 'test_gen_len': 33.8714, 'test_runtime': 80.9427, 'test_samples_per_second': 54.162, 'test_steps_per_second': 0.852})

In [40]:
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("eval_nofinetune_spell.txt", lines)

In [41]:
trainer.train()


Epoch,Training Loss,Validation Loss,Google Bleu,Gen Len
1,0.5317,0.352965,0.8049,25.3919
2,0.4173,0.323551,0.8137,25.5315
3,0.3923,0.310672,0.8174,25.5285
4,0.3773,0.30259,0.8201,25.5347
5,0.3675,0.297122,0.8209,25.5436
6,0.3582,0.293801,0.8221,25.5671
7,0.3559,0.290907,0.8228,25.5737
8,0.348,0.289429,0.8231,25.563
9,0.3463,0.288686,0.8233,25.5705
10,0.3481,0.28819,0.8234,25.5739


TrainOutput(global_step=5240, training_loss=0.3824155297898154, metrics={'train_runtime': 974.8562, 'train_samples_per_second': 343.569, 'train_steps_per_second': 5.375, 'total_flos': 7210768254468096.0, 'train_loss': 0.3824155297898154, 'epoch': 10.0})

In [42]:
trainer_spell.train()

Epoch,Training Loss,Validation Loss,Google Bleu,Gen Len
1,0.536,0.358873,0.8017,25.3458
2,0.4244,0.329997,0.8134,25.3814
3,0.3989,0.316959,0.8172,25.3755
4,0.3833,0.310818,0.8205,25.3828
5,0.3727,0.30365,0.8211,25.3978
6,0.3638,0.300568,0.8226,25.4245
7,0.3613,0.297306,0.8231,25.458
8,0.3545,0.295757,0.8235,25.4352
9,0.3515,0.294793,0.8238,25.4364
10,0.3539,0.294578,0.8239,25.4391


TrainOutput(global_step=5240, training_loss=0.3881325248543543, metrics={'train_runtime': 972.0713, 'train_samples_per_second': 344.553, 'train_steps_per_second': 5.391, 'total_flos': 7169043593035776.0, 'train_loss': 0.3881325248543543, 'epoch': 10.0})

In [43]:
outputs = trainer.predict(tokenized_dataset['test'])
outputs

PredictionOutput(predictions=array([[   0,   94,    3, ..., -100, -100, -100],
       [   0,  366,   27, ..., -100, -100, -100],
       [   0,   27,  133, ..., -100, -100, -100],
       ...,
       [   0,  282,   21, ..., -100, -100, -100],
       [   0,   37, 1137, ..., -100, -100, -100],
       [   0,   37,    3, ..., -100, -100, -100]]), label_ids=array([[   94,     3,    31, ...,  -100,  -100,  -100],
       [  366,    27,    47, ...,  -100,  -100,  -100],
       [   27,   133,   114, ...,  -100,  -100,  -100],
       ...,
       [  282,    21,     8, ...,  -100,  -100,  -100],
       [   37,  1053,    21, ...,  -100,  -100,  -100],
       [   37,     3, 25360, ...,  -100,  -100,  -100]]), metrics={'test_loss': 0.28819018602371216, 'test_google_bleu': 0.8234, 'test_gen_len': 25.5739, 'test_runtime': 56.8397, 'test_samples_per_second': 77.129, 'test_steps_per_second': 1.214})

In [44]:
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("eval_finetune_nospell.txt", lines)

In [45]:
outputs = trainer_spell.predict(tokenized_spellchecked_dataset['test'])
outputs

PredictionOutput(predictions=array([[   0,   94,    3, ..., -100, -100, -100],
       [   0,  366,   27, ..., -100, -100, -100],
       [   0,   27,  133, ..., -100, -100, -100],
       ...,
       [   0,  282,   21, ..., -100, -100, -100],
       [   0,   37, 1137, ..., -100, -100, -100],
       [   0,   37,    3, ..., -100, -100, -100]]), label_ids=array([[   94,     3,    31, ...,  -100,  -100,  -100],
       [  366,    27,    47, ...,  -100,  -100,  -100],
       [   27,   133,   114, ...,  -100,  -100,  -100],
       ...,
       [  282,    21,     8, ...,  -100,  -100,  -100],
       [   37,  1053,    21, ...,  -100,  -100,  -100],
       [   37,     3, 25360, ...,  -100,  -100,  -100]]), metrics={'test_loss': 0.29457783699035645, 'test_google_bleu': 0.8239, 'test_gen_len': 25.4391, 'test_runtime': 56.3609, 'test_samples_per_second': 77.784, 'test_steps_per_second': 1.224})

In [47]:
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("eval_finetune_spell.txt", lines)

# Inference step of GEC

In [56]:
text = {"original": ["The main aim of ERRANT is to automatic anotate parallel English sentences with error type information , specifically , given an original and corrected sentence pair , ERRANT will extract an edits that transform the former to the latter and classify them according to a rule - based error type framework ."], "corrected": [""]}

inference_dataset = Dataset.from_dict(text).map(lambda x: preprocess_function(x, False), batched=True)
outputs = trainer.predict(inference_dataset)
tokenizer.decode(outputs.predictions[0], skip_special_tokens=True)

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

'The main aim of ERRANT is to automate parallel English sentences with error type information, specifically, given an original and corrected sentence pair, ERRANT will extract an edit that transforms the latter to the latter and classifies them according to a rule - based error type framework.'

In [58]:
with open("wi+locness/test/ABCN.test.bea19.orig", encoding="utf-8") as file:
    test_lines = file.read().strip().split("\n")
test_dict = {"original": test_lines, "corrected": ["" for _ in test_lines]}
test_dataset_spellcheck = Dataset.from_dict(test_dict).map(preprocess_function, batched=True)
test_dataset_nospell = Dataset.from_dict(test_dict).map(lambda x: preprocess_function(x, False), batched=True)

Map:   0%|          | 0/4477 [00:00<?, ? examples/s]

Map:   0%|          | 0/4477 [00:00<?, ? examples/s]

In [59]:
outputs = trainer.predict(test_dataset_nospell)
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("competition_finetune_nospell.txt", lines)

In [60]:
outputs = trainer_spell.predict(test_dataset_spellcheck)
lines = tokenizer.batch_decode(np.where(outputs.predictions != -100, outputs.predictions, tokenizer.pad_token_id), skip_special_tokens=True)
write_output_file("competition_finetune_spell.txt", lines)

In [70]:
with open("eval_orig.txt", mode="w", encoding="utf-8") as file:
    for line in dev_df['original']:
        file.write(line)
        file.write("\n")

with open("eval_corr.txt", mode="w", encoding="utf-8") as file:
    for line in dev_df['corrected']:
        file.write(line)
        file.write("\n")
