In [1]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
)
from peft import (
    LoraConfig,
    get_peft_model,
)
import torch
from trl import SFTTrainer, setup_chat_format

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
base_model = "OpenLLM-Ro/RoLlama3-8b-Instruct"
new_model = "llama-3-8b-chat-aromanian_v2"

In [3]:
torch_dtype = torch.bfloat16
attn_implementation = "eager"

In [4]:
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

Loading checkpoint shards: 100%|██████████| 4/4 [00:27<00:00,  6.90s/it]


In [5]:
tokenizer = AutoTokenizer.from_pretrained(base_model)
model, tokenizer = setup_chat_format(model, tokenizer)

In [6]:
peft_config = LoraConfig(
    r=32,
    lora_alpha=32,
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=['up_proj', 'down_proj', 'gate_proj', 'k_proj', 'q_proj', 'v_proj', 'o_proj']
)
model = get_peft_model(model, peft_config)

In [7]:
import pandas as pd
from datasets import Dataset
import unicodedata
df = pd.read_csv("../dataset/nllb_corpus_train.csv")
df_transformed = df.applymap(lambda x: ''.join([c for c in unicodedata.normalize('NFKD', x)  if unicodedata.category(c) != 'Mn']) if type(x) == str else x)
# There are some words that have a -mi at the end, we will eliminate them also
df_transformed.replace(r'\s*-\s*mi\b', '', regex=True, inplace=True)
df_transformed.replace(r'\(i\)', 'i', regex=True, inplace=True)
df_transformed.replace('γ', 'y', regex=True, inplace=True)
df_transformed.replace(r'’', '', regex=True, inplace=True)
df_transformed.replace(r'“', '', regex=True, inplace=True)
df_transformed.replace(r'„', '', regex=True, inplace=True)
df_transformed.columns = [str(q).strip() for q in df_transformed.columns]
# df_transformed.drop(columns=['ro', 'rup', 'translations'], inplace=True)
dataset = Dataset.from_pandas(df_transformed)

dataset = dataset.shuffle(seed=42) # Only use 1000 samples for quick demo


# I'm not sure if apply_chat_template works with llama3, a new jinja template should be created
# def format_chat_template(row, tokenizer=tokenizer):
#     row_json = [{"role": "user", "content": f"Traduce din aromana in romana: {row['rup']}"},
#                {"role": "assistant", "content": row["ro"]}]
#     row["text"] = tokenizer.apply_chat_template(row_json, tokenize=False, add_generation_prompt=True)
#     return row

# dataset = dataset.map(
#     format_chat_template,
#     num_proc=4,
# )


def generate_prompt(data_point):
    begin_of_text = "<|begin_of_text|>"
    end_of_text = "<|end_of_text|>"
    start_header_id = "<|start_header_id|>"
    end_header_id = "<|end_header_id|>"
    eot_id = "<|eot_id|>"
    return {"text":
    f"""{begin_of_text}
        {start_header_id}system{end_header_id} Tradu urmatorul text din aromana in romana:
        {start_header_id}user{end_header_id} {data_point["rup"]}{eot_id}
        {start_header_id} assistant{end_header_id} {data_point["ro"]}{eot_id}
        {end_of_text}"""}

dataset = dataset.map(generate_prompt)

dataset['text'][3]

  df_transformed = df.applymap(lambda x: ''.join([c for c in unicodedata.normalize('NFKD', x)  if unicodedata.category(c) != 'Mn']) if type(x) == str else x)
Map: 100%|██████████| 27033/27033 [00:00<00:00, 30431.02 examples/s]


'<|begin_of_text|>\n        <|start_header_id|>system<|end_header_id|> Tradu această propoziție din aromână în română.\n        <|start_header_id|>user<|end_header_id|> pazitu<|eot_id|>\n        <|start_header_id|> assistant<|end_header_id|> pazit<|eot_id|>\n        <|end_of_text|>'

In [8]:
dataset = dataset.train_test_split(test_size=0.01)
# model.gradient_checkpointing_enable()
# model.gradient_checkpointing_disable()


In [9]:
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    evaluation_strategy="steps",
    eval_steps=0.2,
    logging_steps=70,
    warmup_steps=15,
    logging_strategy="steps",
    learning_rate=7e-6,
    fp16=False,
    bf16=False,
    group_by_length=True,
    disable_tqdm=False,
    report_to="none",
)



In [10]:
trainer = SFTTrainer(
    model=model,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    peft_config=peft_config,
    max_seq_length=216,
    dataset_text_field="text",
    tokenizer=tokenizer,
    args=training_arguments,
    packing= False,
)


Deprecated positional argument(s) used in SFTTrainer, please use the SFTConfig to set these arguments instead.
Map: 100%|██████████| 26762/26762 [00:00<00:00, 37720.44 examples/s]
Map: 100%|██████████| 271/271 [00:00<00:00, 30104.52 examples/s]


In [11]:
trainer.train()

  1%|          | 70/13381 [00:54<2:42:30,  1.37it/s]

{'loss': 4.4699, 'grad_norm': 3.831062078475952, 'learning_rate': 6.971195570851414e-06, 'epoch': 0.01}


  1%|          | 140/13381 [01:45<2:41:01,  1.37it/s]

{'loss': 2.3875, 'grad_norm': 2.2177364826202393, 'learning_rate': 6.934535388298668e-06, 'epoch': 0.01}


  2%|▏         | 210/13381 [02:37<2:41:00,  1.36it/s]

{'loss': 2.1302, 'grad_norm': 3.847557783126831, 'learning_rate': 6.897875205745923e-06, 'epoch': 0.02}


  2%|▏         | 280/13381 [03:28<2:37:37,  1.39it/s]

{'loss': 1.9491, 'grad_norm': 4.579961776733398, 'learning_rate': 6.8612150231931765e-06, 'epoch': 0.02}


  3%|▎         | 350/13381 [04:18<2:36:29,  1.39it/s]

{'loss': 1.7907, 'grad_norm': 3.3574726581573486, 'learning_rate': 6.824554840640431e-06, 'epoch': 0.03}


  3%|▎         | 420/13381 [05:09<2:35:50,  1.39it/s]

{'loss': 1.8775, 'grad_norm': 3.4469032287597656, 'learning_rate': 6.787894658087685e-06, 'epoch': 0.03}


  4%|▎         | 490/13381 [06:00<2:35:40,  1.38it/s]

{'loss': 1.6482, 'grad_norm': 2.2302167415618896, 'learning_rate': 6.75123447553494e-06, 'epoch': 0.04}


  4%|▍         | 560/13381 [06:58<2:34:48,  1.38it/s] 

{'loss': 1.6603, 'grad_norm': 3.9294652938842773, 'learning_rate': 6.7145742929821935e-06, 'epoch': 0.04}


  5%|▍         | 630/13381 [07:49<2:33:34,  1.38it/s]

{'loss': 1.6454, 'grad_norm': 2.2656431198120117, 'learning_rate': 6.677914110429448e-06, 'epoch': 0.05}


  5%|▌         | 700/13381 [08:39<2:32:29,  1.39it/s]

{'loss': 1.5895, 'grad_norm': 4.86131477355957, 'learning_rate': 6.641253927876702e-06, 'epoch': 0.05}


  6%|▌         | 770/13381 [09:30<2:32:22,  1.38it/s]

{'loss': 1.6735, 'grad_norm': 2.0407073497772217, 'learning_rate': 6.604593745323956e-06, 'epoch': 0.06}


  6%|▋         | 840/13381 [10:21<2:30:58,  1.38it/s]

{'loss': 1.571, 'grad_norm': 2.1862173080444336, 'learning_rate': 6.567933562771211e-06, 'epoch': 0.06}


  7%|▋         | 910/13381 [11:12<2:30:54,  1.38it/s]

{'loss': 1.6918, 'grad_norm': 5.35400915145874, 'learning_rate': 6.531273380218465e-06, 'epoch': 0.07}


  7%|▋         | 980/13381 [12:03<2:32:38,  1.35it/s]

{'loss': 1.6108, 'grad_norm': 2.868340015411377, 'learning_rate': 6.494613197665719e-06, 'epoch': 0.07}


  8%|▊         | 1050/13381 [12:57<2:27:55,  1.39it/s]

{'loss': 1.5871, 'grad_norm': 2.2455058097839355, 'learning_rate': 6.457953015112973e-06, 'epoch': 0.08}


  8%|▊         | 1120/13381 [13:48<2:28:10,  1.38it/s]

{'loss': 1.6704, 'grad_norm': 3.054772138595581, 'learning_rate': 6.421292832560227e-06, 'epoch': 0.08}


  9%|▉         | 1190/13381 [14:39<2:27:00,  1.38it/s]

{'loss': 1.5898, 'grad_norm': 1.9471360445022583, 'learning_rate': 6.384632650007482e-06, 'epoch': 0.09}


  9%|▉         | 1260/13381 [15:30<2:26:51,  1.38it/s]

{'loss': 1.6136, 'grad_norm': 3.676391124725342, 'learning_rate': 6.347972467454736e-06, 'epoch': 0.09}


 10%|▉         | 1330/13381 [16:21<2:25:06,  1.38it/s]

{'loss': 1.5678, 'grad_norm': 2.673454761505127, 'learning_rate': 6.31131228490199e-06, 'epoch': 0.1}


 10%|█         | 1400/13381 [17:12<2:23:34,  1.39it/s]

{'loss': 1.5776, 'grad_norm': 1.999883770942688, 'learning_rate': 6.274652102349244e-06, 'epoch': 0.1}


 11%|█         | 1470/13381 [18:03<2:24:46,  1.37it/s]

{'loss': 1.6542, 'grad_norm': 2.5991885662078857, 'learning_rate': 6.237991919796498e-06, 'epoch': 0.11}


 12%|█▏        | 1540/13381 [18:57<2:22:18,  1.39it/s]

{'loss': 1.5655, 'grad_norm': 1.9180128574371338, 'learning_rate': 6.201331737243753e-06, 'epoch': 0.12}


 12%|█▏        | 1610/13381 [19:48<2:22:55,  1.37it/s]

{'loss': 1.5798, 'grad_norm': 3.230391025543213, 'learning_rate': 6.164671554691007e-06, 'epoch': 0.12}


 13%|█▎        | 1680/13381 [20:39<2:21:37,  1.38it/s]

{'loss': 1.5296, 'grad_norm': 2.3884551525115967, 'learning_rate': 6.128011372138261e-06, 'epoch': 0.13}


 13%|█▎        | 1750/13381 [21:30<2:22:52,  1.36it/s]

{'loss': 1.5541, 'grad_norm': 1.9845858812332153, 'learning_rate': 6.0913511895855154e-06, 'epoch': 0.13}


 14%|█▎        | 1820/13381 [22:21<2:19:13,  1.38it/s]

{'loss': 1.7074, 'grad_norm': 2.6682469844818115, 'learning_rate': 6.054691007032769e-06, 'epoch': 0.14}


 14%|█▍        | 1890/13381 [23:12<2:17:57,  1.39it/s]

{'loss': 1.5441, 'grad_norm': 2.391983985900879, 'learning_rate': 6.018030824480024e-06, 'epoch': 0.14}


 15%|█▍        | 1960/13381 [24:02<2:18:08,  1.38it/s]

{'loss': 1.5571, 'grad_norm': 3.4322457313537598, 'learning_rate': 5.981370641927278e-06, 'epoch': 0.15}


 15%|█▌        | 2030/13381 [24:57<2:16:34,  1.39it/s]

{'loss': 1.5567, 'grad_norm': 2.7812633514404297, 'learning_rate': 5.9447104593745325e-06, 'epoch': 0.15}


 16%|█▌        | 2100/13381 [25:51<3:19:54,  1.06s/it]

{'loss': 1.4998, 'grad_norm': 2.089855194091797, 'learning_rate': 5.908050276821786e-06, 'epoch': 0.16}


 16%|█▌        | 2170/13381 [27:05<3:15:33,  1.05s/it]

{'loss': 1.6605, 'grad_norm': 3.107151508331299, 'learning_rate': 5.871390094269041e-06, 'epoch': 0.16}


 17%|█▋        | 2240/13381 [28:19<3:13:36,  1.04s/it]

{'loss': 1.5408, 'grad_norm': 2.594787836074829, 'learning_rate': 5.834729911716295e-06, 'epoch': 0.17}


 17%|█▋        | 2310/13381 [29:33<3:21:50,  1.09s/it]

{'loss': 1.5917, 'grad_norm': 4.870392799377441, 'learning_rate': 5.7980697291635496e-06, 'epoch': 0.17}


 18%|█▊        | 2380/13381 [30:48<3:17:05,  1.07s/it]

{'loss': 1.5299, 'grad_norm': 2.5011847019195557, 'learning_rate': 5.761409546610803e-06, 'epoch': 0.18}


 18%|█▊        | 2450/13381 [32:04<3:15:48,  1.07s/it]

{'loss': 1.5111, 'grad_norm': 3.9632532596588135, 'learning_rate': 5.724749364058058e-06, 'epoch': 0.18}


 19%|█▉        | 2520/13381 [33:27<3:21:27,  1.11s/it]

{'loss': 1.5909, 'grad_norm': 2.3252291679382324, 'learning_rate': 5.688089181505311e-06, 'epoch': 0.19}


 19%|█▉        | 2590/13381 [34:44<3:19:09,  1.11s/it]

{'loss': 1.5217, 'grad_norm': 3.0725691318511963, 'learning_rate': 5.651428998952567e-06, 'epoch': 0.19}


 20%|█▉        | 2660/13381 [36:02<3:28:24,  1.17s/it]

{'loss': 1.5668, 'grad_norm': 4.003012180328369, 'learning_rate': 5.61476881639982e-06, 'epoch': 0.2}


                                                      
 20%|██        | 2677/13381 [37:18<3:15:09,  1.09s/it]

{'eval_loss': 1.5272678136825562, 'eval_runtime': 57.3232, 'eval_samples_per_second': 4.728, 'eval_steps_per_second': 4.728, 'epoch': 0.2}


 20%|██        | 2730/13381 [38:15<3:10:48,  1.07s/it] 

{'loss': 1.5313, 'grad_norm': 2.2613368034362793, 'learning_rate': 5.578108633847075e-06, 'epoch': 0.2}


 21%|██        | 2800/13381 [39:32<3:18:09,  1.12s/it]

{'loss': 1.4981, 'grad_norm': 1.5644930601119995, 'learning_rate': 5.541448451294328e-06, 'epoch': 0.21}


 21%|██▏       | 2870/13381 [40:50<3:49:37,  1.31s/it]

{'loss': 1.611, 'grad_norm': 3.115896224975586, 'learning_rate': 5.504788268741583e-06, 'epoch': 0.21}


 22%|██▏       | 2940/13381 [42:08<3:08:29,  1.08s/it]

{'loss': 1.4747, 'grad_norm': 2.6373565196990967, 'learning_rate': 5.468128086188837e-06, 'epoch': 0.22}


 22%|██▏       | 3010/13381 [43:30<3:14:58,  1.13s/it]

{'loss': 1.5884, 'grad_norm': 3.380664348602295, 'learning_rate': 5.431467903636092e-06, 'epoch': 0.22}


 23%|██▎       | 3080/13381 [44:47<3:10:59,  1.11s/it]

{'loss': 1.504, 'grad_norm': 5.2270941734313965, 'learning_rate': 5.3948077210833455e-06, 'epoch': 0.23}


 24%|██▎       | 3150/13381 [46:05<3:02:34,  1.07s/it]

{'loss': 1.4857, 'grad_norm': 1.9489716291427612, 'learning_rate': 5.3581475385306e-06, 'epoch': 0.24}


 24%|██▍       | 3220/13381 [47:21<3:05:27,  1.10s/it]

{'loss': 1.5864, 'grad_norm': 3.133960723876953, 'learning_rate': 5.3214873559778536e-06, 'epoch': 0.24}


 25%|██▍       | 3290/13381 [48:38<3:05:13,  1.10s/it]

{'loss': 1.4871, 'grad_norm': 3.3830742835998535, 'learning_rate': 5.284827173425109e-06, 'epoch': 0.25}


 25%|██▌       | 3360/13381 [49:40<2:01:24,  1.38it/s]

{'loss': 1.5637, 'grad_norm': 3.5503275394439697, 'learning_rate': 5.2481669908723625e-06, 'epoch': 0.25}


 26%|██▌       | 3430/13381 [50:31<2:00:07,  1.38it/s]

{'loss': 1.5376, 'grad_norm': 2.4744014739990234, 'learning_rate': 5.211506808319617e-06, 'epoch': 0.26}


 26%|██▌       | 3500/13381 [51:22<1:58:33,  1.39it/s]

{'loss': 1.5068, 'grad_norm': 1.9578207731246948, 'learning_rate': 5.174846625766871e-06, 'epoch': 0.26}


 27%|██▋       | 3570/13381 [52:16<1:58:11,  1.38it/s]

{'loss': 1.5847, 'grad_norm': 2.648228168487549, 'learning_rate': 5.138186443214125e-06, 'epoch': 0.27}


 27%|██▋       | 3640/13381 [53:07<1:56:38,  1.39it/s]

{'loss': 1.5261, 'grad_norm': 2.857973098754883, 'learning_rate': 5.10152626066138e-06, 'epoch': 0.27}


 28%|██▊       | 3710/13381 [53:58<1:58:00,  1.37it/s]

{'loss': 1.5812, 'grad_norm': 3.685159206390381, 'learning_rate': 5.064866078108634e-06, 'epoch': 0.28}


 28%|██▊       | 3780/13381 [54:49<1:55:27,  1.39it/s]

{'loss': 1.5245, 'grad_norm': 3.8962695598602295, 'learning_rate': 5.028205895555888e-06, 'epoch': 0.28}


 29%|██▉       | 3850/13381 [55:39<1:54:20,  1.39it/s]

{'loss': 1.4887, 'grad_norm': 3.8138046264648438, 'learning_rate': 4.991545713003142e-06, 'epoch': 0.29}


 29%|██▉       | 3920/13381 [56:31<1:53:35,  1.39it/s]

{'loss': 1.5814, 'grad_norm': 2.782113552093506, 'learning_rate': 4.954885530450396e-06, 'epoch': 0.29}


 30%|██▉       | 3990/13381 [57:21<1:53:00,  1.38it/s]

{'loss': 1.5077, 'grad_norm': 3.2795307636260986, 'learning_rate': 4.91822534789765e-06, 'epoch': 0.3}


 30%|███       | 4060/13381 [58:16<1:52:45,  1.38it/s]

{'loss': 1.563, 'grad_norm': 4.703070640563965, 'learning_rate': 4.881565165344906e-06, 'epoch': 0.3}


 31%|███       | 4130/13381 [59:07<1:51:12,  1.39it/s]

{'loss': 1.5346, 'grad_norm': 3.02072811126709, 'learning_rate': 4.844904982792159e-06, 'epoch': 0.31}


 31%|███▏      | 4200/13381 [59:58<1:50:56,  1.38it/s]

{'loss': 1.4996, 'grad_norm': 3.7623887062072754, 'learning_rate': 4.808244800239414e-06, 'epoch': 0.31}


 32%|███▏      | 4270/13381 [1:00:49<1:49:43,  1.38it/s]

{'loss': 1.6174, 'grad_norm': 3.8152248859405518, 'learning_rate': 4.771584617686667e-06, 'epoch': 0.32}


 32%|███▏      | 4340/13381 [1:01:40<1:48:53,  1.38it/s]

{'loss': 1.4873, 'grad_norm': 2.9360506534576416, 'learning_rate': 4.734924435133922e-06, 'epoch': 0.32}


 33%|███▎      | 4410/13381 [1:02:31<1:48:40,  1.38it/s]

{'loss': 1.5656, 'grad_norm': 4.764707088470459, 'learning_rate': 4.698264252581176e-06, 'epoch': 0.33}


 33%|███▎      | 4480/13381 [1:03:21<1:47:15,  1.38it/s]

{'loss': 1.5185, 'grad_norm': 4.571254253387451, 'learning_rate': 4.661604070028431e-06, 'epoch': 0.33}


 34%|███▍      | 4550/13381 [1:04:15<1:45:55,  1.39it/s]

{'loss': 1.4839, 'grad_norm': 4.118692874908447, 'learning_rate': 4.624943887475684e-06, 'epoch': 0.34}


 35%|███▍      | 4620/13381 [1:05:06<1:45:19,  1.39it/s]

{'loss': 1.5183, 'grad_norm': 3.9295566082000732, 'learning_rate': 4.588283704922939e-06, 'epoch': 0.35}


 35%|███▌      | 4690/13381 [1:05:57<1:44:54,  1.38it/s]

{'loss': 1.5254, 'grad_norm': 2.950314521789551, 'learning_rate': 4.5516235223701925e-06, 'epoch': 0.35}


 36%|███▌      | 4760/13381 [1:06:48<1:45:12,  1.37it/s]

{'loss': 1.596, 'grad_norm': 5.272146701812744, 'learning_rate': 4.514963339817448e-06, 'epoch': 0.36}


 36%|███▌      | 4830/13381 [1:07:39<1:43:27,  1.38it/s]

{'loss': 1.511, 'grad_norm': 3.3266663551330566, 'learning_rate': 4.4783031572647015e-06, 'epoch': 0.36}


 37%|███▋      | 4900/13381 [1:08:30<1:41:54,  1.39it/s]

{'loss': 1.4701, 'grad_norm': 2.48479962348938, 'learning_rate': 4.441642974711956e-06, 'epoch': 0.37}


 37%|███▋      | 4970/13381 [1:09:21<1:40:57,  1.39it/s]

{'loss': 1.5726, 'grad_norm': 3.2525229454040527, 'learning_rate': 4.40498279215921e-06, 'epoch': 0.37}


 38%|███▊      | 5040/13381 [1:10:16<1:40:21,  1.39it/s]

{'loss': 1.5229, 'grad_norm': 3.211667060852051, 'learning_rate': 4.368322609606464e-06, 'epoch': 0.38}


 38%|███▊      | 5110/13381 [1:11:07<1:40:22,  1.37it/s]

{'loss': 1.586, 'grad_norm': 3.366373300552368, 'learning_rate': 4.3316624270537186e-06, 'epoch': 0.38}


 39%|███▊      | 5180/13381 [1:11:58<1:39:03,  1.38it/s]

{'loss': 1.557, 'grad_norm': 3.2799644470214844, 'learning_rate': 4.295002244500973e-06, 'epoch': 0.39}


 39%|███▉      | 5250/13381 [1:12:49<1:37:37,  1.39it/s]

{'loss': 1.4811, 'grad_norm': 3.293226718902588, 'learning_rate': 4.258342061948227e-06, 'epoch': 0.39}


 40%|███▉      | 5320/13381 [1:13:40<1:37:14,  1.38it/s]

{'loss': 1.5843, 'grad_norm': 3.3000261783599854, 'learning_rate': 4.221681879395481e-06, 'epoch': 0.4}


                                                        
 40%|████      | 5354/13381 [1:14:50<1:46:03,  1.26it/s]

{'eval_loss': 1.502929925918579, 'eval_runtime': 45.0819, 'eval_samples_per_second': 6.011, 'eval_steps_per_second': 6.011, 'epoch': 0.4}


 40%|████      | 5390/13381 [1:15:16<1:35:50,  1.39it/s] 

{'loss': 1.5501, 'grad_norm': 2.1481857299804688, 'learning_rate': 4.185021696842735e-06, 'epoch': 0.4}


 41%|████      | 5460/13381 [1:16:07<1:35:23,  1.38it/s]

{'loss': 1.4882, 'grad_norm': 3.831087589263916, 'learning_rate': 4.14836151428999e-06, 'epoch': 0.41}


 41%|████▏     | 5530/13381 [1:17:01<1:34:18,  1.39it/s]

{'loss': 1.5143, 'grad_norm': 2.627490520477295, 'learning_rate': 4.111701331737244e-06, 'epoch': 0.41}


 42%|████▏     | 5600/13381 [1:17:52<1:33:42,  1.38it/s]

{'loss': 1.4645, 'grad_norm': 2.2558135986328125, 'learning_rate': 4.075041149184498e-06, 'epoch': 0.42}


 42%|████▏     | 5670/13381 [1:18:43<1:33:00,  1.38it/s]

{'loss': 1.6055, 'grad_norm': 2.766329765319824, 'learning_rate': 4.038380966631752e-06, 'epoch': 0.42}


 43%|████▎     | 5740/13381 [1:19:34<1:31:55,  1.39it/s]

{'loss': 1.4612, 'grad_norm': 2.302356243133545, 'learning_rate': 4.001720784079006e-06, 'epoch': 0.43}


 43%|████▎     | 5810/13381 [1:20:24<1:31:30,  1.38it/s]

{'loss': 1.5157, 'grad_norm': 3.779472589492798, 'learning_rate': 3.965060601526261e-06, 'epoch': 0.43}


 44%|████▍     | 5880/13381 [1:21:15<1:30:20,  1.38it/s]

{'loss': 1.5083, 'grad_norm': 3.826599359512329, 'learning_rate': 3.928400418973515e-06, 'epoch': 0.44}


 44%|████▍     | 5950/13381 [1:22:06<1:29:19,  1.39it/s]

{'loss': 1.4574, 'grad_norm': 2.6590099334716797, 'learning_rate': 3.891740236420769e-06, 'epoch': 0.44}


 45%|████▍     | 6020/13381 [1:23:00<1:29:02,  1.38it/s]

{'loss': 1.5707, 'grad_norm': 3.8977246284484863, 'learning_rate': 3.855080053868023e-06, 'epoch': 0.45}


 46%|████▌     | 6090/13381 [1:23:51<1:28:11,  1.38it/s]

{'loss': 1.5216, 'grad_norm': 3.4568490982055664, 'learning_rate': 3.818419871315277e-06, 'epoch': 0.46}


 46%|████▌     | 6160/13381 [1:24:42<1:27:45,  1.37it/s]

{'loss': 1.5098, 'grad_norm': 3.7829809188842773, 'learning_rate': 3.781759688762532e-06, 'epoch': 0.46}


 47%|████▋     | 6230/13381 [1:25:33<1:26:09,  1.38it/s]

{'loss': 1.4659, 'grad_norm': 3.674182891845703, 'learning_rate': 3.745099506209786e-06, 'epoch': 0.47}


 47%|████▋     | 6300/13381 [1:26:24<1:25:16,  1.38it/s]

{'loss': 1.4913, 'grad_norm': 2.1266837120056152, 'learning_rate': 3.7084393236570405e-06, 'epoch': 0.47}


 48%|████▊     | 6370/13381 [1:27:15<1:24:17,  1.39it/s]

{'loss': 1.5389, 'grad_norm': 3.5406129360198975, 'learning_rate': 3.671779141104294e-06, 'epoch': 0.48}


 48%|████▊     | 6440/13381 [1:28:06<1:23:24,  1.39it/s]

{'loss': 1.5184, 'grad_norm': 2.780592441558838, 'learning_rate': 3.635118958551549e-06, 'epoch': 0.48}


 49%|████▊     | 6510/13381 [1:29:00<1:27:46,  1.30it/s]

{'loss': 1.5258, 'grad_norm': 3.473909378051758, 'learning_rate': 3.5984587759988026e-06, 'epoch': 0.49}


 49%|████▉     | 6580/13381 [1:29:51<1:22:10,  1.38it/s]

{'loss': 1.5148, 'grad_norm': 4.514517784118652, 'learning_rate': 3.561798593446057e-06, 'epoch': 0.49}


 50%|████▉     | 6650/13381 [1:30:42<1:21:03,  1.38it/s]

{'loss': 1.496, 'grad_norm': 2.6352527141571045, 'learning_rate': 3.525138410893311e-06, 'epoch': 0.5}


 50%|█████     | 6720/13381 [1:31:33<1:20:05,  1.39it/s]

{'loss': 1.5701, 'grad_norm': 3.8466854095458984, 'learning_rate': 3.4884782283405656e-06, 'epoch': 0.5}


 51%|█████     | 6790/13381 [1:32:24<1:19:38,  1.38it/s]

{'loss': 1.4856, 'grad_norm': 2.371412754058838, 'learning_rate': 3.45181804578782e-06, 'epoch': 0.51}


 51%|█████▏    | 6860/13381 [1:33:15<1:19:00,  1.38it/s]

{'loss': 1.5351, 'grad_norm': 4.164725303649902, 'learning_rate': 3.415157863235074e-06, 'epoch': 0.51}


 52%|█████▏    | 6930/13381 [1:34:05<1:17:36,  1.39it/s]

{'loss': 1.4773, 'grad_norm': 3.146613121032715, 'learning_rate': 3.3784976806823282e-06, 'epoch': 0.52}


 52%|█████▏    | 7000/13381 [1:34:56<1:16:41,  1.39it/s]

{'loss': 1.482, 'grad_norm': 2.6565420627593994, 'learning_rate': 3.3418374981295827e-06, 'epoch': 0.52}


 53%|█████▎    | 7070/13381 [1:35:51<1:16:45,  1.37it/s]

{'loss': 1.5255, 'grad_norm': 3.897552251815796, 'learning_rate': 3.3051773155768368e-06, 'epoch': 0.53}


 53%|█████▎    | 7140/13381 [1:36:42<1:15:25,  1.38it/s]

{'loss': 1.469, 'grad_norm': 2.7544968128204346, 'learning_rate': 3.2685171330240912e-06, 'epoch': 0.53}


 54%|█████▍    | 7210/13381 [1:37:33<1:14:53,  1.37it/s]

{'loss': 1.5408, 'grad_norm': 6.686241626739502, 'learning_rate': 3.2318569504713453e-06, 'epoch': 0.54}


 54%|█████▍    | 7280/13381 [1:38:23<1:13:43,  1.38it/s]

{'loss': 1.474, 'grad_norm': 3.5797760486602783, 'learning_rate': 3.1951967679185993e-06, 'epoch': 0.54}


 55%|█████▍    | 7350/13381 [1:39:14<1:12:30,  1.39it/s]

{'loss': 1.4694, 'grad_norm': 2.861029863357544, 'learning_rate': 3.158536585365854e-06, 'epoch': 0.55}


 55%|█████▌    | 7420/13381 [1:40:05<1:11:57,  1.38it/s]

{'loss': 1.5332, 'grad_norm': 3.635829448699951, 'learning_rate': 3.121876402813108e-06, 'epoch': 0.55}


 56%|█████▌    | 7490/13381 [1:40:56<1:10:42,  1.39it/s]

{'loss': 1.479, 'grad_norm': 2.9299302101135254, 'learning_rate': 3.085216220260362e-06, 'epoch': 0.56}


 56%|█████▋    | 7560/13381 [1:41:51<1:10:51,  1.37it/s]

{'loss': 1.5527, 'grad_norm': 5.61636209487915, 'learning_rate': 3.0485560377076164e-06, 'epoch': 0.56}


 57%|█████▋    | 7630/13381 [1:42:42<1:09:14,  1.38it/s]

{'loss': 1.476, 'grad_norm': 3.368586301803589, 'learning_rate': 3.0118958551548705e-06, 'epoch': 0.57}


 58%|█████▊    | 7700/13381 [1:43:33<1:08:19,  1.39it/s]

{'loss': 1.4979, 'grad_norm': 2.552743673324585, 'learning_rate': 2.975235672602125e-06, 'epoch': 0.58}


 58%|█████▊    | 7770/13381 [1:44:24<1:07:44,  1.38it/s]

{'loss': 1.6001, 'grad_norm': 3.9531490802764893, 'learning_rate': 2.938575490049379e-06, 'epoch': 0.58}


 59%|█████▊    | 7840/13381 [1:45:15<1:06:34,  1.39it/s]

{'loss': 1.4683, 'grad_norm': 3.0169870853424072, 'learning_rate': 2.901915307496633e-06, 'epoch': 0.59}


 59%|█████▉    | 7910/13381 [1:46:06<1:06:14,  1.38it/s]

{'loss': 1.5471, 'grad_norm': 4.209919452667236, 'learning_rate': 2.8652551249438875e-06, 'epoch': 0.59}


 60%|█████▉    | 7980/13381 [1:46:56<1:04:45,  1.39it/s]

{'loss': 1.4619, 'grad_norm': 4.054385662078857, 'learning_rate': 2.8285949423911416e-06, 'epoch': 0.6}


                                                        
 60%|██████    | 8031/13381 [1:48:22<1:04:08,  1.39it/s]

{'eval_loss': 1.4825297594070435, 'eval_runtime': 45.0707, 'eval_samples_per_second': 6.013, 'eval_steps_per_second': 6.013, 'epoch': 0.6}


 60%|██████    | 8050/13381 [1:48:36<1:06:13,  1.34it/s] 

{'loss': 1.4548, 'grad_norm': 3.072955846786499, 'learning_rate': 2.791934759838396e-06, 'epoch': 0.6}


 61%|██████    | 8120/13381 [1:49:27<1:03:39,  1.38it/s]

{'loss': 1.5512, 'grad_norm': 3.3526339530944824, 'learning_rate': 2.75527457728565e-06, 'epoch': 0.61}


 61%|██████    | 8190/13381 [1:50:18<1:02:31,  1.38it/s]

{'loss': 1.4829, 'grad_norm': 4.1179962158203125, 'learning_rate': 2.718614394732904e-06, 'epoch': 0.61}


 62%|██████▏   | 8260/13381 [1:51:09<1:01:46,  1.38it/s]

{'loss': 1.4935, 'grad_norm': 5.381463527679443, 'learning_rate': 2.6819542121801587e-06, 'epoch': 0.62}


 62%|██████▏   | 8330/13381 [1:52:00<1:00:38,  1.39it/s]

{'loss': 1.4475, 'grad_norm': 4.071767807006836, 'learning_rate': 2.6452940296274127e-06, 'epoch': 0.62}


 63%|██████▎   | 8400/13381 [1:52:50<59:40,  1.39it/s]  

{'loss': 1.4454, 'grad_norm': 3.676591396331787, 'learning_rate': 2.608633847074667e-06, 'epoch': 0.63}


 63%|██████▎   | 8470/13381 [1:53:41<59:08,  1.38it/s]  

{'loss': 1.5511, 'grad_norm': 3.3915627002716064, 'learning_rate': 2.5719736645219212e-06, 'epoch': 0.63}


 64%|██████▍   | 8540/13381 [1:54:36<58:13,  1.39it/s]  

{'loss': 1.4548, 'grad_norm': 5.008771896362305, 'learning_rate': 2.5353134819691753e-06, 'epoch': 0.64}


 64%|██████▍   | 8610/13381 [1:55:26<57:38,  1.38it/s]  

{'loss': 1.4798, 'grad_norm': 3.536423683166504, 'learning_rate': 2.4986532994164298e-06, 'epoch': 0.64}


 65%|██████▍   | 8680/13381 [1:56:17<56:39,  1.38it/s]  

{'loss': 1.4774, 'grad_norm': 4.033646106719971, 'learning_rate': 2.461993116863684e-06, 'epoch': 0.65}


 65%|██████▌   | 8750/13381 [1:57:08<55:45,  1.38it/s]  

{'loss': 1.4943, 'grad_norm': 2.588831663131714, 'learning_rate': 2.4253329343109383e-06, 'epoch': 0.65}


 66%|██████▌   | 8820/13381 [1:57:59<54:55,  1.38it/s]  

{'loss': 1.5519, 'grad_norm': 3.635460138320923, 'learning_rate': 2.3886727517581924e-06, 'epoch': 0.66}


 66%|██████▋   | 8890/13381 [1:58:50<53:57,  1.39it/s]

{'loss': 1.4611, 'grad_norm': 3.540724277496338, 'learning_rate': 2.3520125692054464e-06, 'epoch': 0.66}


 67%|██████▋   | 8960/13381 [1:59:41<53:36,  1.37it/s]

{'loss': 1.545, 'grad_norm': 5.781626224517822, 'learning_rate': 2.315352386652701e-06, 'epoch': 0.67}


 67%|██████▋   | 9030/13381 [2:00:36<52:17,  1.39it/s]  

{'loss': 1.4869, 'grad_norm': 4.693935394287109, 'learning_rate': 2.278692204099955e-06, 'epoch': 0.67}


 68%|██████▊   | 9100/13381 [2:01:26<51:28,  1.39it/s]

{'loss': 1.4622, 'grad_norm': 3.2687556743621826, 'learning_rate': 2.2420320215472094e-06, 'epoch': 0.68}


 69%|██████▊   | 9170/13381 [2:02:18<50:59,  1.38it/s]

{'loss': 1.5624, 'grad_norm': 4.64288330078125, 'learning_rate': 2.2053718389944635e-06, 'epoch': 0.69}


 69%|██████▉   | 9240/13381 [2:03:09<49:51,  1.38it/s]

{'loss': 1.501, 'grad_norm': 3.7023046016693115, 'learning_rate': 2.1687116564417175e-06, 'epoch': 0.69}


 70%|██████▉   | 9310/13381 [2:03:59<49:22,  1.37it/s]

{'loss': 1.5187, 'grad_norm': 6.196258068084717, 'learning_rate': 2.132051473888972e-06, 'epoch': 0.7}


 70%|███████   | 9380/13381 [2:04:50<48:01,  1.39it/s]

{'loss': 1.4949, 'grad_norm': 3.6101911067962646, 'learning_rate': 2.095391291336226e-06, 'epoch': 0.7}


 71%|███████   | 9450/13381 [2:05:41<47:53,  1.37it/s]

{'loss': 1.4977, 'grad_norm': 2.7750089168548584, 'learning_rate': 2.05873110878348e-06, 'epoch': 0.71}


 71%|███████   | 9520/13381 [2:06:36<46:40,  1.38it/s]  

{'loss': 1.5546, 'grad_norm': 4.290271282196045, 'learning_rate': 2.0220709262307346e-06, 'epoch': 0.71}


 72%|███████▏  | 9590/13381 [2:07:26<45:41,  1.38it/s]

{'loss': 1.4867, 'grad_norm': 3.6754684448242188, 'learning_rate': 1.9854107436779887e-06, 'epoch': 0.72}


 72%|███████▏  | 9660/13381 [2:08:18<45:17,  1.37it/s]

{'loss': 1.5324, 'grad_norm': 5.059299945831299, 'learning_rate': 1.948750561125243e-06, 'epoch': 0.72}


 73%|███████▎  | 9730/13381 [2:09:08<43:51,  1.39it/s]

{'loss': 1.4509, 'grad_norm': 3.856311559677124, 'learning_rate': 1.912090378572497e-06, 'epoch': 0.73}


 73%|███████▎  | 9800/13381 [2:09:59<43:00,  1.39it/s]

{'loss': 1.5083, 'grad_norm': 3.633253812789917, 'learning_rate': 1.8754301960197515e-06, 'epoch': 0.73}


 74%|███████▍  | 9870/13381 [2:10:50<42:10,  1.39it/s]

{'loss': 1.5193, 'grad_norm': 4.8516645431518555, 'learning_rate': 1.8387700134670057e-06, 'epoch': 0.74}


 74%|███████▍  | 9940/13381 [2:11:41<41:29,  1.38it/s]

{'loss': 1.4579, 'grad_norm': 5.015903949737549, 'learning_rate': 1.8021098309142598e-06, 'epoch': 0.74}


 75%|███████▍  | 10010/13381 [2:12:35<43:15,  1.30it/s]  

{'loss': 1.4922, 'grad_norm': 5.800303936004639, 'learning_rate': 1.765449648361514e-06, 'epoch': 0.75}


 75%|███████▌  | 10080/13381 [2:13:26<39:43,  1.38it/s]

{'loss': 1.4667, 'grad_norm': 3.4278616905212402, 'learning_rate': 1.7287894658087685e-06, 'epoch': 0.75}


 76%|███████▌  | 10150/13381 [2:14:17<39:00,  1.38it/s]

{'loss': 1.43, 'grad_norm': 3.5497477054595947, 'learning_rate': 1.6921292832560228e-06, 'epoch': 0.76}


 76%|███████▋  | 10220/13381 [2:15:08<38:11,  1.38it/s]

{'loss': 1.5454, 'grad_norm': 3.6079094409942627, 'learning_rate': 1.655469100703277e-06, 'epoch': 0.76}


 77%|███████▋  | 10290/13381 [2:15:59<37:12,  1.38it/s]

{'loss': 1.4302, 'grad_norm': 3.4630298614501953, 'learning_rate': 1.6188089181505311e-06, 'epoch': 0.77}


 77%|███████▋  | 10360/13381 [2:16:50<36:40,  1.37it/s]

{'loss': 1.5067, 'grad_norm': 4.574410915374756, 'learning_rate': 1.5821487355977854e-06, 'epoch': 0.77}


 78%|███████▊  | 10430/13381 [2:17:41<35:38,  1.38it/s]

{'loss': 1.5017, 'grad_norm': 3.1026723384857178, 'learning_rate': 1.5454885530450396e-06, 'epoch': 0.78}


 78%|███████▊  | 10500/13381 [2:18:32<34:42,  1.38it/s]

{'loss': 1.4191, 'grad_norm': 1.973806619644165, 'learning_rate': 1.508828370492294e-06, 'epoch': 0.78}


 79%|███████▉  | 10570/13381 [2:19:26<33:52,  1.38it/s]  

{'loss': 1.5201, 'grad_norm': 4.310762405395508, 'learning_rate': 1.472168187939548e-06, 'epoch': 0.79}


 80%|███████▉  | 10640/13381 [2:20:17<32:58,  1.39it/s]

{'loss': 1.4471, 'grad_norm': 3.6396679878234863, 'learning_rate': 1.4355080053868022e-06, 'epoch': 0.8}


                                                       
 80%|████████  | 10708/13381 [2:21:51<32:41,  1.36it/s]

{'eval_loss': 1.471451759338379, 'eval_runtime': 44.9394, 'eval_samples_per_second': 6.03, 'eval_steps_per_second': 6.03, 'epoch': 0.8}


 80%|████████  | 10710/13381 [2:21:53<7:32:42, 10.17s/it] 

{'loss': 1.5211, 'grad_norm': 4.8758039474487305, 'learning_rate': 1.3988478228340565e-06, 'epoch': 0.8}


 81%|████████  | 10780/13381 [2:22:44<31:19,  1.38it/s]  

{'loss': 1.4668, 'grad_norm': 3.7700326442718506, 'learning_rate': 1.3621876402813108e-06, 'epoch': 0.81}


 81%|████████  | 10850/13381 [2:23:34<30:19,  1.39it/s]

{'loss': 1.4645, 'grad_norm': 2.8656556606292725, 'learning_rate': 1.325527457728565e-06, 'epoch': 0.81}


 82%|████████▏ | 10920/13381 [2:24:25<29:46,  1.38it/s]

{'loss': 1.4941, 'grad_norm': 4.341451168060303, 'learning_rate': 1.288867275175819e-06, 'epoch': 0.82}


 82%|████████▏ | 10990/13381 [2:25:16<28:45,  1.39it/s]

{'loss': 1.4146, 'grad_norm': 3.0740396976470947, 'learning_rate': 1.2522070926230734e-06, 'epoch': 0.82}


 83%|████████▎ | 11060/13381 [2:26:10<28:09,  1.37it/s]  

{'loss': 1.5028, 'grad_norm': 7.543613433837891, 'learning_rate': 1.2155469100703276e-06, 'epoch': 0.83}


 83%|████████▎ | 11130/13381 [2:27:01<27:08,  1.38it/s]

{'loss': 1.4457, 'grad_norm': 3.3555715084075928, 'learning_rate': 1.1788867275175819e-06, 'epoch': 0.83}


 84%|████████▎ | 11200/13381 [2:27:52<26:09,  1.39it/s]

{'loss': 1.4801, 'grad_norm': 5.293652057647705, 'learning_rate': 1.1422265449648362e-06, 'epoch': 0.84}


 84%|████████▍ | 11270/13381 [2:28:44<25:29,  1.38it/s]

{'loss': 1.6061, 'grad_norm': 5.346441268920898, 'learning_rate': 1.1055663624120902e-06, 'epoch': 0.84}


 85%|████████▍ | 11340/13381 [2:29:34<24:30,  1.39it/s]

{'loss': 1.476, 'grad_norm': 3.5781028270721436, 'learning_rate': 1.0689061798593445e-06, 'epoch': 0.85}


 85%|████████▌ | 11410/13381 [2:30:26<24:09,  1.36it/s]

{'loss': 1.515, 'grad_norm': 4.436609745025635, 'learning_rate': 1.0322459973065987e-06, 'epoch': 0.85}


 86%|████████▌ | 11480/13381 [2:31:17<23:01,  1.38it/s]

{'loss': 1.4593, 'grad_norm': 3.861689805984497, 'learning_rate': 9.95585814753853e-07, 'epoch': 0.86}


 86%|████████▋ | 11550/13381 [2:32:11<22:01,  1.39it/s]

{'loss': 1.447, 'grad_norm': 3.764894723892212, 'learning_rate': 9.589256322011073e-07, 'epoch': 0.86}


 87%|████████▋ | 11620/13381 [2:33:02<21:08,  1.39it/s]

{'loss': 1.5711, 'grad_norm': 4.282415866851807, 'learning_rate': 9.222654496483615e-07, 'epoch': 0.87}


 87%|████████▋ | 11690/13381 [2:33:53<20:14,  1.39it/s]

{'loss': 1.4474, 'grad_norm': 2.927554130554199, 'learning_rate': 8.856052670956158e-07, 'epoch': 0.87}


 88%|████████▊ | 11760/13381 [2:34:43<19:29,  1.39it/s]

{'loss': 1.4965, 'grad_norm': 5.826684951782227, 'learning_rate': 8.4894508454287e-07, 'epoch': 0.88}


 88%|████████▊ | 11830/13381 [2:35:34<18:37,  1.39it/s]

{'loss': 1.4539, 'grad_norm': 4.363199234008789, 'learning_rate': 8.122849019901241e-07, 'epoch': 0.88}


 89%|████████▉ | 11900/13381 [2:36:25<17:48,  1.39it/s]

{'loss': 1.4524, 'grad_norm': 4.344451904296875, 'learning_rate': 7.756247194373784e-07, 'epoch': 0.89}


 89%|████████▉ | 11970/13381 [2:37:16<16:58,  1.39it/s]

{'loss': 1.5199, 'grad_norm': 4.459169387817383, 'learning_rate': 7.389645368846326e-07, 'epoch': 0.89}


 90%|████████▉ | 12040/13381 [2:38:10<16:11,  1.38it/s]

{'loss': 1.4208, 'grad_norm': 4.1606926918029785, 'learning_rate': 7.023043543318868e-07, 'epoch': 0.9}


 91%|█████████ | 12110/13381 [2:39:01<15:27,  1.37it/s]

{'loss': 1.5591, 'grad_norm': 4.025750637054443, 'learning_rate': 6.656441717791411e-07, 'epoch': 0.91}


 91%|█████████ | 12180/13381 [2:39:52<14:26,  1.39it/s]

{'loss': 1.4877, 'grad_norm': 4.674910068511963, 'learning_rate': 6.289839892263954e-07, 'epoch': 0.91}


 92%|█████████▏| 12250/13381 [2:40:43<13:37,  1.38it/s]

{'loss': 1.4355, 'grad_norm': 3.145505428314209, 'learning_rate': 5.923238066736496e-07, 'epoch': 0.92}


 92%|█████████▏| 12320/13381 [2:41:34<12:47,  1.38it/s]

{'loss': 1.5478, 'grad_norm': 3.945233106613159, 'learning_rate': 5.556636241209038e-07, 'epoch': 0.92}


 93%|█████████▎| 12390/13381 [2:42:25<11:56,  1.38it/s]

{'loss': 1.4554, 'grad_norm': 4.094316005706787, 'learning_rate': 5.190034415681581e-07, 'epoch': 0.93}


 93%|█████████▎| 12460/13381 [2:43:16<11:09,  1.37it/s]

{'loss': 1.4849, 'grad_norm': 3.9738705158233643, 'learning_rate': 4.823432590154122e-07, 'epoch': 0.93}


 94%|█████████▎| 12530/13381 [2:44:11<10:16,  1.38it/s]

{'loss': 1.4485, 'grad_norm': 3.7084672451019287, 'learning_rate': 4.456830764626665e-07, 'epoch': 0.94}


 94%|█████████▍| 12600/13381 [2:45:02<09:22,  1.39it/s]

{'loss': 1.4864, 'grad_norm': 2.707033634185791, 'learning_rate': 4.090228939099207e-07, 'epoch': 0.94}


 95%|█████████▍| 12670/13381 [2:45:53<08:34,  1.38it/s]

{'loss': 1.5249, 'grad_norm': 3.8965983390808105, 'learning_rate': 3.723627113571749e-07, 'epoch': 0.95}


 95%|█████████▌| 12740/13381 [2:46:44<07:42,  1.39it/s]

{'loss': 1.4595, 'grad_norm': 4.002164840698242, 'learning_rate': 3.3570252880442913e-07, 'epoch': 0.95}


 96%|█████████▌| 12810/13381 [2:47:35<06:53,  1.38it/s]

{'loss': 1.5208, 'grad_norm': 4.374786376953125, 'learning_rate': 2.9904234625168334e-07, 'epoch': 0.96}


 96%|█████████▋| 12880/13381 [2:48:26<06:02,  1.38it/s]

{'loss': 1.4722, 'grad_norm': 4.900125980377197, 'learning_rate': 2.623821636989376e-07, 'epoch': 0.96}


 97%|█████████▋| 12950/13381 [2:49:16<05:11,  1.38it/s]

{'loss': 1.4511, 'grad_norm': 2.8963398933410645, 'learning_rate': 2.2572198114619185e-07, 'epoch': 0.97}


 97%|█████████▋| 13020/13381 [2:50:11<04:22,  1.38it/s]

{'loss': 1.5813, 'grad_norm': 4.974064350128174, 'learning_rate': 1.8906179859344606e-07, 'epoch': 0.97}


 98%|█████████▊| 13090/13381 [2:51:02<03:30,  1.38it/s]

{'loss': 1.5074, 'grad_norm': 4.188724994659424, 'learning_rate': 1.5240161604070028e-07, 'epoch': 0.98}


 98%|█████████▊| 13160/13381 [2:51:53<02:40,  1.38it/s]

{'loss': 1.4943, 'grad_norm': 4.55222225189209, 'learning_rate': 1.157414334879545e-07, 'epoch': 0.98}


 99%|█████████▉| 13230/13381 [2:52:44<01:48,  1.39it/s]

{'loss': 1.4416, 'grad_norm': 3.0845043659210205, 'learning_rate': 7.908125093520875e-08, 'epoch': 0.99}


 99%|█████████▉| 13300/13381 [2:53:35<00:58,  1.39it/s]

{'loss': 1.4335, 'grad_norm': 4.494693279266357, 'learning_rate': 4.2421068382462965e-08, 'epoch': 0.99}


100%|█████████▉| 13370/13381 [2:54:26<00:07,  1.38it/s]

{'loss': 1.5707, 'grad_norm': 6.909970760345459, 'learning_rate': 5.760885829717193e-09, 'epoch': 1.0}


100%|██████████| 13381/13381 [2:54:37<00:00,  1.28it/s]

{'train_runtime': 10477.8373, 'train_samples_per_second': 2.554, 'train_steps_per_second': 1.277, 'train_loss': 1.549549167497072, 'epoch': 1.0}





TrainOutput(global_step=13381, training_loss=1.549549167497072, metrics={'train_runtime': 10477.8373, 'train_samples_per_second': 2.554, 'train_steps_per_second': 1.277, 'total_flos': 6.008832910914355e+16, 'train_loss': 1.549549167497072, 'epoch': 1.0})

In [12]:
begin_of_text = "<|begin_of_text|>"
end_of_text = "<|end_of_text|>"
start_header_id = "<|start_header_id|>"
end_header_id = "<|end_header_id|>"
eot_id = "<|eot_id|>"
prompt =f"""{begin_of_text}
{start_header_id}system{end_header_id} Tradu această propoziție din aromână în română.
{start_header_id}user{end_header_id} Te s-hiba, greaste tata-su al Teatire, - ficiorlu-a meu easte!{eot_id}
{start_header_id}assistant{end_header_id}"""

inputs = tokenizer(prompt, return_tensors='pt', padding=True, 
                   truncation=True).to("cuda")

outputs = model.generate(**inputs, max_length=150, 
                         num_return_sequences=1)

text = tokenizer.decode(outputs[0], skip_special_tokens=True)

print(text.split("assistant")[1])

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


 Se-ncepea, si-l numara tatul lui Teatire - cei 8 copii ai lui!
://
://































In [25]:
def translate(text, src_lang='rup', tgt_lang='ro', max_length=256):
    """Translate a text or list of texts"""
    if isinstance(text, str):
        text = [text]
    
    results = []
    for sentence in text:
        prompt = f"""{begin_of_text}
        {start_header_id}system{end_header_id} Tradu această propoziție din aromână în română.
        {start_header_id}user{end_header_id} {sentence}{eot_id}
        {start_header_id}assistant{end_header_id}"""
        inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True)
        
        with torch.no_grad():
            outputs = model.generate(**inputs, max_length=max_length, num_return_sequences=1, temperature=0.4, top_k=50, top_p=0.95)
        
        result = tokenizer.decode(outputs[0], skip_special_tokens=True)
        translation = result.split("assistant")[-1].strip()
        results.append(translation)
    
    return results

# Example usage
t = 'Ma, a lor la si paru c-amintara na vasilie-ntreaga'
print(translate(t, 'aromanian', 'romanian'))

['Dar, ei, se pare ca se mai amintesc de o vasilie - de totul!\n        ://']


In [None]:
from tqdm import tqdm
import sacrebleu
df_ro_rup_test = pd.read_csv("../dataset/nllb_corpus_test.csv")
df_ro_rup_test['ro_pred'] = ''
df_ro_rup_test['rup_pred'] = ''

test_len = len(df_ro_rup_test)
test_len = 200
for i in tqdm(range(0, test_len)):
    rup_texts = df_ro_rup_test.loc[i, 'rup']
    
    if rup_texts:
        df_ro_rup_test.loc[i, 'ro_pred'] = translate(rup_texts, 'romanian', 'aromanian')



In [13]:
# Save the trained model
trainer.model.save_pretrained(new_model)



In [37]:
import re
bleu_calc = sacrebleu.BLEU()
chrf_calc = sacrebleu.CHRF()
df_ro_rup_test_t = [re.sub(r'[\n:/]', '', el[0]) for el in df_ro_rup_test['ro_pred'].tolist() if el]

# df_ro_rup_test_t = [el[0] for el in df_ro_rup_test['ro_pred'].tolist() if el]

print("Aromanian to Romanian BLEU:", bleu_calc.corpus_score(df_ro_rup_test['ro'][:200].tolist(), [df_ro_rup_test_t]))

Aromanian to Romanian BLEU: BLEU = 1.32 11.1/2.5/0.4/0.3 (BP = 1.000 ratio = 1.343 hyp_len = 478 ref_len = 356)


In [39]:
df_ro_rup_test_v = []
for ref, pred in zip(df_ro_rup_test['ro'][:200].tolist(), df_ro_rup_test_t):
    if pred:
        df_ro_rup_test_v.append(ref)

print(len(df_ro_rup_test_v))
df_ro_rup_test_t = [el for el in df_ro_rup_test_t if el]
print(len(df_ro_rup_test_t))
print("Aromanian to Romanian BLEU:", bleu_calc.corpus_score(df_ro_rup_test_v, [df_ro_rup_test_t]))
print("Aromanian to Romanian CHRF:", chrf_calc.corpus_score(df_ro_rup_test_v, [df_ro_rup_test_t]))

185
185
Aromanian to Romanian BLEU: BLEU = 0.00 0.0/0.0/0.0/0.0 (BP = 1.000 ratio = 1.261 hyp_len = 449 ref_len = 356)
Aromanian to Romanian CHRF: chrF2 = 6.02
