# Grammar Error Correction using FLAN-T5
Code below mainly follows from Huggingface's Translation [tutorial](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb#scrollTo=MOsHUjgdIrIW)

### install packages, imports

In [1]:
# %pip install transformers[sentencepiece] datasets evaluate peft accelerate protobuf bitsandbytes

In [3]:
import transformers
import datasets

4.38.2


## Preprocessing

### Create correct_texts column
The fields we care about in the dataset are structured as below:


- **text**: gramatically incorrect text (input to model)
- **edits**: Each edit is a span `[start:end]` where the original text should be replaced by `edits.text`
  - **start**: start indexes of each edit as a list of integers
  - **end**: end indexes of each edit as a list of integers
  - **text**: the text content of each edit as a list of strings

Below I'm creating a field `correct_text` which will be used as labels during training


In [4]:
from datasets import load_dataset, load_metric
raw_datasets = load_dataset("wi_locness", 'wi')

def make_correct_sentence_column(example):
    corrected_text = example["text"]
    for start, end, correction in reversed(list(zip(example["edits"]["start"], example["edits"]["end"], example["edits"]["text"]))):
        if correction == None:
            continue # TODO: what to do with None?
        corrected_text = corrected_text[:start] + correction + corrected_text[end:]
    return {'correct_text': corrected_text}

raw_datasets = raw_datasets.map(make_correct_sentence_column, remove_columns=["id", "userid", "cefr", "edits"])
print(raw_datasets['train'][0])

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


{'text': 'My town is a medium size city with eighty thousand inhabitants. It has a high density population because its small territory. Despite of it is an industrial city, there are many shops and department stores.  I recommend visiting the artificial lake in the certer of the city which is surrounded by a park. Pasteries are very common and most of them offer the special dessert from the city. There are a comercial zone along the widest street of the city where you can find all kind of establishments: banks, bars, chemists, cinemas, pet shops, restaurants, fast food restaurants, groceries, travel agencies, supermarkets and others. Most of the shops have sales and offers at least three months of the year: January, June and August. The quality of the products and services are quite good, because there are a huge competition, however I suggest you taking care about some fakes or cheats.', 'correct_text': 'My town is a medium-sized city with eighty thousand inhabitants. It has a high-de

### Tokenization
The inputs and labels for the transformer are as below

- **Inputs** = prefix (see note at end) + grammatically incorrect sentence
- **Labels** = corrected sentence

I'm using the pretrained `sentencepiece` tokenizer (loaded with `AutoTokenizer.from_pretrained(...)`) to tokenize the inputs and labels, and then drop the original columns.

**Note on prefix**: T5 models are pretrained to follow instructions given in a task prefix before each input. I abridged some grammar-related prompts from [here](https://github.com/google-research/FLAN/blob/main/flan/v2/flan_templates_branched.py) to come up with `Produce a grammatically correct version of this text:`. The prompt was chosen pretty unscientifically, more work could be done to craft a good prompt.

In [19]:
from transformers import AutoTokenizer
model_checkpoint = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix = "Produce a grammatically correct version of this text: "
def preprocess_function(examples):
    # print([x for x in examples])
    inputs = [prefix + ex for ex in examples['text']]
    targets = examples['correct_text']
    model_inputs = tokenizer(inputs, text_target=targets, max_length=128, truncation=True)
    return model_inputs
preprocess_function(raw_datasets['train'][:2])


{'input_ids': [[22966, 15, 3, 9, 3, 16582, 144, 6402, 2024, 988, 13, 48, 1499, 10, 499, 1511, 19, 3, 9, 2768, 812, 690, 28, 2641, 63, 7863, 21155, 5, 94, 65, 3, 9, 306, 11048, 2074, 250, 165, 422, 9964, 5, 3, 4868, 13, 34, 19, 46, 2913, 690, 6, 132, 33, 186, 5391, 11, 3066, 3253, 5, 27, 1568, 3644, 8, 7353, 6957, 16, 8, 12276, 49, 13, 8, 690, 84, 19, 3, 8623, 57, 3, 9, 2447, 5, 10180, 4074, 7, 33, 182, 1017, 11, 167, 13, 135, 462, 8, 534, 7737, 45, 8, 690, 5, 290, 33, 3, 9, 15173, 2901, 590, 8, 1148, 7, 17, 2815, 13, 8, 690, 213, 25, 54, 253, 66, 773, 13, 8752, 7, 10, 5028, 6, 6448, 6, 3, 1], [22966, 15, 3, 9, 3, 16582, 144, 6402, 2024, 988, 13, 48, 1499, 10, 6656, 65, 112, 293, 1390, 5, 2449, 241, 12, 36, 3, 9, 2472, 6, 119, 241, 12, 36, 3, 9, 3145, 5, 27, 43, 82, 293, 515, 396, 68, 27, 278, 31, 17, 337, 12, 135, 6, 27, 241, 12, 582, 3, 9, 3559, 343, 5, 7301, 38, 3, 9, 3559, 343, 19, 182, 2949, 250, 186, 2081, 5, 1485, 6, 27, 56, 43, 46, 1004, 12, 619, 7979, 5, 5212, 6, 27, 54, 1111, 

NB: The attention mask a **indicates the position of the padded indices** so that the model does not attend to them.
More details can be found [here](https://huggingface.co/docs/transformers/en/glossary).

In [13]:
tokenized_datasets = raw_datasets.map(preprocess_function, batched=True, remove_columns=["text", "correct_text"])
tokenized_datasets["train"][0].keys()

Map:   0%|          | 0/3000 [00:00<?, ? examples/s]

dict_keys(['input_ids', 'attention_mask', 'labels'])

### Batching
The training dataset has sentences of different lengths, so
`DataCollatorForSeq2Seq` is used to automatically pad the input texts to the same size within each batch for training

This [api reference](https://huggingface.co/docs/transformers/en/main_classes/data_collator) contains further info about how it works.

In [14]:
from transformers import DataCollatorForSeq2Seq
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

## Model
The [model docs on T5](https://huggingface.co/docs/transformers/model_doc/t5) and [model card](https://huggingface.co/google-t5/t5-small) provide more info on the model.
### Zero-shot inference
I'm loading a pretrained Flan-T5 and using the task prompt `Produce a grammatically correct version of this text:` and using it for zero-shot inference. 

In [9]:
from transformers import AutoModelForSeq2SeqLM
# TODO: load in int8, map_to_device=auto
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint, load_in_8bit=True, device_map="auto")
input_text = f"{prefix}She are baking cake."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]



<pad> She is baking cake.</s>


It seems to be capable of correcting subject-verb agreement zero-shot (though on a very simple example). We could consider measuring how it performs zero-shot on the whole `WI_LOCNESS` dataset with some prompt engineering as an extension.

### LoRA setup

Given that our dataset is quite small and GPU resources are limited, it probably doesn't make sense to train all the weights of T5. I'm trying a  parameter efficient finetuning (PEFT) method called **LoRA (Low Rank Adaptation)**. I found [this article](https://towardsdatascience.com/understanding-lora-low-rank-adaptation-for-finetuning-large-models-936bce1a07c6) useful in understanding how it works

The LoRA creation and training code below follows mainly from this [tutorial](https://www.philschmid.de/fine-tune-flan-t5-peft) and the [docs](https://huggingface.co/docs/diffusers/v0.19.3/en/training/lora).

Also, I'm using a quantized (`int8` params instead of `fp16`) Flan-T5 in order to (a) reduce memory usage and (b) increase inference speed

In [17]:
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM
)
lora_model = get_peft_model(model, lora_config)

lora_model.print_trainable_parameters()

trainable params: 688,128 || all params: 77,649,280 || trainable%: 0.8862001038515747


Note that the % of training params when using a LoRA is very small, < 1% of the original model size. This scheme will hopefully help us train the model faster and without overfitting to our small dataset

In [16]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir="lora-flan-t5"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
	auto_find_batch_size=True,
    learning_rate=1e-3, # higher learning rate
    num_train_epochs=5,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="no",
    report_to="tensorboard",
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["val"]
)
model.config.use_cache = False

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [18]:
trainer.train()

Step,Training Loss


KeyboardInterrupt: 