# Grammar Error Correction using FLAN-T5
Code below mainly follows from Huggingface's Translation [tutorial](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/translation.ipynb#scrollTo=MOsHUjgdIrIW)

### install packages

In [1]:
# %pip install transformers[sentencepiece] datasets evaluate peft accelerate protobuf bitsandbytes

## Preprocessing

### Dataset structure
The fields we care about in the dataset are structured as below
- **text**: gramatically incorrect text (input to model)
- **edits**: Each edit is a span `[start:end]` where the original text should be replaced by `edits.text`
  - **start**: start indexes of each edit as a list of integers
  - **end**: end indexes of each edit as a list of integers
  - **text**: the text content of each edit as a list of strings

### Tokenize + Create labels

The inputs and labels for the transformer are as below

- **Inputs** = grammatically incorrect text
- **Labels** = corrected text

1. I tokenize and divide each input text into smaller chunks of length at most `max_len`. A `stride` is used to create overlaps between windows (so that the model learns grammar errors that occur across boundaries of chunks). These are now the inputs to the model.
2. Then, I compute the appropriate corrected text for each "chunk" and tokenize these. These become the labels for the model.

- At the moment, I'm replacing corrections that are `None` with `tokenizer.unk_token`. An example where this occurs is printed at the end

In [24]:
from datasets import load_dataset
from transformers import AutoTokenizer

raw_datasets = load_dataset("wi_locness", "wi")
model_checkpoint = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
prefix = "Produce a grammatically correct version of this text: "

## these are hyperparams that can be tuned
max_len = 256
stride = max_len // 2  # half length of prev window included in next window


def preprocess_function(examples):
    # how to handle prefix? not handled currently (but only relevant for T5 finetuning)
    inputs = tokenizer(
        text=examples["text"],
        max_length=max_len,
        truncation=True,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        stride=stride,
    )
    labels_out = []
    overflow_to_sample_mapping = inputs.pop("overflow_to_sample_mapping")
    offset_mapping = inputs.pop("offset_mapping")
    for i in range(len(inputs["input_ids"])):
        example_idx = overflow_to_sample_mapping[i]

        start_idx = offset_mapping[i][0][0]
        end_idx = offset_mapping[i][-2][
            1
        ]  # last token is <eos>, so we care about second last tok offset
        edits = examples["edits"][example_idx]
        corrected_text = examples["text"][example_idx][start_idx:end_idx]
        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):

            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )
        labels_out.append(corrected_text)
    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    inputs["labels"] = labels_out["input_ids"]
    return inputs


tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    num_proc=4, 
    remove_columns=raw_datasets["train"].column_names,
)
print(tokenizer.decode(tokenized_datasets["train"][500]["input_ids"]))
print(tokenizer.decode(tokenized_datasets["train"][500]["labels"]))

Map (num_proc=4):   0%|          | 0/3000 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/300 [00:00<?, ? examples/s]

depend on the public transport of their journey. As private transport is increasing day by day, it's surprises that there may not be any public transport in up coming days. Secondly, if the private transportation is growing faster, the traffic is also going to create a major problem, which effects the global warming. And the income that is getting through this servicing transport will automatically dropdown, which create a problem to the economy of the government. Another issue we can see here is unemployment. Most illiterate people choose their proficiency as public transportations like,bus service, cab service, auto service....etc, in order to full fill their basic needs.People with more unemployment may also lead to a crime rate. I hope, giving equal priority for both public and private transportation makes healthy. And we can see that controlling traffic is not going out of the hands. </s>
dependent on public transport on their journey. As private transport is increasing day by day

**Note on prefix choice**: T5 models are pretrained to follow instructions given in a task prefix before each input. I abridged some grammar-related prompts from [here](https://github.com/google-research/FLAN/blob/main/flan/v2/flan_templates_branched.py) to come up with `Produce a grammatically correct version of this text:`. The prompt was chosen pretty unscientifically, more work could be done to craft a good prompt.

Below I'm creating a field `correct_text` which will be used as labels during training

### Batching
The training dataset has sentences of different lengths, so
`DataCollatorForSeq2Seq` is used to automatically pad the input texts to the same size within each batch for training

This [api reference](https://huggingface.co/docs/transformers/en/main_classes/data_collator) contains further info about how it works.

In [25]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_checkpoint)

## Model
The [model docs on T5](https://huggingface.co/docs/transformers/model_doc/t5) and [model card](https://huggingface.co/google-t5/t5-small) provide more info on the model.
### Zero-shot inference
I'm loading a pretrained Flan-T5 and using the task prompt `Produce a grammatically correct version of this text:` and using it for zero-shot inference. 

In [9]:
from transformers import AutoModelForSeq2SeqLM

# TODO: load in int8, map_to_device=auto
model = AutoModelForSeq2SeqLM.from_pretrained(
    model_checkpoint, load_in_8bit=True, device_map="auto"
)
input_text = f"{prefix}She are baking cake."
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

outputs = model.generate(input_ids)
print(tokenizer.decode(outputs[0]))

model.safetensors:   0%|          | 0.00/308M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]



<pad> She is baking cake.</s>


It seems to be capable of correcting subject-verb agreement zero-shot (though on a very simple example). We could consider measuring how it performs zero-shot on the whole `WI_LOCNESS` dataset with some prompt engineering as an extension.

### LoRA setup

Given that our dataset is quite small and GPU resources are limited, it probably doesn't make sense to train all the weights of T5. I'm trying a  parameter efficient finetuning (PEFT) method called **LoRA (Low Rank Adaptation)**. I found [this article](https://towardsdatascience.com/understanding-lora-low-rank-adaptation-for-finetuning-large-models-936bce1a07c6) useful in understanding how it works

The LoRA creation and training code below follows mainly from this [tutorial](https://www.philschmid.de/fine-tune-flan-t5-peft) and the [docs](https://huggingface.co/docs/diffusers/v0.19.3/en/training/lora).

Also, I'm using a quantized (`int8` params instead of `fp16`) Flan-T5 in order to (a) reduce memory usage and (b) increase inference speed

In [17]:
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training

model = prepare_model_for_kbit_training(model)
lora_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.SEQ_2_SEQ_LM,
)
lora_model = get_peft_model(model, lora_config)

lora_model.print_trainable_parameters()

trainable params: 688,128 || all params: 77,649,280 || trainable%: 0.8862001038515747


Note that the % of training params when using a LoRA is very small, < 1% of the original model size. This scheme will hopefully help us train the model faster and without overfitting to our small dataset

### Training

In [16]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

output_dir = "lora-flan-t5"

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    auto_find_batch_size=True,
    learning_rate=1e-3,  # higher learning rate
    num_train_epochs=5,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=500,
    save_strategy="no",
    report_to="tensorboard",
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["val"],
)
model.config.use_cache = False

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


**ETA for training**: showed ~5 hours (tested on a rented RTX 3070 instance). I haven't run it to completion yet, just tested that it works. 

In [18]:
trainer.train()

Step,Training Loss


KeyboardInterrupt: 