# Grammar Error Correction using FLAN-T5
Code below mainly follows from Huggingface's Translation [tutorial](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb#scrollTo=MOsHUjgdIrIW)

### install packages, imports

In [1]:
pip install transformers[sentencepiece] datasets evaluate peft accelerate flash-attn bitsandbytes

[0mNote: you may need to restart the kernel to use updated packages.


## Preprocessing

In [2]:
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("wi_locness", 'wi')

# def make_correct_sentence_column(example):
#     corrected_text = example["text"]
#     for start, end, correction in reversed(list(zip(example["edits"]["start"], example["edits"]["end"], example["edits"]["text"]))):
#         if correction == None:
#             continue # TODO: what to do with None?
#         corrected_text = corrected_text[:start] + correction + corrected_text[end:]
#     return {'correct_text': corrected_text}

# raw_datasets = raw_datasets.map(make_correct_sentence_column, remove_columns=["id", "userid", "cefr", "edits"])
# print(raw_datasets['train'][0])

### Tokenization
The inputs and labels for the transformer are as below

- **Inputs** = prefix (see note at end) + grammatically incorrect sentence
- **Labels** = corrected sentence

I'm using the pretrained `sentencepiece` tokenizer (loaded with `AutoTokenizer.from_pretrained(...)`) to tokenize the inputs and labels, and then drop the original columns.

**Note on prefix**: T5 models are pretrained to follow instructions given in a task prefix before each input. I abridged some grammar-related prompts from [here](https://github.com/google-research/FLAN/blob/main/flan/v2/flan_templates_branched.py) to come up with `Produce a grammatically correct version of this text:`. The prompt was chosen pretty unscientifically, more work could be done to craft a good prompt.

In [3]:
from transformers import AutoTokenizer
model_checkpoint = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix = "Correct spelling, punctuation and grammatical errors in this text:"
len_tokenized_prefix = len(tokenizer(prefix, add_special_tokens=False)["input_ids"])
def preprocess_function(examples):
    # print([x for x in examples])
    inputs = [prefix + ex for ex in examples['text']]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, return_offsets_mapping=True)
    labels_out = []
    offset_mapping = model_inputs.pop("offset_mapping")
    for i in range(len(model_inputs["input_ids"])):
        example_idx = i

        start_idx = offset_mapping[i][len_tokenized_prefix][0]
        end_idx = offset_mapping[i][-2][
            1
        ]  # last token is <eos>, so we care about second last tok offset

        edits = examples["edits"][example_idx]
        
        corrected_text = inputs[example_idx][start_idx:end_idx]
        start_idx -= len(prefix)
        end_idx -= len(prefix)

        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):
            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )
        labels_out.append(corrected_text)
    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    model_inputs["labels"] = labels_out["input_ids"]
    return model_inputs
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=raw_datasets['train'].column_names)


In [4]:
print(tokenizer.batch_decode(tokenized_datasets["train"]['input_ids'][512]))
print(tokenizer.batch_decode(tokenized_datasets["train"]['labels'][512]))

['Correct', 'spelling', ',', 'punct', 'u', 'ation', 'and', '', 'gram', 'matic', 'al', 'errors', 'in', 'this', 'text', ':', 'My', 'favourite', 'sport', 'is', 'cricket', '.', 'I', 'love', 'cricket', 'very', 'much', 'since', 'from', 'my', 'school', 'time', '.', 'cricket', 'is', '', 'a', 'game', 'of', 'bat', 'and', 'ball', 'in', 'which', 'there', 'are', 'two', 'teams', 'which', 'have', 'eleven', 'players', 'on', 'each', 'side', '.', 'generally', 'we', 'are', 'using', '', 'cri', 'ket', 'ground', 'as', '', 'a', 'oval', 'shape', '.', '', '</s>']
['My', 'favourite', 'sport', 'is', 'cricket', '.', 'I', 'have', 'loved', 'cricket', 'very', 'much', 'since', 'from', 'my', 'school', 'days', '.', 'Cricket', 'is', '', 'a', 'game', 'of', 'bat', 'and', 'ball', 'in', 'which', 'there', 'are', 'two', 'teams', 'which', 'have', 'eleven', 'players', 'on', 'each', 'side', '.', '', 'Generally', ',', 'we', 'use', '', 'a', 'cricket', 'ground', 'which', 'has', 'an', 'oval', 'shape', '.', '', '</s>']


NB: The attention mask a **indicates the position of the padded indices** so that the model does not attend to them.
More details can be found [here](https://huggingface.co/docs/transformers/en/glossary).

### Batching
The training dataset has sentences of different lengths, so
`DataCollatorForSeq2Seq` is used to automatically pad the input texts to the same size within each batch for training

This [api reference](https://huggingface.co/docs/transformers/en/main_classes/data_collator) contains further info about how it works.

In [5]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

## Model
The [model docs on T5](https://huggingface.co/docs/transformers/model_doc/t5) and [model card](https://huggingface.co/google-t5/t5-small) provide more info on the model.
### Zero-shot inference
I'm loading a pretrained Flan-T5 and using the task prompt `Produce a grammatically correct version of this text:` and using it for zero-shot inference. 

In [6]:
from transformers import AutoModelForSeq2SeqLM
# TODO: load in int8, map_to_device=auto
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, load_in_8bit=True, device_map="auto")
input_text = f"{prefix}She are baking cake."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


<pad> She are baking cake.</s>


### LoRA setup

Given that our dataset is quite small and GPU resources are limited, it probably doesn't make sense to train all the weights of T5. I'm trying a  parameter efficient finetuning (PEFT) method called **LoRA (Low Rank Adaptation)**. I found [this article](https://towardsdatascience.com/understanding-lora-low-rank-adaptation-for-finetuning-large-models-936bce1a07c6) useful in understanding how it works

The LoRA creation and training code below follows mainly from this [tutorial](https://www.philschmid.de/fine-tune-flan-t5-peft) and the [docs](https://huggingface.co/docs/diffusers/v0.19.3/en/training/lora).

Also, I'm using a quantized (`int8` params instead of `fp16`) Flan-T5 in order to (a) reduce memory usage and (b) increase inference speed

In [7]:
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=64,
    lora_alpha=16,
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
lora_model = get_peft_model(model, lora_config)

lora_model.print_trainable_parameters()

trainable params: 18,874,368 || all params: 802,024,448 || trainable%: 2.3533407300820834


Note that the % of training params when using a LoRA is very small, < 1% of the original model size. This scheme will hopefully help us train the model faster and without overfitting to our small dataset

### Training

In [8]:
# %pip install wandb
import wandb
import os
wandb.login()
os.environ["WANDB_ENTITY"] = "ay2324s2-cs4248-team-47"
os.environ["WANDB_PROJECT"]="finetune-pretrained-transformer"
os.environ["WANDB_LOG_MODEL"] = "end"

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mjayanth-b[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir="lora-flan-t5"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
	auto_find_batch_size=True,
    learning_rate=2e-4, # higher learning rate
    num_train_epochs=5,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=1,
    save_strategy="steps",
    save_steps=0.1,
    evaluation_strategy="steps",
    eval_steps=0.1,
    report_to="wandb",
    load_best_model_at_end=True
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=lora_model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)
model.config.use_cache = False

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


**ETA for training**: showed ~5 hours (tested on a rented RTX 3070 instance). I haven't run it to completion yet, just tested that it works. 

In [10]:
trainer.train()



Step,Training Loss,Validation Loss
188,0.2267,0.312162
376,0.284,0.292979
564,0.3364,0.278271
752,0.2454,0.271645
940,0.3769,0.274238
1128,0.3118,0.271458
1316,0.2493,0.266688
1504,0.2866,0.267584
1692,0.2129,0.268464


dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)
Detected kernel version 5.4.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


TrainOutput(global_step=1875, training_loss=0.29663439728021623, metrics={'train_runtime': 3317.9959, 'train_samples_per_second': 4.521, 'train_steps_per_second': 0.565, 'total_flos': 2.8902671751020544e+16, 'train_loss': 0.29663439728021623, 'epoch': 5.0})

In [11]:
wandb.finish()

VBox(children=(Label(value='72.054 MB of 72.054 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
eval/loss,█▅▃▂▂▂▁▁▁
eval/runtime,▇▂▂▂▄▄▁█▂
eval/samples_per_second,▂▇▇▇▅▅█▁▇
eval/steps_per_second,▂▇▇▇▅▅█▁▇
train/epoch,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/global_step,▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train/grad_norm,▂▁▁▂▁▂▅▅▄▁█▂▄▂▄▄▅▃▄▄▄▃▃▇▂▂▃▆▆▆▆▇▄▁▆█▂▂▄▃
train/learning_rate,████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁
train/loss,▇█▆▄▄▄▅▃▅▂▇▅▅▃▄▄▄▄▄▃▄▃▂▅▁▃▂▃▅▃▅▄▃▂▃▅▃▃▂▃

0,1
eval/loss,0.26846
eval/runtime,15.8914
eval/samples_per_second,18.878
eval/steps_per_second,2.391
total_flos,2.8902671751020544e+16
train/epoch,5.0
train/global_step,1875.0
train/grad_norm,0.05636
train/learning_rate,0.0
train/loss,0.287


In [13]:
from transformers import pipeline, AutoModelForSeq2SeqLM
from peft import PeftModel, PeftConfig
ckpt_path = "./lora-flan-t5/checkpoint-1316"
config = PeftConfig.from_pretrained(ckpt_path)
model = AutoModelForSeq2SeqLM.from_pretrained(config.base_model_name_or_path, load_in_8bit=True)
model = PeftModel.from_pretrained(model, ckpt_path)
model = model.merge_and_unload(safe_merge=True)

generator = pipeline("text2text-generation", model=model, device_map="auto", tokenizer=tokenizer)
INSTRUCTION = prefix




The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.
`low_cpu_mem_usage` was None, now set to True since model is quantized.


In [15]:
original_text = input("Enter text to correct: ")
prompt = f'{INSTRUCTION}{original_text}'
outputs = generator(prompt, do_sample=True, max_new_tokens=512,num_beams=5, num_return_sequences=1)
print(outputs[0])

Enter text to correct:  My favourite sport is volleyball because I love plays with my friends. Volleyball is a sport play every place, when I travel on the beach I like plays with my sister in the sand and after we are going to the sea. It is very funny. when I was young I like plays with the ball in the playground and my friend and I played using the soccer goals as a network of volleyball


{'generated_text': 'My favourite sport is volleyball because I love playing with my friends. Volleyball is a sport played everywhere. When I travel on the beach, I like playing with my sister in the sand and after we go to the sea. It is very funny. When I was young, I liked playing with the ball in the playground and my friend and I played using soccer goals as a volleyball net.'}
