In [2]:
import json, os, torch
from transformers import (
                        T5Tokenizer,
                        T5ForConditionalGeneration,
                        TrainingArguments, 
                        Trainer
                        )

from datasets import load_dataset

  from .autonotebook import tqdm as notebook_tqdm





In [3]:
ROOTS = [
        'data/fce/json/fce.train.json',
        'data/fce/json/fce.dev.json',
        'data/fce/json/fce.test.json'
        ]

SPLITS = [
        'train', 
        'valid', 
        'test'
        ]

save_dir = 'data/fce/final'

In [4]:
def replace_multiple_substrings(
                                original_string, 
                                replacements
                                ):
    replacements.sort(key=lambda x: x[0])
    result = original_string
    offset = 0  
    for start_index, end_index, new_substring in replacements:
        adjusted_start = start_index + offset
        adjusted_end = end_index + offset
        if adjusted_start < 0 or adjusted_end > len(result) or adjusted_start > adjusted_end:
            print(f"Error: Invalid indices for replacement '{new_substring}'. Skipping.")
            continue

        result = result[:adjusted_start] + str(new_substring) + result[adjusted_end:]
        offset += len(str(new_substring)) - (end_index - start_index)
    return result

if not os.path.exists(save_dir):
    for root, split in zip(ROOTS, SPLITS):
        data = []
        data_points = []
        with open(root, 'r') as f:
            for line in f:
                data.append(json.loads(line))
        
        for i in range(len(data)):
            str_data = data[i]['text']
            re_data = data[i]['edits'][0][1]
            modified_string = replace_multiple_substrings(str_data, [data[:3] for data in re_data])
            
            data_point = {
                'original': str_data,
                'corrected': modified_string
            }
            data_points.append(data_point)
        
        with open(os.path.join(save_dir, split+'.json'), 'w') as f:
            json.dump(data_points, f, indent=4)

In [5]:
dataset_train = load_dataset(
                            'json', 
                            data_files='data/fce/final/train.json', 
                            split='train'
                            )
dataset_valid = load_dataset(
                            'json', 
                            data_files='data/fce/final/valid.json', 
                            split='train'
                            )
dataset_test = load_dataset(
                            'json', 
                            data_files='data/fce/final/test.json', 
                            split='train'
                            )

Generating train split: 2116 examples [00:00, 37634.62 examples/s]
Generating train split: 159 examples [00:00, 15324.92 examples/s]
Generating train split: 194 examples [00:00, 14464.15 examples/s]


In [5]:
OUT_DIR = 'results/grammar-error-correction'
MODEL = 't5-small'
MAX_LENGTH = 256
BATCH_SIZE = 16
NUM_WORKERS = 8
EPOCHS = 70

In [6]:
tokenizer = T5Tokenizer.from_pretrained(MODEL)
model = T5ForConditionalGeneration.from_pretrained(MODEL)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=512, bias=False)
              (k): Linear(in_features=512, out_features=512, bias=False)
              (v): Linear(in_features=512, out_features=512, bias=False)
              (o): Linear(in_features=512, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 8)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseActDense(
              (wi): Linear(in_features=512, out_features=2048, bias=False)
              (wo): Linear(in_features=2048, out_features=512, bias=False)
              (dropout): Drop

In [7]:
def preprocess_function(
                        examples, 
                        tokenizer=tokenizer,
                        MAX_LENGTH=MAX_LENGTH
                        ):
    inputs = [f"rectify: {inc}" for inc in examples['original']]
    model_inputs = tokenizer(
                            inputs, 
                            max_length=MAX_LENGTH, 
                            truncation=True,
                            padding='max_length'
                            )

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(
                        examples['corrected'], 
                        max_length=MAX_LENGTH, 
                        truncation=True,
                        padding='max_length'
                        )
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
tokenized_train = dataset_train.map(
                                    preprocess_function, 
                                    batched=True,
                                    num_proc=8
                                    )

tokenized_valid = dataset_valid.map(
                                    preprocess_function, 
                                    batched=True,
                                    num_proc=8
                                    )
tokenized_test = dataset_test.map(
                                preprocess_function, 
                                batched=True,
                                num_proc=8
                                )

In [8]:
total_params = sum(p.numel() for p in model.parameters())
print(f"{total_params:,} total parameters.")

total_trainable_params = sum(
    p.numel() for p in model.parameters() if p.requires_grad)
print(f"{total_trainable_params:,} training parameters.")

60,506,624 total parameters.
60,506,624 training parameters.


In [10]:
training_args = TrainingArguments(
                                output_dir=OUT_DIR,          
                                num_train_epochs=EPOCHS,
                                per_device_train_batch_size=BATCH_SIZE,
                                per_device_eval_batch_size=BATCH_SIZE*2,
                                warmup_steps=500,
                                weight_decay=0.01,
                                logging_dir=OUT_DIR,
                                evaluation_strategy='steps',
                                save_steps=500,
                                eval_steps=500,
                                load_best_model_at_end=True,
                                save_total_limit=2,
                                report_to='tensorboard',
                                dataloader_num_workers=NUM_WORKERS
                                )

trainer = Trainer(
                model=model,
                args=training_args,
                train_dataset=tokenized_train,
                eval_dataset=tokenized_valid,
                )

history = trainer.train()

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False)
  5%|▌         | 500/9310 [21:02<4:43:24,  1.93s/it] 

{'loss': 2.119, 'learning_rate': 5e-05, 'epoch': 3.76}


                                                    
  5%|▌         | 500/9310 [21:23<4:43:24,  1.93s/it]

{'eval_loss': 0.5771849155426025, 'eval_runtime': 21.0081, 'eval_samples_per_second': 7.569, 'eval_steps_per_second': 0.238, 'epoch': 3.76}


 11%|█         | 1000/9310 [38:26<4:26:12,  1.92s/it]

{'loss': 0.548, 'learning_rate': 4.7162315550510784e-05, 'epoch': 7.52}


                                                     
 11%|█         | 1000/9310 [38:46<4:26:12,  1.92s/it]

{'eval_loss': 0.49688783288002014, 'eval_runtime': 20.4095, 'eval_samples_per_second': 7.79, 'eval_steps_per_second': 0.245, 'epoch': 7.52}


 16%|█▌        | 1500/9310 [55:49<4:13:17,  1.95s/it] 

{'loss': 0.4821, 'learning_rate': 4.432463110102157e-05, 'epoch': 11.28}


                                                     
 16%|█▌        | 1500/9310 [56:09<4:13:17,  1.95s/it]

{'eval_loss': 0.46675047278404236, 'eval_runtime': 20.3717, 'eval_samples_per_second': 7.805, 'eval_steps_per_second': 0.245, 'epoch': 11.28}


 21%|██▏       | 2000/9310 [1:13:12<5:43:53,  2.82s/it] 

{'loss': 0.4444, 'learning_rate': 4.1486946651532346e-05, 'epoch': 15.04}


                                                       
 21%|██▏       | 2000/9310 [1:13:32<5:43:53,  2.82s/it]

{'eval_loss': 0.44291672110557556, 'eval_runtime': 20.1024, 'eval_samples_per_second': 7.909, 'eval_steps_per_second': 0.249, 'epoch': 15.04}


 27%|██▋       | 2500/9310 [1:30:24<3:39:15,  1.93s/it] 

{'loss': 0.4181, 'learning_rate': 3.8649262202043134e-05, 'epoch': 18.8}


                                                       
 27%|██▋       | 2500/9310 [1:30:44<3:39:15,  1.93s/it]

{'eval_loss': 0.4322322607040405, 'eval_runtime': 20.3395, 'eval_samples_per_second': 7.817, 'eval_steps_per_second': 0.246, 'epoch': 18.8}


 32%|███▏      | 3000/9310 [1:47:45<3:24:29,  1.94s/it] 

{'loss': 0.3974, 'learning_rate': 3.5811577752553915e-05, 'epoch': 22.56}


                                                       
 32%|███▏      | 3000/9310 [1:48:06<3:24:29,  1.94s/it]

{'eval_loss': 0.42515939474105835, 'eval_runtime': 20.453, 'eval_samples_per_second': 7.774, 'eval_steps_per_second': 0.244, 'epoch': 22.56}


 38%|███▊      | 3500/9310 [2:05:07<3:05:49,  1.92s/it] 

{'loss': 0.3805, 'learning_rate': 3.29738933030647e-05, 'epoch': 26.32}


                                                       
 38%|███▊      | 3500/9310 [2:05:27<3:05:49,  1.92s/it]

{'eval_loss': 0.41905856132507324, 'eval_runtime': 20.3572, 'eval_samples_per_second': 7.81, 'eval_steps_per_second': 0.246, 'epoch': 26.32}


 43%|████▎     | 4000/9310 [2:22:29<3:03:42,  2.08s/it] 

{'loss': 0.366, 'learning_rate': 3.013620885357548e-05, 'epoch': 30.08}


                                                       
 43%|████▎     | 4000/9310 [2:22:49<3:03:42,  2.08s/it]

{'eval_loss': 0.4152158200740814, 'eval_runtime': 20.1988, 'eval_samples_per_second': 7.872, 'eval_steps_per_second': 0.248, 'epoch': 30.08}


 48%|████▊     | 4500/9310 [2:39:38<2:37:06,  1.96s/it] 

{'loss': 0.3549, 'learning_rate': 2.7298524404086268e-05, 'epoch': 33.83}


                                                       
 48%|████▊     | 4500/9310 [2:39:58<2:37:06,  1.96s/it]

{'eval_loss': 0.4128681421279907, 'eval_runtime': 20.1272, 'eval_samples_per_second': 7.9, 'eval_steps_per_second': 0.248, 'epoch': 33.83}


 54%|█████▎    | 5000/9310 [2:56:59<2:19:12,  1.94s/it] 

{'loss': 0.3446, 'learning_rate': 2.446083995459705e-05, 'epoch': 37.59}


                                                       
 54%|█████▎    | 5000/9310 [2:57:19<2:19:12,  1.94s/it]

{'eval_loss': 0.4111403822898865, 'eval_runtime': 19.9924, 'eval_samples_per_second': 7.953, 'eval_steps_per_second': 0.25, 'epoch': 37.59}


 59%|█████▉    | 5500/9310 [3:14:22<2:03:04,  1.94s/it]

{'loss': 0.338, 'learning_rate': 2.1623155505107834e-05, 'epoch': 41.35}


                                                       
 59%|█████▉    | 5500/9310 [3:14:42<2:03:04,  1.94s/it]

{'eval_loss': 0.4124952256679535, 'eval_runtime': 20.2133, 'eval_samples_per_second': 7.866, 'eval_steps_per_second': 0.247, 'epoch': 41.35}


 64%|██████▍   | 6000/9310 [3:31:44<1:50:56,  2.01s/it]

{'loss': 0.33, 'learning_rate': 1.878547105561862e-05, 'epoch': 45.11}


                                                       
 64%|██████▍   | 6000/9310 [3:32:05<1:50:56,  2.01s/it]

{'eval_loss': 0.4117662310600281, 'eval_runtime': 20.5584, 'eval_samples_per_second': 7.734, 'eval_steps_per_second': 0.243, 'epoch': 45.11}


 70%|██████▉   | 6500/9310 [3:48:54<1:31:49,  1.96s/it]

{'loss': 0.3241, 'learning_rate': 1.59477866061294e-05, 'epoch': 48.87}


                                                       
 70%|██████▉   | 6500/9310 [3:49:15<1:31:49,  1.96s/it]

{'eval_loss': 0.4102106988430023, 'eval_runtime': 20.3813, 'eval_samples_per_second': 7.801, 'eval_steps_per_second': 0.245, 'epoch': 48.87}


 75%|███████▌  | 7000/9310 [4:06:17<1:13:47,  1.92s/it]

{'loss': 0.3194, 'learning_rate': 1.3110102156640184e-05, 'epoch': 52.63}


                                                       
 75%|███████▌  | 7000/9310 [4:06:38<1:13:47,  1.92s/it]

{'eval_loss': 0.41092225909233093, 'eval_runtime': 20.7185, 'eval_samples_per_second': 7.674, 'eval_steps_per_second': 0.241, 'epoch': 52.63}


 81%|████████  | 7500/9310 [4:23:39<59:21,  1.97s/it]  

{'loss': 0.3155, 'learning_rate': 1.0272417707150965e-05, 'epoch': 56.39}


                                                     
 81%|████████  | 7500/9310 [4:24:00<59:21,  1.97s/it]

{'eval_loss': 0.4105730950832367, 'eval_runtime': 20.6932, 'eval_samples_per_second': 7.684, 'eval_steps_per_second': 0.242, 'epoch': 56.39}


 86%|████████▌ | 8000/9310 [4:41:04<43:11,  1.98s/it]  

{'loss': 0.3129, 'learning_rate': 7.434733257661748e-06, 'epoch': 60.15}


                                                     
 86%|████████▌ | 8000/9310 [4:41:24<43:11,  1.98s/it]

{'eval_loss': 0.4103817343711853, 'eval_runtime': 20.4779, 'eval_samples_per_second': 7.764, 'eval_steps_per_second': 0.244, 'epoch': 60.15}


 91%|█████████▏| 8500/9310 [4:58:15<26:07,  1.94s/it]  

{'loss': 0.3107, 'learning_rate': 4.5970488081725315e-06, 'epoch': 63.91}


                                                     
 91%|█████████▏| 8500/9310 [4:58:36<26:07,  1.94s/it]

{'eval_loss': 0.4108417332172394, 'eval_runtime': 20.5321, 'eval_samples_per_second': 7.744, 'eval_steps_per_second': 0.244, 'epoch': 63.91}


 97%|█████████▋| 9000/9310 [5:15:38<10:08,  1.96s/it]  

{'loss': 0.3086, 'learning_rate': 1.7593643586833145e-06, 'epoch': 67.67}


                                                     
 97%|█████████▋| 9000/9310 [5:15:59<10:08,  1.96s/it]

{'eval_loss': 0.41025298833847046, 'eval_runtime': 20.6367, 'eval_samples_per_second': 7.705, 'eval_steps_per_second': 0.242, 'epoch': 67.67}


100%|██████████| 9310/9310 [5:26:27<00:00,  1.95s/it]There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight', 'lm_head.weight'].
100%|██████████| 9310/9310 [5:26:27<00:00,  2.10s/it]

{'train_runtime': 19587.4875, 'train_samples_per_second': 7.562, 'train_steps_per_second': 0.475, 'train_loss': 0.46214260425014475, 'epoch': 70.0}





In [11]:
tokenizer.save_pretrained('models/grammar_error_detection')
model.save_pretrained('models/grammar_error_detection')

In [12]:
trainer.evaluate(tokenized_test)

100%|██████████| 7/7 [00:16<00:00,  2.30s/it]


{'eval_loss': 0.47226184606552124,
 'eval_runtime': 29.7839,
 'eval_samples_per_second': 6.514,
 'eval_steps_per_second': 0.235,
 'epoch': 70.0}

## Inference

In [9]:
model_path = 'models/grammar_error_detection'
model = T5ForConditionalGeneration.from_pretrained(model_path)
tokenizer = T5Tokenizer.from_pretrained(model_path)

In [10]:
def do_correction(text):
    input_text = f"rectify: {text}"
    inputs = tokenizer.encode(
                            input_text,
                            return_tensors='pt',
                            max_length=256,
                            padding='max_length',
                            truncation=True
                            )

    corrected_ids = model.generate(
                                inputs,
                                max_length=384,
                                num_beams=5,
                                early_stopping=True
                                )

    corrected_sentence = tokenizer.decode(
                                        corrected_ids[0],
                                        skip_special_tokens=True
                                        )
    return corrected_sentence

In [11]:
sentences = [
    "He don't like to eat vegetables.",
    "They was going to the store yesterday.",
    "She don't sings very well.",
    "Between you and I, the decision not well received.",
    "The book I borrowed from the library, it was really interesting.",
    "Despite of the rain, they went for a picnic."
]
for sentence in sentences:
    corrected_sentence = do_correction(sentence)
    print(f"ORIG: {sentence}\nCORRECT: {corrected_sentence}")

ORIG: He don't like to eat vegetables.
CORRECT: He doesn't like to eat vegetables.
ORIG: They was going to the store yesterday.
CORRECT: They were going to the store yesterday.
ORIG: She don't sings very well.
CORRECT: She doesn't sing very well.
ORIG: Between you and I, the decision not well received.
CORRECT: Between you and I, the decision is not well received.
ORIG: The book I borrowed from the library, it was really interesting.
CORRECT: The book I borrowed from the library was really interesting.
ORIG: Despite of the rain, they went for a picnic.
CORRECT: Despite the rain, they went for a picnic.
