In [None]:
# ======================================
# 📦 CELL 1–3: LOAD, AUGMENT, FINETUNE AND SAVE
# ======================================
from datasets import load_dataset, DatasetDict, concatenate_datasets
from transformers import T5Tokenizer, T5ForConditionalGeneration, TrainingArguments, Trainer, DataCollatorForSeq2Seq
import nltk
import nlpaug.augmenter.word as naw
import torch

# Download NLTK resources for nlpaug
nltk.download('averaged_perceptron_tagger')
nltk.download('punkt')

# 1. Load and filter dataset
raw_dataset = load_dataset("grammarly/coedit")
selected_tasks = ["gec", "clarity", "simplification", "paraphrase"]
raw_dataset = raw_dataset.filter(lambda x: x["task"] in selected_tasks)

def add_prefix(example):
    example["input"] = f"{example['task']}: {example['src']}"
    example["output"] = example['tgt']
    return example

raw_dataset = raw_dataset.map(add_prefix)

# 2. Use 100% for train/val
dataset = DatasetDict({
    "train": raw_dataset["train"].shuffle(seed=42).select(range(int(len(raw_dataset["train"])))),
    "validation": raw_dataset["validation"].shuffle(seed=42).select(range(int(len(raw_dataset["validation"]))))
})

# 3. Data augmentation on source
syn_aug = naw.SynonymAug(aug_src='wordnet')

def augment_data(example):
    try:
        example['input'] = syn_aug.augment(example['input'])
    except:
        pass
    return example

augmented = dataset["train"].select(range(1000)).map(augment_data)
dataset["train"] = concatenate_datasets([dataset["train"], augmented])

# 4. Tokenization
model_name = "vennify/t5-base-grammar-correction"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to("cuda" if torch.cuda.is_available() else "cpu")

def preprocess(example):
    model_inputs = tokenizer(example["input"], max_length=128, padding="max_length", truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(example["output"], max_length=128, padding="max_length", truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_train = dataset["train"].map(preprocess, batched=True)
tokenized_val = dataset["validation"].map(preprocess, batched=True)

# 5. Training setup
args = TrainingArguments(
    output_dir="./multitask-gec-finetuned",
    evaluation_strategy="epoch",
    learning_rate=3e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    save_strategy="epoch",
    save_total_limit=1,
    load_best_model_at_end=True,
    fp16=torch.cuda.is_available(),
    logging_steps=50,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_val,
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model)
)

# 6. Train and save
trainer.train()
trainer.save_model("./multitask-gec-finetuned")
tokenizer.save_pretrained("./multitask-gec-finetuned")




[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


Filter:   0%|          | 0/69071 [00:00<?, ? examples/s]

Filter:   0%|          | 0/1712 [00:00<?, ? examples/s]

Map:   0%|          | 0/47885 [00:00<?, ? examples/s]

Map:   0%|          | 0/1012 [00:00<?, ? examples/s]

Map:   0%|          | 0/1000 [00:00<?, ? examples/s]

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Roaming\nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!
[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     C:\Users\VUONGLOCTRUONG\AppData\Ro

Map:   0%|          | 0/3394 [00:00<?, ? examples/s]



Map:   0%|          | 0/50 [00:00<?, ? examples/s]

  trainer = Trainer(
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Epoch,Training Loss,Validation Loss
1,0.1813,0.285653
2,0.1829,0.275821
3,0.1655,0.270408


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].


('./multitask-gec-finetuned\\tokenizer_config.json',
 './multitask-gec-finetuned\\special_tokens_map.json',
 './multitask-gec-finetuned\\spiece.model',
 './multitask-gec-finetuned\\added_tokens.json')