In [13]:
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    HfArgumentParser,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import (
    LoraConfig,
    PeftModel,
    prepare_model_for_kbit_training,
    get_peft_model,
)
import os, torch, wandb
from datasets import load_dataset
from trl import SFTTrainer, setup_chat_format

In [15]:
run = wandb.init(
    project='Fine-tune Llama 3 on CC Dataset', 
    job_type="training", 
    anonymous="allow"
)

In [16]:
base_model = "meta-llama/Llama-3.2-3B-Instruct"
new_model = "llama-3.2-3b-CC"


In [18]:
torch_dtype = torch.float16
attn_implementation = "eager"

In [19]:
# QLoRA config
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch_dtype,
    bnb_4bit_use_double_quant=True,
)
# Load model
model = AutoModelForCausalLM.from_pretrained(
    base_model,
    quantization_config=bnb_config,
    device_map="auto",
    attn_implementation=attn_implementation
)

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)

Loading checkpoint shards: 100%|██████████| 2/2 [00:08<00:00,  4.31s/it]


In [26]:
df = pd.read_csv('augmented.csv')
df

Unnamed: 0,Gender,Age,Debt,Married,BankCustomer,Industry,Ethnicity,YearsEmployed,PriorDefault,Employed,CreditScore,DriversLicense,Citizen,ZipCode,Income,Approved,Reason
0,1,30.83,0.000,1,1,Industrials,White,1.25,1,1,1,0,ByBirth,202,0,1,"This application was approved due to Income, Y..."
1,0,58.67,4.460,1,1,Materials,Black,3.04,1,1,6,0,ByBirth,43,560,1,"This application was approved due to Income, Y..."
2,0,24.50,0.500,1,1,Materials,Black,1.50,1,0,0,0,ByBirth,280,824,1,"This application was approved due to Income, Y..."
3,1,27.83,1.540,1,1,Industrials,White,3.75,1,1,5,1,ByBirth,100,3,1,This application was approved due to YearsEmpl...
4,1,20.17,5.625,1,1,Industrials,White,1.71,1,0,0,0,ByOtherMeans,120,0,1,"This application was approved due to Income, Y..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
685,1,21.08,10.085,0,0,Education,Black,1.25,0,0,0,0,ByBirth,260,0,0,"This application was denied due to Employed, Z..."
686,0,22.67,0.750,1,1,Energy,White,2.00,0,1,2,1,ByBirth,200,394,0,"This application was denied due to Income, Zip..."
687,0,25.25,13.500,0,0,Healthcare,Latino,2.00,0,1,1,1,ByBirth,200,1,0,"This application was denied due to Income, Zip..."
688,1,17.92,0.205,1,1,ConsumerStaples,White,0.04,0,0,0,0,ByBirth,280,750,0,This application was denied due to YearsEmploy...


In [27]:
import pandas as pd
import random

# Function to randomly drop certain labels with a specified probability
def preprocess_data_generalized(row, drop_prob=0.2):
    # Randomly generate different text for input
    gender_text = f"Gender: {'Male' if row['Gender'] == 1 else 'Female'}" if random.random() > drop_prob else ""
    age_text = f"Age: {row['Age']}" if random.random() > drop_prob else ""
    debt_text = f"Debt: {row['Debt']}" if random.random() > drop_prob else ""
    married_text = f"Married: {'Yes' if row['Married'] == 1 else 'No'}" if random.random() > drop_prob else ""
    bank_customer_text = f"BankCustomer: {'Yes' if row['BankCustomer'] == 1 else 'No'}" if random.random() > drop_prob else ""
    industry_text = f"Industry: {row['Industry']}" if random.random() > drop_prob else ""
    ethnicity_text = f"Ethnicity: {row['Ethnicity']}" if random.random() > drop_prob else ""
    years_employed_text = f"YearsEmployed: {row['YearsEmployed']}" if random.random() > drop_prob else ""
    prior_default_text = f"PriorDefault: {'Yes' if row['PriorDefault'] == 1 else 'No'}" if random.random() > drop_prob else ""
    employed_text = f"Employed: {'Yes' if row['Employed'] == 1 else 'No'}" if random.random() > drop_prob else ""
    credit_score_text = f"CreditScore: {row['CreditScore']}" if random.random() > drop_prob else ""
    drivers_license_text = f"DriversLicense: {'Yes' if row['DriversLicense'] == 1 else 'No'}" if random.random() > drop_prob else ""
    citizen_text = f"Citizen: {row['Citizen']}" if random.random() > drop_prob else ""
    zip_code_text = f"ZipCode: {row['ZipCode']}" if random.random() > drop_prob else ""
    income_text = f"Income: {row['Income']}" if random.random() > drop_prob else ""

    # Combine input text with random omissions
    input_text = ", ".join(filter(None, [
        gender_text, age_text, debt_text, married_text, bank_customer_text, 
        industry_text, ethnicity_text, years_employed_text, prior_default_text, 
        employed_text, credit_score_text, drivers_license_text, citizen_text, 
        zip_code_text, income_text
    ]))
    
    # Simplified output format (Approved status and reason)
    output_text = f"{'Yes' if row['Approved'] == 1 else 'No'}, {row['Reason']}"
    
    return {"text": input_text, "label": output_text}

# Apply the generalized preprocessing to the dataframe
df_processed = df.apply(preprocess_data_generalized, axis=1)
df_final = pd.DataFrame(df_processed.tolist())

# Save the processed data to a new CSV file
df_final.to_csv('generalized_transformed_text_data.csv', index=False)

# Display the first few rows of the processed data
print(df_final.head())


Unnamed: 0,text,label
0,"Age: 30.83, CreditScore: 1, Income: 0, YearsEm...","Yes, This application was approved due to Inco..."
1,"Age: 58.67, CreditScore: 6, Income: 560, Years...","Yes, This application was approved due to Inco..."
2,"Age: 24.5, CreditScore: 0, Income: 824, YearsE...","Yes, This application was approved due to Inco..."
3,"Age: 27.83, CreditScore: 5, Income: 3, YearsEm...","Yes, This application was approved due to Year..."
4,"Age: 20.17, CreditScore: 0, Income: 0, YearsEm...","Yes, This application was approved due to Inco..."
...,...,...
685,"Age: 21.08, CreditScore: 0, Income: 0, YearsEm...","No, This application was denied due to Employe..."
686,"Age: 22.67, CreditScore: 2, Income: 394, Years...","No, This application was denied due to Income,..."
687,"Age: 25.25, CreditScore: 1, Income: 1, YearsEm...","No, This application was denied due to Income,..."
688,"Age: 17.92, CreditScore: 0, Income: 750, Years...","No, This application was denied due to YearsEm..."


In [28]:
dataset = Dataset.from_pandas(df_final)
dataset['text'][3]


'Age: 27.83, CreditScore: 5, Income: 3, YearsEmployed: 3.75, Gender: Male, Married: Yes, Industry: Industrials, Ethnicity: White, PriorDefault: Yes, Employed: Yes'

In [29]:
import bitsandbytes as bnb

def find_all_linear_names(model):
    cls = bnb.nn.Linear4bit
    lora_module_names = set()
    for name, module in model.named_modules():
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])
    if 'lm_head' in lora_module_names:  # needed for 16 bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

modules = find_all_linear_names(model)

In [30]:
# LoRA config
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM",
    target_modules=modules
)
model, tokenizer = setup_chat_format(model, tokenizer)
model = get_peft_model(model, peft_config)

In [31]:
#Hyperparamter
training_arguments = TrainingArguments(
    output_dir=new_model,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=2,
    optim="paged_adamw_32bit",
    num_train_epochs=1,
    eval_strategy="steps",
    eval_steps=0.2,
    logging_steps=1,
    warmup_steps=10,
    logging_strategy="steps",
    learning_rate=2e-4,
    fp16=False,
    bf16=False,
    group_by_length=True,
    report_to="wandb"
)

In [34]:
from datasets import Dataset

# Assuming dataset is a Dataset object with columns 'text' and 'label'
train_test_split = dataset.train_test_split(test_size=0.2)  # Split into 80% train, 20% test
train_dataset = train_test_split['train']
eval_dataset = train_test_split['test']

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    peft_config=peft_config,
    max_seq_length=512,
    dataset_text_field="text",  # The column containing the input text
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
)


Map: 100%|██████████| 552/552 [00:00<00:00, 6702.75 examples/s]
Map: 100%|██████████| 138/138 [00:00<00:00, 6819.77 examples/s]
  lambda data: self._console_raw_callback("stderr", data),
  0%|          | 0/241 [07:16<?, ?it/s]


In [None]:
trainer.train()

  0%|          | 1/276 [00:04<18:29,  4.04s/it]

{'loss': 10.2969, 'grad_norm': 26.70298194885254, 'learning_rate': 2e-05, 'epoch': 0.0}


  1%|          | 2/276 [00:05<12:16,  2.69s/it]

{'loss': 10.3047, 'grad_norm': 26.68596839904785, 'learning_rate': 4e-05, 'epoch': 0.01}


  1%|          | 3/276 [00:07<09:24,  2.07s/it]

{'loss': 9.8516, 'grad_norm': 25.80520248413086, 'learning_rate': 6e-05, 'epoch': 0.01}


  1%|▏         | 4/276 [00:08<08:41,  1.92s/it]

{'loss': 8.7188, 'grad_norm': 21.52022361755371, 'learning_rate': 8e-05, 'epoch': 0.01}


  2%|▏         | 5/276 [00:10<08:07,  1.80s/it]

{'loss': 7.6016, 'grad_norm': 12.182042121887207, 'learning_rate': 0.0001, 'epoch': 0.02}


  2%|▏         | 6/276 [00:11<07:05,  1.58s/it]

{'loss': 7.1602, 'grad_norm': 23.240028381347656, 'learning_rate': 0.00012, 'epoch': 0.02}


  3%|▎         | 7/276 [00:12<06:34,  1.46s/it]

{'loss': 6.8047, 'grad_norm': 14.437237739562988, 'learning_rate': 0.00014, 'epoch': 0.03}


  3%|▎         | 8/276 [00:14<06:32,  1.46s/it]

{'loss': 6.2305, 'grad_norm': 10.429811477661133, 'learning_rate': 0.00016, 'epoch': 0.03}


  3%|▎         | 9/276 [00:15<06:37,  1.49s/it]

{'loss': 5.8398, 'grad_norm': 8.53739070892334, 'learning_rate': 0.00018, 'epoch': 0.03}


  4%|▎         | 10/276 [00:18<07:36,  1.72s/it]

{'loss': 5.2891, 'grad_norm': 5.546876907348633, 'learning_rate': 0.0002, 'epoch': 0.04}


  4%|▍         | 11/276 [00:19<07:14,  1.64s/it]

{'loss': 6.3359, 'grad_norm': 15.156576156616211, 'learning_rate': 0.00019924812030075188, 'epoch': 0.04}


  4%|▍         | 12/276 [00:20<06:38,  1.51s/it]

{'loss': 4.9844, 'grad_norm': 9.44710922241211, 'learning_rate': 0.00019849624060150375, 'epoch': 0.04}


  5%|▍         | 13/276 [00:21<06:17,  1.44s/it]

{'loss': 4.5508, 'grad_norm': 4.848315238952637, 'learning_rate': 0.00019774436090225567, 'epoch': 0.05}


  5%|▌         | 14/276 [00:23<06:02,  1.38s/it]

{'loss': 4.207, 'grad_norm': 6.452585220336914, 'learning_rate': 0.00019699248120300754, 'epoch': 0.05}


  5%|▌         | 15/276 [00:24<05:49,  1.34s/it]

{'loss': 4.0156, 'grad_norm': 5.607563018798828, 'learning_rate': 0.0001962406015037594, 'epoch': 0.05}


  6%|▌         | 16/276 [00:25<05:42,  1.32s/it]

{'loss': 3.6172, 'grad_norm': 4.719135761260986, 'learning_rate': 0.00019548872180451127, 'epoch': 0.06}


  6%|▌         | 17/276 [00:27<05:52,  1.36s/it]

{'loss': 3.5, 'grad_norm': 5.756117343902588, 'learning_rate': 0.00019473684210526317, 'epoch': 0.06}


  7%|▋         | 18/276 [00:28<06:18,  1.47s/it]

{'loss': 3.1934, 'grad_norm': 4.544558048248291, 'learning_rate': 0.00019398496240601503, 'epoch': 0.07}


  7%|▋         | 19/276 [00:31<07:32,  1.76s/it]

{'loss': 3.0918, 'grad_norm': 4.251826763153076, 'learning_rate': 0.00019323308270676693, 'epoch': 0.07}


  7%|▋         | 20/276 [00:33<07:29,  1.76s/it]

{'loss': 2.9297, 'grad_norm': 5.528449535369873, 'learning_rate': 0.0001924812030075188, 'epoch': 0.07}


  8%|▊         | 21/276 [00:34<06:53,  1.62s/it]

{'loss': 2.5332, 'grad_norm': 3.6929728984832764, 'learning_rate': 0.0001917293233082707, 'epoch': 0.08}


  8%|▊         | 22/276 [00:35<06:06,  1.44s/it]

{'loss': 2.5059, 'grad_norm': 4.111555099487305, 'learning_rate': 0.00019097744360902256, 'epoch': 0.08}


  8%|▊         | 23/276 [00:36<05:26,  1.29s/it]

{'loss': 2.3906, 'grad_norm': 4.453573226928711, 'learning_rate': 0.00019022556390977443, 'epoch': 0.08}


  9%|▊         | 24/276 [00:37<05:20,  1.27s/it]

{'loss': 2.3008, 'grad_norm': 3.8797948360443115, 'learning_rate': 0.00018947368421052632, 'epoch': 0.09}


  9%|▉         | 25/276 [00:38<05:02,  1.21s/it]

{'loss': 1.8359, 'grad_norm': 5.211111068725586, 'learning_rate': 0.00018872180451127822, 'epoch': 0.09}


  9%|▉         | 26/276 [00:39<04:55,  1.18s/it]

{'loss': 1.6094, 'grad_norm': 7.205620765686035, 'learning_rate': 0.00018796992481203009, 'epoch': 0.09}


 10%|▉         | 27/276 [00:40<04:44,  1.14s/it]

{'loss': 1.3584, 'grad_norm': 6.098331451416016, 'learning_rate': 0.00018721804511278195, 'epoch': 0.1}


 10%|█         | 28/276 [00:41<04:30,  1.09s/it]

{'loss': 1.459, 'grad_norm': 5.732055187225342, 'learning_rate': 0.00018646616541353382, 'epoch': 0.1}


 11%|█         | 29/276 [00:43<05:40,  1.38s/it]

{'loss': 1.0889, 'grad_norm': 5.559864044189453, 'learning_rate': 0.00018571428571428572, 'epoch': 0.11}


 11%|█         | 30/276 [00:45<05:54,  1.44s/it]

{'loss': 0.7651, 'grad_norm': 4.676792144775391, 'learning_rate': 0.0001849624060150376, 'epoch': 0.11}


 11%|█         | 31/276 [00:46<05:36,  1.37s/it]

{'loss': 1.042, 'grad_norm': 4.23452091217041, 'learning_rate': 0.00018421052631578948, 'epoch': 0.11}


 12%|█▏        | 32/276 [00:48<05:36,  1.38s/it]

{'loss': 0.6743, 'grad_norm': 3.8252575397491455, 'learning_rate': 0.00018345864661654135, 'epoch': 0.12}


 12%|█▏        | 33/276 [00:49<05:26,  1.34s/it]

{'loss': 0.5781, 'grad_norm': 3.3544938564300537, 'learning_rate': 0.00018270676691729324, 'epoch': 0.12}


 12%|█▏        | 34/276 [00:50<05:18,  1.32s/it]

{'loss': 0.8247, 'grad_norm': 4.921255588531494, 'learning_rate': 0.0001819548872180451, 'epoch': 0.12}


 13%|█▎        | 35/276 [00:51<04:54,  1.22s/it]

{'loss': 0.6196, 'grad_norm': 3.311687469482422, 'learning_rate': 0.000181203007518797, 'epoch': 0.13}


 13%|█▎        | 36/276 [00:52<04:41,  1.17s/it]

{'loss': 0.7314, 'grad_norm': 2.4906251430511475, 'learning_rate': 0.00018045112781954887, 'epoch': 0.13}


 13%|█▎        | 37/276 [00:53<04:28,  1.12s/it]

{'loss': 0.6152, 'grad_norm': 2.0345935821533203, 'learning_rate': 0.00017969924812030077, 'epoch': 0.13}


 14%|█▍        | 38/276 [00:54<04:16,  1.08s/it]

{'loss': 0.4407, 'grad_norm': 1.597603678703308, 'learning_rate': 0.00017894736842105264, 'epoch': 0.14}


 14%|█▍        | 39/276 [00:55<04:05,  1.04s/it]

{'loss': 0.8037, 'grad_norm': 4.05405330657959, 'learning_rate': 0.0001781954887218045, 'epoch': 0.14}


 14%|█▍        | 40/276 [00:56<04:06,  1.04s/it]

{'loss': 0.7764, 'grad_norm': 3.4125139713287354, 'learning_rate': 0.0001774436090225564, 'epoch': 0.14}


 15%|█▍        | 41/276 [00:57<03:57,  1.01s/it]

{'loss': 0.6143, 'grad_norm': 3.5511066913604736, 'learning_rate': 0.0001766917293233083, 'epoch': 0.15}


 15%|█▌        | 42/276 [00:58<03:54,  1.00s/it]

{'loss': 0.6748, 'grad_norm': 2.148223876953125, 'learning_rate': 0.00017593984962406016, 'epoch': 0.15}


 16%|█▌        | 43/276 [00:59<03:49,  1.01it/s]

{'loss': 0.4673, 'grad_norm': 2.034865379333496, 'learning_rate': 0.00017518796992481203, 'epoch': 0.16}


 16%|█▌        | 44/276 [01:00<03:49,  1.01it/s]

{'loss': 0.6016, 'grad_norm': 2.496068239212036, 'learning_rate': 0.0001744360902255639, 'epoch': 0.16}


 16%|█▋        | 45/276 [01:01<03:58,  1.03s/it]

{'loss': 0.4551, 'grad_norm': 1.6024969816207886, 'learning_rate': 0.0001736842105263158, 'epoch': 0.16}


 17%|█▋        | 46/276 [01:02<03:56,  1.03s/it]

{'loss': 0.7964, 'grad_norm': 2.4085323810577393, 'learning_rate': 0.0001729323308270677, 'epoch': 0.17}


 17%|█▋        | 47/276 [01:04<04:28,  1.17s/it]

{'loss': 0.4712, 'grad_norm': 1.8362751007080078, 'learning_rate': 0.00017218045112781956, 'epoch': 0.17}


 17%|█▋        | 48/276 [01:07<07:15,  1.91s/it]

{'loss': 0.5601, 'grad_norm': 1.912221908569336, 'learning_rate': 0.00017142857142857143, 'epoch': 0.17}


 18%|█▊        | 49/276 [01:10<08:27,  2.23s/it]

{'loss': 0.4333, 'grad_norm': 1.5183923244476318, 'learning_rate': 0.00017067669172932332, 'epoch': 0.18}


 18%|█▊        | 50/276 [01:13<08:28,  2.25s/it]

{'loss': 0.6445, 'grad_norm': 1.8218071460723877, 'learning_rate': 0.0001699248120300752, 'epoch': 0.18}


 18%|█▊        | 51/276 [01:15<09:00,  2.40s/it]

{'loss': 1.7822, 'grad_norm': 6.8697509765625, 'learning_rate': 0.00016917293233082708, 'epoch': 0.18}


 19%|█▉        | 52/276 [01:18<08:48,  2.36s/it]

{'loss': 1.4014, 'grad_norm': 7.13614559173584, 'learning_rate': 0.00016842105263157895, 'epoch': 0.19}


 19%|█▉        | 53/276 [01:21<09:52,  2.66s/it]

{'loss': 1.1357, 'grad_norm': 3.38922381401062, 'learning_rate': 0.00016766917293233085, 'epoch': 0.19}


 20%|█▉        | 54/276 [01:25<11:31,  3.12s/it]

{'loss': 1.2373, 'grad_norm': 7.06620454788208, 'learning_rate': 0.00016691729323308271, 'epoch': 0.2}


 20%|█▉        | 55/276 [01:28<11:25,  3.10s/it]

{'loss': 1.1104, 'grad_norm': 7.737163066864014, 'learning_rate': 0.00016616541353383458, 'epoch': 0.2}


 20%|██        | 56/276 [01:30<10:31,  2.87s/it]

{'loss': 0.8159, 'grad_norm': 4.985777854919434, 'learning_rate': 0.00016541353383458648, 'epoch': 0.2}



 20%|██        | 56/276 [02:08<10:31,  2.87s/it] 

{'eval_loss': 0.7902008891105652, 'eval_runtime': 37.2745, 'eval_samples_per_second': 3.702, 'eval_steps_per_second': 3.702, 'epoch': 0.2}


 21%|██        | 57/276 [02:09<49:47, 13.64s/it]

{'loss': 0.668, 'grad_norm': 1.8529971837997437, 'learning_rate': 0.00016466165413533837, 'epoch': 0.21}


 21%|██        | 58/276 [02:10<35:45,  9.84s/it]

{'loss': 0.8179, 'grad_norm': 2.2227694988250732, 'learning_rate': 0.00016390977443609024, 'epoch': 0.21}


 21%|██▏       | 59/276 [02:11<26:03,  7.20s/it]

{'loss': 0.6816, 'grad_norm': 1.8843238353729248, 'learning_rate': 0.0001631578947368421, 'epoch': 0.21}


 22%|██▏       | 60/276 [02:12<19:13,  5.34s/it]

{'loss': 0.8438, 'grad_norm': 2.6790804862976074, 'learning_rate': 0.00016240601503759398, 'epoch': 0.22}


 22%|██▏       | 61/276 [02:13<14:26,  4.03s/it]

{'loss': 0.428, 'grad_norm': 1.8881653547286987, 'learning_rate': 0.00016165413533834587, 'epoch': 0.22}


 22%|██▏       | 62/276 [02:14<11:09,  3.13s/it]

{'loss': 0.7769, 'grad_norm': 2.6173083782196045, 'learning_rate': 0.00016090225563909777, 'epoch': 0.22}


 23%|██▎       | 63/276 [02:15<08:47,  2.48s/it]

{'loss': 1.3691, 'grad_norm': 5.553521156311035, 'learning_rate': 0.00016015037593984963, 'epoch': 0.23}


 23%|██▎       | 64/276 [02:16<07:07,  2.02s/it]

{'loss': 1.123, 'grad_norm': 6.541497230529785, 'learning_rate': 0.0001593984962406015, 'epoch': 0.23}


 24%|██▎       | 65/276 [02:17<06:00,  1.71s/it]

{'loss': 0.9961, 'grad_norm': 4.409235000610352, 'learning_rate': 0.0001586466165413534, 'epoch': 0.24}


 24%|██▍       | 66/276 [02:18<05:26,  1.55s/it]

{'loss': 0.9917, 'grad_norm': 3.3555874824523926, 'learning_rate': 0.00015789473684210527, 'epoch': 0.24}


 24%|██▍       | 67/276 [02:19<04:48,  1.38s/it]

{'loss': 0.6289, 'grad_norm': 2.6924326419830322, 'learning_rate': 0.00015714285714285716, 'epoch': 0.24}


 25%|██▍       | 68/276 [02:20<04:18,  1.24s/it]

{'loss': 0.7642, 'grad_norm': 2.225796699523926, 'learning_rate': 0.00015639097744360903, 'epoch': 0.25}


 25%|██▌       | 69/276 [02:21<03:58,  1.15s/it]

{'loss': 0.8145, 'grad_norm': 2.0125582218170166, 'learning_rate': 0.00015563909774436092, 'epoch': 0.25}


 25%|██▌       | 70/276 [02:22<03:48,  1.11s/it]

{'loss': 0.71, 'grad_norm': 2.2030112743377686, 'learning_rate': 0.0001548872180451128, 'epoch': 0.25}


 26%|██▌       | 71/276 [02:24<04:11,  1.22s/it]

{'loss': 0.6016, 'grad_norm': 2.2031803131103516, 'learning_rate': 0.00015413533834586466, 'epoch': 0.26}


 26%|██▌       | 72/276 [02:25<04:43,  1.39s/it]

{'loss': 0.4829, 'grad_norm': 2.152784585952759, 'learning_rate': 0.00015338345864661653, 'epoch': 0.26}


 26%|██▋       | 73/276 [02:27<04:33,  1.35s/it]

{'loss': 0.5649, 'grad_norm': 1.2033601999282837, 'learning_rate': 0.00015263157894736845, 'epoch': 0.26}


 27%|██▋       | 74/276 [02:28<04:09,  1.24s/it]

{'loss': 0.6284, 'grad_norm': 1.8394187688827515, 'learning_rate': 0.00015187969924812032, 'epoch': 0.27}


 27%|██▋       | 75/276 [02:29<03:52,  1.16s/it]

{'loss': 0.6792, 'grad_norm': 2.3845109939575195, 'learning_rate': 0.00015112781954887218, 'epoch': 0.27}


 28%|██▊       | 76/276 [02:29<03:30,  1.05s/it]

{'loss': 0.5308, 'grad_norm': 1.4314661026000977, 'learning_rate': 0.00015037593984962405, 'epoch': 0.28}


 28%|██▊       | 77/276 [02:30<03:17,  1.01it/s]

{'loss': 0.5117, 'grad_norm': 1.5793501138687134, 'learning_rate': 0.00014962406015037595, 'epoch': 0.28}


 28%|██▊       | 78/276 [02:31<03:11,  1.04it/s]

{'loss': 0.5742, 'grad_norm': 1.6310536861419678, 'learning_rate': 0.00014887218045112784, 'epoch': 0.28}


 29%|██▊       | 79/276 [02:32<03:02,  1.08it/s]

{'loss': 0.6221, 'grad_norm': 1.635337471961975, 'learning_rate': 0.0001481203007518797, 'epoch': 0.29}


 29%|██▉       | 80/276 [02:33<02:53,  1.13it/s]

{'loss': 0.624, 'grad_norm': 1.3487088680267334, 'learning_rate': 0.00014736842105263158, 'epoch': 0.29}


 29%|██▉       | 81/276 [02:34<02:59,  1.09it/s]

{'loss': 0.7231, 'grad_norm': 1.6932178735733032, 'learning_rate': 0.00014661654135338347, 'epoch': 0.29}


 30%|██▉       | 82/276 [02:35<03:11,  1.01it/s]

{'loss': 0.4539, 'grad_norm': 1.2928904294967651, 'learning_rate': 0.00014586466165413534, 'epoch': 0.3}


 30%|███       | 83/276 [02:36<03:13,  1.00s/it]

{'loss': 0.4744, 'grad_norm': 1.2033175230026245, 'learning_rate': 0.0001451127819548872, 'epoch': 0.3}


 30%|███       | 84/276 [02:37<03:16,  1.03s/it]

{'loss': 0.6162, 'grad_norm': 2.0493946075439453, 'learning_rate': 0.0001443609022556391, 'epoch': 0.3}


 31%|███       | 85/276 [02:38<03:12,  1.01s/it]

{'loss': 0.6865, 'grad_norm': 1.976599931716919, 'learning_rate': 0.000143609022556391, 'epoch': 0.31}


 31%|███       | 86/276 [02:39<03:28,  1.10s/it]

{'loss': 0.3672, 'grad_norm': 1.174584984779358, 'learning_rate': 0.00014285714285714287, 'epoch': 0.31}


 32%|███▏      | 87/276 [02:40<03:27,  1.10s/it]

{'loss': 0.4805, 'grad_norm': 1.1954286098480225, 'learning_rate': 0.00014210526315789474, 'epoch': 0.32}


 32%|███▏      | 88/276 [02:41<03:22,  1.08s/it]

{'loss': 0.6548, 'grad_norm': 1.3249880075454712, 'learning_rate': 0.0001413533834586466, 'epoch': 0.32}


 32%|███▏      | 89/276 [02:42<03:17,  1.06s/it]

{'loss': 0.5371, 'grad_norm': 1.3420549631118774, 'learning_rate': 0.0001406015037593985, 'epoch': 0.32}


 33%|███▎      | 90/276 [02:43<03:07,  1.01s/it]

{'loss': 0.3843, 'grad_norm': 1.3617827892303467, 'learning_rate': 0.0001398496240601504, 'epoch': 0.33}


 33%|███▎      | 91/276 [02:44<02:57,  1.04it/s]

{'loss': 0.5088, 'grad_norm': 1.54584538936615, 'learning_rate': 0.00013909774436090226, 'epoch': 0.33}


 33%|███▎      | 92/276 [02:45<02:50,  1.08it/s]

{'loss': 0.4485, 'grad_norm': 1.2385166883468628, 'learning_rate': 0.00013834586466165413, 'epoch': 0.33}


 34%|███▎      | 93/276 [02:46<02:49,  1.08it/s]

{'loss': 0.5415, 'grad_norm': 1.674709439277649, 'learning_rate': 0.00013759398496240602, 'epoch': 0.34}


 34%|███▍      | 94/276 [02:48<03:21,  1.11s/it]

{'loss': 0.364, 'grad_norm': 0.7663993239402771, 'learning_rate': 0.0001368421052631579, 'epoch': 0.34}


 34%|███▍      | 95/276 [02:50<04:09,  1.38s/it]

{'loss': 0.4209, 'grad_norm': 1.3148678541183472, 'learning_rate': 0.0001360902255639098, 'epoch': 0.34}


 35%|███▍      | 96/276 [02:52<05:02,  1.68s/it]

{'loss': 0.4944, 'grad_norm': 1.118783950805664, 'learning_rate': 0.00013533834586466166, 'epoch': 0.35}


 35%|███▌      | 97/276 [02:53<04:41,  1.57s/it]

{'loss': 0.4285, 'grad_norm': 1.3597559928894043, 'learning_rate': 0.00013458646616541355, 'epoch': 0.35}


 36%|███▌      | 98/276 [02:55<04:37,  1.56s/it]

{'loss': 0.626, 'grad_norm': 1.5662529468536377, 'learning_rate': 0.00013383458646616542, 'epoch': 0.36}


 36%|███▌      | 99/276 [02:57<04:51,  1.65s/it]

{'loss': 0.4858, 'grad_norm': 1.2841732501983643, 'learning_rate': 0.0001330827067669173, 'epoch': 0.36}


 36%|███▌      | 100/276 [02:59<05:40,  1.94s/it]

{'loss': 0.3635, 'grad_norm': 1.0784813165664673, 'learning_rate': 0.00013233082706766918, 'epoch': 0.36}


 37%|███▋      | 101/276 [03:02<06:02,  2.07s/it]

{'loss': 1.0342, 'grad_norm': 5.8145365715026855, 'learning_rate': 0.00013157894736842108, 'epoch': 0.37}


 37%|███▋      | 102/276 [03:04<06:15,  2.16s/it]

{'loss': 0.814, 'grad_norm': 3.0278897285461426, 'learning_rate': 0.00013082706766917294, 'epoch': 0.37}


 37%|███▋      | 103/276 [03:06<06:11,  2.15s/it]

{'loss': 0.6074, 'grad_norm': 1.795637607574463, 'learning_rate': 0.0001300751879699248, 'epoch': 0.37}


 38%|███▊      | 104/276 [03:08<05:39,  1.98s/it]

{'loss': 0.5898, 'grad_norm': 1.5427641868591309, 'learning_rate': 0.00012932330827067668, 'epoch': 0.38}


 38%|███▊      | 105/276 [03:09<05:22,  1.88s/it]

{'loss': 0.3862, 'grad_norm': 1.3580750226974487, 'learning_rate': 0.00012857142857142858, 'epoch': 0.38}


 38%|███▊      | 106/276 [03:11<04:59,  1.76s/it]

{'loss': 0.5576, 'grad_norm': 1.4930124282836914, 'learning_rate': 0.00012781954887218047, 'epoch': 0.38}


 39%|███▉      | 107/276 [03:13<04:59,  1.77s/it]

{'loss': 0.5352, 'grad_norm': 1.7733781337738037, 'learning_rate': 0.00012706766917293234, 'epoch': 0.39}


 39%|███▉      | 108/276 [03:14<04:59,  1.78s/it]

{'loss': 0.3972, 'grad_norm': 2.455991506576538, 'learning_rate': 0.0001263157894736842, 'epoch': 0.39}


 39%|███▉      | 109/276 [03:16<04:43,  1.70s/it]

{'loss': 0.4309, 'grad_norm': 1.3094679117202759, 'learning_rate': 0.0001255639097744361, 'epoch': 0.39}


 40%|███▉      | 110/276 [03:17<04:18,  1.55s/it]

{'loss': 0.5933, 'grad_norm': 1.821753740310669, 'learning_rate': 0.00012481203007518797, 'epoch': 0.4}


 40%|████      | 111/276 [03:19<04:16,  1.55s/it]

{'loss': 0.7905, 'grad_norm': 3.589935541152954, 'learning_rate': 0.00012406015037593986, 'epoch': 0.4}


 41%|████      | 112/276 [03:20<04:04,  1.49s/it]

{'loss': 0.8652, 'grad_norm': 2.4594533443450928, 'learning_rate': 0.00012330827067669173, 'epoch': 0.41}



 41%|████      | 112/276 [03:48<04:04,  1.49s/it]

{'eval_loss': 0.629747211933136, 'eval_runtime': 28.1555, 'eval_samples_per_second': 4.901, 'eval_steps_per_second': 4.901, 'epoch': 0.41}


 41%|████      | 113/276 [03:49<26:41,  9.82s/it]

{'loss': 0.7285, 'grad_norm': 1.7024028301239014, 'learning_rate': 0.00012255639097744363, 'epoch': 0.41}


 41%|████▏     | 114/276 [03:50<19:24,  7.19s/it]

{'loss': 0.3652, 'grad_norm': 1.4405171871185303, 'learning_rate': 0.0001218045112781955, 'epoch': 0.41}


 42%|████▏     | 115/276 [03:52<14:25,  5.38s/it]

{'loss': 0.3723, 'grad_norm': 1.1425399780273438, 'learning_rate': 0.00012105263157894738, 'epoch': 0.42}


 42%|████▏     | 116/276 [03:53<11:03,  4.14s/it]

{'loss': 0.5488, 'grad_norm': 1.535598874092102, 'learning_rate': 0.00012030075187969925, 'epoch': 0.42}


 42%|████▏     | 117/276 [03:54<08:35,  3.24s/it]

{'loss': 0.4717, 'grad_norm': 1.7443166971206665, 'learning_rate': 0.00011954887218045114, 'epoch': 0.42}


 43%|████▎     | 118/276 [03:55<06:50,  2.60s/it]

{'loss': 0.4468, 'grad_norm': 1.4785562753677368, 'learning_rate': 0.00011879699248120302, 'epoch': 0.43}


 43%|████▎     | 119/276 [03:56<05:34,  2.13s/it]

{'loss': 0.5791, 'grad_norm': 1.4641586542129517, 'learning_rate': 0.00011804511278195489, 'epoch': 0.43}


 43%|████▎     | 120/276 [03:57<04:43,  1.82s/it]

{'loss': 0.6055, 'grad_norm': 1.5679218769073486, 'learning_rate': 0.00011729323308270677, 'epoch': 0.43}


 44%|████▍     | 121/276 [03:58<04:13,  1.64s/it]

{'loss': 0.4956, 'grad_norm': 1.7748253345489502, 'learning_rate': 0.00011654135338345867, 'epoch': 0.44}


 44%|████▍     | 122/276 [04:00<03:55,  1.53s/it]

{'loss': 0.5112, 'grad_norm': 1.6214032173156738, 'learning_rate': 0.00011578947368421053, 'epoch': 0.44}


 45%|████▍     | 123/276 [04:01<03:51,  1.51s/it]

{'loss': 0.4688, 'grad_norm': 1.4648406505584717, 'learning_rate': 0.00011503759398496242, 'epoch': 0.45}


 45%|████▍     | 124/276 [04:02<03:35,  1.42s/it]

{'loss': 0.3989, 'grad_norm': 1.1768051385879517, 'learning_rate': 0.00011428571428571428, 'epoch': 0.45}


 45%|████▌     | 125/276 [04:03<03:16,  1.30s/it]

{'loss': 0.4666, 'grad_norm': 1.3791494369506836, 'learning_rate': 0.00011353383458646618, 'epoch': 0.45}


 46%|████▌     | 126/276 [04:04<03:04,  1.23s/it]

{'loss': 0.4392, 'grad_norm': 1.8941147327423096, 'learning_rate': 0.00011278195488721806, 'epoch': 0.46}


 46%|████▌     | 127/276 [04:05<02:56,  1.19s/it]

{'loss': 0.6641, 'grad_norm': 1.8704129457473755, 'learning_rate': 0.00011203007518796993, 'epoch': 0.46}


 46%|████▋     | 128/276 [04:06<02:46,  1.12s/it]

{'loss': 0.3989, 'grad_norm': 1.6363885402679443, 'learning_rate': 0.00011127819548872181, 'epoch': 0.46}


 47%|████▋     | 129/276 [04:07<02:40,  1.09s/it]

{'loss': 0.4827, 'grad_norm': 1.2373268604278564, 'learning_rate': 0.0001105263157894737, 'epoch': 0.47}


 47%|████▋     | 130/276 [04:08<02:31,  1.04s/it]

{'loss': 0.5884, 'grad_norm': 1.7582745552062988, 'learning_rate': 0.00010977443609022557, 'epoch': 0.47}


 47%|████▋     | 131/276 [04:09<02:21,  1.02it/s]

{'loss': 0.4556, 'grad_norm': 1.338306188583374, 'learning_rate': 0.00010902255639097745, 'epoch': 0.47}


 48%|████▊     | 132/276 [04:10<02:13,  1.08it/s]

{'loss': 0.3083, 'grad_norm': 1.0309653282165527, 'learning_rate': 0.00010827067669172932, 'epoch': 0.48}


 48%|████▊     | 133/276 [04:11<02:08,  1.11it/s]

{'loss': 0.5664, 'grad_norm': 1.3391379117965698, 'learning_rate': 0.00010751879699248122, 'epoch': 0.48}


 49%|████▊     | 134/276 [04:12<02:07,  1.11it/s]

{'loss': 0.3953, 'grad_norm': 1.2880429029464722, 'learning_rate': 0.0001067669172932331, 'epoch': 0.49}


 49%|████▉     | 135/276 [04:13<02:03,  1.14it/s]

{'loss': 0.4917, 'grad_norm': 1.1637654304504395, 'learning_rate': 0.00010601503759398497, 'epoch': 0.49}


 49%|████▉     | 136/276 [04:13<02:00,  1.16it/s]

{'loss': 0.272, 'grad_norm': 0.9838435053825378, 'learning_rate': 0.00010526315789473685, 'epoch': 0.49}


 50%|████▉     | 137/276 [04:14<01:57,  1.18it/s]

{'loss': 0.5264, 'grad_norm': 1.0869580507278442, 'learning_rate': 0.00010451127819548874, 'epoch': 0.5}


 50%|█████     | 138/276 [04:15<02:00,  1.14it/s]

{'loss': 0.4861, 'grad_norm': 1.2618166208267212, 'learning_rate': 0.00010375939849624061, 'epoch': 0.5}


 50%|█████     | 139/276 [04:16<02:05,  1.09it/s]

{'loss': 0.4797, 'grad_norm': 1.4908971786499023, 'learning_rate': 0.00010300751879699249, 'epoch': 0.5}


 51%|█████     | 140/276 [04:17<02:06,  1.07it/s]

{'loss': 0.584, 'grad_norm': 1.349509596824646, 'learning_rate': 0.00010225563909774436, 'epoch': 0.51}


 51%|█████     | 141/276 [04:18<02:04,  1.08it/s]

{'loss': 0.4028, 'grad_norm': 1.0188051462173462, 'learning_rate': 0.00010150375939849626, 'epoch': 0.51}


 51%|█████▏    | 142/276 [04:19<02:01,  1.11it/s]

{'loss': 0.3931, 'grad_norm': 1.2702834606170654, 'learning_rate': 0.00010075187969924814, 'epoch': 0.51}


 52%|█████▏    | 143/276 [04:20<02:05,  1.06it/s]

{'loss': 0.4011, 'grad_norm': 1.3328301906585693, 'learning_rate': 0.0001, 'epoch': 0.52}


 52%|█████▏    | 144/276 [04:21<02:11,  1.01it/s]

{'loss': 0.4153, 'grad_norm': 1.0483112335205078, 'learning_rate': 9.924812030075187e-05, 'epoch': 0.52}


 53%|█████▎    | 145/276 [04:22<02:13,  1.02s/it]

{'loss': 0.4495, 'grad_norm': 1.2352298498153687, 'learning_rate': 9.849624060150377e-05, 'epoch': 0.53}


 53%|█████▎    | 146/276 [04:23<02:09,  1.00it/s]

{'loss': 0.6089, 'grad_norm': 1.1940081119537354, 'learning_rate': 9.774436090225564e-05, 'epoch': 0.53}


 53%|█████▎    | 147/276 [04:24<02:06,  1.02it/s]

{'loss': 0.3765, 'grad_norm': 0.9796791672706604, 'learning_rate': 9.699248120300752e-05, 'epoch': 0.53}


 54%|█████▎    | 148/276 [04:25<02:09,  1.01s/it]

{'loss': 0.3367, 'grad_norm': 1.0151063203811646, 'learning_rate': 9.62406015037594e-05, 'epoch': 0.54}


 54%|█████▍    | 149/276 [04:26<02:08,  1.01s/it]

{'loss': 0.4426, 'grad_norm': 1.1767115592956543, 'learning_rate': 9.548872180451128e-05, 'epoch': 0.54}


 54%|█████▍    | 150/276 [04:27<02:08,  1.02s/it]

{'loss': 0.4136, 'grad_norm': 0.9587398171424866, 'learning_rate': 9.473684210526316e-05, 'epoch': 0.54}


 55%|█████▍    | 151/276 [04:28<02:05,  1.01s/it]

{'loss': 1.3086, 'grad_norm': 5.914578437805176, 'learning_rate': 9.398496240601504e-05, 'epoch': 0.55}


 55%|█████▌    | 152/276 [04:29<02:03,  1.00it/s]

{'loss': 0.7881, 'grad_norm': 4.638031005859375, 'learning_rate': 9.323308270676691e-05, 'epoch': 0.55}


 55%|█████▌    | 153/276 [04:30<02:02,  1.00it/s]

{'loss': 0.7656, 'grad_norm': 2.267179012298584, 'learning_rate': 9.24812030075188e-05, 'epoch': 0.55}


 56%|█████▌    | 154/276 [04:31<02:06,  1.04s/it]

{'loss': 0.666, 'grad_norm': 4.3957743644714355, 'learning_rate': 9.172932330827067e-05, 'epoch': 0.56}


 56%|█████▌    | 155/276 [04:32<02:09,  1.07s/it]

{'loss': 0.6489, 'grad_norm': 1.7613804340362549, 'learning_rate': 9.097744360902256e-05, 'epoch': 0.56}


 57%|█████▋    | 156/276 [04:34<02:18,  1.15s/it]

{'loss': 0.5444, 'grad_norm': 1.6505614519119263, 'learning_rate': 9.022556390977444e-05, 'epoch': 0.57}


 57%|█████▋    | 157/276 [04:35<02:23,  1.20s/it]

{'loss': 0.5205, 'grad_norm': 0.9608510732650757, 'learning_rate': 8.947368421052632e-05, 'epoch': 0.57}


 57%|█████▋    | 158/276 [04:36<02:16,  1.16s/it]

{'loss': 0.5654, 'grad_norm': 2.409430503845215, 'learning_rate': 8.87218045112782e-05, 'epoch': 0.57}


 58%|█████▊    | 159/276 [04:37<02:08,  1.10s/it]

{'loss': 0.5918, 'grad_norm': 1.905513882637024, 'learning_rate': 8.796992481203008e-05, 'epoch': 0.58}


 58%|█████▊    | 160/276 [04:38<02:10,  1.13s/it]

{'loss': 0.728, 'grad_norm': 2.3244152069091797, 'learning_rate': 8.721804511278195e-05, 'epoch': 0.58}


 58%|█████▊    | 161/276 [04:39<02:09,  1.13s/it]

{'loss': 0.5376, 'grad_norm': 1.2378532886505127, 'learning_rate': 8.646616541353384e-05, 'epoch': 0.58}


 59%|█████▊    | 162/276 [04:40<02:04,  1.10s/it]

{'loss': 0.5693, 'grad_norm': 1.333698034286499, 'learning_rate': 8.571428571428571e-05, 'epoch': 0.59}


 59%|█████▉    | 163/276 [04:42<02:06,  1.12s/it]

{'loss': 0.6313, 'grad_norm': 1.3205636739730835, 'learning_rate': 8.49624060150376e-05, 'epoch': 0.59}


 59%|█████▉    | 164/276 [04:43<02:07,  1.13s/it]

{'loss': 0.5493, 'grad_norm': 1.2547458410263062, 'learning_rate': 8.421052631578948e-05, 'epoch': 0.59}


 60%|█████▉    | 165/276 [04:44<02:01,  1.09s/it]

{'loss': 0.5254, 'grad_norm': 1.3122588396072388, 'learning_rate': 8.345864661654136e-05, 'epoch': 0.6}


 60%|██████    | 166/276 [04:45<01:59,  1.08s/it]

{'loss': 0.5615, 'grad_norm': 1.112001895904541, 'learning_rate': 8.270676691729324e-05, 'epoch': 0.6}


 61%|██████    | 167/276 [04:46<01:59,  1.09s/it]

{'loss': 0.4817, 'grad_norm': 1.2292543649673462, 'learning_rate': 8.195488721804512e-05, 'epoch': 0.61}


 61%|██████    | 168/276 [04:47<01:57,  1.08s/it]

{'loss': 0.5996, 'grad_norm': 1.0559923648834229, 'learning_rate': 8.120300751879699e-05, 'epoch': 0.61}



 61%|██████    | 168/276 [05:46<01:57,  1.08s/it]

{'eval_loss': 0.49787360429763794, 'eval_runtime': 58.541, 'eval_samples_per_second': 2.357, 'eval_steps_per_second': 2.357, 'epoch': 0.61}


 61%|██████    | 169/276 [05:47<33:33, 18.82s/it]

{'loss': 0.541, 'grad_norm': 1.0026582479476929, 'learning_rate': 8.045112781954888e-05, 'epoch': 0.61}


 62%|██████▏   | 170/276 [05:49<24:19, 13.77s/it]

{'loss': 0.5713, 'grad_norm': 1.3466386795043945, 'learning_rate': 7.969924812030075e-05, 'epoch': 0.62}


 62%|██████▏   | 171/276 [05:51<17:39, 10.09s/it]

{'loss': 0.5088, 'grad_norm': 1.1714128255844116, 'learning_rate': 7.894736842105263e-05, 'epoch': 0.62}


 62%|██████▏   | 172/276 [05:53<13:30,  7.80s/it]

{'loss': 0.436, 'grad_norm': 1.1897695064544678, 'learning_rate': 7.819548872180451e-05, 'epoch': 0.62}


 63%|██████▎   | 173/276 [05:56<10:43,  6.24s/it]

{'loss': 0.4009, 'grad_norm': 1.4270508289337158, 'learning_rate': 7.74436090225564e-05, 'epoch': 0.63}


 63%|██████▎   | 174/276 [05:58<08:38,  5.08s/it]

{'loss': 0.5122, 'grad_norm': 1.131439208984375, 'learning_rate': 7.669172932330826e-05, 'epoch': 0.63}


 63%|██████▎   | 175/276 [06:00<06:53,  4.09s/it]

{'loss': 0.4526, 'grad_norm': 1.0898324251174927, 'learning_rate': 7.593984962406016e-05, 'epoch': 0.63}


 64%|██████▍   | 176/276 [06:01<05:33,  3.33s/it]

{'loss': 0.4431, 'grad_norm': 1.0577200651168823, 'learning_rate': 7.518796992481203e-05, 'epoch': 0.64}


 64%|██████▍   | 177/276 [06:03<04:49,  2.92s/it]

{'loss': 0.3564, 'grad_norm': 0.9710974097251892, 'learning_rate': 7.443609022556392e-05, 'epoch': 0.64}


 64%|██████▍   | 178/276 [06:05<03:52,  2.38s/it]

{'loss': 0.436, 'grad_norm': 1.6010806560516357, 'learning_rate': 7.368421052631579e-05, 'epoch': 0.64}


 65%|██████▍   | 179/276 [06:06<03:10,  1.97s/it]

{'loss': 0.5269, 'grad_norm': 1.532763957977295, 'learning_rate': 7.293233082706767e-05, 'epoch': 0.65}


 65%|██████▌   | 180/276 [06:07<02:50,  1.78s/it]

{'loss': 0.4192, 'grad_norm': 0.9425287246704102, 'learning_rate': 7.218045112781955e-05, 'epoch': 0.65}


 66%|██████▌   | 181/276 [06:09<03:11,  2.01s/it]

{'loss': 0.5415, 'grad_norm': 1.2171251773834229, 'learning_rate': 7.142857142857143e-05, 'epoch': 0.66}


 66%|██████▌   | 182/276 [06:12<03:29,  2.23s/it]

{'loss': 0.6318, 'grad_norm': 1.4027799367904663, 'learning_rate': 7.06766917293233e-05, 'epoch': 0.66}


 66%|██████▋   | 183/276 [06:14<03:23,  2.19s/it]

{'loss': 0.4082, 'grad_norm': 0.980003833770752, 'learning_rate': 6.99248120300752e-05, 'epoch': 0.66}


 67%|██████▋   | 184/276 [06:17<03:40,  2.40s/it]

{'loss': 0.4058, 'grad_norm': 1.0560064315795898, 'learning_rate': 6.917293233082706e-05, 'epoch': 0.67}


 67%|██████▋   | 185/276 [06:19<03:36,  2.38s/it]

{'loss': 0.5151, 'grad_norm': 1.5396226644515991, 'learning_rate': 6.842105263157895e-05, 'epoch': 0.67}


 67%|██████▋   | 186/276 [06:21<03:04,  2.05s/it]

{'loss': 0.353, 'grad_norm': 1.305216670036316, 'learning_rate': 6.766917293233083e-05, 'epoch': 0.67}


 68%|██████▊   | 187/276 [06:22<02:43,  1.84s/it]

{'loss': 0.6133, 'grad_norm': 2.442368268966675, 'learning_rate': 6.691729323308271e-05, 'epoch': 0.68}


 68%|██████▊   | 188/276 [06:24<02:34,  1.75s/it]

{'loss': 0.4375, 'grad_norm': 1.1237207651138306, 'learning_rate': 6.616541353383459e-05, 'epoch': 0.68}


 68%|██████▊   | 189/276 [06:25<02:32,  1.75s/it]

{'loss': 0.3777, 'grad_norm': 1.2549282312393188, 'learning_rate': 6.541353383458647e-05, 'epoch': 0.68}


 69%|██████▉   | 190/276 [06:28<02:49,  1.97s/it]

{'loss': 0.3877, 'grad_norm': 1.084622859954834, 'learning_rate': 6.466165413533834e-05, 'epoch': 0.69}


 69%|██████▉   | 191/276 [06:30<03:01,  2.13s/it]

{'loss': 0.4739, 'grad_norm': 1.2032749652862549, 'learning_rate': 6.390977443609024e-05, 'epoch': 0.69}


 70%|██████▉   | 192/276 [06:33<03:01,  2.16s/it]

{'loss': 0.29, 'grad_norm': 0.9543437957763672, 'learning_rate': 6.31578947368421e-05, 'epoch': 0.7}


 70%|██████▉   | 193/276 [06:34<02:51,  2.06s/it]

{'loss': 0.4197, 'grad_norm': 1.2533046007156372, 'learning_rate': 6.240601503759398e-05, 'epoch': 0.7}


 70%|███████   | 194/276 [06:36<02:39,  1.94s/it]

{'loss': 0.5083, 'grad_norm': 1.1142679452896118, 'learning_rate': 6.165413533834587e-05, 'epoch': 0.7}


 71%|███████   | 195/276 [06:38<02:36,  1.93s/it]

In [None]:
wandb.finish()

In [None]:
# Instruction tailored to credit card approval context
instruction = """You are a highly knowledgeable financial advisor specializing in credit card approvals. 
    Be informative, polite, and provide clear responses to any queries regarding credit approval decisions.
    """

# Example message (user asking about credit card approval)
messages = [
    {"role": "system", "content": instruction},
    {"role": "user", "content": "Can I know why my credit card application was rejected? My age is 30, income is $40,000, and credit score is 580."}
]

# Generate the prompt using the chat template (assuming tokenizer.apply_chat_template is a custom method for your setup)
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)

# Tokenize the prompt
inputs = tokenizer(prompt, return_tensors='pt', padding=True, truncation=True).to("cuda")

# Generate model outputs (adjusting parameters if necessary)
outputs = model.generate(**inputs, max_new_tokens=150, num_return_sequences=1)

# Decode the model's response
text = tokenizer.decode(outputs[0], skip_special_tokens=True)

# Print the assistant's response (assuming the response begins after the 'assistant' token)
print(text.split("assistant")[1])

# Save the fine-tuned model and tokenizer for future use
model.save_pretrained(new_model)
tokenizer.save_pretrained(new_model)
