# GEC Baseline - Transformer

## Downloading required packages

In [None]:
!pip install datasets
!pip install wandb

In [None]:
!pip install -U accelerate
!pip install -U transformers

## Importing required packages

In [None]:
import os

import torch

import wandb

import transformers
from transformers import AutoTokenizer
from transformers import DataCollatorForSeq2Seq
from transformers import AutoModelForSeq2SeqLM, AutoConfig
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

import datasets
from datasets import load_dataset, load_metric

## WandB Login

In [None]:
wandb.login()

## Data Preprocessing


### Dataset structure
The fields we care about in the dataset are structured as below
- **text**: gramatically incorrect text (input to model)
- **edits**: Each edit is a span `[start:end]` where the original text should be replaced by `edits.text`
  - **start**: start indexes of each edit as a list of integers
  - **end**: end indexes of each edit as a list of integers
  - **text**: the text content of each edit as a list of strings

### Tokenize + Create labels

The inputs and labels for the transformer are as below

- **Inputs** = grammatically incorrect text
- **Labels** = corrected text

1. I tokenize and divide each input text into smaller chunks of length at most `max_len`. A `stride` is used to create overlaps between windows (so that the model learns grammar errors that occur across boundaries of chunks). These are now the inputs to the model.
2. Then, I compute the appropriate corrected text for each "chunk" and tokenize these. These become the labels for the model.

- At the moment, I'm replacing corrections that are `None` with `tokenizer.unk_token`. An example where this occurs is printed at the end

In [None]:
raw_datasets = load_dataset("wi_locness", "wi")

In [None]:
model_checkpoint = "t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [None]:
## these are hyperparams that can be tuned
max_len = 256
stride = max_len // 2  # half length of prev window included in next window

def preprocess_function(examples):
    # how to handle prefix? not handled currently (but only relevant for T5 finetuning)
    inputs = tokenizer(
        text=examples["text"],
        max_length=max_len,
        truncation=True,
        return_overflowing_tokens=True,
        return_offsets_mapping=True,
        stride=stride,
    )

    labels_out = []
    overflow_to_sample_mapping = inputs.pop("overflow_to_sample_mapping")
    offset_mapping = inputs.pop("offset_mapping")

    for i in range(len(inputs["input_ids"])):
        example_idx = overflow_to_sample_mapping[i]

        start_idx = offset_mapping[i][0][0]
        end_idx = offset_mapping[i][-2][1]  # last token is <eos>, so we care about second last tok offset
        edits = examples["edits"][example_idx]
        corrected_text = examples["text"][example_idx][start_idx:end_idx]

        for start, end, correction in reversed(
            list(zip(edits["start"], edits["end"], edits["text"]))
        ):

            if start < start_idx or end > end_idx:
                continue
            start_offset = start - start_idx  # >= 0
            end_offset = end - start_idx
            if correction == None:
                correction = tokenizer.unk_token
            corrected_text = (
                corrected_text[:start_offset] + correction + corrected_text[end_offset:]
            )

        labels_out.append(corrected_text)

    labels_out = tokenizer(labels_out, max_length=512, truncation=True)
    inputs["labels"] = labels_out["input_ids"]

    return inputs


In [None]:
tokenized_datasets = raw_datasets.map(
    preprocess_function,
    batched=True,
    num_proc=4,
    remove_columns=raw_datasets["train"].column_names,
)

### Batching
`DataCollatorForSeq2Seq` will automatically pad the input texts to the same size and batch them for training

In [None]:
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

## Model

### Creating the model



In [None]:
# Initialize model

config = AutoConfig.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_config(config)

### Set WandB params

In [None]:
# set the wandb project where this run will be logged
os.environ["WANDB_PROJECT"] = "gec-baseline-transformer"

# save your trained model checkpoint to wandb
os.environ["WANDB_LOG_MODEL"] = "true"

# turn off watch to log faster
os.environ["WANDB_WATCH"] = "false"

### Set training arguments

In [None]:
output_dir = 'baseline-transformer'

# Define training args
training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
	auto_find_batch_size=True,
    num_train_epochs=100,
    logging_dir=f"{output_dir}/logs",
    logging_strategy="steps",
    logging_steps=500,
    evaluation_strategy="steps",
    save_strategy="epoch",
    report_to="wandb"
)

# Create Trainer instance
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    data_collator=data_collator,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"]
)

model.config.use_cache = False

### Train the model

In [None]:
# !jupyter notebook --ServerApp.iopub_data_rate_limit=1.0e10

In [None]:
trainer.train()

In [None]:
wandb.finish()

In [None]:
# !rm -rf ./baseline-transformer
# !rm -rf ./wandb

In [None]:
model.cpu()

In [None]:
def inference(sentence):
    input_ids = tokenizer(sentence, return_tensors="pt").input_ids
    outputs = model.generate(input_ids)
    print(tokenizer.decode(outputs[0]))

In [None]:
sentence = "I is go to the park."
inference(sentence)