https://blog.devgenius.io/sculpting-language-gpt-2-fine-tuning-with-lora-1caf3bfbc3c6

In [1]:
import json
import torch
import numpy as np
import transformers
import pandas as pd
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments

  from .autonotebook import tqdm as notebook_tqdm


# Prepare Dataset

In [2]:
def train_test_split(df: pd.DataFrame, test_size=0.3, random_state: int =None):
    if random_state:
        np.random.seed(random_state)
    
    # Shuffle indices
    shuffled_indices = np.random.permutation(len(df))
    
    # Split into train and test indices
    test_size = int(len(df) * test_size)
    test_indices = shuffled_indices[:test_size]
    train_indices = shuffled_indices[test_size:]
    
    # Return train and test splits
    return df.iloc[train_indices], df.iloc[test_indices]

In [3]:
# Load the dataset
dataset = load_dataset('json', data_files={'train': 'dataset/data_eng.json'})

# Convert the dataset to a pandas DataFrame
df = pd.DataFrame(dataset['train'])

# Display the number of rows before removing duplicates
print(f"Number of rows before deduplication: {df.shape[0]}")

# Remove duplicates
df = df.drop_duplicates()

# Display the number of rows after removing duplicates
print(f"Number of rows after deduplication: {df.shape[0]}")

# Remove \r characters from the text in new_df
df['text'] = "Extract the job title from the provided text\ntext: " + df["input"] + "\njob title: " + df["target"] + " <STOP>"

# Split the dataset into train, validation, and test sets (70% train, 15% validation, 15% test)
train_df, val_df = train_test_split(df[['text']], test_size=0.3, random_state=42)
# val_df, test_df = train_test_split(temp_df, test_size=0.3, random_state=42)


print(f"Train Dataframe {train_df.shape[0]}")
# print(f"Test Dataframe {test_df.shape[0]}")
print(f"Validation Dataframe {val_df.shape[0]}")

# Convert DataFrames to list of dictionaries for JSON format
train_data = train_df.to_dict(orient='records')
val_data = val_df.to_dict(orient='records')
# test_data = test_df.to_dict(orient='records')

# Save the datasets to JSON files
with open('data/train_data.json', 'w') as f:
    json.dump(train_data, f)

with open('data/val_data.json', 'w') as f:
    json.dump(val_data, f)

# with open('data/test_data.json', 'w') as f:
#     json.dump(test_data, f)

print("Train, validation, and test datasets created and saved successfully in JSON format!")

Number of rows before deduplication: 2600
Number of rows after deduplication: 1848
Train Dataframe 1294
Validation Dataframe 554
Train, validation, and test datasets created and saved successfully in JSON format!


# Training

## Configuration Before Training

In [4]:
cache_dir = "models"
modelID = "openai-community/gpt2"
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [6]:
tokenizer = AutoTokenizer.from_pretrained(modelID, cache_dir=cache_dir)

# Set padding token
tokenizer.padding_side = "right"                # Set padding side to left
tokenizer.pad_token = tokenizer.eos_token      # Using eos_token as pad_token

model = AutoModelForCausalLM.from_pretrained(modelID, device_map='auto', cache_dir=cache_dir)

model.config.pad_token_id = tokenizer.pad_token_id  # Set the pad_token_id in the model config



In [7]:
# FREEZE WEIGHTS
for param in model.parameters():
    param.requires_grad = False

In [8]:
lora_config = LoraConfig(
    r=16,                      # Rank
    lora_alpha=32,            # Alpha parameter for LoRA
    lora_dropout=0.05,         # Dropout for LoRA
    bias="none",              # Choose bias (none, all, or lora)
    task_type=TaskType.CAUSAL_LM,  # Set to Causal Language Modeling
)
model = get_peft_model(model, lora_config)



In [9]:
# Load train, validation, and test datasets
dataset = load_dataset('json', data_files={
    'train': 'data/train_data.json',
    'validation': 'data/val_data.json',
    # 'test': 'data/test_data.json'
})

Generating train split: 1293 examples [00:00, 92376.42 examples/s]
Generating validation split: 555 examples [00:00, 50359.96 examples/s]


In [10]:
import math

# Load your training data from train_data.json
with open('data/train_data.json', 'r') as file:
    train_data = json.load(file)


# Get lengths of tokenized texts
lengths = [len(tokenizer(obj['text'])['input_ids']) for obj in train_data]

def next_power_of_2(n):
    if n < 1:
        raise ValueError("Input must be a positive integer.")

    # Calculate the power of 2 using logarithm
    power = math.ceil(math.log2(n))  # Get the smallest integer >= log2(n)
    
    # Return 2 raised to the calculated power
    return 2 ** power

print("Maximum length:", max(lengths))
max_length = next_power_of_2(max(lengths))
print("max_length:",max_length)

Maximum length: 46
max_length: 64


In [11]:
# Tokenize the dataset and create labels
def tokenize_function(examples):
    tokenized = tokenizer(examples['text'], padding='max_length', truncation=True, max_length=max_length)
    # Create labels (shifted input for language modeling)
    tokenized['labels'] = tokenized['input_ids'].copy()
    return tokenized

In [12]:
# Tokenize the dataset
tokenized_datasets = dataset.map(tokenize_function, batched=True)

Map: 100%|██████████| 1293/1293 [00:00<00:00, 21551.05 examples/s]
Map: 100%|██████████| 555/555 [00:00<00:00, 17342.81 examples/s]


In [13]:
# Set format for PyTorch
tokenized_datasets.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [14]:
def print_trainable_parameters(model):
    """
    Prints the number of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    for _, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            trainable_params += param.numel()
    print(
        f"trainable params: {trainable_params} || all params: {all_param} || trainable%: {100 * trainable_params / all_param}"
    )

In [15]:
print_trainable_parameters(model)

trainable params: 589824 || all params: 125029632 || trainable%: 0.4717473694555863


In [16]:
batch = 4
training_args = TrainingArguments(
    per_device_train_batch_size=batch,
    gradient_accumulation_steps=batch,
    warmup_steps=10,
    max_steps=500,
    # num_train_epochs=3, 
    learning_rate=2e-4,
    logging_steps=batch*2,
    output_dir='outputs',
    auto_find_batch_size=True
)

In [17]:
# Initialize Trainer
trainer = Trainer(
    model=model,                        
    args=training_args,                 
    train_dataset=tokenized_datasets['train'],
    eval_dataset=tokenized_datasets['validation'],  # Use validation set here
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)

max_steps is given, it will override any value given in num_train_epochs


In [18]:
model.config.use_cache = False

In [19]:
trainer.train()

  attn_output = torch.nn.functional.scaled_dot_product_attention(
  0%|          | 1/500 [00:00<04:09,  2.00it/s]

{'loss': 5.7891, 'grad_norm': 0.99388587474823, 'learning_rate': 2e-05, 'epoch': 0.01}


  0%|          | 2/500 [00:00<02:52,  2.88it/s]

{'loss': 5.6777, 'grad_norm': 0.8763487339019775, 'learning_rate': 4e-05, 'epoch': 0.02}


  1%|          | 3/500 [00:00<02:27,  3.38it/s]

{'loss': 5.7591, 'grad_norm': 0.9175100326538086, 'learning_rate': 6e-05, 'epoch': 0.04}


  1%|          | 4/500 [00:01<02:15,  3.66it/s]

{'loss': 5.5955, 'grad_norm': 1.0750375986099243, 'learning_rate': 8e-05, 'epoch': 0.05}


  1%|          | 5/500 [00:01<02:09,  3.82it/s]

{'loss': 5.5484, 'grad_norm': 1.0198581218719482, 'learning_rate': 0.0001, 'epoch': 0.06}


  1%|          | 6/500 [00:01<02:05,  3.93it/s]

{'loss': 5.6455, 'grad_norm': 1.0925275087356567, 'learning_rate': 0.00012, 'epoch': 0.07}


  1%|▏         | 7/500 [00:01<02:05,  3.94it/s]

{'loss': 5.5884, 'grad_norm': 1.0281507968902588, 'learning_rate': 0.00014, 'epoch': 0.09}


  2%|▏         | 8/500 [00:02<02:02,  4.01it/s]

{'loss': 5.6945, 'grad_norm': 1.164261817932129, 'learning_rate': 0.00016, 'epoch': 0.1}


  2%|▏         | 9/500 [00:02<02:01,  4.06it/s]

{'loss': 5.5243, 'grad_norm': 1.0436813831329346, 'learning_rate': 0.00018, 'epoch': 0.11}


  2%|▏         | 10/500 [00:02<02:00,  4.06it/s]

{'loss': 5.4563, 'grad_norm': 1.079095721244812, 'learning_rate': 0.0002, 'epoch': 0.12}


  2%|▏         | 11/500 [00:02<02:00,  4.06it/s]

{'loss': 5.5794, 'grad_norm': 1.2673184871673584, 'learning_rate': 0.0001995918367346939, 'epoch': 0.14}


  2%|▏         | 12/500 [00:03<01:59,  4.08it/s]

{'loss': 5.5709, 'grad_norm': 1.3143059015274048, 'learning_rate': 0.00019918367346938775, 'epoch': 0.15}


  3%|▎         | 13/500 [00:03<01:59,  4.09it/s]

{'loss': 5.5407, 'grad_norm': 1.290617823600769, 'learning_rate': 0.00019877551020408164, 'epoch': 0.16}


  3%|▎         | 14/500 [00:03<02:00,  4.05it/s]

{'loss': 5.2846, 'grad_norm': 1.2151299715042114, 'learning_rate': 0.00019836734693877553, 'epoch': 0.17}


  3%|▎         | 15/500 [00:03<02:01,  4.01it/s]

{'loss': 5.0641, 'grad_norm': 1.398290991783142, 'learning_rate': 0.00019795918367346938, 'epoch': 0.19}


  3%|▎         | 16/500 [00:04<01:59,  4.04it/s]

{'loss': 5.2246, 'grad_norm': 1.4275861978530884, 'learning_rate': 0.00019755102040816327, 'epoch': 0.2}


  3%|▎         | 17/500 [00:04<01:58,  4.07it/s]

{'loss': 4.8867, 'grad_norm': 1.462467074394226, 'learning_rate': 0.00019714285714285716, 'epoch': 0.21}


  4%|▎         | 18/500 [00:04<01:57,  4.10it/s]

{'loss': 4.8187, 'grad_norm': 1.491856336593628, 'learning_rate': 0.00019673469387755104, 'epoch': 0.22}


  4%|▍         | 19/500 [00:04<01:56,  4.12it/s]

{'loss': 5.0686, 'grad_norm': 2.115849494934082, 'learning_rate': 0.0001963265306122449, 'epoch': 0.23}


  4%|▍         | 20/500 [00:05<01:55,  4.14it/s]

{'loss': 4.7923, 'grad_norm': 1.7474935054779053, 'learning_rate': 0.0001959183673469388, 'epoch': 0.25}


  4%|▍         | 21/500 [00:05<01:55,  4.15it/s]

{'loss': 4.6636, 'grad_norm': 1.5726242065429688, 'learning_rate': 0.00019551020408163265, 'epoch': 0.26}


  4%|▍         | 22/500 [00:05<01:54,  4.16it/s]

{'loss': 4.7149, 'grad_norm': 1.7293367385864258, 'learning_rate': 0.00019510204081632656, 'epoch': 0.27}


  5%|▍         | 23/500 [00:05<01:54,  4.17it/s]

{'loss': 4.4972, 'grad_norm': 1.603661060333252, 'learning_rate': 0.00019469387755102042, 'epoch': 0.28}


  5%|▍         | 24/500 [00:06<01:54,  4.17it/s]

{'loss': 4.4374, 'grad_norm': 1.7698326110839844, 'learning_rate': 0.0001942857142857143, 'epoch': 0.3}


  5%|▌         | 25/500 [00:06<01:53,  4.17it/s]

{'loss': 4.2689, 'grad_norm': 1.6559691429138184, 'learning_rate': 0.00019387755102040816, 'epoch': 0.31}


  5%|▌         | 26/500 [00:06<01:53,  4.18it/s]

{'loss': 4.0816, 'grad_norm': 1.6458497047424316, 'learning_rate': 0.00019346938775510205, 'epoch': 0.32}


  5%|▌         | 27/500 [00:06<01:52,  4.19it/s]

{'loss': 4.2498, 'grad_norm': 1.6196962594985962, 'learning_rate': 0.00019306122448979593, 'epoch': 0.33}


  6%|▌         | 28/500 [00:07<01:52,  4.18it/s]

{'loss': 3.9193, 'grad_norm': 1.7878202199935913, 'learning_rate': 0.0001926530612244898, 'epoch': 0.35}


  6%|▌         | 29/500 [00:07<01:54,  4.11it/s]

{'loss': 4.1253, 'grad_norm': 1.9710205793380737, 'learning_rate': 0.00019224489795918368, 'epoch': 0.36}


  6%|▌         | 30/500 [00:07<01:54,  4.10it/s]

{'loss': 4.0591, 'grad_norm': 1.6702569723129272, 'learning_rate': 0.00019183673469387756, 'epoch': 0.37}


  6%|▌         | 31/500 [00:07<01:57,  3.98it/s]

{'loss': 3.7585, 'grad_norm': 1.814359188079834, 'learning_rate': 0.00019142857142857145, 'epoch': 0.38}


  6%|▋         | 32/500 [00:08<01:57,  3.99it/s]

{'loss': 3.737, 'grad_norm': 1.5100528001785278, 'learning_rate': 0.0001910204081632653, 'epoch': 0.4}


  7%|▋         | 33/500 [00:08<01:56,  4.02it/s]

{'loss': 3.8887, 'grad_norm': 1.4700597524642944, 'learning_rate': 0.0001906122448979592, 'epoch': 0.41}


  7%|▋         | 34/500 [00:08<01:55,  4.03it/s]

{'loss': 3.3824, 'grad_norm': 1.7927922010421753, 'learning_rate': 0.00019020408163265305, 'epoch': 0.42}


  7%|▋         | 35/500 [00:08<01:55,  4.03it/s]

{'loss': 3.6659, 'grad_norm': 1.5350183248519897, 'learning_rate': 0.00018979591836734697, 'epoch': 0.43}


  7%|▋         | 36/500 [00:09<01:54,  4.05it/s]

{'loss': 3.4378, 'grad_norm': 1.6027231216430664, 'learning_rate': 0.00018938775510204083, 'epoch': 0.44}


  7%|▋         | 37/500 [00:09<01:54,  4.06it/s]

{'loss': 3.287, 'grad_norm': 1.655875563621521, 'learning_rate': 0.0001889795918367347, 'epoch': 0.46}


  8%|▊         | 38/500 [00:09<01:56,  3.95it/s]

{'loss': 3.1557, 'grad_norm': 1.676769733428955, 'learning_rate': 0.00018857142857142857, 'epoch': 0.47}


  8%|▊         | 39/500 [00:09<01:56,  3.96it/s]

{'loss': 3.3966, 'grad_norm': 1.4622732400894165, 'learning_rate': 0.00018816326530612246, 'epoch': 0.48}


  8%|▊         | 40/500 [00:10<01:54,  4.00it/s]

{'loss': 3.1301, 'grad_norm': 1.4458069801330566, 'learning_rate': 0.00018775510204081634, 'epoch': 0.49}


  8%|▊         | 41/500 [00:10<01:55,  3.96it/s]

{'loss': 2.9691, 'grad_norm': 1.621411919593811, 'learning_rate': 0.00018734693877551023, 'epoch': 0.51}


  8%|▊         | 42/500 [00:10<01:55,  3.96it/s]

{'loss': 3.1204, 'grad_norm': 1.4900789260864258, 'learning_rate': 0.0001869387755102041, 'epoch': 0.52}


  9%|▊         | 43/500 [00:10<01:54,  3.98it/s]

{'loss': 2.9102, 'grad_norm': 1.502042293548584, 'learning_rate': 0.00018653061224489797, 'epoch': 0.53}


  9%|▉         | 44/500 [00:11<01:56,  3.92it/s]

{'loss': 2.7484, 'grad_norm': 1.5985779762268066, 'learning_rate': 0.00018612244897959183, 'epoch': 0.54}


  9%|▉         | 45/500 [00:11<01:59,  3.79it/s]

{'loss': 2.4946, 'grad_norm': 1.695258378982544, 'learning_rate': 0.00018571428571428572, 'epoch': 0.56}


  9%|▉         | 46/500 [00:11<01:58,  3.83it/s]

{'loss': 2.5543, 'grad_norm': 1.5039219856262207, 'learning_rate': 0.0001853061224489796, 'epoch': 0.57}


  9%|▉         | 47/500 [00:11<01:57,  3.84it/s]

{'loss': 2.4753, 'grad_norm': 1.4918897151947021, 'learning_rate': 0.00018489795918367346, 'epoch': 0.58}


 10%|▉         | 48/500 [00:12<01:56,  3.89it/s]

{'loss': 2.119, 'grad_norm': 1.4835295677185059, 'learning_rate': 0.00018448979591836735, 'epoch': 0.59}


 10%|▉         | 49/500 [00:12<01:55,  3.92it/s]

{'loss': 2.1371, 'grad_norm': 1.420097827911377, 'learning_rate': 0.00018408163265306123, 'epoch': 0.6}


 10%|█         | 50/500 [00:12<01:54,  3.94it/s]

{'loss': 2.2208, 'grad_norm': 1.3981046676635742, 'learning_rate': 0.00018367346938775512, 'epoch': 0.62}


 10%|█         | 51/500 [00:12<01:53,  3.97it/s]

{'loss': 2.0243, 'grad_norm': 1.414880394935608, 'learning_rate': 0.00018326530612244898, 'epoch': 0.63}


 10%|█         | 52/500 [00:13<01:53,  3.96it/s]

{'loss': 2.2407, 'grad_norm': 1.5511564016342163, 'learning_rate': 0.00018285714285714286, 'epoch': 0.64}


 11%|█         | 53/500 [00:13<01:52,  3.96it/s]

{'loss': 2.0756, 'grad_norm': 1.4143495559692383, 'learning_rate': 0.00018244897959183672, 'epoch': 0.65}


 11%|█         | 54/500 [00:13<01:52,  3.98it/s]

{'loss': 2.0969, 'grad_norm': 1.6097537279129028, 'learning_rate': 0.00018204081632653064, 'epoch': 0.67}


 11%|█         | 55/500 [00:13<01:51,  3.98it/s]

{'loss': 1.9398, 'grad_norm': 1.3168439865112305, 'learning_rate': 0.0001816326530612245, 'epoch': 0.68}


 11%|█         | 56/500 [00:14<01:51,  3.97it/s]

{'loss': 1.8243, 'grad_norm': 1.351765513420105, 'learning_rate': 0.00018122448979591838, 'epoch': 0.69}


 11%|█▏        | 57/500 [00:14<01:51,  3.96it/s]

{'loss': 1.499, 'grad_norm': 1.4088847637176514, 'learning_rate': 0.00018081632653061224, 'epoch': 0.7}


 12%|█▏        | 58/500 [00:14<01:51,  3.98it/s]

{'loss': 1.7839, 'grad_norm': 1.409450650215149, 'learning_rate': 0.00018040816326530615, 'epoch': 0.72}


 12%|█▏        | 59/500 [00:14<01:50,  3.99it/s]

{'loss': 1.8124, 'grad_norm': 1.3987782001495361, 'learning_rate': 0.00018, 'epoch': 0.73}


 12%|█▏        | 60/500 [00:15<01:52,  3.91it/s]

{'loss': 1.9576, 'grad_norm': 1.286457896232605, 'learning_rate': 0.0001795918367346939, 'epoch': 0.74}


 12%|█▏        | 61/500 [00:15<01:51,  3.93it/s]

{'loss': 1.7724, 'grad_norm': 1.1543934345245361, 'learning_rate': 0.00017918367346938776, 'epoch': 0.75}


 12%|█▏        | 62/500 [00:15<01:51,  3.93it/s]

{'loss': 1.7574, 'grad_norm': 1.1368556022644043, 'learning_rate': 0.00017877551020408164, 'epoch': 0.77}


 13%|█▎        | 63/500 [00:15<01:52,  3.89it/s]

{'loss': 1.7345, 'grad_norm': 1.172516107559204, 'learning_rate': 0.00017836734693877553, 'epoch': 0.78}


 13%|█▎        | 64/500 [00:16<01:51,  3.90it/s]

{'loss': 1.5904, 'grad_norm': 1.1089311838150024, 'learning_rate': 0.0001779591836734694, 'epoch': 0.79}


 13%|█▎        | 65/500 [00:16<01:50,  3.93it/s]

{'loss': 1.7384, 'grad_norm': 1.0269614458084106, 'learning_rate': 0.00017755102040816327, 'epoch': 0.8}


 13%|█▎        | 66/500 [00:16<01:49,  3.96it/s]

{'loss': 1.3574, 'grad_norm': 1.0893303155899048, 'learning_rate': 0.00017714285714285713, 'epoch': 0.81}


 13%|█▎        | 67/500 [00:16<01:48,  3.99it/s]

{'loss': 1.4277, 'grad_norm': 1.1960433721542358, 'learning_rate': 0.00017673469387755104, 'epoch': 0.83}


 14%|█▎        | 68/500 [00:17<01:47,  4.01it/s]

{'loss': 1.3742, 'grad_norm': 1.1259870529174805, 'learning_rate': 0.0001763265306122449, 'epoch': 0.84}


 14%|█▍        | 69/500 [00:17<01:46,  4.03it/s]

{'loss': 1.5256, 'grad_norm': 1.0298231840133667, 'learning_rate': 0.0001759183673469388, 'epoch': 0.85}


 14%|█▍        | 70/500 [00:17<01:46,  4.03it/s]

{'loss': 1.3633, 'grad_norm': 0.8990874886512756, 'learning_rate': 0.00017551020408163265, 'epoch': 0.86}


 14%|█▍        | 71/500 [00:17<01:45,  4.05it/s]

{'loss': 1.6452, 'grad_norm': 1.0129883289337158, 'learning_rate': 0.00017510204081632653, 'epoch': 0.88}


 14%|█▍        | 72/500 [00:18<01:45,  4.08it/s]

{'loss': 1.353, 'grad_norm': 0.8342068195343018, 'learning_rate': 0.00017469387755102042, 'epoch': 0.89}


 15%|█▍        | 73/500 [00:18<01:44,  4.08it/s]

{'loss': 1.754, 'grad_norm': 0.8723317980766296, 'learning_rate': 0.0001742857142857143, 'epoch': 0.9}


 15%|█▍        | 74/500 [00:18<01:44,  4.09it/s]

{'loss': 1.4425, 'grad_norm': 0.954399049282074, 'learning_rate': 0.00017387755102040816, 'epoch': 0.91}


 15%|█▌        | 75/500 [00:18<01:43,  4.09it/s]

{'loss': 1.2549, 'grad_norm': 0.962368369102478, 'learning_rate': 0.00017346938775510205, 'epoch': 0.93}


 15%|█▌        | 76/500 [00:19<01:45,  4.02it/s]

{'loss': 1.3718, 'grad_norm': 1.0068851709365845, 'learning_rate': 0.00017306122448979594, 'epoch': 0.94}


 15%|█▌        | 77/500 [00:19<01:49,  3.88it/s]

{'loss': 1.2409, 'grad_norm': 0.7926297783851624, 'learning_rate': 0.00017265306122448982, 'epoch': 0.95}


 16%|█▌        | 78/500 [00:19<01:46,  3.95it/s]

{'loss': 1.2151, 'grad_norm': 0.9876044392585754, 'learning_rate': 0.00017224489795918368, 'epoch': 0.96}


 16%|█▌        | 79/500 [00:19<01:45,  3.98it/s]

{'loss': 1.5372, 'grad_norm': 1.0759302377700806, 'learning_rate': 0.00017183673469387757, 'epoch': 0.98}


 16%|█▌        | 80/500 [00:20<01:44,  4.00it/s]

{'loss': 1.5003, 'grad_norm': 0.9743643403053284, 'learning_rate': 0.00017142857142857143, 'epoch': 0.99}


 16%|█▌        | 81/500 [00:20<01:41,  4.14it/s]

{'loss': 1.1192, 'grad_norm': 1.0909558534622192, 'learning_rate': 0.0001710204081632653, 'epoch': 1.0}


 16%|█▋        | 82/500 [00:20<01:42,  4.09it/s]

{'loss': 1.448, 'grad_norm': 0.7285757660865784, 'learning_rate': 0.0001706122448979592, 'epoch': 1.01}


 17%|█▋        | 83/500 [00:20<01:42,  4.08it/s]

{'loss': 1.1703, 'grad_norm': 0.8657848834991455, 'learning_rate': 0.00017020408163265306, 'epoch': 1.02}


 17%|█▋        | 84/500 [00:21<01:42,  4.07it/s]

{'loss': 1.2329, 'grad_norm': 0.7522028684616089, 'learning_rate': 0.00016979591836734694, 'epoch': 1.04}


 17%|█▋        | 85/500 [00:21<01:42,  4.07it/s]

{'loss': 1.3783, 'grad_norm': 0.8978303074836731, 'learning_rate': 0.00016938775510204083, 'epoch': 1.05}


 17%|█▋        | 86/500 [00:21<01:42,  4.05it/s]

{'loss': 1.1172, 'grad_norm': 1.0662012100219727, 'learning_rate': 0.0001689795918367347, 'epoch': 1.06}


 17%|█▋        | 87/500 [00:21<01:41,  4.06it/s]

{'loss': 1.2022, 'grad_norm': 0.7939725518226624, 'learning_rate': 0.00016857142857142857, 'epoch': 1.07}


 18%|█▊        | 88/500 [00:22<01:40,  4.09it/s]

{'loss': 1.0156, 'grad_norm': 0.8489658832550049, 'learning_rate': 0.00016816326530612246, 'epoch': 1.09}


 18%|█▊        | 89/500 [00:22<01:40,  4.07it/s]

{'loss': 1.0603, 'grad_norm': 0.7325296998023987, 'learning_rate': 0.00016775510204081632, 'epoch': 1.1}


 18%|█▊        | 90/500 [00:22<01:40,  4.09it/s]

{'loss': 1.2964, 'grad_norm': 0.9030181765556335, 'learning_rate': 0.00016734693877551023, 'epoch': 1.11}


 18%|█▊        | 91/500 [00:22<01:39,  4.09it/s]

{'loss': 1.2421, 'grad_norm': 0.7699439525604248, 'learning_rate': 0.0001669387755102041, 'epoch': 1.12}


 18%|█▊        | 92/500 [00:23<01:39,  4.09it/s]

{'loss': 1.1618, 'grad_norm': 0.8091244697570801, 'learning_rate': 0.00016653061224489797, 'epoch': 1.14}


 19%|█▊        | 93/500 [00:23<01:39,  4.08it/s]

{'loss': 0.9774, 'grad_norm': 0.632295548915863, 'learning_rate': 0.00016612244897959183, 'epoch': 1.15}


 19%|█▉        | 94/500 [00:23<01:39,  4.09it/s]

{'loss': 0.9847, 'grad_norm': 0.605323851108551, 'learning_rate': 0.00016571428571428575, 'epoch': 1.16}


 19%|█▉        | 95/500 [00:23<01:38,  4.09it/s]

{'loss': 1.6107, 'grad_norm': 0.8257647156715393, 'learning_rate': 0.0001653061224489796, 'epoch': 1.17}


 19%|█▉        | 96/500 [00:24<01:39,  4.08it/s]

{'loss': 1.2588, 'grad_norm': 0.766649603843689, 'learning_rate': 0.0001648979591836735, 'epoch': 1.19}


 19%|█▉        | 97/500 [00:24<01:38,  4.08it/s]

{'loss': 0.8455, 'grad_norm': 0.7755692601203918, 'learning_rate': 0.00016448979591836735, 'epoch': 1.2}


 20%|█▉        | 98/500 [00:24<01:38,  4.06it/s]

{'loss': 1.3938, 'grad_norm': 0.8010809421539307, 'learning_rate': 0.00016408163265306124, 'epoch': 1.21}


 20%|█▉        | 99/500 [00:24<01:38,  4.06it/s]

{'loss': 1.4563, 'grad_norm': 0.7961347103118896, 'learning_rate': 0.00016367346938775512, 'epoch': 1.22}


 20%|██        | 100/500 [00:25<01:38,  4.04it/s]

{'loss': 1.3862, 'grad_norm': 0.8024153709411621, 'learning_rate': 0.00016326530612244898, 'epoch': 1.23}


 20%|██        | 101/500 [00:25<01:39,  4.02it/s]

{'loss': 1.3657, 'grad_norm': 1.146589756011963, 'learning_rate': 0.00016285714285714287, 'epoch': 1.25}


 20%|██        | 102/500 [00:25<01:39,  3.99it/s]

{'loss': 1.2747, 'grad_norm': 0.8830966949462891, 'learning_rate': 0.00016244897959183672, 'epoch': 1.26}


 21%|██        | 103/500 [00:25<01:40,  3.96it/s]

{'loss': 1.15, 'grad_norm': 0.7373305559158325, 'learning_rate': 0.0001620408163265306, 'epoch': 1.27}


 21%|██        | 104/500 [00:26<01:40,  3.95it/s]

{'loss': 0.8587, 'grad_norm': 0.580769419670105, 'learning_rate': 0.0001616326530612245, 'epoch': 1.28}


 21%|██        | 105/500 [00:26<01:39,  3.99it/s]

{'loss': 1.5454, 'grad_norm': 0.9264122843742371, 'learning_rate': 0.00016122448979591838, 'epoch': 1.3}


 21%|██        | 106/500 [00:26<01:38,  4.00it/s]

{'loss': 1.2303, 'grad_norm': 0.6750348806381226, 'learning_rate': 0.00016081632653061224, 'epoch': 1.31}


 21%|██▏       | 107/500 [00:26<01:37,  4.02it/s]

{'loss': 1.1571, 'grad_norm': 0.7486207485198975, 'learning_rate': 0.00016040816326530613, 'epoch': 1.32}


 22%|██▏       | 108/500 [00:27<01:40,  3.91it/s]

{'loss': 1.0734, 'grad_norm': 0.604182243347168, 'learning_rate': 0.00016, 'epoch': 1.33}


 22%|██▏       | 109/500 [00:27<01:40,  3.88it/s]

{'loss': 1.0586, 'grad_norm': 0.7527517080307007, 'learning_rate': 0.0001595918367346939, 'epoch': 1.35}


 22%|██▏       | 110/500 [00:27<01:41,  3.84it/s]

{'loss': 1.3496, 'grad_norm': 0.9269075393676758, 'learning_rate': 0.00015918367346938776, 'epoch': 1.36}


 22%|██▏       | 111/500 [00:27<01:42,  3.80it/s]

{'loss': 1.2033, 'grad_norm': 0.6613150238990784, 'learning_rate': 0.00015877551020408164, 'epoch': 1.37}


 22%|██▏       | 112/500 [00:28<01:41,  3.84it/s]

{'loss': 0.9542, 'grad_norm': 0.6437546014785767, 'learning_rate': 0.0001583673469387755, 'epoch': 1.38}


 23%|██▎       | 113/500 [00:28<01:39,  3.90it/s]

{'loss': 1.2459, 'grad_norm': 0.6654561161994934, 'learning_rate': 0.00015795918367346942, 'epoch': 1.4}


 23%|██▎       | 114/500 [00:28<01:37,  3.94it/s]

{'loss': 1.3077, 'grad_norm': 0.7464560270309448, 'learning_rate': 0.00015755102040816327, 'epoch': 1.41}


 23%|██▎       | 115/500 [00:28<01:39,  3.87it/s]

{'loss': 1.3534, 'grad_norm': 0.8870065212249756, 'learning_rate': 0.00015714285714285716, 'epoch': 1.42}


 23%|██▎       | 116/500 [00:29<01:38,  3.88it/s]

{'loss': 1.0515, 'grad_norm': 0.594891369342804, 'learning_rate': 0.00015673469387755102, 'epoch': 1.43}


 23%|██▎       | 117/500 [00:29<01:39,  3.85it/s]

{'loss': 1.2852, 'grad_norm': 0.928194522857666, 'learning_rate': 0.0001563265306122449, 'epoch': 1.44}


 24%|██▎       | 118/500 [00:29<01:38,  3.88it/s]

{'loss': 1.376, 'grad_norm': 0.7941483855247498, 'learning_rate': 0.0001559183673469388, 'epoch': 1.46}


 24%|██▍       | 119/500 [00:29<01:38,  3.87it/s]

{'loss': 1.1593, 'grad_norm': 0.6771276593208313, 'learning_rate': 0.00015551020408163265, 'epoch': 1.47}


 24%|██▍       | 120/500 [00:30<01:38,  3.87it/s]

{'loss': 1.5341, 'grad_norm': 0.6947227120399475, 'learning_rate': 0.00015510204081632654, 'epoch': 1.48}


 24%|██▍       | 121/500 [00:30<01:39,  3.83it/s]

{'loss': 1.4319, 'grad_norm': 0.7066975235939026, 'learning_rate': 0.0001546938775510204, 'epoch': 1.49}


 24%|██▍       | 122/500 [00:30<01:43,  3.65it/s]

{'loss': 1.167, 'grad_norm': 0.6661967635154724, 'learning_rate': 0.0001542857142857143, 'epoch': 1.51}


 25%|██▍       | 123/500 [00:30<01:41,  3.73it/s]

{'loss': 1.274, 'grad_norm': 0.6508380174636841, 'learning_rate': 0.00015387755102040817, 'epoch': 1.52}


 25%|██▍       | 124/500 [00:31<01:38,  3.81it/s]

{'loss': 1.0298, 'grad_norm': 0.780297040939331, 'learning_rate': 0.00015346938775510205, 'epoch': 1.53}


 25%|██▌       | 125/500 [00:31<01:39,  3.78it/s]

{'loss': 1.1188, 'grad_norm': 0.7153838276863098, 'learning_rate': 0.0001530612244897959, 'epoch': 1.54}


 25%|██▌       | 126/500 [00:31<01:37,  3.83it/s]

{'loss': 1.0206, 'grad_norm': 0.6576055884361267, 'learning_rate': 0.00015265306122448982, 'epoch': 1.56}


 25%|██▌       | 127/500 [00:32<01:37,  3.82it/s]

{'loss': 1.431, 'grad_norm': 0.6952803134918213, 'learning_rate': 0.00015224489795918368, 'epoch': 1.57}


 26%|██▌       | 128/500 [00:32<01:37,  3.83it/s]

{'loss': 1.0376, 'grad_norm': 0.5500839948654175, 'learning_rate': 0.00015183673469387757, 'epoch': 1.58}


 26%|██▌       | 129/500 [00:32<01:35,  3.90it/s]

{'loss': 1.093, 'grad_norm': 1.0153294801712036, 'learning_rate': 0.00015142857142857143, 'epoch': 1.59}


 26%|██▌       | 130/500 [00:32<01:36,  3.84it/s]

{'loss': 1.1301, 'grad_norm': 0.6470506191253662, 'learning_rate': 0.0001510204081632653, 'epoch': 1.6}


 26%|██▌       | 131/500 [00:33<01:34,  3.89it/s]

{'loss': 1.2523, 'grad_norm': 0.7887449264526367, 'learning_rate': 0.0001506122448979592, 'epoch': 1.62}


 26%|██▋       | 132/500 [00:33<01:33,  3.94it/s]

{'loss': 1.5059, 'grad_norm': 0.8372064828872681, 'learning_rate': 0.00015020408163265306, 'epoch': 1.63}


 27%|██▋       | 133/500 [00:33<01:33,  3.95it/s]

{'loss': 1.2213, 'grad_norm': 0.6451519131660461, 'learning_rate': 0.00014979591836734694, 'epoch': 1.64}


 27%|██▋       | 134/500 [00:33<01:32,  3.97it/s]

{'loss': 1.3843, 'grad_norm': 0.6888554692268372, 'learning_rate': 0.00014938775510204083, 'epoch': 1.65}


 27%|██▋       | 135/500 [00:34<01:32,  3.97it/s]

{'loss': 1.408, 'grad_norm': 0.685683012008667, 'learning_rate': 0.00014897959183673472, 'epoch': 1.67}


 27%|██▋       | 136/500 [00:34<01:31,  3.96it/s]

{'loss': 1.1088, 'grad_norm': 0.550769567489624, 'learning_rate': 0.00014857142857142857, 'epoch': 1.68}


 27%|██▋       | 137/500 [00:34<01:31,  3.96it/s]

{'loss': 0.9492, 'grad_norm': 0.6809455752372742, 'learning_rate': 0.00014816326530612246, 'epoch': 1.69}


 28%|██▊       | 138/500 [00:34<01:30,  3.99it/s]

{'loss': 1.4075, 'grad_norm': 0.7054213881492615, 'learning_rate': 0.00014775510204081632, 'epoch': 1.7}


 28%|██▊       | 139/500 [00:35<01:30,  3.99it/s]

{'loss': 1.1252, 'grad_norm': 0.6513248085975647, 'learning_rate': 0.0001473469387755102, 'epoch': 1.72}


 28%|██▊       | 140/500 [00:35<01:31,  3.95it/s]

{'loss': 1.1913, 'grad_norm': 0.7500342726707458, 'learning_rate': 0.0001469387755102041, 'epoch': 1.73}


 28%|██▊       | 141/500 [00:35<01:30,  3.98it/s]

{'loss': 1.3139, 'grad_norm': 1.1779230833053589, 'learning_rate': 0.00014653061224489798, 'epoch': 1.74}


 28%|██▊       | 142/500 [00:35<01:29,  4.00it/s]

{'loss': 1.3961, 'grad_norm': 0.671419084072113, 'learning_rate': 0.00014612244897959183, 'epoch': 1.75}


 29%|██▊       | 143/500 [00:36<01:30,  3.96it/s]

{'loss': 1.2409, 'grad_norm': 0.677131712436676, 'learning_rate': 0.00014571428571428572, 'epoch': 1.77}


 29%|██▉       | 144/500 [00:36<01:29,  3.99it/s]

{'loss': 1.2961, 'grad_norm': 0.5578007698059082, 'learning_rate': 0.0001453061224489796, 'epoch': 1.78}


 29%|██▉       | 145/500 [00:36<01:28,  4.01it/s]

{'loss': 1.362, 'grad_norm': 0.6503255367279053, 'learning_rate': 0.0001448979591836735, 'epoch': 1.79}


 29%|██▉       | 146/500 [00:36<01:28,  4.00it/s]

{'loss': 0.9206, 'grad_norm': 0.6487815380096436, 'learning_rate': 0.00014448979591836735, 'epoch': 1.8}


 29%|██▉       | 147/500 [00:37<01:27,  4.02it/s]

{'loss': 1.0242, 'grad_norm': 0.6709322929382324, 'learning_rate': 0.00014408163265306124, 'epoch': 1.81}


 30%|██▉       | 148/500 [00:37<01:27,  4.00it/s]

{'loss': 1.305, 'grad_norm': 0.7034499645233154, 'learning_rate': 0.0001436734693877551, 'epoch': 1.83}


 30%|██▉       | 149/500 [00:37<01:27,  4.01it/s]

{'loss': 1.2076, 'grad_norm': 0.6803886890411377, 'learning_rate': 0.00014326530612244898, 'epoch': 1.84}


 30%|███       | 150/500 [00:37<01:27,  4.01it/s]

{'loss': 1.1219, 'grad_norm': 0.7591019868850708, 'learning_rate': 0.00014285714285714287, 'epoch': 1.85}


 30%|███       | 151/500 [00:38<01:29,  3.92it/s]

{'loss': 1.1991, 'grad_norm': 0.7377718687057495, 'learning_rate': 0.00014244897959183673, 'epoch': 1.86}


 30%|███       | 152/500 [00:38<01:28,  3.95it/s]

{'loss': 1.1042, 'grad_norm': 0.5289999842643738, 'learning_rate': 0.0001420408163265306, 'epoch': 1.88}


 31%|███       | 153/500 [00:38<01:30,  3.83it/s]

{'loss': 1.3734, 'grad_norm': 0.6951505541801453, 'learning_rate': 0.0001416326530612245, 'epoch': 1.89}


 31%|███       | 154/500 [00:38<01:30,  3.81it/s]

{'loss': 1.2605, 'grad_norm': 0.6177518367767334, 'learning_rate': 0.00014122448979591838, 'epoch': 1.9}


 31%|███       | 155/500 [00:39<01:31,  3.79it/s]

{'loss': 1.0456, 'grad_norm': 0.7263686060905457, 'learning_rate': 0.00014081632653061224, 'epoch': 1.91}


 31%|███       | 156/500 [00:39<01:30,  3.81it/s]

{'loss': 1.429, 'grad_norm': 0.7825396060943604, 'learning_rate': 0.00014040816326530613, 'epoch': 1.93}


 31%|███▏      | 157/500 [00:39<01:32,  3.70it/s]

{'loss': 1.5208, 'grad_norm': 0.7214285135269165, 'learning_rate': 0.00014, 'epoch': 1.94}


 32%|███▏      | 158/500 [00:39<01:35,  3.59it/s]

{'loss': 1.2384, 'grad_norm': 0.6396978497505188, 'learning_rate': 0.0001395918367346939, 'epoch': 1.95}


 32%|███▏      | 159/500 [00:40<01:36,  3.53it/s]

{'loss': 1.3247, 'grad_norm': 0.7131567597389221, 'learning_rate': 0.00013918367346938776, 'epoch': 1.96}


 32%|███▏      | 160/500 [00:40<01:36,  3.54it/s]

{'loss': 1.1473, 'grad_norm': 0.599135160446167, 'learning_rate': 0.00013877551020408165, 'epoch': 1.98}


 32%|███▏      | 161/500 [00:40<01:33,  3.63it/s]

{'loss': 1.3648, 'grad_norm': 0.7186665534973145, 'learning_rate': 0.0001383673469387755, 'epoch': 1.99}


 32%|███▏      | 162/500 [00:41<01:28,  3.83it/s]

{'loss': 1.2275, 'grad_norm': 0.7274830341339111, 'learning_rate': 0.00013795918367346942, 'epoch': 2.0}


 33%|███▎      | 163/500 [00:41<01:27,  3.87it/s]

{'loss': 1.0618, 'grad_norm': 0.7130605578422546, 'learning_rate': 0.00013755102040816328, 'epoch': 2.01}


 33%|███▎      | 164/500 [00:41<01:26,  3.91it/s]

{'loss': 1.2727, 'grad_norm': 0.5285937786102295, 'learning_rate': 0.00013714285714285716, 'epoch': 2.02}


 33%|███▎      | 165/500 [00:41<01:25,  3.93it/s]

{'loss': 1.2949, 'grad_norm': 0.6412875056266785, 'learning_rate': 0.00013673469387755102, 'epoch': 2.04}


 33%|███▎      | 166/500 [00:42<01:24,  3.94it/s]

{'loss': 0.7951, 'grad_norm': 0.6191409230232239, 'learning_rate': 0.0001363265306122449, 'epoch': 2.05}


 33%|███▎      | 167/500 [00:42<01:24,  3.94it/s]

{'loss': 1.4306, 'grad_norm': 0.8368130922317505, 'learning_rate': 0.0001359183673469388, 'epoch': 2.06}


 34%|███▎      | 168/500 [00:42<01:24,  3.92it/s]

{'loss': 0.9265, 'grad_norm': 0.6485166549682617, 'learning_rate': 0.00013551020408163265, 'epoch': 2.07}


 34%|███▍      | 169/500 [00:42<01:26,  3.83it/s]

{'loss': 0.9935, 'grad_norm': 0.7643015384674072, 'learning_rate': 0.00013510204081632654, 'epoch': 2.09}


 34%|███▍      | 170/500 [00:43<01:35,  3.44it/s]

{'loss': 1.0669, 'grad_norm': 0.7061769962310791, 'learning_rate': 0.0001346938775510204, 'epoch': 2.1}


 34%|███▍      | 171/500 [00:43<01:35,  3.45it/s]

{'loss': 1.1225, 'grad_norm': 0.6664972305297852, 'learning_rate': 0.00013428571428571428, 'epoch': 2.11}


 34%|███▍      | 172/500 [00:43<01:30,  3.62it/s]

{'loss': 1.1662, 'grad_norm': 0.62553471326828, 'learning_rate': 0.00013387755102040817, 'epoch': 2.12}


 35%|███▍      | 173/500 [00:44<01:31,  3.59it/s]

{'loss': 1.1558, 'grad_norm': 0.6945876479148865, 'learning_rate': 0.00013346938775510205, 'epoch': 2.14}


 35%|███▍      | 174/500 [00:44<01:30,  3.59it/s]

{'loss': 1.1899, 'grad_norm': 0.5926531553268433, 'learning_rate': 0.0001330612244897959, 'epoch': 2.15}


 35%|███▌      | 175/500 [00:44<01:30,  3.58it/s]

{'loss': 1.1464, 'grad_norm': 0.7649738192558289, 'learning_rate': 0.0001326530612244898, 'epoch': 2.16}


 35%|███▌      | 176/500 [00:44<01:31,  3.53it/s]

{'loss': 1.3193, 'grad_norm': 0.6857064366340637, 'learning_rate': 0.00013224489795918368, 'epoch': 2.17}


 35%|███▌      | 177/500 [00:45<01:30,  3.56it/s]

{'loss': 0.9011, 'grad_norm': 0.5734919905662537, 'learning_rate': 0.00013183673469387757, 'epoch': 2.19}


 36%|███▌      | 178/500 [00:45<01:30,  3.57it/s]

{'loss': 0.8804, 'grad_norm': 0.6360177993774414, 'learning_rate': 0.00013142857142857143, 'epoch': 2.2}


 36%|███▌      | 179/500 [00:45<01:29,  3.57it/s]

{'loss': 1.0716, 'grad_norm': 0.8265212774276733, 'learning_rate': 0.00013102040816326531, 'epoch': 2.21}


 36%|███▌      | 180/500 [00:45<01:29,  3.57it/s]

{'loss': 1.098, 'grad_norm': 0.6826854944229126, 'learning_rate': 0.00013061224489795917, 'epoch': 2.22}


 36%|███▌      | 181/500 [00:46<01:28,  3.61it/s]

{'loss': 1.3227, 'grad_norm': 0.6153204441070557, 'learning_rate': 0.00013020408163265309, 'epoch': 2.23}


 36%|███▋      | 182/500 [00:46<01:27,  3.62it/s]

{'loss': 1.171, 'grad_norm': 0.5622444152832031, 'learning_rate': 0.00012979591836734695, 'epoch': 2.25}


 37%|███▋      | 183/500 [00:46<01:27,  3.64it/s]

{'loss': 1.1779, 'grad_norm': 0.6894166469573975, 'learning_rate': 0.00012938775510204083, 'epoch': 2.26}


 37%|███▋      | 184/500 [00:47<01:26,  3.66it/s]

{'loss': 1.1127, 'grad_norm': 0.687205970287323, 'learning_rate': 0.0001289795918367347, 'epoch': 2.27}


 37%|███▋      | 185/500 [00:47<01:26,  3.66it/s]

{'loss': 1.1991, 'grad_norm': 0.7003761529922485, 'learning_rate': 0.00012857142857142858, 'epoch': 2.28}


 37%|███▋      | 186/500 [00:47<01:26,  3.64it/s]

{'loss': 1.1983, 'grad_norm': 0.5663744211196899, 'learning_rate': 0.00012816326530612246, 'epoch': 2.3}


 37%|███▋      | 187/500 [00:47<01:25,  3.67it/s]

{'loss': 0.8589, 'grad_norm': 0.4839745759963989, 'learning_rate': 0.00012775510204081632, 'epoch': 2.31}


 38%|███▊      | 188/500 [00:48<01:24,  3.70it/s]

{'loss': 1.3798, 'grad_norm': 0.6429716944694519, 'learning_rate': 0.0001273469387755102, 'epoch': 2.32}


 38%|███▊      | 189/500 [00:48<01:22,  3.78it/s]

{'loss': 1.2436, 'grad_norm': 0.626857578754425, 'learning_rate': 0.00012693877551020406, 'epoch': 2.33}


 38%|███▊      | 190/500 [00:48<01:23,  3.73it/s]

{'loss': 1.0587, 'grad_norm': 0.6010022759437561, 'learning_rate': 0.00012653061224489798, 'epoch': 2.35}


 38%|███▊      | 191/500 [00:48<01:21,  3.79it/s]

{'loss': 1.1026, 'grad_norm': 0.6964883208274841, 'learning_rate': 0.00012612244897959184, 'epoch': 2.36}


 38%|███▊      | 192/500 [00:49<01:20,  3.82it/s]

{'loss': 1.1702, 'grad_norm': 0.6271846890449524, 'learning_rate': 0.00012571428571428572, 'epoch': 2.37}


 39%|███▊      | 193/500 [00:49<01:22,  3.74it/s]

{'loss': 1.3154, 'grad_norm': 0.6430373191833496, 'learning_rate': 0.00012530612244897958, 'epoch': 2.38}


 39%|███▉      | 194/500 [00:49<01:20,  3.80it/s]

{'loss': 1.2058, 'grad_norm': 0.823179304599762, 'learning_rate': 0.0001248979591836735, 'epoch': 2.4}


 39%|███▉      | 195/500 [00:49<01:20,  3.81it/s]

{'loss': 1.2461, 'grad_norm': 0.684354305267334, 'learning_rate': 0.00012448979591836735, 'epoch': 2.41}


 39%|███▉      | 196/500 [00:50<01:19,  3.84it/s]

{'loss': 1.431, 'grad_norm': 0.7661777138710022, 'learning_rate': 0.00012408163265306124, 'epoch': 2.42}


 39%|███▉      | 197/500 [00:50<01:18,  3.85it/s]

{'loss': 1.0198, 'grad_norm': 0.5808024406433105, 'learning_rate': 0.0001236734693877551, 'epoch': 2.43}


 40%|███▉      | 198/500 [00:50<01:17,  3.88it/s]

{'loss': 1.3102, 'grad_norm': 0.5718187689781189, 'learning_rate': 0.00012326530612244898, 'epoch': 2.44}


 40%|███▉      | 199/500 [00:50<01:16,  3.92it/s]

{'loss': 0.8622, 'grad_norm': 0.5235544443130493, 'learning_rate': 0.00012285714285714287, 'epoch': 2.46}


 40%|████      | 200/500 [00:51<01:16,  3.92it/s]

{'loss': 1.0308, 'grad_norm': 0.6931825280189514, 'learning_rate': 0.00012244897959183676, 'epoch': 2.47}


 40%|████      | 201/500 [00:51<01:15,  3.95it/s]

{'loss': 1.1148, 'grad_norm': 0.6058634519577026, 'learning_rate': 0.00012204081632653061, 'epoch': 2.48}


 40%|████      | 202/500 [00:51<01:15,  3.95it/s]

{'loss': 1.2207, 'grad_norm': 0.5833950638771057, 'learning_rate': 0.00012163265306122449, 'epoch': 2.49}


 41%|████      | 203/500 [00:52<01:14,  3.97it/s]

{'loss': 1.1636, 'grad_norm': 0.5631295442581177, 'learning_rate': 0.00012122448979591839, 'epoch': 2.51}


 41%|████      | 204/500 [00:52<01:14,  4.00it/s]

{'loss': 1.2319, 'grad_norm': 0.7381582856178284, 'learning_rate': 0.00012081632653061226, 'epoch': 2.52}


 41%|████      | 205/500 [00:52<01:13,  4.01it/s]

{'loss': 0.8508, 'grad_norm': 0.5262117385864258, 'learning_rate': 0.00012040816326530613, 'epoch': 2.53}


 41%|████      | 206/500 [00:52<01:12,  4.04it/s]

{'loss': 0.9788, 'grad_norm': 0.6325759291648865, 'learning_rate': 0.00012, 'epoch': 2.54}


 41%|████▏     | 207/500 [00:52<01:12,  4.04it/s]

{'loss': 1.0572, 'grad_norm': 0.5312268733978271, 'learning_rate': 0.00011959183673469388, 'epoch': 2.56}


 42%|████▏     | 208/500 [00:53<01:12,  4.04it/s]

{'loss': 1.4436, 'grad_norm': 0.6975216865539551, 'learning_rate': 0.00011918367346938777, 'epoch': 2.57}


 42%|████▏     | 209/500 [00:53<01:13,  3.98it/s]

{'loss': 1.1712, 'grad_norm': 0.629215657711029, 'learning_rate': 0.00011877551020408165, 'epoch': 2.58}


 42%|████▏     | 210/500 [00:53<01:12,  3.99it/s]

{'loss': 1.2746, 'grad_norm': 0.5622393488883972, 'learning_rate': 0.00011836734693877552, 'epoch': 2.59}


 42%|████▏     | 211/500 [00:53<01:12,  4.00it/s]

{'loss': 1.0408, 'grad_norm': 0.6123989224433899, 'learning_rate': 0.00011795918367346939, 'epoch': 2.6}


 42%|████▏     | 212/500 [00:54<01:11,  4.03it/s]

{'loss': 1.0816, 'grad_norm': 0.7064108848571777, 'learning_rate': 0.00011755102040816328, 'epoch': 2.62}


 43%|████▎     | 213/500 [00:54<01:10,  4.04it/s]

{'loss': 1.2625, 'grad_norm': 0.7158111333847046, 'learning_rate': 0.00011714285714285715, 'epoch': 2.63}


 43%|████▎     | 214/500 [00:54<01:10,  4.06it/s]

{'loss': 0.852, 'grad_norm': 0.6517495512962341, 'learning_rate': 0.00011673469387755102, 'epoch': 2.64}


 43%|████▎     | 215/500 [00:54<01:09,  4.08it/s]

{'loss': 1.2798, 'grad_norm': 0.6196041703224182, 'learning_rate': 0.0001163265306122449, 'epoch': 2.65}


 43%|████▎     | 216/500 [00:55<01:09,  4.06it/s]

{'loss': 0.8935, 'grad_norm': 0.6413226127624512, 'learning_rate': 0.00011591836734693877, 'epoch': 2.67}


 43%|████▎     | 217/500 [00:55<01:09,  4.08it/s]

{'loss': 1.2983, 'grad_norm': 0.6691237092018127, 'learning_rate': 0.00011551020408163267, 'epoch': 2.68}


 44%|████▎     | 218/500 [00:55<01:08,  4.09it/s]

{'loss': 1.2373, 'grad_norm': 0.7318903207778931, 'learning_rate': 0.00011510204081632654, 'epoch': 2.69}


 44%|████▍     | 219/500 [00:55<01:08,  4.10it/s]

{'loss': 0.9006, 'grad_norm': 0.4896867871284485, 'learning_rate': 0.00011469387755102041, 'epoch': 2.7}


 44%|████▍     | 220/500 [00:56<01:09,  4.03it/s]

{'loss': 1.2908, 'grad_norm': 0.6394292116165161, 'learning_rate': 0.00011428571428571428, 'epoch': 2.72}


 44%|████▍     | 221/500 [00:56<01:10,  3.96it/s]

{'loss': 1.2429, 'grad_norm': 0.6546691060066223, 'learning_rate': 0.00011387755102040818, 'epoch': 2.73}


 44%|████▍     | 222/500 [00:56<01:10,  3.95it/s]

{'loss': 1.0236, 'grad_norm': 0.6155187487602234, 'learning_rate': 0.00011346938775510206, 'epoch': 2.74}


 45%|████▍     | 223/500 [00:56<01:10,  3.95it/s]

{'loss': 1.2821, 'grad_norm': 0.6078084707260132, 'learning_rate': 0.00011306122448979593, 'epoch': 2.75}


 45%|████▍     | 224/500 [00:57<01:10,  3.90it/s]

{'loss': 1.4669, 'grad_norm': 0.6579426527023315, 'learning_rate': 0.0001126530612244898, 'epoch': 2.77}


 45%|████▌     | 225/500 [00:57<01:10,  3.89it/s]

{'loss': 0.9527, 'grad_norm': 0.47393709421157837, 'learning_rate': 0.00011224489795918367, 'epoch': 2.78}


 45%|████▌     | 226/500 [00:57<01:10,  3.87it/s]

{'loss': 1.2534, 'grad_norm': 0.6504090428352356, 'learning_rate': 0.00011183673469387757, 'epoch': 2.79}


 45%|████▌     | 227/500 [00:58<01:12,  3.76it/s]

{'loss': 1.3074, 'grad_norm': 0.6557174921035767, 'learning_rate': 0.00011142857142857144, 'epoch': 2.8}


 46%|████▌     | 228/500 [00:58<01:13,  3.71it/s]

{'loss': 1.2179, 'grad_norm': 0.7479513883590698, 'learning_rate': 0.00011102040816326532, 'epoch': 2.81}


 46%|████▌     | 229/500 [00:58<01:14,  3.63it/s]

{'loss': 1.0343, 'grad_norm': 0.5722095370292664, 'learning_rate': 0.00011061224489795919, 'epoch': 2.83}


 46%|████▌     | 230/500 [00:58<01:13,  3.66it/s]

{'loss': 1.2951, 'grad_norm': 0.8618522882461548, 'learning_rate': 0.00011020408163265306, 'epoch': 2.84}


 46%|████▌     | 231/500 [00:59<01:13,  3.66it/s]

{'loss': 1.5132, 'grad_norm': 0.6571285724639893, 'learning_rate': 0.00010979591836734695, 'epoch': 2.85}


 46%|████▋     | 232/500 [00:59<01:12,  3.68it/s]

{'loss': 0.907, 'grad_norm': 0.6185920834541321, 'learning_rate': 0.00010938775510204082, 'epoch': 2.86}


 47%|████▋     | 233/500 [00:59<01:12,  3.67it/s]

{'loss': 1.1141, 'grad_norm': 0.6057601571083069, 'learning_rate': 0.00010897959183673469, 'epoch': 2.88}


 47%|████▋     | 234/500 [00:59<01:11,  3.73it/s]

{'loss': 0.939, 'grad_norm': 0.5848249793052673, 'learning_rate': 0.00010857142857142856, 'epoch': 2.89}


 47%|████▋     | 235/500 [01:00<01:09,  3.79it/s]

{'loss': 1.4058, 'grad_norm': 0.6733136773109436, 'learning_rate': 0.00010816326530612246, 'epoch': 2.9}


 47%|████▋     | 236/500 [01:00<01:08,  3.85it/s]

{'loss': 0.9352, 'grad_norm': 0.5835859179496765, 'learning_rate': 0.00010775510204081634, 'epoch': 2.91}


 47%|████▋     | 237/500 [01:00<01:08,  3.84it/s]

{'loss': 0.9787, 'grad_norm': 0.6139756441116333, 'learning_rate': 0.00010734693877551021, 'epoch': 2.93}


 48%|████▊     | 238/500 [01:00<01:07,  3.90it/s]

{'loss': 1.0176, 'grad_norm': 0.5642780661582947, 'learning_rate': 0.00010693877551020408, 'epoch': 2.94}


 48%|████▊     | 239/500 [01:01<01:06,  3.91it/s]

{'loss': 1.0545, 'grad_norm': 0.5839713215827942, 'learning_rate': 0.00010653061224489795, 'epoch': 2.95}


 48%|████▊     | 240/500 [01:01<01:05,  3.96it/s]

{'loss': 0.9959, 'grad_norm': 0.5793209671974182, 'learning_rate': 0.00010612244897959185, 'epoch': 2.96}


 48%|████▊     | 241/500 [01:01<01:05,  3.96it/s]

{'loss': 1.147, 'grad_norm': 0.644212007522583, 'learning_rate': 0.00010571428571428572, 'epoch': 2.98}


 48%|████▊     | 242/500 [01:01<01:04,  3.98it/s]

{'loss': 1.3821, 'grad_norm': 0.5806352496147156, 'learning_rate': 0.0001053061224489796, 'epoch': 2.99}


 49%|████▊     | 243/500 [01:02<01:01,  4.15it/s]

{'loss': 1.2107, 'grad_norm': 0.9430399537086487, 'learning_rate': 0.00010489795918367347, 'epoch': 3.0}


 49%|████▉     | 244/500 [01:02<01:02,  4.11it/s]

{'loss': 1.0996, 'grad_norm': 0.5552637577056885, 'learning_rate': 0.00010448979591836735, 'epoch': 3.01}


 49%|████▉     | 245/500 [01:02<01:03,  4.00it/s]

{'loss': 0.8729, 'grad_norm': 0.5111056566238403, 'learning_rate': 0.00010408163265306123, 'epoch': 3.02}


 49%|████▉     | 246/500 [01:02<01:04,  3.95it/s]

{'loss': 1.1525, 'grad_norm': 0.6518040299415588, 'learning_rate': 0.00010367346938775511, 'epoch': 3.04}


 49%|████▉     | 247/500 [01:03<01:04,  3.90it/s]

{'loss': 1.4717, 'grad_norm': 0.8082369565963745, 'learning_rate': 0.00010326530612244899, 'epoch': 3.05}


 50%|████▉     | 248/500 [01:03<01:04,  3.89it/s]

{'loss': 1.2827, 'grad_norm': 0.7476845383644104, 'learning_rate': 0.00010285714285714286, 'epoch': 3.06}


 50%|████▉     | 249/500 [01:03<01:04,  3.88it/s]

{'loss': 1.1805, 'grad_norm': 0.8767880797386169, 'learning_rate': 0.00010244897959183674, 'epoch': 3.07}


 50%|█████     | 250/500 [01:03<01:03,  3.92it/s]

{'loss': 1.4044, 'grad_norm': 0.5949388742446899, 'learning_rate': 0.00010204081632653062, 'epoch': 3.09}


 50%|█████     | 251/500 [01:04<01:03,  3.94it/s]

{'loss': 1.0466, 'grad_norm': 0.6662126183509827, 'learning_rate': 0.00010163265306122449, 'epoch': 3.1}


 50%|█████     | 252/500 [01:04<01:02,  3.95it/s]

{'loss': 0.8239, 'grad_norm': 0.5564828515052795, 'learning_rate': 0.00010122448979591836, 'epoch': 3.11}


 51%|█████     | 253/500 [01:04<01:02,  3.97it/s]

{'loss': 0.8515, 'grad_norm': 0.6468363404273987, 'learning_rate': 0.00010081632653061226, 'epoch': 3.12}


 51%|█████     | 254/500 [01:04<01:02,  3.94it/s]

{'loss': 1.2232, 'grad_norm': 0.6079509258270264, 'learning_rate': 0.00010040816326530613, 'epoch': 3.14}


 51%|█████     | 255/500 [01:05<01:02,  3.90it/s]

{'loss': 1.2242, 'grad_norm': 0.6056531071662903, 'learning_rate': 0.0001, 'epoch': 3.15}


 51%|█████     | 256/500 [01:05<01:02,  3.89it/s]

{'loss': 1.179, 'grad_norm': 0.6704080104827881, 'learning_rate': 9.959183673469388e-05, 'epoch': 3.16}


 51%|█████▏    | 257/500 [01:05<01:02,  3.88it/s]

{'loss': 1.0834, 'grad_norm': 0.49054843187332153, 'learning_rate': 9.918367346938776e-05, 'epoch': 3.17}


 52%|█████▏    | 258/500 [01:06<01:01,  3.91it/s]

{'loss': 1.2723, 'grad_norm': 0.7951711416244507, 'learning_rate': 9.877551020408164e-05, 'epoch': 3.19}


 52%|█████▏    | 259/500 [01:06<01:02,  3.88it/s]

{'loss': 1.2456, 'grad_norm': 0.5763389468193054, 'learning_rate': 9.836734693877552e-05, 'epoch': 3.2}


 52%|█████▏    | 260/500 [01:06<01:01,  3.90it/s]

{'loss': 0.9181, 'grad_norm': 0.7057625651359558, 'learning_rate': 9.79591836734694e-05, 'epoch': 3.21}


 52%|█████▏    | 261/500 [01:06<01:01,  3.91it/s]

{'loss': 1.1925, 'grad_norm': 0.621731698513031, 'learning_rate': 9.755102040816328e-05, 'epoch': 3.22}


 52%|█████▏    | 262/500 [01:07<01:00,  3.95it/s]

{'loss': 1.1769, 'grad_norm': 0.764331042766571, 'learning_rate': 9.714285714285715e-05, 'epoch': 3.23}


 53%|█████▎    | 263/500 [01:07<01:00,  3.92it/s]

{'loss': 1.1322, 'grad_norm': 0.6433961987495422, 'learning_rate': 9.673469387755102e-05, 'epoch': 3.25}


 53%|█████▎    | 264/500 [01:07<00:59,  3.94it/s]

{'loss': 1.5404, 'grad_norm': 0.8132268786430359, 'learning_rate': 9.63265306122449e-05, 'epoch': 3.26}


 53%|█████▎    | 265/500 [01:07<00:58,  3.98it/s]

{'loss': 1.1449, 'grad_norm': 0.6518386006355286, 'learning_rate': 9.591836734693878e-05, 'epoch': 3.27}


 53%|█████▎    | 266/500 [01:08<00:58,  3.98it/s]

{'loss': 1.2639, 'grad_norm': 0.6741182208061218, 'learning_rate': 9.551020408163265e-05, 'epoch': 3.28}


 53%|█████▎    | 267/500 [01:08<00:58,  3.99it/s]

{'loss': 1.0051, 'grad_norm': 0.617243230342865, 'learning_rate': 9.510204081632653e-05, 'epoch': 3.3}


 54%|█████▎    | 268/500 [01:08<00:57,  4.02it/s]

{'loss': 1.5132, 'grad_norm': 0.6274101734161377, 'learning_rate': 9.469387755102041e-05, 'epoch': 3.31}


 54%|█████▍    | 269/500 [01:08<00:57,  4.03it/s]

{'loss': 1.4021, 'grad_norm': 0.5572466254234314, 'learning_rate': 9.428571428571429e-05, 'epoch': 3.32}


 54%|█████▍    | 270/500 [01:09<00:56,  4.04it/s]

{'loss': 1.0961, 'grad_norm': 0.5398631691932678, 'learning_rate': 9.387755102040817e-05, 'epoch': 3.33}


 54%|█████▍    | 271/500 [01:09<00:56,  4.05it/s]

{'loss': 1.1094, 'grad_norm': 0.6040838360786438, 'learning_rate': 9.346938775510204e-05, 'epoch': 3.35}


 54%|█████▍    | 272/500 [01:09<00:55,  4.07it/s]

{'loss': 1.1551, 'grad_norm': 0.6076251864433289, 'learning_rate': 9.306122448979592e-05, 'epoch': 3.36}


 55%|█████▍    | 273/500 [01:09<00:55,  4.09it/s]

{'loss': 1.2409, 'grad_norm': 0.5619151592254639, 'learning_rate': 9.26530612244898e-05, 'epoch': 3.37}


 55%|█████▍    | 274/500 [01:10<00:55,  4.10it/s]

{'loss': 1.1682, 'grad_norm': 0.5965188145637512, 'learning_rate': 9.224489795918367e-05, 'epoch': 3.38}


 55%|█████▌    | 275/500 [01:10<00:54,  4.11it/s]

{'loss': 1.0889, 'grad_norm': 0.5356046557426453, 'learning_rate': 9.183673469387756e-05, 'epoch': 3.4}


 55%|█████▌    | 276/500 [01:10<00:54,  4.11it/s]

{'loss': 1.2563, 'grad_norm': 0.684173047542572, 'learning_rate': 9.142857142857143e-05, 'epoch': 3.41}


 55%|█████▌    | 277/500 [01:10<00:54,  4.11it/s]

{'loss': 1.0637, 'grad_norm': 0.5329380035400391, 'learning_rate': 9.102040816326532e-05, 'epoch': 3.42}


 56%|█████▌    | 278/500 [01:10<00:54,  4.10it/s]

{'loss': 0.9898, 'grad_norm': 0.6091470718383789, 'learning_rate': 9.061224489795919e-05, 'epoch': 3.43}


 56%|█████▌    | 279/500 [01:11<00:53,  4.11it/s]

{'loss': 1.2327, 'grad_norm': 0.6450108289718628, 'learning_rate': 9.020408163265308e-05, 'epoch': 3.44}


 56%|█████▌    | 280/500 [01:11<00:53,  4.11it/s]

{'loss': 1.0272, 'grad_norm': 0.6068325042724609, 'learning_rate': 8.979591836734695e-05, 'epoch': 3.46}


 56%|█████▌    | 281/500 [01:11<00:53,  4.10it/s]

{'loss': 1.1307, 'grad_norm': 0.5997315645217896, 'learning_rate': 8.938775510204082e-05, 'epoch': 3.47}


 56%|█████▋    | 282/500 [01:11<00:53,  4.09it/s]

{'loss': 0.9739, 'grad_norm': 0.6840189695358276, 'learning_rate': 8.89795918367347e-05, 'epoch': 3.48}


 57%|█████▋    | 283/500 [01:12<00:53,  4.05it/s]

{'loss': 1.0158, 'grad_norm': 0.6370605826377869, 'learning_rate': 8.857142857142857e-05, 'epoch': 3.49}


 57%|█████▋    | 284/500 [01:12<00:53,  4.03it/s]

{'loss': 1.1654, 'grad_norm': 0.6716293692588806, 'learning_rate': 8.816326530612245e-05, 'epoch': 3.51}


 57%|█████▋    | 285/500 [01:12<00:53,  4.04it/s]

{'loss': 1.0694, 'grad_norm': 0.6345279216766357, 'learning_rate': 8.775510204081632e-05, 'epoch': 3.52}


 57%|█████▋    | 286/500 [01:12<00:53,  3.97it/s]

{'loss': 1.2092, 'grad_norm': 0.582961916923523, 'learning_rate': 8.734693877551021e-05, 'epoch': 3.53}


 57%|█████▋    | 287/500 [01:13<00:53,  3.96it/s]

{'loss': 0.92, 'grad_norm': 0.6600664854049683, 'learning_rate': 8.693877551020408e-05, 'epoch': 3.54}


 58%|█████▊    | 288/500 [01:13<00:53,  3.99it/s]

{'loss': 0.895, 'grad_norm': 0.675390362739563, 'learning_rate': 8.653061224489797e-05, 'epoch': 3.56}


 58%|█████▊    | 289/500 [01:13<00:52,  4.01it/s]

{'loss': 1.1281, 'grad_norm': 0.5760377645492554, 'learning_rate': 8.612244897959184e-05, 'epoch': 3.57}


 58%|█████▊    | 290/500 [01:13<00:51,  4.04it/s]

{'loss': 1.2462, 'grad_norm': 0.6896598935127258, 'learning_rate': 8.571428571428571e-05, 'epoch': 3.58}


 58%|█████▊    | 291/500 [01:14<00:51,  4.05it/s]

{'loss': 1.0813, 'grad_norm': 0.7698401808738708, 'learning_rate': 8.53061224489796e-05, 'epoch': 3.59}


 58%|█████▊    | 292/500 [01:14<00:51,  4.07it/s]

{'loss': 0.9303, 'grad_norm': 0.5705361366271973, 'learning_rate': 8.489795918367347e-05, 'epoch': 3.6}


 59%|█████▊    | 293/500 [01:14<00:50,  4.07it/s]

{'loss': 0.9987, 'grad_norm': 0.5575817227363586, 'learning_rate': 8.448979591836736e-05, 'epoch': 3.62}


 59%|█████▉    | 294/500 [01:14<00:50,  4.07it/s]

{'loss': 1.0351, 'grad_norm': 0.546240508556366, 'learning_rate': 8.408163265306123e-05, 'epoch': 3.63}


 59%|█████▉    | 295/500 [01:15<00:50,  4.08it/s]

{'loss': 1.1864, 'grad_norm': 0.5423186421394348, 'learning_rate': 8.367346938775511e-05, 'epoch': 3.64}


 59%|█████▉    | 296/500 [01:15<00:50,  4.08it/s]

{'loss': 1.1666, 'grad_norm': 0.6329001188278198, 'learning_rate': 8.326530612244899e-05, 'epoch': 3.65}


 59%|█████▉    | 297/500 [01:15<00:49,  4.10it/s]

{'loss': 1.3313, 'grad_norm': 0.7804444432258606, 'learning_rate': 8.285714285714287e-05, 'epoch': 3.67}


 60%|█████▉    | 298/500 [01:15<00:49,  4.09it/s]

{'loss': 1.0857, 'grad_norm': 0.5885705947875977, 'learning_rate': 8.244897959183675e-05, 'epoch': 3.68}


 60%|█████▉    | 299/500 [01:16<00:49,  4.10it/s]

{'loss': 1.1665, 'grad_norm': 0.6699228286743164, 'learning_rate': 8.204081632653062e-05, 'epoch': 3.69}


 60%|██████    | 300/500 [01:16<00:50,  3.96it/s]

{'loss': 1.1381, 'grad_norm': 0.643438994884491, 'learning_rate': 8.163265306122449e-05, 'epoch': 3.7}


 60%|██████    | 301/500 [01:16<00:50,  3.95it/s]

{'loss': 0.9308, 'grad_norm': 0.5217877626419067, 'learning_rate': 8.122448979591836e-05, 'epoch': 3.72}


 60%|██████    | 302/500 [01:16<00:50,  3.95it/s]

{'loss': 1.0217, 'grad_norm': 0.6037978529930115, 'learning_rate': 8.081632653061225e-05, 'epoch': 3.73}


 61%|██████    | 303/500 [01:17<00:51,  3.81it/s]

{'loss': 1.0728, 'grad_norm': 1.0281758308410645, 'learning_rate': 8.040816326530612e-05, 'epoch': 3.74}


 61%|██████    | 304/500 [01:17<00:51,  3.80it/s]

{'loss': 0.9272, 'grad_norm': 0.5678081512451172, 'learning_rate': 8e-05, 'epoch': 3.75}


 61%|██████    | 305/500 [01:17<00:51,  3.81it/s]

{'loss': 1.1206, 'grad_norm': 0.554819643497467, 'learning_rate': 7.959183673469388e-05, 'epoch': 3.77}


 61%|██████    | 306/500 [01:18<00:50,  3.84it/s]

{'loss': 1.0777, 'grad_norm': 0.5626429319381714, 'learning_rate': 7.918367346938775e-05, 'epoch': 3.78}


 61%|██████▏   | 307/500 [01:18<00:51,  3.72it/s]

{'loss': 1.0605, 'grad_norm': 0.6743727922439575, 'learning_rate': 7.877551020408164e-05, 'epoch': 3.79}


 62%|██████▏   | 308/500 [01:18<00:56,  3.42it/s]

{'loss': 1.0517, 'grad_norm': 0.6355631351470947, 'learning_rate': 7.836734693877551e-05, 'epoch': 3.8}


 62%|██████▏   | 309/500 [01:18<00:54,  3.49it/s]

{'loss': 0.907, 'grad_norm': 0.6509261131286621, 'learning_rate': 7.79591836734694e-05, 'epoch': 3.81}


 62%|██████▏   | 310/500 [01:19<00:53,  3.56it/s]

{'loss': 0.9273, 'grad_norm': 0.6816343665122986, 'learning_rate': 7.755102040816327e-05, 'epoch': 3.83}


 62%|██████▏   | 311/500 [01:19<00:51,  3.64it/s]

{'loss': 1.2404, 'grad_norm': 0.6753402948379517, 'learning_rate': 7.714285714285715e-05, 'epoch': 3.84}


 62%|██████▏   | 312/500 [01:19<00:50,  3.70it/s]

{'loss': 1.0427, 'grad_norm': 0.5932518243789673, 'learning_rate': 7.673469387755103e-05, 'epoch': 3.85}


 63%|██████▎   | 313/500 [01:19<00:50,  3.70it/s]

{'loss': 1.0502, 'grad_norm': 0.6781144142150879, 'learning_rate': 7.632653061224491e-05, 'epoch': 3.86}


 63%|██████▎   | 314/500 [01:20<00:49,  3.76it/s]

{'loss': 1.0419, 'grad_norm': 0.7356116771697998, 'learning_rate': 7.591836734693878e-05, 'epoch': 3.88}


 63%|██████▎   | 315/500 [01:20<00:48,  3.78it/s]

{'loss': 0.9821, 'grad_norm': 0.6410813927650452, 'learning_rate': 7.551020408163266e-05, 'epoch': 3.89}


 63%|██████▎   | 316/500 [01:20<00:48,  3.82it/s]

{'loss': 1.1099, 'grad_norm': 0.5913413763046265, 'learning_rate': 7.510204081632653e-05, 'epoch': 3.9}


 63%|██████▎   | 317/500 [01:20<00:47,  3.87it/s]

{'loss': 0.9531, 'grad_norm': 0.5249179005622864, 'learning_rate': 7.469387755102041e-05, 'epoch': 3.91}


 64%|██████▎   | 318/500 [01:21<00:46,  3.91it/s]

{'loss': 1.5154, 'grad_norm': 0.7512898445129395, 'learning_rate': 7.428571428571429e-05, 'epoch': 3.93}


 64%|██████▍   | 319/500 [01:21<00:46,  3.92it/s]

{'loss': 1.1313, 'grad_norm': 0.6653596758842468, 'learning_rate': 7.387755102040816e-05, 'epoch': 3.94}


 64%|██████▍   | 320/500 [01:21<00:46,  3.88it/s]

{'loss': 1.2991, 'grad_norm': 0.7155333757400513, 'learning_rate': 7.346938775510205e-05, 'epoch': 3.95}


 64%|██████▍   | 321/500 [01:22<00:47,  3.75it/s]

{'loss': 1.1244, 'grad_norm': 0.5833516120910645, 'learning_rate': 7.306122448979592e-05, 'epoch': 3.96}


 64%|██████▍   | 322/500 [01:22<00:47,  3.75it/s]

{'loss': 1.1033, 'grad_norm': 0.8193710446357727, 'learning_rate': 7.26530612244898e-05, 'epoch': 3.98}


 65%|██████▍   | 323/500 [01:22<00:46,  3.80it/s]

{'loss': 0.9311, 'grad_norm': 0.580195963382721, 'learning_rate': 7.224489795918368e-05, 'epoch': 3.99}


 65%|██████▍   | 324/500 [01:22<00:44,  3.98it/s]

{'loss': 0.7542, 'grad_norm': 0.8297509551048279, 'learning_rate': 7.183673469387755e-05, 'epoch': 4.0}


 65%|██████▌   | 325/500 [01:23<00:44,  3.96it/s]

{'loss': 1.2172, 'grad_norm': 0.835143506526947, 'learning_rate': 7.142857142857143e-05, 'epoch': 4.01}


 65%|██████▌   | 326/500 [01:23<00:44,  3.95it/s]

{'loss': 0.9182, 'grad_norm': 0.5841493010520935, 'learning_rate': 7.10204081632653e-05, 'epoch': 4.02}


 65%|██████▌   | 327/500 [01:23<00:43,  3.96it/s]

{'loss': 1.0707, 'grad_norm': 0.8181084394454956, 'learning_rate': 7.061224489795919e-05, 'epoch': 4.04}


 66%|██████▌   | 328/500 [01:23<00:43,  3.92it/s]

{'loss': 1.0534, 'grad_norm': 0.6037523746490479, 'learning_rate': 7.020408163265306e-05, 'epoch': 4.05}


 66%|██████▌   | 329/500 [01:24<00:44,  3.83it/s]

{'loss': 0.999, 'grad_norm': 0.5796051025390625, 'learning_rate': 6.979591836734695e-05, 'epoch': 4.06}


 66%|██████▌   | 330/500 [01:24<00:44,  3.86it/s]

{'loss': 1.0539, 'grad_norm': 0.7273560762405396, 'learning_rate': 6.938775510204082e-05, 'epoch': 4.07}


 66%|██████▌   | 331/500 [01:24<00:43,  3.87it/s]

{'loss': 1.0089, 'grad_norm': 0.5650511980056763, 'learning_rate': 6.897959183673471e-05, 'epoch': 4.09}


 66%|██████▋   | 332/500 [01:24<00:43,  3.90it/s]

{'loss': 1.0579, 'grad_norm': 0.8580501675605774, 'learning_rate': 6.857142857142858e-05, 'epoch': 4.1}


 67%|██████▋   | 333/500 [01:25<00:43,  3.84it/s]

{'loss': 0.9099, 'grad_norm': 0.5923689007759094, 'learning_rate': 6.816326530612245e-05, 'epoch': 4.11}


 67%|██████▋   | 334/500 [01:25<00:43,  3.84it/s]

{'loss': 1.0975, 'grad_norm': 0.6492199301719666, 'learning_rate': 6.775510204081633e-05, 'epoch': 4.12}


 67%|██████▋   | 335/500 [01:25<00:42,  3.87it/s]

{'loss': 1.0997, 'grad_norm': 0.6104254126548767, 'learning_rate': 6.73469387755102e-05, 'epoch': 4.14}


 67%|██████▋   | 336/500 [01:25<00:42,  3.88it/s]

{'loss': 1.0407, 'grad_norm': 0.6848065257072449, 'learning_rate': 6.693877551020408e-05, 'epoch': 4.15}


 67%|██████▋   | 337/500 [01:26<00:41,  3.91it/s]

{'loss': 1.139, 'grad_norm': 0.6508702039718628, 'learning_rate': 6.653061224489796e-05, 'epoch': 4.16}


 68%|██████▊   | 338/500 [01:26<00:41,  3.91it/s]

{'loss': 1.2635, 'grad_norm': 0.6427459716796875, 'learning_rate': 6.612244897959184e-05, 'epoch': 4.17}


 68%|██████▊   | 339/500 [01:26<00:40,  3.94it/s]

{'loss': 1.0401, 'grad_norm': 0.5982540249824524, 'learning_rate': 6.571428571428571e-05, 'epoch': 4.19}


 68%|██████▊   | 340/500 [01:26<00:40,  3.97it/s]

{'loss': 0.8756, 'grad_norm': 0.7523930668830872, 'learning_rate': 6.530612244897959e-05, 'epoch': 4.2}


 68%|██████▊   | 341/500 [01:27<00:40,  3.95it/s]

{'loss': 1.0784, 'grad_norm': 0.5776974558830261, 'learning_rate': 6.489795918367347e-05, 'epoch': 4.21}


 68%|██████▊   | 342/500 [01:27<00:39,  3.98it/s]

{'loss': 0.8391, 'grad_norm': 0.7260725498199463, 'learning_rate': 6.448979591836734e-05, 'epoch': 4.22}


 69%|██████▊   | 343/500 [01:27<00:39,  4.00it/s]

{'loss': 1.2556, 'grad_norm': 0.9445098638534546, 'learning_rate': 6.408163265306123e-05, 'epoch': 4.23}


 69%|██████▉   | 344/500 [01:27<00:38,  4.02it/s]

{'loss': 1.1352, 'grad_norm': 0.8995423316955566, 'learning_rate': 6.36734693877551e-05, 'epoch': 4.25}


 69%|██████▉   | 345/500 [01:28<00:38,  4.03it/s]

{'loss': 0.8289, 'grad_norm': 0.624947190284729, 'learning_rate': 6.326530612244899e-05, 'epoch': 4.26}


 69%|██████▉   | 346/500 [01:28<00:39,  3.94it/s]

{'loss': 0.986, 'grad_norm': 0.5690768957138062, 'learning_rate': 6.285714285714286e-05, 'epoch': 4.27}


 69%|██████▉   | 347/500 [01:28<00:38,  3.97it/s]

{'loss': 1.0872, 'grad_norm': 0.6304959654808044, 'learning_rate': 6.244897959183675e-05, 'epoch': 4.28}


 70%|██████▉   | 348/500 [01:28<00:38,  3.99it/s]

{'loss': 0.9004, 'grad_norm': 0.6071205139160156, 'learning_rate': 6.204081632653062e-05, 'epoch': 4.3}


 70%|██████▉   | 349/500 [01:29<00:37,  4.01it/s]

{'loss': 1.2005, 'grad_norm': 0.6514768004417419, 'learning_rate': 6.163265306122449e-05, 'epoch': 4.31}


 70%|███████   | 350/500 [01:29<00:37,  4.01it/s]

{'loss': 1.3347, 'grad_norm': 0.9919952750205994, 'learning_rate': 6.122448979591838e-05, 'epoch': 4.32}


 70%|███████   | 351/500 [01:29<00:36,  4.03it/s]

{'loss': 1.5546, 'grad_norm': 0.9472942352294922, 'learning_rate': 6.081632653061224e-05, 'epoch': 4.33}


 70%|███████   | 352/500 [01:29<00:36,  4.05it/s]

{'loss': 1.2632, 'grad_norm': 0.6129977107048035, 'learning_rate': 6.040816326530613e-05, 'epoch': 4.35}


 71%|███████   | 353/500 [01:30<00:36,  4.06it/s]

{'loss': 1.3065, 'grad_norm': 0.6959934830665588, 'learning_rate': 6e-05, 'epoch': 4.36}


 71%|███████   | 354/500 [01:30<00:35,  4.07it/s]

{'loss': 1.0546, 'grad_norm': 0.5473862290382385, 'learning_rate': 5.959183673469389e-05, 'epoch': 4.37}


 71%|███████   | 355/500 [01:30<00:35,  4.07it/s]

{'loss': 1.2956, 'grad_norm': 0.8937044143676758, 'learning_rate': 5.918367346938776e-05, 'epoch': 4.38}


 71%|███████   | 356/500 [01:30<00:35,  4.04it/s]

{'loss': 1.3838, 'grad_norm': 0.7680718898773193, 'learning_rate': 5.877551020408164e-05, 'epoch': 4.4}


 71%|███████▏  | 357/500 [01:31<00:35,  4.01it/s]

{'loss': 1.1118, 'grad_norm': 0.5829837918281555, 'learning_rate': 5.836734693877551e-05, 'epoch': 4.41}


 72%|███████▏  | 358/500 [01:31<00:35,  4.01it/s]

{'loss': 1.1774, 'grad_norm': 0.5573932528495789, 'learning_rate': 5.7959183673469384e-05, 'epoch': 4.42}


 72%|███████▏  | 359/500 [01:31<00:34,  4.04it/s]

{'loss': 1.2614, 'grad_norm': 0.6530594229698181, 'learning_rate': 5.755102040816327e-05, 'epoch': 4.43}


 72%|███████▏  | 360/500 [01:31<00:34,  4.05it/s]

{'loss': 1.1966, 'grad_norm': 0.5110976696014404, 'learning_rate': 5.714285714285714e-05, 'epoch': 4.44}


 72%|███████▏  | 361/500 [01:32<00:34,  4.06it/s]

{'loss': 0.9631, 'grad_norm': 0.5152856707572937, 'learning_rate': 5.673469387755103e-05, 'epoch': 4.46}


 72%|███████▏  | 362/500 [01:32<00:33,  4.09it/s]

{'loss': 1.6943, 'grad_norm': 0.8649474382400513, 'learning_rate': 5.63265306122449e-05, 'epoch': 4.47}


 73%|███████▎  | 363/500 [01:32<00:33,  4.08it/s]

{'loss': 1.1284, 'grad_norm': 0.705816388130188, 'learning_rate': 5.5918367346938786e-05, 'epoch': 4.48}


 73%|███████▎  | 364/500 [01:32<00:33,  4.02it/s]

{'loss': 0.9891, 'grad_norm': 0.6391851902008057, 'learning_rate': 5.551020408163266e-05, 'epoch': 4.49}


 73%|███████▎  | 365/500 [01:33<00:35,  3.78it/s]

{'loss': 1.1575, 'grad_norm': 0.5978396534919739, 'learning_rate': 5.510204081632653e-05, 'epoch': 4.51}


 73%|███████▎  | 366/500 [01:33<00:35,  3.82it/s]

{'loss': 1.1793, 'grad_norm': 0.5487282872200012, 'learning_rate': 5.469387755102041e-05, 'epoch': 4.52}


 73%|███████▎  | 367/500 [01:33<00:34,  3.82it/s]

{'loss': 1.1609, 'grad_norm': 0.6243578791618347, 'learning_rate': 5.428571428571428e-05, 'epoch': 4.53}


 74%|███████▎  | 368/500 [01:33<00:34,  3.85it/s]

{'loss': 0.9219, 'grad_norm': 0.5779898166656494, 'learning_rate': 5.387755102040817e-05, 'epoch': 4.54}


 74%|███████▍  | 369/500 [01:34<00:33,  3.89it/s]

{'loss': 1.0411, 'grad_norm': 0.6071159839630127, 'learning_rate': 5.346938775510204e-05, 'epoch': 4.56}


 74%|███████▍  | 370/500 [01:34<00:33,  3.93it/s]

{'loss': 0.9974, 'grad_norm': 0.5431668758392334, 'learning_rate': 5.3061224489795926e-05, 'epoch': 4.57}


 74%|███████▍  | 371/500 [01:34<00:32,  3.95it/s]

{'loss': 1.3892, 'grad_norm': 0.6137733459472656, 'learning_rate': 5.26530612244898e-05, 'epoch': 4.58}


 74%|███████▍  | 372/500 [01:34<00:32,  3.98it/s]

{'loss': 0.95, 'grad_norm': 0.9558703303337097, 'learning_rate': 5.224489795918368e-05, 'epoch': 4.59}


 75%|███████▍  | 373/500 [01:35<00:31,  3.99it/s]

{'loss': 1.2207, 'grad_norm': 0.6730886101722717, 'learning_rate': 5.1836734693877557e-05, 'epoch': 4.6}


 75%|███████▍  | 374/500 [01:35<00:31,  4.00it/s]

{'loss': 1.0977, 'grad_norm': 0.596878170967102, 'learning_rate': 5.142857142857143e-05, 'epoch': 4.62}


 75%|███████▌  | 375/500 [01:35<00:31,  4.01it/s]

{'loss': 1.0535, 'grad_norm': 0.6390148997306824, 'learning_rate': 5.102040816326531e-05, 'epoch': 4.63}


 75%|███████▌  | 376/500 [01:35<00:30,  4.04it/s]

{'loss': 0.9138, 'grad_norm': 0.5591403841972351, 'learning_rate': 5.061224489795918e-05, 'epoch': 4.64}


 75%|███████▌  | 377/500 [01:36<00:30,  4.01it/s]

{'loss': 1.1098, 'grad_norm': 0.6013767719268799, 'learning_rate': 5.0204081632653066e-05, 'epoch': 4.65}


 76%|███████▌  | 378/500 [01:36<00:30,  4.03it/s]

{'loss': 0.9726, 'grad_norm': 0.7183091640472412, 'learning_rate': 4.979591836734694e-05, 'epoch': 4.67}


 76%|███████▌  | 379/500 [01:36<00:30,  4.00it/s]

{'loss': 1.1753, 'grad_norm': 0.7627601623535156, 'learning_rate': 4.938775510204082e-05, 'epoch': 4.68}


 76%|███████▌  | 380/500 [01:36<00:29,  4.01it/s]

{'loss': 1.1601, 'grad_norm': 0.6940284371376038, 'learning_rate': 4.89795918367347e-05, 'epoch': 4.69}


 76%|███████▌  | 381/500 [01:37<00:29,  4.02it/s]

{'loss': 1.0312, 'grad_norm': 0.5980905890464783, 'learning_rate': 4.8571428571428576e-05, 'epoch': 4.7}


 76%|███████▋  | 382/500 [01:37<00:29,  4.04it/s]

{'loss': 1.0798, 'grad_norm': 0.6721817255020142, 'learning_rate': 4.816326530612245e-05, 'epoch': 4.72}


 77%|███████▋  | 383/500 [01:37<00:29,  4.03it/s]

{'loss': 1.1345, 'grad_norm': 0.7104556560516357, 'learning_rate': 4.775510204081633e-05, 'epoch': 4.73}


 77%|███████▋  | 384/500 [01:37<00:28,  4.06it/s]

{'loss': 0.9797, 'grad_norm': 0.6339846253395081, 'learning_rate': 4.7346938775510206e-05, 'epoch': 4.74}


 77%|███████▋  | 385/500 [01:38<00:28,  4.04it/s]

{'loss': 0.9918, 'grad_norm': 0.6855306625366211, 'learning_rate': 4.6938775510204086e-05, 'epoch': 4.75}


 77%|███████▋  | 386/500 [01:38<00:28,  4.04it/s]

{'loss': 0.9291, 'grad_norm': 0.6190972328186035, 'learning_rate': 4.653061224489796e-05, 'epoch': 4.77}


 77%|███████▋  | 387/500 [01:38<00:27,  4.04it/s]

{'loss': 1.0403, 'grad_norm': 0.653311014175415, 'learning_rate': 4.612244897959184e-05, 'epoch': 4.78}


 78%|███████▊  | 388/500 [01:38<00:27,  4.03it/s]

{'loss': 1.1294, 'grad_norm': 0.6740363836288452, 'learning_rate': 4.5714285714285716e-05, 'epoch': 4.79}


 78%|███████▊  | 389/500 [01:39<00:27,  4.03it/s]

{'loss': 1.0329, 'grad_norm': 0.5649850368499756, 'learning_rate': 4.5306122448979595e-05, 'epoch': 4.8}


 78%|███████▊  | 390/500 [01:39<00:27,  4.05it/s]

{'loss': 1.0463, 'grad_norm': 0.5460544228553772, 'learning_rate': 4.4897959183673474e-05, 'epoch': 4.81}


 78%|███████▊  | 391/500 [01:39<00:26,  4.05it/s]

{'loss': 1.0199, 'grad_norm': 0.6819478273391724, 'learning_rate': 4.448979591836735e-05, 'epoch': 4.83}


 78%|███████▊  | 392/500 [01:39<00:26,  4.04it/s]

{'loss': 1.1267, 'grad_norm': 0.7179959416389465, 'learning_rate': 4.4081632653061226e-05, 'epoch': 4.84}


 79%|███████▊  | 393/500 [01:40<00:26,  4.00it/s]

{'loss': 0.8925, 'grad_norm': 0.8525646328926086, 'learning_rate': 4.3673469387755105e-05, 'epoch': 4.85}


 79%|███████▉  | 394/500 [01:40<00:26,  3.99it/s]

{'loss': 1.0494, 'grad_norm': 0.587584912776947, 'learning_rate': 4.3265306122448984e-05, 'epoch': 4.86}


 79%|███████▉  | 395/500 [01:40<00:26,  3.99it/s]

{'loss': 1.0682, 'grad_norm': 0.6353826522827148, 'learning_rate': 4.2857142857142856e-05, 'epoch': 4.88}


 79%|███████▉  | 396/500 [01:40<00:26,  3.98it/s]

{'loss': 0.7765, 'grad_norm': 0.5984517931938171, 'learning_rate': 4.2448979591836735e-05, 'epoch': 4.89}


 79%|███████▉  | 397/500 [01:41<00:25,  3.99it/s]

{'loss': 1.1762, 'grad_norm': 0.7680466771125793, 'learning_rate': 4.2040816326530615e-05, 'epoch': 4.9}


 80%|███████▉  | 398/500 [01:41<00:25,  4.00it/s]

{'loss': 1.0187, 'grad_norm': 0.6564246416091919, 'learning_rate': 4.1632653061224494e-05, 'epoch': 4.91}


 80%|███████▉  | 399/500 [01:41<00:25,  4.00it/s]

{'loss': 1.1679, 'grad_norm': 0.7582579255104065, 'learning_rate': 4.122448979591837e-05, 'epoch': 4.93}


 80%|████████  | 400/500 [01:41<00:24,  4.01it/s]

{'loss': 1.2066, 'grad_norm': 0.8280940055847168, 'learning_rate': 4.0816326530612245e-05, 'epoch': 4.94}


 80%|████████  | 401/500 [01:42<00:24,  4.01it/s]

{'loss': 0.9262, 'grad_norm': 0.5391770601272583, 'learning_rate': 4.0408163265306124e-05, 'epoch': 4.95}


 80%|████████  | 402/500 [01:42<00:24,  4.00it/s]

{'loss': 1.1094, 'grad_norm': 0.5711162686347961, 'learning_rate': 4e-05, 'epoch': 4.96}


 81%|████████  | 403/500 [01:42<00:24,  3.97it/s]

{'loss': 0.9238, 'grad_norm': 0.6521950960159302, 'learning_rate': 3.9591836734693876e-05, 'epoch': 4.98}


 81%|████████  | 404/500 [01:42<00:24,  3.92it/s]

{'loss': 1.2009, 'grad_norm': 0.5921667814254761, 'learning_rate': 3.9183673469387755e-05, 'epoch': 4.99}


 81%|████████  | 405/500 [01:43<00:23,  4.08it/s]

{'loss': 1.3299, 'grad_norm': 0.8634422421455383, 'learning_rate': 3.8775510204081634e-05, 'epoch': 5.0}


 81%|████████  | 406/500 [01:43<00:23,  4.04it/s]

{'loss': 0.7302, 'grad_norm': 0.6284120082855225, 'learning_rate': 3.836734693877551e-05, 'epoch': 5.01}


 81%|████████▏ | 407/500 [01:43<00:23,  4.02it/s]

{'loss': 0.8773, 'grad_norm': 0.5993466377258301, 'learning_rate': 3.795918367346939e-05, 'epoch': 5.02}


 82%|████████▏ | 408/500 [01:43<00:22,  4.02it/s]

{'loss': 0.9087, 'grad_norm': 0.5995808839797974, 'learning_rate': 3.7551020408163264e-05, 'epoch': 5.04}


 82%|████████▏ | 409/500 [01:44<00:22,  4.03it/s]

{'loss': 1.0421, 'grad_norm': 0.599212110042572, 'learning_rate': 3.7142857142857143e-05, 'epoch': 5.05}


 82%|████████▏ | 410/500 [01:44<00:22,  4.03it/s]

{'loss': 1.0026, 'grad_norm': 0.6687333583831787, 'learning_rate': 3.673469387755102e-05, 'epoch': 5.06}


 82%|████████▏ | 411/500 [01:44<00:22,  4.03it/s]

{'loss': 1.2476, 'grad_norm': 0.697039008140564, 'learning_rate': 3.63265306122449e-05, 'epoch': 5.07}


 82%|████████▏ | 412/500 [01:44<00:21,  4.03it/s]

{'loss': 1.4227, 'grad_norm': 0.6720567345619202, 'learning_rate': 3.5918367346938774e-05, 'epoch': 5.09}


 83%|████████▎ | 413/500 [01:45<00:21,  4.03it/s]

{'loss': 1.211, 'grad_norm': 0.6120622158050537, 'learning_rate': 3.551020408163265e-05, 'epoch': 5.1}


 83%|████████▎ | 414/500 [01:45<00:21,  4.02it/s]

{'loss': 0.9681, 'grad_norm': 0.6107531189918518, 'learning_rate': 3.510204081632653e-05, 'epoch': 5.11}


 83%|████████▎ | 415/500 [01:45<00:21,  3.99it/s]

{'loss': 1.1466, 'grad_norm': 0.7081167101860046, 'learning_rate': 3.469387755102041e-05, 'epoch': 5.12}


 83%|████████▎ | 416/500 [01:45<00:21,  3.99it/s]

{'loss': 0.8053, 'grad_norm': 0.5530222654342651, 'learning_rate': 3.428571428571429e-05, 'epoch': 5.14}


 83%|████████▎ | 417/500 [01:46<00:20,  3.98it/s]

{'loss': 0.9022, 'grad_norm': 0.6405904293060303, 'learning_rate': 3.387755102040816e-05, 'epoch': 5.15}


 84%|████████▎ | 418/500 [01:46<00:20,  4.00it/s]

{'loss': 1.1483, 'grad_norm': 0.9387404322624207, 'learning_rate': 3.346938775510204e-05, 'epoch': 5.16}


 84%|████████▍ | 419/500 [01:46<00:20,  3.99it/s]

{'loss': 1.1012, 'grad_norm': 0.9407128691673279, 'learning_rate': 3.306122448979592e-05, 'epoch': 5.17}


 84%|████████▍ | 420/500 [01:46<00:20,  3.97it/s]

{'loss': 1.2519, 'grad_norm': 0.685187816619873, 'learning_rate': 3.265306122448979e-05, 'epoch': 5.19}


 84%|████████▍ | 421/500 [01:47<00:19,  3.96it/s]

{'loss': 1.2278, 'grad_norm': 0.8421305418014526, 'learning_rate': 3.224489795918367e-05, 'epoch': 5.2}


 84%|████████▍ | 422/500 [01:47<00:20,  3.88it/s]

{'loss': 0.7489, 'grad_norm': 0.613745927810669, 'learning_rate': 3.183673469387755e-05, 'epoch': 5.21}


 85%|████████▍ | 423/500 [01:47<00:20,  3.77it/s]

{'loss': 1.0859, 'grad_norm': 0.8859882354736328, 'learning_rate': 3.142857142857143e-05, 'epoch': 5.22}


 85%|████████▍ | 424/500 [01:47<00:20,  3.76it/s]

{'loss': 1.1073, 'grad_norm': 0.6056309938430786, 'learning_rate': 3.102040816326531e-05, 'epoch': 5.23}


 85%|████████▌ | 425/500 [01:48<00:19,  3.77it/s]

{'loss': 1.0324, 'grad_norm': 0.6032839417457581, 'learning_rate': 3.061224489795919e-05, 'epoch': 5.25}


 85%|████████▌ | 426/500 [01:48<00:19,  3.77it/s]

{'loss': 1.0339, 'grad_norm': 0.8209232091903687, 'learning_rate': 3.0204081632653065e-05, 'epoch': 5.26}


 85%|████████▌ | 427/500 [01:48<00:19,  3.77it/s]

{'loss': 1.1561, 'grad_norm': 0.7172107696533203, 'learning_rate': 2.9795918367346944e-05, 'epoch': 5.27}


 86%|████████▌ | 428/500 [01:49<00:18,  3.80it/s]

{'loss': 0.9376, 'grad_norm': 0.5894507169723511, 'learning_rate': 2.938775510204082e-05, 'epoch': 5.28}


 86%|████████▌ | 429/500 [01:49<00:18,  3.74it/s]

{'loss': 1.061, 'grad_norm': 0.6849045157432556, 'learning_rate': 2.8979591836734692e-05, 'epoch': 5.3}


 86%|████████▌ | 430/500 [01:49<00:18,  3.78it/s]

{'loss': 1.0826, 'grad_norm': 0.5803796648979187, 'learning_rate': 2.857142857142857e-05, 'epoch': 5.31}


 86%|████████▌ | 431/500 [01:49<00:18,  3.81it/s]

{'loss': 0.9978, 'grad_norm': 0.6621502637863159, 'learning_rate': 2.816326530612245e-05, 'epoch': 5.32}


 86%|████████▋ | 432/500 [01:50<00:17,  3.85it/s]

{'loss': 1.0896, 'grad_norm': 0.565870463848114, 'learning_rate': 2.775510204081633e-05, 'epoch': 5.33}


 87%|████████▋ | 433/500 [01:50<00:17,  3.87it/s]

{'loss': 1.0679, 'grad_norm': 0.5517186522483826, 'learning_rate': 2.7346938775510205e-05, 'epoch': 5.35}


 87%|████████▋ | 434/500 [01:50<00:16,  3.89it/s]

{'loss': 0.9151, 'grad_norm': 0.6491119265556335, 'learning_rate': 2.6938775510204084e-05, 'epoch': 5.36}


 87%|████████▋ | 435/500 [01:50<00:16,  3.92it/s]

{'loss': 0.8887, 'grad_norm': 0.5727147459983826, 'learning_rate': 2.6530612244897963e-05, 'epoch': 5.37}


 87%|████████▋ | 436/500 [01:51<00:16,  3.96it/s]

{'loss': 1.4195, 'grad_norm': 0.7633095383644104, 'learning_rate': 2.612244897959184e-05, 'epoch': 5.38}


 87%|████████▋ | 437/500 [01:51<00:15,  3.98it/s]

{'loss': 1.2472, 'grad_norm': 0.6389709115028381, 'learning_rate': 2.5714285714285714e-05, 'epoch': 5.4}


 88%|████████▊ | 438/500 [01:51<00:15,  3.99it/s]

{'loss': 1.0239, 'grad_norm': 0.6012716889381409, 'learning_rate': 2.530612244897959e-05, 'epoch': 5.41}


 88%|████████▊ | 439/500 [01:51<00:15,  3.95it/s]

{'loss': 1.1132, 'grad_norm': 0.744253933429718, 'learning_rate': 2.489795918367347e-05, 'epoch': 5.42}


 88%|████████▊ | 440/500 [01:52<00:15,  3.94it/s]

{'loss': 0.8895, 'grad_norm': 0.9457151293754578, 'learning_rate': 2.448979591836735e-05, 'epoch': 5.43}


 88%|████████▊ | 441/500 [01:52<00:15,  3.84it/s]

{'loss': 1.0301, 'grad_norm': 0.836777925491333, 'learning_rate': 2.4081632653061224e-05, 'epoch': 5.44}


 88%|████████▊ | 442/500 [01:52<00:15,  3.85it/s]

{'loss': 1.0669, 'grad_norm': 0.6379890441894531, 'learning_rate': 2.3673469387755103e-05, 'epoch': 5.46}


 89%|████████▊ | 443/500 [01:52<00:14,  3.84it/s]

{'loss': 0.9887, 'grad_norm': 0.5744550228118896, 'learning_rate': 2.326530612244898e-05, 'epoch': 5.47}


 89%|████████▉ | 444/500 [01:53<00:14,  3.74it/s]

{'loss': 1.6722, 'grad_norm': 0.7606108784675598, 'learning_rate': 2.2857142857142858e-05, 'epoch': 5.48}


 89%|████████▉ | 445/500 [01:53<00:14,  3.78it/s]

{'loss': 0.9609, 'grad_norm': 0.5661510825157166, 'learning_rate': 2.2448979591836737e-05, 'epoch': 5.49}


 89%|████████▉ | 446/500 [01:53<00:14,  3.82it/s]

{'loss': 0.9857, 'grad_norm': 0.9024667739868164, 'learning_rate': 2.2040816326530613e-05, 'epoch': 5.51}


 89%|████████▉ | 447/500 [01:53<00:13,  3.88it/s]

{'loss': 0.8312, 'grad_norm': 0.5743830800056458, 'learning_rate': 2.1632653061224492e-05, 'epoch': 5.52}


 90%|████████▉ | 448/500 [01:54<00:13,  3.91it/s]

{'loss': 1.1485, 'grad_norm': 0.6106956005096436, 'learning_rate': 2.1224489795918368e-05, 'epoch': 5.53}


 90%|████████▉ | 449/500 [01:54<00:13,  3.91it/s]

{'loss': 0.9361, 'grad_norm': 0.5805802941322327, 'learning_rate': 2.0816326530612247e-05, 'epoch': 5.54}


 90%|█████████ | 450/500 [01:54<00:12,  3.93it/s]

{'loss': 1.201, 'grad_norm': 0.6036517024040222, 'learning_rate': 2.0408163265306123e-05, 'epoch': 5.56}


 90%|█████████ | 451/500 [01:54<00:12,  3.96it/s]

{'loss': 1.1315, 'grad_norm': 0.5602678656578064, 'learning_rate': 2e-05, 'epoch': 5.57}


 90%|█████████ | 452/500 [01:55<00:12,  3.98it/s]

{'loss': 1.0104, 'grad_norm': 0.5868871212005615, 'learning_rate': 1.9591836734693877e-05, 'epoch': 5.58}


 91%|█████████ | 453/500 [01:55<00:11,  4.00it/s]

{'loss': 1.2297, 'grad_norm': 0.6645532250404358, 'learning_rate': 1.9183673469387756e-05, 'epoch': 5.59}


 91%|█████████ | 454/500 [01:55<00:11,  4.00it/s]

{'loss': 1.5919, 'grad_norm': 0.9820370674133301, 'learning_rate': 1.8775510204081632e-05, 'epoch': 5.6}


 91%|█████████ | 455/500 [01:55<00:11,  4.04it/s]

{'loss': 0.9968, 'grad_norm': 0.991390585899353, 'learning_rate': 1.836734693877551e-05, 'epoch': 5.62}


 91%|█████████ | 456/500 [01:56<00:11,  3.99it/s]

{'loss': 1.1522, 'grad_norm': 0.6191473007202148, 'learning_rate': 1.7959183673469387e-05, 'epoch': 5.63}


 91%|█████████▏| 457/500 [01:56<00:10,  4.00it/s]

{'loss': 1.0184, 'grad_norm': 0.605698823928833, 'learning_rate': 1.7551020408163266e-05, 'epoch': 5.64}


 92%|█████████▏| 458/500 [01:56<00:10,  4.02it/s]

{'loss': 0.9371, 'grad_norm': 0.5277153253555298, 'learning_rate': 1.7142857142857145e-05, 'epoch': 5.65}


 92%|█████████▏| 459/500 [01:56<00:10,  4.02it/s]

{'loss': 0.9541, 'grad_norm': 0.5089048147201538, 'learning_rate': 1.673469387755102e-05, 'epoch': 5.67}


 92%|█████████▏| 460/500 [01:57<00:09,  4.04it/s]

{'loss': 1.3525, 'grad_norm': 0.6765568256378174, 'learning_rate': 1.6326530612244897e-05, 'epoch': 5.68}


 92%|█████████▏| 461/500 [01:57<00:09,  4.05it/s]

{'loss': 0.9812, 'grad_norm': 0.6300511956214905, 'learning_rate': 1.5918367346938776e-05, 'epoch': 5.69}


 92%|█████████▏| 462/500 [01:57<00:09,  4.02it/s]

{'loss': 0.9249, 'grad_norm': 0.589173436164856, 'learning_rate': 1.5510204081632655e-05, 'epoch': 5.7}


 93%|█████████▎| 463/500 [01:57<00:09,  4.03it/s]

{'loss': 1.1208, 'grad_norm': 0.7077599167823792, 'learning_rate': 1.5102040816326532e-05, 'epoch': 5.72}


 93%|█████████▎| 464/500 [01:58<00:08,  4.00it/s]

{'loss': 1.0823, 'grad_norm': 0.6846060156822205, 'learning_rate': 1.469387755102041e-05, 'epoch': 5.73}


 93%|█████████▎| 465/500 [01:58<00:08,  4.02it/s]

{'loss': 1.0237, 'grad_norm': 0.6358457803726196, 'learning_rate': 1.4285714285714285e-05, 'epoch': 5.74}


 93%|█████████▎| 466/500 [01:58<00:08,  4.01it/s]

{'loss': 0.9581, 'grad_norm': 0.657516360282898, 'learning_rate': 1.3877551020408165e-05, 'epoch': 5.75}


 93%|█████████▎| 467/500 [01:58<00:08,  4.01it/s]

{'loss': 1.0021, 'grad_norm': 0.6042182445526123, 'learning_rate': 1.3469387755102042e-05, 'epoch': 5.77}


 94%|█████████▎| 468/500 [01:59<00:07,  4.02it/s]

{'loss': 1.1247, 'grad_norm': 0.6976016163825989, 'learning_rate': 1.306122448979592e-05, 'epoch': 5.78}


 94%|█████████▍| 469/500 [01:59<00:07,  4.00it/s]

{'loss': 0.9091, 'grad_norm': 0.5635151863098145, 'learning_rate': 1.2653061224489795e-05, 'epoch': 5.79}


 94%|█████████▍| 470/500 [01:59<00:07,  3.95it/s]

{'loss': 0.9711, 'grad_norm': 0.6473380923271179, 'learning_rate': 1.2244897959183674e-05, 'epoch': 5.8}


 94%|█████████▍| 471/500 [01:59<00:07,  3.97it/s]

{'loss': 1.492, 'grad_norm': 0.7519372701644897, 'learning_rate': 1.1836734693877552e-05, 'epoch': 5.81}


 94%|█████████▍| 472/500 [02:00<00:07,  3.95it/s]

{'loss': 1.2675, 'grad_norm': 0.6353920102119446, 'learning_rate': 1.1428571428571429e-05, 'epoch': 5.83}


 95%|█████████▍| 473/500 [02:00<00:06,  3.98it/s]

{'loss': 0.9271, 'grad_norm': 0.5736612677574158, 'learning_rate': 1.1020408163265306e-05, 'epoch': 5.84}


 95%|█████████▍| 474/500 [02:00<00:06,  3.97it/s]

{'loss': 1.2207, 'grad_norm': 0.6310171484947205, 'learning_rate': 1.0612244897959184e-05, 'epoch': 5.85}


 95%|█████████▌| 475/500 [02:00<00:06,  4.01it/s]

{'loss': 1.0533, 'grad_norm': 0.6180223226547241, 'learning_rate': 1.0204081632653061e-05, 'epoch': 5.86}


 95%|█████████▌| 476/500 [02:01<00:05,  4.02it/s]

{'loss': 0.7291, 'grad_norm': 0.5512651205062866, 'learning_rate': 9.795918367346939e-06, 'epoch': 5.88}


 95%|█████████▌| 477/500 [02:01<00:05,  4.00it/s]

{'loss': 1.0294, 'grad_norm': 0.5967519283294678, 'learning_rate': 9.387755102040816e-06, 'epoch': 5.89}


 96%|█████████▌| 478/500 [02:01<00:05,  3.89it/s]

{'loss': 1.1257, 'grad_norm': 0.6464192271232605, 'learning_rate': 8.979591836734694e-06, 'epoch': 5.9}


 96%|█████████▌| 479/500 [02:01<00:05,  3.81it/s]

{'loss': 1.0189, 'grad_norm': 0.6007264852523804, 'learning_rate': 8.571428571428573e-06, 'epoch': 5.91}


 96%|█████████▌| 480/500 [02:02<00:05,  3.86it/s]

{'loss': 1.3914, 'grad_norm': 0.6176788210868835, 'learning_rate': 8.163265306122448e-06, 'epoch': 5.93}


 96%|█████████▌| 481/500 [02:02<00:04,  3.88it/s]

{'loss': 1.2462, 'grad_norm': 0.6428161263465881, 'learning_rate': 7.755102040816327e-06, 'epoch': 5.94}


 96%|█████████▋| 482/500 [02:02<00:04,  3.91it/s]

{'loss': 0.9806, 'grad_norm': 0.6033040881156921, 'learning_rate': 7.346938775510205e-06, 'epoch': 5.95}


 97%|█████████▋| 483/500 [02:02<00:04,  3.95it/s]

{'loss': 0.9316, 'grad_norm': 0.5347486734390259, 'learning_rate': 6.938775510204082e-06, 'epoch': 5.96}


 97%|█████████▋| 484/500 [02:03<00:04,  3.96it/s]

{'loss': 1.393, 'grad_norm': 0.6629508137702942, 'learning_rate': 6.53061224489796e-06, 'epoch': 5.98}


 97%|█████████▋| 485/500 [02:03<00:03,  3.97it/s]

{'loss': 1.0085, 'grad_norm': 0.6137326955795288, 'learning_rate': 6.122448979591837e-06, 'epoch': 5.99}


 97%|█████████▋| 486/500 [02:03<00:03,  4.17it/s]

{'loss': 1.3308, 'grad_norm': 0.9032194018363953, 'learning_rate': 5.7142857142857145e-06, 'epoch': 6.0}


 97%|█████████▋| 487/500 [02:03<00:03,  4.11it/s]

{'loss': 1.1408, 'grad_norm': 0.5805677175521851, 'learning_rate': 5.306122448979592e-06, 'epoch': 6.01}


 98%|█████████▊| 488/500 [02:04<00:02,  4.09it/s]

{'loss': 1.1637, 'grad_norm': 0.6858969926834106, 'learning_rate': 4.897959183673469e-06, 'epoch': 6.02}


 98%|█████████▊| 489/500 [02:04<00:02,  4.07it/s]

{'loss': 1.2317, 'grad_norm': 0.7328004837036133, 'learning_rate': 4.489795918367347e-06, 'epoch': 6.04}


 98%|█████████▊| 490/500 [02:04<00:02,  3.86it/s]

{'loss': 0.932, 'grad_norm': 0.5643309950828552, 'learning_rate': 4.081632653061224e-06, 'epoch': 6.05}


 98%|█████████▊| 491/500 [02:04<00:02,  3.80it/s]

{'loss': 1.1817, 'grad_norm': 0.680317223072052, 'learning_rate': 3.6734693877551024e-06, 'epoch': 6.06}


 98%|█████████▊| 492/500 [02:05<00:02,  3.69it/s]

{'loss': 0.7601, 'grad_norm': 0.6423470377922058, 'learning_rate': 3.26530612244898e-06, 'epoch': 6.07}


 99%|█████████▊| 493/500 [02:05<00:01,  3.71it/s]

{'loss': 1.0771, 'grad_norm': 0.5771646499633789, 'learning_rate': 2.8571428571428573e-06, 'epoch': 6.09}


 99%|█████████▉| 494/500 [02:05<00:01,  3.72it/s]

{'loss': 1.16, 'grad_norm': 0.7037580013275146, 'learning_rate': 2.4489795918367347e-06, 'epoch': 6.1}


 99%|█████████▉| 495/500 [02:06<00:01,  3.67it/s]

{'loss': 1.1229, 'grad_norm': 0.5673518776893616, 'learning_rate': 2.040816326530612e-06, 'epoch': 6.11}


 99%|█████████▉| 496/500 [02:06<00:01,  3.70it/s]

{'loss': 1.13, 'grad_norm': 0.7319287657737732, 'learning_rate': 1.63265306122449e-06, 'epoch': 6.12}


 99%|█████████▉| 497/500 [02:06<00:00,  3.72it/s]

{'loss': 1.4814, 'grad_norm': 0.8812659382820129, 'learning_rate': 1.2244897959183673e-06, 'epoch': 6.14}


100%|█████████▉| 498/500 [02:06<00:00,  3.57it/s]

{'loss': 1.0348, 'grad_norm': 0.7232120633125305, 'learning_rate': 8.16326530612245e-07, 'epoch': 6.15}


100%|█████████▉| 499/500 [02:07<00:00,  3.62it/s]

{'loss': 0.8365, 'grad_norm': 0.6450437903404236, 'learning_rate': 4.081632653061225e-07, 'epoch': 6.16}


100%|██████████| 500/500 [02:07<00:00,  3.50it/s]

{'loss': 1.2415, 'grad_norm': 0.7012266516685486, 'learning_rate': 0.0, 'epoch': 6.17}


100%|██████████| 500/500 [02:11<00:00,  3.79it/s]

{'train_runtime': 131.9597, 'train_samples_per_second': 60.625, 'train_steps_per_second': 3.789, 'train_loss': 1.472257623553276, 'epoch': 6.17}





TrainOutput(global_step=500, training_loss=1.472257623553276, metrics={'train_runtime': 131.9597, 'train_samples_per_second': 60.625, 'train_steps_per_second': 3.789, 'total_flos': 262511987392512.0, 'train_loss': 1.472257623553276, 'epoch': 6.172839506172839})

In [20]:
torch.save(model.state_dict(), 'lora.pt')

In [21]:
xcacs

NameError: name 'xcacs' is not defined

In [24]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import get_peft_model, LoraConfig, TaskType

cache_dir = "models"
modelID = "openai-community/gpt2"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(modelID, cache_dir=cache_dir)

# Set padding token
tokenizer.padding_side = "right"               # Set padding side to left
tokenizer.pad_token = tokenizer.eos_token      # Using eos_token as pad_token

model = AutoModelForCausalLM.from_pretrained(modelID, device_map='auto', cache_dir=cache_dir)

model.config.pad_token_id = tokenizer.pad_token_id  # Set the pad_token_id in the model config

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model = model.to(device)
model.load_state_dict(torch.load("lora.pt", map_location=device))

# Function for inference
def generate_job_title(input_text):
    # Prepare the input
    input_text = f"{input_text}"
    input_ids = tokenizer.encode(input_text, return_tensors='pt')

    # Create attention mask
    attention_mask = torch.ones(input_ids.shape, dtype=torch.long)  # 1 for real tokens, 0 for padding

    # Move tensors to the device (GPU if available)
    input_ids = input_ids.to(model.device)
    attention_mask = attention_mask.to(model.device)

    # Generate text with the model
    generated_outputs = model.generate(
        input_ids,
        attention_mask=attention_mask,
        max_length=input_ids.shape[-1] + 20,  # Max length is input length + 40
        num_return_sequences=1,
        do_sample=True,                       # Enable sampling
        temperature=0.9,                     # Control randomness
        top_k=50,                            # Top-K sampling
        top_p=0.95,                          # Top-P (nucleus) sampling
        pad_token_id=tokenizer.eos_token_id  # Padding token id
    )

    # Decode and return the generated text
    generated_text = tokenizer.decode(generated_outputs[0], skip_special_tokens=True)
    return generated_text

input_text = """Extract the job title from the provided text\ntext: battery state estimation engineer affirmative for women\njob title:"""
input_text = """Extract the job title from the provided text\ntext: senior devops engineer 2 jobs max in last 5 years candidates from consulting firms will be rejected\njob title:"""
input_text = """Extract the job title from the provided text\ntext: senior lighting product development engineer and product line manager\njob title:"""

# input_text = "Job Title: pharmacist at community pharmacy => Extracted Title: "
predicted_title = generate_job_title(input_text)
print(f"{predicted_title}")


Extract the job title from the provided text
text: senior lighting product development engineer and product line manager
job title: lighting product development engineer <STOP> <STOP> <STOP> <STOP>


In [None]:
xasdas

In [26]:
import torch
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model

cache_dir = "models"
modelID = "openai-community/gpt2"

device = 'cuda' if torch.cuda.is_available() else 'cpu'

tokenizer = AutoTokenizer.from_pretrained(modelID, cache_dir=cache_dir)

# Set padding token
tokenizer.padding_side = "left"                # Set padding side to left
tokenizer.pad_token = tokenizer.eos_token      # Using eos_token as pad_token

model = AutoModelForCausalLM.from_pretrained(modelID, device_map='auto', cache_dir=cache_dir)

model.config.pad_token_id = tokenizer.pad_token_id  # Set the pad_token_id in the model config

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
model = model.to(device)
model.load_state_dict(torch.load("lora.pt", map_location=device))

with torch.no_grad():
    batch = tokenizer("""Extract the job title from the provided text
text: senior devops engineer 2 jobs max in last 5 years candidates from consulting firms will be rejected
job title:"""
        , return_tensors='pt').to(device)
    output_tokens = model.generate(**batch, max_new_tokens=25)

print('\n\n', tokenizer.decode(output_tokens[0], skip_special_tokens=True))

Setting `pad_token_id` to `eos_token_id`:None for open-end generation.




 Extract the job title from the provided text
text: senior devops engineer 2 jobs max in last 5 years candidates from consulting firms will be rejected
job title: devops engineer <STOP> <STOP> <STOP> <STOP> <STOP> <ST
