In [1]:
from transformers import AutoModelForSeq2SeqLM, AutoModel, AutoTokenizer, DataCollatorForSeq2Seq, Seq2SeqTrainer, Seq2SeqTrainingArguments
from datasets import *
import nltk
import json
import pandas as pd
import numpy as np
import torch

SEED = 42
nltk.download('punkt')

  from .autonotebook import tqdm as notebook_tqdm
[nltk_data] Downloading package punkt to
[nltk_data]     C:\Users\Luka\AppData\Roaming\nltk_data...
[nltk_data]   Package punkt is already up-to-date!


True

In [2]:
model_path = "../../models"

In [3]:
# model to fine tune
model_checkpoint = "cjvt/t5-sl-small"
# model_checkpoint = "cjvt/t5-sl-large"
# model_checkpoint = "google/mt5-small"
# model_checkpoint = "google/mt5-base"
# model_checkpoint = "google/mt5-large"
name = model_checkpoint.split("/")[-1]

epochs = 25

# this batch size fits on a GPU with 12GB of RAM with the model_checkpoint = "cjvt/t5-sl-small"
batch_size = 16
# this batch size fits on a GPU with 12GB of RAM with the model_checkpoint = "google/mt5-small"
# batch_size = 8

tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(model_checkpoint)

### Load the data

In [4]:
data = pd.read_csv("../../data/processed/prompt_reply_pairs.csv", sep=";")
data

Unnamed: 0,prompt,reply,rank,split
0,"V mojo hišo bodo prispeli 3 prijatelji, rad bi...",Kaj pa igranje lažnivca? Za to igro so potrebn...,0,train
1,"V mojo hišo bodo prispeli 3 prijatelji, rad bi...","Uno: Je klasična igra za otroke in odrasle, pr...",1,train
2,"V mojo hišo bodo prispeli 3 prijatelji, rad bi...",Igrate lahko:\n - 8 noro\n - celina\n - Lažnivec,2,train
3,Napišite hvalnico šolskemu taboru v slogu Mela...,Spomnim se vas prijatelji\nIn spet sva se sreč...,0,train
4,Ali kaj bolj veselega?,♪ Spakirajte in se pridružite zabavi!♪\n♪ Šols...,0,train
...,...,...,...,...
47509,"Hvala, bom videl, če lahko naredim te sprememb...",Elektronski seti bobnov so na splošno tišji od...,0,test
47510,"Hvala, bom videl, če lahko naredim te sprememb...","To je odvisno od seta bobnov, ki ga trenutno u...",1,test
47511,"Imam elektronski set bobnov, a vsakič, ko igra...",Pod komplet lahko poskusite postaviti izolacij...,2,test
47512,Kadarkoli poskušam postaviti izolacijski mater...,"Da, obstaja več načinov, kako to popraviti.\n\...",0,test


In [5]:
train_data = data[data["split"] == "train"]
val_data = data[data["split"] == "val"]
test_data = data[data["split"] == "test"]
len(train_data), len(val_data), len(test_data)

(37902, 4737, 4875)

In [6]:
train_data = Dataset.from_pandas(train_data[['prompt', 'reply']])
val_data = Dataset.from_pandas(val_data[['prompt', 'reply']])
test_data = Dataset.from_pandas(test_data[['prompt', 'reply']])
train_data

Dataset({
    features: ['prompt', 'reply', '__index_level_0__'],
    num_rows: 37902
})

In [7]:
def convert_to_features(examples):
    prefix_in = "Uporabnik: "
    # prefix_in = ""
    examples["prompt"] = [prefix_in + prompt for prompt in examples["prompt"]]
    # prefix_out = "Asistent: "
    prefix_out = ""
    examples["reply"] = [prefix_out + reply for reply in examples["reply"]]
    
    model_inputs = tokenizer(examples['prompt'], pad_to_max_length=True, max_length=512, truncation=True, return_tensors='pt')

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(examples['reply'], pad_to_max_length=True, max_length=128, truncation=True, return_tensors='pt')

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [8]:
train_data = train_data.map(convert_to_features, batched=True, load_from_cache_file=False)
val_data = val_data.map(convert_to_features, batched=True, load_from_cache_file=False)
test_data = test_data.map(convert_to_features, batched=True, load_from_cache_file=False)
train_data

                                                                   

Dataset({
    features: ['prompt', 'reply', '__index_level_0__', 'input_ids', 'attention_mask', 'labels'],
    num_rows: 37902
})

### Fine Tune

In [9]:
metric = load_metric("rouge")

  metric = load_metric("rouge")


In [10]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    # Replace -100 in the labels as we can't decode them.
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    
    # Rouge expects a newline after each sentence
    decoded_preds = ["\n".join(nltk.sent_tokenize(pred.strip())) for pred in decoded_preds]
    decoded_labels = ["\n".join(nltk.sent_tokenize(label.strip())) for label in decoded_labels]
    
    result = metric.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    # Extract a few results
    result = {key: value.mid.fmeasure * 100 for key, value in result.items()}
    
    # Add mean generated length
    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)
    
    return {k: round(v, 4) for k, v in result.items()}

In [11]:
training_args = Seq2SeqTrainingArguments(
    output_dir=f"{model_path}/{name}-finetuned-assistant",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=epochs,
    predict_with_generate=True,
    fp16=False, # setting this to true gives loss 0.0 at every step for some reason
    push_to_hub=False, 
    load_best_model_at_end=True # load best model at end so we save the best model instead of the last model
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    eval_dataset=val_data,
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics
)

In [12]:
trainer.train()
trainer.save_model(f"{model_path}/{name}-finetuned-assistant")

  0%|          | 0/59225 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  1%|          | 500/59225 [01:53<3:37:54,  4.49it/s]

{'loss': 9.3774, 'learning_rate': 1.9831152384972564e-05, 'epoch': 0.21}


  2%|▏         | 1000/59225 [03:45<3:36:32,  4.48it/s]

{'loss': 3.3558, 'learning_rate': 1.9662304769945127e-05, 'epoch': 0.42}


  3%|▎         | 1500/59225 [05:37<3:34:22,  4.49it/s]

{'loss': 3.2196, 'learning_rate': 1.949345715491769e-05, 'epoch': 0.63}


  3%|▎         | 2000/59225 [07:29<3:33:41,  4.46it/s]

{'loss': 3.1523, 'learning_rate': 1.932460953989025e-05, 'epoch': 0.84}


                                                      
  4%|▍         | 2369/59225 [10:33<3:25:36,  4.61it/s]

{'eval_loss': 2.629715919494629, 'eval_rouge1': 12.011, 'eval_rouge2': 2.3305, 'eval_rougeL': 9.2253, 'eval_rougeLsum': 10.5675, 'eval_gen_len': 18.8035, 'eval_runtime': 100.4999, 'eval_samples_per_second': 47.134, 'eval_steps_per_second': 2.955, 'epoch': 1.0}


  4%|▍         | 2500/59225 [11:03<3:31:36,  4.47it/s]  

{'loss': 3.1039, 'learning_rate': 1.9155761924862814e-05, 'epoch': 1.06}


  5%|▌         | 3000/59225 [12:55<3:29:57,  4.46it/s]

{'loss': 3.0824, 'learning_rate': 1.8986914309835376e-05, 'epoch': 1.27}


  6%|▌         | 3500/59225 [14:47<3:27:28,  4.48it/s]

{'loss': 3.0579, 'learning_rate': 1.8818066694807936e-05, 'epoch': 1.48}


  7%|▋         | 4000/59225 [16:37<3:20:18,  4.60it/s]

{'loss': 3.0134, 'learning_rate': 1.86492190797805e-05, 'epoch': 1.69}


  8%|▊         | 4500/59225 [18:25<3:17:43,  4.61it/s]

{'loss': 3.0032, 'learning_rate': 1.848037146475306e-05, 'epoch': 1.9}


                                                      
  8%|▊         | 4738/59225 [20:55<3:11:19,  4.75it/s]

{'eval_loss': 2.5587046146392822, 'eval_rouge1': 11.7592, 'eval_rouge2': 3.0415, 'eval_rougeL': 9.5133, 'eval_rougeLsum': 10.7109, 'eval_gen_len': 18.5168, 'eval_runtime': 98.0345, 'eval_samples_per_second': 48.32, 'eval_steps_per_second': 3.03, 'epoch': 2.0}


  8%|▊         | 5000/59225 [21:53<3:15:43,  4.62it/s]  

{'loss': 2.9785, 'learning_rate': 1.8311523849725623e-05, 'epoch': 2.11}


  9%|▉         | 5500/59225 [23:41<3:12:53,  4.64it/s]

{'loss': 2.9604, 'learning_rate': 1.8142676234698185e-05, 'epoch': 2.32}


 10%|█         | 6000/59225 [25:30<3:14:34,  4.56it/s]

{'loss': 2.9441, 'learning_rate': 1.7973828619670748e-05, 'epoch': 2.53}


 11%|█         | 6500/59225 [27:18<3:09:48,  4.63it/s]

{'loss': 2.956, 'learning_rate': 1.780498100464331e-05, 'epoch': 2.74}


 12%|█▏        | 7000/59225 [29:07<3:08:18,  4.62it/s]

{'loss': 2.9133, 'learning_rate': 1.7636133389615873e-05, 'epoch': 2.95}


                                                      
 12%|█▏        | 7107/59225 [31:07<3:01:47,  4.78it/s]

{'eval_loss': 2.5168209075927734, 'eval_rouge1': 11.8015, 'eval_rouge2': 3.0852, 'eval_rougeL': 9.6113, 'eval_rougeLsum': 10.769, 'eval_gen_len': 18.6863, 'eval_runtime': 97.1746, 'eval_samples_per_second': 48.747, 'eval_steps_per_second': 3.056, 'epoch': 3.0}


 13%|█▎        | 7500/59225 [32:33<3:06:51,  4.61it/s]  

{'loss': 2.9082, 'learning_rate': 1.7467285774588435e-05, 'epoch': 3.17}


 14%|█▎        | 8001/59225 [34:22<3:04:25,  4.63it/s]

{'loss': 2.9176, 'learning_rate': 1.7298438159560998e-05, 'epoch': 3.38}


 14%|█▍        | 8500/59225 [36:10<3:02:48,  4.62it/s]

{'loss': 2.8492, 'learning_rate': 1.712959054453356e-05, 'epoch': 3.59}


 15%|█▌        | 9000/59225 [37:59<3:01:11,  4.62it/s]

{'loss': 2.9033, 'learning_rate': 1.6960742929506123e-05, 'epoch': 3.8}


                                                      
 16%|█▌        | 9476/59225 [41:20<2:56:55,  4.69it/s]

{'eval_loss': 2.4901955127716064, 'eval_rouge1': 11.7933, 'eval_rouge2': 3.1537, 'eval_rougeL': 9.6646, 'eval_rougeLsum': 10.8236, 'eval_gen_len': 18.8244, 'eval_runtime': 97.9017, 'eval_samples_per_second': 48.385, 'eval_steps_per_second': 3.034, 'epoch': 4.0}


 16%|█▌        | 9500/59225 [41:27<3:06:07,  4.45it/s]  

{'loss': 2.8646, 'learning_rate': 1.6791895314478682e-05, 'epoch': 4.01}


 17%|█▋        | 10000/59225 [43:15<2:57:50,  4.61it/s]

{'loss': 2.8629, 'learning_rate': 1.6623047699451248e-05, 'epoch': 4.22}


 18%|█▊        | 10500/59225 [45:04<2:54:59,  4.64it/s]

{'loss': 2.8583, 'learning_rate': 1.6454200084423807e-05, 'epoch': 4.43}


 19%|█▊        | 11000/59225 [46:52<2:55:30,  4.58it/s]

{'loss': 2.8396, 'learning_rate': 1.6285352469396373e-05, 'epoch': 4.64}


 19%|█▉        | 11500/59225 [48:40<2:52:23,  4.61it/s]

{'loss': 2.854, 'learning_rate': 1.6116504854368932e-05, 'epoch': 4.85}


                                                       
 20%|██        | 11845/59225 [51:34<2:47:27,  4.72it/s]

{'eval_loss': 2.4688475131988525, 'eval_rouge1': 12.0306, 'eval_rouge2': 3.3307, 'eval_rougeL': 9.8763, 'eval_rougeLsum': 11.0319, 'eval_gen_len': 18.8723, 'eval_runtime': 98.9753, 'eval_samples_per_second': 47.86, 'eval_steps_per_second': 3.001, 'epoch': 5.0}


 20%|██        | 12001/59225 [52:09<2:49:47,  4.64it/s]  

{'loss': 2.795, 'learning_rate': 1.5947657239341498e-05, 'epoch': 5.07}


 21%|██        | 12500/59225 [53:57<2:47:52,  4.64it/s]

{'loss': 2.823, 'learning_rate': 1.5778809624314057e-05, 'epoch': 5.28}


 22%|██▏       | 13000/59225 [55:46<2:46:39,  4.62it/s]

{'loss': 2.8139, 'learning_rate': 1.5609962009286623e-05, 'epoch': 5.49}


 23%|██▎       | 13500/59225 [57:34<2:44:45,  4.63it/s]

{'loss': 2.8377, 'learning_rate': 1.544111439425918e-05, 'epoch': 5.7}


 24%|██▎       | 14000/59225 [59:23<2:43:02,  4.62it/s]

{'loss': 2.8129, 'learning_rate': 1.5272266779231744e-05, 'epoch': 5.91}


                                                         
 24%|██▍       | 14214/59225 [1:01:48<2:37:02,  4.78it/s]

{'eval_loss': 2.4567205905914307, 'eval_rouge1': 12.1489, 'eval_rouge2': 3.357, 'eval_rougeL': 9.9491, 'eval_rougeLsum': 11.1134, 'eval_gen_len': 18.8469, 'eval_runtime': 98.6145, 'eval_samples_per_second': 48.036, 'eval_steps_per_second': 3.012, 'epoch': 6.0}


 24%|██▍       | 14500/59225 [1:02:51<2:41:38,  4.61it/s]  

{'loss': 2.7946, 'learning_rate': 1.5103419164204307e-05, 'epoch': 6.12}


 25%|██▌       | 15000/59225 [1:04:39<2:39:45,  4.61it/s]

{'loss': 2.7893, 'learning_rate': 1.4934571549176869e-05, 'epoch': 6.33}


 26%|██▌       | 15501/59225 [1:06:28<2:37:15,  4.63it/s]

{'loss': 2.7815, 'learning_rate': 1.4765723934149432e-05, 'epoch': 6.54}


 27%|██▋       | 16001/59225 [1:08:17<2:36:05,  4.62it/s]

{'loss': 2.8038, 'learning_rate': 1.4596876319121992e-05, 'epoch': 6.75}


 28%|██▊       | 16500/59225 [1:10:05<2:33:54,  4.63it/s]

{'loss': 2.7768, 'learning_rate': 1.4428028704094557e-05, 'epoch': 6.96}


                                                         
 28%|██▊       | 16583/59225 [1:12:00<2:29:43,  4.75it/s]

{'eval_loss': 2.4409737586975098, 'eval_rouge1': 12.1656, 'eval_rouge2': 3.381, 'eval_rougeL': 9.9821, 'eval_rougeLsum': 11.1229, 'eval_gen_len': 18.8581, 'eval_runtime': 96.9213, 'eval_samples_per_second': 48.875, 'eval_steps_per_second': 3.064, 'epoch': 7.0}


 29%|██▊       | 17000/59225 [1:13:32<2:32:48,  4.61it/s]  

{'loss': 2.7497, 'learning_rate': 1.4259181089067117e-05, 'epoch': 7.18}


 30%|██▉       | 17500/59225 [1:15:20<2:31:08,  4.60it/s]

{'loss': 2.7651, 'learning_rate': 1.4090333474039681e-05, 'epoch': 7.39}


 30%|███       | 18000/59225 [1:17:09<2:27:57,  4.64it/s]

{'loss': 2.7585, 'learning_rate': 1.3921485859012242e-05, 'epoch': 7.6}


 31%|███       | 18500/59225 [1:18:57<2:27:38,  4.60it/s]

{'loss': 2.7899, 'learning_rate': 1.3752638243984803e-05, 'epoch': 7.81}


                                                         
 32%|███▏      | 18952/59225 [1:22:13<2:23:07,  4.69it/s]

{'eval_loss': 2.4297287464141846, 'eval_rouge1': 12.2573, 'eval_rouge2': 3.482, 'eval_rougeL': 10.0466, 'eval_rougeLsum': 11.1662, 'eval_gen_len': 18.845, 'eval_runtime': 97.7742, 'eval_samples_per_second': 48.448, 'eval_steps_per_second': 3.038, 'epoch': 8.0}


 32%|███▏      | 19000/59225 [1:22:24<2:24:28,  4.64it/s]  

{'loss': 2.7665, 'learning_rate': 1.3583790628957367e-05, 'epoch': 8.02}


 33%|███▎      | 19500/59225 [1:24:13<2:22:52,  4.63it/s]

{'loss': 2.7191, 'learning_rate': 1.3414943013929928e-05, 'epoch': 8.23}


 34%|███▍      | 20000/59225 [1:26:01<2:22:03,  4.60it/s]

{'loss': 2.7481, 'learning_rate': 1.3246095398902492e-05, 'epoch': 8.44}


 35%|███▍      | 20500/59225 [1:27:49<2:20:08,  4.61it/s]

{'loss': 2.764, 'learning_rate': 1.3077247783875053e-05, 'epoch': 8.65}


 35%|███▌      | 21000/59225 [1:29:38<2:17:48,  4.62it/s]

{'loss': 2.7604, 'learning_rate': 1.2908400168847617e-05, 'epoch': 8.86}


                                                         
 36%|███▌      | 21321/59225 [1:32:27<2:13:30,  4.73it/s]

{'eval_loss': 2.422267198562622, 'eval_rouge1': 12.2087, 'eval_rouge2': 3.4995, 'eval_rougeL': 10.024, 'eval_rougeLsum': 11.1462, 'eval_gen_len': 18.8461, 'eval_runtime': 99.1319, 'eval_samples_per_second': 47.785, 'eval_steps_per_second': 2.996, 'epoch': 9.0}


 36%|███▋      | 21500/59225 [1:33:07<2:15:50,  4.63it/s]  

{'loss': 2.711, 'learning_rate': 1.2739552553820178e-05, 'epoch': 9.08}


 37%|███▋      | 22000/59225 [1:34:55<2:15:07,  4.59it/s]

{'loss': 2.7371, 'learning_rate': 1.2570704938792742e-05, 'epoch': 9.29}


 38%|███▊      | 22500/59225 [1:36:44<2:13:34,  4.58it/s]

{'loss': 2.7275, 'learning_rate': 1.2401857323765303e-05, 'epoch': 9.5}


 39%|███▉      | 23000/59225 [1:38:32<2:09:58,  4.64it/s]

{'loss': 2.721, 'learning_rate': 1.2233009708737864e-05, 'epoch': 9.71}


 40%|███▉      | 23500/59225 [1:40:21<2:08:11,  4.64it/s]

{'loss': 2.7477, 'learning_rate': 1.2064162093710428e-05, 'epoch': 9.92}


                                                         
 40%|████      | 23690/59225 [1:42:38<2:04:56,  4.74it/s]

{'eval_loss': 2.413259744644165, 'eval_rouge1': 12.1915, 'eval_rouge2': 3.4595, 'eval_rougeL': 10.0354, 'eval_rougeLsum': 11.1265, 'eval_gen_len': 18.8153, 'eval_runtime': 96.3769, 'eval_samples_per_second': 49.151, 'eval_steps_per_second': 3.082, 'epoch': 10.0}


 41%|████      | 24000/59225 [1:43:47<2:07:02,  4.62it/s]  

{'loss': 2.7017, 'learning_rate': 1.1895314478682989e-05, 'epoch': 10.13}


 41%|████▏     | 24500/59225 [1:45:35<2:04:49,  4.64it/s]

{'loss': 2.7537, 'learning_rate': 1.1726466863655553e-05, 'epoch': 10.34}


 42%|████▏     | 25000/59225 [1:47:24<2:03:04,  4.63it/s]

{'loss': 2.7013, 'learning_rate': 1.1557619248628114e-05, 'epoch': 10.55}


 43%|████▎     | 25500/59225 [1:49:12<2:01:27,  4.63it/s]

{'loss': 2.713, 'learning_rate': 1.1388771633600678e-05, 'epoch': 10.76}


 44%|████▍     | 26000/59225 [1:51:01<2:01:18,  4.56it/s]

{'loss': 2.6872, 'learning_rate': 1.1219924018573238e-05, 'epoch': 10.98}


                                                         
 44%|████▍     | 26059/59225 [1:52:50<1:56:50,  4.73it/s]

{'eval_loss': 2.407684326171875, 'eval_rouge1': 12.3677, 'eval_rouge2': 3.5926, 'eval_rougeL': 10.1637, 'eval_rougeLsum': 11.2903, 'eval_gen_len': 18.7912, 'eval_runtime': 96.5508, 'eval_samples_per_second': 49.062, 'eval_steps_per_second': 3.076, 'epoch': 11.0}


 45%|████▍     | 26500/59225 [1:54:26<1:58:37,  4.60it/s]  

{'loss': 2.7151, 'learning_rate': 1.1051076403545803e-05, 'epoch': 11.19}


 46%|████▌     | 27000/59225 [1:56:15<1:57:04,  4.59it/s]

{'loss': 2.6982, 'learning_rate': 1.0882228788518363e-05, 'epoch': 11.4}


 46%|████▋     | 27500/59225 [1:58:04<1:54:55,  4.60it/s]

{'loss': 2.7138, 'learning_rate': 1.0713381173490924e-05, 'epoch': 11.61}


 47%|████▋     | 28000/59225 [1:59:53<1:52:36,  4.62it/s]

{'loss': 2.6736, 'learning_rate': 1.0544533558463488e-05, 'epoch': 11.82}


                                                         
 48%|████▊     | 28428/59225 [2:03:05<1:49:14,  4.70it/s]

{'eval_loss': 2.400829315185547, 'eval_rouge1': 12.2403, 'eval_rouge2': 3.5499, 'eval_rougeL': 10.069, 'eval_rougeLsum': 11.1895, 'eval_gen_len': 18.7612, 'eval_runtime': 99.2845, 'eval_samples_per_second': 47.711, 'eval_steps_per_second': 2.991, 'epoch': 12.0}


 48%|████▊     | 28500/59225 [2:03:22<1:50:02,  4.65it/s]  

{'loss': 2.7115, 'learning_rate': 1.0375685943436049e-05, 'epoch': 12.03}


 49%|████▉     | 29000/59225 [2:05:10<1:49:08,  4.62it/s]

{'loss': 2.6939, 'learning_rate': 1.0206838328408613e-05, 'epoch': 12.24}


 50%|████▉     | 29500/59225 [2:06:58<1:48:56,  4.55it/s]

{'loss': 2.6783, 'learning_rate': 1.0037990713381174e-05, 'epoch': 12.45}


 51%|█████     | 30000/59225 [2:08:47<1:47:16,  4.54it/s]

{'loss': 2.6923, 'learning_rate': 9.869143098353737e-06, 'epoch': 12.66}


 51%|█████▏    | 30500/59225 [2:10:35<1:45:42,  4.53it/s]

{'loss': 2.6916, 'learning_rate': 9.700295483326299e-06, 'epoch': 12.87}


                                                         
 52%|█████▏    | 30797/59225 [2:13:18<1:39:01,  4.78it/s]

{'eval_loss': 2.397221565246582, 'eval_rouge1': 12.2623, 'eval_rouge2': 3.5395, 'eval_rougeL': 10.1046, 'eval_rougeLsum': 11.2007, 'eval_gen_len': 18.7636, 'eval_runtime': 97.7456, 'eval_samples_per_second': 48.463, 'eval_steps_per_second': 3.038, 'epoch': 13.0}


 52%|█████▏    | 31000/59225 [2:14:03<1:41:59,  4.61it/s]  

{'loss': 2.675, 'learning_rate': 9.531447868298862e-06, 'epoch': 13.09}


 53%|█████▎    | 31500/59225 [2:15:51<1:40:12,  4.61it/s]

{'loss': 2.6902, 'learning_rate': 9.362600253271424e-06, 'epoch': 13.3}


 54%|█████▍    | 32000/59225 [2:17:39<1:37:32,  4.65it/s]

{'loss': 2.6864, 'learning_rate': 9.193752638243986e-06, 'epoch': 13.51}


 55%|█████▍    | 32500/59225 [2:19:28<1:36:08,  4.63it/s]

{'loss': 2.6846, 'learning_rate': 9.024905023216547e-06, 'epoch': 13.72}


 56%|█████▌    | 33000/59225 [2:21:16<1:34:46,  4.61it/s]

{'loss': 2.656, 'learning_rate': 8.85605740818911e-06, 'epoch': 13.93}


                                                         
 56%|█████▌    | 33166/59225 [2:23:30<1:31:10,  4.76it/s]

{'eval_loss': 2.392380714416504, 'eval_rouge1': 12.2214, 'eval_rouge2': 3.5996, 'eval_rougeL': 10.0848, 'eval_rougeLsum': 11.1798, 'eval_gen_len': 18.7368, 'eval_runtime': 97.9942, 'eval_samples_per_second': 48.34, 'eval_steps_per_second': 3.031, 'epoch': 14.0}


 57%|█████▋    | 33500/59225 [2:24:44<1:34:30,  4.54it/s]  

{'loss': 2.6838, 'learning_rate': 8.687209793161672e-06, 'epoch': 14.14}


 57%|█████▋    | 34000/59225 [2:26:32<1:31:20,  4.60it/s]

{'loss': 2.6621, 'learning_rate': 8.518362178134235e-06, 'epoch': 14.35}


 58%|█████▊    | 34500/59225 [2:28:21<1:29:13,  4.62it/s]

{'loss': 2.6546, 'learning_rate': 8.349514563106797e-06, 'epoch': 14.56}


 59%|█████▉    | 35000/59225 [2:30:09<1:27:38,  4.61it/s]

{'loss': 2.6867, 'learning_rate': 8.18066694807936e-06, 'epoch': 14.77}


 60%|█████▉    | 35500/59225 [2:31:58<1:26:13,  4.59it/s]

{'loss': 2.6608, 'learning_rate': 8.01181933305192e-06, 'epoch': 14.99}


                                                         
 60%|██████    | 35535/59225 [2:33:44<1:22:50,  4.77it/s]

{'eval_loss': 2.3895113468170166, 'eval_rouge1': 12.3468, 'eval_rouge2': 3.6137, 'eval_rougeL': 10.1645, 'eval_rougeLsum': 11.2768, 'eval_gen_len': 18.7739, 'eval_runtime': 98.2553, 'eval_samples_per_second': 48.211, 'eval_steps_per_second': 3.023, 'epoch': 15.0}


 61%|██████    | 36000/59225 [2:35:26<1:23:54,  4.61it/s]  

{'loss': 2.6529, 'learning_rate': 7.842971718024483e-06, 'epoch': 15.2}


 62%|██████▏   | 36501/59225 [2:37:14<1:22:12,  4.61it/s]

{'loss': 2.6771, 'learning_rate': 7.674124102997045e-06, 'epoch': 15.41}


 62%|██████▏   | 37000/59225 [2:39:03<1:20:39,  4.59it/s]

{'loss': 2.6479, 'learning_rate': 7.505276487969609e-06, 'epoch': 15.62}


 63%|██████▎   | 37500/59225 [2:40:51<1:19:23,  4.56it/s]

{'loss': 2.6923, 'learning_rate': 7.33642887294217e-06, 'epoch': 15.83}


                                                         
 64%|██████▍   | 37904/59225 [2:43:58<1:14:28,  4.77it/s]

{'eval_loss': 2.3844408988952637, 'eval_rouge1': 12.3332, 'eval_rouge2': 3.605, 'eval_rougeL': 10.1611, 'eval_rougeLsum': 11.2629, 'eval_gen_len': 18.7365, 'eval_runtime': 99.41, 'eval_samples_per_second': 47.651, 'eval_steps_per_second': 2.988, 'epoch': 16.0}


 64%|██████▍   | 38000/59225 [2:44:20<1:16:10,  4.64it/s]  

{'loss': 2.6397, 'learning_rate': 7.167581257914733e-06, 'epoch': 16.04}


 65%|██████▌   | 38500/59225 [2:46:08<1:14:25,  4.64it/s]

{'loss': 2.6595, 'learning_rate': 6.998733642887294e-06, 'epoch': 16.25}


 66%|██████▌   | 39000/59225 [2:47:57<1:12:46,  4.63it/s]

{'loss': 2.6517, 'learning_rate': 6.829886027859857e-06, 'epoch': 16.46}


 67%|██████▋   | 39500/59225 [2:49:45<1:11:49,  4.58it/s]

{'loss': 2.6581, 'learning_rate': 6.661038412832419e-06, 'epoch': 16.67}


 68%|██████▊   | 40000/59225 [2:51:34<1:09:06,  4.64it/s]

{'loss': 2.6588, 'learning_rate': 6.492190797804982e-06, 'epoch': 16.88}


                                                         
 68%|██████▊   | 40273/59225 [2:54:11<1:05:47,  4.80it/s]

{'eval_loss': 2.3824191093444824, 'eval_rouge1': 12.3387, 'eval_rouge2': 3.6234, 'eval_rougeL': 10.1743, 'eval_rougeLsum': 11.2722, 'eval_gen_len': 18.7634, 'eval_runtime': 97.9776, 'eval_samples_per_second': 48.348, 'eval_steps_per_second': 3.031, 'epoch': 17.0}


 68%|██████▊   | 40500/59225 [2:55:01<1:07:48,  4.60it/s]  

{'loss': 2.6344, 'learning_rate': 6.3233431827775435e-06, 'epoch': 17.1}


 69%|██████▉   | 41000/59225 [2:56:50<1:06:14,  4.59it/s]

{'loss': 2.6496, 'learning_rate': 6.154495567750106e-06, 'epoch': 17.31}


 70%|███████   | 41500/59225 [2:58:38<1:05:18,  4.52it/s]

{'loss': 2.6426, 'learning_rate': 5.985647952722668e-06, 'epoch': 17.52}


 71%|███████   | 42000/59225 [3:00:27<1:01:55,  4.64it/s]

{'loss': 2.66, 'learning_rate': 5.816800337695231e-06, 'epoch': 17.73}


 72%|███████▏  | 42500/59225 [3:02:16<1:00:07,  4.64it/s]

{'loss': 2.6482, 'learning_rate': 5.647952722667793e-06, 'epoch': 17.94}


                                                         
 72%|███████▏  | 42642/59225 [3:04:24<58:15,  4.74it/s]

{'eval_loss': 2.378870964050293, 'eval_rouge1': 12.3815, 'eval_rouge2': 3.6482, 'eval_rougeL': 10.2005, 'eval_rougeLsum': 11.2806, 'eval_gen_len': 18.7401, 'eval_runtime': 97.27, 'eval_samples_per_second': 48.699, 'eval_steps_per_second': 3.053, 'epoch': 18.0}


 73%|███████▎  | 43000/59225 [3:05:43<58:20,  4.63it/s]    

{'loss': 2.6366, 'learning_rate': 5.479105107640354e-06, 'epoch': 18.15}


 73%|███████▎  | 43500/59225 [3:07:31<57:06,  4.59it/s]

{'loss': 2.6689, 'learning_rate': 5.310257492612917e-06, 'epoch': 18.36}


 74%|███████▍  | 44000/59225 [3:09:20<54:47,  4.63it/s]

{'loss': 2.6514, 'learning_rate': 5.141409877585479e-06, 'epoch': 18.57}


 75%|███████▌  | 44500/59225 [3:11:08<53:03,  4.62it/s]

{'loss': 2.6449, 'learning_rate': 4.9725622625580416e-06, 'epoch': 18.78}


 76%|███████▌  | 45000/59225 [3:12:57<51:30,  4.60it/s]

{'loss': 2.6111, 'learning_rate': 4.803714647530604e-06, 'epoch': 19.0}


                                                       
 76%|███████▌  | 45011/59225 [3:14:38<49:50,  4.75it/s]

{'eval_loss': 2.3771345615386963, 'eval_rouge1': 12.384, 'eval_rouge2': 3.6596, 'eval_rougeL': 10.19, 'eval_rougeLsum': 11.2949, 'eval_gen_len': 18.734, 'eval_runtime': 98.3774, 'eval_samples_per_second': 48.151, 'eval_steps_per_second': 3.019, 'epoch': 19.0}


 77%|███████▋  | 45500/59225 [3:16:25<49:27,  4.63it/s]    

{'loss': 2.6759, 'learning_rate': 4.6348670325031665e-06, 'epoch': 19.21}


 78%|███████▊  | 46000/59225 [3:18:14<47:33,  4.63it/s]

{'loss': 2.6245, 'learning_rate': 4.466019417475729e-06, 'epoch': 19.42}


 79%|███████▊  | 46500/59225 [3:20:02<46:42,  4.54it/s]

{'loss': 2.6251, 'learning_rate': 4.297171802448291e-06, 'epoch': 19.63}


 79%|███████▉  | 47000/59225 [3:21:51<43:57,  4.63it/s]

{'loss': 2.6308, 'learning_rate': 4.128324187420853e-06, 'epoch': 19.84}


                                                       
 80%|████████  | 47380/59225 [3:24:50<41:35,  4.75it/s]

{'eval_loss': 2.376089096069336, 'eval_rouge1': 12.4266, 'eval_rouge2': 3.6597, 'eval_rougeL': 10.2388, 'eval_rougeLsum': 11.3311, 'eval_gen_len': 18.7226, 'eval_runtime': 96.9432, 'eval_samples_per_second': 48.864, 'eval_steps_per_second': 3.064, 'epoch': 20.0}


 80%|████████  | 47500/59225 [3:25:18<42:30,  4.60it/s]   

{'loss': 2.6371, 'learning_rate': 3.9594765723934156e-06, 'epoch': 20.05}


 81%|████████  | 48000/59225 [3:27:06<40:27,  4.62it/s]

{'loss': 2.6404, 'learning_rate': 3.7906289573659776e-06, 'epoch': 20.26}


 82%|████████▏ | 48500/59225 [3:28:55<38:49,  4.60it/s]

{'loss': 2.626, 'learning_rate': 3.6217813423385397e-06, 'epoch': 20.47}


 83%|████████▎ | 49000/59225 [3:30:43<37:22,  4.56it/s]

{'loss': 2.6024, 'learning_rate': 3.4529337273111017e-06, 'epoch': 20.68}


 84%|████████▎ | 49500/59225 [3:32:32<35:11,  4.61it/s]

{'loss': 2.6531, 'learning_rate': 3.284086112283664e-06, 'epoch': 20.89}


                                                       
 84%|████████▍ | 49749/59225 [3:35:03<33:08,  4.76it/s]

{'eval_loss': 2.373603343963623, 'eval_rouge1': 12.4487, 'eval_rouge2': 3.6668, 'eval_rougeL': 10.2266, 'eval_rougeLsum': 11.3261, 'eval_gen_len': 18.7361, 'eval_runtime': 96.6933, 'eval_samples_per_second': 48.99, 'eval_steps_per_second': 3.072, 'epoch': 21.0}


 84%|████████▍ | 50000/59225 [3:35:58<33:22,  4.61it/s]   

{'loss': 2.6419, 'learning_rate': 3.1152384972562267e-06, 'epoch': 21.11}


 85%|████████▌ | 50500/59225 [3:37:47<31:25,  4.63it/s]

{'loss': 2.6346, 'learning_rate': 2.946390882228789e-06, 'epoch': 21.32}


 86%|████████▌ | 51000/59225 [3:39:35<29:36,  4.63it/s]

{'loss': 2.6358, 'learning_rate': 2.7775432672013508e-06, 'epoch': 21.53}


 87%|████████▋ | 51501/59225 [3:41:24<27:46,  4.63it/s]

{'loss': 2.6222, 'learning_rate': 2.6086956521739132e-06, 'epoch': 21.74}


 88%|████████▊ | 52000/59225 [3:43:12<26:09,  4.60it/s]

{'loss': 2.6341, 'learning_rate': 2.4398480371464757e-06, 'epoch': 21.95}


                                                       
 88%|████████▊ | 52118/59225 [3:45:16<24:47,  4.78it/s]

{'eval_loss': 2.373302936553955, 'eval_rouge1': 12.4472, 'eval_rouge2': 3.6507, 'eval_rougeL': 10.2487, 'eval_rougeLsum': 11.3399, 'eval_gen_len': 18.7251, 'eval_runtime': 98.3829, 'eval_samples_per_second': 48.149, 'eval_steps_per_second': 3.019, 'epoch': 22.0}


 89%|████████▊ | 52500/59225 [3:46:40<24:41,  4.54it/s]   

{'loss': 2.6261, 'learning_rate': 2.2710004221190378e-06, 'epoch': 22.16}


 89%|████████▉ | 53000/59225 [3:48:28<22:19,  4.65it/s]

{'loss': 2.6176, 'learning_rate': 2.1021528070916e-06, 'epoch': 22.37}


 90%|█████████ | 53500/59225 [3:50:17<20:38,  4.62it/s]

{'loss': 2.6108, 'learning_rate': 1.9333051920641623e-06, 'epoch': 22.58}


 91%|█████████ | 54000/59225 [3:52:05<18:42,  4.66it/s]

{'loss': 2.6485, 'learning_rate': 1.7644575770367246e-06, 'epoch': 22.79}


                                                       
 92%|█████████▏| 54487/59225 [3:55:29<16:32,  4.78it/s]

{'eval_loss': 2.37231183052063, 'eval_rouge1': 12.4365, 'eval_rouge2': 3.6613, 'eval_rougeL': 10.2465, 'eval_rougeLsum': 11.3296, 'eval_gen_len': 18.7205, 'eval_runtime': 97.9087, 'eval_samples_per_second': 48.382, 'eval_steps_per_second': 3.033, 'epoch': 23.0}


 92%|█████████▏| 54500/59225 [3:55:33<49:26,  1.59it/s]   

{'loss': 2.6509, 'learning_rate': 1.5956099620092868e-06, 'epoch': 23.01}


 93%|█████████▎| 55000/59225 [3:57:21<15:13,  4.62it/s]

{'loss': 2.6426, 'learning_rate': 1.4267623469818489e-06, 'epoch': 23.22}


 94%|█████████▎| 55500/59225 [3:59:10<13:23,  4.63it/s]

{'loss': 2.5863, 'learning_rate': 1.2579147319544113e-06, 'epoch': 23.43}


 95%|█████████▍| 56000/59225 [4:00:58<11:41,  4.60it/s]

{'loss': 2.6509, 'learning_rate': 1.0890671169269736e-06, 'epoch': 23.64}


 95%|█████████▌| 56500/59225 [4:02:47<09:54,  4.59it/s]

{'loss': 2.6393, 'learning_rate': 9.202195018995358e-07, 'epoch': 23.85}


                                                       
 96%|█████████▌| 56856/59225 [4:05:42<08:15,  4.78it/s]

{'eval_loss': 2.3718297481536865, 'eval_rouge1': 12.4361, 'eval_rouge2': 3.647, 'eval_rougeL': 10.2355, 'eval_rougeLsum': 11.324, 'eval_gen_len': 18.7224, 'eval_runtime': 98.3658, 'eval_samples_per_second': 48.157, 'eval_steps_per_second': 3.019, 'epoch': 24.0}


 96%|█████████▌| 57001/59225 [4:06:15<07:58,  4.64it/s]   

{'loss': 2.6359, 'learning_rate': 7.51371886872098e-07, 'epoch': 24.06}


 97%|█████████▋| 57500/59225 [4:08:03<06:12,  4.63it/s]

{'loss': 2.6373, 'learning_rate': 5.825242718446603e-07, 'epoch': 24.27}


 98%|█████████▊| 58000/59225 [4:09:52<04:25,  4.62it/s]

{'loss': 2.6286, 'learning_rate': 4.136766568172225e-07, 'epoch': 24.48}


 99%|█████████▉| 58501/59225 [4:11:41<02:36,  4.63it/s]

{'loss': 2.6243, 'learning_rate': 2.448290417897847e-07, 'epoch': 24.69}


100%|█████████▉| 59000/59225 [4:13:29<00:48,  4.63it/s]

{'loss': 2.5903, 'learning_rate': 7.598142676234698e-08, 'epoch': 24.91}


                                                       
100%|██████████| 59225/59225 [4:15:56<00:00,  4.77it/s]

{'eval_loss': 2.37141489982605, 'eval_rouge1': 12.4506, 'eval_rouge2': 3.6566, 'eval_rougeL': 10.2485, 'eval_rougeLsum': 11.3412, 'eval_gen_len': 18.7218, 'eval_runtime': 98.3333, 'eval_samples_per_second': 48.173, 'eval_steps_per_second': 3.02, 'epoch': 25.0}


100%|██████████| 59225/59225 [4:15:57<00:00,  3.86it/s]


{'train_runtime': 15357.9704, 'train_samples_per_second': 61.698, 'train_steps_per_second': 3.856, 'train_loss': 2.796831964675562, 'epoch': 25.0}


### Evaluate

Validation set

In [13]:
val_results = trainer.evaluate()
print('Val results: ', val_results)

100%|██████████| 297/297 [01:37<00:00,  3.05it/s]

Val results:  {'eval_loss': 2.37141489982605, 'eval_rouge1': 12.4506, 'eval_rouge2': 3.6566, 'eval_rougeL': 10.2485, 'eval_rougeLsum': 11.3412, 'eval_gen_len': 18.7218, 'eval_runtime': 97.9626, 'eval_samples_per_second': 48.355, 'eval_steps_per_second': 3.032, 'epoch': 25.0}





Test set

In [14]:
test_results = trainer.predict(test_dataset=test_data)
print('Test results:', test_results.metrics)

100%|██████████| 305/305 [01:40<00:00,  3.03it/s]

Test results: {'test_loss': 2.4149272441864014, 'test_rouge1': 12.288, 'test_rouge2': 3.5705, 'test_rougeL': 10.0529, 'test_rougeLsum': 11.2572, 'test_gen_len': 18.6868, 'test_runtime': 100.9741, 'test_samples_per_second': 48.28, 'test_steps_per_second': 3.021}





In [15]:
input = "Uporabnik: Kdo je France Prešeren?"
input = tokenizer(input, return_tensors="pt").to("cuda")
outputs = trainer.model.generate(**input, max_length=128, no_repeat_ngram_size=2, num_beams=5, num_return_sequences=5)
tokenizer.decode(outputs[0], skip_special_tokens=True)
tokenizer.batch_decode(outputs, skip_special_tokens=True)

['France Prešeren je bil francoski pesnik, ki se je rodil leta 1922 v Parizu.',
 'France Prešeren je bil francoski pesnik, ki se je rodil leta 1922 v Parizu. Bil je francoski pisatelj in pisatelj, znan po svojem delu in delu v različnih delih sveta.',
 'France Prešeren je bil francoski pesnik, ki se je rodil leta 1922 v Parizu. Bil je francoski pisatelj in pisatelj, znan po svojem delu in delu na področju glasbe, glasbe in glasbe.',
 'France Prešeren je bil francoski pesnik, ki se je rodil leta 1922 v Parizu. Bil je francoski pisatelj in pisatelj, znan po svojem delu in delu na področju umetnosti in umetnosti.',
 'France Prešeren je bil francoski pesnik, ki se je rodil leta 1922 v Parizu. Bil je francoski pisatelj in pisatelj, znan po svojem delu in delu na področju glasbe, glasbe in glasbe. Znan je tudi po svojih delih, kot so pesmi, eseji, pesmi in pesmi.']