In [1]:
import torch
import glob
import pandas as pd
import numpy as np
import re
from peft import get_peft_model, PeftConfig, PeftModel, LoraConfig, prepare_model_for_kbit_training
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TrainingArguments, GenerationConfig
from trl import SFTTrainer
from datasets import Dataset

In [8]:
# load multi30k
from datasets import load_dataset
multi30k = load_dataset("bentrevett/multi30k")
multi30k

DatasetDict({
    train: Dataset({
        features: ['en', 'de'],
        num_rows: 29000
    })
    validation: Dataset({
        features: ['en', 'de'],
        num_rows: 1014
    })
    test: Dataset({
        features: ['en', 'de'],
        num_rows: 1000
    })
})

In [6]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [2]:
model_name = 'TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T'
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,           
    bnb_4bit_quant_type="nf4",    
    bnb_4bit_use_double_quant=True, 
    bnb_4bit_compute_dtype=torch.bfloat16, 
)

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config, 
    device_map="auto",  
    trust_remote_code=True, 
)

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) 
tokenizer.pad_token = tokenizer.eos_token

In [11]:
# Setting arguments for low-rank adaptation 

model = prepare_model_for_kbit_training(model)

lora_alpha = 32 # The weight matrix is scaled by lora_alpha/lora_rank, so I set lora_alpha = lora_rank to remove scaling
lora_dropout = 0.05 
lora_rank = 32 

peft_config = LoraConfig(
    lora_alpha=lora_alpha,
    lora_dropout=lora_dropout,
    r=lora_rank,
    bias="none",  # setting to 'none' for only training weight params instead of biases
    task_type="CAUSAL_LM")

peft_model = get_peft_model(model, peft_config)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [55]:
output_dir = "HamadJassem/tinyllama2k" # Model repo on your hugging face account where you want to save your model
per_device_train_batch_size = 3
gradient_accumulation_steps = 2  
optim = "paged_adamw_32bit" 
save_strategy="steps" 
save_steps = 10 
logging_steps = 10  
learning_rate = 2e-3  
max_grad_norm = 0.3 # Sets limit for gradient clipping
max_steps = 2000     # Number of training steps
warmup_ratio = 0.03 # Portion of steps used for learning_rate to warmup from 0
lr_scheduler_type = "cosine" # I chose cosine to avoid learning plateaus

training_arguments = TrainingArguments(
    output_dir=output_dir,
    per_device_train_batch_size=per_device_train_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    optim=optim,
    save_steps=save_steps,
    logging_steps=logging_steps,
    learning_rate=learning_rate,
    max_grad_norm=max_grad_norm,
    max_steps=max_steps,
    warmup_ratio=warmup_ratio,
    lr_scheduler_type=lr_scheduler_type,
    push_to_hub=True,
    report_to='none'
)

In [13]:
def formatted_train(input, response):
    return f"<|user|>\n{input}</s>\n<|assistant|>\n{response}</s>"

def prepare_train_data(data_id):
    data = load_dataset(data_id)
    data_df = pd.DataFrame(data['train'])
    data_df['formatted'] = data_df.apply(lambda x: formatted_train(x['en'], x['de']), axis=1)
    return data_df['formatted'].tolist()


In [15]:
train_data = prepare_train_data("bentrevett/multi30k")

In [17]:
data = Dataset.from_dict({"text": train_data})


In [56]:
trainer = SFTTrainer(
    model=peft_model,
    train_dataset=data,
    peft_config=peft_config,
    max_seq_length=500,
    dataset_text_field='text',
    tokenizer=tokenizer,
    args=training_arguments
)
peft_model.config.use_cache = False

Map:   0%|          | 0/29000 [00:00<?, ? examples/s]

In [57]:
trainer.train()

  0%|          | 0/2000 [00:00<?, ?it/s]



{'loss': 1.2449, 'grad_norm': 0.7951179146766663, 'learning_rate': 0.0003333333333333333, 'epoch': 0.0}




{'loss': 1.1993, 'grad_norm': 0.9199689626693726, 'learning_rate': 0.0006666666666666666, 'epoch': 0.0}




{'loss': 1.2023, 'grad_norm': 1.319110631942749, 'learning_rate': 0.001, 'epoch': 0.01}




{'loss': 1.207, 'grad_norm': 1.0385454893112183, 'learning_rate': 0.0013333333333333333, 'epoch': 0.01}




{'loss': 1.1531, 'grad_norm': 2.355604887008667, 'learning_rate': 0.0016666666666666668, 'epoch': 0.01}




{'loss': 1.1424, 'grad_norm': 0.9894558787345886, 'learning_rate': 0.002, 'epoch': 0.01}




{'loss': 1.1534, 'grad_norm': 1.4532840251922607, 'learning_rate': 0.001999868883665632, 'epoch': 0.01}




{'loss': 1.2159, 'grad_norm': 1.9860273599624634, 'learning_rate': 0.0019994755690455153, 'epoch': 0.02}




{'loss': 1.1991, 'grad_norm': 1.673425555229187, 'learning_rate': 0.0019988201592795905, 'epoch': 0.02}




{'loss': 1.1449, 'grad_norm': 1.316955804824829, 'learning_rate': 0.0019979028262377117, 'epoch': 0.02}




{'loss': 1.3165, 'grad_norm': 1.8808751106262207, 'learning_rate': 0.0019967238104745696, 'epoch': 0.02}




{'loss': 1.3474, 'grad_norm': 1.7279906272888184, 'learning_rate': 0.0019952834211666138, 'epoch': 0.02}




{'loss': 1.3314, 'grad_norm': 1.298763632774353, 'learning_rate': 0.001993582036030978, 'epoch': 0.03}




{'loss': 1.3654, 'grad_norm': 1.8918811082839966, 'learning_rate': 0.001991620101226425, 'epoch': 0.03}




{'loss': 1.4248, 'grad_norm': 1.5480859279632568, 'learning_rate': 0.001989398131236356, 'epoch': 0.03}




{'loss': 1.3882, 'grad_norm': 1.576898217201233, 'learning_rate': 0.0019869167087338906, 'epoch': 0.03}




{'loss': 1.375, 'grad_norm': 1.4029897451400757, 'learning_rate': 0.0019841764844290744, 'epoch': 0.04}




{'loss': 1.3815, 'grad_norm': 1.789203405380249, 'learning_rate': 0.0019811781768982392, 'epoch': 0.04}




{'loss': 1.4313, 'grad_norm': 1.9586189985275269, 'learning_rate': 0.0019779225723955706, 'epoch': 0.04}




{'loss': 1.4009, 'grad_norm': 2.2698206901550293, 'learning_rate': 0.001974410524646926, 'epoch': 0.04}




{'loss': 1.4199, 'grad_norm': 2.350048542022705, 'learning_rate': 0.0019706429546259593, 'epoch': 0.04}




{'loss': 1.3828, 'grad_norm': 1.9082087278366089, 'learning_rate': 0.001966620850312611, 'epoch': 0.05}




{'loss': 1.3048, 'grad_norm': 2.2600371837615967, 'learning_rate': 0.0019623452664340305, 'epoch': 0.05}




{'loss': 1.3531, 'grad_norm': 1.7225388288497925, 'learning_rate': 0.001957817324187987, 'epoch': 0.05}




{'loss': 1.366, 'grad_norm': 1.9740359783172607, 'learning_rate': 0.0019530382109488608, 'epoch': 0.05}




{'loss': 1.4342, 'grad_norm': 1.7084492444992065, 'learning_rate': 0.0019480091799562705, 'epoch': 0.05}




{'loss': 1.4003, 'grad_norm': 2.17181658744812, 'learning_rate': 0.0019427315499864343, 'epoch': 0.06}




{'loss': 1.4203, 'grad_norm': 1.7350876331329346, 'learning_rate': 0.0019372067050063438, 'epoch': 0.06}




{'loss': 1.3506, 'grad_norm': 2.7233171463012695, 'learning_rate': 0.0019314360938108425, 'epoch': 0.06}




{'loss': 1.2747, 'grad_norm': 1.5963140726089478, 'learning_rate': 0.0019254212296427042, 'epoch': 0.06}




{'loss': 1.3522, 'grad_norm': 1.672876000404358, 'learning_rate': 0.0019191636897958123, 'epoch': 0.06}




{'loss': 1.3926, 'grad_norm': 1.5802855491638184, 'learning_rate': 0.0019126651152015402, 'epoch': 0.07}




{'loss': 1.3348, 'grad_norm': 1.952505350112915, 'learning_rate': 0.0019059272099984468, 'epoch': 0.07}




{'loss': 1.3315, 'grad_norm': 2.091609239578247, 'learning_rate': 0.0018989517410853954, 'epoch': 0.07}




{'loss': 1.4062, 'grad_norm': 1.9274786710739136, 'learning_rate': 0.0018917405376582144, 'epoch': 0.07}




{'loss': 1.3646, 'grad_norm': 2.3886921405792236, 'learning_rate': 0.0018842954907300237, 'epoch': 0.07}




{'loss': 1.401, 'grad_norm': 1.861738681793213, 'learning_rate': 0.001876618552635348, 'epoch': 0.08}




{'loss': 1.3834, 'grad_norm': 1.2812858819961548, 'learning_rate': 0.0018687117365181513, 'epoch': 0.08}




{'loss': 1.3764, 'grad_norm': 2.1207404136657715, 'learning_rate': 0.0018605771158039252, 'epoch': 0.08}




{'loss': 1.3556, 'grad_norm': 1.912116289138794, 'learning_rate': 0.0018522168236559692, 'epoch': 0.08}




{'loss': 1.3563, 'grad_norm': 2.5215275287628174, 'learning_rate': 0.0018436330524160046, 'epoch': 0.08}




{'loss': 1.3837, 'grad_norm': 2.0447325706481934, 'learning_rate': 0.0018348280530292712, 'epoch': 0.09}




{'loss': 1.4622, 'grad_norm': 2.477651834487915, 'learning_rate': 0.0018258041344542564, 'epoch': 0.09}




{'loss': 1.3595, 'grad_norm': 2.7910008430480957, 'learning_rate': 0.001816563663057211, 'epoch': 0.09}




{'loss': 1.4155, 'grad_norm': 1.8973078727722168, 'learning_rate': 0.0018071090619916092, 'epoch': 0.09}




{'loss': 1.384, 'grad_norm': 1.4867877960205078, 'learning_rate': 0.0017974428105627207, 'epoch': 0.1}




{'loss': 1.3666, 'grad_norm': 2.003140449523926, 'learning_rate': 0.0017875674435774544, 'epoch': 0.1}




{'loss': 1.3114, 'grad_norm': 1.6805721521377563, 'learning_rate': 0.0017774855506796495, 'epoch': 0.1}




{'loss': 1.3441, 'grad_norm': 1.605634331703186, 'learning_rate': 0.001767199775670986, 'epoch': 0.1}




{'loss': 1.4243, 'grad_norm': 1.8011895418167114, 'learning_rate': 0.0017567128158176952, 'epoch': 0.1}




{'loss': 1.4099, 'grad_norm': 2.1648190021514893, 'learning_rate': 0.0017460274211432462, 'epoch': 0.11}




{'loss': 1.3107, 'grad_norm': 1.8194637298583984, 'learning_rate': 0.0017351463937072004, 'epoch': 0.11}




{'loss': 1.2999, 'grad_norm': 2.1159987449645996, 'learning_rate': 0.0017240725868704217, 'epoch': 0.11}




{'loss': 1.3794, 'grad_norm': 1.9560562372207642, 'learning_rate': 0.0017128089045468293, 'epoch': 0.11}




{'loss': 1.3407, 'grad_norm': 1.5670764446258545, 'learning_rate': 0.0017013583004418993, 'epoch': 0.11}




{'loss': 1.3801, 'grad_norm': 1.4321688413619995, 'learning_rate': 0.0016897237772781045, 'epoch': 0.12}




{'loss': 1.4476, 'grad_norm': 2.3535895347595215, 'learning_rate': 0.0016779083860075032, 'epoch': 0.12}




{'loss': 1.4001, 'grad_norm': 1.7069334983825684, 'learning_rate': 0.0016659152250116812, 'epoch': 0.12}




{'loss': 1.4101, 'grad_norm': 2.14029860496521, 'learning_rate': 0.0016537474392892527, 'epoch': 0.12}




{'loss': 1.3598, 'grad_norm': 1.6677004098892212, 'learning_rate': 0.00164140821963114, 'epoch': 0.12}




{'loss': 1.4159, 'grad_norm': 1.5774400234222412, 'learning_rate': 0.0016289008017838446, 'epoch': 0.13}




{'loss': 1.3251, 'grad_norm': 1.727860689163208, 'learning_rate': 0.0016162284656009273, 'epoch': 0.13}




{'loss': 1.3245, 'grad_norm': 2.1560733318328857, 'learning_rate': 0.0016033945341829248, 'epoch': 0.13}




{'loss': 1.2865, 'grad_norm': 1.9356287717819214, 'learning_rate': 0.0015904023730059227, 'epoch': 0.13}




{'loss': 1.3856, 'grad_norm': 1.7546733617782593, 'learning_rate': 0.0015772553890390196, 'epoch': 0.13}




{'loss': 1.3631, 'grad_norm': 2.2954840660095215, 'learning_rate': 0.0015639570298509064, 'epoch': 0.14}




{'loss': 1.3624, 'grad_norm': 1.2395082712173462, 'learning_rate': 0.0015505107827058036, 'epoch': 0.14}




{'loss': 1.2993, 'grad_norm': 5.917083740234375, 'learning_rate': 0.0015369201736489839, 'epoch': 0.14}




{'loss': 1.3596, 'grad_norm': 1.4701038599014282, 'learning_rate': 0.00152318876658213, 'epoch': 0.14}




{'loss': 1.426, 'grad_norm': 1.4517688751220703, 'learning_rate': 0.001509320162328763, 'epoch': 0.14}




{'loss': 1.4313, 'grad_norm': 2.3441741466522217, 'learning_rate': 0.0014953179976899878, 'epoch': 0.15}




{'loss': 1.3003, 'grad_norm': 2.1034417152404785, 'learning_rate': 0.001481185944490805, 'epoch': 0.15}




{'loss': 1.2696, 'grad_norm': 1.6347455978393555, 'learning_rate': 0.0014669277086172407, 'epoch': 0.15}




{'loss': 1.4324, 'grad_norm': 1.6337964534759521, 'learning_rate': 0.0014525470290445391, 'epoch': 0.15}




{'loss': 1.25, 'grad_norm': 1.9949686527252197, 'learning_rate': 0.0014380476768566823, 'epoch': 0.16}




{'loss': 1.3137, 'grad_norm': 1.3360998630523682, 'learning_rate': 0.0014234334542574906, 'epoch': 0.16}




{'loss': 1.2822, 'grad_norm': 1.5174576044082642, 'learning_rate': 0.0014087081935735563, 'epoch': 0.16}




{'loss': 1.3623, 'grad_norm': 1.9854929447174072, 'learning_rate': 0.0013938757562492873, 'epoch': 0.16}




{'loss': 1.3729, 'grad_norm': 1.4631831645965576, 'learning_rate': 0.0013789400318343068, 'epoch': 0.16}




{'loss': 1.3517, 'grad_norm': 1.8151986598968506, 'learning_rate': 0.0013639049369634877, 'epoch': 0.17}




{'loss': 1.368, 'grad_norm': 2.078228712081909, 'learning_rate': 0.001348774414329882, 'epoch': 0.17}




{'loss': 1.3768, 'grad_norm': 1.5143951177597046, 'learning_rate': 0.0013335524316508208, 'epoch': 0.17}




{'loss': 1.3349, 'grad_norm': 1.7548283338546753, 'learning_rate': 0.0013182429806274441, 'epoch': 0.17}




{'loss': 1.3378, 'grad_norm': 1.8395391702651978, 'learning_rate': 0.0013028500758979506, 'epoch': 0.17}




{'loss': 1.3031, 'grad_norm': 1.8354792594909668, 'learning_rate': 0.0012873777539848283, 'epoch': 0.18}




{'loss': 1.3625, 'grad_norm': 1.37184739112854, 'learning_rate': 0.001271830072236343, 'epoch': 0.18}




{'loss': 1.3397, 'grad_norm': 1.8135080337524414, 'learning_rate': 0.0012562111077625722, 'epoch': 0.18}




{'loss': 1.2005, 'grad_norm': 1.3517221212387085, 'learning_rate': 0.0012405249563662538, 'epoch': 0.18}




{'loss': 1.2202, 'grad_norm': 1.251674771308899, 'learning_rate': 0.0012247757314687295, 'epoch': 0.18}




{'loss': 1.2549, 'grad_norm': 1.4959944486618042, 'learning_rate': 0.0012089675630312753, 'epoch': 0.19}




{'loss': 1.2888, 'grad_norm': 1.22725248336792, 'learning_rate': 0.001193104596472088, 'epoch': 0.19}




{'loss': 1.311, 'grad_norm': 1.9193379878997803, 'learning_rate': 0.0011771909915792229, 'epoch': 0.19}




{'loss': 1.3235, 'grad_norm': 1.573407530784607, 'learning_rate': 0.0011612309214197598, 'epoch': 0.19}




{'loss': 1.3096, 'grad_norm': 1.2790882587432861, 'learning_rate': 0.0011452285712454905, 'epoch': 0.19}




{'loss': 1.3261, 'grad_norm': 1.664119005203247, 'learning_rate': 0.0011291881373954064, 'epoch': 0.2}




{'loss': 1.173, 'grad_norm': 1.1240872144699097, 'learning_rate': 0.0011131138261952845, 'epoch': 0.2}




{'loss': 1.3029, 'grad_norm': 1.4965968132019043, 'learning_rate': 0.0010970098528546481, 'epoch': 0.2}




{'loss': 1.3066, 'grad_norm': 1.2815098762512207, 'learning_rate': 0.0010808804403614042, 'epoch': 0.2}




{'loss': 1.3018, 'grad_norm': 1.6038103103637695, 'learning_rate': 0.0010647298183744359, 'epoch': 0.2}




{'loss': 1.2803, 'grad_norm': 1.3743510246276855, 'learning_rate': 0.0010485622221144484, 'epoch': 0.21}




{'loss': 1.3631, 'grad_norm': 1.5975850820541382, 'learning_rate': 0.001032381891253356, 'epoch': 0.21}




{'loss': 1.2592, 'grad_norm': 1.6555545330047607, 'learning_rate': 0.0010161930688025015, 'epoch': 0.21}




{'loss': 1.2786, 'grad_norm': 1.4284824132919312, 'learning_rate': 0.001, 'epoch': 0.21}




{'loss': 1.2963, 'grad_norm': 1.7123064994812012, 'learning_rate': 0.0009838069311974985, 'epoch': 0.22}




{'loss': 1.2897, 'grad_norm': 1.3351553678512573, 'learning_rate': 0.0009676181087466443, 'epoch': 0.22}




{'loss': 1.3101, 'grad_norm': 1.2764580249786377, 'learning_rate': 0.000951437777885552, 'epoch': 0.22}




{'loss': 1.2351, 'grad_norm': 1.3756550550460815, 'learning_rate': 0.0009352701816255643, 'epoch': 0.22}




{'loss': 1.2861, 'grad_norm': 1.209535002708435, 'learning_rate': 0.0009191195596385959, 'epoch': 0.22}




{'loss': 1.262, 'grad_norm': 1.3619756698608398, 'learning_rate': 0.000902990147145352, 'epoch': 0.23}




{'loss': 1.233, 'grad_norm': 1.7408370971679688, 'learning_rate': 0.0008868861738047158, 'epoch': 0.23}




{'loss': 1.2751, 'grad_norm': 0.9142020344734192, 'learning_rate': 0.0008708118626045939, 'epoch': 0.23}




{'loss': 1.2343, 'grad_norm': 1.490129828453064, 'learning_rate': 0.00085477142875451, 'epoch': 0.23}




{'loss': 1.2249, 'grad_norm': 1.238423466682434, 'learning_rate': 0.0008387690785802402, 'epoch': 0.23}




{'loss': 1.21, 'grad_norm': 0.9772652983665466, 'learning_rate': 0.0008228090084207773, 'epoch': 0.24}




{'loss': 1.2844, 'grad_norm': 1.4546443223953247, 'learning_rate': 0.0008068954035279121, 'epoch': 0.24}




{'loss': 1.2194, 'grad_norm': 1.2997456789016724, 'learning_rate': 0.000791032436968725, 'epoch': 0.24}




{'loss': 1.2083, 'grad_norm': 1.9079747200012207, 'learning_rate': 0.0007752242685312709, 'epoch': 0.24}




{'loss': 1.2306, 'grad_norm': 1.2147644758224487, 'learning_rate': 0.0007594750436337467, 'epoch': 0.24}




{'loss': 1.2893, 'grad_norm': 1.4435302019119263, 'learning_rate': 0.0007437888922374276, 'epoch': 0.25}




{'loss': 1.2489, 'grad_norm': 1.3550947904586792, 'learning_rate': 0.0007281699277636571, 'epoch': 0.25}




{'loss': 1.2136, 'grad_norm': 1.2309287786483765, 'learning_rate': 0.0007126222460151719, 'epoch': 0.25}




{'loss': 1.1553, 'grad_norm': 1.4029550552368164, 'learning_rate': 0.0006971499241020494, 'epoch': 0.25}




{'loss': 1.3305, 'grad_norm': 0.936649739742279, 'learning_rate': 0.0006817570193725564, 'epoch': 0.25}




{'loss': 1.2361, 'grad_norm': 1.6900100708007812, 'learning_rate': 0.0006664475683491796, 'epoch': 0.26}




{'loss': 1.1373, 'grad_norm': 1.4432871341705322, 'learning_rate': 0.0006512255856701177, 'epoch': 0.26}




{'loss': 1.2614, 'grad_norm': 1.4110153913497925, 'learning_rate': 0.0006360950630365126, 'epoch': 0.26}




{'loss': 1.2106, 'grad_norm': 1.833466649055481, 'learning_rate': 0.0006210599681656932, 'epoch': 0.26}




{'loss': 1.2128, 'grad_norm': 1.7234034538269043, 'learning_rate': 0.0006061242437507131, 'epoch': 0.26}




{'loss': 1.1634, 'grad_norm': 1.3333450555801392, 'learning_rate': 0.000591291806426444, 'epoch': 0.27}




{'loss': 1.1737, 'grad_norm': 1.0637725591659546, 'learning_rate': 0.0005765665457425102, 'epoch': 0.27}




{'loss': 1.2479, 'grad_norm': 1.445026159286499, 'learning_rate': 0.0005619523231433177, 'epoch': 0.27}




{'loss': 1.1636, 'grad_norm': 1.106514573097229, 'learning_rate': 0.0005474529709554612, 'epoch': 0.27}




{'loss': 1.2446, 'grad_norm': 1.3041290044784546, 'learning_rate': 0.0005330722913827594, 'epoch': 0.28}




{'loss': 1.1832, 'grad_norm': 1.6564666032791138, 'learning_rate': 0.0005188140555091949, 'epoch': 0.28}




{'loss': 1.2049, 'grad_norm': 1.3603057861328125, 'learning_rate': 0.0005046820023100129, 'epoch': 0.28}




{'loss': 1.1295, 'grad_norm': 0.8950020670890808, 'learning_rate': 0.0004906798376712373, 'epoch': 0.28}




{'loss': 1.2735, 'grad_norm': 1.2325656414031982, 'learning_rate': 0.00047681123341786994, 'epoch': 0.28}




{'loss': 1.1663, 'grad_norm': 1.0435501337051392, 'learning_rate': 0.0004630798263510162, 'epoch': 0.29}




{'loss': 1.1079, 'grad_norm': 1.575168251991272, 'learning_rate': 0.00044948921729419644, 'epoch': 0.29}




{'loss': 1.2027, 'grad_norm': 1.3369417190551758, 'learning_rate': 0.0004360429701490934, 'epoch': 0.29}




{'loss': 1.1588, 'grad_norm': 2.9825758934020996, 'learning_rate': 0.0004227446109609808, 'epoch': 0.29}




{'loss': 1.2215, 'grad_norm': 2.1022071838378906, 'learning_rate': 0.00040959762699407763, 'epoch': 0.29}




{'loss': 1.1935, 'grad_norm': 1.2092372179031372, 'learning_rate': 0.00039660546581707536, 'epoch': 0.3}




{'loss': 1.1984, 'grad_norm': 1.2401303052902222, 'learning_rate': 0.00038377153439907266, 'epoch': 0.3}




{'loss': 1.2061, 'grad_norm': 1.3339202404022217, 'learning_rate': 0.00037109919821615546, 'epoch': 0.3}




{'loss': 1.1581, 'grad_norm': 1.6069828271865845, 'learning_rate': 0.0003585917803688603, 'epoch': 0.3}




{'loss': 1.191, 'grad_norm': 1.215423583984375, 'learning_rate': 0.0003462525607107477, 'epoch': 0.3}




{'loss': 1.169, 'grad_norm': 0.95818030834198, 'learning_rate': 0.00033408477498831913, 'epoch': 0.31}




{'loss': 1.1822, 'grad_norm': 1.501001000404358, 'learning_rate': 0.00032209161399249675, 'epoch': 0.31}




{'loss': 1.1222, 'grad_norm': 1.3030200004577637, 'learning_rate': 0.00031027622272189573, 'epoch': 0.31}




{'loss': 1.2376, 'grad_norm': 1.0448280572891235, 'learning_rate': 0.00029864169955810084, 'epoch': 0.31}




{'loss': 1.182, 'grad_norm': 1.0665391683578491, 'learning_rate': 0.00028719109545317104, 'epoch': 0.31}




{'loss': 1.1612, 'grad_norm': 1.0376957654953003, 'learning_rate': 0.0002759274131295787, 'epoch': 0.32}




{'loss': 1.1151, 'grad_norm': 0.920687198638916, 'learning_rate': 0.0002648536062927999, 'epoch': 0.32}




{'loss': 1.1543, 'grad_norm': 0.872800886631012, 'learning_rate': 0.00025397257885675397, 'epoch': 0.32}




{'loss': 1.0633, 'grad_norm': 1.0265511274337769, 'learning_rate': 0.00024328718418230468, 'epoch': 0.32}




{'loss': 1.1705, 'grad_norm': 1.1426221132278442, 'learning_rate': 0.00023280022432901383, 'epoch': 0.32}




{'loss': 1.1971, 'grad_norm': 1.3941009044647217, 'learning_rate': 0.0002225144493203509, 'epoch': 0.33}




{'loss': 1.1458, 'grad_norm': 1.212072491645813, 'learning_rate': 0.00021243255642254576, 'epoch': 0.33}




{'loss': 1.1758, 'grad_norm': 1.5925945043563843, 'learning_rate': 0.0002025571894372794, 'epoch': 0.33}




{'loss': 1.1755, 'grad_norm': 1.1809017658233643, 'learning_rate': 0.00019289093800839064, 'epoch': 0.33}




{'loss': 1.137, 'grad_norm': 1.1642966270446777, 'learning_rate': 0.00018343633694278895, 'epoch': 0.34}




{'loss': 1.1469, 'grad_norm': 0.9544137716293335, 'learning_rate': 0.00017419586554574362, 'epoch': 0.34}




{'loss': 1.1536, 'grad_norm': 0.9863898158073425, 'learning_rate': 0.00016517194697072903, 'epoch': 0.34}




{'loss': 1.162, 'grad_norm': 0.9222826361656189, 'learning_rate': 0.0001563669475839956, 'epoch': 0.34}




{'loss': 1.1388, 'grad_norm': 1.1527345180511475, 'learning_rate': 0.0001477831763440308, 'epoch': 0.34}




{'loss': 1.164, 'grad_norm': 1.08338463306427, 'learning_rate': 0.00013942288419607475, 'epoch': 0.35}




{'loss': 1.0924, 'grad_norm': 0.8170047998428345, 'learning_rate': 0.00013128826348184885, 'epoch': 0.35}




{'loss': 1.0811, 'grad_norm': 1.0327695608139038, 'learning_rate': 0.0001233814473646524, 'epoch': 0.35}




{'loss': 1.1422, 'grad_norm': 1.2194979190826416, 'learning_rate': 0.00011570450926997656, 'epoch': 0.35}




{'loss': 1.1099, 'grad_norm': 1.187402367591858, 'learning_rate': 0.00010825946234178574, 'epoch': 0.35}




{'loss': 1.1378, 'grad_norm': 1.2684168815612793, 'learning_rate': 0.00010104825891460479, 'epoch': 0.36}




{'loss': 1.2639, 'grad_norm': 1.2378133535385132, 'learning_rate': 9.407279000155311e-05, 'epoch': 0.36}




{'loss': 1.1016, 'grad_norm': 1.4303408861160278, 'learning_rate': 8.733488479845996e-05, 'epoch': 0.36}




{'loss': 1.1046, 'grad_norm': 1.1888500452041626, 'learning_rate': 8.083631020418791e-05, 'epoch': 0.36}




{'loss': 1.0615, 'grad_norm': 1.1826225519180298, 'learning_rate': 7.457877035729587e-05, 'epoch': 0.36}




{'loss': 1.1731, 'grad_norm': 1.2951655387878418, 'learning_rate': 6.856390618915776e-05, 'epoch': 0.37}




{'loss': 1.1871, 'grad_norm': 1.2795650959014893, 'learning_rate': 6.279329499365649e-05, 'epoch': 0.37}




{'loss': 1.085, 'grad_norm': 1.0524147748947144, 'learning_rate': 5.726845001356573e-05, 'epoch': 0.37}




{'loss': 1.1745, 'grad_norm': 1.1112385988235474, 'learning_rate': 5.199082004372957e-05, 'epoch': 0.37}




{'loss': 1.0861, 'grad_norm': 1.0111092329025269, 'learning_rate': 4.6961789051139124e-05, 'epoch': 0.37}




{'loss': 1.2061, 'grad_norm': 1.151517629623413, 'learning_rate': 4.218267581201296e-05, 'epoch': 0.38}




{'loss': 1.1612, 'grad_norm': 0.8623828887939453, 'learning_rate': 3.7654733565969825e-05, 'epoch': 0.38}




{'loss': 1.1552, 'grad_norm': 1.4482641220092773, 'learning_rate': 3.337914968738887e-05, 'epoch': 0.38}




{'loss': 1.1344, 'grad_norm': 1.7752220630645752, 'learning_rate': 2.9357045374040825e-05, 'epoch': 0.38}




{'loss': 1.1664, 'grad_norm': 1.6485670804977417, 'learning_rate': 2.5589475353073986e-05, 'epoch': 0.38}




{'loss': 1.1216, 'grad_norm': 0.9247493743896484, 'learning_rate': 2.2077427604429433e-05, 'epoch': 0.39}




{'loss': 1.1278, 'grad_norm': 0.9902240037918091, 'learning_rate': 1.882182310176095e-05, 'epoch': 0.39}




{'loss': 1.0465, 'grad_norm': 1.434756875038147, 'learning_rate': 1.5823515570925763e-05, 'epoch': 0.39}




{'loss': 1.109, 'grad_norm': 1.8376632928848267, 'learning_rate': 1.3083291266109298e-05, 'epoch': 0.39}




{'loss': 1.1307, 'grad_norm': 1.2433574199676514, 'learning_rate': 1.0601868763643995e-05, 'epoch': 0.4}




{'loss': 1.0966, 'grad_norm': 1.1007815599441528, 'learning_rate': 8.379898773574923e-06, 'epoch': 0.4}




{'loss': 1.1598, 'grad_norm': 0.923400342464447, 'learning_rate': 6.417963969022389e-06, 'epoch': 0.4}




{'loss': 1.1293, 'grad_norm': 1.7794280052185059, 'learning_rate': 4.7165788333860535e-06, 'epoch': 0.4}




{'loss': 1.1662, 'grad_norm': 1.6357388496398926, 'learning_rate': 3.2761895254306282e-06, 'epoch': 0.4}




{'loss': 1.1877, 'grad_norm': 1.4677106142044067, 'learning_rate': 2.0971737622883515e-06, 'epoch': 0.41}




{'loss': 1.1779, 'grad_norm': 1.119966745376587, 'learning_rate': 1.1798407204093308e-06, 'epoch': 0.41}




{'loss': 1.1374, 'grad_norm': 1.0557212829589844, 'learning_rate': 5.244309544850667e-07, 'epoch': 0.41}




{'loss': 1.193, 'grad_norm': 1.1305638551712036, 'learning_rate': 1.3111633436779792e-07, 'epoch': 0.41}




{'loss': 1.0554, 'grad_norm': 1.6062161922454834, 'learning_rate': 0.0, 'epoch': 0.41}
{'train_runtime': 947.0058, 'train_samples_per_second': 12.672, 'train_steps_per_second': 2.112, 'train_loss': 1.2585658359527587, 'epoch': 0.41}


TrainOutput(global_step=2000, training_loss=1.2585658359527587, metrics={'train_runtime': 947.0058, 'train_samples_per_second': 12.672, 'train_steps_per_second': 2.112, 'train_loss': 1.2585658359527587, 'epoch': 0.41})

In [5]:
def formatted_prompt(question)-> str:
    return f"<|user|>\n{question}</s>\n<|assistant|>:"


In [6]:
from transformers import GenerationConfig
from time import perf_counter

def generate_response(user_input):

  prompt = formatted_prompt(user_input)

  inputs = tokenizer([prompt], return_tensors="pt")
  generation_config = GenerationConfig(penalty_alpha=0.6,do_sample = True,
      top_k=5,temperature=0.5,repetition_penalty=1.5,
      max_new_tokens=20,pad_token_id=tokenizer.eos_token_id
  )
  start_time = perf_counter()

  inputs = tokenizer(prompt, return_tensors="pt").to('cuda')

  outputs = model.generate(**inputs, generation_config=generation_config)
  output=tokenizer.decode(outputs[0], skip_special_tokens=True)
  #print(output)
  output_time = perf_counter() - start_time
  #print(f"Time taken for inference: {round(output_time,2)} seconds")
  return output

In [10]:
test_data = load_dataset("bentrevett/multi30k", split="test")
test_data_df = pd.DataFrame(test_data)


In [13]:
generate_response(test_data_df.iloc[16]['en'])

'<|user|>\nA blond holding hands with a guy in the sand. \n<|assistant|>:\nEin Blonder hält Händchen mit einem Mann im Sand herum. \n'

In [134]:
print(test_data_df.iloc[16]['de'])

Eine Blondine hält mit einem Mann im Sand Händchen.


Loading The model

In [3]:
from peft import AutoPeftModelForCausalLM, PeftModel
from transformers import AutoModelForCausalLM
import torch
import os

model = AutoModelForCausalLM.from_pretrained(
    model_name,
    quantization_config=bnb_config, 
    device_map="auto",  
    trust_remote_code=True, 
)

model_path = "HamadJassem/tinyllama2k"
peft_model = PeftModel.from_pretrained(model, model_path, from_transformers=True, device_map="auto")
model = peft_model.merge_and_unload()



In [14]:
# for each test data, generate response and save it in a list

responses = []
for i in range(0, len(test_data_df)):
    response = generate_response(test_data_df.iloc[i]['en'])
    responses.append(response)

In [16]:
# save the responses in a new dataframe
new_df = test_data_df.copy()
new_df['generated_de'] = responses
new_df.to_csv('generated_responses_tinyllama2k.csv', index=False)
