## Setup

In [2]:
import random
import torch
import pandas as pd
from datasets import Dataset
import peft
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)

set_seed()

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
mistral7b = 'mistralai/Mistral-7B-v0.1'
# STEP 1. Check and make sure you're using the right model and notebook here.
model_name = mistral7b

## EDA

In [4]:
df = pd.read_csv("frankenstein_chunks.csv")
df.head()

Unnamed: 0,text
0,﻿The Project Gutenberg eBook of Frankenstein; ...
1,Further corrections by Menno de Leeuw.\n\n\n**...
2,"I am already far north of London, and as I wal..."
3,Its productions and features may be without ex...
4,But supposing all these conjectures to be fals...


In [5]:
print("Dataframe Info:")
print(df.info())
print("\n")
print("Dataframe Description:")
print(df.describe())
print("\n")
print("Number of unique values in each column:")
print(df.nunique())
random_index= random.randint(0, len(df) - 1)
df.loc[random_index, 'text']

Dataframe Info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 481 entries, 0 to 480
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    481 non-null    object
dtypes: object(1)
memory usage: 3.9+ KB
None


Dataframe Description:
                                                     text
count                                                 481
unique                                                481
top     ﻿The Project Gutenberg eBook of Frankenstein; ...
freq                                                    1


Number of unique values in each column:
text    481
dtype: int64


'The thatch had fallen in, the walls were unplastered, and the\ndoor was off its hinges. I ordered it to be repaired, bought some\nfurniture, and took possession, an incident which would doubtless have\noccasioned some surprise had not all the senses of the cottagers been\nbenumbed by want and squalid poverty. As it was, I lived ungazed at\nand unmolested, hardly thanked for the pittance of food and clothes\nwhich I gave, so much does suffering blunt even the coarsest sensations\nof men.\n\nIn this retreat I devoted the morning to labour; but in the evening,\nwhen the weather permitted, I walked on the stony beach of the sea to\nlisten to the waves as they roared and dashed at my feet. It was a\nmonotonous yet ever-changing scene. I thought of Switzerland; it was\nfar different from this desolate and appalling landscape. '

In [6]:
df.isnull().sum()

text    0
dtype: int64

In [7]:
# Now we'll quickly convert this to a train/test split
from sklearn.model_selection import train_test_split
train_df, test_df = train_test_split(df, test_size=0.2)

# STEP 2. Convert the train_df and test_df from Pandas into Hugging Face Datasets
train_dataset = Dataset.from_pandas(train_df)
test_dataset = Dataset.from_pandas(test_df)


## Model Import and Tokenization

In [8]:
import bitsandbytes as bnb

print("\n\nModel is running on:\n")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Check CUDA availability for torch
print("CUDA available for PyTorch:", torch.cuda.is_available())

# Check CUDA availability for bitsandbytes
print("CUDA available for bitsandbytes:")



Model is running on:

cuda
CUDA available for PyTorch: True
CUDA available for bitsandbytes:


In [9]:
import bitsandbytes
print("\n\nModel is running on:" + "\n")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)



Model is running on:

cuda


In [10]:
quant_config = BitsAndBytesConfig(
  # STEP 3. Pass the appropriate parameters here to 4-bit quantize the model, then instantiate the model and check what it's running on.
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16
)
model = AutoModelForCausalLM.from_pretrained(model_name, quantization_config = quant_config)

`low_cpu_mem_usage` was None, now set to True since model is quantized.
Loading checkpoint shards: 100%|██████████| 2/2 [00:21<00:00, 10.91s/it]


In [12]:
print(model.device)

cuda:0


In [14]:
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

# STEP 4. Prepare the model for QLoRA. Configure LoRA for our finetuning run. Then tokenize the data.
model = prepare_model_for_kbit_training(model)
config = LoraConfig(
    r=32,
    lora_alpha=64,
    lora_dropout=0.05,
    task_type="CAUSAL_LM"
)

model = get_peft_model(model, config)
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenized_train_dataset= train_dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)
tokenized_test_dataset = test_dataset.map(lambda examples: tokenizer(examples["text"]), batched=True)

Map: 100%|██████████| 384/384 [00:00<00:00, 6426.49 examples/s]
Map: 100%|██████████| 97/97 [00:00<00:00, 8427.71 examples/s]


## Base Model Evaluation

In [15]:
def generate_text(prompt):
  device = "cuda"
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
  outputs = model.generate(**inputs, max_new_tokens=100)
  output = tokenizer.decode(outputs[0], skip_special_tokens=True)
  return output

In [17]:
# STEP 5. Generate a completion with the base model for informal evaluation.
base_generation = generate_text("I'm afraid I've created a ")
base_generation

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.


"I'm afraid I've created a 2000-level problem with a 100-level solution.\n\nI'm a 2000-level problem.\n\nI'm a 2000-level problem.\n\nI'm a 2000-level problem.\n\nI'm a 2000-level problem.\n\nI'm a 2000-level problem.\n\nI'm a 2"

In [18]:
def calc_perplexity(model):
  total_perplexity = 0
  for row in test_dataset:
    inputs = tokenizer(row['text'], return_tensors="pt")
    input_ids = inputs["input_ids"]
    # Calculate the loss without updating the model
    with torch.no_grad():
        outputs = model(**inputs, labels=input_ids)
    loss = outputs.loss
    # STEP 6. Complete the equation for perplexity.
    perplexity = torch.exp(loss)
    total_perplexity += perplexity

  num_test_rows = len(test_dataset)
  avg_perplexity = total_perplexity / num_test_rows
  return avg_perplexity

base_ppl = calc_perplexity(model)
base_ppl

tensor(8.6594)

## Training

Make sure you can leave your browser open for a while. This took about 15 minutes on a Colab T4 GPU.

In [20]:
import transformers

tokenizer.pad_token = tokenizer.eos_token
model.config.use_cache = False

trainer = transformers.Trainer(
    model=model,
    train_dataset=tokenized_train_dataset,
    args=transformers.TrainingArguments(
        warmup_steps=2,
        fp16=True,
        logging_steps=1,
        save_steps=200,
        output_dir="outputs",
      # STEP 7. Configure the training arguments.
        per_device_train_batch_size=2,
        num_train_epochs=2,
        learning_rate=2e-5,
        optim="paged_adamw_8bit"
    ),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False),
)
# STEP 8. Finetune the model.
trainer.train()

  self.scaler = torch.cuda.amp.GradScaler(**kwargs)
  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
  0%|          | 1/384 [00:02<15:02,  2.36s/it]

{'loss': 2.1985, 'grad_norm': 6.324793338775635, 'learning_rate': 1e-05, 'epoch': 0.01}


  1%|          | 2/384 [00:03<10:58,  1.72s/it]

{'loss': 1.2872, 'grad_norm': 4.1828179359436035, 'learning_rate': 2e-05, 'epoch': 0.01}


  1%|          | 3/384 [00:04<09:33,  1.50s/it]

{'loss': 2.2366, 'grad_norm': 8.60242748260498, 'learning_rate': 1.9947643979057594e-05, 'epoch': 0.02}


  1%|          | 4/384 [00:06<08:51,  1.40s/it]

{'loss': 2.2372, 'grad_norm': 5.317896842956543, 'learning_rate': 1.9895287958115186e-05, 'epoch': 0.02}


  1%|▏         | 5/384 [00:07<08:32,  1.35s/it]

{'loss': 2.401, 'grad_norm': 4.972834587097168, 'learning_rate': 1.9842931937172775e-05, 'epoch': 0.03}


  2%|▏         | 6/384 [00:08<08:16,  1.31s/it]

{'loss': 2.3001, 'grad_norm': 4.754769802093506, 'learning_rate': 1.9790575916230367e-05, 'epoch': 0.03}


  2%|▏         | 7/384 [00:09<08:05,  1.29s/it]

{'loss': 2.255, 'grad_norm': 5.092869758605957, 'learning_rate': 1.973821989528796e-05, 'epoch': 0.04}


  2%|▏         | 8/384 [00:11<08:22,  1.34s/it]

{'loss': 1.9886, 'grad_norm': 8.324347496032715, 'learning_rate': 1.968586387434555e-05, 'epoch': 0.04}


  2%|▏         | 9/384 [00:12<08:10,  1.31s/it]

{'loss': 2.1838, 'grad_norm': 6.041589260101318, 'learning_rate': 1.9633507853403143e-05, 'epoch': 0.05}


  3%|▎         | 10/384 [00:13<08:07,  1.30s/it]

{'loss': 2.1153, 'grad_norm': 3.0792086124420166, 'learning_rate': 1.9581151832460736e-05, 'epoch': 0.05}


  3%|▎         | 11/384 [00:15<08:22,  1.35s/it]

{'loss': 2.0236, 'grad_norm': 5.902219772338867, 'learning_rate': 1.9528795811518328e-05, 'epoch': 0.06}


  3%|▎         | 12/384 [00:16<08:30,  1.37s/it]

{'loss': 2.2039, 'grad_norm': 4.662735462188721, 'learning_rate': 1.947643979057592e-05, 'epoch': 0.06}


  3%|▎         | 13/384 [00:18<08:37,  1.39s/it]

{'loss': 2.0808, 'grad_norm': 3.948413372039795, 'learning_rate': 1.9424083769633512e-05, 'epoch': 0.07}


  4%|▎         | 14/384 [00:19<08:43,  1.41s/it]

{'loss': 2.1154, 'grad_norm': 4.277752876281738, 'learning_rate': 1.93717277486911e-05, 'epoch': 0.07}


  4%|▍         | 15/384 [00:20<08:26,  1.37s/it]

{'loss': 2.2586, 'grad_norm': 5.6784443855285645, 'learning_rate': 1.9319371727748693e-05, 'epoch': 0.08}


  4%|▍         | 16/384 [00:22<08:15,  1.35s/it]

{'loss': 2.1785, 'grad_norm': nan, 'learning_rate': 1.9319371727748693e-05, 'epoch': 0.08}


  4%|▍         | 17/384 [00:23<08:03,  1.32s/it]

{'loss': 1.3108, 'grad_norm': 6.273836135864258, 'learning_rate': 1.9267015706806285e-05, 'epoch': 0.09}


  5%|▍         | 18/384 [00:24<07:57,  1.30s/it]

{'loss': 2.065, 'grad_norm': 5.976648807525635, 'learning_rate': 1.9214659685863877e-05, 'epoch': 0.09}


  5%|▍         | 19/384 [00:26<08:12,  1.35s/it]

{'loss': 2.0812, 'grad_norm': 5.787518501281738, 'learning_rate': 1.9162303664921466e-05, 'epoch': 0.1}


  5%|▌         | 20/384 [00:27<08:04,  1.33s/it]

{'loss': 1.7196, 'grad_norm': 4.83299446105957, 'learning_rate': 1.9109947643979058e-05, 'epoch': 0.1}


  5%|▌         | 21/384 [00:28<07:57,  1.32s/it]

{'loss': 2.1501, 'grad_norm': 4.245680332183838, 'learning_rate': 1.905759162303665e-05, 'epoch': 0.11}


  6%|▌         | 22/384 [00:30<08:09,  1.35s/it]

{'loss': 1.9646, 'grad_norm': 4.7882866859436035, 'learning_rate': 1.9005235602094243e-05, 'epoch': 0.11}


  6%|▌         | 23/384 [00:31<07:58,  1.33s/it]

{'loss': 2.1283, 'grad_norm': 5.1942901611328125, 'learning_rate': 1.895287958115183e-05, 'epoch': 0.12}


  6%|▋         | 24/384 [00:32<07:49,  1.30s/it]

{'loss': 2.1705, 'grad_norm': 4.441123008728027, 'learning_rate': 1.8900523560209423e-05, 'epoch': 0.12}


  7%|▋         | 25/384 [00:33<07:46,  1.30s/it]

{'loss': 2.1482, 'grad_norm': 3.566265821456909, 'learning_rate': 1.8848167539267016e-05, 'epoch': 0.13}


  7%|▋         | 26/384 [00:35<07:42,  1.29s/it]

{'loss': 2.044, 'grad_norm': 6.620034694671631, 'learning_rate': 1.8795811518324608e-05, 'epoch': 0.14}


  7%|▋         | 27/384 [00:36<07:57,  1.34s/it]

{'loss': 2.1528, 'grad_norm': 4.818078994750977, 'learning_rate': 1.87434554973822e-05, 'epoch': 0.14}


  7%|▋         | 28/384 [00:38<08:07,  1.37s/it]

{'loss': 2.2164, 'grad_norm': 5.686012268066406, 'learning_rate': 1.8691099476439792e-05, 'epoch': 0.15}


  8%|▊         | 29/384 [00:39<07:55,  1.34s/it]

{'loss': 2.102, 'grad_norm': 4.8455891609191895, 'learning_rate': 1.8638743455497384e-05, 'epoch': 0.15}


  8%|▊         | 30/384 [00:40<08:03,  1.37s/it]

{'loss': 2.0609, 'grad_norm': 5.550879955291748, 'learning_rate': 1.8586387434554976e-05, 'epoch': 0.16}


  8%|▊         | 31/384 [00:42<07:50,  1.33s/it]

{'loss': 1.8921, 'grad_norm': 4.42840051651001, 'learning_rate': 1.853403141361257e-05, 'epoch': 0.16}


  8%|▊         | 32/384 [00:43<07:44,  1.32s/it]

{'loss': 1.4567, 'grad_norm': 5.085615158081055, 'learning_rate': 1.8481675392670157e-05, 'epoch': 0.17}


  9%|▊         | 33/384 [00:44<07:38,  1.31s/it]

{'loss': 1.5427, 'grad_norm': 7.930014133453369, 'learning_rate': 1.842931937172775e-05, 'epoch': 0.17}


  9%|▉         | 34/384 [00:46<07:49,  1.34s/it]

{'loss': 2.1994, 'grad_norm': 5.270382881164551, 'learning_rate': 1.837696335078534e-05, 'epoch': 0.18}


  9%|▉         | 35/384 [00:47<07:40,  1.32s/it]

{'loss': 2.1613, 'grad_norm': 4.795782566070557, 'learning_rate': 1.8324607329842934e-05, 'epoch': 0.18}


  9%|▉         | 36/384 [00:48<07:31,  1.30s/it]

{'loss': 1.8096, 'grad_norm': 5.824542045593262, 'learning_rate': 1.8272251308900526e-05, 'epoch': 0.19}


 10%|▉         | 37/384 [00:49<07:29,  1.29s/it]

{'loss': 2.1213, 'grad_norm': 6.981709003448486, 'learning_rate': 1.8219895287958115e-05, 'epoch': 0.19}


 10%|▉         | 38/384 [00:51<07:46,  1.35s/it]

{'loss': 1.4474, 'grad_norm': 6.010638236999512, 'learning_rate': 1.8167539267015707e-05, 'epoch': 0.2}


 10%|█         | 39/384 [00:52<07:38,  1.33s/it]

{'loss': 2.0112, 'grad_norm': 4.714521408081055, 'learning_rate': 1.81151832460733e-05, 'epoch': 0.2}


 10%|█         | 40/384 [00:54<07:50,  1.37s/it]

{'loss': 2.1671, 'grad_norm': 4.109230995178223, 'learning_rate': 1.806282722513089e-05, 'epoch': 0.21}


 11%|█         | 41/384 [00:55<07:57,  1.39s/it]

{'loss': 2.1709, 'grad_norm': 5.104012489318848, 'learning_rate': 1.8010471204188483e-05, 'epoch': 0.21}


 11%|█         | 42/384 [00:56<08:01,  1.41s/it]

{'loss': 1.9615, 'grad_norm': 5.769094944000244, 'learning_rate': 1.7958115183246076e-05, 'epoch': 0.22}


 11%|█         | 43/384 [00:58<07:44,  1.36s/it]

{'loss': 1.5914, 'grad_norm': 16.594453811645508, 'learning_rate': 1.7905759162303668e-05, 'epoch': 0.22}


 11%|█▏        | 44/384 [00:59<07:51,  1.39s/it]

{'loss': 2.2444, 'grad_norm': 4.788738250732422, 'learning_rate': 1.785340314136126e-05, 'epoch': 0.23}


 12%|█▏        | 45/384 [01:00<07:38,  1.35s/it]

{'loss': 1.9791, 'grad_norm': 6.28133487701416, 'learning_rate': 1.7801047120418852e-05, 'epoch': 0.23}


 12%|█▏        | 46/384 [01:02<07:29,  1.33s/it]

{'loss': 1.7413, 'grad_norm': 6.157125473022461, 'learning_rate': 1.774869109947644e-05, 'epoch': 0.24}


 12%|█▏        | 47/384 [01:03<07:38,  1.36s/it]

{'loss': 1.3588, 'grad_norm': 5.066906452178955, 'learning_rate': 1.7696335078534033e-05, 'epoch': 0.24}


 12%|█▎        | 48/384 [01:04<07:28,  1.33s/it]

{'loss': 0.7226, 'grad_norm': 4.645292282104492, 'learning_rate': 1.7643979057591625e-05, 'epoch': 0.25}


 13%|█▎        | 49/384 [01:06<07:21,  1.32s/it]

{'loss': 2.2463, 'grad_norm': 5.277648448944092, 'learning_rate': 1.7591623036649217e-05, 'epoch': 0.26}


 13%|█▎        | 50/384 [01:07<07:17,  1.31s/it]

{'loss': 2.001, 'grad_norm': 4.58256721496582, 'learning_rate': 1.7539267015706806e-05, 'epoch': 0.26}


 13%|█▎        | 51/384 [01:08<07:28,  1.35s/it]

{'loss': 2.2299, 'grad_norm': 4.446080684661865, 'learning_rate': 1.7486910994764398e-05, 'epoch': 0.27}


 14%|█▎        | 52/384 [01:10<07:20,  1.33s/it]

{'loss': 1.9652, 'grad_norm': 4.2965087890625, 'learning_rate': 1.743455497382199e-05, 'epoch': 0.27}


 14%|█▍        | 53/384 [01:11<07:34,  1.37s/it]

{'loss': 1.6338, 'grad_norm': 6.280250072479248, 'learning_rate': 1.7382198952879583e-05, 'epoch': 0.28}


 14%|█▍        | 54/384 [01:13<07:38,  1.39s/it]

{'loss': 2.2571, 'grad_norm': 4.260898113250732, 'learning_rate': 1.7329842931937175e-05, 'epoch': 0.28}


 14%|█▍        | 55/384 [01:14<07:41,  1.40s/it]

{'loss': 2.0122, 'grad_norm': 13.4190673828125, 'learning_rate': 1.7277486910994767e-05, 'epoch': 0.29}


 15%|█▍        | 56/384 [01:16<08:14,  1.51s/it]

{'loss': 2.016, 'grad_norm': 5.135417461395264, 'learning_rate': 1.7225130890052356e-05, 'epoch': 0.29}


 15%|█▍        | 57/384 [01:17<08:08,  1.50s/it]

{'loss': 2.0947, 'grad_norm': 5.15279483795166, 'learning_rate': 1.7172774869109948e-05, 'epoch': 0.3}


 15%|█▌        | 58/384 [01:19<08:15,  1.52s/it]

{'loss': 1.2367, 'grad_norm': 5.656447887420654, 'learning_rate': 1.712041884816754e-05, 'epoch': 0.3}


 15%|█▌        | 59/384 [01:20<07:56,  1.47s/it]

{'loss': 2.1177, 'grad_norm': 4.428615570068359, 'learning_rate': 1.7068062827225132e-05, 'epoch': 0.31}


 16%|█▌        | 60/384 [01:22<07:41,  1.42s/it]

{'loss': 1.8503, 'grad_norm': 6.662624835968018, 'learning_rate': 1.7015706806282724e-05, 'epoch': 0.31}


 16%|█▌        | 61/384 [01:23<07:42,  1.43s/it]

{'loss': 1.7515, 'grad_norm': 6.948648452758789, 'learning_rate': 1.6963350785340316e-05, 'epoch': 0.32}


 16%|█▌        | 62/384 [01:24<07:27,  1.39s/it]

{'loss': 1.9245, 'grad_norm': 5.782127380371094, 'learning_rate': 1.691099476439791e-05, 'epoch': 0.32}


 16%|█▋        | 63/384 [01:26<07:30,  1.40s/it]

{'loss': 1.8796, 'grad_norm': 3.7058980464935303, 'learning_rate': 1.6858638743455497e-05, 'epoch': 0.33}


 17%|█▋        | 64/384 [01:27<07:33,  1.42s/it]

{'loss': 2.1566, 'grad_norm': 4.965316295623779, 'learning_rate': 1.680628272251309e-05, 'epoch': 0.33}


 17%|█▋        | 65/384 [01:28<07:18,  1.38s/it]

{'loss': 2.0981, 'grad_norm': 5.203216552734375, 'learning_rate': 1.675392670157068e-05, 'epoch': 0.34}


 17%|█▋        | 66/384 [01:30<07:08,  1.35s/it]

{'loss': 1.9894, 'grad_norm': 5.453594207763672, 'learning_rate': 1.6701570680628274e-05, 'epoch': 0.34}


 17%|█▋        | 67/384 [01:31<07:01,  1.33s/it]

{'loss': 1.3274, 'grad_norm': 5.380238056182861, 'learning_rate': 1.6649214659685866e-05, 'epoch': 0.35}


 18%|█▊        | 68/384 [01:32<06:56,  1.32s/it]

{'loss': 2.0679, 'grad_norm': 4.728832721710205, 'learning_rate': 1.6596858638743455e-05, 'epoch': 0.35}


 18%|█▊        | 69/384 [01:34<06:53,  1.31s/it]

{'loss': 1.6151, 'grad_norm': 6.776437282562256, 'learning_rate': 1.6544502617801047e-05, 'epoch': 0.36}


 18%|█▊        | 70/384 [01:35<06:46,  1.30s/it]

{'loss': 1.5034, 'grad_norm': 6.074085235595703, 'learning_rate': 1.649214659685864e-05, 'epoch': 0.36}


 18%|█▊        | 71/384 [01:36<06:59,  1.34s/it]

{'loss': 1.3531, 'grad_norm': 8.119874000549316, 'learning_rate': 1.643979057591623e-05, 'epoch': 0.37}


 19%|█▉        | 72/384 [01:38<06:51,  1.32s/it]

{'loss': 2.1217, 'grad_norm': 4.894474983215332, 'learning_rate': 1.6387434554973823e-05, 'epoch': 0.38}


 19%|█▉        | 73/384 [01:39<06:44,  1.30s/it]

{'loss': 2.0873, 'grad_norm': 4.933759689331055, 'learning_rate': 1.6335078534031416e-05, 'epoch': 0.38}


 19%|█▉        | 74/384 [01:40<06:40,  1.29s/it]

{'loss': 1.6382, 'grad_norm': 7.60133695602417, 'learning_rate': 1.6282722513089008e-05, 'epoch': 0.39}


 20%|█▉        | 75/384 [01:42<06:52,  1.33s/it]

{'loss': 2.1713, 'grad_norm': 6.205070495605469, 'learning_rate': 1.62303664921466e-05, 'epoch': 0.39}


 20%|█▉        | 76/384 [01:43<07:01,  1.37s/it]

{'loss': 1.8531, 'grad_norm': 4.810950756072998, 'learning_rate': 1.6178010471204192e-05, 'epoch': 0.4}


 20%|██        | 77/384 [01:44<07:07,  1.39s/it]

{'loss': 1.9037, 'grad_norm': 5.1281352043151855, 'learning_rate': 1.612565445026178e-05, 'epoch': 0.4}


 20%|██        | 78/384 [01:46<06:56,  1.36s/it]

{'loss': 1.8772, 'grad_norm': 5.798677921295166, 'learning_rate': 1.6073298429319373e-05, 'epoch': 0.41}


 21%|██        | 79/384 [01:47<06:47,  1.34s/it]

{'loss': 1.7304, 'grad_norm': 5.651444435119629, 'learning_rate': 1.6020942408376965e-05, 'epoch': 0.41}


 21%|██        | 80/384 [01:48<06:41,  1.32s/it]

{'loss': 1.4595, 'grad_norm': 8.950423240661621, 'learning_rate': 1.5968586387434557e-05, 'epoch': 0.42}


 21%|██        | 81/384 [01:50<06:52,  1.36s/it]

{'loss': 1.7093, 'grad_norm': 6.481417179107666, 'learning_rate': 1.5916230366492146e-05, 'epoch': 0.42}


 21%|██▏       | 82/384 [01:51<06:41,  1.33s/it]

{'loss': 1.6214, 'grad_norm': 9.144989013671875, 'learning_rate': 1.5863874345549738e-05, 'epoch': 0.43}


 22%|██▏       | 83/384 [01:52<06:52,  1.37s/it]

{'loss': 1.9072, 'grad_norm': 3.9141485691070557, 'learning_rate': 1.581151832460733e-05, 'epoch': 0.43}


 22%|██▏       | 84/384 [01:54<06:42,  1.34s/it]

{'loss': 2.0626, 'grad_norm': 4.255926132202148, 'learning_rate': 1.5759162303664923e-05, 'epoch': 0.44}


 22%|██▏       | 85/384 [01:55<06:50,  1.37s/it]

{'loss': 2.05, 'grad_norm': 6.577460765838623, 'learning_rate': 1.5706806282722515e-05, 'epoch': 0.44}


 22%|██▏       | 86/384 [01:56<06:41,  1.35s/it]

{'loss': 2.2692, 'grad_norm': 5.977269649505615, 'learning_rate': 1.5654450261780107e-05, 'epoch': 0.45}


 23%|██▎       | 87/384 [01:58<06:36,  1.34s/it]

{'loss': 1.4334, 'grad_norm': 6.545283317565918, 'learning_rate': 1.56020942408377e-05, 'epoch': 0.45}


 23%|██▎       | 88/384 [01:59<06:31,  1.32s/it]

{'loss': 1.6511, 'grad_norm': 8.088180541992188, 'learning_rate': 1.554973821989529e-05, 'epoch': 0.46}


 23%|██▎       | 89/384 [02:00<06:28,  1.32s/it]

{'loss': 1.9694, 'grad_norm': 4.179590702056885, 'learning_rate': 1.5497382198952883e-05, 'epoch': 0.46}


 23%|██▎       | 90/384 [02:02<06:24,  1.31s/it]

{'loss': 1.9089, 'grad_norm': 4.619419097900391, 'learning_rate': 1.5445026178010472e-05, 'epoch': 0.47}


 24%|██▎       | 91/384 [02:03<06:20,  1.30s/it]

{'loss': 2.0008, 'grad_norm': 4.098575592041016, 'learning_rate': 1.5392670157068064e-05, 'epoch': 0.47}


 24%|██▍       | 92/384 [02:04<06:17,  1.29s/it]

{'loss': 2.1195, 'grad_norm': 6.071103572845459, 'learning_rate': 1.5340314136125656e-05, 'epoch': 0.48}


 24%|██▍       | 93/384 [02:06<06:43,  1.39s/it]

{'loss': 1.7122, 'grad_norm': 11.191338539123535, 'learning_rate': 1.528795811518325e-05, 'epoch': 0.48}


 24%|██▍       | 94/384 [02:07<06:33,  1.36s/it]

{'loss': 1.5345, 'grad_norm': 7.632675647735596, 'learning_rate': 1.523560209424084e-05, 'epoch': 0.49}


 25%|██▍       | 95/384 [02:08<06:26,  1.34s/it]

{'loss': 2.4132, 'grad_norm': 5.531858921051025, 'learning_rate': 1.518324607329843e-05, 'epoch': 0.49}


 25%|██▌       | 96/384 [02:10<06:20,  1.32s/it]

{'loss': 1.7674, 'grad_norm': 5.383974075317383, 'learning_rate': 1.5130890052356022e-05, 'epoch': 0.5}


 25%|██▌       | 97/384 [02:11<06:14,  1.31s/it]

{'loss': 2.135, 'grad_norm': 5.348142147064209, 'learning_rate': 1.5078534031413614e-05, 'epoch': 0.51}


 26%|██▌       | 98/384 [02:12<06:10,  1.30s/it]

{'loss': 2.3841, 'grad_norm': 5.595819473266602, 'learning_rate': 1.5026178010471206e-05, 'epoch': 0.51}


 26%|██▌       | 99/384 [02:14<06:22,  1.34s/it]

{'loss': 1.9146, 'grad_norm': 5.46187162399292, 'learning_rate': 1.4973821989528796e-05, 'epoch': 0.52}


 26%|██▌       | 100/384 [02:15<06:31,  1.38s/it]

{'loss': 1.308, 'grad_norm': 6.659438133239746, 'learning_rate': 1.4921465968586389e-05, 'epoch': 0.52}


 26%|██▋       | 101/384 [02:16<06:22,  1.35s/it]

{'loss': 1.9579, 'grad_norm': 4.831873416900635, 'learning_rate': 1.486910994764398e-05, 'epoch': 0.53}


 27%|██▋       | 102/384 [02:18<06:30,  1.38s/it]

{'loss': 1.8895, 'grad_norm': 6.342447280883789, 'learning_rate': 1.4816753926701573e-05, 'epoch': 0.53}


 27%|██▋       | 103/384 [02:19<06:33,  1.40s/it]

{'loss': 1.8016, 'grad_norm': 5.485368251800537, 'learning_rate': 1.4764397905759162e-05, 'epoch': 0.54}


 27%|██▋       | 104/384 [02:21<06:22,  1.37s/it]

{'loss': 1.8206, 'grad_norm': 5.840073585510254, 'learning_rate': 1.4712041884816754e-05, 'epoch': 0.54}


 27%|██▋       | 105/384 [02:22<06:27,  1.39s/it]

{'loss': 2.0841, 'grad_norm': 4.726497650146484, 'learning_rate': 1.4659685863874346e-05, 'epoch': 0.55}


 28%|██▊       | 106/384 [02:23<06:15,  1.35s/it]

{'loss': 2.1482, 'grad_norm': 5.7779741287231445, 'learning_rate': 1.4607329842931938e-05, 'epoch': 0.55}


 28%|██▊       | 107/384 [02:25<06:05,  1.32s/it]

{'loss': 2.2411, 'grad_norm': 5.325514793395996, 'learning_rate': 1.455497382198953e-05, 'epoch': 0.56}


 28%|██▊       | 108/384 [02:26<06:01,  1.31s/it]

{'loss': 0.8779, 'grad_norm': 8.681464195251465, 'learning_rate': 1.450261780104712e-05, 'epoch': 0.56}


 28%|██▊       | 109/384 [02:27<05:58,  1.30s/it]

{'loss': 1.9241, 'grad_norm': 5.524139881134033, 'learning_rate': 1.4450261780104713e-05, 'epoch': 0.57}


 29%|██▊       | 110/384 [02:28<05:55,  1.30s/it]

{'loss': 1.7526, 'grad_norm': 7.29537296295166, 'learning_rate': 1.4397905759162305e-05, 'epoch': 0.57}


 29%|██▉       | 111/384 [02:30<05:50,  1.28s/it]

{'loss': 1.1996, 'grad_norm': 6.076169490814209, 'learning_rate': 1.4345549738219897e-05, 'epoch': 0.58}


 29%|██▉       | 112/384 [02:31<05:48,  1.28s/it]

{'loss': 2.0869, 'grad_norm': 6.092881202697754, 'learning_rate': 1.4293193717277488e-05, 'epoch': 0.58}


 29%|██▉       | 113/384 [02:32<05:46,  1.28s/it]

{'loss': 1.8207, 'grad_norm': 5.230616092681885, 'learning_rate': 1.424083769633508e-05, 'epoch': 0.59}


 30%|██▉       | 114/384 [02:34<05:59,  1.33s/it]

{'loss': 2.0165, 'grad_norm': 4.310360431671143, 'learning_rate': 1.4188481675392672e-05, 'epoch': 0.59}


 30%|██▉       | 115/384 [02:35<05:54,  1.32s/it]

{'loss': 2.2273, 'grad_norm': 6.9682183265686035, 'learning_rate': 1.4136125654450264e-05, 'epoch': 0.6}


 30%|███       | 116/384 [02:37<06:28,  1.45s/it]

{'loss': 2.1981, 'grad_norm': 5.520657062530518, 'learning_rate': 1.4083769633507855e-05, 'epoch': 0.6}


 30%|███       | 117/384 [02:38<06:26,  1.45s/it]

{'loss': 2.1431, 'grad_norm': 5.488803863525391, 'learning_rate': 1.4031413612565445e-05, 'epoch': 0.61}


 31%|███       | 118/384 [02:40<06:21,  1.44s/it]

{'loss': 1.9875, 'grad_norm': 6.113977432250977, 'learning_rate': 1.3979057591623037e-05, 'epoch': 0.61}


 31%|███       | 119/384 [02:41<06:11,  1.40s/it]

{'loss': 2.1218, 'grad_norm': 7.8866286277771, 'learning_rate': 1.392670157068063e-05, 'epoch': 0.62}


 31%|███▏      | 120/384 [02:42<06:02,  1.37s/it]

{'loss': 1.6218, 'grad_norm': 6.795436859130859, 'learning_rate': 1.3874345549738222e-05, 'epoch': 0.62}


 32%|███▏      | 121/384 [02:43<05:51,  1.34s/it]

{'loss': 2.3569, 'grad_norm': 5.125807285308838, 'learning_rate': 1.3821989528795812e-05, 'epoch': 0.63}


 32%|███▏      | 122/384 [02:45<05:45,  1.32s/it]

{'loss': 1.955, 'grad_norm': 4.465867519378662, 'learning_rate': 1.3769633507853404e-05, 'epoch': 0.64}


 32%|███▏      | 123/384 [02:46<05:41,  1.31s/it]

{'loss': 1.9996, 'grad_norm': 4.693528175354004, 'learning_rate': 1.3717277486910996e-05, 'epoch': 0.64}


 32%|███▏      | 124/384 [02:47<05:52,  1.35s/it]

{'loss': 1.2393, 'grad_norm': 4.331855773925781, 'learning_rate': 1.3664921465968589e-05, 'epoch': 0.65}


 33%|███▎      | 125/384 [02:49<05:43,  1.33s/it]

{'loss': 1.458, 'grad_norm': 6.437409400939941, 'learning_rate': 1.361256544502618e-05, 'epoch': 0.65}


 33%|███▎      | 126/384 [02:50<05:38,  1.31s/it]

{'loss': 1.5432, 'grad_norm': 6.595175266265869, 'learning_rate': 1.356020942408377e-05, 'epoch': 0.66}


 33%|███▎      | 127/384 [02:51<05:48,  1.36s/it]

{'loss': 2.0461, 'grad_norm': 4.691247463226318, 'learning_rate': 1.3507853403141362e-05, 'epoch': 0.66}


 33%|███▎      | 128/384 [02:53<05:41,  1.33s/it]

{'loss': 1.445, 'grad_norm': 14.312119483947754, 'learning_rate': 1.3455497382198954e-05, 'epoch': 0.67}


 34%|███▎      | 129/384 [02:54<05:36,  1.32s/it]

{'loss': 2.0032, 'grad_norm': 5.836889266967773, 'learning_rate': 1.3403141361256546e-05, 'epoch': 0.67}


 34%|███▍      | 130/384 [02:55<05:32,  1.31s/it]

{'loss': 1.5517, 'grad_norm': 7.76394510269165, 'learning_rate': 1.3350785340314136e-05, 'epoch': 0.68}


 34%|███▍      | 131/384 [02:57<05:28,  1.30s/it]

{'loss': 2.2141, 'grad_norm': 5.657382011413574, 'learning_rate': 1.3298429319371729e-05, 'epoch': 0.68}


 34%|███▍      | 132/384 [02:58<05:27,  1.30s/it]

{'loss': 1.6534, 'grad_norm': 8.329041481018066, 'learning_rate': 1.324607329842932e-05, 'epoch': 0.69}


 35%|███▍      | 133/384 [02:59<05:37,  1.35s/it]

{'loss': 2.0634, 'grad_norm': 5.842409610748291, 'learning_rate': 1.3193717277486913e-05, 'epoch': 0.69}


 35%|███▍      | 134/384 [03:01<05:33,  1.33s/it]

{'loss': 2.1573, 'grad_norm': 4.7996907234191895, 'learning_rate': 1.3141361256544505e-05, 'epoch': 0.7}


 35%|███▌      | 135/384 [03:02<05:28,  1.32s/it]

{'loss': 1.8295, 'grad_norm': 6.400012969970703, 'learning_rate': 1.3089005235602094e-05, 'epoch': 0.7}


 35%|███▌      | 136/384 [03:03<05:25,  1.31s/it]

{'loss': 1.8031, 'grad_norm': 6.648474216461182, 'learning_rate': 1.3036649214659686e-05, 'epoch': 0.71}


 36%|███▌      | 137/384 [03:05<05:21,  1.30s/it]

{'loss': 1.8157, 'grad_norm': 7.820156097412109, 'learning_rate': 1.2984293193717278e-05, 'epoch': 0.71}


 36%|███▌      | 138/384 [03:06<05:17,  1.29s/it]

{'loss': 1.9818, 'grad_norm': 6.795681476593018, 'learning_rate': 1.293193717277487e-05, 'epoch': 0.72}


 36%|███▌      | 139/384 [03:07<05:15,  1.29s/it]

{'loss': 2.0441, 'grad_norm': 5.03396463394165, 'learning_rate': 1.287958115183246e-05, 'epoch': 0.72}


 36%|███▋      | 140/384 [03:08<05:16,  1.30s/it]

{'loss': 2.1008, 'grad_norm': 5.519693374633789, 'learning_rate': 1.2827225130890053e-05, 'epoch': 0.73}


 37%|███▋      | 141/384 [03:10<05:31,  1.36s/it]

{'loss': 1.0426, 'grad_norm': 4.148629665374756, 'learning_rate': 1.2774869109947645e-05, 'epoch': 0.73}


 37%|███▋      | 142/384 [03:11<05:25,  1.35s/it]

{'loss': 1.71, 'grad_norm': 9.368799209594727, 'learning_rate': 1.2722513089005237e-05, 'epoch': 0.74}


 37%|███▋      | 143/384 [03:13<05:21,  1.34s/it]

{'loss': 1.4305, 'grad_norm': 6.897060871124268, 'learning_rate': 1.2670157068062828e-05, 'epoch': 0.74}


 38%|███▊      | 144/384 [03:14<05:16,  1.32s/it]

{'loss': 2.2293, 'grad_norm': 5.80340576171875, 'learning_rate': 1.261780104712042e-05, 'epoch': 0.75}


 38%|███▊      | 145/384 [03:15<05:26,  1.37s/it]

{'loss': 2.0018, 'grad_norm': 5.19938325881958, 'learning_rate': 1.2565445026178012e-05, 'epoch': 0.76}


 38%|███▊      | 146/384 [03:17<05:18,  1.34s/it]

{'loss': 1.6326, 'grad_norm': 8.52566146850586, 'learning_rate': 1.2513089005235604e-05, 'epoch': 0.76}


 38%|███▊      | 147/384 [03:18<05:13,  1.32s/it]

{'loss': 2.032, 'grad_norm': 5.137462139129639, 'learning_rate': 1.2460732984293196e-05, 'epoch': 0.77}


 39%|███▊      | 148/384 [03:19<05:22,  1.37s/it]

{'loss': 1.8449, 'grad_norm': 5.730991840362549, 'learning_rate': 1.2408376963350785e-05, 'epoch': 0.77}


 39%|███▉      | 149/384 [03:21<05:14,  1.34s/it]

{'loss': 1.2312, 'grad_norm': 9.195182800292969, 'learning_rate': 1.2356020942408377e-05, 'epoch': 0.78}


 39%|███▉      | 150/384 [03:22<05:09,  1.32s/it]

{'loss': 2.0121, 'grad_norm': 5.241406440734863, 'learning_rate': 1.230366492146597e-05, 'epoch': 0.78}


 39%|███▉      | 151/384 [03:23<05:03,  1.30s/it]

{'loss': 1.807, 'grad_norm': 6.0918378829956055, 'learning_rate': 1.2251308900523562e-05, 'epoch': 0.79}


 40%|███▉      | 152/384 [03:24<05:01,  1.30s/it]

{'loss': 1.9411, 'grad_norm': 5.888011932373047, 'learning_rate': 1.2198952879581152e-05, 'epoch': 0.79}


 40%|███▉      | 153/384 [03:26<05:02,  1.31s/it]

{'loss': 2.1025, 'grad_norm': 4.876124382019043, 'learning_rate': 1.2146596858638744e-05, 'epoch': 0.8}


 40%|████      | 154/384 [03:27<05:01,  1.31s/it]

{'loss': 2.0054, 'grad_norm': 5.1048359870910645, 'learning_rate': 1.2094240837696336e-05, 'epoch': 0.8}


 40%|████      | 155/384 [03:28<05:00,  1.31s/it]

{'loss': 2.2064, 'grad_norm': 5.382903575897217, 'learning_rate': 1.2041884816753929e-05, 'epoch': 0.81}


 41%|████      | 156/384 [03:30<04:58,  1.31s/it]

{'loss': 1.8802, 'grad_norm': 4.707157611846924, 'learning_rate': 1.198952879581152e-05, 'epoch': 0.81}


 41%|████      | 157/384 [03:31<05:08,  1.36s/it]

{'loss': 1.5249, 'grad_norm': 5.850284576416016, 'learning_rate': 1.193717277486911e-05, 'epoch': 0.82}


 41%|████      | 158/384 [03:32<05:04,  1.35s/it]

{'loss': 1.9585, 'grad_norm': 5.536887168884277, 'learning_rate': 1.1884816753926702e-05, 'epoch': 0.82}


 41%|████▏     | 159/384 [03:34<05:00,  1.34s/it]

{'loss': 1.6734, 'grad_norm': 7.876927852630615, 'learning_rate': 1.1832460732984294e-05, 'epoch': 0.83}


 42%|████▏     | 160/384 [03:35<04:53,  1.31s/it]

{'loss': 1.9442, 'grad_norm': 6.0363993644714355, 'learning_rate': 1.1780104712041886e-05, 'epoch': 0.83}


 42%|████▏     | 161/384 [03:36<04:51,  1.31s/it]

{'loss': 1.1183, 'grad_norm': 10.847200393676758, 'learning_rate': 1.1727748691099476e-05, 'epoch': 0.84}


 42%|████▏     | 162/384 [03:38<04:48,  1.30s/it]

{'loss': 1.7346, 'grad_norm': 6.472291946411133, 'learning_rate': 1.1675392670157069e-05, 'epoch': 0.84}


 42%|████▏     | 163/384 [03:39<04:45,  1.29s/it]

{'loss': 1.8955, 'grad_norm': 5.350700855255127, 'learning_rate': 1.162303664921466e-05, 'epoch': 0.85}


 43%|████▎     | 164/384 [03:40<04:53,  1.34s/it]

{'loss': 2.03, 'grad_norm': 5.41845703125, 'learning_rate': 1.1570680628272253e-05, 'epoch': 0.85}


 43%|████▎     | 165/384 [03:42<04:50,  1.33s/it]

{'loss': 1.6575, 'grad_norm': 6.0131120681762695, 'learning_rate': 1.1518324607329845e-05, 'epoch': 0.86}


 43%|████▎     | 166/384 [03:43<04:48,  1.32s/it]

{'loss': 2.0406, 'grad_norm': 6.715968132019043, 'learning_rate': 1.1465968586387436e-05, 'epoch': 0.86}


 43%|████▎     | 167/384 [03:44<04:42,  1.30s/it]

{'loss': 2.0977, 'grad_norm': 5.572907447814941, 'learning_rate': 1.1413612565445028e-05, 'epoch': 0.87}


 44%|████▍     | 168/384 [03:45<04:38,  1.29s/it]

{'loss': 2.272, 'grad_norm': 4.8940887451171875, 'learning_rate': 1.1361256544502618e-05, 'epoch': 0.88}


 44%|████▍     | 169/384 [03:47<04:37,  1.29s/it]

{'loss': 2.0946, 'grad_norm': 4.923325061798096, 'learning_rate': 1.130890052356021e-05, 'epoch': 0.88}


 44%|████▍     | 170/384 [03:48<04:34,  1.28s/it]

{'loss': 1.7583, 'grad_norm': 6.142749786376953, 'learning_rate': 1.12565445026178e-05, 'epoch': 0.89}


 45%|████▍     | 171/384 [03:49<04:31,  1.27s/it]

{'loss': 1.9926, 'grad_norm': 5.168422222137451, 'learning_rate': 1.1204188481675393e-05, 'epoch': 0.89}


 45%|████▍     | 172/384 [03:51<04:42,  1.33s/it]

{'loss': 2.1921, 'grad_norm': 6.300307750701904, 'learning_rate': 1.1151832460732985e-05, 'epoch': 0.9}


 45%|████▌     | 173/384 [03:52<04:39,  1.32s/it]

{'loss': 1.8968, 'grad_norm': 6.751431941986084, 'learning_rate': 1.1099476439790577e-05, 'epoch': 0.9}


 45%|████▌     | 174/384 [03:53<04:45,  1.36s/it]

{'loss': 1.7106, 'grad_norm': 6.319056987762451, 'learning_rate': 1.104712041884817e-05, 'epoch': 0.91}


 46%|████▌     | 175/384 [03:55<04:37,  1.33s/it]

{'loss': 1.8463, 'grad_norm': 5.416242599487305, 'learning_rate': 1.099476439790576e-05, 'epoch': 0.91}


 46%|████▌     | 176/384 [03:56<04:43,  1.36s/it]

{'loss': 1.3063, 'grad_norm': 5.115551948547363, 'learning_rate': 1.0942408376963352e-05, 'epoch': 0.92}


 46%|████▌     | 177/384 [03:57<04:36,  1.34s/it]

{'loss': 2.0666, 'grad_norm': 4.969640254974365, 'learning_rate': 1.0890052356020944e-05, 'epoch': 0.92}


 46%|████▋     | 178/384 [03:59<04:32,  1.32s/it]

{'loss': 1.7625, 'grad_norm': 6.759654521942139, 'learning_rate': 1.0837696335078536e-05, 'epoch': 0.93}


 47%|████▋     | 179/384 [04:00<04:38,  1.36s/it]

{'loss': 1.8232, 'grad_norm': 5.572596549987793, 'learning_rate': 1.0785340314136125e-05, 'epoch': 0.93}


 47%|████▋     | 180/384 [04:02<04:42,  1.38s/it]

{'loss': 2.1767, 'grad_norm': 5.463850498199463, 'learning_rate': 1.0732984293193717e-05, 'epoch': 0.94}


 47%|████▋     | 181/384 [04:03<04:35,  1.36s/it]

{'loss': 1.881, 'grad_norm': 5.678812026977539, 'learning_rate': 1.068062827225131e-05, 'epoch': 0.94}


 47%|████▋     | 182/384 [04:04<04:28,  1.33s/it]

{'loss': 1.7606, 'grad_norm': 6.979583740234375, 'learning_rate': 1.0628272251308902e-05, 'epoch': 0.95}


 48%|████▊     | 183/384 [04:06<04:33,  1.36s/it]

{'loss': 2.0284, 'grad_norm': 5.369931697845459, 'learning_rate': 1.0575916230366492e-05, 'epoch': 0.95}


 48%|████▊     | 184/384 [04:07<04:28,  1.34s/it]

{'loss': 2.0162, 'grad_norm': 5.044394016265869, 'learning_rate': 1.0523560209424084e-05, 'epoch': 0.96}


 48%|████▊     | 185/384 [04:08<04:23,  1.32s/it]

{'loss': 0.7392, 'grad_norm': 9.87430191040039, 'learning_rate': 1.0471204188481676e-05, 'epoch': 0.96}


 48%|████▊     | 186/384 [04:10<04:31,  1.37s/it]

{'loss': 1.0443, 'grad_norm': 5.475067615509033, 'learning_rate': 1.0418848167539269e-05, 'epoch': 0.97}


 49%|████▊     | 187/384 [04:11<04:24,  1.34s/it]

{'loss': 1.239, 'grad_norm': 14.32051944732666, 'learning_rate': 1.036649214659686e-05, 'epoch': 0.97}


 49%|████▉     | 188/384 [04:12<04:20,  1.33s/it]

{'loss': 2.1215, 'grad_norm': 6.01576566696167, 'learning_rate': 1.031413612565445e-05, 'epoch': 0.98}


 49%|████▉     | 189/384 [04:14<04:26,  1.37s/it]

{'loss': 1.6857, 'grad_norm': 8.61133861541748, 'learning_rate': 1.0261780104712042e-05, 'epoch': 0.98}


 49%|████▉     | 190/384 [04:15<04:21,  1.35s/it]

{'loss': 1.938, 'grad_norm': 6.195417881011963, 'learning_rate': 1.0209424083769634e-05, 'epoch': 0.99}


 50%|████▉     | 191/384 [04:16<04:19,  1.34s/it]

{'loss': 2.0171, 'grad_norm': 6.05283260345459, 'learning_rate': 1.0157068062827226e-05, 'epoch': 0.99}


 50%|█████     | 192/384 [04:18<04:25,  1.38s/it]

{'loss': 2.3197, 'grad_norm': 4.9442267417907715, 'learning_rate': 1.0104712041884816e-05, 'epoch': 1.0}


 50%|█████     | 193/384 [04:19<04:21,  1.37s/it]

{'loss': 1.8236, 'grad_norm': 4.744423866271973, 'learning_rate': 1.0052356020942409e-05, 'epoch': 1.01}


 51%|█████     | 194/384 [04:21<04:26,  1.40s/it]

{'loss': 1.9502, 'grad_norm': 4.3898468017578125, 'learning_rate': 1e-05, 'epoch': 1.01}


 51%|█████     | 195/384 [04:22<04:37,  1.47s/it]

{'loss': 2.0922, 'grad_norm': 4.629128932952881, 'learning_rate': 9.947643979057593e-06, 'epoch': 1.02}


 51%|█████     | 196/384 [04:24<04:36,  1.47s/it]

{'loss': 1.9411, 'grad_norm': 4.501903057098389, 'learning_rate': 9.895287958115183e-06, 'epoch': 1.02}


 51%|█████▏    | 197/384 [04:25<04:33,  1.46s/it]

{'loss': 2.1189, 'grad_norm': 5.454916000366211, 'learning_rate': 9.842931937172776e-06, 'epoch': 1.03}


 52%|█████▏    | 198/384 [04:26<04:20,  1.40s/it]

{'loss': 1.6893, 'grad_norm': 7.218893527984619, 'learning_rate': 9.790575916230368e-06, 'epoch': 1.03}


 52%|█████▏    | 199/384 [04:28<04:12,  1.36s/it]

{'loss': 1.5823, 'grad_norm': 8.073346138000488, 'learning_rate': 9.73821989528796e-06, 'epoch': 1.04}


 52%|█████▏    | 200/384 [04:29<04:07,  1.34s/it]

{'loss': 1.9848, 'grad_norm': 4.990479946136475, 'learning_rate': 9.68586387434555e-06, 'epoch': 1.04}


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
 52%|█████▏    | 201/384 [04:31<04:33,  1.49s/it]

{'loss': 1.7223, 'grad_norm': 4.71182918548584, 'learning_rate': 9.633507853403143e-06, 'epoch': 1.05}


 53%|█████▎    | 202/384 [04:32<04:39,  1.53s/it]

{'loss': 1.6443, 'grad_norm': 4.393730163574219, 'learning_rate': 9.581151832460733e-06, 'epoch': 1.05}


 53%|█████▎    | 203/384 [04:34<04:32,  1.51s/it]

{'loss': 1.6484, 'grad_norm': 5.437723636627197, 'learning_rate': 9.528795811518325e-06, 'epoch': 1.06}


 53%|█████▎    | 204/384 [04:35<04:18,  1.44s/it]

{'loss': 1.8831, 'grad_norm': 5.042850971221924, 'learning_rate': 9.476439790575916e-06, 'epoch': 1.06}


 53%|█████▎    | 205/384 [04:37<04:17,  1.44s/it]

{'loss': 1.995, 'grad_norm': 4.539870738983154, 'learning_rate': 9.424083769633508e-06, 'epoch': 1.07}


 54%|█████▎    | 206/384 [04:38<04:10,  1.41s/it]

{'loss': 2.0303, 'grad_norm': 5.140021324157715, 'learning_rate': 9.3717277486911e-06, 'epoch': 1.07}


 54%|█████▍    | 207/384 [04:39<04:04,  1.38s/it]

{'loss': 1.4712, 'grad_norm': 6.813496112823486, 'learning_rate': 9.319371727748692e-06, 'epoch': 1.08}


 54%|█████▍    | 208/384 [04:41<04:08,  1.41s/it]

{'loss': 1.7277, 'grad_norm': 5.814014434814453, 'learning_rate': 9.267015706806284e-06, 'epoch': 1.08}


 54%|█████▍    | 209/384 [04:42<04:11,  1.44s/it]

{'loss': 1.4941, 'grad_norm': 7.701446533203125, 'learning_rate': 9.214659685863875e-06, 'epoch': 1.09}


 55%|█████▍    | 210/384 [04:44<04:03,  1.40s/it]

{'loss': 1.8379, 'grad_norm': 5.545687198638916, 'learning_rate': 9.162303664921467e-06, 'epoch': 1.09}


 55%|█████▍    | 211/384 [04:45<04:07,  1.43s/it]

{'loss': 1.9539, 'grad_norm': 4.972479820251465, 'learning_rate': 9.109947643979057e-06, 'epoch': 1.1}


 55%|█████▌    | 212/384 [04:46<03:57,  1.38s/it]

{'loss': 1.6935, 'grad_norm': 6.318912506103516, 'learning_rate': 9.05759162303665e-06, 'epoch': 1.1}


 55%|█████▌    | 213/384 [04:48<03:59,  1.40s/it]

{'loss': 1.9014, 'grad_norm': 5.3524861335754395, 'learning_rate': 9.005235602094242e-06, 'epoch': 1.11}


 56%|█████▌    | 214/384 [04:49<03:50,  1.36s/it]

{'loss': 1.3113, 'grad_norm': 5.940618991851807, 'learning_rate': 8.952879581151834e-06, 'epoch': 1.11}


 56%|█████▌    | 215/384 [04:51<03:53,  1.38s/it]

{'loss': 1.0572, 'grad_norm': 6.6659016609191895, 'learning_rate': 8.900523560209426e-06, 'epoch': 1.12}


 56%|█████▋    | 216/384 [04:52<03:47,  1.36s/it]

{'loss': 1.6538, 'grad_norm': 6.675582408905029, 'learning_rate': 8.848167539267016e-06, 'epoch': 1.12}


 57%|█████▋    | 217/384 [04:53<03:43,  1.34s/it]

{'loss': 2.158, 'grad_norm': 6.017210960388184, 'learning_rate': 8.795811518324609e-06, 'epoch': 1.13}


 57%|█████▋    | 218/384 [04:55<03:47,  1.37s/it]

{'loss': 1.7437, 'grad_norm': 6.349145889282227, 'learning_rate': 8.743455497382199e-06, 'epoch': 1.14}


 57%|█████▋    | 219/384 [04:56<03:49,  1.39s/it]

{'loss': 1.1986, 'grad_norm': 6.400429725646973, 'learning_rate': 8.691099476439791e-06, 'epoch': 1.14}


 57%|█████▋    | 220/384 [04:57<03:44,  1.37s/it]

{'loss': 1.927, 'grad_norm': 6.882306098937988, 'learning_rate': 8.638743455497383e-06, 'epoch': 1.15}


 58%|█████▊    | 221/384 [04:59<03:37,  1.33s/it]

{'loss': 1.3281, 'grad_norm': 9.025094985961914, 'learning_rate': 8.586387434554974e-06, 'epoch': 1.15}


 58%|█████▊    | 222/384 [05:00<03:34,  1.32s/it]

{'loss': 1.9734, 'grad_norm': 5.1772613525390625, 'learning_rate': 8.534031413612566e-06, 'epoch': 1.16}


 58%|█████▊    | 223/384 [05:01<03:30,  1.31s/it]

{'loss': 1.9479, 'grad_norm': 6.092801094055176, 'learning_rate': 8.481675392670158e-06, 'epoch': 1.16}


 58%|█████▊    | 224/384 [05:02<03:27,  1.30s/it]

{'loss': 1.6935, 'grad_norm': 6.401683330535889, 'learning_rate': 8.429319371727749e-06, 'epoch': 1.17}


 59%|█████▊    | 225/384 [05:04<03:25,  1.30s/it]

{'loss': 2.2548, 'grad_norm': 6.222934246063232, 'learning_rate': 8.37696335078534e-06, 'epoch': 1.17}


 59%|█████▉    | 226/384 [05:05<03:22,  1.28s/it]

{'loss': 1.3599, 'grad_norm': 7.9255051612854, 'learning_rate': 8.324607329842933e-06, 'epoch': 1.18}


 59%|█████▉    | 227/384 [05:06<03:20,  1.27s/it]

{'loss': 1.7759, 'grad_norm': 5.868777751922607, 'learning_rate': 8.272251308900523e-06, 'epoch': 1.18}


 59%|█████▉    | 228/384 [05:08<03:19,  1.28s/it]

{'loss': 1.0551, 'grad_norm': 5.757028579711914, 'learning_rate': 8.219895287958116e-06, 'epoch': 1.19}


 60%|█████▉    | 229/384 [05:09<03:18,  1.28s/it]

{'loss': 2.1578, 'grad_norm': 5.316577434539795, 'learning_rate': 8.167539267015708e-06, 'epoch': 1.19}


 60%|█████▉    | 230/384 [05:10<03:17,  1.28s/it]

{'loss': 1.2431, 'grad_norm': 5.30445671081543, 'learning_rate': 8.1151832460733e-06, 'epoch': 1.2}


 60%|██████    | 231/384 [05:11<03:16,  1.29s/it]

{'loss': 1.6914, 'grad_norm': 6.194228172302246, 'learning_rate': 8.06282722513089e-06, 'epoch': 1.2}


 60%|██████    | 232/384 [05:13<03:15,  1.29s/it]

{'loss': 1.9466, 'grad_norm': 5.333029270172119, 'learning_rate': 8.010471204188483e-06, 'epoch': 1.21}


 61%|██████    | 233/384 [05:14<03:14,  1.29s/it]

{'loss': 2.0877, 'grad_norm': 4.875550270080566, 'learning_rate': 7.958115183246073e-06, 'epoch': 1.21}


 61%|██████    | 234/384 [05:15<03:12,  1.28s/it]

{'loss': 1.8757, 'grad_norm': 4.701094150543213, 'learning_rate': 7.905759162303665e-06, 'epoch': 1.22}


 61%|██████    | 235/384 [05:17<03:11,  1.28s/it]

{'loss': 1.1733, 'grad_norm': 8.579570770263672, 'learning_rate': 7.853403141361257e-06, 'epoch': 1.22}


 61%|██████▏   | 236/384 [05:18<03:09,  1.28s/it]

{'loss': 1.3031, 'grad_norm': 6.217527866363525, 'learning_rate': 7.80104712041885e-06, 'epoch': 1.23}


 62%|██████▏   | 237/384 [05:19<03:08,  1.28s/it]

{'loss': 1.7935, 'grad_norm': 5.791785717010498, 'learning_rate': 7.748691099476442e-06, 'epoch': 1.23}


 62%|██████▏   | 238/384 [05:20<03:07,  1.28s/it]

{'loss': 1.8344, 'grad_norm': 5.349313259124756, 'learning_rate': 7.696335078534032e-06, 'epoch': 1.24}


 62%|██████▏   | 239/384 [05:22<03:13,  1.34s/it]

{'loss': 1.4465, 'grad_norm': 7.603414535522461, 'learning_rate': 7.643979057591624e-06, 'epoch': 1.24}


 62%|██████▎   | 240/384 [05:23<03:10,  1.32s/it]

{'loss': 1.8951, 'grad_norm': 5.128584384918213, 'learning_rate': 7.591623036649215e-06, 'epoch': 1.25}


 63%|██████▎   | 241/384 [05:24<03:07,  1.31s/it]

{'loss': 1.8996, 'grad_norm': 5.829614639282227, 'learning_rate': 7.539267015706807e-06, 'epoch': 1.26}


 63%|██████▎   | 242/384 [05:26<03:03,  1.29s/it]

{'loss': 1.775, 'grad_norm': 7.41880989074707, 'learning_rate': 7.486910994764398e-06, 'epoch': 1.26}


 63%|██████▎   | 243/384 [05:27<03:01,  1.29s/it]

{'loss': 1.9413, 'grad_norm': 6.237235069274902, 'learning_rate': 7.43455497382199e-06, 'epoch': 1.27}


 64%|██████▎   | 244/384 [05:28<03:00,  1.29s/it]

{'loss': 1.7396, 'grad_norm': 5.094207763671875, 'learning_rate': 7.382198952879581e-06, 'epoch': 1.27}


 64%|██████▍   | 245/384 [05:29<02:58,  1.29s/it]

{'loss': 1.5076, 'grad_norm': 5.302342891693115, 'learning_rate': 7.329842931937173e-06, 'epoch': 1.28}


 64%|██████▍   | 246/384 [05:31<02:57,  1.29s/it]

{'loss': 2.1132, 'grad_norm': 5.687541961669922, 'learning_rate': 7.277486910994765e-06, 'epoch': 1.28}


 64%|██████▍   | 247/384 [05:32<02:56,  1.29s/it]

{'loss': 0.4023, 'grad_norm': 6.05728006362915, 'learning_rate': 7.2251308900523565e-06, 'epoch': 1.29}


 65%|██████▍   | 248/384 [05:34<03:02,  1.34s/it]

{'loss': 1.9022, 'grad_norm': 5.755108833312988, 'learning_rate': 7.172774869109949e-06, 'epoch': 1.29}


 65%|██████▍   | 249/384 [05:35<02:59,  1.33s/it]

{'loss': 1.6421, 'grad_norm': 6.983827590942383, 'learning_rate': 7.12041884816754e-06, 'epoch': 1.3}


 65%|██████▌   | 250/384 [05:36<03:03,  1.37s/it]

{'loss': 1.0946, 'grad_norm': 5.651775360107422, 'learning_rate': 7.068062827225132e-06, 'epoch': 1.3}


 65%|██████▌   | 251/384 [05:38<02:57,  1.33s/it]

{'loss': 2.038, 'grad_norm': 6.1702656745910645, 'learning_rate': 7.015706806282723e-06, 'epoch': 1.31}


 66%|██████▌   | 252/384 [05:39<02:53,  1.31s/it]

{'loss': 1.8998, 'grad_norm': 7.370166778564453, 'learning_rate': 6.963350785340315e-06, 'epoch': 1.31}


 66%|██████▌   | 253/384 [05:40<02:49,  1.30s/it]

{'loss': 1.376, 'grad_norm': 10.968611717224121, 'learning_rate': 6.910994764397906e-06, 'epoch': 1.32}


 66%|██████▌   | 254/384 [05:42<02:54,  1.34s/it]

{'loss': 1.2734, 'grad_norm': 6.085896015167236, 'learning_rate': 6.858638743455498e-06, 'epoch': 1.32}


 66%|██████▋   | 255/384 [05:43<02:51,  1.33s/it]

{'loss': 1.8546, 'grad_norm': 6.1322197914123535, 'learning_rate': 6.80628272251309e-06, 'epoch': 1.33}


 67%|██████▋   | 256/384 [05:44<02:54,  1.36s/it]

{'loss': 1.5582, 'grad_norm': 8.33713150024414, 'learning_rate': 6.753926701570681e-06, 'epoch': 1.33}


 67%|██████▋   | 257/384 [05:46<02:49,  1.34s/it]

{'loss': 1.6109, 'grad_norm': 9.387441635131836, 'learning_rate': 6.701570680628273e-06, 'epoch': 1.34}


 67%|██████▋   | 258/384 [05:47<02:46,  1.32s/it]

{'loss': 2.2455, 'grad_norm': 5.286557674407959, 'learning_rate': 6.649214659685864e-06, 'epoch': 1.34}


 67%|██████▋   | 259/384 [05:48<02:43,  1.31s/it]

{'loss': 1.9179, 'grad_norm': 4.802183628082275, 'learning_rate': 6.5968586387434565e-06, 'epoch': 1.35}


 68%|██████▊   | 260/384 [05:50<02:47,  1.35s/it]

{'loss': 1.7624, 'grad_norm': 5.298939228057861, 'learning_rate': 6.544502617801047e-06, 'epoch': 1.35}


 68%|██████▊   | 261/384 [05:51<02:43,  1.33s/it]

{'loss': 1.7453, 'grad_norm': 6.833159923553467, 'learning_rate': 6.492146596858639e-06, 'epoch': 1.36}


 68%|██████▊   | 262/384 [05:52<02:47,  1.37s/it]

{'loss': 1.8633, 'grad_norm': 8.636818885803223, 'learning_rate': 6.43979057591623e-06, 'epoch': 1.36}


 68%|██████▊   | 263/384 [05:54<02:42,  1.34s/it]

{'loss': 1.9128, 'grad_norm': 4.9010910987854, 'learning_rate': 6.3874345549738226e-06, 'epoch': 1.37}


 69%|██████▉   | 264/384 [05:55<02:39,  1.33s/it]

{'loss': 2.0354, 'grad_norm': 7.166689395904541, 'learning_rate': 6.335078534031414e-06, 'epoch': 1.38}


 69%|██████▉   | 265/384 [05:56<02:41,  1.36s/it]

{'loss': 1.9757, 'grad_norm': 5.78518533706665, 'learning_rate': 6.282722513089006e-06, 'epoch': 1.38}


 69%|██████▉   | 266/384 [05:58<02:37,  1.34s/it]

{'loss': 1.6676, 'grad_norm': 5.742091655731201, 'learning_rate': 6.230366492146598e-06, 'epoch': 1.39}


 70%|██████▉   | 267/384 [05:59<02:33,  1.31s/it]

{'loss': 1.172, 'grad_norm': 5.390910625457764, 'learning_rate': 6.178010471204189e-06, 'epoch': 1.39}


 70%|██████▉   | 268/384 [06:00<02:31,  1.30s/it]

{'loss': 1.7556, 'grad_norm': 6.734187126159668, 'learning_rate': 6.125654450261781e-06, 'epoch': 1.4}


 70%|███████   | 269/384 [06:01<02:28,  1.29s/it]

{'loss': 2.0556, 'grad_norm': 5.1225738525390625, 'learning_rate': 6.073298429319372e-06, 'epoch': 1.4}


 70%|███████   | 270/384 [06:03<02:27,  1.29s/it]

{'loss': 1.8105, 'grad_norm': 6.10861873626709, 'learning_rate': 6.020942408376964e-06, 'epoch': 1.41}


 71%|███████   | 271/384 [06:04<02:32,  1.35s/it]

{'loss': 2.0944, 'grad_norm': 4.942053318023682, 'learning_rate': 5.968586387434555e-06, 'epoch': 1.41}


 71%|███████   | 272/384 [06:05<02:29,  1.34s/it]

{'loss': 2.0346, 'grad_norm': 5.184123516082764, 'learning_rate': 5.916230366492147e-06, 'epoch': 1.42}


 71%|███████   | 273/384 [06:07<02:25,  1.31s/it]

{'loss': 1.8074, 'grad_norm': 6.558712959289551, 'learning_rate': 5.863874345549738e-06, 'epoch': 1.42}


 71%|███████▏  | 274/384 [06:08<02:28,  1.35s/it]

{'loss': 1.7588, 'grad_norm': 5.152683258056641, 'learning_rate': 5.81151832460733e-06, 'epoch': 1.43}


 72%|███████▏  | 275/384 [06:10<02:31,  1.39s/it]

{'loss': 2.0218, 'grad_norm': 5.077221870422363, 'learning_rate': 5.7591623036649226e-06, 'epoch': 1.43}


 72%|███████▏  | 276/384 [06:11<02:27,  1.37s/it]

{'loss': 1.9079, 'grad_norm': 5.903632164001465, 'learning_rate': 5.706806282722514e-06, 'epoch': 1.44}


 72%|███████▏  | 277/384 [06:12<02:29,  1.40s/it]

{'loss': 1.8177, 'grad_norm': 5.863330364227295, 'learning_rate': 5.654450261780105e-06, 'epoch': 1.44}


 72%|███████▏  | 278/384 [06:14<02:29,  1.41s/it]

{'loss': 1.8661, 'grad_norm': 4.614030838012695, 'learning_rate': 5.6020942408376965e-06, 'epoch': 1.45}


 73%|███████▎  | 279/384 [06:15<02:24,  1.37s/it]

{'loss': 2.1482, 'grad_norm': 5.301157474517822, 'learning_rate': 5.549738219895289e-06, 'epoch': 1.45}


 73%|███████▎  | 280/384 [06:16<02:19,  1.35s/it]

{'loss': 1.9107, 'grad_norm': 6.917619705200195, 'learning_rate': 5.49738219895288e-06, 'epoch': 1.46}


 73%|███████▎  | 281/384 [06:18<02:21,  1.37s/it]

{'loss': 1.801, 'grad_norm': 7.607491970062256, 'learning_rate': 5.445026178010472e-06, 'epoch': 1.46}


 73%|███████▎  | 282/384 [06:19<02:16,  1.34s/it]

{'loss': 1.9348, 'grad_norm': 7.425357341766357, 'learning_rate': 5.392670157068063e-06, 'epoch': 1.47}


 74%|███████▎  | 283/384 [06:20<02:13,  1.32s/it]

{'loss': 1.056, 'grad_norm': 10.968154907226562, 'learning_rate': 5.340314136125655e-06, 'epoch': 1.47}


 74%|███████▍  | 284/384 [06:22<02:16,  1.36s/it]

{'loss': 1.9094, 'grad_norm': 6.1460490226745605, 'learning_rate': 5.287958115183246e-06, 'epoch': 1.48}


 74%|███████▍  | 285/384 [06:23<02:12,  1.34s/it]

{'loss': 1.93, 'grad_norm': 4.916043281555176, 'learning_rate': 5.235602094240838e-06, 'epoch': 1.48}


 74%|███████▍  | 286/384 [06:24<02:09,  1.32s/it]

{'loss': 1.9072, 'grad_norm': 5.267256736755371, 'learning_rate': 5.18324607329843e-06, 'epoch': 1.49}


 75%|███████▍  | 287/384 [06:26<02:06,  1.31s/it]

{'loss': 2.2094, 'grad_norm': 5.596776008605957, 'learning_rate': 5.130890052356021e-06, 'epoch': 1.49}


 75%|███████▌  | 288/384 [06:27<02:09,  1.35s/it]

{'loss': 2.0468, 'grad_norm': 5.85833215713501, 'learning_rate': 5.078534031413613e-06, 'epoch': 1.5}


 75%|███████▌  | 289/384 [06:28<02:06,  1.33s/it]

{'loss': 2.1618, 'grad_norm': 6.016538619995117, 'learning_rate': 5.026178010471204e-06, 'epoch': 1.51}


 76%|███████▌  | 290/384 [06:30<02:08,  1.36s/it]

{'loss': 1.1309, 'grad_norm': 12.891338348388672, 'learning_rate': 4.9738219895287965e-06, 'epoch': 1.51}


 76%|███████▌  | 291/384 [06:31<02:04,  1.34s/it]

{'loss': 1.8118, 'grad_norm': 5.464110374450684, 'learning_rate': 4.921465968586388e-06, 'epoch': 1.52}


 76%|███████▌  | 292/384 [06:33<02:06,  1.37s/it]

{'loss': 1.9703, 'grad_norm': 5.6201324462890625, 'learning_rate': 4.86910994764398e-06, 'epoch': 1.52}


 76%|███████▋  | 293/384 [06:34<02:02,  1.35s/it]

{'loss': 1.4877, 'grad_norm': 6.801509380340576, 'learning_rate': 4.816753926701571e-06, 'epoch': 1.53}


 77%|███████▋  | 294/384 [06:35<01:59,  1.33s/it]

{'loss': 2.1788, 'grad_norm': 5.959479331970215, 'learning_rate': 4.764397905759163e-06, 'epoch': 1.53}


 77%|███████▋  | 295/384 [06:36<01:57,  1.32s/it]

{'loss': 1.7978, 'grad_norm': 6.16651725769043, 'learning_rate': 4.712041884816754e-06, 'epoch': 1.54}


 77%|███████▋  | 296/384 [06:38<02:00,  1.36s/it]

{'loss': 1.4277, 'grad_norm': 6.400694847106934, 'learning_rate': 4.659685863874346e-06, 'epoch': 1.54}


 77%|███████▋  | 297/384 [06:39<01:56,  1.33s/it]

{'loss': 1.2759, 'grad_norm': 5.831536293029785, 'learning_rate': 4.607329842931937e-06, 'epoch': 1.55}


 78%|███████▊  | 298/384 [06:41<01:53,  1.32s/it]

{'loss': 1.0341, 'grad_norm': 10.246345520019531, 'learning_rate': 4.554973821989529e-06, 'epoch': 1.55}


 78%|███████▊  | 299/384 [06:42<01:55,  1.36s/it]

{'loss': 1.3657, 'grad_norm': 5.917598247528076, 'learning_rate': 4.502617801047121e-06, 'epoch': 1.56}


 78%|███████▊  | 300/384 [06:43<01:52,  1.34s/it]

{'loss': 1.8697, 'grad_norm': 5.655407905578613, 'learning_rate': 4.450261780104713e-06, 'epoch': 1.56}


 78%|███████▊  | 301/384 [06:45<01:54,  1.37s/it]

{'loss': 1.3746, 'grad_norm': 6.497529983520508, 'learning_rate': 4.397905759162304e-06, 'epoch': 1.57}


 79%|███████▊  | 302/384 [06:46<01:49,  1.34s/it]

{'loss': 1.1647, 'grad_norm': 6.244491100311279, 'learning_rate': 4.345549738219896e-06, 'epoch': 1.57}


 79%|███████▉  | 303/384 [06:47<01:47,  1.32s/it]

{'loss': 2.0252, 'grad_norm': 5.280650615692139, 'learning_rate': 4.293193717277487e-06, 'epoch': 1.58}


 79%|███████▉  | 304/384 [06:49<01:44,  1.31s/it]

{'loss': 1.6421, 'grad_norm': 7.096103668212891, 'learning_rate': 4.240837696335079e-06, 'epoch': 1.58}


 79%|███████▉  | 305/384 [06:50<01:42,  1.30s/it]

{'loss': 1.5672, 'grad_norm': 6.147646903991699, 'learning_rate': 4.18848167539267e-06, 'epoch': 1.59}


 80%|███████▉  | 306/384 [06:51<01:40,  1.29s/it]

{'loss': 1.948, 'grad_norm': 5.556332588195801, 'learning_rate': 4.136125654450262e-06, 'epoch': 1.59}


 80%|███████▉  | 307/384 [06:52<01:38,  1.28s/it]

{'loss': 1.8444, 'grad_norm': 6.173013210296631, 'learning_rate': 4.083769633507854e-06, 'epoch': 1.6}


 80%|████████  | 308/384 [06:54<01:42,  1.34s/it]

{'loss': 1.0218, 'grad_norm': 6.181855201721191, 'learning_rate': 4.031413612565445e-06, 'epoch': 1.6}


 80%|████████  | 309/384 [06:55<01:42,  1.37s/it]

{'loss': 1.4811, 'grad_norm': 7.842129707336426, 'learning_rate': 3.9790575916230365e-06, 'epoch': 1.61}


 81%|████████  | 310/384 [06:57<01:43,  1.40s/it]

{'loss': 1.9682, 'grad_norm': 4.987566947937012, 'learning_rate': 3.926701570680629e-06, 'epoch': 1.61}


 81%|████████  | 311/384 [06:58<01:39,  1.37s/it]

{'loss': 1.9976, 'grad_norm': 6.831220626831055, 'learning_rate': 3.874345549738221e-06, 'epoch': 1.62}


 81%|████████▏ | 312/384 [06:59<01:36,  1.33s/it]

{'loss': 1.8816, 'grad_norm': 6.3770365715026855, 'learning_rate': 3.821989528795812e-06, 'epoch': 1.62}


 82%|████████▏ | 313/384 [07:01<01:33,  1.32s/it]

{'loss': 1.4573, 'grad_norm': 6.283664703369141, 'learning_rate': 3.7696335078534035e-06, 'epoch': 1.63}


 82%|████████▏ | 314/384 [07:02<01:31,  1.31s/it]

{'loss': 1.8711, 'grad_norm': 6.50789737701416, 'learning_rate': 3.717277486910995e-06, 'epoch': 1.64}


 82%|████████▏ | 315/384 [07:03<01:29,  1.30s/it]

{'loss': 2.0543, 'grad_norm': 5.447295665740967, 'learning_rate': 3.6649214659685865e-06, 'epoch': 1.64}


 82%|████████▏ | 316/384 [07:04<01:28,  1.30s/it]

{'loss': 2.0043, 'grad_norm': 5.589226245880127, 'learning_rate': 3.6125654450261782e-06, 'epoch': 1.65}


 83%|████████▎ | 317/384 [07:06<01:30,  1.34s/it]

{'loss': 1.8562, 'grad_norm': 6.53181266784668, 'learning_rate': 3.56020942408377e-06, 'epoch': 1.65}


 83%|████████▎ | 318/384 [07:07<01:27,  1.32s/it]

{'loss': 1.9866, 'grad_norm': 5.3224711418151855, 'learning_rate': 3.5078534031413613e-06, 'epoch': 1.66}


 83%|████████▎ | 319/384 [07:08<01:24,  1.31s/it]

{'loss': 2.2076, 'grad_norm': 5.971419811248779, 'learning_rate': 3.455497382198953e-06, 'epoch': 1.66}


 83%|████████▎ | 320/384 [07:10<01:22,  1.29s/it]

{'loss': 1.8465, 'grad_norm': 5.502999782562256, 'learning_rate': 3.403141361256545e-06, 'epoch': 1.67}


 84%|████████▎ | 321/384 [07:11<01:20,  1.28s/it]

{'loss': 2.069, 'grad_norm': 5.101822376251221, 'learning_rate': 3.3507853403141365e-06, 'epoch': 1.67}


 84%|████████▍ | 322/384 [07:12<01:19,  1.28s/it]

{'loss': 1.6754, 'grad_norm': 6.387209892272949, 'learning_rate': 3.2984293193717282e-06, 'epoch': 1.68}


 84%|████████▍ | 323/384 [07:14<01:21,  1.34s/it]

{'loss': 2.0543, 'grad_norm': 5.7346930503845215, 'learning_rate': 3.2460732984293196e-06, 'epoch': 1.68}


 84%|████████▍ | 324/384 [07:15<01:22,  1.37s/it]

{'loss': 1.3751, 'grad_norm': 6.1925249099731445, 'learning_rate': 3.1937172774869113e-06, 'epoch': 1.69}


 85%|████████▍ | 325/384 [07:16<01:19,  1.35s/it]

{'loss': 1.7631, 'grad_norm': 6.8996782302856445, 'learning_rate': 3.141361256544503e-06, 'epoch': 1.69}


 85%|████████▍ | 326/384 [07:18<01:20,  1.38s/it]

{'loss': 2.1638, 'grad_norm': 5.806210517883301, 'learning_rate': 3.0890052356020943e-06, 'epoch': 1.7}


 85%|████████▌ | 327/384 [07:19<01:17,  1.36s/it]

{'loss': 2.1706, 'grad_norm': 5.145326137542725, 'learning_rate': 3.036649214659686e-06, 'epoch': 1.7}


 85%|████████▌ | 328/384 [07:20<01:14,  1.33s/it]

{'loss': 2.031, 'grad_norm': 6.805805683135986, 'learning_rate': 2.9842931937172774e-06, 'epoch': 1.71}


 86%|████████▌ | 329/384 [07:22<01:12,  1.32s/it]

{'loss': 1.3216, 'grad_norm': 5.602585792541504, 'learning_rate': 2.931937172774869e-06, 'epoch': 1.71}


 86%|████████▌ | 330/384 [07:23<01:10,  1.30s/it]

{'loss': 1.3575, 'grad_norm': 10.475010871887207, 'learning_rate': 2.8795811518324613e-06, 'epoch': 1.72}


 86%|████████▌ | 331/384 [07:24<01:08,  1.29s/it]

{'loss': 0.6628, 'grad_norm': 7.505149841308594, 'learning_rate': 2.8272251308900526e-06, 'epoch': 1.72}


 86%|████████▋ | 332/384 [07:26<01:06,  1.29s/it]

{'loss': 1.7799, 'grad_norm': 6.617682933807373, 'learning_rate': 2.7748691099476443e-06, 'epoch': 1.73}


 87%|████████▋ | 333/384 [07:27<01:05,  1.28s/it]

{'loss': 2.1618, 'grad_norm': 6.718728542327881, 'learning_rate': 2.722513089005236e-06, 'epoch': 1.73}


 87%|████████▋ | 334/384 [07:28<01:04,  1.28s/it]

{'loss': 2.0821, 'grad_norm': 5.856939792633057, 'learning_rate': 2.6701570680628274e-06, 'epoch': 1.74}


 87%|████████▋ | 335/384 [07:30<01:05,  1.33s/it]

{'loss': 2.0841, 'grad_norm': 5.514623165130615, 'learning_rate': 2.617801047120419e-06, 'epoch': 1.74}


 88%|████████▊ | 336/384 [07:31<01:05,  1.37s/it]

{'loss': 1.8328, 'grad_norm': 6.0402727127075195, 'learning_rate': 2.5654450261780104e-06, 'epoch': 1.75}


 88%|████████▊ | 337/384 [07:32<01:03,  1.34s/it]

{'loss': 2.0062, 'grad_norm': 5.294955730438232, 'learning_rate': 2.513089005235602e-06, 'epoch': 1.76}


 88%|████████▊ | 338/384 [07:34<01:03,  1.38s/it]

{'loss': 1.3732, 'grad_norm': 7.360295295715332, 'learning_rate': 2.460732984293194e-06, 'epoch': 1.76}


 88%|████████▊ | 339/384 [07:35<01:01,  1.36s/it]

{'loss': 1.5051, 'grad_norm': 6.318129062652588, 'learning_rate': 2.4083769633507856e-06, 'epoch': 1.77}


 89%|████████▊ | 340/384 [07:37<01:01,  1.39s/it]

{'loss': 1.7414, 'grad_norm': 5.703719139099121, 'learning_rate': 2.356020942408377e-06, 'epoch': 1.77}


 89%|████████▉ | 341/384 [07:38<01:01,  1.42s/it]

{'loss': 0.8587, 'grad_norm': 8.146125793457031, 'learning_rate': 2.3036649214659687e-06, 'epoch': 1.78}


 89%|████████▉ | 342/384 [07:39<00:57,  1.38s/it]

{'loss': 1.9732, 'grad_norm': 5.297956466674805, 'learning_rate': 2.2513089005235604e-06, 'epoch': 1.78}


 89%|████████▉ | 343/384 [07:41<00:57,  1.41s/it]

{'loss': 1.7724, 'grad_norm': 6.804519176483154, 'learning_rate': 2.198952879581152e-06, 'epoch': 1.79}


 90%|████████▉ | 344/384 [07:42<00:54,  1.37s/it]

{'loss': 1.4216, 'grad_norm': 7.053571701049805, 'learning_rate': 2.1465968586387435e-06, 'epoch': 1.79}


 90%|████████▉ | 345/384 [07:43<00:52,  1.35s/it]

{'loss': 2.3168, 'grad_norm': 5.532853603363037, 'learning_rate': 2.094240837696335e-06, 'epoch': 1.8}


 90%|█████████ | 346/384 [07:45<00:50,  1.33s/it]

{'loss': 1.3211, 'grad_norm': 8.096820831298828, 'learning_rate': 2.041884816753927e-06, 'epoch': 1.8}


 90%|█████████ | 347/384 [07:46<00:50,  1.37s/it]

{'loss': 1.7327, 'grad_norm': 6.7509894371032715, 'learning_rate': 1.9895287958115183e-06, 'epoch': 1.81}


 91%|█████████ | 348/384 [07:47<00:48,  1.34s/it]

{'loss': 1.9506, 'grad_norm': 6.627098083496094, 'learning_rate': 1.9371727748691104e-06, 'epoch': 1.81}


 91%|█████████ | 349/384 [07:49<00:52,  1.50s/it]

{'loss': 2.1003, 'grad_norm': 5.497021675109863, 'learning_rate': 1.8848167539267017e-06, 'epoch': 1.82}


 91%|█████████ | 350/384 [07:51<00:50,  1.50s/it]

{'loss': 1.6842, 'grad_norm': 5.349874496459961, 'learning_rate': 1.8324607329842933e-06, 'epoch': 1.82}


 91%|█████████▏| 351/384 [07:52<00:50,  1.54s/it]

{'loss': 2.07, 'grad_norm': 5.07811164855957, 'learning_rate': 1.780104712041885e-06, 'epoch': 1.83}


 92%|█████████▏| 352/384 [07:54<00:48,  1.52s/it]

{'loss': 1.6393, 'grad_norm': 6.777271747589111, 'learning_rate': 1.7277486910994765e-06, 'epoch': 1.83}


 92%|█████████▏| 353/384 [07:55<00:45,  1.47s/it]

{'loss': 1.7482, 'grad_norm': 7.424725532531738, 'learning_rate': 1.6753926701570683e-06, 'epoch': 1.84}


 92%|█████████▏| 354/384 [07:57<00:43,  1.44s/it]

{'loss': 1.9683, 'grad_norm': 6.47988224029541, 'learning_rate': 1.6230366492146598e-06, 'epoch': 1.84}


 92%|█████████▏| 355/384 [07:58<00:41,  1.42s/it]

{'loss': 2.0513, 'grad_norm': 5.09385871887207, 'learning_rate': 1.5706806282722515e-06, 'epoch': 1.85}


 93%|█████████▎| 356/384 [07:59<00:40,  1.43s/it]

{'loss': 0.9696, 'grad_norm': 5.430362701416016, 'learning_rate': 1.518324607329843e-06, 'epoch': 1.85}


 93%|█████████▎| 357/384 [08:01<00:37,  1.39s/it]

{'loss': 2.0511, 'grad_norm': 5.954738616943359, 'learning_rate': 1.4659685863874346e-06, 'epoch': 1.86}


 93%|█████████▎| 358/384 [08:02<00:35,  1.36s/it]

{'loss': 1.752, 'grad_norm': 5.27864933013916, 'learning_rate': 1.4136125654450263e-06, 'epoch': 1.86}


 93%|█████████▎| 359/384 [08:03<00:33,  1.34s/it]

{'loss': 1.9459, 'grad_norm': 4.623623371124268, 'learning_rate': 1.361256544502618e-06, 'epoch': 1.87}


 94%|█████████▍| 360/384 [08:05<00:31,  1.32s/it]

{'loss': 1.9597, 'grad_norm': 6.440476894378662, 'learning_rate': 1.3089005235602096e-06, 'epoch': 1.88}


 94%|█████████▍| 361/384 [08:06<00:33,  1.47s/it]

{'loss': 1.9887, 'grad_norm': 5.155098915100098, 'learning_rate': 1.256544502617801e-06, 'epoch': 1.88}


 94%|█████████▍| 362/384 [08:08<00:32,  1.48s/it]

{'loss': 1.2893, 'grad_norm': 10.29442024230957, 'learning_rate': 1.2041884816753928e-06, 'epoch': 1.89}


 95%|█████████▍| 363/384 [08:10<00:31,  1.52s/it]

{'loss': 1.7759, 'grad_norm': 5.744540214538574, 'learning_rate': 1.1518324607329843e-06, 'epoch': 1.89}


 95%|█████████▍| 364/384 [08:11<00:29,  1.47s/it]

{'loss': 2.0837, 'grad_norm': 6.221816539764404, 'learning_rate': 1.099476439790576e-06, 'epoch': 1.9}


 95%|█████████▌| 365/384 [08:12<00:27,  1.43s/it]

{'loss': 1.8837, 'grad_norm': 5.985708236694336, 'learning_rate': 1.0471204188481676e-06, 'epoch': 1.9}


 95%|█████████▌| 366/384 [08:13<00:25,  1.39s/it]

{'loss': 1.6353, 'grad_norm': 6.870574474334717, 'learning_rate': 9.947643979057591e-07, 'epoch': 1.91}


 96%|█████████▌| 367/384 [08:15<00:23,  1.36s/it]

{'loss': 1.5119, 'grad_norm': 8.2836332321167, 'learning_rate': 9.424083769633509e-07, 'epoch': 1.91}


 96%|█████████▌| 368/384 [08:16<00:22,  1.39s/it]

{'loss': 1.5005, 'grad_norm': 6.9841203689575195, 'learning_rate': 8.900523560209425e-07, 'epoch': 1.92}


 96%|█████████▌| 369/384 [08:18<00:20,  1.37s/it]

{'loss': 1.9389, 'grad_norm': 6.060715198516846, 'learning_rate': 8.376963350785341e-07, 'epoch': 1.92}


 96%|█████████▋| 370/384 [08:19<00:18,  1.34s/it]

{'loss': 1.9263, 'grad_norm': 6.47554349899292, 'learning_rate': 7.853403141361258e-07, 'epoch': 1.93}


 97%|█████████▋| 371/384 [08:20<00:17,  1.32s/it]

{'loss': 1.9937, 'grad_norm': 7.4792985916137695, 'learning_rate': 7.329842931937173e-07, 'epoch': 1.93}


 97%|█████████▋| 372/384 [08:22<00:16,  1.36s/it]

{'loss': 2.0246, 'grad_norm': 6.171481132507324, 'learning_rate': 6.80628272251309e-07, 'epoch': 1.94}


 97%|█████████▋| 373/384 [08:23<00:14,  1.35s/it]

{'loss': 1.8217, 'grad_norm': 5.318910598754883, 'learning_rate': 6.282722513089005e-07, 'epoch': 1.94}


 97%|█████████▋| 374/384 [08:24<00:13,  1.33s/it]

{'loss': 1.5248, 'grad_norm': 8.699291229248047, 'learning_rate': 5.759162303664922e-07, 'epoch': 1.95}


 98%|█████████▊| 375/384 [08:25<00:11,  1.32s/it]

{'loss': 1.5858, 'grad_norm': 6.239152431488037, 'learning_rate': 5.235602094240838e-07, 'epoch': 1.95}


 98%|█████████▊| 376/384 [08:27<00:10,  1.30s/it]

{'loss': 0.7894, 'grad_norm': 9.297835350036621, 'learning_rate': 4.7120418848167543e-07, 'epoch': 1.96}


 98%|█████████▊| 377/384 [08:28<00:09,  1.35s/it]

{'loss': 1.9196, 'grad_norm': 5.521204471588135, 'learning_rate': 4.1884816753926706e-07, 'epoch': 1.96}


 98%|█████████▊| 378/384 [08:30<00:08,  1.39s/it]

{'loss': 0.8772, 'grad_norm': 5.713862895965576, 'learning_rate': 3.6649214659685864e-07, 'epoch': 1.97}


 99%|█████████▊| 379/384 [08:31<00:07,  1.41s/it]

{'loss': 1.9284, 'grad_norm': 4.935030937194824, 'learning_rate': 3.1413612565445027e-07, 'epoch': 1.97}


 99%|█████████▉| 380/384 [08:32<00:05,  1.37s/it]

{'loss': 1.8036, 'grad_norm': 6.689606666564941, 'learning_rate': 2.617801047120419e-07, 'epoch': 1.98}


 99%|█████████▉| 381/384 [08:34<00:04,  1.35s/it]

{'loss': 1.7304, 'grad_norm': 5.778144836425781, 'learning_rate': 2.0942408376963353e-07, 'epoch': 1.98}


 99%|█████████▉| 382/384 [08:35<00:02,  1.33s/it]

{'loss': 2.1056, 'grad_norm': 6.472676753997803, 'learning_rate': 1.5706806282722514e-07, 'epoch': 1.99}


100%|█████████▉| 383/384 [08:36<00:01,  1.32s/it]

{'loss': 1.9525, 'grad_norm': 7.035027503967285, 'learning_rate': 1.0471204188481677e-07, 'epoch': 1.99}


100%|██████████| 384/384 [08:38<00:00,  1.35s/it]

{'loss': 1.587, 'grad_norm': 8.534857749938965, 'learning_rate': 5.235602094240838e-08, 'epoch': 2.0}
{'train_runtime': 518.0545, 'train_samples_per_second': 1.482, 'train_steps_per_second': 0.741, 'train_loss': 1.817257465639462, 'epoch': 2.0}





TrainOutput(global_step=384, training_loss=1.817257465639462, metrics={'train_runtime': 518.0545, 'train_samples_per_second': 1.482, 'train_steps_per_second': 0.741, 'total_flos': 8181365472067584.0, 'train_loss': 1.817257465639462, 'epoch': 2.0})

## Evaluating the finetuned model

In [21]:
# STEP 9. Generate a completion with the finetuned model and compare it to the base generation.
ft_generation = generate_text("I'm afraid I've created a ")

print("Base model generation: " + base_generation + "\n\n")
print("Finetuned generation: " + ft_generation)

Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Base model generation: I'm afraid I've created a 2000-level problem with a 100-level solution.

I'm a 2000-level problem.

I'm a 2000-level problem.

I'm a 2000-level problem.

I'm a 2000-level problem.

I'm a 2000-level problem.

I'm a 2


Finetuned generation: I'm afraid I've created a  monster in putting
  life and sensation into the work of dead
  matter.  Speak to me that I may feel the
  delight of words. Speak that I may share
  the sympathies of our nature, which are
  more powerful and more real than all material
  appetites. Agatha, be not afraid; the
  picture of your beauty vanishes.  I behold you
  in your true form.  Come, my


A little more like the original text, right? Try experimenting with the hyperparameters to see if you can improve performance.

In [None]:
# STEP 10. Calculate the finetuned model's perplexity and compare it to the base model's.
ft_ppl = calc_perplexity(model)
print("Base model perplexity: " + str(base_ppl))
print("Finetuned model perplexity: " + str(ft_ppl))