### Practice: Parameter Efficient Fine-Tuning
In this notebook, you're gonna fine-tune large language models within limited GPU memory.

In [1]:
%pip install --quiet transformers accelerate sentencepiece optimum peft bitsandbytes

import torch
import torch.nn as nn
import torch.nn.functional as F

import transformers
from tqdm.auto import tqdm, trange
assert torch.cuda.is_available(), "you need cuda for this part"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m23.0.1[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpython3 -m pip install --upgrade pip[0m




In [2]:
model_name = 'Enoch/llama-7b-hf'

# loading Llama tokenizer ...
tokenizer = transformers.LlamaTokenizer.from_pretrained(model_name, device_map=device)
tokenizer.pad_token_id = tokenizer.eos_token_id

# ... and the model itself
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False

model.gradient_checkpointing_enable()  # only store a small subset of activations, re-compute the rest.
model.enable_input_require_grads()     # override an implementation quirk in gradient checkpoints that disables backprop unless inputs require grad
# more on gradient checkpointing: https://pytorch.org/docs/stable/checkpoint.html https://arxiv.org/abs/1604.06174

You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message
2024-12-18 20:37:53.334906: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesCon

### Prompt tuning: the story of a fox (2 pts)

![img](https://i.imgur.com/Ux3qQAu.png) (source: theodd1souts.fandom.com)

In [None]:
prompt = 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

for i in range(10):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0].cpu().numpy().tolist()))


Output: <s>A quick brown fox jumps over the lazy dog.
A quick


What a blatant lie! This particular fox assures you that it didn't in fact jump over the lazy dog. No, sir! The fox was just minding its own business. __Your task is to train the model to say truth: no dog was jumped over today.__

In [None]:
the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
outputs = model(**batch)

next_word_logits = outputs.logits[:, :-1]
true_next_tokens = batch['input_ids'][:, 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

print("Loss:", loss)

Loss: tensor(3.0729, device='cuda:0', grad_fn=<NllLossBackward0>)


Except, we can't train the entire model - that would be 28GB gradients in float32. Instead, let's run [prompt tuning](https://arxiv.org/abs/2104.08691).

![img](https://i.imgur.com/VwNNKnb.png)


In [None]:
class WordEmbeddingsWithLearnedPrompts(nn.Module):
    """
    To perform prompt tuning, you will need to replace model's original word embeddings with a layer - THIS layer
     - that inserts trainable prompts instead of the first N token embeddings. """

    def __init__(self, word_embeddings: nn.Embedding, num_prompts: int):
        super().__init__()
        self.original_word_embeddings = word_embeddings
        self.num_prompts = num_prompts
        self.learnable_prompts = nn.Parameter(
            torch.randn(1, num_prompts, word_embeddings.embedding_dim), requires_grad=True)

    def forward(self, input_ids: torch.LongTensor):
        # input_ids shape: [batch_size, seq length]
        assert input_ids.dtype == torch.int64
        assert input_ids.shape[1] > self.num_prompts
        assert torch.all(input_ids[:, :self.num_prompts] == tokenizer.pad_token_id).item(), "don't forget to prepend several BOS tokens to input_ids"

        # Your task: embed input_ids, but replace the first :num_prompts: tokens with self.learnable_prompts
        # This is because we will prepend :num_prompts: padding tokens at the beginning

        # After you are done, you must produce a word embedding vector for each token in input_ids,
        # except that the first :num_prompts: vectors should equal learnable_prompts;
        # any additional vectors after first :num_prompts: ones should be embedded as usual
        # Note: since you're dealing with trainable params, please torch.cat instead of item assignment

        your_outputs_with_prompts_as_per_instructions_above = torch.cat([self.learnable_prompts, self.original_word_embeddings(input_ids[:, self.num_prompts:])], dim=1)

        return your_outputs_with_prompts_as_per_instructions_above

In [None]:
num_prompts = 16
test_emb_layer = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)
test_input_ids = tokenizer("a cat say on a may", return_tensors='pt')['input_ids'].to(device)

space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
test_inputs_with_prompts = torch.cat([space_for_prompts, test_input_ids], dim=1)

with torch.cuda.amp.autocast():
  test_prompt_embeddings = test_emb_layer(test_inputs_with_prompts)

assert test_prompt_embeddings.shape[:2] == test_inputs_with_prompts.shape
assert test_prompt_embeddings.shape[-1] == model.config.hidden_size
assert torch.allclose(test_prompt_embeddings[:, :num_prompts], test_emb_layer.learnable_prompts.float())
assert torch.allclose(test_prompt_embeddings[:, num_prompts:], model.model.embed_tokens(test_input_ids).float())
print("Looks legit!")

Looks legit!


  with torch.cuda.amp.autocast():


__Now that it works,__ let's inject learnable prompts into the main model and teach it about foxes.

In [None]:
assert isinstance(model.model.embed_tokens, nn.Embedding), "you have already replaced the embedding layer. If the replacement is broken, please reload the model"

model.model.embed_tokens = WordEmbeddingsWithLearnedPrompts(model.model.embed_tokens, num_prompts=num_prompts).to(device)

opt = torch.optim.Adam([model.model.embed_tokens.learnable_prompts], lr=0.01)

In [None]:
the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)
space_for_prompts = torch.full([len(test_input_ids), num_prompts], fill_value=tokenizer.pad_token_id,
                               dtype=torch.int64, device=device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)

outputs = model(**batch)
next_word_logits = outputs.logits[:, num_prompts : -1, :]
true_next_tokens = batch['input_ids'][:, num_prompts + 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
print("Loss:", loss)

while loss.item() > 0.1:
    outputs = model(**batch)
    next_word_logits = outputs.logits[:, num_prompts : -1, :]
    true_next_tokens = batch['input_ids'][:, num_prompts + 1:]
    loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

    opt.zero_grad()
    loss.backward()
    opt.step()


assert loss.item() <= 0.1
print("Good job!")

Loss: tensor(4.5127, device='cuda:0', grad_fn=<NllLossBackward0>)
Good job!


In [None]:
prompt = 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)
batch['input_ids'] = torch.cat([space_for_prompts, batch['input_ids']], dim=1)
batch['attention_mask'] = torch.cat([torch.ones_like(space_for_prompts), batch['attention_mask']], dim=1)


for i in range(15):
    next_token = model(**batch).logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, num_prompts:].cpu().numpy().tolist()))

# if you did everything right, the model will deny that the fox jumped over the lazy dog


Output: <s>A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it


### Using HuggingFace PEFT (2 points)

[`peft`](https://huggingface.co/docs/peft/index) is a transformer's sister library that allows you to apply various __p__arameter __e__fficient __f__ine-__t__uning methods to pre-trained transformers. The library imlements both prompt tuning, prefix tuning, as well as several adapter-based techniques under a common interface:



In [3]:
import peft
assert isinstance(model.model.embed_tokens, nn.Embedding), "please reload the model"

peft_config = peft.PromptTuningConfig(task_type=peft.TaskType.CAUSAL_LM, num_virtual_tokens=16)
model = peft.get_peft_model(model, peft_config)  # note: for most peft methods, this line also modifies model in-place
print("Trainable parameters:", sum(p.numel() for p in model.parameters() if p.requires_grad))
print("Total parameters (excluding quantization):", sum(p.numel() for p in model.parameters()))

Trainable parameters: 65536
Total parameters (excluding quantization): 3500478464


In [None]:
# Your task: optimize the PEFT-wrapped model to achieve next token prediction loss < 0.1, but this time using PEFT
# Please note: you no longer need to prepend PAD tokens, but you still need to skip :num_virtual_tokens: first logits.
# Finally, generate the sentence to make sure that the model learned the truth.

In [22]:
opt = torch.optim.Adam(model.parameters(), lr=0.01)

In [23]:
the_truth = "A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it anyway!"
batch = tokenizer(the_truth, return_tensors='pt', return_token_type_ids=False).to(device)

num_virtual_tokens = 16
outputs = model(**batch)
next_word_logits = outputs.logits[:, num_virtual_tokens : -1, :]
true_next_tokens = batch['input_ids'][:, 1:]
loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))
print("Loss:", loss)

while loss.item() > 0.1:
    outputs = model(**batch)
    next_word_logits = outputs.logits[:, num_virtual_tokens : -1, :]
    true_next_tokens = batch['input_ids'][:, 1:]
    loss = F.cross_entropy(next_word_logits.flatten(0, 1), true_next_tokens.flatten(0, 1))

    opt.zero_grad()
    loss.backward()
    opt.step()

assert loss.item() <= 0.1
print("Good job!")

Loss: tensor(7.9091, device='cuda:0', grad_fn=<NllLossBackward0>)
Good job!


In [25]:
prompt = 'A quick brown fox'
batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

for i in range(15):
    next_token_logits = model(**batch).logits[:, num_virtual_tokens:, :]
    next_token = next_token_logits[0, -1].argmax(-1).reshape(1, 1)
    batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
    batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

print("\nOutput:", tokenizer.decode(batch['input_ids'][0, :].cpu().numpy().tolist()))


Output: <s>A quick brown fox did not jump over the lazy dog. Besides, that dog deserved it


### Parameter-efficient finetuning with LoRA (2 points)

When training on more serious tasks, you can use low-rank adapters based on the [LoRA paper](https://arxiv.org/pdf/2106.09685.pdf).

The core idea is to add low-rank adapters __in parallel with existing linear layers,__ like this:
<center><img src="https://i.imgur.com/6bQLNiG.png" width=240px></center>

In the original LoRA paper, the adapters were only added to attention projection matrices. However, [subsequent works](https://arxiv.org/abs/2305.14314) show that it is useful to adapt FFNs as well. But before we do any training, we need to implement the basic LoRA layer.

In [3]:
# re-load the model to remove any previous PEFT tuners
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s]

In [3]:
class LoRALayer(nn.Module):
    """Wraps a linear layer with LoRA-like adapter. Wraps an existing OPT linear layer"""
    def __init__(self, module: nn.Linear, rank: int):
        super().__init__()
        self.module = module  # pre-trained (frozen) linear layer
        self.adapter_A = nn.Parameter(torch.empty(module.in_features, rank, device=module.weight.device))
        nn.init.kaiming_uniform_(self.adapter_A, a=5 ** 0.5)
        self.adapter_B = nn.Parameter(torch.zeros(rank, module.out_features, device=module.weight.device))

    def forward(self, input):
        # Apply self.module and LoRA adapter, return the sum (self.module outputs + adapter outputs)
        pre_lora_forward = self.module(input)
        lora_forward = input @ self.adapter_A @ self.adapter_B
        return pre_lora_forward + lora_forward

In [9]:
# test your implementation
test_linear = nn.Linear(128, 128)
test_linear.weight.data[...] = torch.eye(128)
test_adapter = LoRALayer(test_linear, rank=8)

assert torch.allclose(test_adapter(torch.ones(1, 1, 128)), test_linear.bias + 1), "please check your forward pass"

test_adapter.adapter_A.data[...] = torch.linspace(0.1, -0.5, 128 * 8).view(128, 8)
test_adapter.adapter_B.data[...] = torch.linspace(0.5, -0.1, 128 * 8).view(8, 128)
test_linear.bias.data[...] = torch.linspace(1., -1., 128)

dummy_loss = F.mse_loss(test_adapter(torch.ones(1, 128) / 128).squeeze(), torch.linspace(-1, 1, 128))
assert torch.allclose(dummy_loss, torch.tensor(1.3711389), rtol=0, atol=1e-4)
dummy_loss.backward()
assert all(w.grad is not None for w in [test_adapter.adapter_A, test_adapter.adapter_B]), "some adapter weights have no grad"
assert torch.allclose(test_adapter.adapter_A.grad.sum(), torch.tensor(-0.60158), rtol=0, atol=1e-4), "bad grad w.r.t. A"
assert torch.allclose(test_adapter.adapter_B.grad.sum(), torch.tensor(0.9931), rtol=0, atol=1e-4), "bad grad w.r.t. B"
# note: bad grad means that your code is different from LoRA paper OR that your code is not autograd-friendly (e.g. no_grad)
del dummy_loss, test_linear, test_adapter
print("All tests passed!")

All tests passed!


### Apply LoRA to the model

The code below applies LoRA adapters on top of Q/K/V linear layers in Llama attention. You may also choose to modify other layers:
* self_attn.o_proj - attention output projection
* mlp.up_proj, mlp.gate_proj, mlp.down_proj - transformer feedforward layers
* lm_head - output LM head

__Note:__ please scroll down for the homework task

In [10]:
lora_rank = 8

for name, module in model.model.layers.named_modules():
    if 'LlamaDecoderLayer' in repr(type(module)):
        module.self_attn.q_proj = LoRALayer(module.self_attn.q_proj, rank=lora_rank).to(device)
        module.self_attn.k_proj = LoRALayer(module.self_attn.k_proj, rank=lora_rank).to(device)
        module.self_attn.v_proj = LoRALayer(module.self_attn.v_proj, rank=lora_rank).to(device)

assert sum(isinstance(module, LoRALayer) for module in model.modules()) == 96  # for Llama-7B

In [11]:
batch = tokenizer("This model wants to share its greatest secret:", return_tensors='pt', return_token_type_ids=False)
# test a single training step, make sure we get meaningful gradients
with torch.cuda.amp.autocast(dtype=torch.float32):
    out = model.forward(**batch)
    (out.logits.norm() / 100).backward()

for i, module in enumerate(model.modules()):
    if isinstance(module, LoRALayer):
        assert module.adapter_B.grad is not None
        assert module.adapter_B.grad.norm().item() > 0

model.zero_grad(set_to_none=True)
print("Grad check successful, well done!")

  with torch.cuda.amp.autocast(dtype=torch.float32):


Grad check successful, well done!


  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


### (example) How to train your model

The example below shows how to train the LoRA adapters on a dummy dataset. You will need to run a _similar_ training task later.

__Note:__ please scroll down for the homework task

In [12]:
# checking if the model can learn. Change max_steps for proper training
import datasets
data = datasets.load_dataset("Abirate/english_quotes", split="train[:32]") # 32 lines
data = data.map(lambda samples: tokenizer(samples['quote']), batched=True)
model._hf_peft_config_loaded = True  # silence a warning from HF trainer

trainer = transformers.Trainer(
    model=model, train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=2, gradient_accumulation_steps=1,
        # note: if you want larger batch size, increase gradient_accumulation_steps
        warmup_steps=250, max_steps=100, learning_rate=2e-4, fp16=True,
        logging_steps=1, output_dir='outputs', report_to=None),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
# if you see cache warnings, set `model.config.use_cache = False` to silence them. Please re-enable for inference!

trainer.train()

# NOTE: this is just an example! you do not have to wait for this progressbar to finish :)

README.md:   0%|          | 0.00/5.55k [00:00<?, ?B/s]

quotes.jsonl:   0%|          | 0.00/647k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/2508 [00:00<?, ? examples/s]

Map:   0%|          | 0/32 [00:00<?, ? examples/s]

max_steps is given, it will override any value given in num_train_epochs
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize
wandb: Paste an API key from your profile and hit enter, or press ctrl+c to quit:

 ··········


[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,1.229
2,0.373
3,1.4028
4,1.365
5,0.8374
6,1.5453
7,1.7293
8,1.208
9,0.5465
10,1.2249




UnboundLocalError: local variable 'active_adapters' referenced before assignment

### Final task: *actually* train the model (4 points)

Your task is to fine-tune the model to _generate python code_. Please use the above examples for inspiration. More specifically,

* __dataset:__ use [codeparrot-clean](https://huggingface.co/datasets/codeparrot/codeparrot-clean) or any other data containing python code. Since you do not need much data for this excercise, it is enough to use just shorter validation subset of `codeparrots`
* __preprocessing:__ select python code based on file extentions (.py)  (may skip in case of codeparrot - it is 100% python)
* __short lines:__ please take the first 512 characters of each line
* __adapter type:__ please use LoRA as defined above __plus at least one of:__
   - extra adapter on lm_head
   - extra adapter on MLP components (mlp.*)
   - trainable input embeddings (requires tweaking memory usage)

* __training:__ you do not have to train to convergence. If all goes well, your model should `.generate` code after 500 steps. Please use batch size of at least 4 (4 x 1 x 512 tokens) using `gradient_accumulation_steps=4`.


Note: the peft library also has LoRA implementation. However, we ask that for this assignment you show at least one complete training run with your own LoRA code.

__Alternative assignment:__ Instead of doing python code, feel free to substitute the task with any other dataset, e.g. your favorite artist or podcast, as long as it's ethical. If you choose your own task, please show examples of what your model learned - or did not learn, akin to the code examples below.

In [32]:
# re-load the model to remove any previous PEFT tuners
model = transformers.AutoModelForCausalLM.from_pretrained(
    model_name, device_map='auto', low_cpu_mem_usage=True, offload_state_dict=True,
    load_in_4bit=True, torch_dtype=torch.float32,  # weights are 4-bit; layernorms and activations are fp32
)
for param in model.parameters():
    param.requires_grad=False
model.gradient_checkpointing_enable()
model.enable_input_require_grads()

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.

Downloading shards: 100%|██████████| 33/33 [00:00<00:00, 1220.08it/s]

Loading checkpoint shards:   0%|          | 0/33 [00:00<?, ?it/s][A
Loading checkpoint shards:   3%|▎         | 1/33 [00:00<00:13,  2.36it/s][A
Loading checkpoint shards:   6%|▌         | 2/33 [00:00<00:12,  2.40it/s][A
Loading checkpoint shards:   9%|▉         | 3/33 [00:01<00:12,  2.42it/s][A
Loading checkpoint shards:  12%|█▏        | 4/33 [00:01<00:12,  2.42it/s][A
Loading checkpoint shards:  15%|█▌        | 5/33 [00:02<00:11,  2.39it/s][A
Loading checkpoint shards:  18%|█▊        | 6/33 [00:02<00:11,  2.34it/s][A
Loading checkpoint shards:  21%|██        | 7/33 [00:02<00:11,  2.31it/s][A
Loading checkpoint shards:  24%|██▍       | 8/33 [00:03<00:10,  2.30it/s][A
Loading checkpoint shards:  27%|██▋       | 9/33 [

In [33]:
prompts =  ['', 'import', 'from', 'while', 'try', 'if', 'for', 'torch']  # feel free to add a few more that are not 100% assiciated with Python

# <A WHOLE LOT OF YOUR CODE>
# generate baseline samples with the selected prompts before finetuning
# please feel free to use transformers.Trainer (as above) or your custom training code
# after the training concludes, please show examples of text generated by your model. It is expected to look like Python code fragments
# print the generation examples nicely (suggestion: use pandas or HTML) for easier comparison
# note: your LoRA-enhanced model can run generation the same way as the non-trained model (above)

In [29]:
def code_generator(prompt, max_len=100):
    batch = tokenizer(prompt, return_tensors='pt', return_token_type_ids=False).to(device)

    for i in range(max_len):
        next_token_logits = model(**batch).logits[:, :, :]
        next_token = next_token_logits[0, -1].argmax(-1).reshape(1, 1)
        batch['input_ids'] = torch.cat([batch['input_ids'], next_token], dim=-1)
        batch['attention_mask'] = torch.cat([batch['attention_mask'], torch.ones_like(next_token)], dim=-1)

    return tokenizer.decode(batch['input_ids'][0, :].cpu().numpy().tolist())

In [34]:
before_finetune_code = [code_generator(prompt) for prompt in prompts]

In [11]:
lora_rank = 8

for name, module in model.model.layers.named_modules():
    if 'LlamaDecoderLayer' in repr(type(module)):
        module.self_attn.q_proj = LoRALayer(module.self_attn.q_proj, rank=lora_rank).to(device)
        module.self_attn.k_proj = LoRALayer(module.self_attn.k_proj, rank=lora_rank).to(device)
        module.self_attn.v_proj = LoRALayer(module.self_attn.v_proj, rank=lora_rank).to(device)
        module.mlp.gate_proj = LoRALayer(module.mlp.gate_proj, rank=lora_rank).to(device)
        module.mlp.up_proj = LoRALayer(module.mlp.up_proj, rank=lora_rank).to(device)
        module.mlp.down_proj = LoRALayer(module.mlp.down_proj, rank=lora_rank).to(device)

assert sum(isinstance(module, LoRALayer) for module in model.modules()) == 192  # for Llama-7B

In [12]:
model.lm_head = LoRALayer(model.lm_head, rank=lora_rank).to(device)

In [4]:
# checking if the model can learn. Change max_steps for proper training
import datasets
data = datasets.load_dataset("codeparrot/codeparrot-clean-valid", split='train')

Repo card metadata block was not found. Setting CardData to empty.


In [5]:
def content_handler(data_item):
    tokenized_item = tokenizer(data_item['content'], return_tensors="pt", padding='max_length', max_length=512, truncation=True).to(model.device)
    return tokenized_item

In [6]:
data = data.map(content_handler, batched=True)

Map: 100%|██████████| 61373/61373 [10:56<00:00, 93.54 examples/s]


In [14]:
model._hf_peft_config_loaded = True  # silence a warning from HF trainer

trainer = transformers.Trainer(
    model=model, train_dataset=data,
    args=transformers.TrainingArguments(
        per_device_train_batch_size=1, gradient_accumulation_steps=4,
        # note: if you want larger batch size, increase gradient_accumulation_steps
        warmup_steps=250, max_steps=500, learning_rate=2e-4, fp16=True,
        logging_steps=1, output_dir='outputs', report_to=None),
    data_collator=transformers.DataCollatorForLanguageModeling(tokenizer, mlm=False)
)
# if you see cache warnings, set `model.config.use_cache = False` to silence them. Please re-enable for inference!

trainer.train()

# NOTE: this is just an example! you do not have to wait for this progressbar to finish :)

[2024-12-18 21:09:49,417] [INFO] [real_accelerator.py:219:get_accelerator] Setting ds_accelerator to cuda (auto detect)


/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
/usr/bin/ld: cannot find -laio: No such file or directory
collect2: error: ld returned 1 exit status
max_steps is given, it will override any value given in num_train_epochs
  0%|          | 0/500 [00:00<?, ?it/s]`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.
  0%|          | 1/500 [00:02<19:06,  2.30s/it]

{'loss': 1.1075, 'grad_norm': 9.958664894104004, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.0}


  0%|          | 2/500 [00:03<15:52,  1.91s/it]

{'loss': 1.0005, 'grad_norm': 10.636354446411133, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.0}


  1%|          | 3/500 [00:05<14:48,  1.79s/it]

{'loss': 1.4238, 'grad_norm': 10.357032775878906, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.0}


  1%|          | 4/500 [00:07<14:17,  1.73s/it]

{'loss': 0.6755, 'grad_norm': 8.130602836608887, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.0}


  1%|          | 5/500 [00:08<13:59,  1.70s/it]

{'loss': 1.3445, 'grad_norm': 9.773690223693848, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.0}


  1%|          | 6/500 [00:10<13:48,  1.68s/it]

{'loss': 0.9132, 'grad_norm': 14.143204689025879, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.0}


  1%|▏         | 7/500 [00:12<13:40,  1.66s/it]

{'loss': 1.1413, 'grad_norm': 10.938251495361328, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.0}


  2%|▏         | 8/500 [00:13<13:33,  1.65s/it]

{'loss': 1.0574, 'grad_norm': 6.917974948883057, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.0}


  2%|▏         | 9/500 [00:15<13:30,  1.65s/it]

{'loss': 0.8478, 'grad_norm': 5.956939220428467, 'learning_rate': 7.2e-06, 'epoch': 0.0}


  2%|▏         | 10/500 [00:17<13:27,  1.65s/it]

{'loss': 0.9646, 'grad_norm': 3.804741859436035, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.0}


  2%|▏         | 11/500 [00:18<13:24,  1.64s/it]

{'loss': 0.8998, 'grad_norm': 3.373345375061035, 'learning_rate': 8.8e-06, 'epoch': 0.0}


  2%|▏         | 12/500 [00:20<13:21,  1.64s/it]

{'loss': 1.2857, 'grad_norm': 2.773005485534668, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.0}


  3%|▎         | 13/500 [00:21<13:18,  1.64s/it]

{'loss': 1.0488, 'grad_norm': 2.514221668243408, 'learning_rate': 1.04e-05, 'epoch': 0.0}


  3%|▎         | 14/500 [00:23<13:15,  1.64s/it]

{'loss': 1.2775, 'grad_norm': 3.1998727321624756, 'learning_rate': 1.1200000000000001e-05, 'epoch': 0.0}


  3%|▎         | 15/500 [00:25<13:14,  1.64s/it]

{'loss': 1.2843, 'grad_norm': 2.516591787338257, 'learning_rate': 1.2e-05, 'epoch': 0.0}


  3%|▎         | 16/500 [00:26<13:12,  1.64s/it]

{'loss': 1.3268, 'grad_norm': 2.7130000591278076, 'learning_rate': 1.2800000000000001e-05, 'epoch': 0.0}


  3%|▎         | 17/500 [00:28<13:09,  1.64s/it]

{'loss': 0.7802, 'grad_norm': 2.330303430557251, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.0}


  4%|▎         | 18/500 [00:30<13:08,  1.64s/it]

{'loss': 0.951, 'grad_norm': 3.112450361251831, 'learning_rate': 1.44e-05, 'epoch': 0.0}


  4%|▍         | 19/500 [00:31<13:07,  1.64s/it]

{'loss': 0.9791, 'grad_norm': 2.296211004257202, 'learning_rate': 1.52e-05, 'epoch': 0.0}


  4%|▍         | 20/500 [00:33<13:05,  1.64s/it]

{'loss': 0.9091, 'grad_norm': 2.0336759090423584, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.0}


  4%|▍         | 21/500 [00:35<13:03,  1.64s/it]

{'loss': 0.906, 'grad_norm': 2.8163819313049316, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.0}


  4%|▍         | 22/500 [00:36<13:01,  1.64s/it]

{'loss': 1.1586, 'grad_norm': 2.5169715881347656, 'learning_rate': 1.76e-05, 'epoch': 0.0}


  5%|▍         | 23/500 [00:38<13:00,  1.64s/it]

{'loss': 1.1455, 'grad_norm': 2.3193955421447754, 'learning_rate': 1.84e-05, 'epoch': 0.0}


  5%|▍         | 24/500 [00:39<12:58,  1.64s/it]

{'loss': 1.331, 'grad_norm': 3.4795398712158203, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.0}


  5%|▌         | 25/500 [00:41<12:56,  1.64s/it]

{'loss': 1.0602, 'grad_norm': 2.3575491905212402, 'learning_rate': 2e-05, 'epoch': 0.0}


  5%|▌         | 26/500 [00:43<12:54,  1.63s/it]

{'loss': 1.2019, 'grad_norm': 3.0827279090881348, 'learning_rate': 2.08e-05, 'epoch': 0.0}


  5%|▌         | 27/500 [00:44<12:53,  1.63s/it]

{'loss': 1.2454, 'grad_norm': 2.306868314743042, 'learning_rate': 2.16e-05, 'epoch': 0.0}


  6%|▌         | 28/500 [00:46<12:51,  1.63s/it]

{'loss': 1.1118, 'grad_norm': 2.0429000854492188, 'learning_rate': 2.2400000000000002e-05, 'epoch': 0.0}


  6%|▌         | 29/500 [00:48<12:49,  1.63s/it]

{'loss': 1.3033, 'grad_norm': 2.21431565284729, 'learning_rate': 2.32e-05, 'epoch': 0.0}


  6%|▌         | 30/500 [00:49<12:47,  1.63s/it]

{'loss': 0.8458, 'grad_norm': 2.2719640731811523, 'learning_rate': 2.4e-05, 'epoch': 0.0}


  6%|▌         | 31/500 [00:51<12:46,  1.63s/it]

{'loss': 0.8241, 'grad_norm': 2.590688467025757, 'learning_rate': 2.48e-05, 'epoch': 0.0}


  6%|▋         | 32/500 [00:53<12:44,  1.63s/it]

{'loss': 1.2835, 'grad_norm': 2.0941720008850098, 'learning_rate': 2.5600000000000002e-05, 'epoch': 0.0}


  7%|▋         | 33/500 [00:54<12:42,  1.63s/it]

{'loss': 0.8856, 'grad_norm': 1.8727165460586548, 'learning_rate': 2.64e-05, 'epoch': 0.0}


  7%|▋         | 34/500 [00:56<12:42,  1.64s/it]

{'loss': 0.9691, 'grad_norm': 2.1382999420166016, 'learning_rate': 2.7200000000000004e-05, 'epoch': 0.0}


  7%|▋         | 35/500 [00:57<12:40,  1.64s/it]

{'loss': 1.1348, 'grad_norm': 2.058391809463501, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.0}


  7%|▋         | 36/500 [00:59<12:38,  1.63s/it]

{'loss': 1.013, 'grad_norm': 1.949440598487854, 'learning_rate': 2.88e-05, 'epoch': 0.0}


  7%|▋         | 37/500 [01:01<12:36,  1.63s/it]

{'loss': 0.976, 'grad_norm': 2.1035830974578857, 'learning_rate': 2.96e-05, 'epoch': 0.0}


  8%|▊         | 38/500 [01:02<12:34,  1.63s/it]

{'loss': 1.0371, 'grad_norm': 2.158766984939575, 'learning_rate': 3.04e-05, 'epoch': 0.0}


  8%|▊         | 39/500 [01:04<12:32,  1.63s/it]

{'loss': 1.115, 'grad_norm': 1.790012001991272, 'learning_rate': 3.12e-05, 'epoch': 0.0}


  8%|▊         | 40/500 [01:06<12:30,  1.63s/it]

{'loss': 1.0773, 'grad_norm': 2.1120822429656982, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.0}


  8%|▊         | 41/500 [01:07<12:29,  1.63s/it]

{'loss': 0.9546, 'grad_norm': 1.9515409469604492, 'learning_rate': 3.2800000000000004e-05, 'epoch': 0.0}


  8%|▊         | 42/500 [01:09<12:27,  1.63s/it]

{'loss': 1.177, 'grad_norm': 1.9201194047927856, 'learning_rate': 3.3600000000000004e-05, 'epoch': 0.0}


  9%|▊         | 43/500 [01:10<12:25,  1.63s/it]

{'loss': 0.9196, 'grad_norm': 1.9394794702529907, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.0}


  9%|▉         | 44/500 [01:12<12:24,  1.63s/it]

{'loss': 0.8947, 'grad_norm': 1.9599976539611816, 'learning_rate': 3.52e-05, 'epoch': 0.0}


  9%|▉         | 45/500 [01:14<12:22,  1.63s/it]

{'loss': 0.9306, 'grad_norm': 1.7344435453414917, 'learning_rate': 3.6e-05, 'epoch': 0.0}


  9%|▉         | 46/500 [01:15<12:19,  1.63s/it]

{'loss': 1.3123, 'grad_norm': 2.1268956661224365, 'learning_rate': 3.68e-05, 'epoch': 0.0}


  9%|▉         | 47/500 [01:17<12:18,  1.63s/it]

{'loss': 1.2673, 'grad_norm': 1.8305596113204956, 'learning_rate': 3.76e-05, 'epoch': 0.0}


 10%|▉         | 48/500 [01:19<12:17,  1.63s/it]

{'loss': 1.0701, 'grad_norm': 2.178293228149414, 'learning_rate': 3.8400000000000005e-05, 'epoch': 0.0}


 10%|▉         | 49/500 [01:20<12:14,  1.63s/it]

{'loss': 1.4859, 'grad_norm': 2.05263614654541, 'learning_rate': 3.9200000000000004e-05, 'epoch': 0.0}


 10%|█         | 50/500 [01:22<12:13,  1.63s/it]

{'loss': 0.7849, 'grad_norm': 1.7817809581756592, 'learning_rate': 4e-05, 'epoch': 0.0}


 10%|█         | 51/500 [01:24<12:12,  1.63s/it]

{'loss': 0.9805, 'grad_norm': 1.7285983562469482, 'learning_rate': 4.08e-05, 'epoch': 0.0}


 10%|█         | 52/500 [01:25<12:10,  1.63s/it]

{'loss': 0.8092, 'grad_norm': 1.7652089595794678, 'learning_rate': 4.16e-05, 'epoch': 0.0}


 11%|█         | 53/500 [01:27<12:08,  1.63s/it]

{'loss': 1.1663, 'grad_norm': 2.038025379180908, 'learning_rate': 4.24e-05, 'epoch': 0.0}


 11%|█         | 54/500 [01:28<12:07,  1.63s/it]

{'loss': 0.844, 'grad_norm': 2.1890783309936523, 'learning_rate': 4.32e-05, 'epoch': 0.0}


 11%|█         | 55/500 [01:30<12:06,  1.63s/it]

{'loss': 1.0118, 'grad_norm': 1.9898532629013062, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.0}


 11%|█         | 56/500 [01:32<12:04,  1.63s/it]

{'loss': 0.8717, 'grad_norm': 1.8474277257919312, 'learning_rate': 4.4800000000000005e-05, 'epoch': 0.0}


 11%|█▏        | 57/500 [01:33<12:03,  1.63s/it]

{'loss': 0.9596, 'grad_norm': 1.8344168663024902, 'learning_rate': 4.5600000000000004e-05, 'epoch': 0.0}


 12%|█▏        | 58/500 [01:35<12:02,  1.63s/it]

{'loss': 0.9363, 'grad_norm': 2.0019900798797607, 'learning_rate': 4.64e-05, 'epoch': 0.0}


 12%|█▏        | 59/500 [01:37<12:01,  1.64s/it]

{'loss': 0.7335, 'grad_norm': 1.9039759635925293, 'learning_rate': 4.72e-05, 'epoch': 0.0}


 12%|█▏        | 60/500 [01:38<11:59,  1.63s/it]

{'loss': 1.2023, 'grad_norm': 1.8756877183914185, 'learning_rate': 4.8e-05, 'epoch': 0.0}


 12%|█▏        | 61/500 [01:40<11:57,  1.63s/it]

{'loss': 0.7312, 'grad_norm': 2.1120095252990723, 'learning_rate': 4.88e-05, 'epoch': 0.0}


 12%|█▏        | 62/500 [01:41<11:55,  1.63s/it]

{'loss': 0.6423, 'grad_norm': 2.23500394821167, 'learning_rate': 4.96e-05, 'epoch': 0.0}


 13%|█▎        | 63/500 [01:43<11:53,  1.63s/it]

{'loss': 0.9589, 'grad_norm': 2.1269054412841797, 'learning_rate': 5.0400000000000005e-05, 'epoch': 0.0}


 13%|█▎        | 64/500 [01:45<11:52,  1.63s/it]

{'loss': 1.185, 'grad_norm': 2.617119789123535, 'learning_rate': 5.1200000000000004e-05, 'epoch': 0.0}


 13%|█▎        | 65/500 [01:46<11:50,  1.63s/it]

{'loss': 1.1618, 'grad_norm': 2.060913324356079, 'learning_rate': 5.2000000000000004e-05, 'epoch': 0.0}


 13%|█▎        | 66/500 [01:48<11:49,  1.63s/it]

{'loss': 0.6128, 'grad_norm': 1.6228535175323486, 'learning_rate': 5.28e-05, 'epoch': 0.0}


 13%|█▎        | 67/500 [01:50<11:48,  1.64s/it]

{'loss': 0.9548, 'grad_norm': 1.8796191215515137, 'learning_rate': 5.360000000000001e-05, 'epoch': 0.0}


 14%|█▎        | 68/500 [01:51<11:46,  1.64s/it]

{'loss': 1.3123, 'grad_norm': 2.4129750728607178, 'learning_rate': 5.440000000000001e-05, 'epoch': 0.0}


 14%|█▍        | 69/500 [01:53<11:44,  1.63s/it]

{'loss': 0.78, 'grad_norm': 1.6544045209884644, 'learning_rate': 5.520000000000001e-05, 'epoch': 0.0}


 14%|█▍        | 70/500 [01:55<11:43,  1.63s/it]

{'loss': 0.9488, 'grad_norm': 2.4643521308898926, 'learning_rate': 5.6000000000000006e-05, 'epoch': 0.0}


 14%|█▍        | 71/500 [01:56<11:41,  1.63s/it]

{'loss': 0.9178, 'grad_norm': 1.79359769821167, 'learning_rate': 5.68e-05, 'epoch': 0.0}


 14%|█▍        | 72/500 [01:58<11:40,  1.64s/it]

{'loss': 0.9453, 'grad_norm': 1.9586142301559448, 'learning_rate': 5.76e-05, 'epoch': 0.0}


 15%|█▍        | 73/500 [01:59<11:37,  1.63s/it]

{'loss': 1.0887, 'grad_norm': 2.013502359390259, 'learning_rate': 5.8399999999999997e-05, 'epoch': 0.0}


 15%|█▍        | 74/500 [02:01<11:36,  1.64s/it]

{'loss': 1.3016, 'grad_norm': 2.098543643951416, 'learning_rate': 5.92e-05, 'epoch': 0.0}


 15%|█▌        | 75/500 [02:03<11:35,  1.64s/it]

{'loss': 1.0661, 'grad_norm': 2.2345263957977295, 'learning_rate': 6e-05, 'epoch': 0.0}


 15%|█▌        | 76/500 [02:04<11:34,  1.64s/it]

{'loss': 0.9748, 'grad_norm': 1.9515430927276611, 'learning_rate': 6.08e-05, 'epoch': 0.0}


 15%|█▌        | 77/500 [02:06<11:32,  1.64s/it]

{'loss': 0.9347, 'grad_norm': 2.1024298667907715, 'learning_rate': 6.16e-05, 'epoch': 0.01}


 16%|█▌        | 78/500 [02:08<11:30,  1.64s/it]

{'loss': 0.9564, 'grad_norm': 1.7681403160095215, 'learning_rate': 6.24e-05, 'epoch': 0.01}


 16%|█▌        | 79/500 [02:09<11:28,  1.63s/it]

{'loss': 1.13, 'grad_norm': 2.0950140953063965, 'learning_rate': 6.32e-05, 'epoch': 0.01}


 16%|█▌        | 80/500 [02:11<11:26,  1.64s/it]

{'loss': 1.2023, 'grad_norm': 1.9388152360916138, 'learning_rate': 6.400000000000001e-05, 'epoch': 0.01}


 16%|█▌        | 81/500 [02:13<11:25,  1.64s/it]

{'loss': 1.2998, 'grad_norm': 2.2966065406799316, 'learning_rate': 6.48e-05, 'epoch': 0.01}


 16%|█▋        | 82/500 [02:14<11:23,  1.63s/it]

{'loss': 1.1443, 'grad_norm': 2.2591452598571777, 'learning_rate': 6.560000000000001e-05, 'epoch': 0.01}


 17%|█▋        | 83/500 [02:16<11:21,  1.63s/it]

{'loss': 0.9924, 'grad_norm': 1.8750332593917847, 'learning_rate': 6.64e-05, 'epoch': 0.01}


 17%|█▋        | 84/500 [02:17<11:19,  1.63s/it]

{'loss': 1.1727, 'grad_norm': 2.152881145477295, 'learning_rate': 6.720000000000001e-05, 'epoch': 0.01}


 17%|█▋        | 85/500 [02:19<11:18,  1.63s/it]

{'loss': 0.9105, 'grad_norm': 1.7240092754364014, 'learning_rate': 6.800000000000001e-05, 'epoch': 0.01}


 17%|█▋        | 86/500 [02:21<11:16,  1.63s/it]

{'loss': 0.8959, 'grad_norm': 1.7424421310424805, 'learning_rate': 6.879999999999999e-05, 'epoch': 0.01}


 17%|█▋        | 87/500 [02:22<11:15,  1.64s/it]

{'loss': 1.2922, 'grad_norm': 1.6868888139724731, 'learning_rate': 6.96e-05, 'epoch': 0.01}


 18%|█▊        | 88/500 [02:24<11:13,  1.64s/it]

{'loss': 1.1043, 'grad_norm': 1.9439663887023926, 'learning_rate': 7.04e-05, 'epoch': 0.01}


 18%|█▊        | 89/500 [02:26<11:11,  1.63s/it]

{'loss': 0.8063, 'grad_norm': 1.550131916999817, 'learning_rate': 7.12e-05, 'epoch': 0.01}


 18%|█▊        | 90/500 [02:27<11:10,  1.63s/it]

{'loss': 0.903, 'grad_norm': 3.347158908843994, 'learning_rate': 7.2e-05, 'epoch': 0.01}


 18%|█▊        | 91/500 [02:29<11:08,  1.63s/it]

{'loss': 1.2436, 'grad_norm': 1.8389891386032104, 'learning_rate': 7.280000000000001e-05, 'epoch': 0.01}


 18%|█▊        | 92/500 [02:31<11:06,  1.63s/it]

{'loss': 0.9803, 'grad_norm': 1.673673391342163, 'learning_rate': 7.36e-05, 'epoch': 0.01}


 19%|█▊        | 93/500 [02:32<11:05,  1.64s/it]

{'loss': 0.816, 'grad_norm': 1.5490624904632568, 'learning_rate': 7.44e-05, 'epoch': 0.01}


 19%|█▉        | 94/500 [02:34<11:03,  1.64s/it]

{'loss': 1.0594, 'grad_norm': 2.2263340950012207, 'learning_rate': 7.52e-05, 'epoch': 0.01}


 19%|█▉        | 95/500 [02:35<11:01,  1.63s/it]

{'loss': 0.8374, 'grad_norm': 1.618703842163086, 'learning_rate': 7.6e-05, 'epoch': 0.01}


 19%|█▉        | 96/500 [02:37<11:00,  1.64s/it]

{'loss': 0.6523, 'grad_norm': 3.1396288871765137, 'learning_rate': 7.680000000000001e-05, 'epoch': 0.01}


 19%|█▉        | 97/500 [02:39<10:58,  1.63s/it]

{'loss': 0.8884, 'grad_norm': 1.800696611404419, 'learning_rate': 7.76e-05, 'epoch': 0.01}


 20%|█▉        | 98/500 [02:40<10:57,  1.64s/it]

{'loss': 1.0931, 'grad_norm': 1.6809250116348267, 'learning_rate': 7.840000000000001e-05, 'epoch': 0.01}


 20%|█▉        | 99/500 [02:42<10:56,  1.64s/it]

{'loss': 1.0181, 'grad_norm': 5.022506237030029, 'learning_rate': 7.920000000000001e-05, 'epoch': 0.01}


 20%|██        | 100/500 [02:44<10:54,  1.64s/it]

{'loss': 0.8854, 'grad_norm': 1.9565129280090332, 'learning_rate': 8e-05, 'epoch': 0.01}


 20%|██        | 101/500 [02:45<10:52,  1.64s/it]

{'loss': 0.8789, 'grad_norm': 1.6539382934570312, 'learning_rate': 8.080000000000001e-05, 'epoch': 0.01}


 20%|██        | 102/500 [02:47<10:51,  1.64s/it]

{'loss': 0.6175, 'grad_norm': 1.60969877243042, 'learning_rate': 8.16e-05, 'epoch': 0.01}


 21%|██        | 103/500 [02:49<10:49,  1.64s/it]

{'loss': 1.0994, 'grad_norm': 1.8011871576309204, 'learning_rate': 8.24e-05, 'epoch': 0.01}


 21%|██        | 104/500 [02:50<10:47,  1.64s/it]

{'loss': 0.8535, 'grad_norm': 1.6914161443710327, 'learning_rate': 8.32e-05, 'epoch': 0.01}


 21%|██        | 105/500 [02:52<10:46,  1.64s/it]

{'loss': 1.0068, 'grad_norm': 1.842616081237793, 'learning_rate': 8.4e-05, 'epoch': 0.01}


 21%|██        | 106/500 [02:53<10:44,  1.64s/it]

{'loss': 0.8903, 'grad_norm': 1.742197036743164, 'learning_rate': 8.48e-05, 'epoch': 0.01}


 21%|██▏       | 107/500 [02:55<10:42,  1.64s/it]

{'loss': 0.9225, 'grad_norm': 2.2147693634033203, 'learning_rate': 8.560000000000001e-05, 'epoch': 0.01}


 22%|██▏       | 108/500 [02:57<10:41,  1.64s/it]

{'loss': 0.775, 'grad_norm': 1.654103398323059, 'learning_rate': 8.64e-05, 'epoch': 0.01}


 22%|██▏       | 109/500 [02:58<10:39,  1.64s/it]

{'loss': 1.0431, 'grad_norm': 1.878722071647644, 'learning_rate': 8.72e-05, 'epoch': 0.01}


 22%|██▏       | 110/500 [03:00<10:37,  1.64s/it]

{'loss': 0.754, 'grad_norm': 1.7926084995269775, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.01}


 22%|██▏       | 111/500 [03:02<10:36,  1.64s/it]

{'loss': 1.412, 'grad_norm': 1.8803273439407349, 'learning_rate': 8.88e-05, 'epoch': 0.01}


 22%|██▏       | 112/500 [03:03<10:34,  1.63s/it]

{'loss': 0.9973, 'grad_norm': 1.8294317722320557, 'learning_rate': 8.960000000000001e-05, 'epoch': 0.01}


 23%|██▎       | 113/500 [03:05<10:32,  1.63s/it]

{'loss': 0.8916, 'grad_norm': 2.4179463386535645, 'learning_rate': 9.04e-05, 'epoch': 0.01}


 23%|██▎       | 114/500 [03:07<10:31,  1.64s/it]

{'loss': 0.5039, 'grad_norm': 1.8161554336547852, 'learning_rate': 9.120000000000001e-05, 'epoch': 0.01}


 23%|██▎       | 115/500 [03:08<10:29,  1.63s/it]

{'loss': 0.8153, 'grad_norm': 1.8276764154434204, 'learning_rate': 9.200000000000001e-05, 'epoch': 0.01}


 23%|██▎       | 116/500 [03:10<10:27,  1.63s/it]

{'loss': 0.6817, 'grad_norm': 1.5013765096664429, 'learning_rate': 9.28e-05, 'epoch': 0.01}


 23%|██▎       | 117/500 [03:11<10:25,  1.63s/it]

{'loss': 1.621, 'grad_norm': 1.9100112915039062, 'learning_rate': 9.360000000000001e-05, 'epoch': 0.01}


 24%|██▎       | 118/500 [03:13<10:24,  1.63s/it]

{'loss': 0.8587, 'grad_norm': 1.9445546865463257, 'learning_rate': 9.44e-05, 'epoch': 0.01}


 24%|██▍       | 119/500 [03:15<10:22,  1.63s/it]

{'loss': 1.1297, 'grad_norm': 1.7822966575622559, 'learning_rate': 9.52e-05, 'epoch': 0.01}


 24%|██▍       | 120/500 [03:16<10:21,  1.64s/it]

{'loss': 1.033, 'grad_norm': 1.8105236291885376, 'learning_rate': 9.6e-05, 'epoch': 0.01}


 24%|██▍       | 121/500 [03:18<10:19,  1.63s/it]

{'loss': 1.484, 'grad_norm': 2.006600856781006, 'learning_rate': 9.680000000000001e-05, 'epoch': 0.01}


 24%|██▍       | 122/500 [03:20<10:17,  1.63s/it]

{'loss': 1.0457, 'grad_norm': 1.6487014293670654, 'learning_rate': 9.76e-05, 'epoch': 0.01}


 25%|██▍       | 123/500 [03:21<10:15,  1.63s/it]

{'loss': 1.0153, 'grad_norm': 1.927760124206543, 'learning_rate': 9.84e-05, 'epoch': 0.01}


 25%|██▍       | 124/500 [03:23<10:14,  1.63s/it]

{'loss': 1.4319, 'grad_norm': 1.736218810081482, 'learning_rate': 9.92e-05, 'epoch': 0.01}


 25%|██▌       | 125/500 [03:25<10:12,  1.63s/it]

{'loss': 1.2297, 'grad_norm': 1.7747565507888794, 'learning_rate': 0.0001, 'epoch': 0.01}


 25%|██▌       | 126/500 [03:26<10:10,  1.63s/it]

{'loss': 0.951, 'grad_norm': 1.7375296354293823, 'learning_rate': 0.00010080000000000001, 'epoch': 0.01}


 25%|██▌       | 127/500 [03:28<10:08,  1.63s/it]

{'loss': 0.8666, 'grad_norm': 1.5974416732788086, 'learning_rate': 0.0001016, 'epoch': 0.01}


 26%|██▌       | 128/500 [03:29<10:07,  1.63s/it]

{'loss': 0.6319, 'grad_norm': 1.7828048467636108, 'learning_rate': 0.00010240000000000001, 'epoch': 0.01}


 26%|██▌       | 129/500 [03:31<10:05,  1.63s/it]

{'loss': 0.7909, 'grad_norm': 1.6543679237365723, 'learning_rate': 0.0001032, 'epoch': 0.01}


 26%|██▌       | 130/500 [03:33<10:04,  1.63s/it]

{'loss': 1.1867, 'grad_norm': 1.8209333419799805, 'learning_rate': 0.00010400000000000001, 'epoch': 0.01}


 26%|██▌       | 131/500 [03:34<10:03,  1.63s/it]

{'loss': 0.7503, 'grad_norm': 2.1002516746520996, 'learning_rate': 0.00010480000000000001, 'epoch': 0.01}


 26%|██▋       | 132/500 [03:36<10:01,  1.63s/it]

{'loss': 0.7176, 'grad_norm': 2.246497631072998, 'learning_rate': 0.0001056, 'epoch': 0.01}


 27%|██▋       | 133/500 [03:38<10:00,  1.64s/it]

{'loss': 0.8207, 'grad_norm': 1.7854201793670654, 'learning_rate': 0.00010640000000000001, 'epoch': 0.01}


 27%|██▋       | 134/500 [03:39<09:58,  1.63s/it]

{'loss': 0.7208, 'grad_norm': 1.6226915121078491, 'learning_rate': 0.00010720000000000002, 'epoch': 0.01}


 27%|██▋       | 135/500 [03:41<09:56,  1.63s/it]

{'loss': 0.8666, 'grad_norm': 1.9868241548538208, 'learning_rate': 0.00010800000000000001, 'epoch': 0.01}


 27%|██▋       | 136/500 [03:42<09:54,  1.63s/it]

{'loss': 1.1808, 'grad_norm': 2.0403738021850586, 'learning_rate': 0.00010880000000000002, 'epoch': 0.01}


 27%|██▋       | 137/500 [03:44<09:53,  1.63s/it]

{'loss': 0.6998, 'grad_norm': 2.117292881011963, 'learning_rate': 0.00010960000000000001, 'epoch': 0.01}


 28%|██▊       | 138/500 [03:46<09:51,  1.63s/it]

{'loss': 1.1273, 'grad_norm': 2.2280020713806152, 'learning_rate': 0.00011040000000000001, 'epoch': 0.01}


 28%|██▊       | 139/500 [03:47<09:49,  1.63s/it]

{'loss': 0.9044, 'grad_norm': 1.7978522777557373, 'learning_rate': 0.00011120000000000002, 'epoch': 0.01}


 28%|██▊       | 140/500 [03:49<09:47,  1.63s/it]

{'loss': 0.84, 'grad_norm': 1.6713776588439941, 'learning_rate': 0.00011200000000000001, 'epoch': 0.01}


 28%|██▊       | 141/500 [03:51<09:46,  1.63s/it]

{'loss': 1.1585, 'grad_norm': 1.9057258367538452, 'learning_rate': 0.00011279999999999999, 'epoch': 0.01}


 28%|██▊       | 142/500 [03:52<09:45,  1.63s/it]

{'loss': 1.5581, 'grad_norm': 1.7278590202331543, 'learning_rate': 0.0001136, 'epoch': 0.01}


 29%|██▊       | 143/500 [03:54<09:43,  1.63s/it]

{'loss': 1.0897, 'grad_norm': 1.761773943901062, 'learning_rate': 0.0001144, 'epoch': 0.01}


 29%|██▉       | 144/500 [03:56<09:41,  1.63s/it]

{'loss': 1.1226, 'grad_norm': 1.7919100522994995, 'learning_rate': 0.0001152, 'epoch': 0.01}


 29%|██▉       | 145/500 [03:57<09:40,  1.64s/it]

{'loss': 1.2088, 'grad_norm': 1.6925092935562134, 'learning_rate': 0.000116, 'epoch': 0.01}


 29%|██▉       | 146/500 [03:59<09:39,  1.64s/it]

{'loss': 0.9619, 'grad_norm': 1.585002064704895, 'learning_rate': 0.00011679999999999999, 'epoch': 0.01}


 29%|██▉       | 147/500 [04:00<09:37,  1.64s/it]

{'loss': 1.1358, 'grad_norm': 1.5836524963378906, 'learning_rate': 0.0001176, 'epoch': 0.01}


 30%|██▉       | 148/500 [04:02<09:35,  1.64s/it]

{'loss': 1.0376, 'grad_norm': 1.7011502981185913, 'learning_rate': 0.0001184, 'epoch': 0.01}


 30%|██▉       | 149/500 [04:04<09:34,  1.64s/it]

{'loss': 1.5095, 'grad_norm': 2.07239031791687, 'learning_rate': 0.0001192, 'epoch': 0.01}


 30%|███       | 150/500 [04:05<09:32,  1.64s/it]

{'loss': 1.4159, 'grad_norm': 2.1525931358337402, 'learning_rate': 0.00012, 'epoch': 0.01}


 30%|███       | 151/500 [04:07<09:30,  1.64s/it]

{'loss': 0.9506, 'grad_norm': 2.417452812194824, 'learning_rate': 0.0001208, 'epoch': 0.01}


 30%|███       | 152/500 [04:09<09:28,  1.63s/it]

{'loss': 0.9253, 'grad_norm': 1.6790062189102173, 'learning_rate': 0.0001216, 'epoch': 0.01}


 31%|███       | 153/500 [04:10<09:26,  1.63s/it]

{'loss': 1.0423, 'grad_norm': 1.5899471044540405, 'learning_rate': 0.0001224, 'epoch': 0.01}


 31%|███       | 154/500 [04:12<09:25,  1.63s/it]

{'loss': 1.0328, 'grad_norm': 1.579790472984314, 'learning_rate': 0.0001232, 'epoch': 0.01}


 31%|███       | 155/500 [04:14<09:23,  1.63s/it]

{'loss': 0.6543, 'grad_norm': 1.3636727333068848, 'learning_rate': 0.000124, 'epoch': 0.01}


 31%|███       | 156/500 [04:15<09:21,  1.63s/it]

{'loss': 1.2287, 'grad_norm': 1.788979411125183, 'learning_rate': 0.0001248, 'epoch': 0.01}


 31%|███▏      | 157/500 [04:17<09:19,  1.63s/it]

{'loss': 0.8167, 'grad_norm': 1.4884939193725586, 'learning_rate': 0.00012560000000000002, 'epoch': 0.01}


 32%|███▏      | 158/500 [04:18<09:18,  1.63s/it]

{'loss': 0.8841, 'grad_norm': 1.3482173681259155, 'learning_rate': 0.0001264, 'epoch': 0.01}


 32%|███▏      | 159/500 [04:20<09:16,  1.63s/it]

{'loss': 0.915, 'grad_norm': 1.5456628799438477, 'learning_rate': 0.0001272, 'epoch': 0.01}


 32%|███▏      | 160/500 [04:22<09:15,  1.63s/it]

{'loss': 0.6615, 'grad_norm': 1.5279741287231445, 'learning_rate': 0.00012800000000000002, 'epoch': 0.01}


 32%|███▏      | 161/500 [04:23<09:13,  1.63s/it]

{'loss': 1.1013, 'grad_norm': 1.6769704818725586, 'learning_rate': 0.00012880000000000001, 'epoch': 0.01}


 32%|███▏      | 162/500 [04:25<09:11,  1.63s/it]

{'loss': 0.6654, 'grad_norm': 1.4124362468719482, 'learning_rate': 0.0001296, 'epoch': 0.01}


 33%|███▎      | 163/500 [04:27<09:10,  1.63s/it]

{'loss': 0.5504, 'grad_norm': 2.1134793758392334, 'learning_rate': 0.0001304, 'epoch': 0.01}


 33%|███▎      | 164/500 [04:28<09:08,  1.63s/it]

{'loss': 1.0002, 'grad_norm': 1.8545222282409668, 'learning_rate': 0.00013120000000000002, 'epoch': 0.01}


 33%|███▎      | 165/500 [04:30<09:07,  1.63s/it]

{'loss': 1.1198, 'grad_norm': 1.8394708633422852, 'learning_rate': 0.000132, 'epoch': 0.01}


 33%|███▎      | 166/500 [04:31<09:05,  1.63s/it]

{'loss': 0.7398, 'grad_norm': 1.5333759784698486, 'learning_rate': 0.0001328, 'epoch': 0.01}


 33%|███▎      | 167/500 [04:33<09:04,  1.63s/it]

{'loss': 0.9955, 'grad_norm': 1.7503355741500854, 'learning_rate': 0.00013360000000000002, 'epoch': 0.01}


 34%|███▎      | 168/500 [04:35<09:03,  1.64s/it]

{'loss': 0.6866, 'grad_norm': 1.3682163953781128, 'learning_rate': 0.00013440000000000001, 'epoch': 0.01}


 34%|███▍      | 169/500 [04:36<09:01,  1.63s/it]

{'loss': 0.5938, 'grad_norm': 1.4822001457214355, 'learning_rate': 0.0001352, 'epoch': 0.01}


 34%|███▍      | 170/500 [04:38<08:59,  1.63s/it]

{'loss': 0.6773, 'grad_norm': 1.4269042015075684, 'learning_rate': 0.00013600000000000003, 'epoch': 0.01}


 34%|███▍      | 171/500 [04:40<08:57,  1.63s/it]

{'loss': 1.1178, 'grad_norm': 1.65548574924469, 'learning_rate': 0.00013680000000000002, 'epoch': 0.01}


 34%|███▍      | 172/500 [04:41<08:55,  1.63s/it]

{'loss': 0.882, 'grad_norm': 1.8866180181503296, 'learning_rate': 0.00013759999999999998, 'epoch': 0.01}


 35%|███▍      | 173/500 [04:43<08:54,  1.63s/it]

{'loss': 1.0867, 'grad_norm': 1.5590859651565552, 'learning_rate': 0.0001384, 'epoch': 0.01}


 35%|███▍      | 174/500 [04:45<08:52,  1.63s/it]

{'loss': 0.9664, 'grad_norm': 1.6676576137542725, 'learning_rate': 0.0001392, 'epoch': 0.01}


 35%|███▌      | 175/500 [04:46<08:51,  1.63s/it]

{'loss': 1.0267, 'grad_norm': 1.7796900272369385, 'learning_rate': 0.00014, 'epoch': 0.01}


 35%|███▌      | 176/500 [04:48<08:49,  1.63s/it]

{'loss': 1.2319, 'grad_norm': 2.166712760925293, 'learning_rate': 0.0001408, 'epoch': 0.01}


 35%|███▌      | 177/500 [04:49<08:48,  1.64s/it]

{'loss': 1.1339, 'grad_norm': 1.6392652988433838, 'learning_rate': 0.0001416, 'epoch': 0.01}


 36%|███▌      | 178/500 [04:51<08:46,  1.64s/it]

{'loss': 0.7742, 'grad_norm': 1.5398057699203491, 'learning_rate': 0.0001424, 'epoch': 0.01}


 36%|███▌      | 179/500 [04:53<08:45,  1.64s/it]

{'loss': 1.1776, 'grad_norm': 1.4908267259597778, 'learning_rate': 0.0001432, 'epoch': 0.01}


 36%|███▌      | 180/500 [04:54<08:43,  1.64s/it]

{'loss': 0.7872, 'grad_norm': 1.6341063976287842, 'learning_rate': 0.000144, 'epoch': 0.01}


 36%|███▌      | 181/500 [04:56<08:41,  1.63s/it]

{'loss': 1.0362, 'grad_norm': 1.724847674369812, 'learning_rate': 0.0001448, 'epoch': 0.01}


 36%|███▋      | 182/500 [04:58<08:40,  1.64s/it]

{'loss': 0.9921, 'grad_norm': 1.8144066333770752, 'learning_rate': 0.00014560000000000002, 'epoch': 0.01}


 37%|███▋      | 183/500 [04:59<08:38,  1.63s/it]

{'loss': 1.1099, 'grad_norm': 1.66533362865448, 'learning_rate': 0.0001464, 'epoch': 0.01}


 37%|███▋      | 184/500 [05:01<08:36,  1.63s/it]

{'loss': 1.099, 'grad_norm': 1.6589025259017944, 'learning_rate': 0.0001472, 'epoch': 0.01}


 37%|███▋      | 185/500 [05:03<08:35,  1.64s/it]

{'loss': 0.9669, 'grad_norm': 2.308854818344116, 'learning_rate': 0.000148, 'epoch': 0.01}


 37%|███▋      | 186/500 [05:04<08:33,  1.64s/it]

{'loss': 1.4436, 'grad_norm': 1.9314411878585815, 'learning_rate': 0.0001488, 'epoch': 0.01}


 37%|███▋      | 187/500 [05:06<08:31,  1.63s/it]

{'loss': 1.1149, 'grad_norm': 2.2160379886627197, 'learning_rate': 0.0001496, 'epoch': 0.01}


 38%|███▊      | 188/500 [05:07<08:30,  1.63s/it]

{'loss': 1.0427, 'grad_norm': 1.7441619634628296, 'learning_rate': 0.0001504, 'epoch': 0.01}


 38%|███▊      | 189/500 [05:09<08:28,  1.63s/it]

{'loss': 0.8659, 'grad_norm': 1.5852364301681519, 'learning_rate': 0.00015120000000000002, 'epoch': 0.01}


 38%|███▊      | 190/500 [05:11<08:26,  1.64s/it]

{'loss': 0.9876, 'grad_norm': 1.488118052482605, 'learning_rate': 0.000152, 'epoch': 0.01}


 38%|███▊      | 191/500 [05:12<08:24,  1.63s/it]

{'loss': 0.9965, 'grad_norm': 1.4945793151855469, 'learning_rate': 0.0001528, 'epoch': 0.01}


 38%|███▊      | 192/500 [05:14<08:23,  1.63s/it]

{'loss': 1.0348, 'grad_norm': 1.657145619392395, 'learning_rate': 0.00015360000000000002, 'epoch': 0.01}


 39%|███▊      | 193/500 [05:16<08:21,  1.63s/it]

{'loss': 1.026, 'grad_norm': 1.570009708404541, 'learning_rate': 0.0001544, 'epoch': 0.01}


 39%|███▉      | 194/500 [05:17<08:19,  1.63s/it]

{'loss': 1.2624, 'grad_norm': 1.9386193752288818, 'learning_rate': 0.0001552, 'epoch': 0.01}


 39%|███▉      | 195/500 [05:19<08:18,  1.63s/it]

{'loss': 0.8859, 'grad_norm': 1.8751604557037354, 'learning_rate': 0.00015600000000000002, 'epoch': 0.01}


 39%|███▉      | 196/500 [05:21<08:17,  1.64s/it]

{'loss': 1.1757, 'grad_norm': 1.5165499448776245, 'learning_rate': 0.00015680000000000002, 'epoch': 0.01}


 39%|███▉      | 197/500 [05:22<08:15,  1.63s/it]

{'loss': 1.5875, 'grad_norm': 1.672820806503296, 'learning_rate': 0.0001576, 'epoch': 0.01}


 40%|███▉      | 198/500 [05:24<08:13,  1.63s/it]

{'loss': 0.8494, 'grad_norm': 1.4620497226715088, 'learning_rate': 0.00015840000000000003, 'epoch': 0.01}


 40%|███▉      | 199/500 [05:25<08:11,  1.63s/it]

{'loss': 1.0241, 'grad_norm': 1.6128382682800293, 'learning_rate': 0.00015920000000000002, 'epoch': 0.01}


 40%|████      | 200/500 [05:27<08:10,  1.63s/it]

{'loss': 0.7114, 'grad_norm': 1.463457703590393, 'learning_rate': 0.00016, 'epoch': 0.01}


 40%|████      | 201/500 [05:29<08:08,  1.63s/it]

{'loss': 0.7952, 'grad_norm': 1.9663381576538086, 'learning_rate': 0.0001608, 'epoch': 0.01}


 40%|████      | 202/500 [05:30<08:06,  1.63s/it]

{'loss': 0.82, 'grad_norm': 1.3471438884735107, 'learning_rate': 0.00016160000000000002, 'epoch': 0.01}


 41%|████      | 203/500 [05:32<08:05,  1.63s/it]

{'loss': 1.2711, 'grad_norm': 1.552713394165039, 'learning_rate': 0.00016240000000000002, 'epoch': 0.01}


 41%|████      | 204/500 [05:34<08:03,  1.63s/it]

{'loss': 1.0812, 'grad_norm': 1.6525088548660278, 'learning_rate': 0.0001632, 'epoch': 0.01}


 41%|████      | 205/500 [05:35<08:02,  1.63s/it]

{'loss': 1.1807, 'grad_norm': 1.5568053722381592, 'learning_rate': 0.000164, 'epoch': 0.01}


 41%|████      | 206/500 [05:37<08:00,  1.63s/it]

{'loss': 0.7802, 'grad_norm': 2.0203299522399902, 'learning_rate': 0.0001648, 'epoch': 0.01}


 41%|████▏     | 207/500 [05:38<07:58,  1.63s/it]

{'loss': 1.2346, 'grad_norm': 1.4882316589355469, 'learning_rate': 0.0001656, 'epoch': 0.01}


 42%|████▏     | 208/500 [05:40<07:56,  1.63s/it]

{'loss': 1.0824, 'grad_norm': 1.381462574005127, 'learning_rate': 0.0001664, 'epoch': 0.01}


 42%|████▏     | 209/500 [05:42<07:55,  1.63s/it]

{'loss': 1.4559, 'grad_norm': 1.6230849027633667, 'learning_rate': 0.0001672, 'epoch': 0.01}


 42%|████▏     | 210/500 [05:43<07:53,  1.63s/it]

{'loss': 1.1202, 'grad_norm': 1.6105053424835205, 'learning_rate': 0.000168, 'epoch': 0.01}


 42%|████▏     | 211/500 [05:45<07:51,  1.63s/it]

{'loss': 1.2804, 'grad_norm': 1.7575520277023315, 'learning_rate': 0.0001688, 'epoch': 0.01}


 42%|████▏     | 212/500 [05:47<07:50,  1.63s/it]

{'loss': 1.1476, 'grad_norm': 1.4534780979156494, 'learning_rate': 0.0001696, 'epoch': 0.01}


 43%|████▎     | 213/500 [05:48<07:48,  1.63s/it]

{'loss': 1.2089, 'grad_norm': 1.6715227365493774, 'learning_rate': 0.0001704, 'epoch': 0.01}


 43%|████▎     | 214/500 [05:50<07:47,  1.63s/it]

{'loss': 0.7338, 'grad_norm': 1.4267735481262207, 'learning_rate': 0.00017120000000000001, 'epoch': 0.01}


 43%|████▎     | 215/500 [05:52<07:45,  1.63s/it]

{'loss': 0.7939, 'grad_norm': 1.351159930229187, 'learning_rate': 0.000172, 'epoch': 0.01}


 43%|████▎     | 216/500 [05:53<07:43,  1.63s/it]

{'loss': 0.9597, 'grad_norm': 1.354401707649231, 'learning_rate': 0.0001728, 'epoch': 0.01}


 43%|████▎     | 217/500 [05:55<07:42,  1.63s/it]

{'loss': 1.275, 'grad_norm': 1.647094488143921, 'learning_rate': 0.00017360000000000002, 'epoch': 0.01}


 44%|████▎     | 218/500 [05:56<07:40,  1.63s/it]

{'loss': 1.1589, 'grad_norm': 1.5071678161621094, 'learning_rate': 0.0001744, 'epoch': 0.01}


 44%|████▍     | 219/500 [05:58<07:38,  1.63s/it]

{'loss': 0.7914, 'grad_norm': 1.335300326347351, 'learning_rate': 0.0001752, 'epoch': 0.01}


 44%|████▍     | 220/500 [06:00<07:37,  1.63s/it]

{'loss': 0.5997, 'grad_norm': 1.7317445278167725, 'learning_rate': 0.00017600000000000002, 'epoch': 0.01}


 44%|████▍     | 221/500 [06:01<07:36,  1.63s/it]

{'loss': 0.9697, 'grad_norm': 1.7267537117004395, 'learning_rate': 0.00017680000000000001, 'epoch': 0.01}


 44%|████▍     | 222/500 [06:03<07:34,  1.63s/it]

{'loss': 1.2529, 'grad_norm': 1.914074420928955, 'learning_rate': 0.0001776, 'epoch': 0.01}


 45%|████▍     | 223/500 [06:05<07:32,  1.63s/it]

{'loss': 0.9655, 'grad_norm': 1.403516411781311, 'learning_rate': 0.0001784, 'epoch': 0.01}


 45%|████▍     | 224/500 [06:06<07:31,  1.64s/it]

{'loss': 0.8335, 'grad_norm': 1.5503950119018555, 'learning_rate': 0.00017920000000000002, 'epoch': 0.01}


 45%|████▌     | 225/500 [06:08<07:29,  1.63s/it]

{'loss': 1.2629, 'grad_norm': 1.6929566860198975, 'learning_rate': 0.00018, 'epoch': 0.01}


 45%|████▌     | 226/500 [06:10<07:27,  1.63s/it]

{'loss': 0.9614, 'grad_norm': 2.4828429222106934, 'learning_rate': 0.0001808, 'epoch': 0.01}


                                                 

{'loss': 1.0357, 'grad_norm': 1.5519770383834839, 'learning_rate': 0.00018160000000000002, 'epoch': 0.01}


 46%|████▌     | 228/500 [06:13<07:23,  1.63s/it]

{'loss': 1.2403, 'grad_norm': 1.6836614608764648, 'learning_rate': 0.00018240000000000002, 'epoch': 0.01}


 46%|████▌     | 229/500 [06:14<07:22,  1.63s/it]

{'loss': 0.9789, 'grad_norm': 1.530909776687622, 'learning_rate': 0.0001832, 'epoch': 0.01}


 46%|████▌     | 230/500 [06:16<07:20,  1.63s/it]

{'loss': 0.771, 'grad_norm': 1.5825693607330322, 'learning_rate': 0.00018400000000000003, 'epoch': 0.01}


 46%|████▌     | 231/500 [06:18<07:19,  1.63s/it]

{'loss': 1.0412, 'grad_norm': 1.4103798866271973, 'learning_rate': 0.00018480000000000002, 'epoch': 0.02}


 46%|████▋     | 232/500 [06:19<07:17,  1.63s/it]

{'loss': 0.9276, 'grad_norm': 1.2779414653778076, 'learning_rate': 0.0001856, 'epoch': 0.02}


 47%|████▋     | 233/500 [06:21<07:15,  1.63s/it]

{'loss': 0.8348, 'grad_norm': 1.280408501625061, 'learning_rate': 0.00018640000000000003, 'epoch': 0.02}


 47%|████▋     | 234/500 [06:23<07:14,  1.63s/it]

{'loss': 1.07, 'grad_norm': 1.6178674697875977, 'learning_rate': 0.00018720000000000002, 'epoch': 0.02}


 47%|████▋     | 235/500 [06:24<07:12,  1.63s/it]

{'loss': 1.1227, 'grad_norm': 1.5126190185546875, 'learning_rate': 0.000188, 'epoch': 0.02}


 47%|████▋     | 236/500 [06:26<07:10,  1.63s/it]

{'loss': 1.1364, 'grad_norm': 2.2866687774658203, 'learning_rate': 0.0001888, 'epoch': 0.02}


 47%|████▋     | 237/500 [06:27<07:09,  1.63s/it]

{'loss': 0.7644, 'grad_norm': 1.7970434427261353, 'learning_rate': 0.0001896, 'epoch': 0.02}


 48%|████▊     | 238/500 [06:29<07:07,  1.63s/it]

{'loss': 1.1834, 'grad_norm': 1.6063437461853027, 'learning_rate': 0.0001904, 'epoch': 0.02}


 48%|████▊     | 239/500 [06:31<07:06,  1.63s/it]

{'loss': 0.8372, 'grad_norm': 1.3356834650039673, 'learning_rate': 0.0001912, 'epoch': 0.02}


 48%|████▊     | 240/500 [06:32<07:04,  1.63s/it]

{'loss': 0.8221, 'grad_norm': 1.5905930995941162, 'learning_rate': 0.000192, 'epoch': 0.02}


 48%|████▊     | 241/500 [06:34<07:03,  1.63s/it]

{'loss': 0.8172, 'grad_norm': 1.2271316051483154, 'learning_rate': 0.0001928, 'epoch': 0.02}


 48%|████▊     | 242/500 [06:36<07:01,  1.63s/it]

{'loss': 0.6781, 'grad_norm': 1.4689662456512451, 'learning_rate': 0.00019360000000000002, 'epoch': 0.02}


 49%|████▊     | 243/500 [06:37<06:59,  1.63s/it]

{'loss': 1.2693, 'grad_norm': 1.5498120784759521, 'learning_rate': 0.0001944, 'epoch': 0.02}


 49%|████▉     | 244/500 [06:39<06:57,  1.63s/it]

{'loss': 1.0594, 'grad_norm': 1.518126368522644, 'learning_rate': 0.0001952, 'epoch': 0.02}


 49%|████▉     | 245/500 [06:41<06:56,  1.63s/it]

{'loss': 1.1771, 'grad_norm': 1.55638587474823, 'learning_rate': 0.000196, 'epoch': 0.02}


 49%|████▉     | 246/500 [06:42<06:54,  1.63s/it]

{'loss': 0.9966, 'grad_norm': 1.4261810779571533, 'learning_rate': 0.0001968, 'epoch': 0.02}


 49%|████▉     | 247/500 [06:44<06:53,  1.63s/it]

{'loss': 1.3095, 'grad_norm': 1.638160228729248, 'learning_rate': 0.0001976, 'epoch': 0.02}


 50%|████▉     | 248/500 [06:45<06:51,  1.63s/it]

{'loss': 1.2725, 'grad_norm': 1.6754518747329712, 'learning_rate': 0.0001984, 'epoch': 0.02}


 50%|████▉     | 249/500 [06:47<06:50,  1.63s/it]

{'loss': 1.0022, 'grad_norm': 2.4520692825317383, 'learning_rate': 0.00019920000000000002, 'epoch': 0.02}


 50%|█████     | 250/500 [06:49<06:48,  1.64s/it]

{'loss': 1.0573, 'grad_norm': 1.326831340789795, 'learning_rate': 0.0002, 'epoch': 0.02}


 50%|█████     | 251/500 [06:50<06:46,  1.63s/it]

{'loss': 1.0156, 'grad_norm': 1.3122682571411133, 'learning_rate': 0.00019920000000000002, 'epoch': 0.02}


 50%|█████     | 252/500 [06:52<06:45,  1.63s/it]

{'loss': 1.166, 'grad_norm': 2.0937447547912598, 'learning_rate': 0.0001984, 'epoch': 0.02}


 51%|█████     | 253/500 [06:54<06:43,  1.63s/it]

{'loss': 0.6245, 'grad_norm': 1.1341845989227295, 'learning_rate': 0.0001976, 'epoch': 0.02}


 51%|█████     | 254/500 [06:55<06:42,  1.63s/it]

{'loss': 0.8326, 'grad_norm': 1.1626161336898804, 'learning_rate': 0.0001968, 'epoch': 0.02}


 51%|█████     | 255/500 [06:57<06:40,  1.64s/it]

{'loss': 0.8587, 'grad_norm': 1.3687927722930908, 'learning_rate': 0.000196, 'epoch': 0.02}


 51%|█████     | 256/500 [06:59<06:38,  1.63s/it]

{'loss': 0.9415, 'grad_norm': 1.443182349205017, 'learning_rate': 0.0001952, 'epoch': 0.02}


 51%|█████▏    | 257/500 [07:00<06:37,  1.63s/it]

{'loss': 1.0699, 'grad_norm': 1.4957154989242554, 'learning_rate': 0.0001944, 'epoch': 0.02}


 52%|█████▏    | 258/500 [07:02<06:35,  1.63s/it]

{'loss': 1.0102, 'grad_norm': 1.307771921157837, 'learning_rate': 0.00019360000000000002, 'epoch': 0.02}


 52%|█████▏    | 259/500 [07:03<06:33,  1.63s/it]

{'loss': 0.919, 'grad_norm': 1.6542714834213257, 'learning_rate': 0.0001928, 'epoch': 0.02}


 52%|█████▏    | 260/500 [07:05<06:32,  1.63s/it]

{'loss': 1.0906, 'grad_norm': 1.4570691585540771, 'learning_rate': 0.000192, 'epoch': 0.02}


 52%|█████▏    | 261/500 [07:07<06:30,  1.63s/it]

{'loss': 1.0563, 'grad_norm': 1.4133572578430176, 'learning_rate': 0.0001912, 'epoch': 0.02}


 52%|█████▏    | 262/500 [07:08<06:28,  1.63s/it]

{'loss': 0.6832, 'grad_norm': 1.4783408641815186, 'learning_rate': 0.0001904, 'epoch': 0.02}


 53%|█████▎    | 263/500 [07:10<06:26,  1.63s/it]

{'loss': 1.0986, 'grad_norm': 1.5264620780944824, 'learning_rate': 0.0001896, 'epoch': 0.02}


 53%|█████▎    | 264/500 [07:12<06:25,  1.63s/it]

{'loss': 0.9481, 'grad_norm': 1.5070178508758545, 'learning_rate': 0.0001888, 'epoch': 0.02}


 53%|█████▎    | 265/500 [07:13<06:23,  1.63s/it]

{'loss': 1.3575, 'grad_norm': 1.4197278022766113, 'learning_rate': 0.000188, 'epoch': 0.02}


 53%|█████▎    | 266/500 [07:15<06:22,  1.63s/it]

{'loss': 0.9903, 'grad_norm': 1.588348388671875, 'learning_rate': 0.00018720000000000002, 'epoch': 0.02}


 53%|█████▎    | 267/500 [07:16<06:20,  1.63s/it]

{'loss': 1.3476, 'grad_norm': 1.476796269416809, 'learning_rate': 0.00018640000000000003, 'epoch': 0.02}


 54%|█████▎    | 268/500 [07:18<06:18,  1.63s/it]

{'loss': 0.9008, 'grad_norm': 1.4072660207748413, 'learning_rate': 0.0001856, 'epoch': 0.02}


 54%|█████▍    | 269/500 [07:20<06:16,  1.63s/it]

{'loss': 0.6343, 'grad_norm': 1.8145833015441895, 'learning_rate': 0.00018480000000000002, 'epoch': 0.02}


 54%|█████▍    | 270/500 [07:21<06:15,  1.63s/it]

{'loss': 0.6359, 'grad_norm': 1.4709874391555786, 'learning_rate': 0.00018400000000000003, 'epoch': 0.02}


 54%|█████▍    | 271/500 [07:23<06:13,  1.63s/it]

{'loss': 0.7681, 'grad_norm': 1.4295666217803955, 'learning_rate': 0.0001832, 'epoch': 0.02}


 54%|█████▍    | 272/500 [07:25<06:12,  1.63s/it]

{'loss': 1.2301, 'grad_norm': 1.5791453123092651, 'learning_rate': 0.00018240000000000002, 'epoch': 0.02}


 55%|█████▍    | 273/500 [07:26<06:10,  1.63s/it]

{'loss': 1.2613, 'grad_norm': 1.5480996370315552, 'learning_rate': 0.00018160000000000002, 'epoch': 0.02}


 55%|█████▍    | 274/500 [07:28<06:08,  1.63s/it]

{'loss': 1.6321, 'grad_norm': 1.5232908725738525, 'learning_rate': 0.0001808, 'epoch': 0.02}


 55%|█████▌    | 275/500 [07:30<06:07,  1.63s/it]

{'loss': 0.9187, 'grad_norm': 1.275515079498291, 'learning_rate': 0.00018, 'epoch': 0.02}


 55%|█████▌    | 276/500 [07:31<06:05,  1.63s/it]

{'loss': 0.9942, 'grad_norm': 1.3503837585449219, 'learning_rate': 0.00017920000000000002, 'epoch': 0.02}


 55%|█████▌    | 277/500 [07:33<06:04,  1.63s/it]

{'loss': 0.9691, 'grad_norm': 1.4067039489746094, 'learning_rate': 0.0001784, 'epoch': 0.02}


 56%|█████▌    | 278/500 [07:34<06:02,  1.63s/it]

{'loss': 1.6352, 'grad_norm': 1.8567637205123901, 'learning_rate': 0.0001776, 'epoch': 0.02}


 56%|█████▌    | 279/500 [07:36<06:00,  1.63s/it]

{'loss': 1.1829, 'grad_norm': 1.4396535158157349, 'learning_rate': 0.00017680000000000001, 'epoch': 0.02}


 56%|█████▌    | 280/500 [07:38<05:59,  1.63s/it]

{'loss': 1.079, 'grad_norm': 1.533004641532898, 'learning_rate': 0.00017600000000000002, 'epoch': 0.02}


 56%|█████▌    | 281/500 [07:39<05:57,  1.63s/it]

{'loss': 0.8889, 'grad_norm': 1.3632042407989502, 'learning_rate': 0.0001752, 'epoch': 0.02}


 56%|█████▋    | 282/500 [07:41<05:55,  1.63s/it]

{'loss': 1.0559, 'grad_norm': 1.5870790481567383, 'learning_rate': 0.0001744, 'epoch': 0.02}


 57%|█████▋    | 283/500 [07:43<05:53,  1.63s/it]

{'loss': 0.7765, 'grad_norm': 1.3863394260406494, 'learning_rate': 0.00017360000000000002, 'epoch': 0.02}


 57%|█████▋    | 284/500 [07:44<05:52,  1.63s/it]

{'loss': 1.2136, 'grad_norm': 1.4117792844772339, 'learning_rate': 0.0001728, 'epoch': 0.02}


 57%|█████▋    | 285/500 [07:46<05:50,  1.63s/it]

{'loss': 0.7672, 'grad_norm': 1.4009807109832764, 'learning_rate': 0.000172, 'epoch': 0.02}


 57%|█████▋    | 286/500 [07:47<05:49,  1.63s/it]

{'loss': 1.2228, 'grad_norm': 1.46223783493042, 'learning_rate': 0.00017120000000000001, 'epoch': 0.02}


 57%|█████▋    | 287/500 [07:49<05:47,  1.63s/it]

{'loss': 1.0373, 'grad_norm': 1.2215322256088257, 'learning_rate': 0.0001704, 'epoch': 0.02}


 58%|█████▊    | 288/500 [07:51<05:45,  1.63s/it]

{'loss': 1.0833, 'grad_norm': 1.2538864612579346, 'learning_rate': 0.0001696, 'epoch': 0.02}


 58%|█████▊    | 289/500 [07:52<05:44,  1.63s/it]

{'loss': 1.2493, 'grad_norm': 1.6506984233856201, 'learning_rate': 0.0001688, 'epoch': 0.02}


 58%|█████▊    | 290/500 [07:54<05:42,  1.63s/it]

{'loss': 1.421, 'grad_norm': 1.309770107269287, 'learning_rate': 0.000168, 'epoch': 0.02}


 58%|█████▊    | 291/500 [07:56<05:41,  1.63s/it]

{'loss': 1.1914, 'grad_norm': 1.4009416103363037, 'learning_rate': 0.0001672, 'epoch': 0.02}


 58%|█████▊    | 292/500 [07:57<05:39,  1.63s/it]

{'loss': 0.8264, 'grad_norm': 1.4128608703613281, 'learning_rate': 0.0001664, 'epoch': 0.02}


 59%|█████▊    | 293/500 [07:59<05:37,  1.63s/it]

{'loss': 1.2384, 'grad_norm': 1.5083552598953247, 'learning_rate': 0.0001656, 'epoch': 0.02}


 59%|█████▉    | 294/500 [08:01<05:36,  1.63s/it]

{'loss': 1.1383, 'grad_norm': 1.4036983251571655, 'learning_rate': 0.0001648, 'epoch': 0.02}


 59%|█████▉    | 295/500 [08:02<05:34,  1.63s/it]

{'loss': 0.8696, 'grad_norm': 1.3365284204483032, 'learning_rate': 0.000164, 'epoch': 0.02}


 59%|█████▉    | 296/500 [08:04<05:32,  1.63s/it]

{'loss': 1.1928, 'grad_norm': 1.3188196420669556, 'learning_rate': 0.0001632, 'epoch': 0.02}


 59%|█████▉    | 297/500 [08:05<05:31,  1.63s/it]

{'loss': 1.0071, 'grad_norm': 1.2694014310836792, 'learning_rate': 0.00016240000000000002, 'epoch': 0.02}


 60%|█████▉    | 298/500 [08:07<05:29,  1.63s/it]

{'loss': 0.9431, 'grad_norm': 1.231036901473999, 'learning_rate': 0.00016160000000000002, 'epoch': 0.02}


 60%|█████▉    | 299/500 [08:09<05:27,  1.63s/it]

{'loss': 0.7936, 'grad_norm': 1.448082447052002, 'learning_rate': 0.0001608, 'epoch': 0.02}


 60%|██████    | 300/500 [08:10<05:26,  1.63s/it]

{'loss': 0.9853, 'grad_norm': 1.6087507009506226, 'learning_rate': 0.00016, 'epoch': 0.02}


 60%|██████    | 301/500 [08:12<05:24,  1.63s/it]

{'loss': 1.3052, 'grad_norm': 1.3622050285339355, 'learning_rate': 0.00015920000000000002, 'epoch': 0.02}


 60%|██████    | 302/500 [08:14<05:23,  1.63s/it]

{'loss': 0.7495, 'grad_norm': 1.0858871936798096, 'learning_rate': 0.00015840000000000003, 'epoch': 0.02}


 61%|██████    | 303/500 [08:15<05:21,  1.63s/it]

{'loss': 1.1237, 'grad_norm': 1.5032992362976074, 'learning_rate': 0.0001576, 'epoch': 0.02}


 61%|██████    | 304/500 [08:17<05:19,  1.63s/it]

{'loss': 1.0717, 'grad_norm': 1.664565086364746, 'learning_rate': 0.00015680000000000002, 'epoch': 0.02}


 61%|██████    | 305/500 [08:18<05:17,  1.63s/it]

{'loss': 1.1038, 'grad_norm': 1.5416405200958252, 'learning_rate': 0.00015600000000000002, 'epoch': 0.02}


 61%|██████    | 306/500 [08:20<05:16,  1.63s/it]

{'loss': 1.2932, 'grad_norm': 1.5375949144363403, 'learning_rate': 0.0001552, 'epoch': 0.02}


 61%|██████▏   | 307/500 [08:22<05:14,  1.63s/it]

{'loss': 1.1985, 'grad_norm': 1.3814051151275635, 'learning_rate': 0.0001544, 'epoch': 0.02}


 62%|██████▏   | 308/500 [08:23<05:12,  1.63s/it]

{'loss': 1.0849, 'grad_norm': 1.3224091529846191, 'learning_rate': 0.00015360000000000002, 'epoch': 0.02}


 62%|██████▏   | 309/500 [08:25<05:10,  1.63s/it]

{'loss': 0.7266, 'grad_norm': 1.1626375913619995, 'learning_rate': 0.0001528, 'epoch': 0.02}


 62%|██████▏   | 310/500 [08:27<05:09,  1.63s/it]

{'loss': 1.4868, 'grad_norm': 1.3433473110198975, 'learning_rate': 0.000152, 'epoch': 0.02}


 62%|██████▏   | 311/500 [08:28<05:07,  1.63s/it]

{'loss': 1.0662, 'grad_norm': 1.5674501657485962, 'learning_rate': 0.00015120000000000002, 'epoch': 0.02}


 62%|██████▏   | 312/500 [08:30<05:05,  1.63s/it]

{'loss': 0.9839, 'grad_norm': 1.5270897150039673, 'learning_rate': 0.0001504, 'epoch': 0.02}


 63%|██████▎   | 313/500 [08:32<05:04,  1.63s/it]

{'loss': 1.2371, 'grad_norm': 1.8545507192611694, 'learning_rate': 0.0001496, 'epoch': 0.02}


 63%|██████▎   | 314/500 [08:33<05:02,  1.63s/it]

{'loss': 0.8008, 'grad_norm': 1.4340969324111938, 'learning_rate': 0.0001488, 'epoch': 0.02}


 63%|██████▎   | 315/500 [08:35<05:01,  1.63s/it]

{'loss': 0.7301, 'grad_norm': 1.161571979522705, 'learning_rate': 0.000148, 'epoch': 0.02}


 63%|██████▎   | 316/500 [08:36<04:59,  1.63s/it]

{'loss': 1.1383, 'grad_norm': 1.6660701036453247, 'learning_rate': 0.0001472, 'epoch': 0.02}


 63%|██████▎   | 317/500 [08:38<04:58,  1.63s/it]

{'loss': 0.8935, 'grad_norm': 1.1537994146347046, 'learning_rate': 0.0001464, 'epoch': 0.02}


 64%|██████▎   | 318/500 [08:40<04:56,  1.63s/it]

{'loss': 0.9862, 'grad_norm': 1.7526382207870483, 'learning_rate': 0.00014560000000000002, 'epoch': 0.02}


 64%|██████▍   | 319/500 [08:41<04:54,  1.63s/it]

{'loss': 1.4396, 'grad_norm': 2.0458672046661377, 'learning_rate': 0.0001448, 'epoch': 0.02}


 64%|██████▍   | 320/500 [08:43<04:53,  1.63s/it]

{'loss': 0.8805, 'grad_norm': 1.0983359813690186, 'learning_rate': 0.000144, 'epoch': 0.02}


 64%|██████▍   | 321/500 [08:45<04:51,  1.63s/it]

{'loss': 1.1896, 'grad_norm': 1.4028509855270386, 'learning_rate': 0.0001432, 'epoch': 0.02}


 64%|██████▍   | 322/500 [08:46<04:49,  1.63s/it]

{'loss': 0.9133, 'grad_norm': 1.3520619869232178, 'learning_rate': 0.0001424, 'epoch': 0.02}


 65%|██████▍   | 323/500 [08:48<04:48,  1.63s/it]

{'loss': 0.9625, 'grad_norm': 1.722306251525879, 'learning_rate': 0.0001416, 'epoch': 0.02}


 65%|██████▍   | 324/500 [08:49<04:46,  1.63s/it]

{'loss': 0.9969, 'grad_norm': 1.8542009592056274, 'learning_rate': 0.0001408, 'epoch': 0.02}


 65%|██████▌   | 325/500 [08:51<04:44,  1.63s/it]

{'loss': 1.0885, 'grad_norm': 1.2235890626907349, 'learning_rate': 0.00014, 'epoch': 0.02}


 65%|██████▌   | 326/500 [08:53<04:43,  1.63s/it]

{'loss': 0.8256, 'grad_norm': 1.2326151132583618, 'learning_rate': 0.0001392, 'epoch': 0.02}


 65%|██████▌   | 327/500 [08:54<04:41,  1.63s/it]

{'loss': 1.4215, 'grad_norm': 1.6234488487243652, 'learning_rate': 0.0001384, 'epoch': 0.02}


 66%|██████▌   | 328/500 [08:56<04:39,  1.63s/it]

{'loss': 1.641, 'grad_norm': 1.4249414205551147, 'learning_rate': 0.00013759999999999998, 'epoch': 0.02}


 66%|██████▌   | 329/500 [08:58<04:38,  1.63s/it]

{'loss': 0.9243, 'grad_norm': 1.1247376203536987, 'learning_rate': 0.00013680000000000002, 'epoch': 0.02}


 66%|██████▌   | 330/500 [08:59<04:36,  1.63s/it]

{'loss': 0.8785, 'grad_norm': 1.2669590711593628, 'learning_rate': 0.00013600000000000003, 'epoch': 0.02}


 66%|██████▌   | 331/500 [09:01<04:34,  1.63s/it]

{'loss': 1.058, 'grad_norm': 1.404228687286377, 'learning_rate': 0.0001352, 'epoch': 0.02}


 66%|██████▋   | 332/500 [09:02<04:33,  1.63s/it]

{'loss': 0.7948, 'grad_norm': 1.2612550258636475, 'learning_rate': 0.00013440000000000001, 'epoch': 0.02}


 67%|██████▋   | 333/500 [09:04<04:31,  1.63s/it]

{'loss': 0.9947, 'grad_norm': 1.2975374460220337, 'learning_rate': 0.00013360000000000002, 'epoch': 0.02}


 67%|██████▋   | 334/500 [09:06<04:30,  1.63s/it]

{'loss': 1.2052, 'grad_norm': 1.389978051185608, 'learning_rate': 0.0001328, 'epoch': 0.02}


 67%|██████▋   | 335/500 [09:07<04:28,  1.63s/it]

{'loss': 1.2608, 'grad_norm': 1.3067333698272705, 'learning_rate': 0.000132, 'epoch': 0.02}


 67%|██████▋   | 336/500 [09:09<04:26,  1.63s/it]

{'loss': 1.0147, 'grad_norm': 1.0923635959625244, 'learning_rate': 0.00013120000000000002, 'epoch': 0.02}


 67%|██████▋   | 337/500 [09:11<04:25,  1.63s/it]

{'loss': 1.0272, 'grad_norm': 1.3865101337432861, 'learning_rate': 0.0001304, 'epoch': 0.02}


 68%|██████▊   | 338/500 [09:12<04:23,  1.63s/it]

{'loss': 1.1552, 'grad_norm': 1.3793721199035645, 'learning_rate': 0.0001296, 'epoch': 0.02}


 68%|██████▊   | 339/500 [09:14<04:22,  1.63s/it]

{'loss': 1.1206, 'grad_norm': 1.445041298866272, 'learning_rate': 0.00012880000000000001, 'epoch': 0.02}


 68%|██████▊   | 340/500 [09:15<04:20,  1.63s/it]

{'loss': 0.8608, 'grad_norm': 1.1938637495040894, 'learning_rate': 0.00012800000000000002, 'epoch': 0.02}


 68%|██████▊   | 341/500 [09:17<04:18,  1.63s/it]

{'loss': 0.9714, 'grad_norm': 1.4300309419631958, 'learning_rate': 0.0001272, 'epoch': 0.02}


 68%|██████▊   | 342/500 [09:19<04:17,  1.63s/it]

{'loss': 1.1605, 'grad_norm': 1.3672370910644531, 'learning_rate': 0.0001264, 'epoch': 0.02}


 69%|██████▊   | 343/500 [09:20<04:15,  1.63s/it]

{'loss': 0.6759, 'grad_norm': 1.1901456117630005, 'learning_rate': 0.00012560000000000002, 'epoch': 0.02}


 69%|██████▉   | 344/500 [09:22<04:13,  1.63s/it]

{'loss': 0.9645, 'grad_norm': 1.3162599802017212, 'learning_rate': 0.0001248, 'epoch': 0.02}


 69%|██████▉   | 345/500 [09:24<04:12,  1.63s/it]

{'loss': 1.2289, 'grad_norm': 1.3041964769363403, 'learning_rate': 0.000124, 'epoch': 0.02}


 69%|██████▉   | 346/500 [09:25<04:10,  1.63s/it]

{'loss': 1.0643, 'grad_norm': 1.4984873533248901, 'learning_rate': 0.0001232, 'epoch': 0.02}


 69%|██████▉   | 347/500 [09:27<04:08,  1.63s/it]

{'loss': 0.8492, 'grad_norm': 1.275097131729126, 'learning_rate': 0.0001224, 'epoch': 0.02}


 70%|██████▉   | 348/500 [09:28<04:07,  1.63s/it]

{'loss': 1.1029, 'grad_norm': 1.189995288848877, 'learning_rate': 0.0001216, 'epoch': 0.02}


 70%|██████▉   | 349/500 [09:30<04:05,  1.63s/it]

{'loss': 1.0407, 'grad_norm': 1.4015679359436035, 'learning_rate': 0.0001208, 'epoch': 0.02}


 70%|███████   | 350/500 [09:32<04:03,  1.63s/it]

{'loss': 0.8264, 'grad_norm': 1.2119569778442383, 'learning_rate': 0.00012, 'epoch': 0.02}


 70%|███████   | 351/500 [09:33<04:02,  1.63s/it]

{'loss': 0.9589, 'grad_norm': 1.82626211643219, 'learning_rate': 0.0001192, 'epoch': 0.02}


 70%|███████   | 352/500 [09:35<04:00,  1.63s/it]

{'loss': 0.8458, 'grad_norm': 1.368525743484497, 'learning_rate': 0.0001184, 'epoch': 0.02}


 71%|███████   | 353/500 [09:37<03:59,  1.63s/it]

{'loss': 0.8711, 'grad_norm': 1.3839961290359497, 'learning_rate': 0.0001176, 'epoch': 0.02}


 71%|███████   | 354/500 [09:38<03:57,  1.63s/it]

{'loss': 1.2359, 'grad_norm': 1.3660188913345337, 'learning_rate': 0.00011679999999999999, 'epoch': 0.02}


 71%|███████   | 355/500 [09:40<03:55,  1.63s/it]

{'loss': 1.1392, 'grad_norm': 2.7630903720855713, 'learning_rate': 0.000116, 'epoch': 0.02}


 71%|███████   | 356/500 [09:41<03:54,  1.63s/it]

{'loss': 0.9053, 'grad_norm': 1.1715997457504272, 'learning_rate': 0.0001152, 'epoch': 0.02}


 71%|███████▏  | 357/500 [09:43<03:52,  1.63s/it]

{'loss': 1.0564, 'grad_norm': 1.237025499343872, 'learning_rate': 0.0001144, 'epoch': 0.02}


 72%|███████▏  | 358/500 [09:45<03:50,  1.63s/it]

{'loss': 1.1716, 'grad_norm': 1.4573943614959717, 'learning_rate': 0.0001136, 'epoch': 0.02}


 72%|███████▏  | 359/500 [09:46<03:49,  1.63s/it]

{'loss': 1.069, 'grad_norm': 1.3615694046020508, 'learning_rate': 0.00011279999999999999, 'epoch': 0.02}


 72%|███████▏  | 360/500 [09:48<03:47,  1.63s/it]

{'loss': 1.2567, 'grad_norm': 1.3627982139587402, 'learning_rate': 0.00011200000000000001, 'epoch': 0.02}


 72%|███████▏  | 361/500 [09:50<03:46,  1.63s/it]

{'loss': 1.2127, 'grad_norm': 1.3777461051940918, 'learning_rate': 0.00011120000000000002, 'epoch': 0.02}


 72%|███████▏  | 362/500 [09:51<03:44,  1.63s/it]

{'loss': 1.0949, 'grad_norm': 1.31735360622406, 'learning_rate': 0.00011040000000000001, 'epoch': 0.02}


 73%|███████▎  | 363/500 [09:53<03:42,  1.63s/it]

{'loss': 1.0491, 'grad_norm': 1.1881794929504395, 'learning_rate': 0.00010960000000000001, 'epoch': 0.02}


 73%|███████▎  | 364/500 [09:55<03:41,  1.63s/it]

{'loss': 1.1184, 'grad_norm': 1.1656306982040405, 'learning_rate': 0.00010880000000000002, 'epoch': 0.02}


 73%|███████▎  | 365/500 [09:56<03:39,  1.63s/it]

{'loss': 1.4158, 'grad_norm': 1.551520824432373, 'learning_rate': 0.00010800000000000001, 'epoch': 0.02}


 73%|███████▎  | 366/500 [09:58<03:38,  1.63s/it]

{'loss': 1.0996, 'grad_norm': 1.3158007860183716, 'learning_rate': 0.00010720000000000002, 'epoch': 0.02}


 73%|███████▎  | 367/500 [09:59<03:36,  1.63s/it]

{'loss': 1.363, 'grad_norm': 1.4955027103424072, 'learning_rate': 0.00010640000000000001, 'epoch': 0.02}


 74%|███████▎  | 368/500 [10:01<03:34,  1.63s/it]

{'loss': 1.1315, 'grad_norm': 1.4372849464416504, 'learning_rate': 0.0001056, 'epoch': 0.02}


 74%|███████▍  | 369/500 [10:03<03:33,  1.63s/it]

{'loss': 0.8217, 'grad_norm': 1.4549458026885986, 'learning_rate': 0.00010480000000000001, 'epoch': 0.02}


 74%|███████▍  | 370/500 [10:04<03:31,  1.63s/it]

{'loss': 1.2641, 'grad_norm': 1.754401683807373, 'learning_rate': 0.00010400000000000001, 'epoch': 0.02}


 74%|███████▍  | 371/500 [10:06<03:29,  1.63s/it]

{'loss': 1.0417, 'grad_norm': 1.2727577686309814, 'learning_rate': 0.0001032, 'epoch': 0.02}


 74%|███████▍  | 372/500 [10:08<03:28,  1.63s/it]

{'loss': 0.6573, 'grad_norm': 1.1276887655258179, 'learning_rate': 0.00010240000000000001, 'epoch': 0.02}


 75%|███████▍  | 373/500 [10:09<03:26,  1.63s/it]

{'loss': 1.1939, 'grad_norm': 1.3871551752090454, 'learning_rate': 0.0001016, 'epoch': 0.02}


 75%|███████▍  | 374/500 [10:11<03:25,  1.63s/it]

{'loss': 1.1361, 'grad_norm': 1.2557176351547241, 'learning_rate': 0.00010080000000000001, 'epoch': 0.02}


 75%|███████▌  | 375/500 [10:12<03:23,  1.63s/it]

{'loss': 0.8941, 'grad_norm': 1.3712372779846191, 'learning_rate': 0.0001, 'epoch': 0.02}


 75%|███████▌  | 376/500 [10:14<03:21,  1.63s/it]

{'loss': 1.3641, 'grad_norm': 1.3602312803268433, 'learning_rate': 9.92e-05, 'epoch': 0.02}


 75%|███████▌  | 377/500 [10:16<03:20,  1.63s/it]

{'loss': 1.2844, 'grad_norm': 1.5043669939041138, 'learning_rate': 9.84e-05, 'epoch': 0.02}


 76%|███████▌  | 378/500 [10:17<03:18,  1.63s/it]

{'loss': 1.3349, 'grad_norm': 1.5664345026016235, 'learning_rate': 9.76e-05, 'epoch': 0.02}


 76%|███████▌  | 379/500 [10:19<03:17,  1.63s/it]

{'loss': 0.9914, 'grad_norm': 1.08395516872406, 'learning_rate': 9.680000000000001e-05, 'epoch': 0.02}


 76%|███████▌  | 380/500 [10:21<03:15,  1.63s/it]

{'loss': 0.9523, 'grad_norm': 1.2042745351791382, 'learning_rate': 9.6e-05, 'epoch': 0.02}


 76%|███████▌  | 381/500 [10:22<03:13,  1.63s/it]

{'loss': 0.9053, 'grad_norm': 1.1258329153060913, 'learning_rate': 9.52e-05, 'epoch': 0.02}


 76%|███████▋  | 382/500 [10:24<03:12,  1.63s/it]

{'loss': 0.761, 'grad_norm': 1.127242922782898, 'learning_rate': 9.44e-05, 'epoch': 0.02}


 77%|███████▋  | 383/500 [10:25<03:10,  1.63s/it]

{'loss': 1.0371, 'grad_norm': 1.2315423488616943, 'learning_rate': 9.360000000000001e-05, 'epoch': 0.02}


 77%|███████▋  | 384/500 [10:27<03:08,  1.63s/it]

{'loss': 0.9731, 'grad_norm': 1.236390233039856, 'learning_rate': 9.28e-05, 'epoch': 0.03}


 77%|███████▋  | 385/500 [10:29<03:06,  1.63s/it]

{'loss': 0.7568, 'grad_norm': 1.1252883672714233, 'learning_rate': 9.200000000000001e-05, 'epoch': 0.03}


 77%|███████▋  | 386/500 [10:30<03:05,  1.63s/it]

{'loss': 1.2622, 'grad_norm': 1.3062915802001953, 'learning_rate': 9.120000000000001e-05, 'epoch': 0.03}


 77%|███████▋  | 387/500 [10:32<03:03,  1.63s/it]

{'loss': 0.9964, 'grad_norm': 1.322613000869751, 'learning_rate': 9.04e-05, 'epoch': 0.03}


 78%|███████▊  | 388/500 [10:34<03:02,  1.63s/it]

{'loss': 1.1215, 'grad_norm': 1.3723728656768799, 'learning_rate': 8.960000000000001e-05, 'epoch': 0.03}


 78%|███████▊  | 389/500 [10:35<03:00,  1.63s/it]

{'loss': 1.0444, 'grad_norm': 1.2172638177871704, 'learning_rate': 8.88e-05, 'epoch': 0.03}


 78%|███████▊  | 390/500 [10:37<02:58,  1.63s/it]

{'loss': 0.879, 'grad_norm': 1.5813677310943604, 'learning_rate': 8.800000000000001e-05, 'epoch': 0.03}


 78%|███████▊  | 391/500 [10:38<02:57,  1.63s/it]

{'loss': 0.9249, 'grad_norm': 1.9351580142974854, 'learning_rate': 8.72e-05, 'epoch': 0.03}


 78%|███████▊  | 392/500 [10:40<02:55,  1.63s/it]

{'loss': 0.9175, 'grad_norm': 1.7009713649749756, 'learning_rate': 8.64e-05, 'epoch': 0.03}


 79%|███████▊  | 393/500 [10:42<03:02,  1.70s/it]

{'loss': 1.2546, 'grad_norm': 1.448543906211853, 'learning_rate': 8.560000000000001e-05, 'epoch': 0.03}


 79%|███████▉  | 394/500 [10:44<02:57,  1.68s/it]

{'loss': 1.002, 'grad_norm': 1.116928219795227, 'learning_rate': 8.48e-05, 'epoch': 0.03}


 79%|███████▉  | 395/500 [10:45<02:54,  1.66s/it]

{'loss': 1.0434, 'grad_norm': 1.213767647743225, 'learning_rate': 8.4e-05, 'epoch': 0.03}


 79%|███████▉  | 396/500 [10:47<02:51,  1.65s/it]

{'loss': 1.211, 'grad_norm': 1.2032907009124756, 'learning_rate': 8.32e-05, 'epoch': 0.03}


 79%|███████▉  | 397/500 [10:48<02:49,  1.65s/it]

{'loss': 1.0014, 'grad_norm': 1.4893544912338257, 'learning_rate': 8.24e-05, 'epoch': 0.03}


 80%|███████▉  | 398/500 [10:50<02:47,  1.64s/it]

{'loss': 0.7141, 'grad_norm': 1.1960383653640747, 'learning_rate': 8.16e-05, 'epoch': 0.03}


 80%|███████▉  | 399/500 [10:52<02:45,  1.64s/it]

{'loss': 1.1631, 'grad_norm': 1.587382197380066, 'learning_rate': 8.080000000000001e-05, 'epoch': 0.03}


 80%|████████  | 400/500 [10:53<02:43,  1.63s/it]

{'loss': 1.23, 'grad_norm': 1.2737855911254883, 'learning_rate': 8e-05, 'epoch': 0.03}


 80%|████████  | 401/500 [10:55<02:41,  1.63s/it]

{'loss': 0.9836, 'grad_norm': 1.173248052597046, 'learning_rate': 7.920000000000001e-05, 'epoch': 0.03}


 80%|████████  | 402/500 [10:57<02:39,  1.63s/it]

{'loss': 0.9894, 'grad_norm': 1.3249462842941284, 'learning_rate': 7.840000000000001e-05, 'epoch': 0.03}


 81%|████████  | 403/500 [10:58<02:38,  1.63s/it]

{'loss': 1.2554, 'grad_norm': 1.416970133781433, 'learning_rate': 7.76e-05, 'epoch': 0.03}


 81%|████████  | 404/500 [11:00<02:36,  1.63s/it]

{'loss': 1.2668, 'grad_norm': 1.453173041343689, 'learning_rate': 7.680000000000001e-05, 'epoch': 0.03}


 81%|████████  | 405/500 [11:01<02:34,  1.63s/it]

{'loss': 1.2556, 'grad_norm': 1.3612627983093262, 'learning_rate': 7.6e-05, 'epoch': 0.03}


 81%|████████  | 406/500 [11:03<02:33,  1.63s/it]

{'loss': 1.0105, 'grad_norm': 1.2603133916854858, 'learning_rate': 7.52e-05, 'epoch': 0.03}


 81%|████████▏ | 407/500 [11:05<02:31,  1.63s/it]

{'loss': 1.163, 'grad_norm': 1.8236860036849976, 'learning_rate': 7.44e-05, 'epoch': 0.03}


 82%|████████▏ | 408/500 [11:06<02:29,  1.63s/it]

{'loss': 1.2699, 'grad_norm': 1.3485418558120728, 'learning_rate': 7.36e-05, 'epoch': 0.03}


 82%|████████▏ | 409/500 [11:08<02:28,  1.63s/it]

{'loss': 1.2824, 'grad_norm': 1.2993643283843994, 'learning_rate': 7.280000000000001e-05, 'epoch': 0.03}


 82%|████████▏ | 410/500 [11:10<02:26,  1.63s/it]

{'loss': 0.943, 'grad_norm': 1.379197120666504, 'learning_rate': 7.2e-05, 'epoch': 0.03}


 82%|████████▏ | 411/500 [11:11<02:24,  1.63s/it]

{'loss': 1.2, 'grad_norm': 1.3079336881637573, 'learning_rate': 7.12e-05, 'epoch': 0.03}


 82%|████████▏ | 412/500 [11:13<02:23,  1.63s/it]

{'loss': 1.1309, 'grad_norm': 1.394849181175232, 'learning_rate': 7.04e-05, 'epoch': 0.03}


 83%|████████▎ | 413/500 [11:15<02:21,  1.63s/it]

{'loss': 1.0415, 'grad_norm': 1.2839776277542114, 'learning_rate': 6.96e-05, 'epoch': 0.03}


 83%|████████▎ | 414/500 [11:17<02:30,  1.75s/it]

{'loss': 0.662, 'grad_norm': 1.0919253826141357, 'learning_rate': 6.879999999999999e-05, 'epoch': 0.03}


 83%|████████▎ | 415/500 [11:18<02:25,  1.71s/it]

{'loss': 1.2319, 'grad_norm': 1.288647174835205, 'learning_rate': 6.800000000000001e-05, 'epoch': 0.03}


 83%|████████▎ | 416/500 [11:20<02:21,  1.69s/it]

{'loss': 0.9448, 'grad_norm': 1.2108290195465088, 'learning_rate': 6.720000000000001e-05, 'epoch': 0.03}


 83%|████████▎ | 417/500 [11:21<02:18,  1.67s/it]

{'loss': 0.842, 'grad_norm': 1.2709823846817017, 'learning_rate': 6.64e-05, 'epoch': 0.03}


 84%|████████▎ | 418/500 [11:23<02:15,  1.66s/it]

{'loss': 0.9714, 'grad_norm': 1.272174596786499, 'learning_rate': 6.560000000000001e-05, 'epoch': 0.03}


 84%|████████▍ | 419/500 [11:25<02:13,  1.65s/it]

{'loss': 0.4736, 'grad_norm': 1.486792802810669, 'learning_rate': 6.48e-05, 'epoch': 0.03}


 84%|████████▍ | 420/500 [11:26<02:11,  1.64s/it]

{'loss': 1.3218, 'grad_norm': 1.5025370121002197, 'learning_rate': 6.400000000000001e-05, 'epoch': 0.03}


 84%|████████▍ | 421/500 [11:28<02:09,  1.64s/it]

{'loss': 1.0177, 'grad_norm': 1.3201254606246948, 'learning_rate': 6.32e-05, 'epoch': 0.03}


 84%|████████▍ | 422/500 [11:30<02:07,  1.63s/it]

{'loss': 1.2795, 'grad_norm': 1.4079028367996216, 'learning_rate': 6.24e-05, 'epoch': 0.03}


 85%|████████▍ | 423/500 [11:31<02:05,  1.63s/it]

{'loss': 1.4249, 'grad_norm': 1.4907509088516235, 'learning_rate': 6.16e-05, 'epoch': 0.03}


 85%|████████▍ | 424/500 [11:33<02:03,  1.63s/it]

{'loss': 0.7886, 'grad_norm': 1.0670909881591797, 'learning_rate': 6.08e-05, 'epoch': 0.03}


 85%|████████▌ | 425/500 [11:34<02:02,  1.63s/it]

{'loss': 0.7987, 'grad_norm': 1.1520075798034668, 'learning_rate': 6e-05, 'epoch': 0.03}


 85%|████████▌ | 426/500 [11:36<02:00,  1.63s/it]

{'loss': 1.2324, 'grad_norm': 1.7460328340530396, 'learning_rate': 5.92e-05, 'epoch': 0.03}


 85%|████████▌ | 427/500 [11:38<01:58,  1.63s/it]

{'loss': 1.1475, 'grad_norm': 1.1884019374847412, 'learning_rate': 5.8399999999999997e-05, 'epoch': 0.03}


 86%|████████▌ | 428/500 [11:39<01:57,  1.63s/it]

{'loss': 0.8538, 'grad_norm': 1.228063702583313, 'learning_rate': 5.76e-05, 'epoch': 0.03}


 86%|████████▌ | 429/500 [11:41<01:55,  1.63s/it]

{'loss': 1.361, 'grad_norm': 1.3980408906936646, 'learning_rate': 5.68e-05, 'epoch': 0.03}


 86%|████████▌ | 430/500 [11:43<01:53,  1.63s/it]

{'loss': 0.8587, 'grad_norm': 1.1927605867385864, 'learning_rate': 5.6000000000000006e-05, 'epoch': 0.03}


 86%|████████▌ | 431/500 [11:44<01:52,  1.63s/it]

{'loss': 0.9095, 'grad_norm': 1.2141889333724976, 'learning_rate': 5.520000000000001e-05, 'epoch': 0.03}


 86%|████████▋ | 432/500 [11:46<01:50,  1.63s/it]

{'loss': 1.0281, 'grad_norm': 1.2853031158447266, 'learning_rate': 5.440000000000001e-05, 'epoch': 0.03}


 87%|████████▋ | 433/500 [11:47<01:49,  1.63s/it]

{'loss': 0.7382, 'grad_norm': 1.074002981185913, 'learning_rate': 5.360000000000001e-05, 'epoch': 0.03}


 87%|████████▋ | 434/500 [11:49<01:47,  1.63s/it]

{'loss': 0.858, 'grad_norm': 1.1633257865905762, 'learning_rate': 5.28e-05, 'epoch': 0.03}


 87%|████████▋ | 435/500 [11:51<01:45,  1.63s/it]

{'loss': 0.8956, 'grad_norm': 1.848616361618042, 'learning_rate': 5.2000000000000004e-05, 'epoch': 0.03}


 87%|████████▋ | 436/500 [11:52<01:44,  1.63s/it]

{'loss': 0.7403, 'grad_norm': 1.4529424905776978, 'learning_rate': 5.1200000000000004e-05, 'epoch': 0.03}


 87%|████████▋ | 437/500 [11:54<01:42,  1.63s/it]

{'loss': 1.1003, 'grad_norm': 1.2765549421310425, 'learning_rate': 5.0400000000000005e-05, 'epoch': 0.03}


 88%|████████▊ | 438/500 [11:56<01:40,  1.63s/it]

{'loss': 1.0247, 'grad_norm': 1.261775016784668, 'learning_rate': 4.96e-05, 'epoch': 0.03}


 88%|████████▊ | 439/500 [11:57<01:39,  1.63s/it]

{'loss': 0.9935, 'grad_norm': 1.1839841604232788, 'learning_rate': 4.88e-05, 'epoch': 0.03}


 88%|████████▊ | 440/500 [11:59<01:37,  1.63s/it]

{'loss': 0.8753, 'grad_norm': 0.9844209551811218, 'learning_rate': 4.8e-05, 'epoch': 0.03}


 88%|████████▊ | 441/500 [12:00<01:36,  1.63s/it]

{'loss': 0.6437, 'grad_norm': 1.1680328845977783, 'learning_rate': 4.72e-05, 'epoch': 0.03}


 88%|████████▊ | 442/500 [12:02<01:34,  1.63s/it]

{'loss': 0.8347, 'grad_norm': 1.1865073442459106, 'learning_rate': 4.64e-05, 'epoch': 0.03}


 89%|████████▊ | 443/500 [12:04<01:32,  1.63s/it]

{'loss': 0.7313, 'grad_norm': 1.2804219722747803, 'learning_rate': 4.5600000000000004e-05, 'epoch': 0.03}


 89%|████████▉ | 444/500 [12:05<01:31,  1.63s/it]

{'loss': 0.9844, 'grad_norm': 1.1490280628204346, 'learning_rate': 4.4800000000000005e-05, 'epoch': 0.03}


 89%|████████▉ | 445/500 [12:07<01:29,  1.63s/it]

{'loss': 1.2258, 'grad_norm': 1.318365454673767, 'learning_rate': 4.4000000000000006e-05, 'epoch': 0.03}


 89%|████████▉ | 446/500 [12:09<01:27,  1.63s/it]

{'loss': 1.3458, 'grad_norm': 1.3035773038864136, 'learning_rate': 4.32e-05, 'epoch': 0.03}


 89%|████████▉ | 447/500 [12:10<01:26,  1.63s/it]

{'loss': 1.0989, 'grad_norm': 1.3492556810379028, 'learning_rate': 4.24e-05, 'epoch': 0.03}


 90%|████████▉ | 448/500 [12:12<01:24,  1.63s/it]

{'loss': 1.4442, 'grad_norm': 1.4504685401916504, 'learning_rate': 4.16e-05, 'epoch': 0.03}


 90%|████████▉ | 449/500 [12:13<01:23,  1.63s/it]

{'loss': 0.6945, 'grad_norm': 1.1917155981063843, 'learning_rate': 4.08e-05, 'epoch': 0.03}


 90%|█████████ | 450/500 [12:15<01:21,  1.63s/it]

{'loss': 1.1355, 'grad_norm': 1.1929389238357544, 'learning_rate': 4e-05, 'epoch': 0.03}


 90%|█████████ | 451/500 [12:17<01:19,  1.63s/it]

{'loss': 0.8759, 'grad_norm': 1.0915582180023193, 'learning_rate': 3.9200000000000004e-05, 'epoch': 0.03}


 90%|█████████ | 452/500 [12:18<01:18,  1.63s/it]

{'loss': 1.0283, 'grad_norm': 1.1743122339248657, 'learning_rate': 3.8400000000000005e-05, 'epoch': 0.03}


 91%|█████████ | 453/500 [12:20<01:16,  1.63s/it]

{'loss': 0.8209, 'grad_norm': 1.0063540935516357, 'learning_rate': 3.76e-05, 'epoch': 0.03}


 91%|█████████ | 454/500 [12:22<01:14,  1.63s/it]

{'loss': 0.5894, 'grad_norm': 1.0845820903778076, 'learning_rate': 3.68e-05, 'epoch': 0.03}


 91%|█████████ | 455/500 [12:23<01:13,  1.63s/it]

{'loss': 0.6906, 'grad_norm': 1.556687831878662, 'learning_rate': 3.6e-05, 'epoch': 0.03}


 91%|█████████ | 456/500 [12:25<01:11,  1.63s/it]

{'loss': 0.9279, 'grad_norm': 1.3336527347564697, 'learning_rate': 3.52e-05, 'epoch': 0.03}


 91%|█████████▏| 457/500 [12:27<01:09,  1.63s/it]

{'loss': 1.0804, 'grad_norm': 1.4213913679122925, 'learning_rate': 3.4399999999999996e-05, 'epoch': 0.03}


 92%|█████████▏| 458/500 [12:28<01:08,  1.63s/it]

{'loss': 1.1328, 'grad_norm': 1.3273924589157104, 'learning_rate': 3.3600000000000004e-05, 'epoch': 0.03}


 92%|█████████▏| 459/500 [12:30<01:06,  1.63s/it]

{'loss': 1.1395, 'grad_norm': 1.5201661586761475, 'learning_rate': 3.2800000000000004e-05, 'epoch': 0.03}


 92%|█████████▏| 460/500 [12:31<01:05,  1.63s/it]

{'loss': 0.8544, 'grad_norm': 1.239298939704895, 'learning_rate': 3.2000000000000005e-05, 'epoch': 0.03}


 92%|█████████▏| 461/500 [12:33<01:03,  1.63s/it]

{'loss': 0.844, 'grad_norm': 1.171979546546936, 'learning_rate': 3.12e-05, 'epoch': 0.03}


 92%|█████████▏| 462/500 [12:35<01:01,  1.63s/it]

{'loss': 1.4349, 'grad_norm': 1.2754545211791992, 'learning_rate': 3.04e-05, 'epoch': 0.03}


 93%|█████████▎| 463/500 [12:36<01:00,  1.63s/it]

{'loss': 1.0075, 'grad_norm': 1.2357275485992432, 'learning_rate': 2.96e-05, 'epoch': 0.03}


 93%|█████████▎| 464/500 [12:38<00:58,  1.63s/it]

{'loss': 0.663, 'grad_norm': 1.0249427556991577, 'learning_rate': 2.88e-05, 'epoch': 0.03}


 93%|█████████▎| 465/500 [12:40<00:56,  1.63s/it]

{'loss': 1.0012, 'grad_norm': 1.1258524656295776, 'learning_rate': 2.8000000000000003e-05, 'epoch': 0.03}


 93%|█████████▎| 466/500 [12:41<00:55,  1.63s/it]

{'loss': 1.0468, 'grad_norm': 1.2545663118362427, 'learning_rate': 2.7200000000000004e-05, 'epoch': 0.03}


 93%|█████████▎| 467/500 [12:43<00:53,  1.63s/it]

{'loss': 1.1492, 'grad_norm': 1.5446910858154297, 'learning_rate': 2.64e-05, 'epoch': 0.03}


 94%|█████████▎| 468/500 [12:44<00:52,  1.63s/it]

{'loss': 1.2116, 'grad_norm': 1.3497676849365234, 'learning_rate': 2.5600000000000002e-05, 'epoch': 0.03}


 94%|█████████▍| 469/500 [12:46<00:50,  1.63s/it]

{'loss': 1.0153, 'grad_norm': 1.4157112836837769, 'learning_rate': 2.48e-05, 'epoch': 0.03}


 94%|█████████▍| 470/500 [12:48<00:48,  1.63s/it]

{'loss': 1.0601, 'grad_norm': 1.211586833000183, 'learning_rate': 2.4e-05, 'epoch': 0.03}


 94%|█████████▍| 471/500 [12:49<00:47,  1.63s/it]

{'loss': 0.9398, 'grad_norm': 1.2557697296142578, 'learning_rate': 2.32e-05, 'epoch': 0.03}


 94%|█████████▍| 472/500 [12:51<00:45,  1.63s/it]

{'loss': 0.8397, 'grad_norm': 1.4671001434326172, 'learning_rate': 2.2400000000000002e-05, 'epoch': 0.03}


 95%|█████████▍| 473/500 [12:53<00:43,  1.63s/it]

{'loss': 0.7429, 'grad_norm': 1.169774055480957, 'learning_rate': 2.16e-05, 'epoch': 0.03}


 95%|█████████▍| 474/500 [12:54<00:42,  1.63s/it]

{'loss': 1.1435, 'grad_norm': 1.2047168016433716, 'learning_rate': 2.08e-05, 'epoch': 0.03}


 95%|█████████▌| 475/500 [12:56<00:40,  1.63s/it]

{'loss': 1.3974, 'grad_norm': 1.3961127996444702, 'learning_rate': 2e-05, 'epoch': 0.03}


 95%|█████████▌| 476/500 [12:57<00:39,  1.63s/it]

{'loss': 0.8155, 'grad_norm': 1.3029673099517822, 'learning_rate': 1.9200000000000003e-05, 'epoch': 0.03}


 95%|█████████▌| 477/500 [12:59<00:37,  1.63s/it]

{'loss': 0.9623, 'grad_norm': 1.4393552541732788, 'learning_rate': 1.84e-05, 'epoch': 0.03}


 96%|█████████▌| 478/500 [13:01<00:35,  1.63s/it]

{'loss': 1.2081, 'grad_norm': 1.2896305322647095, 'learning_rate': 1.76e-05, 'epoch': 0.03}


 96%|█████████▌| 479/500 [13:02<00:34,  1.63s/it]

{'loss': 1.3153, 'grad_norm': 1.4638837575912476, 'learning_rate': 1.6800000000000002e-05, 'epoch': 0.03}


 96%|█████████▌| 480/500 [13:04<00:32,  1.63s/it]

{'loss': 1.1873, 'grad_norm': 1.334862470626831, 'learning_rate': 1.6000000000000003e-05, 'epoch': 0.03}


 96%|█████████▌| 481/500 [13:06<00:30,  1.63s/it]

{'loss': 1.1394, 'grad_norm': 1.358751654624939, 'learning_rate': 1.52e-05, 'epoch': 0.03}


 96%|█████████▋| 482/500 [13:07<00:29,  1.63s/it]

{'loss': 0.8094, 'grad_norm': 1.293616771697998, 'learning_rate': 1.44e-05, 'epoch': 0.03}


 97%|█████████▋| 483/500 [13:09<00:27,  1.63s/it]

{'loss': 0.861, 'grad_norm': 1.2255572080612183, 'learning_rate': 1.3600000000000002e-05, 'epoch': 0.03}


 97%|█████████▋| 484/500 [13:10<00:26,  1.63s/it]

{'loss': 1.0814, 'grad_norm': 1.2570701837539673, 'learning_rate': 1.2800000000000001e-05, 'epoch': 0.03}


 97%|█████████▋| 485/500 [13:12<00:24,  1.63s/it]

{'loss': 0.8414, 'grad_norm': 1.3477081060409546, 'learning_rate': 1.2e-05, 'epoch': 0.03}


 97%|█████████▋| 486/500 [13:14<00:22,  1.63s/it]

{'loss': 1.0183, 'grad_norm': 1.2374134063720703, 'learning_rate': 1.1200000000000001e-05, 'epoch': 0.03}


 97%|█████████▋| 487/500 [13:15<00:21,  1.63s/it]

{'loss': 0.8202, 'grad_norm': 1.4044742584228516, 'learning_rate': 1.04e-05, 'epoch': 0.03}


 98%|█████████▊| 488/500 [13:17<00:19,  1.63s/it]

{'loss': 0.8936, 'grad_norm': 1.3534083366394043, 'learning_rate': 9.600000000000001e-06, 'epoch': 0.03}


 98%|█████████▊| 489/500 [13:19<00:17,  1.63s/it]

{'loss': 0.988, 'grad_norm': 1.3966670036315918, 'learning_rate': 8.8e-06, 'epoch': 0.03}


 98%|█████████▊| 490/500 [13:20<00:16,  1.63s/it]

{'loss': 1.1272, 'grad_norm': 1.0868432521820068, 'learning_rate': 8.000000000000001e-06, 'epoch': 0.03}


 98%|█████████▊| 491/500 [13:22<00:14,  1.63s/it]

{'loss': 1.1211, 'grad_norm': 1.3881926536560059, 'learning_rate': 7.2e-06, 'epoch': 0.03}


 98%|█████████▊| 492/500 [13:23<00:13,  1.63s/it]

{'loss': 0.6539, 'grad_norm': 0.9183357357978821, 'learning_rate': 6.4000000000000006e-06, 'epoch': 0.03}


 99%|█████████▊| 493/500 [13:25<00:11,  1.63s/it]

{'loss': 1.1617, 'grad_norm': 1.3111252784729004, 'learning_rate': 5.600000000000001e-06, 'epoch': 0.03}


 99%|█████████▉| 494/500 [13:27<00:09,  1.62s/it]

{'loss': 1.1881, 'grad_norm': 1.299904227256775, 'learning_rate': 4.800000000000001e-06, 'epoch': 0.03}


 99%|█████████▉| 495/500 [13:28<00:08,  1.63s/it]

{'loss': 1.2242, 'grad_norm': 1.3795512914657593, 'learning_rate': 4.000000000000001e-06, 'epoch': 0.03}


 99%|█████████▉| 496/500 [13:30<00:06,  1.63s/it]

{'loss': 1.2499, 'grad_norm': 1.2474477291107178, 'learning_rate': 3.2000000000000003e-06, 'epoch': 0.03}


 99%|█████████▉| 497/500 [13:32<00:04,  1.63s/it]

{'loss': 0.9071, 'grad_norm': 1.0177066326141357, 'learning_rate': 2.4000000000000003e-06, 'epoch': 0.03}


100%|█████████▉| 498/500 [13:33<00:03,  1.63s/it]

{'loss': 0.9578, 'grad_norm': 1.1276839971542358, 'learning_rate': 1.6000000000000001e-06, 'epoch': 0.03}


100%|█████████▉| 499/500 [13:35<00:01,  1.63s/it]

{'loss': 0.7951, 'grad_norm': 1.06610107421875, 'learning_rate': 8.000000000000001e-07, 'epoch': 0.03}


100%|██████████| 500/500 [13:36<00:00,  1.63s/it]

{'loss': 1.0327, 'grad_norm': 1.386320948600769, 'learning_rate': 0.0, 'epoch': 0.03}




UnboundLocalError: local variable 'active_adapters' referenced before assignment

In [30]:
after_finetune_code = [code_generator(prompt) for prompt in prompts]

In [36]:
# This template helps to compare generated code samples in pretty table form
# feel free to present your work in other forms

from IPython.display import HTML, display
table_template = """<table style="border:1px solid black" >
  <tr>
    <th style="text-align: center; border:1px solid black">PROMPT</th>
    <th style="text-align: center; border:1px solid black">BEFORE</th>
    <th style="text-align: center; border:1px solid black">AFTER</th>
  </tr>
{}
</table>"""

row_template = '''  <tr>
    <td style="width:20%; border:1px solid black"><pre align="left">`{}`</pre></td>
    <td style="width:40%; border:1px solid black"><pre align="left">{}</pre></td>
    <td style="width:40%; border:1px solid black"><pre align="left">{}</pre></td>
  </tr>'''

rows = []

for i in range(len(prompts)):
    # replace placeholders in the format() arguments
    rows.append(row_template.format(prompts[i], before_finetune_code[i][3:], after_finetune_code[i][3:]))

display(HTML(table_template.format('\n'.join(rows))))

PROMPT,BEFORE,AFTER
``,▶▶ 2019-2020 School Year The 2019-2020 school year is here! We are so excited to welcome our new students and families to the school. We are also excited to welcome back our returning families. We are looking forward to another great year at the school. We are also looking forward to another great year of learning and growing together. We are also looking forward to another great year of learning and growing together,"# -*- coding: utf-8 -*- # # Copyright 2015-2016, OpenStack Foundation # # Licensed under the Apache License, Version 2.0 (the ""License""); you may # not use this file except in compliance with the License. You may obtain # a copy of the License at # # http://www.apache.org/licenses/LICEN"
`import`,import Foundation public extension NSURL {  public var absoluteString: String {  return String(cString: CFBundleGetBundleWithURL(self).UTF8String)  } }package com.google.gerrit.server.restapi; import static javax.ws.rs.core.MediaType.APPLICATION_JSON; import com.google.gerrit.extensions.restapi.RestApiModule; import,"import os import sys import time import traceback from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.management.base import BaseCommand from django.db import connection, transaction from django.db.models import get_models from django.db.utils import DatabaseError from django.utils.translation import ugettext_lazy as _ from django.utils.time"
`from`,from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import sys from absl import flags from tensorflow.python import pywrap_tensorflow from tensorflow.python.eager import context from tensorflow.python.eager import function from tensorflow.python.eager import test from tensorflow.python.eager import backprop from tensorflow.python.e,"from __future__ import absolute_import import os import sys import time import traceback from django.conf import settings from django.core.exceptions import ImproperlyConfigured from django.core.management.base import BaseCommand from django.db import connection, transaction from django.db.models import get_models from django.db.utils import DatabaseError from django.utils.translation import ugettext_"
`while`,"while(1) while(1) {  // do something } \end{code} Comment: This is not the same as the OP's code. Comment: @Jeffrey: It's the same as the OP's code, except that it's not a function. Comment: @Jeffrey: The OP's code is a function, but it's not a function declaration. Comment: @","while True:  try:  print(""Enter a number between 1 and 1000000000000000000000000000000000000000000000000000000000000000000000000000000000"
`try`,try to find the best solution for your needs. We are a team of professionals with a long experience in the field of web development. We are a team of professionals with a long experience in the field of web development. We are a team of professionals with a long experience in the field of web development. We are a team of professionals with a long experience in the field of web development. We are a team of professionals with a long experience in the field of web development,"try:  import unittest2 as unittest except ImportError:  import unittest from . import test_utils class TestBase(unittest.TestCase):  def test_get_all_names(self):  self.assertEqual(test_utils.get_all_names(), ['test_get_all_names'])  def test_get_all_names_with_args(self):"
`if`,"if ( !window.atmosphere ) {  window.atmosphere = {}; } (function () {  var o = atmosphere.util,  atmosphere = atmosphere.atmosphere = function () {  var _isClosed = false,  _isOpening = false,  _isOpen = false,  _isClosing = false,  _isError = false,  _isReady =","if (typeof exports !== 'undefined') {  exports.default = __WEBPACK_EXTERNAL_MODULE_1__; } else {  var __WEBPACK_EXTERNAL_MODULE_1__ = (function () {  'use strict';  var _createClass = function () { function defineProperties(target, props) { for (var i = 0; i < props.length"
`for`,for the 2019-2020 school year. The application process for the 2019-2020 school year is now open. The application process for the 2019-2020 school year is now open. Please click here to apply. The application process for the 2019-2020 school year is now open. Please click here to apply. The application process for the 2,"for i in range(1, 10):  print(i)  print(i * 2)  print(i * 3)  print(i * 4)  print(i * 5)  print(i * 6)  print(i * 7)  print(i * 8)  print(i * 9)  print(i * 10)"
`torch`,"torchbearer 2017-05-18 19:55:25 UTC #1 I’m a newbie to the world of RPGs, and I’m looking for a game that I can play with my wife. We’re both in our 30s, and we’re looking for a game that we can play together. We’re both new to the world of RPGs, and we’re looking",torch.setdefaulttensortype('torch.FloatTensor') torch.setdefaultnumpytype('float32') torch.setdefaultnumpytype('float64') torch.setdefaultnumpytype('int32') torch.setdefaultnumpytype('int64') torch.setdefaultnumpytype('uint32') torch.setdefaultnumpytype


If you reach this: congratulations! you've completed everything in this practice session.

If you want to dig deeper, try to implement prompt-tuning (for bonus points!).
You can read more about prompt tuning variants in paper [1](https://arxiv.org/abs/2104.08691) or paper [2](https://arxiv.org/abs/2101.00190). Both versions can be implemented by passing trainable prompts as `model.forward(..., past_key_values=your_prompts)`.



### Read more

* How post-training quantization works: https://arxiv.org/abs/2208.07339
* An overview of running large models: https://huggingface.co/docs/accelerate/package_reference/big_modeling
* A general library for different adapter types: https://adapterhub.ml/


### [extra info] Running other models.

This notebook's code can run with other models of similar size, such as [Falcon-7B](https://huggingface.co/tiiuae/falcon-7b), [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b) or [BLOOM-7.1B](https://huggingface.co/bigscience/bloom-7b1). However, they will require minor code tweaks:
1. change the model name in `AutoModelForCausalLM.from_pretrained()` __and__ `AutoTokenizer`
2. In the prompt tuning code, change `model.model.embed_tokens` to refer to the target model's word embeddings. Simply `print(model)` to navigate to them.
3. Change code to add Lora layers - specifically where you what the transformer block components, since those components now have different names.