In [10]:
%pip install accelerate
%pip install transformers 
%pip install datasets
%pip install sacrebleu evaluate
%pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
%pip install transformers[sentencepiece]


Note: you may need to restart the kernel to use updated packages.


Note: you may need to restart the kernel to use updated packages.
Looking in indexes: https://download.pytorch.org/whl/cu118
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [11]:
from transformers import pipeline, AutoTokenizer, MT5Model, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq, AutoModelForSeq2SeqLM
from datasets import Dataset
from torch import cuda

import pandas as pd
import numpy as np
import evaluate
import csv
import os

In [12]:
def get_raw_data(path) -> None:
    """
    Get the raw data for all the books.

    Returns:
        None
    """
    raw_data = pd.read_csv(path)
    return raw_data

def prepare_data(raw_data: pd.DataFrame) -> pd.DataFrame:
    # remove the "Mefaresh" column
    raw_data.drop(columns=["Mefaresh"], inplace=True)
    # remove rows with empty values
    raw_data.dropna(inplace=True)
    # renaming the columns for mT5
    raw_data.columns = ['input_text', 'target_text']
    return raw_data


In [13]:
df = get_raw_data('./combined.csv')
df = prepare_data(df)
len(df)

343027

In [14]:
!nvidia-smi

Sun Mar 24 15:02:43 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 546.29                 Driver Version: 546.29       CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                     TCC/WDDM  | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA GeForce RTX 4070      WDDM  | 00000000:01:00.0  On |                  N/A |
|  0%   42C    P8              10W / 215W |    708MiB / 12282MiB |     12%      Default |
|                                         |                      |                  N/A |
+-----------------------------------------+----------------------+----------------------+
                                                                    

In [15]:
device = 'cuda' if cuda.is_available() else 'cpu'
print(device)
model_name = 'google/mt5-small'
tokenizer = AutoTokenizer.from_pretrained(model_name, return_tensors="pt", to_device=device)

cuda




In [16]:
max_length = 40
dataset = Dataset.from_pandas(df)


def preprocess_function(examples):
    inputs = [i for i in examples["input_text"]]
    targets = [i for i in examples["target_text"]]
    model_inputs = tokenizer(
        inputs, text_target=targets, max_length=max_length, truncation=True
    )
    return model_inputs


tokenized_datasets = dataset.map(
    preprocess_function,
    batched=True,
    remove_columns=dataset.column_names,
)

Map: 100%|██████████| 343027/343027 [00:35<00:00, 9601.79 examples/s] 


In [17]:
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [18]:
batch = data_collator([tokenized_datasets[i] for i in range(1, 3)])
batch.keys()

dict_keys(['input_ids', 'attention_mask', 'labels', 'decoder_input_ids'])

In [19]:
metric = evaluate.load("sacrebleu")

def compute_metrics(eval_preds):
    preds, labels = eval_preds
    # In case the model returns more than the prediction logits
    if isinstance(preds, tuple):
        preds = preds[0]

    preds = np.where(preds != -100, preds, tokenizer.pad_token_id)

    decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)

    # Replace -100s in the labels as we can't decode them
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

    # Some simple post-processing
    decoded_preds = [pred.strip() for pred in decoded_preds]
    decoded_labels = [[label.strip()] for label in decoded_labels]

    result = metric.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": result["score"]}

In [20]:
tokenized_datasets

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 343027
})

In [25]:
args = Seq2SeqTrainingArguments(
    f"{model_name}",
    evaluation_strategy="no",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=32,
    per_device_eval_batch_size=64,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=3,
    predict_with_generate=True,
    fp16=False,
)

trainer = Seq2SeqTrainer(
    model,
    args,
    train_dataset=tokenized_datasets,
    eval_dataset=tokenized_datasets,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [36]:
#trainer.evaluate(max_length=max_length)

In [26]:
trainer.train()

 10%|▉         | 512/5360 [07:12<1:08:18,  1.18it/s]
  2%|▏         | 501/32160 [01:45<1:51:09,  4.75it/s]

{'loss': 9.567, 'grad_norm': 2.8611743450164795, 'learning_rate': 1.968905472636816e-05, 'epoch': 0.05}


  3%|▎         | 1001/32160 [03:30<1:47:05,  4.85it/s]

{'loss': 5.1, 'grad_norm': 2.200582504272461, 'learning_rate': 1.937810945273632e-05, 'epoch': 0.09}


  5%|▍         | 1501/32160 [05:14<1:45:31,  4.84it/s]

{'loss': 4.7471, 'grad_norm': 1.9834951162338257, 'learning_rate': 1.9067164179104477e-05, 'epoch': 0.14}


  6%|▌         | 2001/32160 [07:07<1:55:44,  4.34it/s]

{'loss': 4.5569, 'grad_norm': 1.6956961154937744, 'learning_rate': 1.8756218905472638e-05, 'epoch': 0.19}


  8%|▊         | 2501/32160 [09:01<1:49:40,  4.51it/s]

{'loss': 4.4542, 'grad_norm': 1.715301752090454, 'learning_rate': 1.8445273631840795e-05, 'epoch': 0.23}


  9%|▉         | 3001/32160 [10:56<1:48:31,  4.48it/s]

{'loss': 4.3988, 'grad_norm': 1.6338735818862915, 'learning_rate': 1.8134328358208956e-05, 'epoch': 0.28}


 11%|█         | 3501/32160 [12:50<1:51:51,  4.27it/s]

{'loss': 4.3258, 'grad_norm': 1.6478290557861328, 'learning_rate': 1.7823383084577114e-05, 'epoch': 0.33}


 12%|█▏        | 4001/32160 [14:44<1:47:45,  4.36it/s]

{'loss': 4.2701, 'grad_norm': 1.8015425205230713, 'learning_rate': 1.7512437810945274e-05, 'epoch': 0.37}


 14%|█▍        | 4501/32160 [16:39<1:46:52,  4.31it/s]

{'loss': 4.2303, 'grad_norm': 1.7219550609588623, 'learning_rate': 1.7201492537313432e-05, 'epoch': 0.42}


 16%|█▌        | 5001/32160 [18:33<1:44:15,  4.34it/s]

{'loss': 4.192, 'grad_norm': 1.3947694301605225, 'learning_rate': 1.6890547263681593e-05, 'epoch': 0.47}


 17%|█▋        | 5501/32160 [20:28<1:40:42,  4.41it/s]

{'loss': 4.1589, 'grad_norm': 1.7538013458251953, 'learning_rate': 1.6579601990049753e-05, 'epoch': 0.51}


 19%|█▊        | 6001/32160 [22:22<1:41:56,  4.28it/s]

{'loss': 4.1304, 'grad_norm': 1.4781137704849243, 'learning_rate': 1.626865671641791e-05, 'epoch': 0.56}


 20%|██        | 6501/32160 [24:16<1:36:49,  4.42it/s]

{'loss': 4.0856, 'grad_norm': 1.7124055624008179, 'learning_rate': 1.595771144278607e-05, 'epoch': 0.61}


 22%|██▏       | 7001/32160 [26:10<1:35:32,  4.39it/s]

{'loss': 4.072, 'grad_norm': 1.7101701498031616, 'learning_rate': 1.564676616915423e-05, 'epoch': 0.65}


 23%|██▎       | 7501/32160 [28:04<1:35:07,  4.32it/s]

{'loss': 4.0583, 'grad_norm': 1.3049976825714111, 'learning_rate': 1.533582089552239e-05, 'epoch': 0.7}


 25%|██▍       | 8001/32160 [29:58<1:33:54,  4.29it/s]

{'loss': 4.0303, 'grad_norm': 1.5469647645950317, 'learning_rate': 1.5024875621890549e-05, 'epoch': 0.75}


 26%|██▋       | 8501/32160 [31:52<1:29:29,  4.41it/s]

{'loss': 4.0139, 'grad_norm': 1.406729817390442, 'learning_rate': 1.4713930348258708e-05, 'epoch': 0.79}


 28%|██▊       | 9001/32160 [33:46<1:26:59,  4.44it/s]

{'loss': 4.0085, 'grad_norm': 1.332188606262207, 'learning_rate': 1.4402985074626867e-05, 'epoch': 0.84}


 30%|██▉       | 9501/32160 [35:40<1:26:13,  4.38it/s]

{'loss': 3.9867, 'grad_norm': 1.456650972366333, 'learning_rate': 1.4092039800995026e-05, 'epoch': 0.89}


 31%|███       | 10001/32160 [37:34<1:23:09,  4.44it/s]

{'loss': 3.9722, 'grad_norm': 1.3517426252365112, 'learning_rate': 1.3781094527363185e-05, 'epoch': 0.93}


 33%|███▎      | 10501/32160 [39:28<1:24:19,  4.28it/s]

{'loss': 3.9664, 'grad_norm': 1.8773671388626099, 'learning_rate': 1.3470149253731344e-05, 'epoch': 0.98}


 34%|███▍      | 11001/32160 [41:32<1:21:21,  4.33it/s] 

{'loss': 3.957, 'grad_norm': 1.6737744808197021, 'learning_rate': 1.3159203980099505e-05, 'epoch': 1.03}


 36%|███▌      | 11501/32160 [43:26<1:19:57,  4.31it/s]

{'loss': 3.9259, 'grad_norm': 1.3659573793411255, 'learning_rate': 1.2848258706467662e-05, 'epoch': 1.07}


 37%|███▋      | 12001/32160 [45:20<1:16:26,  4.40it/s]

{'loss': 3.9247, 'grad_norm': 1.3583663702011108, 'learning_rate': 1.2537313432835823e-05, 'epoch': 1.12}


 39%|███▉      | 12501/32160 [47:14<1:13:47,  4.44it/s]

{'loss': 3.9085, 'grad_norm': 1.503098964691162, 'learning_rate': 1.222636815920398e-05, 'epoch': 1.17}


 40%|████      | 13001/32160 [49:08<1:13:14,  4.36it/s]

{'loss': 3.9092, 'grad_norm': 1.482322335243225, 'learning_rate': 1.1915422885572141e-05, 'epoch': 1.21}


 42%|████▏     | 13501/32160 [51:02<1:11:06,  4.37it/s]

{'loss': 3.886, 'grad_norm': 1.6200906038284302, 'learning_rate': 1.1604477611940299e-05, 'epoch': 1.26}


 44%|████▎     | 14001/32160 [52:56<1:09:12,  4.37it/s]

{'loss': 3.8711, 'grad_norm': 1.3181043863296509, 'learning_rate': 1.129353233830846e-05, 'epoch': 1.31}


 45%|████▌     | 14501/32160 [54:51<1:08:33,  4.29it/s]

{'loss': 3.8812, 'grad_norm': 1.7766988277435303, 'learning_rate': 1.0982587064676617e-05, 'epoch': 1.35}


 47%|████▋     | 15001/32160 [56:44<1:05:37,  4.36it/s]

{'loss': 3.8826, 'grad_norm': 1.3203421831130981, 'learning_rate': 1.0671641791044778e-05, 'epoch': 1.4}


 48%|████▊     | 15501/32160 [58:38<1:06:07,  4.20it/s]

{'loss': 3.8686, 'grad_norm': 1.321150302886963, 'learning_rate': 1.0360696517412937e-05, 'epoch': 1.45}


 50%|████▉     | 16001/32160 [1:00:33<59:54,  4.49it/s]  

{'loss': 3.8678, 'grad_norm': 1.3278621435165405, 'learning_rate': 1.0049751243781096e-05, 'epoch': 1.49}


 51%|█████▏    | 16501/32160 [1:02:26<1:00:43,  4.30it/s]

{'loss': 3.8497, 'grad_norm': 1.2563050985336304, 'learning_rate': 9.738805970149255e-06, 'epoch': 1.54}


 53%|█████▎    | 17001/32160 [1:04:20<56:02,  4.51it/s]  

{'loss': 3.8478, 'grad_norm': 1.3347398042678833, 'learning_rate': 9.427860696517414e-06, 'epoch': 1.59}


 54%|█████▍    | 17501/32160 [1:06:15<55:49,  4.38it/s]  

{'loss': 3.8408, 'grad_norm': 1.6456835269927979, 'learning_rate': 9.116915422885573e-06, 'epoch': 1.63}


 56%|█████▌    | 18001/32160 [1:08:10<55:03,  4.29it/s]

{'loss': 3.8325, 'grad_norm': 1.4305638074874878, 'learning_rate': 8.805970149253732e-06, 'epoch': 1.68}


 58%|█████▊    | 18501/32160 [1:10:05<52:43,  4.32it/s]

{'loss': 3.8385, 'grad_norm': 1.446617603302002, 'learning_rate': 8.495024875621891e-06, 'epoch': 1.73}


 59%|█████▉    | 19001/32160 [1:12:00<51:58,  4.22it/s]

{'loss': 3.8164, 'grad_norm': 1.3500114679336548, 'learning_rate': 8.18407960199005e-06, 'epoch': 1.77}


 61%|██████    | 19501/32160 [1:13:56<49:37,  4.25it/s]

{'loss': 3.8117, 'grad_norm': 1.3277909755706787, 'learning_rate': 7.87313432835821e-06, 'epoch': 1.82}


 62%|██████▏   | 20001/32160 [1:15:51<46:59,  4.31it/s]

{'loss': 3.8158, 'grad_norm': 1.468062162399292, 'learning_rate': 7.5621890547263685e-06, 'epoch': 1.87}


 64%|██████▎   | 20501/32160 [1:17:46<45:32,  4.27it/s]

{'loss': 3.8078, 'grad_norm': 1.2861698865890503, 'learning_rate': 7.251243781094528e-06, 'epoch': 1.91}


 65%|██████▌   | 21001/32160 [1:19:41<42:21,  4.39it/s]

{'loss': 3.8051, 'grad_norm': 1.3736969232559204, 'learning_rate': 6.9402985074626876e-06, 'epoch': 1.96}


 67%|██████▋   | 21501/32160 [1:21:44<41:03,  4.33it/s]  

{'loss': 3.8001, 'grad_norm': 1.5067063570022583, 'learning_rate': 6.629353233830847e-06, 'epoch': 2.01}


 68%|██████▊   | 22001/32160 [1:23:39<39:08,  4.33it/s]

{'loss': 3.7895, 'grad_norm': 1.557795524597168, 'learning_rate': 6.318407960199006e-06, 'epoch': 2.05}


 70%|██████▉   | 22501/32160 [1:25:35<37:43,  4.27it/s]

{'loss': 3.7946, 'grad_norm': 1.2808749675750732, 'learning_rate': 6.007462686567165e-06, 'epoch': 2.1}


 72%|███████▏  | 23001/32160 [1:27:30<35:49,  4.26it/s]

{'loss': 3.7902, 'grad_norm': 1.5786354541778564, 'learning_rate': 5.696517412935324e-06, 'epoch': 2.15}


 73%|███████▎  | 23501/32160 [1:29:25<33:45,  4.27it/s]

{'loss': 3.7923, 'grad_norm': 1.6900386810302734, 'learning_rate': 5.385572139303483e-06, 'epoch': 2.19}


 75%|███████▍  | 24001/32160 [1:31:23<31:44,  4.28it/s]

{'loss': 3.7777, 'grad_norm': 1.594274878501892, 'learning_rate': 5.074626865671642e-06, 'epoch': 2.24}


 76%|███████▌  | 24501/32160 [1:33:18<28:58,  4.41it/s]

{'loss': 3.7866, 'grad_norm': 1.440423846244812, 'learning_rate': 4.763681592039802e-06, 'epoch': 2.29}


 78%|███████▊  | 25001/32160 [1:35:13<27:55,  4.27it/s]

{'loss': 3.7842, 'grad_norm': 1.3393959999084473, 'learning_rate': 4.452736318407961e-06, 'epoch': 2.33}


 79%|███████▉  | 25501/32160 [1:37:08<25:16,  4.39it/s]

{'loss': 3.7794, 'grad_norm': 1.7366305589675903, 'learning_rate': 4.141791044776119e-06, 'epoch': 2.38}


 81%|████████  | 26001/32160 [1:39:04<22:46,  4.51it/s]

{'loss': 3.7718, 'grad_norm': 1.422526478767395, 'learning_rate': 3.8308457711442784e-06, 'epoch': 2.43}


 82%|████████▏ | 26501/32160 [1:40:58<21:24,  4.41it/s]

{'loss': 3.7724, 'grad_norm': 1.2749272584915161, 'learning_rate': 3.519900497512438e-06, 'epoch': 2.47}


 84%|████████▍ | 27001/32160 [1:42:54<20:11,  4.26it/s]

{'loss': 3.7801, 'grad_norm': 1.3431122303009033, 'learning_rate': 3.208955223880597e-06, 'epoch': 2.52}


 86%|████████▌ | 27501/32160 [1:44:49<17:36,  4.41it/s]

{'loss': 3.7778, 'grad_norm': 1.2516902685165405, 'learning_rate': 2.898009950248756e-06, 'epoch': 2.57}


 87%|████████▋ | 28001/32160 [1:46:44<16:14,  4.27it/s]

{'loss': 3.7573, 'grad_norm': 1.2668768167495728, 'learning_rate': 2.5870646766169156e-06, 'epoch': 2.61}


 89%|████████▊ | 28501/32160 [1:48:40<14:25,  4.23it/s]

{'loss': 3.7666, 'grad_norm': 1.1435400247573853, 'learning_rate': 2.2761194029850747e-06, 'epoch': 2.66}


 90%|█████████ | 29001/32160 [1:50:35<12:04,  4.36it/s]

{'loss': 3.7563, 'grad_norm': 1.4683988094329834, 'learning_rate': 1.965174129353234e-06, 'epoch': 2.71}


 92%|█████████▏| 29501/32160 [1:52:30<10:06,  4.38it/s]

{'loss': 3.7604, 'grad_norm': 1.1953444480895996, 'learning_rate': 1.6542288557213931e-06, 'epoch': 2.75}


 93%|█████████▎| 30001/32160 [1:54:25<08:33,  4.20it/s]

{'loss': 3.7601, 'grad_norm': 1.4811553955078125, 'learning_rate': 1.3432835820895524e-06, 'epoch': 2.8}


 95%|█████████▍| 30501/32160 [1:56:20<06:28,  4.27it/s]

{'loss': 3.7796, 'grad_norm': 1.3054877519607544, 'learning_rate': 1.0323383084577115e-06, 'epoch': 2.85}


 96%|█████████▋| 31001/32160 [1:58:16<04:31,  4.27it/s]

{'loss': 3.7692, 'grad_norm': 1.3095703125, 'learning_rate': 7.213930348258707e-07, 'epoch': 2.89}


 98%|█████████▊| 31501/32160 [2:00:11<02:32,  4.31it/s]

{'loss': 3.7589, 'grad_norm': 1.437182903289795, 'learning_rate': 4.104477611940299e-07, 'epoch': 2.94}


100%|█████████▉| 32001/32160 [2:02:06<00:37,  4.27it/s]

{'loss': 3.7701, 'grad_norm': 1.2938390970230103, 'learning_rate': 9.950248756218906e-08, 'epoch': 2.99}


100%|██████████| 32160/32160 [2:02:47<00:00,  4.37it/s]

{'train_runtime': 7367.2769, 'train_samples_per_second': 139.683, 'train_steps_per_second': 4.365, 'train_loss': 4.038430110494889, 'epoch': 3.0}





TrainOutput(global_step=32160, training_loss=4.038430110494889, metrics={'train_runtime': 7367.2769, 'train_samples_per_second': 139.683, 'train_steps_per_second': 4.365, 'train_loss': 4.038430110494889, 'epoch': 3.0})

In [None]:
trainer.evaluate(max_length=max_length)

In [27]:
model.save_pretrained("model-trained")

In [35]:
text = "בראשית ברא אלוהים את השמים ואת הארץ"
inputs = tokenizer(text, return_tensors='pt', ).to(device)
summary_ids = model.generate(**inputs, max_length=200, min_length=100, length_penalty=2.0, num_beams=4, early_stopping=True)
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print(summary)

וחושך. וכן הוא חלק מן העולם, והוא אחד מן הארץ, והוא אחד מן הארץ, וכן הוא חלק מן הארץ, והוא השם הוא חלק מן העניין, וכן יהיה חלק אחד מן הדבר, וכן יראה האיש אחד, וזהו גם שם, וגם משם דכתיב דכתיב
