Import the necessary components. Some of them are new for this week, they will be discussed later in the notebook.

In [None]:
from transformers import AutoModelForSeq2SeqLM, T5Tokenizer, EarlyStoppingCallback, Seq2SeqTrainingArguments,Seq2SeqTrainer, DataCollatorForSeq2Seq
import torch
import time
import pandas as pd

import numpy as np
import re
from datasets import  load_from_disk,load_dataset


<a name='1.2'></a>
### 1.2 - Load Dataset and LLM


In [None]:
# dataset = load_dataset("Ahmadsameh8/qalbPreprocessedAndMerged")

dataset = load_from_disk("qalb")
# dataset


In [None]:
import gc
gc.collect()

24

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)

In [None]:
def print_number_of_trainable_model_parameters(model):
    trainable_model_params = 0
    all_model_params = 0
    for _, param in model.named_parameters():
        all_model_params += param.numel()
        if param.requires_grad:
            trainable_model_params += param.numel()
    return f"trainable model parameters: {trainable_model_params}\nall model parameters: {all_model_params}\npercentage of trainable model parameters: {100 * trainable_model_params / all_model_params:.2f}%"



In [None]:
# from transformers import BitsAndBytesConfig


# nf4_config = BitsAndBytesConfig(
#    load_in_4bit=True,
#    bnb_4bit_use_double_quant=True,
#    bnb_4bit_quant_type="nf4",
#    bnb_4bit_compute_dtype=torch.bfloat16
# )
# quantization_config=nf4_config,

model_name = "model_new"
original_model = AutoModelForSeq2SeqLM.from_pretrained(model_name,use_safetensors = True,device_map={"":0})
tokenizer = T5Tokenizer.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer)

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [None]:
print_number_of_trainable_model_parameters(original_model)

'trainable model parameters: 367508736\nall model parameters: 367508736\npercentage of trainable model parameters: 100.00%'

In [None]:
def preprocess_function(example,padding="max_length"):
    # add prefix to the input for t5
    inputs = example["incorrect"]

    # tokenize inputs
    model_inputs = tokenizer(inputs, max_length=512, padding=padding, truncation=True)

    # Tokenize targets with the `text_target` keyword argument
    labels = tokenizer(text_target=example["correct"], max_length=512, padding=padding, truncation=True)

    # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100 when we want to ignore
    # padding in the loss.
    if padding == "max_length":
        labels["input_ids"] = [
            [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
        ]

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(['correct', 'incorrect'])
print(f"Keys of tokenized dataset: {list(tokenized_datasets['train'].features)}")


                                                                   

Keys of tokenized dataset: ['input_ids', 'attention_mask', 'labels']




Check the shapes of all three parts of the dataset:

In [None]:
print(f"Shapes of the datasets:")
print(f"Training: {tokenized_datasets['train'].shape}")
print(f"Validation: {tokenized_datasets['validation'].shape}")
print(f"Test: {tokenized_datasets['test'].shape}")

print(tokenized_datasets)

Shapes of the datasets:
Training: (18350, 3)
Validation: (2293, 3)
Test: (2295, 3)
DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 18350
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2293
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 2295
    })
})


<a name='2.2'></a>
### 2.2 - Fine-Tune the Model with the Preprocessed Dataset



In [None]:
output_dir = "./textcorrection_model"

training_args = Seq2SeqTrainingArguments(
    output_dir=output_dir,
    learning_rate=5e-5, # Higher learning rate than full fine-tuning.
    num_train_epochs=4,
    # max_grad_norm=0.1,
    logging_steps=500,
    eval_steps=500,
    save_strategy= "epoch",
    evaluation_strategy="epoch",
    do_train=True,
    do_eval=True,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    save_steps=1024,
    warmup_steps=512,
    load_best_model_at_end = True,
    fp16=True,
    save_total_limit=2,
    # report_to="tensorboard"
)

early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=4)

trainer = Seq2SeqTrainer(
    model=original_model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],

    callbacks=[early_stopping_callback],
    data_collator=data_collator,


    # tokenizer=tokenizer,
)

In [None]:
trainer.train()
model_path= "textcorrection"
trainer.model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)

  1%|          | 500/73400 [02:37<6:24:33,  3.16it/s]

{'loss': 0.2258, 'grad_norm': 8.554731369018555, 'learning_rate': 4.8632812500000004e-05, 'epoch': 0.03}


  1%|▏         | 1000/73400 [05:14<6:20:33,  3.17it/s]

{'loss': 0.1818, 'grad_norm': 0.9594561457633972, 'learning_rate': 4.9667297771924045e-05, 'epoch': 0.05}


  2%|▏         | 1500/73400 [07:51<6:16:38,  3.18it/s]

{'loss': 0.1894, 'grad_norm': 2.1392178535461426, 'learning_rate': 4.9324991768192294e-05, 'epoch': 0.08}


  3%|▎         | 2000/73400 [10:29<6:14:33,  3.18it/s]

{'loss': 0.1947, 'grad_norm': 2.7186851501464844, 'learning_rate': 4.898199978048513e-05, 'epoch': 0.11}


  3%|▎         | 2500/73400 [13:06<6:11:23,  3.18it/s]

{'loss': 0.1834, 'grad_norm': 4.714346408843994, 'learning_rate': 4.863969377675338e-05, 'epoch': 0.14}


  4%|▍         | 3000/73400 [15:44<6:09:37,  3.17it/s]

{'loss': 0.1853, 'grad_norm': 1.7235218286514282, 'learning_rate': 4.8296701789046214e-05, 'epoch': 0.16}


  5%|▍         | 3500/73400 [18:21<6:07:08,  3.17it/s]

{'loss': 0.1737, 'grad_norm': 3.028968095779419, 'learning_rate': 4.795370980133904e-05, 'epoch': 0.19}


  5%|▌         | 4000/73400 [20:59<6:06:44,  3.15it/s]

{'loss': 0.1896, 'grad_norm': 3.2034566402435303, 'learning_rate': 4.761071781363188e-05, 'epoch': 0.22}


  6%|▌         | 4500/73400 [23:36<6:01:40,  3.18it/s]

{'loss': 0.192, 'grad_norm': 3.303408622741699, 'learning_rate': 4.7267725825924706e-05, 'epoch': 0.25}


  7%|▋         | 5000/73400 [26:14<5:58:53,  3.18it/s]

{'loss': 0.1824, 'grad_norm': 2.6341094970703125, 'learning_rate': 4.692473383821754e-05, 'epoch': 0.27}


  7%|▋         | 5500/73400 [28:51<5:56:18,  3.18it/s]

{'loss': 0.1819, 'grad_norm': 1.0057538747787476, 'learning_rate': 4.658174185051038e-05, 'epoch': 0.3}


  8%|▊         | 6000/73400 [31:28<5:53:00,  3.18it/s]

{'loss': 0.1694, 'grad_norm': 2.0833628177642822, 'learning_rate': 4.623943584677862e-05, 'epoch': 0.33}


  9%|▉         | 6500/73400 [34:06<5:51:43,  3.17it/s]

{'loss': 0.1802, 'grad_norm': 2.0929925441741943, 'learning_rate': 4.5896443859071455e-05, 'epoch': 0.35}


 10%|▉         | 7000/73400 [36:43<5:46:05,  3.20it/s]

{'loss': 0.1787, 'grad_norm': 4.308587551116943, 'learning_rate': 4.555345187136429e-05, 'epoch': 0.38}


 10%|█         | 7500/73400 [39:21<5:43:45,  3.20it/s]

{'loss': 0.157, 'grad_norm': 1.754582405090332, 'learning_rate': 4.521045988365712e-05, 'epoch': 0.41}


 11%|█         | 8000/73400 [41:59<5:43:54,  3.17it/s]

{'loss': 0.1695, 'grad_norm': 2.2226955890655518, 'learning_rate': 4.4867467895949954e-05, 'epoch': 0.44}


 12%|█▏        | 8500/73400 [44:36<5:38:35,  3.19it/s]

{'loss': 0.1622, 'grad_norm': 2.5653390884399414, 'learning_rate': 4.452447590824278e-05, 'epoch': 0.46}


 12%|█▏        | 9000/73400 [47:13<5:37:17,  3.18it/s]

{'loss': 0.1567, 'grad_norm': 4.757228851318359, 'learning_rate': 4.418148392053562e-05, 'epoch': 0.49}


 13%|█▎        | 9500/73400 [49:51<5:33:05,  3.20it/s]

{'loss': 0.1612, 'grad_norm': 0.6222451329231262, 'learning_rate': 4.383849193282845e-05, 'epoch': 0.52}


 14%|█▎        | 10000/73400 [52:28<5:31:08,  3.19it/s]

{'loss': 0.1779, 'grad_norm': 9.35268497467041, 'learning_rate': 4.349549994512128e-05, 'epoch': 0.54}


 14%|█▍        | 10500/73400 [55:06<5:28:29,  3.19it/s]

{'loss': 0.1718, 'grad_norm': 1.013614296913147, 'learning_rate': 4.315250795741412e-05, 'epoch': 0.57}


 15%|█▍        | 11000/73400 [57:43<5:26:15,  3.19it/s]

{'loss': 0.156, 'grad_norm': 0.08654537051916122, 'learning_rate': 4.2809515969706945e-05, 'epoch': 0.6}


 16%|█▌        | 11500/73400 [1:00:21<5:25:59,  3.16it/s]

{'loss': 0.1645, 'grad_norm': 3.6042351722717285, 'learning_rate': 4.2467209965975194e-05, 'epoch': 0.63}


 16%|█▋        | 12000/73400 [1:02:58<5:25:33,  3.14it/s]

{'loss': 0.1543, 'grad_norm': 0.6567481160163879, 'learning_rate': 4.212421797826803e-05, 'epoch': 0.65}


 17%|█▋        | 12500/73400 [1:05:36<5:21:12,  3.16it/s]

{'loss': 0.1657, 'grad_norm': 0.674049437046051, 'learning_rate': 4.1781225990560865e-05, 'epoch': 0.68}


 18%|█▊        | 13000/73400 [1:08:13<5:17:24,  3.17it/s]

{'loss': 0.1546, 'grad_norm': 1.8236823081970215, 'learning_rate': 4.1438234002853693e-05, 'epoch': 0.71}


 18%|█▊        | 13500/73400 [1:10:51<5:14:43,  3.17it/s]

{'loss': 0.1828, 'grad_norm': 5.183828830718994, 'learning_rate': 4.109592799912194e-05, 'epoch': 0.74}


 19%|█▉        | 14000/73400 [1:13:28<5:11:06,  3.18it/s]

{'loss': 0.1698, 'grad_norm': 4.096951484680176, 'learning_rate': 4.075293601141478e-05, 'epoch': 0.76}


 20%|█▉        | 14500/73400 [1:16:06<5:10:00,  3.17it/s]

{'loss': 0.1738, 'grad_norm': 5.652063846588135, 'learning_rate': 4.041063000768302e-05, 'epoch': 0.79}


 20%|██        | 15000/73400 [1:18:43<5:07:25,  3.17it/s]

{'loss': 0.1817, 'grad_norm': 2.6641340255737305, 'learning_rate': 4.0067638019975855e-05, 'epoch': 0.82}


 21%|██        | 15500/73400 [1:21:21<5:06:13,  3.15it/s]

{'loss': 0.1712, 'grad_norm': 5.526939392089844, 'learning_rate': 3.9724646032268684e-05, 'epoch': 0.84}


 22%|██▏       | 16000/73400 [1:23:58<5:01:36,  3.17it/s]

{'loss': 0.1581, 'grad_norm': 6.694446563720703, 'learning_rate': 3.938165404456152e-05, 'epoch': 0.87}


 22%|██▏       | 16500/73400 [1:26:36<4:59:12,  3.17it/s]

{'loss': 0.1719, 'grad_norm': 0.8760740756988525, 'learning_rate': 3.9038662056854355e-05, 'epoch': 0.9}


 23%|██▎       | 17000/73400 [1:29:13<4:56:59,  3.17it/s]

{'loss': 0.1649, 'grad_norm': 4.48449182510376, 'learning_rate': 3.869567006914719e-05, 'epoch': 0.93}


 24%|██▍       | 17500/73400 [1:31:51<4:54:05,  3.17it/s]

{'loss': 0.1544, 'grad_norm': 1.3211297988891602, 'learning_rate': 3.835267808144002e-05, 'epoch': 0.95}


 25%|██▍       | 18000/73400 [1:34:28<4:51:47,  3.16it/s]

{'loss': 0.1535, 'grad_norm': 1.2606381177902222, 'learning_rate': 3.8009686093732854e-05, 'epoch': 0.98}


                                                         
 25%|██▌       | 18350/73400 [1:37:47<4:48:33,  3.18it/s]

{'eval_loss': 0.4939463138580322, 'eval_runtime': 88.1946, 'eval_samples_per_second': 25.999, 'eval_steps_per_second': 25.999, 'epoch': 1.0}


 25%|██▌       | 18500/73400 [1:38:39<4:48:35,  3.17it/s]  

{'loss': 0.145, 'grad_norm': 1.5432958602905273, 'learning_rate': 3.766669410602568e-05, 'epoch': 1.01}


 26%|██▌       | 19000/73400 [1:41:16<4:45:57,  3.17it/s]

{'loss': 0.1138, 'grad_norm': 0.7905131578445435, 'learning_rate': 3.732370211831852e-05, 'epoch': 1.04}


 27%|██▋       | 19500/73400 [1:43:54<4:42:27,  3.18it/s]

{'loss': 0.1276, 'grad_norm': 2.0539684295654297, 'learning_rate': 3.6981396114586767e-05, 'epoch': 1.06}


 27%|██▋       | 20000/73400 [1:46:31<4:39:50,  3.18it/s]

{'loss': 0.1219, 'grad_norm': 5.758062839508057, 'learning_rate': 3.66384041268796e-05, 'epoch': 1.09}


 28%|██▊       | 20500/73400 [1:49:09<4:38:22,  3.17it/s]

{'loss': 0.1278, 'grad_norm': 9.148983001708984, 'learning_rate': 3.629541213917243e-05, 'epoch': 1.12}


 29%|██▊       | 21000/73400 [1:51:46<4:33:59,  3.19it/s]

{'loss': 0.1314, 'grad_norm': 6.540215492248535, 'learning_rate': 3.5952420151465266e-05, 'epoch': 1.14}


 29%|██▉       | 21500/73400 [1:54:24<4:32:45,  3.17it/s]

{'loss': 0.1237, 'grad_norm': 0.9002007246017456, 'learning_rate': 3.5609428163758094e-05, 'epoch': 1.17}


 30%|██▉       | 22000/73400 [1:57:01<4:29:22,  3.18it/s]

{'loss': 0.1191, 'grad_norm': 3.804367780685425, 'learning_rate': 3.526643617605093e-05, 'epoch': 1.2}


 31%|███       | 22500/73400 [1:59:39<4:26:35,  3.18it/s]

{'loss': 0.1277, 'grad_norm': 10.984539031982422, 'learning_rate': 3.4923444188343765e-05, 'epoch': 1.23}


 31%|███▏      | 23000/73400 [2:02:16<4:24:20,  3.18it/s]

{'loss': 0.1271, 'grad_norm': 0.17762981355190277, 'learning_rate': 3.458045220063659e-05, 'epoch': 1.25}


 32%|███▏      | 23500/73400 [2:04:54<4:22:07,  3.17it/s]

{'loss': 0.1394, 'grad_norm': 2.514493703842163, 'learning_rate': 3.423814619690484e-05, 'epoch': 1.28}


 33%|███▎      | 24000/73400 [2:07:31<4:17:52,  3.19it/s]

{'loss': 0.1379, 'grad_norm': 2.3498055934906006, 'learning_rate': 3.389515420919768e-05, 'epoch': 1.31}


 33%|███▎      | 24500/73400 [2:10:09<4:15:27,  3.19it/s]

{'loss': 0.1375, 'grad_norm': 0.8526722192764282, 'learning_rate': 3.3552162221490506e-05, 'epoch': 1.34}


 34%|███▍      | 25000/73400 [2:12:46<4:11:58,  3.20it/s]

{'loss': 0.1354, 'grad_norm': 3.7062721252441406, 'learning_rate': 3.3209170233783335e-05, 'epoch': 1.36}


 35%|███▍      | 25500/73400 [2:15:24<4:10:44,  3.18it/s]

{'loss': 0.1276, 'grad_norm': 4.206404685974121, 'learning_rate': 3.286617824607618e-05, 'epoch': 1.39}


 35%|███▌      | 26000/73400 [2:18:01<4:09:35,  3.17it/s]

{'loss': 0.1237, 'grad_norm': 7.410590171813965, 'learning_rate': 3.252387224234442e-05, 'epoch': 1.42}


 36%|███▌      | 26500/73400 [2:20:39<4:09:25,  3.13it/s]

{'loss': 0.1363, 'grad_norm': 1.7822672128677368, 'learning_rate': 3.2180880254637254e-05, 'epoch': 1.44}


 37%|███▋      | 27000/73400 [2:23:16<4:04:19,  3.17it/s]

{'loss': 0.131, 'grad_norm': 2.514401912689209, 'learning_rate': 3.183788826693009e-05, 'epoch': 1.47}


 37%|███▋      | 27500/73400 [2:25:54<4:03:36,  3.14it/s]

{'loss': 0.124, 'grad_norm': 2.887645721435547, 'learning_rate': 3.149489627922292e-05, 'epoch': 1.5}


 38%|███▊      | 28000/73400 [2:28:31<3:59:07,  3.16it/s]

{'loss': 0.1334, 'grad_norm': 1.6971276998519897, 'learning_rate': 3.1151904291515754e-05, 'epoch': 1.53}


 39%|███▉      | 28500/73400 [2:31:09<3:56:22,  3.17it/s]

{'loss': 0.1302, 'grad_norm': 0.931185781955719, 'learning_rate': 3.080891230380858e-05, 'epoch': 1.55}


 40%|███▉      | 29000/73400 [2:33:46<3:53:16,  3.17it/s]

{'loss': 0.1326, 'grad_norm': 0.5743779540061951, 'learning_rate': 3.0466606300076835e-05, 'epoch': 1.58}


 40%|████      | 29500/73400 [2:36:23<3:50:26,  3.18it/s]

{'loss': 0.1333, 'grad_norm': 3.480675458908081, 'learning_rate': 3.0123614312369663e-05, 'epoch': 1.61}


 41%|████      | 30000/73400 [2:39:00<3:47:50,  3.17it/s]

{'loss': 0.1365, 'grad_norm': 4.312435626983643, 'learning_rate': 2.9781308308637912e-05, 'epoch': 1.63}


 42%|████▏     | 30500/73400 [2:41:38<3:46:36,  3.16it/s]

{'loss': 0.135, 'grad_norm': 1.2027876377105713, 'learning_rate': 2.943831632093074e-05, 'epoch': 1.66}


 42%|████▏     | 31000/73400 [2:44:15<3:43:03,  3.17it/s]

{'loss': 0.139, 'grad_norm': 1.5135436058044434, 'learning_rate': 2.909532433322358e-05, 'epoch': 1.69}


 43%|████▎     | 31500/73400 [2:46:53<3:40:48,  3.16it/s]

{'loss': 0.1381, 'grad_norm': 2.4609556198120117, 'learning_rate': 2.8752332345516408e-05, 'epoch': 1.72}


 44%|████▎     | 32000/73400 [2:49:31<3:38:22,  3.16it/s]

{'loss': 0.132, 'grad_norm': 0.12567336857318878, 'learning_rate': 2.8409340357809243e-05, 'epoch': 1.74}


 44%|████▍     | 32500/73400 [2:52:09<3:36:00,  3.16it/s]

{'loss': 0.1291, 'grad_norm': 2.130763530731201, 'learning_rate': 2.8066348370102075e-05, 'epoch': 1.77}


 45%|████▍     | 33000/73400 [2:54:47<3:31:43,  3.18it/s]

{'loss': 0.125, 'grad_norm': 2.070603370666504, 'learning_rate': 2.772335638239491e-05, 'epoch': 1.8}


 46%|████▌     | 33500/73400 [2:57:25<3:28:44,  3.19it/s]

{'loss': 0.1357, 'grad_norm': 17.386510848999023, 'learning_rate': 2.7381050378663153e-05, 'epoch': 1.83}


 46%|████▋     | 34000/73400 [3:00:02<3:26:22,  3.18it/s]

{'loss': 0.1466, 'grad_norm': 1.0934052467346191, 'learning_rate': 2.7038058390955988e-05, 'epoch': 1.85}


 47%|████▋     | 34500/73400 [3:02:40<3:23:56,  3.18it/s]

{'loss': 0.1481, 'grad_norm': 2.916527271270752, 'learning_rate': 2.669506640324882e-05, 'epoch': 1.88}


 48%|████▊     | 35000/73400 [3:05:17<3:21:55,  3.17it/s]

{'loss': 0.1367, 'grad_norm': 3.9788753986358643, 'learning_rate': 2.635276039951707e-05, 'epoch': 1.91}


 48%|████▊     | 35500/73400 [3:07:55<3:19:20,  3.17it/s]

{'loss': 0.1438, 'grad_norm': 2.9018588066101074, 'learning_rate': 2.6009768411809904e-05, 'epoch': 1.93}


 49%|████▉     | 36000/73400 [3:10:32<3:16:19,  3.17it/s]

{'loss': 0.1369, 'grad_norm': 4.471598148345947, 'learning_rate': 2.5666776424102733e-05, 'epoch': 1.96}


 50%|████▉     | 36500/73400 [3:13:10<3:14:30,  3.16it/s]

{'loss': 0.1285, 'grad_norm': 2.8968074321746826, 'learning_rate': 2.5323784436395565e-05, 'epoch': 1.99}


                                                         
 50%|█████     | 36700/73400 [3:15:41<3:13:34,  3.16it/s]

{'eval_loss': 0.5325896143913269, 'eval_runtime': 88.0046, 'eval_samples_per_second': 26.055, 'eval_steps_per_second': 26.055, 'epoch': 2.0}


 50%|█████     | 37000/73400 [3:17:20<3:11:02,  3.18it/s]  

{'loss': 0.1127, 'grad_norm': 0.6899814009666443, 'learning_rate': 2.49807924486884e-05, 'epoch': 2.02}


 51%|█████     | 37500/73400 [3:19:58<3:09:31,  3.16it/s]

{'loss': 0.1076, 'grad_norm': 0.777041494846344, 'learning_rate': 2.4637800460981232e-05, 'epoch': 2.04}


 52%|█████▏    | 38000/73400 [3:22:36<3:05:53,  3.17it/s]

{'loss': 0.1117, 'grad_norm': 15.062843322753906, 'learning_rate': 2.4294808473274067e-05, 'epoch': 2.07}


 52%|█████▏    | 38500/73400 [3:25:13<3:03:11,  3.18it/s]

{'loss': 0.1234, 'grad_norm': 1.6767762899398804, 'learning_rate': 2.3951816485566896e-05, 'epoch': 2.1}


 53%|█████▎    | 39000/73400 [3:27:51<3:00:43,  3.17it/s]

{'loss': 0.1094, 'grad_norm': 3.196225881576538, 'learning_rate': 2.360882449785973e-05, 'epoch': 2.13}


 54%|█████▍    | 39500/73400 [3:30:28<2:58:20,  3.17it/s]

{'loss': 0.1117, 'grad_norm': 1.6075090169906616, 'learning_rate': 2.3265832510152563e-05, 'epoch': 2.15}


 54%|█████▍    | 40000/73400 [3:33:05<2:55:20,  3.17it/s]

{'loss': 0.1119, 'grad_norm': 3.619677782058716, 'learning_rate': 2.2923526506420812e-05, 'epoch': 2.18}


 55%|█████▌    | 40500/73400 [3:35:43<2:52:30,  3.18it/s]

{'loss': 0.1143, 'grad_norm': 46.59795379638672, 'learning_rate': 2.2581220502689058e-05, 'epoch': 2.21}


 56%|█████▌    | 41000/73400 [3:38:20<2:50:05,  3.17it/s]

{'loss': 0.1055, 'grad_norm': 1.9492827653884888, 'learning_rate': 2.223822851498189e-05, 'epoch': 2.23}


 57%|█████▋    | 41500/73400 [3:40:58<2:47:19,  3.18it/s]

{'loss': 0.1113, 'grad_norm': 1.266723394393921, 'learning_rate': 2.1895236527274725e-05, 'epoch': 2.26}


 57%|█████▋    | 42000/73400 [3:43:35<2:44:52,  3.17it/s]

{'loss': 0.1088, 'grad_norm': 1.2613245248794556, 'learning_rate': 2.1552244539567557e-05, 'epoch': 2.29}


 58%|█████▊    | 42500/73400 [3:46:13<2:41:40,  3.19it/s]

{'loss': 0.1051, 'grad_norm': 0.9253221750259399, 'learning_rate': 2.120925255186039e-05, 'epoch': 2.32}


 59%|█████▊    | 43000/73400 [3:48:50<2:39:32,  3.18it/s]

{'loss': 0.112, 'grad_norm': 6.188117027282715, 'learning_rate': 2.0866946548128638e-05, 'epoch': 2.34}


 59%|█████▉    | 43500/73400 [3:51:27<2:37:00,  3.17it/s]

{'loss': 0.1137, 'grad_norm': 0.4861711859703064, 'learning_rate': 2.052395456042147e-05, 'epoch': 2.37}


 60%|█████▉    | 44000/73400 [3:54:05<2:34:48,  3.17it/s]

{'loss': 0.1141, 'grad_norm': 1.8754411935806274, 'learning_rate': 2.0180962572714302e-05, 'epoch': 2.4}


 61%|██████    | 44500/73400 [3:56:42<2:31:42,  3.17it/s]

{'loss': 0.1309, 'grad_norm': 3.298957586288452, 'learning_rate': 1.983865656898255e-05, 'epoch': 2.43}


 61%|██████▏   | 45000/73400 [3:59:20<2:29:24,  3.17it/s]

{'loss': 0.1125, 'grad_norm': 4.983321189880371, 'learning_rate': 1.9495664581275383e-05, 'epoch': 2.45}


 62%|██████▏   | 45500/73400 [4:01:57<2:26:27,  3.17it/s]

{'loss': 0.1134, 'grad_norm': 1.225714087486267, 'learning_rate': 1.9152672593568215e-05, 'epoch': 2.48}


 63%|██████▎   | 46000/73400 [4:04:35<2:24:01,  3.17it/s]

{'loss': 0.1219, 'grad_norm': 5.772304058074951, 'learning_rate': 1.880968060586105e-05, 'epoch': 2.51}


 63%|██████▎   | 46500/73400 [4:07:13<2:21:15,  3.17it/s]

{'loss': 0.1187, 'grad_norm': 2.342487096786499, 'learning_rate': 1.846668861815388e-05, 'epoch': 2.53}


 64%|██████▍   | 47000/73400 [4:09:50<2:19:56,  3.14it/s]

{'loss': 0.1226, 'grad_norm': 3.2726950645446777, 'learning_rate': 1.8123696630446714e-05, 'epoch': 2.56}


 65%|██████▍   | 47500/73400 [4:12:28<2:16:49,  3.15it/s]

{'loss': 0.1143, 'grad_norm': 0.16208617389202118, 'learning_rate': 1.7780704642739546e-05, 'epoch': 2.59}


 65%|██████▌   | 48000/73400 [4:15:05<2:13:49,  3.16it/s]

{'loss': 0.1305, 'grad_norm': 1.4027739763259888, 'learning_rate': 1.7437712655032378e-05, 'epoch': 2.62}


 66%|██████▌   | 48500/73400 [4:17:42<2:10:31,  3.18it/s]

{'loss': 0.1252, 'grad_norm': 3.3315649032592773, 'learning_rate': 1.7094720667325213e-05, 'epoch': 2.64}


 67%|██████▋   | 49000/73400 [4:20:20<2:09:12,  3.15it/s]

{'loss': 0.1303, 'grad_norm': 4.482729434967041, 'learning_rate': 1.6751728679618045e-05, 'epoch': 2.67}


 67%|██████▋   | 49500/73400 [4:22:58<2:06:07,  3.16it/s]

{'loss': 0.1167, 'grad_norm': 1.2627469301223755, 'learning_rate': 1.6408736691910877e-05, 'epoch': 2.7}


 68%|██████▊   | 50000/73400 [4:25:35<2:03:10,  3.17it/s]

{'loss': 0.1376, 'grad_norm': 1.1637539863586426, 'learning_rate': 1.6065744704203712e-05, 'epoch': 2.72}


 69%|██████▉   | 50500/73400 [4:28:12<2:00:16,  3.17it/s]

{'loss': 0.1241, 'grad_norm': 1.1986063718795776, 'learning_rate': 1.5722752716496544e-05, 'epoch': 2.75}


 69%|██████▉   | 51000/73400 [4:30:50<1:57:38,  3.17it/s]

{'loss': 0.131, 'grad_norm': 4.478774070739746, 'learning_rate': 1.5379760728789376e-05, 'epoch': 2.78}


 70%|███████   | 51500/73400 [4:33:28<1:55:00,  3.17it/s]

{'loss': 0.1227, 'grad_norm': 0.32853710651397705, 'learning_rate': 1.5036768741082208e-05, 'epoch': 2.81}


 71%|███████   | 52000/73400 [4:36:05<1:52:28,  3.17it/s]

{'loss': 0.1158, 'grad_norm': 1.3078657388687134, 'learning_rate': 1.4694462737350457e-05, 'epoch': 2.83}


 72%|███████▏  | 52500/73400 [4:38:42<1:49:27,  3.18it/s]

{'loss': 0.1424, 'grad_norm': 2.639411211013794, 'learning_rate': 1.4351470749643289e-05, 'epoch': 2.86}


 72%|███████▏  | 53000/73400 [4:41:20<1:47:16,  3.17it/s]

{'loss': 0.1297, 'grad_norm': 0.13371294736862183, 'learning_rate': 1.4008478761936122e-05, 'epoch': 2.89}


 73%|███████▎  | 53500/73400 [4:43:58<1:44:04,  3.19it/s]

{'loss': 0.1245, 'grad_norm': 1.3698604106903076, 'learning_rate': 1.3665486774228956e-05, 'epoch': 2.92}


 74%|███████▎  | 54000/73400 [4:46:35<1:41:40,  3.18it/s]

{'loss': 0.1144, 'grad_norm': 1.6520432233810425, 'learning_rate': 1.3322494786521788e-05, 'epoch': 2.94}


 74%|███████▍  | 54500/73400 [4:49:13<1:39:12,  3.18it/s]

{'loss': 0.1448, 'grad_norm': 3.8290839195251465, 'learning_rate': 1.297950279881462e-05, 'epoch': 2.97}


 75%|███████▍  | 55000/73400 [4:51:50<1:36:20,  3.18it/s]

{'loss': 0.1291, 'grad_norm': 1.2048864364624023, 'learning_rate': 1.2636510811107452e-05, 'epoch': 3.0}


                                                         
 75%|███████▌  | 55050/73400 [4:53:34<1:36:14,  3.18it/s]

{'eval_loss': 0.5498557686805725, 'eval_runtime': 87.9867, 'eval_samples_per_second': 26.061, 'eval_steps_per_second': 26.061, 'epoch': 3.0}


 76%|███████▌  | 55500/73400 [4:56:01<1:33:25,  3.19it/s]  

{'loss': 0.0985, 'grad_norm': 3.0766146183013916, 'learning_rate': 1.2293518823400287e-05, 'epoch': 3.02}


 76%|███████▋  | 56000/73400 [4:58:39<1:31:10,  3.18it/s]

{'loss': 0.1036, 'grad_norm': 4.650710105895996, 'learning_rate': 1.1950526835693119e-05, 'epoch': 3.05}


 77%|███████▋  | 56500/73400 [5:01:16<1:28:31,  3.18it/s]

{'loss': 0.1031, 'grad_norm': 3.1904141902923584, 'learning_rate': 1.160753484798595e-05, 'epoch': 3.08}


 78%|███████▊  | 57000/73400 [5:03:54<1:25:42,  3.19it/s]

{'loss': 0.1038, 'grad_norm': 3.9037091732025146, 'learning_rate': 1.1264542860278784e-05, 'epoch': 3.11}


 78%|███████▊  | 57500/73400 [5:06:56<1:38:46,  2.68it/s]

{'loss': 0.1105, 'grad_norm': 1.0892443656921387, 'learning_rate': 1.0921550872571618e-05, 'epoch': 3.13}


 79%|███████▉  | 58000/73400 [5:09:52<1:20:42,  3.18it/s]

{'loss': 0.1135, 'grad_norm': 2.1632211208343506, 'learning_rate': 1.0579244868839865e-05, 'epoch': 3.16}


 80%|███████▉  | 58500/73400 [5:12:30<1:18:11,  3.18it/s]

{'loss': 0.122, 'grad_norm': 1.3015050888061523, 'learning_rate': 1.0236938865108111e-05, 'epoch': 3.19}


 80%|████████  | 59000/73400 [5:15:07<1:15:06,  3.20it/s]

{'loss': 0.1087, 'grad_norm': 1.8607341051101685, 'learning_rate': 9.894632861376359e-06, 'epoch': 3.22}


 81%|████████  | 59500/73400 [5:17:45<1:12:28,  3.20it/s]

{'loss': 0.117, 'grad_norm': 2.637951374053955, 'learning_rate': 9.551640873669192e-06, 'epoch': 3.24}


 82%|████████▏ | 60000/73400 [5:20:22<1:09:59,  3.19it/s]

{'loss': 0.1231, 'grad_norm': 0.27874550223350525, 'learning_rate': 9.208648885962024e-06, 'epoch': 3.27}


 82%|████████▏ | 60500/73400 [5:23:00<1:07:37,  3.18it/s]

{'loss': 0.1099, 'grad_norm': 2.9637770652770996, 'learning_rate': 8.865656898254856e-06, 'epoch': 3.3}


 83%|████████▎ | 61000/73400 [5:25:37<1:05:18,  3.16it/s]

{'loss': 0.1249, 'grad_norm': 2.340684413909912, 'learning_rate': 8.52266491054769e-06, 'epoch': 3.32}


 84%|████████▍ | 61500/73400 [5:28:15<1:02:01,  3.20it/s]

{'loss': 0.119, 'grad_norm': 0.5112693309783936, 'learning_rate': 8.179672922840523e-06, 'epoch': 3.35}


 84%|████████▍ | 62000/73400 [5:30:52<59:24,  3.20it/s]  

{'loss': 0.1349, 'grad_norm': 3.7277016639709473, 'learning_rate': 7.836680935133357e-06, 'epoch': 3.38}


 85%|████████▌ | 62500/73400 [5:33:30<56:49,  3.20it/s]  

{'loss': 0.1185, 'grad_norm': 4.4205756187438965, 'learning_rate': 7.493688947426189e-06, 'epoch': 3.41}


 86%|████████▌ | 63000/73400 [5:36:07<54:30,  3.18it/s]  

{'loss': 0.1291, 'grad_norm': 1.3241143226623535, 'learning_rate': 7.1506969597190205e-06, 'epoch': 3.43}


 87%|████████▋ | 63500/73400 [5:38:45<51:37,  3.20it/s]

{'loss': 0.1183, 'grad_norm': 3.0212466716766357, 'learning_rate': 6.807704972011854e-06, 'epoch': 3.46}


 87%|████████▋ | 64000/73400 [5:41:22<49:15,  3.18it/s]

{'loss': 0.128, 'grad_norm': 1.9016361236572266, 'learning_rate': 6.465398968280101e-06, 'epoch': 3.49}


 88%|████████▊ | 64500/73400 [5:44:00<46:37,  3.18it/s]

{'loss': 0.1395, 'grad_norm': 1.122450351715088, 'learning_rate': 6.122406980572934e-06, 'epoch': 3.51}


 89%|████████▊ | 65000/73400 [5:46:37<44:00,  3.18it/s]

{'loss': 0.1197, 'grad_norm': 3.7238266468048096, 'learning_rate': 5.779414992865767e-06, 'epoch': 3.54}


 89%|████████▉ | 65500/73400 [5:49:15<41:26,  3.18it/s]

{'loss': 0.1404, 'grad_norm': 3.74060320854187, 'learning_rate': 5.4364230051586e-06, 'epoch': 3.57}


 90%|████████▉ | 66000/73400 [5:51:52<38:40,  3.19it/s]

{'loss': 0.1331, 'grad_norm': 0.47699445486068726, 'learning_rate': 5.0934310174514325e-06, 'epoch': 3.6}


 91%|█████████ | 66500/73400 [5:54:30<36:04,  3.19it/s]

{'loss': 0.1413, 'grad_norm': 1.3354064226150513, 'learning_rate': 4.750439029744265e-06, 'epoch': 3.62}


 91%|█████████▏| 67000/73400 [5:57:07<33:26,  3.19it/s]

{'loss': 0.1382, 'grad_norm': 3.992628812789917, 'learning_rate': 4.407447042037098e-06, 'epoch': 3.65}


 92%|█████████▏| 67500/73400 [5:59:45<30:57,  3.18it/s]

{'loss': 0.1374, 'grad_norm': 0.9266499280929565, 'learning_rate': 4.064455054329931e-06, 'epoch': 3.68}


 93%|█████████▎| 68000/73400 [6:02:22<28:11,  3.19it/s]

{'loss': 0.1452, 'grad_norm': 4.521307468414307, 'learning_rate': 3.7214630666227635e-06, 'epoch': 3.71}


 93%|█████████▎| 68500/73400 [6:05:00<25:38,  3.19it/s]

{'loss': 0.1522, 'grad_norm': 2.5177605152130127, 'learning_rate': 3.3784710789155967e-06, 'epoch': 3.73}


 94%|█████████▍| 69000/73400 [6:07:37<23:05,  3.17it/s]

{'loss': 0.1525, 'grad_norm': 4.637007236480713, 'learning_rate': 3.0368510591592584e-06, 'epoch': 3.76}


 95%|█████████▍| 69500/73400 [6:10:14<20:27,  3.18it/s]

{'loss': 0.1542, 'grad_norm': 0.4649243652820587, 'learning_rate': 2.693859071452091e-06, 'epoch': 3.79}


 95%|█████████▌| 70000/73400 [6:12:52<17:56,  3.16it/s]

{'loss': 0.1511, 'grad_norm': 2.7363669872283936, 'learning_rate': 2.350867083744924e-06, 'epoch': 3.81}


 96%|█████████▌| 70500/73400 [6:15:30<15:22,  3.14it/s]

{'loss': 0.1646, 'grad_norm': 2.456928014755249, 'learning_rate': 2.0078750960377567e-06, 'epoch': 3.84}


 97%|█████████▋| 71000/73400 [6:18:07<12:36,  3.17it/s]

{'loss': 0.1542, 'grad_norm': 1.5736595392227173, 'learning_rate': 1.6648831083305894e-06, 'epoch': 3.87}


 97%|█████████▋| 71500/73400 [6:20:45<10:03,  3.15it/s]

{'loss': 0.1686, 'grad_norm': 0.19977720081806183, 'learning_rate': 1.3218911206234224e-06, 'epoch': 3.9}


 98%|█████████▊| 72000/73400 [6:23:22<07:22,  3.17it/s]

{'loss': 0.1668, 'grad_norm': 2.0384328365325928, 'learning_rate': 9.788991329162551e-07, 'epoch': 3.92}


 99%|█████████▉| 72500/73400 [6:26:00<04:43,  3.18it/s]

{'loss': 0.1683, 'grad_norm': 4.563311576843262, 'learning_rate': 6.365931291845022e-07, 'epoch': 3.95}


 99%|█████████▉| 73000/73400 [6:28:37<02:06,  3.17it/s]

{'loss': 0.1604, 'grad_norm': 4.361780643463135, 'learning_rate': 2.936011414773351e-07, 'epoch': 3.98}


                                                       
100%|██████████| 73400/73400 [6:32:11<00:00,  3.18it/s]

{'eval_loss': 0.48001521825790405, 'eval_runtime': 88.1994, 'eval_samples_per_second': 25.998, 'eval_steps_per_second': 25.998, 'epoch': 4.0}


There were missing keys in the checkpoint model loaded: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight'].
100%|██████████| 73400/73400 [6:32:16<00:00,  3.12it/s]


{'train_runtime': 23536.7797, 'train_samples_per_second': 3.119, 'train_steps_per_second': 3.119, 'train_loss': 0.13951776972258773, 'epoch': 4.0}


('textcorrection/tokenizer_config.json',
 'textcorrection/special_tokens_map.json',
 'textcorrection/spiece.model',
 'textcorrection/added_tokens.json')

In [None]:
import gc
gc.collect()

835

In [None]:
%load_ext tensorboard
%tensorboard --logdir '{output_dir}'/runs

In [None]:
from transformers import T5Tokenizer, AutoModelForSeq2SeqLM
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.cuda.set_device(device)
tokenizer1 = T5Tokenizer.from_pretrained("textcorrection")
model1 = AutoModelForSeq2SeqLM.from_pretrained("textcorrection").to(device)

# ar_prompt="عاصمة ألمانيا هي <extra_id_0> "
input_ids = tokenizer1(dataset["test"]["incorrect"][2], return_tensors="pt").input_ids.to(device)
outputs = model1.generate(input_ids, max_length=512)
print("Tokenized input:", tokenizer1.tokenize(dataset["test"]["incorrect"][1]))
print("Decoded output:", tokenizer1.decode(outputs[0], skip_special_tokens=True))

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Tokenized input: ['▁علماء', '▁مصر', '▁ال', '▁ذ', 'ين', '▁ملء', 'وا', '▁الفضاء', 'حي', '▁ات', '▁لا', '▁يت', 'كل', '▁مون', '▁الا', '▁عن', '▁ن', 'واق', 'ض', '▁الوضوء', '▁و', '▁احكام', '▁دم', '▁الحي', 'ض', '▁و', '▁الن', '▁فاس', '▁ام', '▁ا', '▁دماء', '▁المسلمين', '▁الت', '▁ي', '▁يهد', 'رها', '▁الحك', '▁ام', '▁الظ', '▁لمه', '▁فلا', '▁احكام', '▁لها', 'في', '▁فقه', 'هم', '▁.', '▁وكيف', '▁نريد', 'هم', '▁ان', '▁ين', 'قد', 'وا', '▁من', '▁يسم', '▁ونه', '▁ولي', '▁امر', '▁و', '▁امير', '▁م', 'ء', 'منين', '▁و', '▁يحمد', 'ون', '▁الله', '▁علي', '▁الهواء', '▁عندما', '▁اجري', '▁عمليه', '▁جرا', 'حي', '▁ه', '▁ناج', 'حه', '▁في', '▁المانيا']
Decoded output: الكيل بمكيالين : الجزر للمصريين والعصا لاخواننا في افغانستان والعراق... امريكا اعلنت حربا صليبيه علي المسلمين علي لسان الجزار بوش ، وقتلت ، ودمرت ، وحاصرت ، وهي من ساندت الطاغيه لا مبارك ، لماذا ساندته ومولته ؟ عودوا الي المعتقلات ، والتصفيات ، والاختطافات فستجدون ان امريكا تخطط ولامبارك ينفذ هم العدو فاتلهم الله اني يءفكون.
