# Setup

In [1]:
import torch
from torch.optim import AdamW
from transformers import BertTokenizer, BertForSequenceClassification, BertForQuestionAnswering, DataCollatorWithPadding, TrainingArguments, Trainer, BertConfig
from datasets import load_dataset
import numpy as np
import evaluate
from torch.utils.data import DataLoader
import math
import torch.nn.modules as nn
from packaging import version
from typing import Optional, Tuple
from torch.nn.attention import SDPBackend, sdpa_kernel
import time
import random
import collections
from transformers import AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Custom Attention Class

In [2]:
class CustomAttention(nn.Module):
    def __init__(self, config, num_context_tokens):
        super().__init__()
        # The following equality must be possible num_attention_heads * attention_head_size = config.hidden_size.
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
        
        # assign attention_head size such that num_attention_heads * attention_head_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)    #    
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Initialize projections matrices and dropout layer
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout_prob = config.attention_probs_dropout_prob
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)        

        # Central Attention complexity is O(num_context_tokens * sequence_length * hidden_size)
        self.num_context_tokens = num_context_tokens
    
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (batch_size, num_tokens, num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3) # (batch_size, num_attention_heads, num_tokens, attention_head_size)

    # Adapted from BertSelfAttention
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        
        # Project input into query, key, value matrices respectively
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # Split q,k, and v for the central attention.
        # Tokens ordering: [CLS] + [CONTEXT_1] + [CONTEXT_2] ... [CONTEXT_N] + ALL SENTENCE TOKENS + [SEP] + [PAD] + [PAD] + [PAD] ....
        # first token is [CLS], then the context tokens, then the sentence, then the [SEP], and finally the padding.
        c = self.num_context_tokens + 1  # [CLS] + context tokens are treated as the central context tokens.
        query_layer_context     = query_layer[:,:,:c,:]
        query_layer_sentence    = query_layer[:,:,c:,:]
        key_layer_context       = key_layer[:,:,:c,:]
        key_layer_sentence      = key_layer[:,:,c:,:]
        value_layer_context     = value_layer[:,:,:c,:]
        value_layer_sentence    = value_layer[:,:,c:,:]

        if attention_mask is not None:
            attention_mask_context  = attention_mask[:,:,:c,c:]
            attention_mask_sentence = attention_mask[:,:,c:,:c]  # should always be zeros, since the central context tokens will never contain padding
            attention_mask_sentence_to_itself = attention_mask.diagonal(dim1=-2,dim2=-1)[:,:,c:].unsqueeze(-1) # (batch_size, 1, n, 1)
          
        if self.num_context_tokens == 0: # Original Attention
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.attention_head_size)

            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            if attention_mask is not None:
                attention_scores = attention_scores + attention_mask

            # Normalize the attention scores to probabilities.
            attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.dropout(attention_probs)
            attn_output = torch.matmul(attention_probs, value_layer)

        else: # Central Attention 

            attention_scores_context_to_sentence = torch.matmul(query_layer_context, key_layer_sentence.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
            attention_scores_sentence_to_context = torch.matmul(query_layer_sentence, key_layer_context.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
            attention_scores_sentence_to_itself = torch.sum(query_layer_sentence * key_layer_sentence,axis = -1, keepdim=True)/ math.sqrt(self.attention_head_size) # (batch_size, attention_heads, n, 1)

            # print(attention_scores_sentence_to_context.shape) # torch.Size([8, 2, 383, 2])
            # print(attention_scores_sentence_to_itself.shape) # torch.Size([8, 2, 383, 1]) 

            if attention_mask is not None:
                attention_scores_context_to_sentence = attention_scores_context_to_sentence + attention_mask_context
                attention_scores_sentence_to_context = attention_scores_sentence_to_context + attention_mask_sentence
                attention_scores_sentence_to_itself = attention_scores_sentence_to_itself + attention_mask_sentence_to_itself
                # print(attention_mask_sentence_to_itself.shape) # torch.Size([8, 1, 383, 1])
                # print(attention_scores_sentence_to_itself.shape) # torch.Size([8, 2, 383, 1])

            attention_probs_context_to_sentence = torch.nn.functional.softmax(attention_scores_context_to_sentence, dim=-1)
            attention_scores_sentence_to_context_and_itself = torch.cat((attention_scores_sentence_to_context,attention_scores_sentence_to_itself), dim = -1)  # (batch_size, attentino_heads, n, c+1)
            attention_probs_sentence_to_context_and_itself = torch.nn.functional.softmax(attention_scores_sentence_to_context_and_itself, dim=-1)

            # print(attention_scores_sentence_to_context_and_itself.shape) # torch.Size([8, 2, 383, 3])

            attention_probs_context_to_sentence = self.dropout(attention_probs_context_to_sentence)
            attention_probs_sentence_to_context_and_itself = self.dropout(attention_probs_sentence_to_context_and_itself)

            attention_probs_sentence_to_context = attention_probs_sentence_to_context_and_itself[:,:,:,:-1]
            attention_probs_sentence_to_itself = attention_probs_sentence_to_context_and_itself[:,:,:,-1:]

            # print(attention_probs_sentence_to_context.shape) # torch.Size([8, 2, 383, 2])
            # print(attention_probs_sentence_to_itself.shape) # torch.Size([8, 2, 383, 1])

            attn_output_context = torch.matmul(attention_probs_context_to_sentence, value_layer_sentence)
            attn_output_sentence_to_context= torch.matmul(attention_probs_sentence_to_context, value_layer_context)
            attn_output_sentence_to_itself = attention_probs_sentence_to_itself * value_layer_sentence

            # print(attn_output_sentence_to_context.shape) # torch.Size([8, 2, 383, 64])
            # print(attn_output_sentence_to_itself.shape) # torch.Size([8, 2, 383, 64])

            attn_output_sentence = attn_output_sentence_to_context + attn_output_sentence_to_itself

            # (batch_size, num_attention_heads, c, h)

            attn_output = torch.cat((attn_output_context, attn_output_sentence), dim = 2)

        # reformatting attention output to (batch_size, num_tokens, hidden_size), which is the contextual embedding for each token.
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        new_attn_output_shape = attn_output.size()[:-2] + (self.all_head_size,)
        attn_output = attn_output.view(new_attn_output_shape)

        # Try those
        # attn_output[:,c:,:] = hidden_states[:,c:,:] 
        # let every token form sentence attend to themselves.

        outputs = (attn_output,)

        return outputs

# Classification Task

## Prepare Dataset Function

In [3]:
def prep_dataset(name, num_context_tokens):

    # Using BERT tokenizer
    checkpoint = "bert-base-uncased"
    tokenizer = BertTokenizer.from_pretrained(checkpoint)

    # Add the special context tokens
    special_tokens = {"additional_special_tokens": ["[CONTEXT_"+ str(i) + "]" for i in range(1,num_context_tokens+1)]}
    tokenizer.add_special_tokens(special_tokens)
    context_tokens_ids = [len(tokenizer)-num_context_tokens+i for i in range(num_context_tokens)]

    if name == "sst2":
        raw_dataset = load_dataset("glue", "sst2")
        del raw_dataset["test"] # the labels are -1
        raw_dataset = raw_dataset.remove_columns(["idx"])
        raw_dataset["test"] = raw_dataset["validation"]
        del raw_dataset["validation"]
    elif name == "imdb":
        raw_dataset = load_dataset("imdb")
        del raw_dataset["unsupervised"]
        raw_dataset = raw_dataset.rename_column("text", "sentence")

    def tokenize_and_add_context_tokens(example):    
        output = tokenizer(example["sentence"])
        for i in range(len(output["input_ids"])):
            output["input_ids"][i] = output["input_ids"][i][:1] + context_tokens_ids + output["input_ids"][i][1:]  # [CLS] + context_tokens + sentence
            output["token_type_ids"][i].extend([0] * num_context_tokens)
            output["attention_mask"][i].extend([1] * num_context_tokens)
        return output
    
    tokenized_dataset = raw_dataset.map(tokenize_and_add_context_tokens, batched=True)
    tokenized_dataset = tokenized_dataset.remove_columns(["sentence"])
    tokenized_dataset = tokenized_dataset.rename_column("label", "labels")
    tokenized_dataset.set_format("torch")

    return tokenized_dataset, tokenizer   

## Initialize BERT model with random weights

In [4]:
def initialize_model(num_context_tokens,num_layers=1, max_position_embeddings=4096, intermediate_size=1024, hidden_size=128, num_attention_heads=2):

    vocab_size = 30522
    # Create the model from scratch
    config = BertConfig(
        hidden_size=hidden_size,   # Size of hidden layers
        num_hidden_layers=num_layers,  # Number of transformer layers
        num_attention_heads=num_attention_heads,  # Number of attention heads in each layer
        intermediate_size=intermediate_size,  # Size of the "intermediate" layer (in the feed-forward part of each transformer layer)
        max_position_embeddings= max_position_embeddings,  # Maximum sequence length
    )

    model = BertForSequenceClassification(config)
    model.resize_token_embeddings(vocab_size + num_context_tokens)

    # change the attention of each layer individually 
    for layer in range(num_layers):
        model.bert.encoder.layer[layer].attention.self = CustomAttention(model.config, num_context_tokens)
    
    return model

## Train & Helper Functions

In [6]:
# Prepare training framework
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

def compute_metrics(eval_preds):
    metric = evaluate.load("glue", "sst2") # just calculates the accuracy
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)
    return metric.compute(predictions=predictions, references=labels)

def train(model, dataset, tokenizer, num_epochs=3, learning_rate = 5e-5, batch_size = 8):

    training_args = TrainingArguments(
        output_dir = "Bert",
        overwrite_output_dir = True,
        run_name = "Bert",
        num_train_epochs = num_epochs,
        learning_rate = learning_rate,
        lr_scheduler_type = "linear",
        per_device_train_batch_size = batch_size,
        per_device_eval_batch_size = batch_size,
        eval_strategy = "no",
        save_strategy = "no",  
        )
    
    trainer = Trainer(
        model,
        training_args,
        train_dataset=dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    
    return trainer

## Run Experiments  

In [7]:
# if num_context_tokens = 0, then the normal attention will be used, otherwise the central-context attention attention will be used.
def run_experiment(dataset_name,num_context_tokens, seed):
    set_seed(seed)
    dataset, tokenizer = prep_dataset(dataset_name,num_context_tokens)
    model = initialize_model(num_context_tokens=num_context_tokens)
    trainer = train(model,dataset["train"],tokenizer, batch_size = 16)
    acc = trainer.predict(dataset["test"]).metrics["test_accuracy"]

    # log results
    if dataset_name == "sst2":
        with open("SST2_Results.txt", "a") as file:
            file.write("\n")
            file.write(f"Num Context Tokens = {num_context_tokens}      Random Seed = {seed}\n")            
            file.write(f"Accuracy = {acc*100:.2f} %" + "\n")
    else:
        with open("IMDB_Results.txt", "a") as file:
            file.write("\n")
            file.write(f"Num Context Tokens = {num_context_tokens}      Random Seed = {seed}\n")            
            file.write(f"Accuracy = {acc*100:.2f} %" + "\n")

In [8]:
seeds = [2025, 17, 771]
num_context_tokens_options = [0, 1, 8, 32, 128]
datasets = ["sst2", "imdb"]
for dataset_name in datasets:
    for seed in seeds:
        for num_context_tokens in num_context_tokens_options:
            run_experiment(dataset_name, num_context_tokens=num_context_tokens, seed=seed)

  trainer = Trainer(
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mamroadil420[0m ([33mamroadil420-university-of-cambridge[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


  4%|▍         | 509/12630 [00:09<04:12, 48.00it/s]

{'loss': 0.681, 'grad_norm': 2.402047872543335, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1006/12630 [00:20<04:15, 45.53it/s]

{'loss': 0.5838, 'grad_norm': 6.353432655334473, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1501/12630 [00:56<17:33, 10.56it/s]

{'loss': 0.4671, 'grad_norm': 12.074444770812988, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2002/12630 [01:25<14:48, 11.96it/s]

{'loss': 0.4107, 'grad_norm': 12.162517547607422, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2501/12630 [02:01<05:20, 31.61it/s]

{'loss': 0.3989, 'grad_norm': 11.13410758972168, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3003/12630 [02:35<06:07, 26.19it/s]

{'loss': 0.3673, 'grad_norm': 7.544828414916992, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3502/12630 [03:08<09:10, 16.59it/s]

{'loss': 0.3523, 'grad_norm': 8.15599536895752, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4002/12630 [03:43<13:22, 10.76it/s]

{'loss': 0.3395, 'grad_norm': 10.765381813049316, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4503/12630 [04:17<03:14, 41.87it/s]

{'loss': 0.3184, 'grad_norm': 8.075982093811035, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5001/12630 [04:38<11:00, 11.55it/s]

{'loss': 0.2858, 'grad_norm': 4.257066249847412, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5502/12630 [05:08<08:01, 14.80it/s]

{'loss': 0.2797, 'grad_norm': 18.837757110595703, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6001/12630 [05:44<08:46, 12.58it/s]

{'loss': 0.2732, 'grad_norm': 9.32409381866455, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 51%|█████▏    | 6501/12630 [06:20<10:22,  9.85it/s]

{'loss': 0.2785, 'grad_norm': 16.80219841003418, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7004/12630 [07:00<06:34, 14.27it/s]

{'loss': 0.2704, 'grad_norm': 7.835187911987305, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7503/12630 [07:32<02:16, 37.68it/s]

{'loss': 0.2671, 'grad_norm': 9.29244327545166, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8001/12630 [07:56<06:43, 11.48it/s]

{'loss': 0.2712, 'grad_norm': 11.323286056518555, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8501/12630 [08:29<05:35, 12.31it/s]

{'loss': 0.258, 'grad_norm': 3.1486778259277344, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9005/12630 [09:01<02:47, 21.64it/s]

{'loss': 0.2397, 'grad_norm': 1.5793640613555908, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9508/12630 [09:30<01:30, 34.46it/s]

{'loss': 0.24, 'grad_norm': 6.081349849700928, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10004/12630 [10:01<01:37, 26.88it/s]

{'loss': 0.2348, 'grad_norm': 10.095458984375, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10511/12630 [10:34<00:39, 53.10it/s]

{'loss': 0.2421, 'grad_norm': 7.6935296058654785, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11011/12630 [10:49<00:27, 58.67it/s]

{'loss': 0.2438, 'grad_norm': 8.254450798034668, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11513/12630 [11:06<00:18, 59.54it/s]

{'loss': 0.2456, 'grad_norm': 12.548931121826172, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12002/12630 [11:18<00:42, 14.81it/s]

{'loss': 0.2445, 'grad_norm': 6.992396354675293, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12505/12630 [11:30<00:02, 53.61it/s]

{'loss': 0.2267, 'grad_norm': 5.858194351196289, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [11:35<00:00, 18.15it/s]


{'train_runtime': 697.6379, 'train_samples_per_second': 289.616, 'train_steps_per_second': 18.104, 'train_loss': 0.3199532894699326, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 50.22it/s] 
The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
  trainer = Trainer(
  4%|▍         | 505/12630 [00:16<11:33, 17.49it/s]

{'loss': 0.6817, 'grad_norm': 1.878991961479187, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1001/12630 [00:32<12:56, 14.98it/s]

{'loss': 0.6055, 'grad_norm': 10.927739143371582, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1500/12630 [00:51<17:10, 10.80it/s]

{'loss': 0.4745, 'grad_norm': 20.428590774536133, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2001/12630 [01:09<19:26,  9.11it/s]

{'loss': 0.4199, 'grad_norm': 8.297346115112305, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2502/12630 [01:50<17:10,  9.83it/s]

{'loss': 0.4034, 'grad_norm': 9.096566200256348, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3001/12630 [02:31<14:00, 11.45it/s]

{'loss': 0.3742, 'grad_norm': 14.160886764526367, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3500/12630 [03:14<13:12, 11.52it/s]

{'loss': 0.3616, 'grad_norm': 12.829171180725098, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4002/12630 [03:56<13:47, 10.42it/s]

{'loss': 0.3473, 'grad_norm': 5.714238166809082, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4507/12630 [04:25<02:57, 45.68it/s]

{'loss': 0.314, 'grad_norm': 15.302688598632812, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5007/12630 [04:36<02:50, 44.73it/s]

{'loss': 0.2892, 'grad_norm': 11.769537925720215, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5509/12630 [04:52<02:31, 47.11it/s]

{'loss': 0.279, 'grad_norm': 23.674449920654297, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6008/12630 [05:05<01:57, 56.15it/s]

{'loss': 0.276, 'grad_norm': 8.658174514770508, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6508/12630 [05:27<02:18, 44.22it/s]

{'loss': 0.2853, 'grad_norm': 12.60753345489502, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7007/12630 [05:37<01:53, 49.47it/s]

{'loss': 0.275, 'grad_norm': 10.737445831298828, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7506/12630 [05:47<01:55, 44.26it/s]

{'loss': 0.2669, 'grad_norm': 16.640262603759766, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8003/12630 [05:59<01:51, 41.43it/s]

{'loss': 0.2731, 'grad_norm': 5.816969394683838, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8506/12630 [06:15<01:23, 49.29it/s]

{'loss': 0.262, 'grad_norm': 8.526118278503418, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9002/12630 [06:29<03:45, 16.08it/s]

{'loss': 0.2331, 'grad_norm': 10.306279182434082, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9501/12630 [07:00<02:27, 21.25it/s]

{'loss': 0.239, 'grad_norm': 5.451944828033447, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10007/12630 [07:15<00:50, 52.20it/s]

{'loss': 0.2384, 'grad_norm': 6.707048416137695, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10507/12630 [07:26<00:40, 52.67it/s]

{'loss': 0.2451, 'grad_norm': 4.500412940979004, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11005/12630 [07:37<00:39, 41.29it/s]

{'loss': 0.2487, 'grad_norm': 7.228950023651123, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11503/12630 [07:47<00:24, 46.75it/s]

{'loss': 0.2508, 'grad_norm': 6.4427032470703125, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12008/12630 [07:57<00:13, 46.31it/s]

{'loss': 0.2513, 'grad_norm': 10.380163192749023, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12502/12630 [08:11<00:04, 26.70it/s]

{'loss': 0.2358, 'grad_norm': 13.990981101989746, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [08:15<00:00, 25.50it/s]


{'train_runtime': 495.3539, 'train_samples_per_second': 407.884, 'train_steps_per_second': 25.497, 'train_loss': 0.32431392866000236, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 49.91it/s] 
  4%|▍         | 506/12630 [00:12<04:46, 42.27it/s]

{'loss': 0.6826, 'grad_norm': 1.7072207927703857, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1007/12630 [00:25<04:37, 41.83it/s]

{'loss': 0.5943, 'grad_norm': 14.093764305114746, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1504/12630 [00:38<04:28, 41.47it/s]

{'loss': 0.4691, 'grad_norm': 11.32673168182373, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2007/12630 [00:54<03:33, 49.87it/s]

{'loss': 0.4152, 'grad_norm': 11.149014472961426, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2503/12630 [01:06<06:11, 27.26it/s]

{'loss': 0.3947, 'grad_norm': 7.254334926605225, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3006/12630 [01:22<04:11, 38.26it/s]

{'loss': 0.3675, 'grad_norm': 13.399767875671387, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3505/12630 [01:42<05:02, 30.17it/s]

{'loss': 0.3628, 'grad_norm': 14.66536808013916, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4004/12630 [01:59<03:24, 42.27it/s]

{'loss': 0.3466, 'grad_norm': 14.193891525268555, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4505/12630 [02:19<03:31, 38.41it/s]

{'loss': 0.3109, 'grad_norm': 8.119940757751465, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5003/12630 [02:36<05:06, 24.90it/s]

{'loss': 0.2935, 'grad_norm': 10.70576286315918, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5505/12630 [02:47<02:55, 40.57it/s]

{'loss': 0.277, 'grad_norm': 7.641109943389893, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6005/12630 [03:00<02:40, 41.27it/s]

{'loss': 0.2788, 'grad_norm': 16.684232711791992, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6509/12630 [03:20<02:00, 50.88it/s]

{'loss': 0.284, 'grad_norm': 14.771650314331055, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7003/12630 [03:31<02:12, 42.61it/s]

{'loss': 0.2765, 'grad_norm': 9.000585556030273, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7501/12630 [03:44<06:34, 12.99it/s]

{'loss': 0.2617, 'grad_norm': 14.324974060058594, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8006/12630 [04:06<01:52, 40.94it/s]

{'loss': 0.2682, 'grad_norm': 4.844738006591797, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8506/12630 [04:18<01:50, 37.28it/s]

{'loss': 0.2573, 'grad_norm': 2.322145938873291, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9002/12630 [04:32<03:17, 18.33it/s]

{'loss': 0.2328, 'grad_norm': 13.811562538146973, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9505/12630 [05:09<02:00, 25.92it/s]

{'loss': 0.2416, 'grad_norm': 9.878108024597168, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10006/12630 [05:40<01:25, 30.82it/s]

{'loss': 0.2371, 'grad_norm': 4.893286228179932, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10500/12630 [06:05<02:27, 14.48it/s]

{'loss': 0.2426, 'grad_norm': 8.10074234008789, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11001/12630 [06:32<02:33, 10.61it/s]

{'loss': 0.2442, 'grad_norm': 2.0688769817352295, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11503/12630 [07:00<00:57, 19.59it/s]

{'loss': 0.2501, 'grad_norm': 14.17969799041748, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12006/12630 [07:23<00:15, 40.02it/s]

{'loss': 0.2404, 'grad_norm': 8.53059196472168, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12505/12630 [07:40<00:03, 31.97it/s]

{'loss': 0.2271, 'grad_norm': 3.9471993446350098, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [07:43<00:00, 27.26it/s]


{'train_runtime': 463.3028, 'train_samples_per_second': 436.101, 'train_steps_per_second': 27.261, 'train_loss': 0.3214501031315336, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 47.89it/s] 
  4%|▍         | 500/12630 [00:14<13:52, 14.57it/s]

{'loss': 0.6811, 'grad_norm': 2.4172298908233643, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1006/12630 [00:28<04:34, 42.42it/s]

{'loss': 0.5825, 'grad_norm': 9.08128547668457, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1504/12630 [00:40<04:31, 40.98it/s]

{'loss': 0.4632, 'grad_norm': 14.306540489196777, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2006/12630 [00:53<04:29, 39.43it/s]

{'loss': 0.4115, 'grad_norm': 12.958028793334961, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2505/12630 [01:06<04:22, 38.58it/s]

{'loss': 0.397, 'grad_norm': 15.131685256958008, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3004/12630 [01:19<03:58, 40.33it/s]

{'loss': 0.3682, 'grad_norm': 10.99029541015625, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3505/12630 [01:32<03:41, 41.21it/s]

{'loss': 0.3586, 'grad_norm': 20.811891555786133, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4005/12630 [01:48<04:00, 35.93it/s]

{'loss': 0.3391, 'grad_norm': 21.864967346191406, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4506/12630 [02:02<03:32, 38.24it/s]

{'loss': 0.3118, 'grad_norm': 14.69084644317627, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5006/12630 [02:14<03:06, 40.93it/s]

{'loss': 0.2909, 'grad_norm': 16.957361221313477, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5505/12630 [02:27<03:23, 35.05it/s]

{'loss': 0.2834, 'grad_norm': 12.773378372192383, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6004/12630 [02:40<02:54, 38.07it/s]

{'loss': 0.2773, 'grad_norm': 7.428937911987305, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 51%|█████▏    | 6504/12630 [02:58<02:22, 43.10it/s]

{'loss': 0.2766, 'grad_norm': 23.762636184692383, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7004/12630 [03:11<02:31, 37.17it/s]

{'loss': 0.2729, 'grad_norm': 11.521417617797852, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7506/12630 [03:27<04:07, 20.74it/s]

{'loss': 0.2666, 'grad_norm': 7.850220203399658, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8009/12630 [03:40<01:33, 49.39it/s]

{'loss': 0.2715, 'grad_norm': 10.537398338317871, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8504/12630 [03:55<01:45, 38.93it/s]

{'loss': 0.2575, 'grad_norm': 3.0921638011932373, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9006/12630 [04:17<01:41, 35.80it/s]

{'loss': 0.2303, 'grad_norm': 18.285852432250977, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9503/12630 [04:32<01:33, 33.53it/s]

{'loss': 0.2366, 'grad_norm': 10.120707511901855, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10005/12630 [04:45<01:04, 40.85it/s]

{'loss': 0.2345, 'grad_norm': 6.394253730773926, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10507/12630 [05:07<00:47, 44.28it/s]

{'loss': 0.2457, 'grad_norm': 21.568603515625, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11000/12630 [05:20<00:43, 37.75it/s]

{'loss': 0.2429, 'grad_norm': 4.333149433135986, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [05:44<00:47, 23.71it/s]

{'loss': 0.2513, 'grad_norm': 10.252379417419434, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12004/12630 [05:57<00:16, 37.48it/s]

{'loss': 0.2372, 'grad_norm': 6.619465351104736, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12507/12630 [06:14<00:02, 41.94it/s]

{'loss': 0.2307, 'grad_norm': 15.270214080810547, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [06:17<00:00, 33.44it/s]


{'train_runtime': 377.6798, 'train_samples_per_second': 534.969, 'train_steps_per_second': 33.441, 'train_loss': 0.3200253523632542, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 48.67it/s] 
  4%|▍         | 505/12630 [00:35<13:07, 15.40it/s]

{'loss': 0.6814, 'grad_norm': 2.434263229370117, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1005/12630 [00:49<05:42, 33.94it/s]

{'loss': 0.599, 'grad_norm': 9.764373779296875, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1506/12630 [01:06<05:13, 35.50it/s]

{'loss': 0.4797, 'grad_norm': 21.54187774658203, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2007/12630 [01:19<04:26, 39.81it/s]

{'loss': 0.419, 'grad_norm': 15.189645767211914, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2508/12630 [01:42<05:15, 32.05it/s]

{'loss': 0.4015, 'grad_norm': 9.109330177307129, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3004/12630 [02:04<04:36, 34.87it/s]

{'loss': 0.3775, 'grad_norm': 11.466562271118164, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3503/12630 [02:20<04:17, 35.50it/s]

{'loss': 0.3551, 'grad_norm': 22.509498596191406, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4002/12630 [02:46<03:50, 37.35it/s]

{'loss': 0.3496, 'grad_norm': 22.520875930786133, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4505/12630 [03:03<04:46, 28.37it/s]

{'loss': 0.3191, 'grad_norm': 11.82175350189209, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5003/12630 [03:19<04:14, 29.92it/s]

{'loss': 0.2927, 'grad_norm': 11.340157508850098, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5503/12630 [03:34<03:14, 36.68it/s]

{'loss': 0.2763, 'grad_norm': 8.49588680267334, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6004/12630 [03:50<03:03, 36.04it/s]

{'loss': 0.2765, 'grad_norm': 13.330621719360352, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6505/12630 [04:10<02:29, 41.07it/s]

{'loss': 0.2846, 'grad_norm': 7.917685031890869, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7002/12630 [04:39<03:37, 25.85it/s]

{'loss': 0.2743, 'grad_norm': 10.52058219909668, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7504/12630 [05:04<03:31, 24.24it/s]

{'loss': 0.2665, 'grad_norm': 8.902249336242676, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8004/12630 [05:28<01:54, 40.28it/s]

{'loss': 0.2741, 'grad_norm': 12.250800132751465, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8505/12630 [05:43<01:47, 38.32it/s]

{'loss': 0.2654, 'grad_norm': 0.9441694021224976, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9004/12630 [06:04<02:22, 25.38it/s]

{'loss': 0.2394, 'grad_norm': 10.208800315856934, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9501/12630 [06:25<03:24, 15.31it/s]

{'loss': 0.236, 'grad_norm': 7.209898948669434, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10005/12630 [06:48<01:12, 36.41it/s]

{'loss': 0.2354, 'grad_norm': 11.30345344543457, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10503/12630 [07:03<01:00, 35.15it/s]

{'loss': 0.2472, 'grad_norm': 4.47744083404541, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11004/12630 [07:17<00:55, 29.37it/s]

{'loss': 0.2477, 'grad_norm': 8.923486709594727, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [07:31<00:31, 35.62it/s]

{'loss': 0.2496, 'grad_norm': 15.693288803100586, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12006/12630 [07:48<00:16, 38.57it/s]

{'loss': 0.2421, 'grad_norm': 9.700057983398438, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12504/12630 [08:11<00:03, 37.89it/s]

{'loss': 0.233, 'grad_norm': 12.430502891540527, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [08:16<00:00, 25.42it/s]


{'train_runtime': 496.9295, 'train_samples_per_second': 406.591, 'train_steps_per_second': 25.416, 'train_loss': 0.3239708733577532, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 30.30it/s]
  4%|▍         | 506/12630 [00:11<03:52, 52.26it/s]

{'loss': 0.684, 'grad_norm': 1.5542278289794922, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1008/12630 [00:22<04:34, 42.40it/s]

{'loss': 0.584, 'grad_norm': 11.456214904785156, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1504/12630 [00:33<04:13, 43.95it/s]

{'loss': 0.4648, 'grad_norm': 12.156997680664062, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2008/12630 [00:45<03:45, 47.20it/s]

{'loss': 0.4124, 'grad_norm': 6.873230934143066, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2504/12630 [00:57<04:36, 36.64it/s]

{'loss': 0.3901, 'grad_norm': 5.705912113189697, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3007/12630 [01:10<04:00, 40.03it/s]

{'loss': 0.3633, 'grad_norm': 10.166162490844727, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3504/12630 [01:22<03:52, 39.27it/s]

{'loss': 0.3487, 'grad_norm': 12.241859436035156, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4004/12630 [01:34<03:32, 40.65it/s]

{'loss': 0.3357, 'grad_norm': 5.9624247550964355, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4505/12630 [01:47<03:20, 40.50it/s]

{'loss': 0.3119, 'grad_norm': 7.897247314453125, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5004/12630 [02:00<03:16, 38.78it/s]

{'loss': 0.2837, 'grad_norm': 11.620152473449707, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5508/12630 [02:12<02:56, 40.33it/s]

{'loss': 0.2729, 'grad_norm': 11.561004638671875, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6005/12630 [02:24<02:29, 44.29it/s]

{'loss': 0.2762, 'grad_norm': 13.957167625427246, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6506/12630 [02:37<02:20, 43.46it/s]

{'loss': 0.2746, 'grad_norm': 15.257967948913574, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7005/12630 [02:50<01:56, 48.14it/s]

{'loss': 0.2765, 'grad_norm': 9.46495532989502, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7507/12630 [03:04<02:01, 42.12it/s]

{'loss': 0.266, 'grad_norm': 11.728475570678711, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8007/12630 [03:16<01:45, 43.71it/s]

{'loss': 0.2706, 'grad_norm': 8.237868309020996, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8504/12630 [03:32<01:30, 45.46it/s]

{'loss': 0.2626, 'grad_norm': 1.3016644716262817, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9010/12630 [03:43<01:04, 56.52it/s]

{'loss': 0.2339, 'grad_norm': 11.897324562072754, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9508/12630 [03:53<01:02, 49.85it/s]

{'loss': 0.2345, 'grad_norm': 14.345879554748535, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10007/12630 [04:04<00:57, 45.91it/s]

{'loss': 0.2389, 'grad_norm': 7.652599811553955, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10504/12630 [04:16<00:47, 44.74it/s]

{'loss': 0.2453, 'grad_norm': 4.913441181182861, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11008/12630 [04:27<00:36, 45.04it/s]

{'loss': 0.243, 'grad_norm': 8.764333724975586, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [04:39<00:25, 44.68it/s]

{'loss': 0.2395, 'grad_norm': 24.49036407470703, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12005/12630 [04:52<00:16, 37.88it/s]

{'loss': 0.24, 'grad_norm': 14.434050559997559, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12507/12630 [05:06<00:02, 42.15it/s]

{'loss': 0.2298, 'grad_norm': 4.059504508972168, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [05:09<00:00, 40.77it/s]


{'train_runtime': 309.8048, 'train_samples_per_second': 652.175, 'train_steps_per_second': 40.768, 'train_loss': 0.3184571182359709, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 45.49it/s] 
  4%|▍         | 502/12630 [00:16<14:38, 13.81it/s]

{'loss': 0.6827, 'grad_norm': 1.7673999071121216, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1002/12630 [00:29<05:06, 37.96it/s]

{'loss': 0.5898, 'grad_norm': 6.131294250488281, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1503/12630 [00:43<07:23, 25.08it/s]

{'loss': 0.4568, 'grad_norm': 9.647164344787598, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2002/12630 [01:00<08:36, 20.57it/s]

{'loss': 0.4163, 'grad_norm': 13.067976951599121, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2507/12630 [01:16<03:34, 47.16it/s]

{'loss': 0.3948, 'grad_norm': 6.895547389984131, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3008/12630 [01:28<03:39, 43.80it/s]

{'loss': 0.3646, 'grad_norm': 14.270785331726074, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3509/12630 [01:41<03:31, 43.16it/s]

{'loss': 0.3551, 'grad_norm': 14.741532325744629, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4005/12630 [01:54<03:29, 41.15it/s]

{'loss': 0.3513, 'grad_norm': 7.079202651977539, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4505/12630 [02:08<03:46, 35.94it/s]

{'loss': 0.3147, 'grad_norm': 9.635496139526367, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5005/12630 [02:23<03:23, 37.45it/s]

{'loss': 0.2926, 'grad_norm': 13.752192497253418, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5503/12630 [02:39<02:39, 44.72it/s]

{'loss': 0.2813, 'grad_norm': 22.174829483032227, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6007/12630 [02:53<02:19, 47.48it/s]

{'loss': 0.2758, 'grad_norm': 15.034001350402832, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6506/12630 [03:06<02:45, 37.09it/s]

{'loss': 0.2859, 'grad_norm': 11.608208656311035, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7007/12630 [03:21<01:58, 47.37it/s]

{'loss': 0.274, 'grad_norm': 7.8395915031433105, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7503/12630 [03:43<02:03, 41.42it/s]

{'loss': 0.2656, 'grad_norm': 12.255988121032715, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8003/12630 [03:56<03:35, 21.50it/s]

{'loss': 0.2715, 'grad_norm': 8.686813354492188, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8505/12630 [04:10<01:37, 42.40it/s]

{'loss': 0.2647, 'grad_norm': 1.4311209917068481, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9001/12630 [04:25<05:34, 10.85it/s]

{'loss': 0.2332, 'grad_norm': 7.225192546844482, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9504/12630 [04:45<01:07, 46.26it/s]

{'loss': 0.2398, 'grad_norm': 11.396669387817383, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10008/12630 [05:02<01:01, 42.94it/s]

{'loss': 0.244, 'grad_norm': 13.725428581237793, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10508/12630 [05:19<00:57, 36.94it/s]

{'loss': 0.2479, 'grad_norm': 20.594467163085938, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11005/12630 [05:35<00:43, 37.37it/s]

{'loss': 0.2404, 'grad_norm': 10.676422119140625, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11504/12630 [05:49<00:26, 42.37it/s]

{'loss': 0.2462, 'grad_norm': 14.399084091186523, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12004/12630 [06:02<00:18, 33.70it/s]

{'loss': 0.2504, 'grad_norm': 9.915960311889648, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12502/12630 [06:24<00:07, 17.44it/s]

{'loss': 0.2331, 'grad_norm': 9.366720199584961, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [06:28<00:00, 32.52it/s]


{'train_runtime': 388.364, 'train_samples_per_second': 520.252, 'train_steps_per_second': 32.521, 'train_loss': 0.32201059240249064, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 39.71it/s]
  4%|▍         | 506/12630 [00:13<06:18, 32.05it/s]

{'loss': 0.6818, 'grad_norm': 2.085576295852661, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1005/12630 [00:27<05:18, 36.46it/s]

{'loss': 0.5886, 'grad_norm': 9.966522216796875, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1504/12630 [00:40<04:59, 37.13it/s]

{'loss': 0.4717, 'grad_norm': 22.627866744995117, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2005/12630 [00:53<04:18, 41.06it/s]

{'loss': 0.4029, 'grad_norm': 14.033493041992188, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2508/12630 [01:06<04:03, 41.60it/s]

{'loss': 0.3913, 'grad_norm': 14.450969696044922, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3007/12630 [01:19<03:45, 42.74it/s]

{'loss': 0.3662, 'grad_norm': 16.82050132751465, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3503/12630 [01:32<03:35, 42.30it/s]

{'loss': 0.3628, 'grad_norm': 8.692012786865234, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4004/12630 [01:45<04:35, 31.30it/s]

{'loss': 0.3424, 'grad_norm': 6.085916042327881, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4507/12630 [01:58<03:26, 39.36it/s]

{'loss': 0.3148, 'grad_norm': 16.392772674560547, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5006/12630 [02:10<03:55, 32.38it/s]

{'loss': 0.2961, 'grad_norm': 8.319581031799316, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5504/12630 [02:24<03:19, 35.80it/s]

{'loss': 0.2756, 'grad_norm': 7.8058180809021, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6007/12630 [02:37<02:41, 40.95it/s]

{'loss': 0.268, 'grad_norm': 7.860093116760254, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6507/12630 [02:51<02:52, 35.52it/s]

{'loss': 0.2809, 'grad_norm': 16.929147720336914, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7000/12630 [03:10<02:41, 34.95it/s]

{'loss': 0.2725, 'grad_norm': 14.43064022064209, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7504/12630 [03:25<02:29, 34.38it/s]

{'loss': 0.2645, 'grad_norm': 9.956308364868164, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8005/12630 [03:44<01:53, 40.85it/s]

{'loss': 0.2728, 'grad_norm': 9.56751823425293, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8505/12630 [04:01<02:10, 31.65it/s]

{'loss': 0.2612, 'grad_norm': 2.858032703399658, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9003/12630 [04:17<01:58, 30.50it/s]

{'loss': 0.235, 'grad_norm': 5.597317218780518, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9506/12630 [04:37<01:09, 44.78it/s]

{'loss': 0.2308, 'grad_norm': 5.2912139892578125, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10007/12630 [04:56<01:08, 38.04it/s]

{'loss': 0.2374, 'grad_norm': 14.218727111816406, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10503/12630 [05:10<00:51, 41.58it/s]

{'loss': 0.2441, 'grad_norm': 15.219283103942871, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11003/12630 [05:23<00:55, 29.50it/s]

{'loss': 0.2452, 'grad_norm': 5.710454940795898, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11504/12630 [05:39<00:27, 40.23it/s]

{'loss': 0.2494, 'grad_norm': 21.129940032958984, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12005/12630 [05:53<00:17, 36.13it/s]

{'loss': 0.2437, 'grad_norm': 8.109538078308105, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12508/12630 [06:07<00:02, 44.99it/s]

{'loss': 0.2237, 'grad_norm': 4.222653388977051, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [06:11<00:00, 33.96it/s]


{'train_runtime': 371.9203, 'train_samples_per_second': 543.254, 'train_steps_per_second': 33.959, 'train_loss': 0.3201258950071116, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 47.09it/s] 
  4%|▍         | 503/12630 [00:18<04:53, 41.38it/s]

{'loss': 0.6828, 'grad_norm': 2.1971943378448486, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1000/12630 [00:37<09:42, 19.98it/s]

{'loss': 0.5836, 'grad_norm': 11.54851245880127, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1505/12630 [00:58<04:10, 44.32it/s]

{'loss': 0.4569, 'grad_norm': 12.772872924804688, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2001/12630 [01:25<12:45, 13.88it/s]

{'loss': 0.3979, 'grad_norm': 12.26815128326416, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2505/12630 [01:42<03:56, 42.75it/s]

{'loss': 0.388, 'grad_norm': 23.505586624145508, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3002/12630 [02:01<17:23,  9.23it/s]

{'loss': 0.3625, 'grad_norm': 10.984942436218262, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3504/12630 [02:24<09:35, 15.86it/s]

{'loss': 0.3516, 'grad_norm': 18.717334747314453, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4006/12630 [02:41<04:08, 34.71it/s]

{'loss': 0.3354, 'grad_norm': 6.682971477508545, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4503/12630 [02:53<03:54, 34.59it/s]

{'loss': 0.3113, 'grad_norm': 11.738722801208496, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5000/12630 [03:15<05:55, 21.48it/s]

{'loss': 0.2884, 'grad_norm': 11.97757625579834, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5503/12630 [03:37<03:25, 34.73it/s]

{'loss': 0.275, 'grad_norm': 14.942835807800293, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6003/12630 [03:59<05:46, 19.13it/s]

{'loss': 0.2715, 'grad_norm': 11.957066535949707, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6506/12630 [04:21<03:00, 33.92it/s]

{'loss': 0.2824, 'grad_norm': 17.704750061035156, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7008/12630 [04:36<02:11, 42.86it/s]

{'loss': 0.2757, 'grad_norm': 11.050206184387207, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7504/12630 [04:50<02:01, 42.06it/s]

{'loss': 0.2668, 'grad_norm': 9.800654411315918, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8005/12630 [05:03<01:56, 39.87it/s]

{'loss': 0.2762, 'grad_norm': 9.526663780212402, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8507/12630 [05:16<01:40, 41.19it/s]

{'loss': 0.2623, 'grad_norm': 5.172091960906982, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9002/12630 [05:28<01:43, 34.90it/s]

{'loss': 0.2348, 'grad_norm': 12.481339454650879, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9507/12630 [05:41<01:15, 41.55it/s]

{'loss': 0.2441, 'grad_norm': 12.42886734008789, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10004/12630 [06:04<02:00, 21.86it/s]

{'loss': 0.2346, 'grad_norm': 11.446856498718262, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10505/12630 [06:24<01:02, 33.92it/s]

{'loss': 0.2431, 'grad_norm': 3.141536235809326, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11009/12630 [06:37<00:32, 49.61it/s]

{'loss': 0.2454, 'grad_norm': 6.338654518127441, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11505/12630 [06:57<00:29, 38.38it/s]

{'loss': 0.2489, 'grad_norm': 12.227042198181152, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12003/12630 [07:18<00:13, 45.64it/s]

{'loss': 0.2411, 'grad_norm': 11.464835166931152, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12503/12630 [07:31<00:03, 37.95it/s]

{'loss': 0.2299, 'grad_norm': 1.7952773571014404, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [07:34<00:00, 27.80it/s]


{'train_runtime': 454.4039, 'train_samples_per_second': 444.642, 'train_steps_per_second': 27.795, 'train_loss': 0.3187029821195021, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 43.20it/s] 
  4%|▍         | 506/12630 [00:15<05:18, 38.07it/s]

{'loss': 0.6828, 'grad_norm': 1.9546953439712524, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1006/12630 [00:29<05:05, 38.04it/s]

{'loss': 0.5978, 'grad_norm': 7.856589317321777, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1500/12630 [00:51<09:36, 19.31it/s]

{'loss': 0.4674, 'grad_norm': 16.061309814453125, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2000/12630 [01:09<09:06, 19.44it/s]

{'loss': 0.4091, 'grad_norm': 9.434154510498047, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2503/12630 [01:28<08:10, 20.64it/s]

{'loss': 0.3897, 'grad_norm': 12.929949760437012, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3006/12630 [01:48<04:39, 34.40it/s]

{'loss': 0.3605, 'grad_norm': 12.345394134521484, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3507/12630 [02:01<03:21, 45.20it/s]

{'loss': 0.3508, 'grad_norm': 7.5666069984436035, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4007/12630 [02:16<03:58, 36.12it/s]

{'loss': 0.3424, 'grad_norm': 11.00944995880127, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4502/12630 [02:33<04:14, 31.99it/s]

{'loss': 0.3116, 'grad_norm': 14.028851509094238, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5004/12630 [02:48<03:40, 34.59it/s]

{'loss': 0.2888, 'grad_norm': 8.682361602783203, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5506/12630 [03:02<03:41, 32.14it/s]

{'loss': 0.273, 'grad_norm': 8.61757755279541, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6005/12630 [03:14<02:19, 47.52it/s]

{'loss': 0.2694, 'grad_norm': 10.42542839050293, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6507/12630 [03:27<02:29, 40.84it/s]

{'loss': 0.2827, 'grad_norm': 19.87401580810547, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7006/12630 [03:42<02:56, 31.83it/s]

{'loss': 0.269, 'grad_norm': 8.66531753540039, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7500/12630 [03:57<04:54, 17.41it/s]

{'loss': 0.2595, 'grad_norm': 10.059662818908691, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8006/12630 [04:14<01:57, 39.26it/s]

{'loss': 0.2706, 'grad_norm': 10.914349555969238, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8507/12630 [04:30<01:41, 40.49it/s]

{'loss': 0.2581, 'grad_norm': 1.2980238199234009, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9000/12630 [04:50<02:38, 22.85it/s]

{'loss': 0.2285, 'grad_norm': 1.7826229333877563, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9506/12630 [05:09<01:15, 41.27it/s]

{'loss': 0.2399, 'grad_norm': 9.02224349975586, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10005/12630 [05:30<01:56, 22.53it/s]

{'loss': 0.2371, 'grad_norm': 8.867856979370117, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10501/12630 [05:47<02:03, 17.18it/s]

{'loss': 0.2402, 'grad_norm': 6.618567943572998, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11006/12630 [06:09<00:43, 37.11it/s]

{'loss': 0.2414, 'grad_norm': 11.140395164489746, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [06:29<00:33, 33.17it/s]

{'loss': 0.2467, 'grad_norm': 14.155277252197266, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12006/12630 [06:45<00:17, 35.25it/s]

{'loss': 0.2441, 'grad_norm': 8.487775802612305, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12505/12630 [06:58<00:03, 38.07it/s]

{'loss': 0.2325, 'grad_norm': 10.017410278320312, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [07:01<00:00, 29.93it/s]


{'train_runtime': 421.925, 'train_samples_per_second': 478.87, 'train_steps_per_second': 29.934, 'train_loss': 0.3188454681601868, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 41.80it/s] 
  4%|▍         | 504/12630 [00:13<04:15, 47.39it/s]

{'loss': 0.682, 'grad_norm': 1.9929654598236084, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1001/12630 [00:28<16:06, 12.03it/s]

{'loss': 0.5722, 'grad_norm': 10.409335136413574, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1501/12630 [01:02<16:12, 11.45it/s]

{'loss': 0.4582, 'grad_norm': 11.6270112991333, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2008/12630 [01:22<03:59, 44.31it/s]

{'loss': 0.3961, 'grad_norm': 8.114145278930664, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2500/12630 [01:35<12:58, 13.02it/s]

{'loss': 0.3902, 'grad_norm': 10.9232177734375, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3006/12630 [01:59<05:01, 31.87it/s]

{'loss': 0.371, 'grad_norm': 14.025274276733398, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3502/12630 [02:31<13:09, 11.56it/s]

{'loss': 0.3562, 'grad_norm': 9.31747817993164, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4007/12630 [02:46<02:59, 48.09it/s]

{'loss': 0.3445, 'grad_norm': 14.842039108276367, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4506/12630 [02:56<02:57, 45.79it/s]

{'loss': 0.3123, 'grad_norm': 11.65176010131836, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5006/12630 [03:09<03:09, 40.22it/s]

{'loss': 0.2893, 'grad_norm': 6.432310581207275, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5510/12630 [03:22<01:52, 63.05it/s]

{'loss': 0.2808, 'grad_norm': 5.108368873596191, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6006/12630 [03:30<01:50, 60.21it/s]

{'loss': 0.2772, 'grad_norm': 9.796430587768555, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6511/12630 [03:38<01:44, 58.66it/s]

{'loss': 0.2747, 'grad_norm': 17.965288162231445, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7009/12630 [03:46<01:33, 60.22it/s]

{'loss': 0.2805, 'grad_norm': 9.458638191223145, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7505/12630 [03:54<01:51, 45.80it/s]

{'loss': 0.2688, 'grad_norm': 13.117256164550781, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8009/12630 [04:02<01:11, 64.51it/s]

{'loss': 0.271, 'grad_norm': 8.93063735961914, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8507/12630 [04:11<01:04, 63.83it/s]

{'loss': 0.2582, 'grad_norm': 4.3155035972595215, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9006/12630 [04:19<00:56, 64.33it/s]

{'loss': 0.2381, 'grad_norm': 5.222204685211182, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9507/12630 [04:27<00:51, 60.70it/s]

{'loss': 0.2405, 'grad_norm': 9.65869140625, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10001/12630 [04:36<01:33, 28.01it/s]

{'loss': 0.2374, 'grad_norm': 16.683135986328125, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10511/12630 [04:46<00:34, 61.70it/s]

{'loss': 0.2428, 'grad_norm': 9.413870811462402, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11006/12630 [04:55<00:27, 59.15it/s]

{'loss': 0.2427, 'grad_norm': 19.81125259399414, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [05:03<00:18, 61.63it/s]

{'loss': 0.2434, 'grad_norm': 11.571051597595215, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12008/12630 [05:12<00:12, 49.59it/s]

{'loss': 0.241, 'grad_norm': 21.128816604614258, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12507/12630 [05:22<00:02, 47.42it/s]

{'loss': 0.2334, 'grad_norm': 14.131233215332031, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [05:25<00:00, 38.83it/s]


{'train_runtime': 325.3226, 'train_samples_per_second': 621.067, 'train_steps_per_second': 38.823, 'train_loss': 0.3193842530722478, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 43.58it/s] 
  4%|▍         | 505/12630 [00:19<09:38, 20.97it/s]

{'loss': 0.683, 'grad_norm': 1.7202165126800537, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1004/12630 [00:49<05:01, 38.60it/s]

{'loss': 0.5818, 'grad_norm': 6.864801406860352, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1504/12630 [01:06<05:44, 32.33it/s]

{'loss': 0.456, 'grad_norm': 12.867807388305664, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2004/12630 [01:19<05:00, 35.33it/s]

{'loss': 0.4029, 'grad_norm': 20.309972763061523, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2506/12630 [01:32<03:53, 43.42it/s]

{'loss': 0.3922, 'grad_norm': 11.55708122253418, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3001/12630 [01:49<14:04, 11.41it/s]

{'loss': 0.365, 'grad_norm': 6.694890975952148, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3508/12630 [02:07<03:26, 44.20it/s]

{'loss': 0.36, 'grad_norm': 10.44481086730957, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4008/12630 [02:32<03:49, 37.61it/s]

{'loss': 0.3524, 'grad_norm': 4.332815170288086, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4503/12630 [03:09<08:14, 16.43it/s]

{'loss': 0.3144, 'grad_norm': 7.026075839996338, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5001/12630 [03:28<03:15, 39.00it/s]

{'loss': 0.2956, 'grad_norm': 9.677919387817383, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5503/12630 [03:49<03:08, 37.88it/s]

{'loss': 0.2805, 'grad_norm': 6.065370082855225, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6005/12630 [04:06<03:11, 34.59it/s]

{'loss': 0.276, 'grad_norm': 11.281490325927734, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 51%|█████▏    | 6501/12630 [04:30<06:35, 15.50it/s]

{'loss': 0.2852, 'grad_norm': 8.91300106048584, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7004/12630 [04:58<02:33, 36.59it/s]

{'loss': 0.2758, 'grad_norm': 11.807966232299805, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7507/12630 [05:08<01:31, 55.99it/s]

{'loss': 0.2691, 'grad_norm': 13.169997215270996, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8007/12630 [05:17<01:25, 54.37it/s]

{'loss': 0.275, 'grad_norm': 3.8826866149902344, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8506/12630 [05:26<01:13, 56.20it/s]

{'loss': 0.26, 'grad_norm': 4.51228141784668, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9007/12630 [05:34<01:03, 57.24it/s]

{'loss': 0.2367, 'grad_norm': 16.51564598083496, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9512/12630 [05:43<00:52, 58.91it/s]

{'loss': 0.2402, 'grad_norm': 12.721693992614746, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10009/12630 [05:53<00:52, 50.17it/s]

{'loss': 0.2406, 'grad_norm': 6.318933963775635, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10506/12630 [06:02<00:37, 57.05it/s]

{'loss': 0.245, 'grad_norm': 9.314985275268555, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11006/12630 [06:11<00:29, 55.30it/s]

{'loss': 0.2421, 'grad_norm': 9.3798828125, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11508/12630 [06:20<00:22, 50.90it/s]

{'loss': 0.2475, 'grad_norm': 15.718522071838379, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12012/12630 [06:30<00:13, 46.87it/s]

{'loss': 0.2497, 'grad_norm': 7.659643650054932, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12506/12630 [06:39<00:04, 26.70it/s]

{'loss': 0.2328, 'grad_norm': 2.733236312866211, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [06:42<00:00, 31.41it/s]


{'train_runtime': 402.1344, 'train_samples_per_second': 502.436, 'train_steps_per_second': 31.407, 'train_loss': 0.32143427183584666, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 41.21it/s] 
  4%|▍         | 505/12630 [00:18<05:34, 36.27it/s]

{'loss': 0.6834, 'grad_norm': 1.608457088470459, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1005/12630 [00:33<05:09, 37.56it/s]

{'loss': 0.5997, 'grad_norm': 7.167250633239746, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1505/12630 [00:47<04:32, 40.80it/s]

{'loss': 0.4722, 'grad_norm': 14.226116180419922, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2005/12630 [01:00<04:07, 42.97it/s]

{'loss': 0.4062, 'grad_norm': 14.005451202392578, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2506/12630 [01:11<04:05, 41.21it/s]

{'loss': 0.3962, 'grad_norm': 4.0499186515808105, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3006/12630 [01:27<03:16, 48.97it/s]

{'loss': 0.3695, 'grad_norm': 12.395832061767578, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3504/12630 [01:38<03:58, 38.20it/s]

{'loss': 0.3527, 'grad_norm': 11.522063255310059, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4003/12630 [01:51<03:15, 44.07it/s]

{'loss': 0.3429, 'grad_norm': 11.571231842041016, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4506/12630 [02:04<04:04, 33.29it/s]

{'loss': 0.3139, 'grad_norm': 8.227351188659668, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5005/12630 [02:17<03:24, 37.33it/s]

{'loss': 0.2945, 'grad_norm': 9.746760368347168, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5502/12630 [02:49<09:41, 12.26it/s]

{'loss': 0.2859, 'grad_norm': 7.436403751373291, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6000/12630 [03:02<03:53, 28.37it/s]

{'loss': 0.2703, 'grad_norm': 10.212536811828613, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 52%|█████▏    | 6507/12630 [03:33<02:16, 44.73it/s]

{'loss': 0.2859, 'grad_norm': 12.119010925292969, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7006/12630 [03:50<02:03, 45.53it/s]

{'loss': 0.2712, 'grad_norm': 5.347606182098389, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7503/12630 [04:16<01:57, 43.63it/s]

{'loss': 0.2676, 'grad_norm': 12.43415355682373, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8001/12630 [04:43<04:50, 15.91it/s]

{'loss': 0.2704, 'grad_norm': 10.044366836547852, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8502/12630 [04:57<05:15, 13.09it/s]

{'loss': 0.2571, 'grad_norm': 3.128206253051758, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9003/12630 [05:24<03:20, 18.06it/s]

{'loss': 0.2332, 'grad_norm': 7.600353717803955, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9502/12630 [05:59<03:10, 16.38it/s]

{'loss': 0.2353, 'grad_norm': 12.903080940246582, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10006/12630 [06:16<01:10, 37.36it/s]

{'loss': 0.2393, 'grad_norm': 9.413187980651855, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10505/12630 [06:29<00:48, 43.88it/s]

{'loss': 0.2394, 'grad_norm': 13.17304515838623, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11007/12630 [06:43<00:43, 37.50it/s]

{'loss': 0.2461, 'grad_norm': 3.8704020977020264, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11506/12630 [06:56<00:47, 23.55it/s]

{'loss': 0.2484, 'grad_norm': 24.19544792175293, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12006/12630 [07:13<00:14, 42.50it/s]

{'loss': 0.2451, 'grad_norm': 11.285748481750488, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12502/12630 [07:34<00:11, 10.76it/s]

{'loss': 0.2208, 'grad_norm': 3.1752610206604004, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [07:43<00:00, 27.23it/s]


{'train_runtime': 463.7179, 'train_samples_per_second': 435.711, 'train_steps_per_second': 27.236, 'train_loss': 0.3210560157576536, 'epoch': 3.0}


100%|██████████| 55/55 [00:02<00:00, 22.84it/s]
  4%|▍         | 501/12630 [00:35<13:49, 14.62it/s]

{'loss': 0.6829, 'grad_norm': 1.7970457077026367, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1001/12630 [01:12<18:13, 10.63it/s]

{'loss': 0.5734, 'grad_norm': 7.3290815353393555, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1506/12630 [01:46<04:20, 42.69it/s]

{'loss': 0.458, 'grad_norm': 16.108062744140625, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2004/12630 [02:00<04:14, 41.74it/s]

{'loss': 0.4063, 'grad_norm': 13.818406105041504, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2505/12630 [02:12<03:59, 42.21it/s]

{'loss': 0.3904, 'grad_norm': 9.23214340209961, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3008/12630 [02:25<03:57, 40.55it/s]

{'loss': 0.357, 'grad_norm': 10.363924980163574, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3501/12630 [02:59<08:19, 18.26it/s]

{'loss': 0.3539, 'grad_norm': 10.293551445007324, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4005/12630 [03:19<03:15, 44.03it/s]

{'loss': 0.3395, 'grad_norm': 6.436017990112305, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4502/12630 [03:38<08:37, 15.71it/s]

{'loss': 0.3121, 'grad_norm': 9.682419776916504, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5006/12630 [04:09<03:07, 40.72it/s]

{'loss': 0.2857, 'grad_norm': 13.683486938476562, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5505/12630 [04:23<02:54, 40.77it/s]

{'loss': 0.2794, 'grad_norm': 13.149089813232422, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6001/12630 [04:51<08:20, 13.26it/s]

{'loss': 0.2687, 'grad_norm': 9.62001895904541, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 51%|█████▏    | 6502/12630 [05:15<09:15, 11.04it/s]

{'loss': 0.2813, 'grad_norm': 11.09862995147705, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7003/12630 [05:35<02:19, 40.47it/s]

{'loss': 0.2739, 'grad_norm': 11.435145378112793, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7507/12630 [05:46<01:48, 47.08it/s]

{'loss': 0.2605, 'grad_norm': 11.659944534301758, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8005/12630 [06:05<01:48, 42.78it/s]

{'loss': 0.2731, 'grad_norm': 7.11970329284668, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8502/12630 [06:16<01:24, 48.68it/s]

{'loss': 0.2559, 'grad_norm': 0.6110323667526245, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9002/12630 [06:52<04:58, 12.15it/s]

{'loss': 0.2329, 'grad_norm': 8.200034141540527, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9502/12630 [07:28<04:33, 11.42it/s]

{'loss': 0.2387, 'grad_norm': 5.027324676513672, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10002/12630 [08:08<03:56, 11.10it/s]

{'loss': 0.24, 'grad_norm': 1.433214545249939, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10502/12630 [08:45<03:13, 10.99it/s]

{'loss': 0.241, 'grad_norm': 9.207327842712402, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11002/12630 [09:23<01:34, 17.25it/s]

{'loss': 0.2449, 'grad_norm': 6.88547945022583, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11507/12630 [09:58<00:29, 38.24it/s]

{'loss': 0.2432, 'grad_norm': 21.283525466918945, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12004/12630 [10:32<00:27, 23.10it/s]

{'loss': 0.2379, 'grad_norm': 4.448768615722656, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12505/12630 [11:09<00:04, 27.48it/s]

{'loss': 0.23, 'grad_norm': 3.5228543281555176, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [11:17<00:00, 18.64it/s]


{'train_runtime': 677.4825, 'train_samples_per_second': 298.232, 'train_steps_per_second': 18.643, 'train_loss': 0.3176317139457165, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 44.00it/s]
  4%|▍         | 502/12630 [00:25<20:00, 10.11it/s]

{'loss': 0.6801, 'grad_norm': 2.3464624881744385, 'learning_rate': 4.802058590657166e-05, 'epoch': 0.12}


  8%|▊         | 1001/12630 [01:01<07:25, 26.10it/s]

{'loss': 0.5745, 'grad_norm': 4.315123558044434, 'learning_rate': 4.604117181314331e-05, 'epoch': 0.24}


 12%|█▏        | 1501/12630 [01:43<11:24, 16.26it/s]

{'loss': 0.4557, 'grad_norm': 20.011852264404297, 'learning_rate': 4.406175771971497e-05, 'epoch': 0.36}


 16%|█▌        | 2001/12630 [02:25<09:28, 18.68it/s]

{'loss': 0.3992, 'grad_norm': 10.660030364990234, 'learning_rate': 4.208234362628662e-05, 'epoch': 0.48}


 20%|█▉        | 2500/12630 [03:09<07:07, 23.72it/s]

{'loss': 0.3919, 'grad_norm': 18.773780822753906, 'learning_rate': 4.0102929532858274e-05, 'epoch': 0.59}


 24%|██▍       | 3000/12630 [03:51<11:07, 14.42it/s]

{'loss': 0.3651, 'grad_norm': 9.234095573425293, 'learning_rate': 3.812351543942993e-05, 'epoch': 0.71}


 28%|██▊       | 3501/12630 [04:34<15:27,  9.85it/s]

{'loss': 0.3587, 'grad_norm': 5.500280857086182, 'learning_rate': 3.614410134600158e-05, 'epoch': 0.83}


 32%|███▏      | 4002/12630 [05:19<11:07, 12.93it/s]

{'loss': 0.3476, 'grad_norm': 8.510976791381836, 'learning_rate': 3.4164687252573244e-05, 'epoch': 0.95}


 36%|███▌      | 4504/12630 [05:45<03:50, 35.28it/s]

{'loss': 0.3101, 'grad_norm': 7.702953815460205, 'learning_rate': 3.21852731591449e-05, 'epoch': 1.07}


 40%|███▉      | 5007/12630 [06:07<04:40, 27.18it/s]

{'loss': 0.2902, 'grad_norm': 19.940528869628906, 'learning_rate': 3.0205859065716553e-05, 'epoch': 1.19}


 44%|████▎     | 5503/12630 [06:45<09:26, 12.59it/s]

{'loss': 0.2734, 'grad_norm': 5.804046630859375, 'learning_rate': 2.82264449722882e-05, 'epoch': 1.31}


 48%|████▊     | 6002/12630 [07:23<08:03, 13.72it/s]

{'loss': 0.2758, 'grad_norm': 10.556024551391602, 'learning_rate': 2.6247030878859858e-05, 'epoch': 1.43}


 51%|█████▏    | 6500/12630 [07:58<10:04, 10.15it/s]

{'loss': 0.2756, 'grad_norm': 14.702404022216797, 'learning_rate': 2.4267616785431512e-05, 'epoch': 1.54}


 55%|█████▌    | 7001/12630 [08:36<02:19, 40.30it/s]

{'loss': 0.2739, 'grad_norm': 11.393826484680176, 'learning_rate': 2.228820269200317e-05, 'epoch': 1.66}


 59%|█████▉    | 7501/12630 [09:19<09:20,  9.15it/s]

{'loss': 0.2596, 'grad_norm': 7.653048038482666, 'learning_rate': 2.0308788598574824e-05, 'epoch': 1.78}


 63%|██████▎   | 8004/12630 [09:45<01:42, 45.08it/s]

{'loss': 0.2718, 'grad_norm': 8.35500717163086, 'learning_rate': 1.8329374505146475e-05, 'epoch': 1.9}


 67%|██████▋   | 8502/12630 [10:02<04:32, 15.16it/s]

{'loss': 0.2643, 'grad_norm': 1.344560146331787, 'learning_rate': 1.6349960411718133e-05, 'epoch': 2.02}


 71%|███████▏  | 9005/12630 [10:18<01:32, 39.15it/s]

{'loss': 0.2383, 'grad_norm': 15.753072738647461, 'learning_rate': 1.4370546318289787e-05, 'epoch': 2.14}


 75%|███████▌  | 9501/12630 [10:48<04:49, 10.81it/s]

{'loss': 0.2392, 'grad_norm': 13.478131294250488, 'learning_rate': 1.2391132224861442e-05, 'epoch': 2.26}


 79%|███████▉  | 10008/12630 [11:08<01:05, 40.01it/s]

{'loss': 0.2454, 'grad_norm': 8.819955825805664, 'learning_rate': 1.0411718131433096e-05, 'epoch': 2.38}


 83%|████████▎ | 10503/12630 [11:20<00:51, 41.42it/s]

{'loss': 0.249, 'grad_norm': 11.823079109191895, 'learning_rate': 8.432304038004752e-06, 'epoch': 2.49}


 87%|████████▋ | 11004/12630 [11:44<00:37, 42.81it/s]

{'loss': 0.2444, 'grad_norm': 9.496817588806152, 'learning_rate': 6.4528899445764055e-06, 'epoch': 2.61}


 91%|█████████ | 11500/12630 [12:01<01:07, 16.63it/s]

{'loss': 0.2488, 'grad_norm': 14.715801239013672, 'learning_rate': 4.47347585114806e-06, 'epoch': 2.73}


 95%|█████████▌| 12007/12630 [12:31<00:15, 39.69it/s]

{'loss': 0.2476, 'grad_norm': 9.80858325958252, 'learning_rate': 2.494061757719715e-06, 'epoch': 2.85}


 99%|█████████▉| 12502/12630 [12:53<00:10, 12.12it/s]

{'loss': 0.2311, 'grad_norm': 24.997697830200195, 'learning_rate': 5.146476642913698e-07, 'epoch': 2.97}


100%|██████████| 12630/12630 [12:58<00:00, 16.23it/s]


{'train_runtime': 778.1309, 'train_samples_per_second': 259.657, 'train_steps_per_second': 16.231, 'train_loss': 0.31958351180559097, 'epoch': 3.0}


100%|██████████| 55/55 [00:01<00:00, 45.86it/s] 
 11%|█         | 502/4689 [00:45<03:56, 17.74it/s]

{'loss': 0.6883, 'grad_norm': 1.5795190334320068, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1004/4689 [01:19<02:51, 21.47it/s]

{'loss': 0.4732, 'grad_norm': 5.931919574737549, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1501/4689 [02:03<04:40, 11.35it/s]

{'loss': 0.389, 'grad_norm': 12.642401695251465, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2000/4689 [02:47<02:22, 18.86it/s]

{'loss': 0.3141, 'grad_norm': 23.479047775268555, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [03:38<04:38,  7.86it/s]

{'loss': 0.2779, 'grad_norm': 24.330202102661133, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3003/4689 [04:18<02:14, 12.51it/s]

{'loss': 0.2806, 'grad_norm': 27.7852840423584, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3503/4689 [05:06<01:07, 17.68it/s]

{'loss': 0.2425, 'grad_norm': 47.115604400634766, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4004/4689 [05:33<00:31, 21.59it/s]

{'loss': 0.2376, 'grad_norm': 6.6867499351501465, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [06:00<00:09, 19.09it/s]

{'loss': 0.2354, 'grad_norm': 3.4770352840423584, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:13<00:00, 12.56it/s]


{'train_runtime': 373.3345, 'train_samples_per_second': 200.892, 'train_steps_per_second': 12.56, 'train_loss': 0.34443309927932436, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:01<00:00, 25.54it/s]
 11%|█         | 503/4689 [00:27<02:26, 28.49it/s]

{'loss': 0.6793, 'grad_norm': 2.663402557373047, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1001/4689 [00:48<06:39,  9.23it/s]

{'loss': 0.4898, 'grad_norm': 7.5676422119140625, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1504/4689 [01:10<01:59, 26.65it/s]

{'loss': 0.3953, 'grad_norm': 14.162248611450195, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2005/4689 [01:27<01:35, 28.10it/s]

{'loss': 0.3188, 'grad_norm': 20.887218475341797, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [01:52<05:32,  6.59it/s]

{'loss': 0.2839, 'grad_norm': 18.217679977416992, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3001/4689 [02:40<04:00,  7.03it/s]

{'loss': 0.2808, 'grad_norm': 20.771020889282227, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3506/4689 [03:18<00:37, 31.46it/s]

{'loss': 0.2411, 'grad_norm': 45.658935546875, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4004/4689 [03:56<00:29, 23.26it/s]

{'loss': 0.2341, 'grad_norm': 8.102738380432129, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [04:41<00:25,  7.29it/s]

{'loss': 0.2447, 'grad_norm': 11.26942253112793, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [04:58<00:00, 15.72it/s]


{'train_runtime': 298.3334, 'train_samples_per_second': 251.397, 'train_steps_per_second': 15.717, 'train_loss': 0.3473158014890964, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:17<00:00, 20.08it/s]
Map:   0%|          | 0/25000 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (532 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 25000/25000 [03:23<00:00, 122.83 examples/s]
 11%|█         | 503/4689 [00:20<02:37, 26.63it/s]

{'loss': 0.6781, 'grad_norm': 2.566037178039551, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1005/4689 [00:40<02:17, 26.83it/s]

{'loss': 0.477, 'grad_norm': 7.015756130218506, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1504/4689 [01:00<02:20, 22.73it/s]

{'loss': 0.3886, 'grad_norm': 10.732378005981445, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2001/4689 [01:38<06:53,  6.50it/s]

{'loss': 0.3134, 'grad_norm': 29.06294059753418, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2503/4689 [02:37<01:33, 23.26it/s]

{'loss': 0.2884, 'grad_norm': 13.614151954650879, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3000/4689 [02:58<01:03, 26.46it/s]

{'loss': 0.2882, 'grad_norm': 7.381897449493408, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [03:25<02:43,  7.26it/s]

{'loss': 0.2488, 'grad_norm': 74.97118377685547, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4004/4689 [04:15<00:50, 13.50it/s]

{'loss': 0.2442, 'grad_norm': 0.5268580317497253, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [04:57<00:28,  6.69it/s]

{'loss': 0.2506, 'grad_norm': 12.697624206542969, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [05:12<00:00, 15.00it/s]


{'train_runtime': 312.6061, 'train_samples_per_second': 239.919, 'train_steps_per_second': 15.0, 'train_loss': 0.34878610654486397, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:25<00:00, 18.24it/s]
 11%|█         | 503/4689 [00:39<02:53, 24.19it/s]

{'loss': 0.6823, 'grad_norm': 2.994295358657837, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1002/4689 [01:20<02:52, 21.39it/s]

{'loss': 0.4795, 'grad_norm': 6.172872066497803, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1500/4689 [02:00<02:59, 17.79it/s]

{'loss': 0.4034, 'grad_norm': 11.212258338928223, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2003/4689 [02:40<02:08, 20.91it/s]

{'loss': 0.3268, 'grad_norm': 11.621807098388672, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2503/4689 [03:10<01:23, 26.34it/s]

{'loss': 0.2915, 'grad_norm': 8.632268905639648, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3001/4689 [03:58<03:18,  8.51it/s]

{'loss': 0.2848, 'grad_norm': 24.18243408203125, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [04:46<02:41,  7.34it/s]

{'loss': 0.2473, 'grad_norm': 51.634708404541016, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4002/4689 [05:31<01:01, 11.15it/s]

{'loss': 0.2389, 'grad_norm': 3.233166217803955, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4503/4689 [06:11<00:09, 19.28it/s]

{'loss': 0.243, 'grad_norm': 14.348790168762207, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:27<00:00, 12.09it/s]


{'train_runtime': 387.7577, 'train_samples_per_second': 193.42, 'train_steps_per_second': 12.093, 'train_loss': 0.35038552962531155, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:33<00:00, 16.79it/s]
Map:   0%|          | 0/25000 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (720 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 25000/25000 [04:56<00:00, 84.23 examples/s] 
Map: 100%|██████████| 25000/25000 [01:51<00:00, 223.67 examples/s]
 11%|█         | 503/4689 [00:21<03:01, 23.03it/s]

{'loss': 0.6804, 'grad_norm': 2.6618316173553467, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1005/4689 [00:42<02:24, 25.54it/s]

{'loss': 0.4911, 'grad_norm': 8.78788948059082, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1503/4689 [01:12<02:31, 21.00it/s]

{'loss': 0.4241, 'grad_norm': 11.335580825805664, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2001/4689 [01:50<04:49,  9.29it/s]

{'loss': 0.3399, 'grad_norm': 25.998159408569336, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2503/4689 [02:28<01:46, 20.53it/s]

{'loss': 0.3029, 'grad_norm': 25.900848388671875, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3001/4689 [03:03<02:49,  9.98it/s]

{'loss': 0.3002, 'grad_norm': 16.75875473022461, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [03:41<02:47,  7.10it/s]

{'loss': 0.2559, 'grad_norm': 43.327613830566406, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4001/4689 [04:16<01:07, 10.21it/s]

{'loss': 0.2529, 'grad_norm': 3.2257580757141113, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4500/4689 [05:01<00:12, 14.75it/s]

{'loss': 0.2458, 'grad_norm': 6.924628734588623, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [05:19<00:00, 14.69it/s]


{'train_runtime': 319.1408, 'train_samples_per_second': 235.006, 'train_steps_per_second': 14.693, 'train_loss': 0.36147251727103574, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:25<00:00, 18.37it/s]
 11%|█         | 501/4689 [00:51<07:20,  9.52it/s]

{'loss': 0.6731, 'grad_norm': 3.2304186820983887, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1003/4689 [01:35<02:31, 24.32it/s] 

{'loss': 0.5145, 'grad_norm': 4.526483535766602, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1501/4689 [02:08<05:55,  8.96it/s]

{'loss': 0.4535, 'grad_norm': 7.193591594696045, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2001/4689 [02:45<02:33, 17.56it/s]

{'loss': 0.3841, 'grad_norm': 15.056717872619629, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2503/4689 [03:17<02:39, 13.69it/s]

{'loss': 0.336, 'grad_norm': 27.891576766967773, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3003/4689 [03:46<01:18, 21.46it/s]

{'loss': 0.3306, 'grad_norm': 13.423554420471191, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3503/4689 [04:19<01:04, 18.49it/s]

{'loss': 0.2767, 'grad_norm': 48.286861419677734, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4004/4689 [04:49<00:32, 20.77it/s]

{'loss': 0.2643, 'grad_norm': 12.456216812133789, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4502/4689 [05:17<00:10, 18.29it/s]

{'loss': 0.2752, 'grad_norm': 10.071730613708496, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [05:35<00:00, 13.96it/s]


{'train_runtime': 335.8315, 'train_samples_per_second': 223.326, 'train_steps_per_second': 13.962, 'train_loss': 0.3844040008108863, 'epoch': 3.0}


100%|██████████| 1563/1563 [00:57<00:00, 27.09it/s]
 11%|█         | 503/4689 [00:17<02:16, 30.62it/s]

{'loss': 0.6823, 'grad_norm': 2.2263879776000977, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1000/4689 [00:49<06:43,  9.14it/s]

{'loss': 0.4844, 'grad_norm': 6.003767490386963, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1500/4689 [01:26<04:24, 12.04it/s]

{'loss': 0.3992, 'grad_norm': 17.222274780273438, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2000/4689 [02:01<01:48, 24.86it/s]

{'loss': 0.3134, 'grad_norm': 25.32183837890625, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [02:42<05:05,  7.17it/s]

{'loss': 0.2911, 'grad_norm': 22.27342987060547, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3000/4689 [03:24<02:39, 10.60it/s]

{'loss': 0.2892, 'grad_norm': 8.776931762695312, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3505/4689 [04:10<01:04, 18.42it/s]

{'loss': 0.241, 'grad_norm': 61.78215789794922, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4002/4689 [04:58<01:09,  9.89it/s]

{'loss': 0.2379, 'grad_norm': 5.02754545211792, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4500/4689 [05:39<00:11, 16.94it/s]

{'loss': 0.2486, 'grad_norm': 12.843093872070312, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [05:53<00:00, 13.25it/s]


{'train_runtime': 353.9068, 'train_samples_per_second': 211.92, 'train_steps_per_second': 13.249, 'train_loss': 0.34963607299960886, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:12<00:00, 21.48it/s]
 11%|█         | 505/4689 [00:47<02:49, 24.65it/s]

{'loss': 0.6849, 'grad_norm': 1.685364007949829, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1003/4689 [01:30<04:08, 14.84it/s]

{'loss': 0.4775, 'grad_norm': 4.333187103271484, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1500/4689 [02:09<03:14, 16.40it/s]

{'loss': 0.3884, 'grad_norm': 11.353154182434082, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2003/4689 [02:43<01:45, 25.48it/s]

{'loss': 0.315, 'grad_norm': 30.046735763549805, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [03:18<03:39,  9.97it/s]

{'loss': 0.2828, 'grad_norm': 20.55183982849121, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3002/4689 [04:01<02:52,  9.79it/s]

{'loss': 0.2848, 'grad_norm': 6.836095809936523, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [04:42<01:53, 10.47it/s]

{'loss': 0.2379, 'grad_norm': 53.8409309387207, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4001/4689 [05:25<01:27,  7.87it/s]

{'loss': 0.2416, 'grad_norm': 0.5431622862815857, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [06:01<00:09, 20.05it/s]

{'loss': 0.2446, 'grad_norm': 4.462981224060059, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:16<00:00, 12.45it/s]


{'train_runtime': 376.5177, 'train_samples_per_second': 199.194, 'train_steps_per_second': 12.454, 'train_loss': 0.34617914373520264, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:45<00:00, 14.88it/s]
 11%|█         | 504/4689 [00:40<03:12, 21.71it/s]

{'loss': 0.6804, 'grad_norm': 2.4530439376831055, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1000/4689 [01:27<03:42, 16.58it/s]

{'loss': 0.4794, 'grad_norm': 6.712974548339844, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1505/4689 [02:11<01:57, 27.13it/s]

{'loss': 0.3931, 'grad_norm': 10.4100980758667, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2003/4689 [02:48<02:45, 16.25it/s]

{'loss': 0.3174, 'grad_norm': 31.20381736755371, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2502/4689 [03:30<02:38, 13.76it/s]

{'loss': 0.2839, 'grad_norm': 16.75237274169922, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3005/4689 [04:14<01:15, 22.28it/s]

{'loss': 0.2863, 'grad_norm': 18.510601043701172, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [04:55<02:15,  8.74it/s]

{'loss': 0.2416, 'grad_norm': 48.8510856628418, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4001/4689 [05:41<01:08,  9.98it/s]

{'loss': 0.2329, 'grad_norm': 0.9729458093643188, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [06:21<00:15, 12.28it/s]

{'loss': 0.2456, 'grad_norm': 8.706291198730469, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:33<00:00, 11.90it/s]


{'train_runtime': 393.9322, 'train_samples_per_second': 190.388, 'train_steps_per_second': 11.903, 'train_loss': 0.3469288948630188, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:19<00:00, 19.58it/s]
Map:   0%|          | 0/25000 [00:00<?, ? examples/s]Token indices sequence length is longer than the specified maximum sequence length for this model (532 > 512). Running this sequence through the model will result in indexing errors
Map: 100%|██████████| 25000/25000 [02:38<00:00, 157.95 examples/s]
 11%|█         | 504/4689 [00:22<02:56, 23.71it/s]

{'loss': 0.6895, 'grad_norm': 1.7727229595184326, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1002/4689 [00:46<02:21, 26.03it/s]

{'loss': 0.502, 'grad_norm': 5.15670919418335, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1502/4689 [01:18<09:29,  5.59it/s]

{'loss': 0.3932, 'grad_norm': 17.386240005493164, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2002/4689 [01:45<02:12, 20.25it/s]

{'loss': 0.3121, 'grad_norm': 21.574377059936523, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [02:27<05:32,  6.57it/s]

{'loss': 0.281, 'grad_norm': 28.624109268188477, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3003/4689 [03:05<01:23, 20.20it/s]

{'loss': 0.2823, 'grad_norm': 14.247049331665039, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3503/4689 [03:39<01:06, 17.92it/s]

{'loss': 0.2391, 'grad_norm': 58.882720947265625, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4002/4689 [04:09<00:28, 24.06it/s]

{'loss': 0.2323, 'grad_norm': 1.12212073802948, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [04:53<00:27,  6.86it/s]

{'loss': 0.2464, 'grad_norm': 23.494747161865234, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [05:12<00:00, 15.01it/s]


{'train_runtime': 312.4807, 'train_samples_per_second': 240.015, 'train_steps_per_second': 15.006, 'train_loss': 0.34898609707263467, 'epoch': 3.0}


100%|██████████| 1563/1563 [02:07<00:00, 12.27it/s]
 11%|█         | 501/4689 [00:56<09:36,  7.26it/s]

{'loss': 0.6854, 'grad_norm': 2.292799949645996, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1001/4689 [02:02<06:51,  8.97it/s] 

{'loss': 0.5139, 'grad_norm': 4.141778469085693, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1501/4689 [02:59<07:54,  6.72it/s]

{'loss': 0.3986, 'grad_norm': 12.730401992797852, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2001/4689 [03:58<05:05,  8.80it/s]

{'loss': 0.3173, 'grad_norm': 31.381786346435547, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [04:52<04:25,  8.24it/s]

{'loss': 0.2782, 'grad_norm': 32.185546875, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3000/4689 [05:47<02:51,  9.86it/s]

{'loss': 0.2881, 'grad_norm': 16.915699005126953, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3502/4689 [06:42<02:00,  9.85it/s]

{'loss': 0.2475, 'grad_norm': 60.469627380371094, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4002/4689 [07:34<00:59, 11.47it/s]

{'loss': 0.2379, 'grad_norm': 13.436944961547852, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [08:23<00:23,  7.94it/s]

{'loss': 0.2479, 'grad_norm': 8.436692237854004, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [08:43<00:00,  8.95it/s]


{'train_runtime': 523.8598, 'train_samples_per_second': 143.168, 'train_steps_per_second': 8.951, 'train_loss': 0.35229111694023535, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:31<00:00, 17.16it/s]
 11%|█         | 500/4689 [00:48<04:54, 14.23it/s]

{'loss': 0.6862, 'grad_norm': 1.3980872631072998, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1001/4689 [01:22<05:33, 11.05it/s]

{'loss': 0.5239, 'grad_norm': 3.8759493827819824, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1506/4689 [02:02<01:48, 29.26it/s]

{'loss': 0.44, 'grad_norm': 8.883119583129883, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2000/4689 [02:38<04:13, 10.61it/s]

{'loss': 0.3437, 'grad_norm': 16.732534408569336, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2500/4689 [03:31<04:37,  7.88it/s]

{'loss': 0.3078, 'grad_norm': 6.999098777770996, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3005/4689 [04:19<01:35, 17.68it/s]

{'loss': 0.2994, 'grad_norm': 17.98838233947754, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3500/4689 [04:59<01:10, 16.91it/s]

{'loss': 0.2563, 'grad_norm': 53.45664978027344, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4004/4689 [05:39<00:43, 15.84it/s]

{'loss': 0.251, 'grad_norm': 1.6906161308288574, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4504/4689 [06:23<00:12, 15.34it/s]

{'loss': 0.2563, 'grad_norm': 1.5543040037155151, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:39<00:00, 11.75it/s]


{'train_runtime': 399.0685, 'train_samples_per_second': 187.938, 'train_steps_per_second': 11.75, 'train_loss': 0.36888867762836725, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:17<00:00, 20.07it/s]
 11%|█         | 500/4689 [00:43<03:42, 18.81it/s]

{'loss': 0.6818, 'grad_norm': 2.397397041320801, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1002/4689 [01:24<02:54, 21.12it/s]

{'loss': 0.5053, 'grad_norm': 5.6941304206848145, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1502/4689 [02:02<02:25, 21.96it/s]

{'loss': 0.3983, 'grad_norm': 10.952552795410156, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2002/4689 [02:42<04:06, 10.92it/s]

{'loss': 0.3248, 'grad_norm': 28.193437576293945, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [03:30<04:44,  7.69it/s]

{'loss': 0.2896, 'grad_norm': 6.946237087249756, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3000/4689 [04:16<02:38, 10.67it/s]

{'loss': 0.2883, 'grad_norm': 8.934955596923828, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [05:01<02:17,  8.63it/s]

{'loss': 0.2538, 'grad_norm': 35.99958801269531, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4005/4689 [05:43<00:26, 25.62it/s]

{'loss': 0.2407, 'grad_norm': 0.4945574104785919, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4502/4689 [06:17<00:12, 14.49it/s]

{'loss': 0.2516, 'grad_norm': 7.32130765914917, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:34<00:00, 11.89it/s]


{'train_runtime': 394.3453, 'train_samples_per_second': 190.189, 'train_steps_per_second': 11.891, 'train_loss': 0.3548241336860746, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:40<00:00, 15.48it/s]
 11%|█         | 502/4689 [00:43<04:24, 15.85it/s]

{'loss': 0.6863, 'grad_norm': 1.8990086317062378, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1001/4689 [01:18<06:53,  8.92it/s]

{'loss': 0.5078, 'grad_norm': 6.81248140335083, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1505/4689 [01:50<01:53, 28.02it/s]

{'loss': 0.4051, 'grad_norm': 5.922446250915527, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2004/4689 [02:25<02:58, 15.02it/s]

{'loss': 0.3335, 'grad_norm': 10.657487869262695, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2501/4689 [03:13<03:17, 11.10it/s]

{'loss': 0.2914, 'grad_norm': 16.379518508911133, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3001/4689 [04:00<03:14,  8.66it/s]

{'loss': 0.29, 'grad_norm': 8.019688606262207, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3500/4689 [04:39<00:55, 21.50it/s]

{'loss': 0.2499, 'grad_norm': 26.030622482299805, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4001/4689 [05:24<01:34,  7.30it/s]

{'loss': 0.2412, 'grad_norm': 7.126326084136963, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4501/4689 [05:56<00:19,  9.54it/s]

{'loss': 0.2478, 'grad_norm': 17.170303344726562, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:16<00:00, 12.44it/s]


{'train_runtime': 376.9944, 'train_samples_per_second': 198.942, 'train_steps_per_second': 12.438, 'train_loss': 0.3566835321681437, 'epoch': 3.0}


100%|██████████| 1563/1563 [01:15<00:00, 20.81it/s]
 11%|█         | 504/4689 [00:51<04:59, 13.96it/s]

{'loss': 0.6759, 'grad_norm': 2.862013816833496, 'learning_rate': 4.466837278737471e-05, 'epoch': 0.32}


 21%|██▏       | 1002/4689 [01:35<02:39, 23.11it/s]

{'loss': 0.5105, 'grad_norm': 6.085923671722412, 'learning_rate': 3.933674557474942e-05, 'epoch': 0.64}


 32%|███▏      | 1500/4689 [02:13<04:32, 11.72it/s]

{'loss': 0.4666, 'grad_norm': 8.154423713684082, 'learning_rate': 3.400511836212412e-05, 'epoch': 0.96}


 43%|████▎     | 2001/4689 [02:54<04:26, 10.08it/s]

{'loss': 0.3579, 'grad_norm': 22.074750900268555, 'learning_rate': 2.8673491149498826e-05, 'epoch': 1.28}


 53%|█████▎    | 2502/4689 [03:40<03:32, 10.29it/s]

{'loss': 0.3021, 'grad_norm': 26.79508399963379, 'learning_rate': 2.3341863936873534e-05, 'epoch': 1.6}


 64%|██████▍   | 3001/4689 [04:25<04:03,  6.94it/s]

{'loss': 0.3016, 'grad_norm': 17.336950302124023, 'learning_rate': 1.801023672424824e-05, 'epoch': 1.92}


 75%|███████▍  | 3501/4689 [05:05<02:06,  9.40it/s]

{'loss': 0.2529, 'grad_norm': 53.79215621948242, 'learning_rate': 1.2678609511622949e-05, 'epoch': 2.24}


 85%|████████▌ | 4001/4689 [05:45<01:13,  9.40it/s]

{'loss': 0.244, 'grad_norm': 19.0843563079834, 'learning_rate': 7.346982298997654e-06, 'epoch': 2.56}


 96%|█████████▌| 4503/4689 [06:26<00:10, 18.52it/s]

{'loss': 0.2586, 'grad_norm': 5.479204177856445, 'learning_rate': 2.015355086372361e-06, 'epoch': 2.88}


100%|██████████| 4689/4689 [06:39<00:00, 11.75it/s]


{'train_runtime': 399.0347, 'train_samples_per_second': 187.954, 'train_steps_per_second': 11.751, 'train_loss': 0.36943519951870607, 'epoch': 3.0}


100%|██████████| 1563/1563 [02:03<00:00, 12.62it/s]


# QA TASK

## Prohibiting the token from attending to itself

In [9]:
class CustomAttention(nn.Module):
    def __init__(self, config, num_context_tokens):
        super().__init__()
        # The following equality must be possible num_attention_heads * attention_head_size = config.hidden_size.
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
        
        # assign attention_head size such that num_attention_heads * attention_head_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)    #    
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Initialize projections matrices and dropout layer
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout_prob = config.attention_probs_dropout_prob
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)        

        # Central Attention complexity is O(num_context_tokens * sequence_length * hidden_size)
        self.num_context_tokens = num_context_tokens
    
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (batch_size, num_tokens, num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3) # (batch_size, num_attention_heads, num_tokens, attention_head_size)

    # Adapted from BertSelfAttention
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        
        # Project input into query, key, value matrices respectively
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # Split q,k, and v for the central attention.
        # Tokens ordering: [CLS] + [CONTEXT_1] + [CONTEXT_2] ... [CONTEXT_N] + ALL SENTENCE TOKENS + [SEP] + [PAD] + [PAD] + [PAD] ....
        # first token is [CLS], then the context tokens, then the sentence, then the [SEP], and finally the padding.
        c = self.num_context_tokens + 1  # [CLS] + context tokens are treated as the central context tokens.
        query_layer_context     = query_layer[:,:,:c,:]
        query_layer_sentence    = query_layer[:,:,c:,:]
        key_layer_context       = key_layer[:,:,:c,:]
        key_layer_sentence      = key_layer[:,:,c:,:]
        value_layer_context     = value_layer[:,:,:c,:]
        value_layer_sentence    = value_layer[:,:,c:,:]

        if attention_mask is not None:
            attention_mask_context  = attention_mask[:,:,:c,c:]
            attention_mask_sentence = attention_mask[:,:,c:,:c]  # should always be zeros, since the central context tokens will never contain padding
          
        if self.num_context_tokens == 0: # Original Attention
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.attention_head_size)

            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            if attention_mask is not None:
                attention_scores = attention_scores + attention_mask

            # Normalize the attention scores to probabilities.
            attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.dropout(attention_probs)
            attn_output = torch.matmul(attention_probs, value_layer)

        else: # Central Attention 

            attention_scores_context_to_sentence = torch.matmul(query_layer_context, key_layer_sentence.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
            attention_scores_sentence_to_context = torch.matmul(query_layer_sentence, key_layer_context.transpose(-1, -2)) / math.sqrt(self.attention_head_size)

            if attention_mask is not None:
                attention_scores_context_to_sentence = attention_scores_context_to_sentence + attention_mask_context
                attention_scores_sentence_to_context = attention_scores_sentence_to_context + attention_mask_sentence

            attention_probs_context_to_sentence = torch.nn.functional.softmax(attention_scores_context_to_sentence, dim=-1)
            attention_probs_sentence_to_context = torch.nn.functional.softmax(attention_scores_sentence_to_context, dim=-1)

            attention_probs_context_to_sentence = self.dropout(attention_probs_context_to_sentence)
            attention_probs_sentence_to_context = self.dropout(attention_probs_sentence_to_context)

            attn_output_context = torch.matmul(attention_probs_context_to_sentence, value_layer_sentence)
            attn_output_sentence= torch.matmul(attention_probs_sentence_to_context, value_layer_context)

            attn_output = torch.cat((attn_output_context, attn_output_sentence), dim = 2)

        # reformatting attention output to (batch_size, num_tokens, hidden_size), which is the contextual embedding for each token.
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        new_attn_output_shape = attn_output.size()[:-2] + (self.all_head_size,)
        attn_output = attn_output.view(new_attn_output_shape)

        outputs = (attn_output,)

        return outputs

## Dataset Preprocessing

In [10]:
def preprocess_dataset(num_context_tokens, max_length=384, stride=128):

    def preprocess_training_examples(examples):

        questions = [q.strip() for q in examples["question"]]

        inputs = tokenizer(
            questions,
            examples["context"],
            max_length=max_length,
            truncation="only_second",
            stride=stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        for i in range(len(inputs["input_ids"])):
            inputs["input_ids"][i] = inputs["input_ids"][i][:1] + context_tokens_ids + inputs["input_ids"][i][1:]  # [CLS] + context_tokens + sentence
            inputs["token_type_ids"][i] = ([0] * num_context_tokens) + inputs["token_type_ids"][i]
            inputs["attention_mask"][i] = ([1] * num_context_tokens) + inputs["attention_mask"][i]
            inputs["offset_mapping"][i] = inputs["offset_mapping"][i][:1] + ([(0,0)] * num_context_tokens) + inputs["offset_mapping"][i][1:]
            
        offset_mapping = inputs.pop("offset_mapping")
        sample_map = inputs.pop("overflow_to_sample_mapping")
        answers = examples["answers"]
        start_positions = []
        end_positions = []

        for i, offset in enumerate(offset_mapping):
            sample_idx = sample_map[i]
            answer = answers[sample_idx]
            start_char = answer["answer_start"][0]
            end_char = answer["answer_start"][0] + len(answer["text"][0])
            sequence_ids = [None] * num_context_tokens + inputs.sequence_ids(i)

            # Find the start and end of the context
            idx = 0
            while sequence_ids[idx] != 1:
                idx += 1
            context_start = idx
            while sequence_ids[idx] == 1:
                idx += 1
            context_end = idx - 1

            # If the answer is not fully inside the context, label is (0, 0)
            if offset[context_start][0] > start_char or offset[context_end][1] < end_char:
                start_positions.append(0)
                end_positions.append(0)
            else:
                # Otherwise it's the start and end token positions
                idx = context_start
                while idx <= context_end and offset[idx][0] <= start_char:
                    idx += 1
                start_positions.append(idx - 1)

                idx = context_end
                while idx >= context_start and offset[idx][1] >= end_char:
                    idx -= 1
                end_positions.append(idx + 1)

        inputs["start_positions"] = start_positions
        inputs["end_positions"] = end_positions
        return inputs


    def preprocess_validation_examples(examples):
        questions = [q.strip() for q in examples["question"]]
        inputs = tokenizer(
            questions,
            examples["context"],
            max_length=max_length,
            truncation="only_second",
            stride=stride,
            return_overflowing_tokens=True,
            return_offsets_mapping=True,
            padding="max_length",
        )

        for i in range(len(inputs["input_ids"])):
            inputs["input_ids"][i] = inputs["input_ids"][i][:1] + context_tokens_ids + inputs["input_ids"][i][1:]  # [CLS] + context_tokens + sentence
            inputs["token_type_ids"][i] = ([0] * num_context_tokens) + inputs["token_type_ids"][i]
            inputs["attention_mask"][i] = ([1] * num_context_tokens) + inputs["attention_mask"][i]
            inputs["offset_mapping"][i] = inputs["offset_mapping"][i][:1] + ([(0,0)] * num_context_tokens) + inputs["offset_mapping"][i][1:]

        sample_map = inputs.pop("overflow_to_sample_mapping")
        example_ids = []

        for i in range(len(inputs["input_ids"])):
            sample_idx = sample_map[i]
            example_ids.append(examples["id"][sample_idx])

            sequence_ids = [None] * num_context_tokens +  inputs.sequence_ids(i)
            offset = inputs["offset_mapping"][i]
            inputs["offset_mapping"][i] = [
                o if sequence_ids[k] == 1 else None for k, o in enumerate(offset)
            ]

        inputs["example_id"] = example_ids
        return inputs
    
    raw_datasets = load_dataset("squad")

    # load tokenizer
    model_checkpoint = "bert-base-uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_checkpoint) 

    # add special tokens
    special_tokens = {"additional_special_tokens": ["[CONTEXT_"+ str(i) + "]" for i in range(1,num_context_tokens+1)]}
    tokenizer.add_special_tokens(special_tokens)
    context_tokens_ids = [len(tokenizer)-num_context_tokens+i for i in range(num_context_tokens)]

    train_dataset = raw_datasets["train"].map(preprocess_training_examples,batched=True,remove_columns=raw_datasets["train"].column_names)
    validation_dataset = raw_datasets["validation"].map(preprocess_validation_examples,batched=True,remove_columns=raw_datasets["validation"].column_names)

    return train_dataset, validation_dataset, tokenizer, raw_datasets

## Evaluation Metric

In [11]:
from tqdm.auto import tqdm
metric = evaluate.load("squad")
n_best = 20
max_answer_length = 30

def compute_metrics(start_logits, end_logits, features, examples):
    example_to_features = collections.defaultdict(list)
    for idx, feature in enumerate(features):
        example_to_features[feature["example_id"]].append(idx)

    predicted_answers = []
    for example in tqdm(examples):
        example_id = example["id"]
        context = example["context"]
        answers = []

        # Loop through all features associated with that example
        for feature_index in example_to_features[example_id]:
            start_logit = start_logits[feature_index]
            end_logit = end_logits[feature_index]
            offsets = features[feature_index]["offset_mapping"]

            start_indexes = np.argsort(start_logit)[-1 : -n_best - 1 : -1].tolist()
            end_indexes = np.argsort(end_logit)[-1 : -n_best - 1 : -1].tolist()
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # Skip answers that are not fully in the context
                    if offsets[start_index] is None or offsets[end_index] is None:
                        continue
                    # Skip answers with a length that is either < 0 or > max_answer_length
                    if (
                        end_index < start_index
                        or end_index - start_index + 1 > max_answer_length
                    ):
                        continue

                    answer = {
                        "text": context[offsets[start_index][0] : offsets[end_index][1]],
                        "logit_score": start_logit[start_index] + end_logit[end_index],
                    }
                    answers.append(answer)

        # Select the answer with the best score
        if len(answers) > 0:
            best_answer = max(answers, key=lambda x: x["logit_score"])
            predicted_answers.append(
                {"id": example_id, "prediction_text": best_answer["text"]}
            )
        else:
            predicted_answers.append({"id": example_id, "prediction_text": ""})

    theoretical_answers = [{"id": ex["id"], "answers": ex["answers"]} for ex in examples]
    return metric.compute(predictions=predicted_answers, references=theoretical_answers)

## Initialize Model

In [17]:
def initialize_model(checkpoint, num_context_tokens):
    
    pretrained_model = BertForQuestionAnswering.from_pretrained(checkpoint)
    model = BertForQuestionAnswering(pretrained_model.config)

    for layer in range(pretrained_model.config.num_hidden_layers):
        model.bert.encoder.layer[layer].attention.self = CustomAttention(model.config, num_context_tokens)

    # for p in model.parameters():
    #     print(p[0][:10])
    #     break

    # model.load_state_dict(pretrained_model.state_dict())
    model.resize_token_embeddings(pretrained_model.config.vocab_size + num_context_tokens)

    # for p in model.parameters():
    #     print(p[0][:10])
    #     break

    return model

## Train and Seed functions

In [13]:
def set_seed(seed):
    torch.manual_seed(seed)
    random.seed(seed)
    np.random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        
def train(model, dataset, tokenizer, batch_size = 8):
    args = TrainingArguments(
        "bert-finetuned-squad",
        evaluation_strategy="no",
        save_strategy="no",
        learning_rate=2e-5,
        num_train_epochs=3,
        weight_decay=0.01,
        per_device_train_batch_size = batch_size,
        per_device_eval_batch_size = batch_size,
    )
    trainer = Trainer(
        model=model,
        args=args,
        train_dataset=dataset,
        tokenizer=tokenizer,
    )
    trainer.train()
    return trainer

## Run Experiments  

In [18]:
def run_experiment(num_context_tokens, seed):
# if num_context_tokens = 0, then the normal attention will be used, otherwise the central-context attention attention will be used.
    set_seed(seed)
    train_dataset, validation_dataset, tokenizer, raw_datasets = preprocess_dataset(num_context_tokens)
    model = initialize_model("prajjwal1/bert-tiny", num_context_tokens) # "prajjwal1/bert-small"
    trainer = train(model,train_dataset,tokenizer,batch_size = 8)

    # Evaluation
    predictions, _, _ = trainer.predict(validation_dataset)
    start_logits, end_logits = predictions
    results = compute_metrics(start_logits, end_logits, validation_dataset, raw_datasets["validation"])

    # log results
    with open("From_Scratch_QA_Results.txt", "a") as file:
        file.write("\n")

        file.write(f"Num Context Tokens = {num_context_tokens}      Random Seed = {seed}\n")
        
        # Write the experiment results
        file.write(str(results) + "\n")

    # original_attention_pretrained {'exact_match': 33.40586565752129, 'f1': 45.79680690116114}  , Bert-tiny (25 min)
    # pretrained (bert small) {'exact_match': 45.193945127719964, 'f1': 57.1308813120565}  Bert small (72 min)

In [None]:
seeds = [2025, 17, 771]
num_context_tokens_options = [0, 1, 8, 32, 128]
for seed in seeds:
    for num_context_tokens in num_context_tokens_options:
        run_experiment(num_context_tokens= num_context_tokens, seed= seed)

# # Num Context Tokens = 0      Random Seed = 2025
# # {'exact_match': 28.94039735099338, 'f1': 40.779770669530336} (20 min)

Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(
  2%|▏         | 501/33198 [00:17<56:53,  9.58it/s]

{'loss': 5.1034, 'grad_norm': 6.323197364807129, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1001/33198 [00:54<52:48, 10.16it/s]

{'loss': 4.7121, 'grad_norm': 5.728161811828613, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1500/33198 [01:24<46:44, 11.30it/s]

{'loss': 4.6087, 'grad_norm': 10.445772171020508, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2002/33198 [02:00<48:41, 10.68it/s]

{'loss': 4.5498, 'grad_norm': 6.27555513381958, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2501/33198 [02:38<50:55, 10.05it/s]

{'loss': 4.5022, 'grad_norm': 6.942661285400391, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3001/33198 [03:12<45:49, 10.98it/s]

{'loss': 4.4584, 'grad_norm': 9.381860733032227, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3501/33198 [03:46<47:41, 10.38it/s]

{'loss': 4.4582, 'grad_norm': 8.040411949157715, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4001/33198 [04:21<46:34, 10.45it/s]

{'loss': 4.4247, 'grad_norm': 8.752968788146973, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4502/33198 [04:53<45:52, 10.43it/s]

{'loss': 4.3868, 'grad_norm': 7.910554885864258, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5007/33198 [05:28<16:39, 28.20it/s]

{'loss': 4.3457, 'grad_norm': 9.95098876953125, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5503/33198 [05:52<22:16, 20.71it/s]

{'loss': 4.3509, 'grad_norm': 7.940922260284424, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6004/33198 [06:20<29:03, 15.60it/s]

{'loss': 4.3235, 'grad_norm': 7.440256118774414, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6502/33198 [06:58<29:58, 14.84it/s]

{'loss': 4.303, 'grad_norm': 9.122773170471191, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7001/33198 [07:25<20:30, 21.29it/s]

{'loss': 4.2849, 'grad_norm': 8.112447738647461, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7502/33198 [08:03<38:19, 11.18it/s]

{'loss': 4.2816, 'grad_norm': 8.647308349609375, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8004/33198 [08:38<20:02, 20.95it/s]

{'loss': 4.2701, 'grad_norm': 9.336305618286133, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8506/33198 [09:08<11:04, 37.19it/s]

{'loss': 4.2547, 'grad_norm': 10.55488395690918, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9002/33198 [09:38<34:13, 11.78it/s]

{'loss': 4.2379, 'grad_norm': 8.122830390930176, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9500/33198 [10:07<35:46, 11.04it/s]

{'loss': 4.2272, 'grad_norm': 8.25277328491211, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10007/33198 [10:36<09:19, 41.42it/s]

{'loss': 4.2147, 'grad_norm': 9.071359634399414, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10500/33198 [10:56<12:18, 30.74it/s]

{'loss': 4.2139, 'grad_norm': 9.932822227478027, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11007/33198 [11:17<07:16, 50.86it/s]

{'loss': 4.2053, 'grad_norm': 9.4308443069458, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11500/33198 [11:37<07:41, 47.04it/s]

{'loss': 4.1053, 'grad_norm': 8.759631156921387, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12008/33198 [12:10<07:15, 48.70it/s]

{'loss': 4.0888, 'grad_norm': 11.609369277954102, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12501/33198 [12:36<24:04, 14.33it/s]

{'loss': 4.0777, 'grad_norm': 12.198697090148926, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13000/33198 [12:55<11:50, 28.44it/s]

{'loss': 4.0637, 'grad_norm': 9.501053810119629, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13504/33198 [13:12<08:02, 40.80it/s]

{'loss': 4.0829, 'grad_norm': 10.89737606048584, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14008/33198 [13:32<06:58, 45.83it/s]

{'loss': 4.0715, 'grad_norm': 8.15256118774414, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14505/33198 [13:52<15:47, 19.74it/s]

{'loss': 4.1074, 'grad_norm': 7.888158798217773, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15000/33198 [14:18<07:11, 42.17it/s]

{'loss': 4.0986, 'grad_norm': 9.797060012817383, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15503/33198 [14:41<10:33, 27.93it/s]

{'loss': 4.0876, 'grad_norm': 8.744867324829102, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16004/33198 [15:07<16:05, 17.80it/s]

{'loss': 4.1181, 'grad_norm': 9.294384002685547, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16505/33198 [15:23<06:57, 39.97it/s]

{'loss': 4.0993, 'grad_norm': 10.6412353515625, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17007/33198 [15:36<06:16, 42.96it/s]

{'loss': 4.0762, 'grad_norm': 8.953187942504883, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17509/33198 [15:53<05:40, 46.02it/s]

{'loss': 4.092, 'grad_norm': 11.196188926696777, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18005/33198 [16:15<06:00, 42.12it/s]

{'loss': 4.1206, 'grad_norm': 8.786276817321777, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18507/33198 [16:34<05:05, 48.09it/s]

{'loss': 4.086, 'grad_norm': 8.987215042114258, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19009/33198 [16:58<07:20, 32.18it/s]

{'loss': 4.0678, 'grad_norm': 8.356576919555664, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19506/33198 [17:27<04:54, 46.54it/s]

{'loss': 4.0926, 'grad_norm': 7.361836910247803, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20001/33198 [17:50<20:57, 10.50it/s]

{'loss': 4.053, 'grad_norm': 8.924410820007324, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20501/33198 [18:23<21:56,  9.65it/s]

{'loss': 4.0849, 'grad_norm': 9.059311866760254, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21007/33198 [18:44<04:11, 48.38it/s]

{'loss': 4.0952, 'grad_norm': 7.650246620178223, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21504/33198 [19:11<04:08, 47.09it/s]

{'loss': 4.0634, 'grad_norm': 9.837230682373047, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22005/33198 [19:22<04:49, 38.60it/s]

{'loss': 4.0762, 'grad_norm': 8.584066390991211, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22504/33198 [19:40<03:53, 45.82it/s]

{'loss': 4.0474, 'grad_norm': 9.074055671691895, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23000/33198 [20:15<16:55, 10.04it/s]

{'loss': 4.0353, 'grad_norm': 9.347384452819824, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23501/33198 [20:47<16:58,  9.52it/s]

{'loss': 4.018, 'grad_norm': 9.503423690795898, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24004/33198 [21:12<03:36, 42.55it/s]

{'loss': 4.0108, 'grad_norm': 10.553555488586426, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24501/33198 [21:38<14:26, 10.03it/s]

{'loss': 4.0001, 'grad_norm': 9.598219871520996, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25007/33198 [22:12<03:17, 41.44it/s]

{'loss': 4.0166, 'grad_norm': 7.998139381408691, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25502/33198 [22:43<03:28, 36.90it/s]

{'loss': 4.0072, 'grad_norm': 8.402083396911621, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26001/33198 [23:07<07:45, 15.47it/s]

{'loss': 4.005, 'grad_norm': 9.77381420135498, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26506/33198 [23:41<02:20, 47.54it/s]

{'loss': 4.0196, 'grad_norm': 11.921496391296387, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27001/33198 [24:03<07:11, 14.36it/s]

{'loss': 4.0, 'grad_norm': 10.335360527038574, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27500/33198 [24:33<04:41, 20.28it/s]

{'loss': 4.0321, 'grad_norm': 8.275535583496094, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28002/33198 [25:01<06:50, 12.67it/s]

{'loss': 4.0301, 'grad_norm': 12.73633861541748, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28502/33198 [25:17<07:30, 10.43it/s]

{'loss': 4.0164, 'grad_norm': 9.644628524780273, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29002/33198 [25:42<03:30, 19.94it/s]

{'loss': 4.0125, 'grad_norm': 9.345446586608887, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29507/33198 [25:59<01:12, 50.92it/s]

{'loss': 4.0151, 'grad_norm': 9.146974563598633, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30002/33198 [26:19<01:11, 44.98it/s]

{'loss': 4.0498, 'grad_norm': 9.55467700958252, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30503/33198 [26:39<02:46, 16.16it/s]

{'loss': 4.0202, 'grad_norm': 10.084405899047852, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31008/33198 [26:55<00:48, 45.11it/s]

{'loss': 4.027, 'grad_norm': 8.763700485229492, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31505/33198 [27:13<00:41, 41.01it/s]

{'loss': 4.0249, 'grad_norm': 11.461695671081543, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32007/33198 [27:23<00:25, 46.40it/s]

{'loss': 4.0168, 'grad_norm': 9.292656898498535, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32504/33198 [27:35<00:33, 20.81it/s]

{'loss': 4.0502, 'grad_norm': 7.590266704559326, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33003/33198 [27:58<00:11, 16.70it/s]

{'loss': 4.0544, 'grad_norm': 8.869757652282715, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [28:12<00:00, 19.62it/s]


{'train_runtime': 1692.2427, 'train_samples_per_second': 156.935, 'train_steps_per_second': 19.618, 'train_loss': 4.16792212331774, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:25<00:00, 52.94it/s]
100%|██████████| 10570/10570 [00:53<00:00, 199.04it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 502/33198 [00:12<14:50, 36.73it/s]

{'loss': 5.0729, 'grad_norm': 7.941679000854492, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1005/33198 [00:25<12:58, 41.38it/s]

{'loss': 4.7247, 'grad_norm': 6.872503280639648, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1507/33198 [00:38<12:39, 41.72it/s]

{'loss': 4.6222, 'grad_norm': 10.825322151184082, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2004/33198 [00:57<11:11, 46.45it/s]  

{'loss': 4.5648, 'grad_norm': 6.597901821136475, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2507/33198 [01:16<11:17, 45.28it/s]  

{'loss': 4.5154, 'grad_norm': 7.078191757202148, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3003/33198 [01:32<13:01, 38.65it/s]

{'loss': 4.4652, 'grad_norm': 8.752741813659668, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3507/33198 [01:56<11:55, 41.50it/s]  

{'loss': 4.4624, 'grad_norm': 9.558876037597656, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4006/33198 [02:21<25:58, 18.74it/s]  

{'loss': 4.4264, 'grad_norm': 8.570479393005371, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4505/33198 [02:40<10:41, 44.75it/s]

{'loss': 4.3891, 'grad_norm': 7.386697769165039, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5007/33198 [03:06<11:44, 40.02it/s]  

{'loss': 4.346, 'grad_norm': 9.392831802368164, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5507/33198 [03:19<10:20, 44.63it/s]

{'loss': 4.3478, 'grad_norm': 8.139887809753418, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6005/33198 [03:32<11:51, 38.22it/s]

{'loss': 4.3279, 'grad_norm': 7.381692886352539, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6505/33198 [03:56<20:47, 21.39it/s]  

{'loss': 4.3086, 'grad_norm': 8.918241500854492, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7001/33198 [04:09<30:14, 14.44it/s]

{'loss': 4.293, 'grad_norm': 7.625133991241455, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7504/33198 [04:43<25:12, 16.99it/s]  

{'loss': 4.2811, 'grad_norm': 8.53279972076416, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8007/33198 [05:01<10:12, 41.14it/s]

{'loss': 4.2697, 'grad_norm': 8.860167503356934, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8500/33198 [05:28<38:30, 10.69it/s]

{'loss': 4.2592, 'grad_norm': 9.299895286560059, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9000/33198 [05:52<30:59, 13.01it/s]

{'loss': 4.2375, 'grad_norm': 8.359989166259766, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9500/33198 [06:24<12:10, 32.43it/s]

{'loss': 4.2304, 'grad_norm': 8.046449661254883, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10003/33198 [07:01<14:12, 27.22it/s]

{'loss': 4.2205, 'grad_norm': 9.207715034484863, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10501/33198 [07:29<53:35,  7.06it/s]

{'loss': 4.2184, 'grad_norm': 10.006412506103516, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11008/33198 [08:07<10:56, 33.81it/s]  

{'loss': 4.2143, 'grad_norm': 8.87746810913086, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11500/33198 [08:33<18:49, 19.21it/s]

{'loss': 4.1077, 'grad_norm': 8.445332527160645, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12000/33198 [09:00<22:28, 15.72it/s]

{'loss': 4.0897, 'grad_norm': 12.525162696838379, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12502/33198 [09:34<36:13,  9.52it/s]

{'loss': 4.077, 'grad_norm': 10.727319717407227, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13005/33198 [09:55<07:51, 42.84it/s]

{'loss': 4.07, 'grad_norm': 9.244545936584473, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13507/33198 [10:21<08:17, 39.60it/s]

{'loss': 4.0831, 'grad_norm': 11.254522323608398, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14006/33198 [10:39<07:21, 43.43it/s]

{'loss': 4.0767, 'grad_norm': 7.645788669586182, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14501/33198 [11:00<19:40, 15.84it/s]

{'loss': 4.1097, 'grad_norm': 7.787367343902588, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15005/33198 [11:25<06:46, 44.76it/s]

{'loss': 4.1001, 'grad_norm': 8.500865936279297, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15503/33198 [11:45<06:51, 42.95it/s]

{'loss': 4.0914, 'grad_norm': 8.647102355957031, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16007/33198 [12:16<06:42, 42.75it/s]

{'loss': 4.113, 'grad_norm': 9.514381408691406, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16500/33198 [12:48<20:02, 13.88it/s]

{'loss': 4.1044, 'grad_norm': 9.739919662475586, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17002/33198 [13:21<12:08, 22.22it/s]

{'loss': 4.0784, 'grad_norm': 10.08538818359375, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17503/33198 [13:51<12:06, 21.60it/s]

{'loss': 4.0983, 'grad_norm': 10.184470176696777, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18000/33198 [14:24<12:03, 21.00it/s]

{'loss': 4.1237, 'grad_norm': 8.355866432189941, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18501/33198 [14:41<07:42, 31.76it/s]

{'loss': 4.085, 'grad_norm': 8.298413276672363, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19003/33198 [15:14<13:47, 17.16it/s]

{'loss': 4.0786, 'grad_norm': 8.222920417785645, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19501/33198 [15:47<25:44,  8.87it/s]

{'loss': 4.0958, 'grad_norm': 7.7430500984191895, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20000/33198 [16:06<10:26, 21.06it/s]

{'loss': 4.0549, 'grad_norm': 9.160365104675293, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20507/33198 [16:37<04:47, 44.22it/s]

{'loss': 4.087, 'grad_norm': 8.935318946838379, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21008/33198 [16:52<05:01, 40.43it/s]

{'loss': 4.0957, 'grad_norm': 7.751334190368652, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21508/33198 [17:11<05:06, 38.14it/s]

{'loss': 4.0732, 'grad_norm': 9.467473983764648, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22007/33198 [17:24<05:44, 32.52it/s]

{'loss': 4.0778, 'grad_norm': 8.609932899475098, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22507/33198 [17:38<04:07, 43.26it/s]

{'loss': 4.0472, 'grad_norm': 8.542832374572754, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23005/33198 [17:52<04:09, 40.81it/s]

{'loss': 4.0409, 'grad_norm': 9.47598934173584, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23506/33198 [18:12<03:55, 41.21it/s]

{'loss': 4.0208, 'grad_norm': 8.947019577026367, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24008/33198 [18:26<03:33, 43.06it/s]

{'loss': 4.0164, 'grad_norm': 9.972373008728027, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24504/33198 [18:48<09:56, 14.56it/s]

{'loss': 4.005, 'grad_norm': 9.927330017089844, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25000/33198 [19:21<02:55, 46.72it/s]

{'loss': 4.0244, 'grad_norm': 7.521032333374023, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25509/33198 [19:46<03:14, 39.63it/s]

{'loss': 4.0036, 'grad_norm': 7.349842071533203, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26004/33198 [20:05<03:02, 39.44it/s]

{'loss': 4.0034, 'grad_norm': 10.096124649047852, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26501/33198 [20:32<02:35, 43.16it/s]

{'loss': 4.021, 'grad_norm': 11.71284294128418, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27001/33198 [20:59<03:41, 28.02it/s]

{'loss': 4.0056, 'grad_norm': 8.97580337524414, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27508/33198 [21:21<02:00, 47.16it/s]

{'loss': 4.0341, 'grad_norm': 6.88792610168457, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28001/33198 [21:42<12:22,  7.00it/s]

{'loss': 4.031, 'grad_norm': 9.918972969055176, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28506/33198 [22:00<02:33, 30.57it/s]

{'loss': 4.0258, 'grad_norm': 8.871109962463379, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [22:13<01:44, 40.20it/s]

{'loss': 4.0201, 'grad_norm': 8.763395309448242, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29501/33198 [22:32<06:26,  9.57it/s]

{'loss': 4.0167, 'grad_norm': 8.564107894897461, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30003/33198 [22:51<01:10, 45.46it/s]

{'loss': 4.0504, 'grad_norm': 9.909037590026855, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30504/33198 [23:10<01:16, 35.15it/s]

{'loss': 4.0244, 'grad_norm': 9.498963356018066, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31005/33198 [23:21<00:46, 47.22it/s]

{'loss': 4.0304, 'grad_norm': 8.68661880493164, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31505/33198 [23:50<00:38, 44.24it/s]

{'loss': 4.0227, 'grad_norm': 9.837090492248535, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32003/33198 [24:12<00:32, 37.21it/s]

{'loss': 4.0178, 'grad_norm': 9.263204574584961, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32500/33198 [24:27<00:24, 28.76it/s]

{'loss': 4.0504, 'grad_norm': 7.834678649902344, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33006/33198 [24:45<00:04, 42.55it/s]

{'loss': 4.0572, 'grad_norm': 8.40782642364502, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [24:57<00:00, 22.18it/s]


{'train_runtime': 1496.9999, 'train_samples_per_second': 177.403, 'train_steps_per_second': 22.176, 'train_loss': 4.170915504472573, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:17<00:00, 79.28it/s]
100%|██████████| 10570/10570 [00:28<00:00, 373.28it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 506/33198 [00:17<13:38, 39.96it/s]

{'loss': 5.0991, 'grad_norm': 5.834445476531982, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1001/33198 [00:40<53:31, 10.03it/s] 

{'loss': 4.7027, 'grad_norm': 5.7961530685424805, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1503/33198 [01:14<27:31, 19.20it/s]  

{'loss': 4.6118, 'grad_norm': 8.408381462097168, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2007/33198 [01:37<15:48, 32.89it/s]  

{'loss': 4.5528, 'grad_norm': 6.306215286254883, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2507/33198 [01:56<12:36, 40.58it/s]

{'loss': 4.5125, 'grad_norm': 6.756565093994141, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3008/33198 [02:14<15:11, 33.11it/s]

{'loss': 4.4651, 'grad_norm': 8.144278526306152, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3500/33198 [02:30<26:18, 18.81it/s]

{'loss': 4.4556, 'grad_norm': 9.603997230529785, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4005/33198 [02:53<11:53, 40.89it/s]  

{'loss': 4.4254, 'grad_norm': 8.648253440856934, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4508/33198 [03:20<10:52, 43.95it/s]  

{'loss': 4.3893, 'grad_norm': 7.650664806365967, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5000/33198 [03:48<19:38, 23.93it/s]

{'loss': 4.3442, 'grad_norm': 10.183305740356445, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5507/33198 [04:15<11:45, 39.27it/s]

{'loss': 4.3518, 'grad_norm': 7.716952323913574, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6007/33198 [04:33<13:39, 33.17it/s]

{'loss': 4.3289, 'grad_norm': 7.286598205566406, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6501/33198 [04:58<53:51,  8.26it/s]

{'loss': 4.3055, 'grad_norm': 8.29345417022705, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7006/33198 [05:16<09:33, 45.65it/s]

{'loss': 4.2936, 'grad_norm': 7.742722988128662, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7507/33198 [05:34<10:19, 41.44it/s]

{'loss': 4.2778, 'grad_norm': 9.46928882598877, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8010/33198 [05:48<09:34, 43.83it/s]

{'loss': 4.2759, 'grad_norm': 9.051563262939453, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8500/33198 [06:02<23:10, 17.76it/s]

{'loss': 4.2593, 'grad_norm': 9.149791717529297, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9007/33198 [06:31<21:31, 18.73it/s]

{'loss': 4.2295, 'grad_norm': 7.946518898010254, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9504/33198 [06:51<09:24, 41.98it/s]

{'loss': 4.2282, 'grad_norm': 8.156983375549316, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10006/33198 [07:13<08:37, 44.81it/s]

{'loss': 4.2199, 'grad_norm': 8.72265625, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10503/33198 [07:30<10:36, 35.63it/s]

{'loss': 4.2181, 'grad_norm': 10.077139854431152, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11001/33198 [07:59<45:24,  8.15it/s]

{'loss': 4.2064, 'grad_norm': 8.714138984680176, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11507/33198 [08:25<07:46, 46.51it/s]

{'loss': 4.105, 'grad_norm': 8.740378379821777, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12004/33198 [08:36<08:04, 43.73it/s]

{'loss': 4.0833, 'grad_norm': 11.82756519317627, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12505/33198 [08:48<07:24, 46.60it/s]

{'loss': 4.0729, 'grad_norm': 11.287479400634766, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13007/33198 [09:12<08:55, 37.70it/s]

{'loss': 4.0766, 'grad_norm': 9.669682502746582, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13506/33198 [09:38<08:01, 40.90it/s]

{'loss': 4.0812, 'grad_norm': 9.991632461547852, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14000/33198 [10:04<07:44, 41.30it/s]

{'loss': 4.0724, 'grad_norm': 8.136473655700684, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14505/33198 [10:17<07:51, 39.63it/s]

{'loss': 4.101, 'grad_norm': 7.6542534828186035, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15007/33198 [10:36<07:36, 39.84it/s]

{'loss': 4.0993, 'grad_norm': 9.363479614257812, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15506/33198 [10:48<07:59, 36.88it/s]

{'loss': 4.0908, 'grad_norm': 8.56434440612793, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16003/33198 [11:08<13:33, 21.14it/s]

{'loss': 4.1166, 'grad_norm': 8.786827087402344, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16505/33198 [11:21<07:14, 38.44it/s]

{'loss': 4.1049, 'grad_norm': 10.510222434997559, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17003/33198 [11:36<06:34, 41.01it/s]

{'loss': 4.0786, 'grad_norm': 8.381999969482422, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17504/33198 [11:59<20:19, 12.87it/s]

{'loss': 4.0932, 'grad_norm': 10.298502922058105, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18006/33198 [12:16<05:34, 45.40it/s]

{'loss': 4.1185, 'grad_norm': 8.285289764404297, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18506/33198 [12:28<05:48, 42.18it/s]

{'loss': 4.0896, 'grad_norm': 8.242209434509277, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19002/33198 [12:48<13:32, 17.48it/s]

{'loss': 4.073, 'grad_norm': 8.090459823608398, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19506/33198 [13:00<04:46, 47.75it/s]

{'loss': 4.0937, 'grad_norm': 7.42319393157959, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20002/33198 [13:19<23:07,  9.51it/s]

{'loss': 4.0587, 'grad_norm': 9.007667541503906, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20504/33198 [13:38<06:53, 30.69it/s]

{'loss': 4.0857, 'grad_norm': 8.722502708435059, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21004/33198 [13:51<05:31, 36.81it/s]

{'loss': 4.0983, 'grad_norm': 6.790106773376465, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21504/33198 [14:03<04:44, 41.04it/s]

{'loss': 4.0672, 'grad_norm': 9.484733581542969, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22004/33198 [14:19<05:47, 32.19it/s]

{'loss': 4.0811, 'grad_norm': 8.290836334228516, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22504/33198 [14:38<05:18, 33.53it/s]

{'loss': 4.0475, 'grad_norm': 8.382569313049316, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23005/33198 [14:52<04:50, 35.09it/s]

{'loss': 4.0407, 'grad_norm': 8.786572456359863, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23506/33198 [15:09<04:56, 32.67it/s]

{'loss': 4.0253, 'grad_norm': 9.053515434265137, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24007/33198 [15:32<03:42, 41.24it/s]

{'loss': 4.0174, 'grad_norm': 9.484186172485352, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24505/33198 [15:45<04:12, 34.39it/s]

{'loss': 3.9943, 'grad_norm': 8.779736518859863, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25006/33198 [16:00<03:21, 40.62it/s]

{'loss': 4.0216, 'grad_norm': 7.314587593078613, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25502/33198 [16:20<04:04, 31.47it/s]

{'loss': 4.0025, 'grad_norm': 8.182217597961426, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26001/33198 [16:38<10:25, 11.50it/s]

{'loss': 4.0017, 'grad_norm': 9.133003234863281, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26503/33198 [17:01<05:38, 19.76it/s]

{'loss': 4.023, 'grad_norm': 10.582517623901367, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27006/33198 [17:16<02:30, 41.05it/s]

{'loss': 3.9992, 'grad_norm': 10.038690567016602, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27504/33198 [17:28<02:04, 45.58it/s]

{'loss': 4.0312, 'grad_norm': 7.1152448654174805, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28001/33198 [17:58<05:56, 14.58it/s]

{'loss': 4.0313, 'grad_norm': 9.871419906616211, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28507/33198 [18:23<02:04, 37.62it/s]

{'loss': 4.0171, 'grad_norm': 8.22767162322998, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [18:53<01:48, 38.56it/s]

{'loss': 4.0146, 'grad_norm': 9.237774848937988, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29504/33198 [19:14<01:31, 40.49it/s]

{'loss': 4.0191, 'grad_norm': 8.053627014160156, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30001/33198 [19:36<04:39, 11.44it/s]

{'loss': 4.0568, 'grad_norm': 8.966979026794434, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30505/33198 [19:52<01:28, 30.36it/s]

{'loss': 4.0224, 'grad_norm': 8.943285942077637, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31003/33198 [20:05<00:55, 39.28it/s]

{'loss': 4.0298, 'grad_norm': 8.511804580688477, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31500/33198 [20:19<01:46, 15.91it/s]

{'loss': 4.0318, 'grad_norm': 10.546942710876465, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32005/33198 [20:42<00:28, 41.25it/s]

{'loss': 4.0172, 'grad_norm': 9.566107749938965, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32502/33198 [20:59<00:16, 42.12it/s]

{'loss': 4.0578, 'grad_norm': 7.778215408325195, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33003/33198 [21:20<00:06, 30.80it/s]

{'loss': 4.0555, 'grad_norm': 8.666788101196289, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [21:27<00:00, 25.79it/s]


{'train_runtime': 1287.2374, 'train_samples_per_second': 206.312, 'train_steps_per_second': 25.79, 'train_loss': 4.169689796898651, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:22<00:00, 59.99it/s]
100%|██████████| 10570/10570 [00:26<00:00, 397.31it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 506/33198 [00:14<17:02, 31.97it/s]

{'loss': 5.1223, 'grad_norm': 5.414699077606201, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1006/33198 [00:27<12:18, 43.60it/s]

{'loss': 4.7046, 'grad_norm': 6.818000793457031, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1508/33198 [00:42<12:27, 42.37it/s]

{'loss': 4.6006, 'grad_norm': 10.136534690856934, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2005/33198 [00:59<13:20, 38.98it/s]  

{'loss': 4.5472, 'grad_norm': 6.3734965324401855, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2505/33198 [01:12<14:19, 35.70it/s]

{'loss': 4.5078, 'grad_norm': 6.877419471740723, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3000/33198 [01:36<41:21, 12.17it/s]  

{'loss': 4.4586, 'grad_norm': 8.550073623657227, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3505/33198 [01:57<14:46, 33.51it/s]

{'loss': 4.456, 'grad_norm': 8.610538482666016, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4008/33198 [02:09<12:41, 38.33it/s]

{'loss': 4.4218, 'grad_norm': 8.33630084991455, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4506/33198 [02:24<11:34, 41.34it/s]

{'loss': 4.3905, 'grad_norm': 7.44803524017334, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5006/33198 [02:39<15:03, 31.22it/s]

{'loss': 4.3423, 'grad_norm': 9.64171314239502, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5504/33198 [02:53<13:20, 34.61it/s]

{'loss': 4.3511, 'grad_norm': 7.79118537902832, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6004/33198 [03:09<11:21, 39.93it/s]

{'loss': 4.3242, 'grad_norm': 7.247951507568359, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6502/33198 [03:24<13:01, 34.15it/s]

{'loss': 4.3043, 'grad_norm': 8.706296920776367, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7006/33198 [03:40<12:51, 33.93it/s]

{'loss': 4.2935, 'grad_norm': 7.399068355560303, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7505/33198 [03:54<09:59, 42.85it/s]

{'loss': 4.2767, 'grad_norm': 8.793529510498047, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8003/33198 [04:09<09:19, 45.07it/s]

{'loss': 4.2715, 'grad_norm': 8.837897300720215, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8507/33198 [04:22<09:58, 41.26it/s]

{'loss': 4.2513, 'grad_norm': 9.54263973236084, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9006/33198 [04:44<10:42, 37.64it/s]

{'loss': 4.2287, 'grad_norm': 7.910299777984619, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9505/33198 [04:56<09:35, 41.18it/s]

{'loss': 4.2265, 'grad_norm': 8.427882194519043, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10005/33198 [05:09<11:47, 32.80it/s]

{'loss': 4.2182, 'grad_norm': 8.798215866088867, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10508/33198 [05:21<09:32, 39.63it/s]

{'loss': 4.2168, 'grad_norm': 10.202123641967773, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11000/33198 [05:43<17:35, 21.03it/s]

{'loss': 4.2091, 'grad_norm': 8.551733016967773, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11501/33198 [06:18<43:11,  8.37it/s]

{'loss': 4.1006, 'grad_norm': 8.643393516540527, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12001/33198 [06:50<42:36,  8.29it/s]

{'loss': 4.0833, 'grad_norm': 12.093172073364258, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12501/33198 [07:24<35:31,  9.71it/s]

{'loss': 4.0752, 'grad_norm': 11.386963844299316, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13001/33198 [07:59<41:46,  8.06it/s]

{'loss': 4.066, 'grad_norm': 8.97738265991211, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13507/33198 [08:32<06:57, 47.12it/s]

{'loss': 4.0824, 'grad_norm': 10.780667304992676, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14002/33198 [08:48<30:30, 10.49it/s]

{'loss': 4.0726, 'grad_norm': 7.828829288482666, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14506/33198 [09:15<08:48, 35.34it/s]

{'loss': 4.1095, 'grad_norm': 7.648870944976807, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15000/33198 [09:47<23:39, 12.82it/s]

{'loss': 4.0992, 'grad_norm': 9.61574649810791, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15505/33198 [10:22<06:23, 46.16it/s]

{'loss': 4.0906, 'grad_norm': 8.450557708740234, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16007/33198 [10:46<06:52, 41.72it/s]

{'loss': 4.1176, 'grad_norm': 9.074934005737305, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16503/33198 [11:10<13:58, 19.91it/s]

{'loss': 4.1025, 'grad_norm': 9.929291725158691, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17006/33198 [11:35<07:43, 34.93it/s]

{'loss': 4.0756, 'grad_norm': 8.940652847290039, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17500/33198 [12:09<21:34, 12.13it/s]

{'loss': 4.0935, 'grad_norm': 9.919208526611328, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18000/33198 [12:39<21:25, 11.82it/s]

{'loss': 4.1219, 'grad_norm': 8.654227256774902, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18504/33198 [13:09<05:31, 44.34it/s]

{'loss': 4.0857, 'grad_norm': 8.549979209899902, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19003/33198 [13:43<08:50, 26.74it/s]

{'loss': 4.0689, 'grad_norm': 8.50699520111084, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19503/33198 [14:09<05:13, 43.70it/s]

{'loss': 4.0928, 'grad_norm': 7.158329963684082, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20005/33198 [14:31<07:11, 30.57it/s]

{'loss': 4.0577, 'grad_norm': 9.2674560546875, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20506/33198 [15:08<05:26, 38.85it/s]

{'loss': 4.0803, 'grad_norm': 8.797717094421387, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21003/33198 [15:20<04:22, 46.48it/s]

{'loss': 4.0939, 'grad_norm': 7.293040752410889, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21503/33198 [15:37<05:44, 33.99it/s]

{'loss': 4.0713, 'grad_norm': 10.092254638671875, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22006/33198 [16:06<04:43, 39.54it/s]

{'loss': 4.0775, 'grad_norm': 8.67422103881836, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22500/33198 [16:40<15:09, 11.76it/s]

{'loss': 4.0465, 'grad_norm': 8.647361755371094, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23005/33198 [16:57<09:25, 18.02it/s]

{'loss': 4.0373, 'grad_norm': 9.026477813720703, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23507/33198 [17:13<07:42, 20.95it/s]

{'loss': 4.0194, 'grad_norm': 9.165351867675781, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24005/33198 [17:37<05:32, 27.65it/s]

{'loss': 4.0113, 'grad_norm': 10.34458065032959, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24500/33198 [17:59<15:24,  9.41it/s]

{'loss': 3.9958, 'grad_norm': 9.71119213104248, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25004/33198 [18:28<03:18, 41.20it/s]

{'loss': 4.025, 'grad_norm': 7.024103164672852, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25501/33198 [18:58<15:45,  8.14it/s]

{'loss': 3.9983, 'grad_norm': 8.05852222442627, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26001/33198 [19:35<13:15,  9.05it/s]

{'loss': 4.002, 'grad_norm': 10.028365135192871, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26504/33198 [20:01<04:13, 26.42it/s]

{'loss': 4.0248, 'grad_norm': 10.826424598693848, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27000/33198 [20:19<06:49, 15.15it/s]

{'loss': 3.9997, 'grad_norm': 9.548158645629883, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27501/33198 [21:00<11:06,  8.55it/s]

{'loss': 4.0331, 'grad_norm': 7.77686071395874, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28005/33198 [21:29<05:04, 17.07it/s]

{'loss': 4.0338, 'grad_norm': 9.795063018798828, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28501/33198 [21:48<10:11,  7.68it/s]

{'loss': 4.0199, 'grad_norm': 8.856990814208984, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [22:10<01:31, 46.08it/s]

{'loss': 4.0142, 'grad_norm': 9.098101615905762, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29506/33198 [22:22<02:38, 23.29it/s]

{'loss': 4.0164, 'grad_norm': 8.84589958190918, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30006/33198 [22:33<01:09, 45.75it/s]

{'loss': 4.0511, 'grad_norm': 9.08362102508545, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30501/33198 [22:48<03:37, 12.40it/s]

{'loss': 4.0226, 'grad_norm': 9.486519813537598, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31007/33198 [23:04<00:50, 43.31it/s]

{'loss': 4.0279, 'grad_norm': 7.953108310699463, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31502/33198 [23:22<00:56, 29.95it/s]

{'loss': 4.027, 'grad_norm': 10.913875579833984, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32008/33198 [23:47<00:38, 30.92it/s]

{'loss': 4.0214, 'grad_norm': 9.034618377685547, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32506/33198 [24:05<00:16, 40.92it/s]

{'loss': 4.0531, 'grad_norm': 7.891056060791016, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33006/33198 [24:17<00:04, 45.05it/s]

{'loss': 4.0525, 'grad_norm': 8.88704776763916, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [24:25<00:00, 22.65it/s]


{'train_runtime': 1465.9939, 'train_samples_per_second': 181.155, 'train_steps_per_second': 22.645, 'train_loss': 4.168477348724298, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:16<00:00, 80.26it/s]
100%|██████████| 10570/10570 [00:23<00:00, 446.39it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 507/33198 [00:13<14:44, 36.94it/s]

{'loss': 5.1652, 'grad_norm': 5.35341215133667, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1004/33198 [00:32<16:17, 32.93it/s] 

{'loss': 4.7197, 'grad_norm': 5.920331954956055, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1507/33198 [00:51<15:02, 35.12it/s]  

{'loss': 4.6036, 'grad_norm': 9.216773986816406, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2001/33198 [01:19<58:09,  8.94it/s]  

{'loss': 4.5473, 'grad_norm': 6.176753044128418, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2505/33198 [01:44<18:09, 28.16it/s]  

{'loss': 4.5061, 'grad_norm': 6.648797035217285, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3006/33198 [02:05<13:54, 36.16it/s]

{'loss': 4.4616, 'grad_norm': 7.878019332885742, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3505/33198 [02:31<12:16, 40.34it/s]  

{'loss': 4.4555, 'grad_norm': 7.837931156158447, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4002/33198 [02:51<29:40, 16.40it/s]  

{'loss': 4.4165, 'grad_norm': 7.94058895111084, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4508/33198 [03:07<12:19, 38.80it/s]

{'loss': 4.3927, 'grad_norm': 7.39168643951416, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5003/33198 [03:31<27:56, 16.81it/s]

{'loss': 4.3414, 'grad_norm': 9.316086769104004, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5500/33198 [03:55<27:11, 16.98it/s]  

{'loss': 4.3418, 'grad_norm': 7.830732345581055, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6007/33198 [04:19<22:00, 20.59it/s]

{'loss': 4.3209, 'grad_norm': 7.122154712677002, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6507/33198 [04:31<11:11, 39.77it/s]

{'loss': 4.3035, 'grad_norm': 8.36758041381836, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7005/33198 [04:43<10:39, 40.97it/s]

{'loss': 4.2931, 'grad_norm': 7.527405261993408, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7506/33198 [04:56<10:28, 40.87it/s]

{'loss': 4.276, 'grad_norm': 8.975018501281738, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8003/33198 [05:16<19:01, 22.06it/s]

{'loss': 4.2686, 'grad_norm': 8.126352310180664, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8507/33198 [05:33<09:42, 42.41it/s]

{'loss': 4.2551, 'grad_norm': 9.822023391723633, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9005/33198 [05:48<10:38, 37.87it/s]

{'loss': 4.2295, 'grad_norm': 7.88865327835083, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9508/33198 [06:00<09:41, 40.74it/s]

{'loss': 4.2244, 'grad_norm': 8.021478652954102, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10003/33198 [06:13<11:56, 32.36it/s]

{'loss': 4.2165, 'grad_norm': 8.627165794372559, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10500/33198 [06:30<33:32, 11.28it/s]

{'loss': 4.2123, 'grad_norm': 9.449482917785645, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11006/33198 [07:04<08:46, 42.12it/s]

{'loss': 4.2038, 'grad_norm': 8.25234317779541, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11505/33198 [07:37<08:25, 42.89it/s]

{'loss': 4.1028, 'grad_norm': 8.729914665222168, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12002/33198 [07:52<15:29, 22.79it/s]

{'loss': 4.0836, 'grad_norm': 11.559292793273926, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12503/33198 [08:06<11:24, 30.24it/s]

{'loss': 4.0759, 'grad_norm': 10.795162200927734, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13006/33198 [08:27<08:12, 41.00it/s]

{'loss': 4.0649, 'grad_norm': 9.08052921295166, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13501/33198 [08:51<26:59, 12.16it/s]

{'loss': 4.0763, 'grad_norm': 11.236367225646973, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14004/33198 [09:12<08:18, 38.52it/s]

{'loss': 4.0698, 'grad_norm': 7.647765636444092, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14504/33198 [09:25<08:33, 36.43it/s]

{'loss': 4.1028, 'grad_norm': 7.443155288696289, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15002/33198 [09:39<08:39, 35.00it/s]

{'loss': 4.0993, 'grad_norm': 9.083023071289062, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15506/33198 [09:57<07:45, 37.99it/s]

{'loss': 4.087, 'grad_norm': 8.030420303344727, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16004/33198 [10:11<08:17, 34.57it/s]

{'loss': 4.1117, 'grad_norm': 8.613899230957031, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16504/33198 [10:25<08:05, 34.42it/s]

{'loss': 4.103, 'grad_norm': 9.191893577575684, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17003/33198 [10:38<07:28, 36.13it/s]

{'loss': 4.0737, 'grad_norm': 9.544635772705078, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17507/33198 [10:52<06:30, 40.17it/s]

{'loss': 4.0925, 'grad_norm': 10.312397003173828, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18004/33198 [11:06<07:07, 35.53it/s]

{'loss': 4.1209, 'grad_norm': 8.412800788879395, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18505/33198 [11:20<06:49, 35.86it/s]

{'loss': 4.0833, 'grad_norm': 7.658194541931152, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19005/33198 [11:34<06:54, 34.27it/s]

{'loss': 4.0716, 'grad_norm': 8.291237831115723, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19506/33198 [11:49<06:25, 35.53it/s]

{'loss': 4.0976, 'grad_norm': 7.008786678314209, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20002/33198 [12:03<07:27, 29.50it/s]

{'loss': 4.0568, 'grad_norm': 9.158574104309082, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20505/33198 [12:17<05:41, 37.16it/s]

{'loss': 4.0793, 'grad_norm': 9.067159652709961, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21006/33198 [12:31<05:37, 36.09it/s]

{'loss': 4.0935, 'grad_norm': 6.86757230758667, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21504/33198 [12:45<05:21, 36.33it/s]

{'loss': 4.0659, 'grad_norm': 8.78319263458252, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22002/33198 [12:59<07:46, 24.02it/s]

{'loss': 4.0736, 'grad_norm': 8.411920547485352, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22502/33198 [13:14<05:11, 34.32it/s]

{'loss': 4.0505, 'grad_norm': 7.948464393615723, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23002/33198 [13:30<04:23, 38.67it/s]

{'loss': 4.041, 'grad_norm': 8.989825248718262, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23504/33198 [13:53<04:39, 34.71it/s]

{'loss': 4.0157, 'grad_norm': 8.818476676940918, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24007/33198 [14:08<04:25, 34.66it/s]

{'loss': 4.0116, 'grad_norm': 10.309024810791016, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24503/33198 [14:28<07:31, 19.24it/s]

{'loss': 3.9892, 'grad_norm': 9.254508972167969, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25003/33198 [14:49<04:16, 31.91it/s]

{'loss': 4.021, 'grad_norm': 7.024820804595947, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25506/33198 [15:05<04:06, 31.22it/s]

{'loss': 3.9991, 'grad_norm': 7.994571685791016, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26004/33198 [15:23<03:30, 34.17it/s]

{'loss': 4.0065, 'grad_norm': 9.437405586242676, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26506/33198 [15:45<03:06, 35.81it/s]

{'loss': 4.0272, 'grad_norm': 10.51068115234375, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27006/33198 [16:06<02:40, 38.52it/s]

{'loss': 4.0024, 'grad_norm': 9.30744457244873, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27504/33198 [16:21<02:36, 36.34it/s]

{'loss': 4.0287, 'grad_norm': 6.8978681564331055, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28005/33198 [16:44<02:24, 35.87it/s]

{'loss': 4.0285, 'grad_norm': 10.378766059875488, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28506/33198 [17:00<02:13, 35.05it/s]

{'loss': 4.0211, 'grad_norm': 8.685905456542969, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [17:18<01:37, 42.90it/s]

{'loss': 4.0168, 'grad_norm': 8.445740699768066, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29506/33198 [17:36<01:31, 40.33it/s]

{'loss': 4.0204, 'grad_norm': 8.455123901367188, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30004/33198 [17:49<01:23, 38.38it/s]

{'loss': 4.0507, 'grad_norm': 8.335433959960938, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30503/33198 [18:01<01:13, 36.88it/s]

{'loss': 4.0236, 'grad_norm': 9.302851676940918, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31009/33198 [18:14<00:55, 39.50it/s]

{'loss': 4.0272, 'grad_norm': 8.624465942382812, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31508/33198 [18:27<00:42, 39.67it/s]

{'loss': 4.0259, 'grad_norm': 10.057394981384277, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32005/33198 [18:45<00:33, 35.26it/s]

{'loss': 4.019, 'grad_norm': 8.779182434082031, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32505/33198 [18:57<00:19, 36.35it/s]

{'loss': 4.0552, 'grad_norm': 8.26949691772461, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33004/33198 [19:11<00:05, 35.39it/s]

{'loss': 4.0627, 'grad_norm': 8.925567626953125, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [19:16<00:00, 28.70it/s]


{'train_runtime': 1156.7251, 'train_samples_per_second': 229.59, 'train_steps_per_second': 28.7, 'train_loss': 4.168669657819536, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:45<00:00, 29.75it/s]
100%|██████████| 10570/10570 [00:34<00:00, 308.25it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 504/33198 [00:13<11:44, 46.42it/s]

{'loss': 5.1795, 'grad_norm': 6.254992961883545, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1004/33198 [00:25<14:38, 36.64it/s]

{'loss': 4.7135, 'grad_norm': 7.177899360656738, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1505/33198 [00:37<14:25, 36.61it/s]

{'loss': 4.6144, 'grad_norm': 8.292195320129395, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2004/33198 [00:50<14:11, 36.62it/s]

{'loss': 4.5474, 'grad_norm': 6.929385662078857, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2507/33198 [01:04<13:46, 37.14it/s]

{'loss': 4.5037, 'grad_norm': 7.352331161499023, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3006/33198 [01:20<10:41, 47.04it/s]

{'loss': 4.466, 'grad_norm': 9.526460647583008, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3506/33198 [01:51<14:35, 33.91it/s]  

{'loss': 4.4621, 'grad_norm': 8.213481903076172, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4003/33198 [02:10<13:33, 35.87it/s]

{'loss': 4.4224, 'grad_norm': 9.231157302856445, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4504/33198 [02:21<11:55, 40.10it/s]

{'loss': 4.3917, 'grad_norm': 8.276362419128418, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5005/33198 [02:32<10:44, 43.71it/s]

{'loss': 4.3547, 'grad_norm': 10.070159912109375, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5506/33198 [02:43<08:59, 51.30it/s]

{'loss': 4.3613, 'grad_norm': 7.645312309265137, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6005/33198 [02:54<09:49, 46.09it/s]

{'loss': 4.3247, 'grad_norm': 7.939721584320068, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6506/33198 [03:04<09:51, 45.13it/s]

{'loss': 4.3054, 'grad_norm': 8.64254093170166, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7004/33198 [03:15<09:43, 44.90it/s]

{'loss': 4.2984, 'grad_norm': 8.187499046325684, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7508/33198 [03:26<09:15, 46.25it/s]

{'loss': 4.2859, 'grad_norm': 9.308192253112793, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8003/33198 [03:41<23:57, 17.52it/s]

{'loss': 4.2787, 'grad_norm': 8.707722663879395, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8504/33198 [03:56<11:51, 34.69it/s]

{'loss': 4.251, 'grad_norm': 10.750065803527832, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9003/33198 [04:15<11:45, 34.28it/s]

{'loss': 4.2396, 'grad_norm': 8.395997047424316, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9504/33198 [04:29<09:25, 41.89it/s]

{'loss': 4.2251, 'grad_norm': 7.9275898933410645, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10001/33198 [04:43<27:58, 13.82it/s]

{'loss': 4.2159, 'grad_norm': 10.159613609313965, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10505/33198 [05:14<16:42, 22.64it/s]

{'loss': 4.2215, 'grad_norm': 9.770827293395996, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11002/33198 [05:30<21:15, 17.41it/s]

{'loss': 4.2127, 'grad_norm': 9.952482223510742, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11506/33198 [05:45<09:39, 37.44it/s]

{'loss': 4.1079, 'grad_norm': 9.061357498168945, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12002/33198 [06:15<33:42, 10.48it/s]

{'loss': 4.0869, 'grad_norm': 12.939123153686523, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12505/33198 [06:38<12:41, 27.19it/s]

{'loss': 4.0652, 'grad_norm': 11.846282958984375, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13001/33198 [06:57<32:00, 10.52it/s]

{'loss': 4.0678, 'grad_norm': 9.681297302246094, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13504/33198 [07:29<14:12, 23.10it/s]

{'loss': 4.0871, 'grad_norm': 10.733711242675781, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14007/33198 [08:09<07:04, 45.22it/s]

{'loss': 4.0684, 'grad_norm': 8.03077507019043, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14505/33198 [08:26<09:40, 32.18it/s]

{'loss': 4.1097, 'grad_norm': 8.193658828735352, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15007/33198 [08:53<09:17, 32.61it/s]

{'loss': 4.0959, 'grad_norm': 10.741477966308594, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15504/33198 [09:05<07:36, 38.79it/s]

{'loss': 4.1009, 'grad_norm': 8.177634239196777, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16003/33198 [09:22<08:39, 33.11it/s]

{'loss': 4.1188, 'grad_norm': 9.04446029663086, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16504/33198 [09:38<06:14, 44.56it/s]

{'loss': 4.1027, 'grad_norm': 10.752859115600586, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17007/33198 [10:00<06:06, 44.14it/s]

{'loss': 4.0816, 'grad_norm': 10.358766555786133, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17500/33198 [10:22<24:22, 10.73it/s]

{'loss': 4.0922, 'grad_norm': 11.210701942443848, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18001/33198 [10:46<16:12, 15.63it/s]

{'loss': 4.1185, 'grad_norm': 8.756023406982422, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18506/33198 [11:05<07:30, 32.62it/s]

{'loss': 4.0904, 'grad_norm': 8.447178840637207, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19005/33198 [11:19<05:51, 40.41it/s]

{'loss': 4.0775, 'grad_norm': 8.526969909667969, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19506/33198 [11:35<13:05, 17.43it/s]

{'loss': 4.0935, 'grad_norm': 7.853848457336426, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20006/33198 [11:47<04:58, 44.13it/s]

{'loss': 4.0572, 'grad_norm': 9.321426391601562, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20507/33198 [11:59<04:51, 43.52it/s]

{'loss': 4.0812, 'grad_norm': 9.615680694580078, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21009/33198 [12:13<04:30, 45.05it/s]

{'loss': 4.0917, 'grad_norm': 7.78338623046875, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21500/33198 [12:37<08:12, 23.77it/s]

{'loss': 4.0634, 'grad_norm': 10.1216459274292, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22000/33198 [13:16<18:02, 10.35it/s]

{'loss': 4.0707, 'grad_norm': 8.669027328491211, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22507/33198 [13:29<03:37, 49.09it/s]

{'loss': 4.039, 'grad_norm': 8.595809936523438, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23006/33198 [13:41<03:53, 43.67it/s]

{'loss': 4.0318, 'grad_norm': 9.351387023925781, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23507/33198 [13:52<03:56, 40.95it/s]

{'loss': 4.0124, 'grad_norm': 9.297517776489258, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24001/33198 [14:03<10:07, 15.15it/s]

{'loss': 4.0081, 'grad_norm': 11.126371383666992, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24500/33198 [14:32<02:47, 51.96it/s]

{'loss': 4.0001, 'grad_norm': 9.153526306152344, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25008/33198 [14:44<02:46, 49.08it/s]

{'loss': 4.017, 'grad_norm': 7.226767063140869, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25508/33198 [14:55<02:28, 51.93it/s]

{'loss': 4.0014, 'grad_norm': 8.286852836608887, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26007/33198 [15:08<02:32, 47.27it/s]

{'loss': 4.0038, 'grad_norm': 9.406195640563965, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26507/33198 [15:20<02:33, 43.57it/s]

{'loss': 4.0171, 'grad_norm': 12.307840347290039, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27003/33198 [15:38<05:22, 19.21it/s]

{'loss': 4.0037, 'grad_norm': 9.362841606140137, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27502/33198 [15:57<02:34, 36.77it/s]

{'loss': 4.0273, 'grad_norm': 9.017616271972656, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28001/33198 [16:15<03:07, 27.68it/s]

{'loss': 4.0257, 'grad_norm': 11.74907112121582, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28501/33198 [16:31<02:00, 38.91it/s]

{'loss': 4.0142, 'grad_norm': 10.005936622619629, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [16:47<02:48, 24.85it/s]

{'loss': 4.0133, 'grad_norm': 9.037223815917969, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29505/33198 [17:02<02:19, 26.38it/s]

{'loss': 4.0137, 'grad_norm': 9.905045509338379, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30005/33198 [17:20<02:38, 20.16it/s]

{'loss': 4.046, 'grad_norm': 8.964784622192383, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30502/33198 [17:46<03:16, 13.75it/s]

{'loss': 4.0216, 'grad_norm': 9.485424041748047, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31002/33198 [18:13<03:21, 10.88it/s]

{'loss': 4.0261, 'grad_norm': 9.373808860778809, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31500/33198 [18:40<01:51, 15.18it/s]

{'loss': 4.023, 'grad_norm': 10.402161598205566, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32006/33198 [19:05<00:58, 20.25it/s]

{'loss': 4.0093, 'grad_norm': 8.97222900390625, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32500/33198 [19:30<00:14, 47.56it/s]

{'loss': 4.0498, 'grad_norm': 7.984977722167969, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33007/33198 [19:56<00:06, 27.78it/s]

{'loss': 4.0516, 'grad_norm': 8.861978530883789, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [20:09<00:00, 27.45it/s]


{'train_runtime': 1209.3939, 'train_samples_per_second': 219.591, 'train_steps_per_second': 27.45, 'train_loss': 4.169787449839316, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:19<00:00, 70.33it/s] 
100%|██████████| 10570/10570 [00:51<00:00, 206.86it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 500/33198 [00:14<25:38, 21.25it/s]

{'loss': 5.1765, 'grad_norm': 7.014313220977783, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1001/33198 [00:37<1:06:44,  8.04it/s]

{'loss': 4.7133, 'grad_norm': 7.358351230621338, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1501/33198 [01:03<16:34, 31.86it/s]  

{'loss': 4.6248, 'grad_norm': 8.83659553527832, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2003/33198 [01:27<13:01, 39.90it/s]  

{'loss': 4.557, 'grad_norm': 7.243553638458252, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2503/33198 [01:47<24:57, 20.50it/s]

{'loss': 4.5173, 'grad_norm': 6.93038272857666, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3005/33198 [02:02<12:55, 38.92it/s]

{'loss': 4.4636, 'grad_norm': 8.193278312683105, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3503/33198 [02:16<12:41, 39.02it/s]

{'loss': 4.4697, 'grad_norm': 7.019516468048096, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4006/33198 [02:29<10:13, 47.59it/s]

{'loss': 4.4249, 'grad_norm': 8.234868049621582, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4508/33198 [02:42<11:14, 42.56it/s]

{'loss': 4.3959, 'grad_norm': 7.741369724273682, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5006/33198 [02:54<11:31, 40.77it/s]

{'loss': 4.3527, 'grad_norm': 10.754257202148438, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5503/33198 [03:06<12:45, 36.19it/s]

{'loss': 4.3559, 'grad_norm': 7.7927165031433105, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6007/33198 [03:21<11:18, 40.07it/s]

{'loss': 4.3264, 'grad_norm': 7.5743513107299805, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6507/33198 [03:37<10:54, 40.79it/s]

{'loss': 4.3135, 'grad_norm': 7.964633941650391, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7004/33198 [03:50<10:59, 39.71it/s]

{'loss': 4.2964, 'grad_norm': 7.863864898681641, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7504/33198 [04:04<11:07, 38.48it/s]

{'loss': 4.2924, 'grad_norm': 9.626738548278809, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8003/33198 [04:17<10:15, 40.95it/s]

{'loss': 4.2732, 'grad_norm': 8.714133262634277, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8500/33198 [04:41<24:19, 16.92it/s]

{'loss': 4.2554, 'grad_norm': 9.504000663757324, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9006/33198 [05:10<16:46, 24.03it/s]

{'loss': 4.2422, 'grad_norm': 8.626008987426758, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9506/33198 [05:31<11:58, 32.99it/s]

{'loss': 4.2306, 'grad_norm': 8.230433464050293, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10001/33198 [05:47<46:08,  8.38it/s]

{'loss': 4.2224, 'grad_norm': 9.006218910217285, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10505/33198 [06:15<09:53, 38.20it/s]

{'loss': 4.2277, 'grad_norm': 9.547746658325195, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11002/33198 [06:27<09:03, 40.86it/s]

{'loss': 4.2122, 'grad_norm': 8.921868324279785, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11505/33198 [06:39<09:48, 36.83it/s]

{'loss': 4.1157, 'grad_norm': 8.629565238952637, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12006/33198 [06:52<09:38, 36.66it/s]

{'loss': 4.0902, 'grad_norm': 13.255416870117188, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12506/33198 [07:10<08:53, 38.77it/s]

{'loss': 4.074, 'grad_norm': 10.108229637145996, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13007/33198 [07:33<13:35, 24.77it/s]

{'loss': 4.0785, 'grad_norm': 9.21137523651123, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13506/33198 [07:50<07:31, 43.61it/s]

{'loss': 4.0921, 'grad_norm': 9.819316864013672, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14000/33198 [08:07<18:34, 17.23it/s]

{'loss': 4.0735, 'grad_norm': 7.978635787963867, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14507/33198 [08:30<08:57, 34.75it/s]

{'loss': 4.1071, 'grad_norm': 8.42748737335205, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15003/33198 [08:44<09:15, 32.77it/s]

{'loss': 4.0994, 'grad_norm': 10.063434600830078, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15500/33198 [08:58<17:31, 16.83it/s]

{'loss': 4.1024, 'grad_norm': 8.467025756835938, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16005/33198 [09:19<07:56, 36.10it/s]

{'loss': 4.12, 'grad_norm': 8.668538093566895, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16505/33198 [09:33<08:47, 31.67it/s]

{'loss': 4.1071, 'grad_norm': 10.079604148864746, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17001/33198 [09:47<29:10,  9.25it/s]

{'loss': 4.0807, 'grad_norm': 9.223477363586426, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17505/33198 [10:26<06:03, 43.23it/s]

{'loss': 4.0979, 'grad_norm': 10.609026908874512, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18005/33198 [10:38<06:20, 39.96it/s]

{'loss': 4.1178, 'grad_norm': 9.912605285644531, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18502/33198 [10:52<07:35, 32.28it/s]

{'loss': 4.0893, 'grad_norm': 7.919819355010986, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19003/33198 [11:09<07:29, 31.61it/s]

{'loss': 4.0763, 'grad_norm': 8.49129867553711, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19500/33198 [11:29<11:11, 20.41it/s]

{'loss': 4.1013, 'grad_norm': 7.389713764190674, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20006/33198 [11:48<09:22, 23.46it/s]

{'loss': 4.0594, 'grad_norm': 8.756551742553711, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20504/33198 [12:07<04:44, 44.55it/s]

{'loss': 4.0845, 'grad_norm': 8.886938095092773, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21008/33198 [12:20<04:27, 45.60it/s]

{'loss': 4.0989, 'grad_norm': 7.256392002105713, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21508/33198 [12:33<05:11, 37.51it/s]

{'loss': 4.0688, 'grad_norm': 9.113428115844727, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22005/33198 [12:46<05:06, 36.47it/s]

{'loss': 4.0733, 'grad_norm': 8.544156074523926, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22504/33198 [13:02<04:17, 41.53it/s]

{'loss': 4.0471, 'grad_norm': 8.657062530517578, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23003/33198 [13:21<03:51, 44.02it/s]

{'loss': 4.0384, 'grad_norm': 9.719781875610352, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23505/33198 [13:34<03:55, 41.22it/s]

{'loss': 4.0186, 'grad_norm': 8.159346580505371, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24001/33198 [13:55<20:19,  7.54it/s]

{'loss': 4.0107, 'grad_norm': 12.816184997558594, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24507/33198 [14:16<03:32, 40.92it/s]

{'loss': 4.0014, 'grad_norm': 10.092711448669434, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25004/33198 [14:29<03:05, 44.19it/s]

{'loss': 4.0267, 'grad_norm': 7.738800048828125, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25500/33198 [14:42<06:29, 19.76it/s]

{'loss': 4.0048, 'grad_norm': 7.486096382141113, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26007/33198 [14:56<03:15, 36.80it/s]

{'loss': 4.0061, 'grad_norm': 9.327330589294434, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26506/33198 [15:08<02:22, 46.90it/s]

{'loss': 4.027, 'grad_norm': 12.0715913772583, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27006/33198 [15:21<02:38, 39.04it/s]

{'loss': 4.0096, 'grad_norm': 9.3324613571167, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27507/33198 [15:34<02:26, 38.90it/s]

{'loss': 4.0314, 'grad_norm': 8.248950004577637, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28002/33198 [15:51<07:49, 11.06it/s]

{'loss': 4.032, 'grad_norm': 10.849385261535645, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28507/33198 [16:10<01:36, 48.70it/s]

{'loss': 4.0179, 'grad_norm': 9.548460006713867, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29003/33198 [16:23<01:54, 36.79it/s]

{'loss': 4.0206, 'grad_norm': 9.290727615356445, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29504/33198 [16:35<01:21, 45.36it/s]

{'loss': 4.0216, 'grad_norm': 8.789266586303711, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30005/33198 [16:48<01:15, 42.50it/s]

{'loss': 4.054, 'grad_norm': 8.914877891540527, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30506/33198 [17:02<01:39, 27.19it/s]

{'loss': 4.019, 'grad_norm': 9.908989906311035, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31006/33198 [17:15<00:49, 44.25it/s]

{'loss': 4.0362, 'grad_norm': 9.137788772583008, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31507/33198 [17:29<00:34, 49.40it/s]

{'loss': 4.0272, 'grad_norm': 9.55882453918457, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32007/33198 [17:40<00:27, 44.04it/s]

{'loss': 4.0268, 'grad_norm': 8.893476486206055, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32506/33198 [17:53<00:16, 42.38it/s]

{'loss': 4.0534, 'grad_norm': 8.28627872467041, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33006/33198 [18:12<00:05, 35.47it/s]

{'loss': 4.0569, 'grad_norm': 9.711369514465332, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [18:17<00:00, 30.24it/s]


{'train_runtime': 1097.7646, 'train_samples_per_second': 241.921, 'train_steps_per_second': 30.241, 'train_loss': 4.174039788185541, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:23<00:00, 57.24it/s]
100%|██████████| 10570/10570 [00:24<00:00, 425.52it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 507/33198 [00:13<12:51, 42.36it/s]

{'loss': 5.1973, 'grad_norm': 5.812710762023926, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1009/33198 [00:26<13:24, 40.00it/s]

{'loss': 4.7033, 'grad_norm': 8.673225402832031, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1506/33198 [00:39<13:21, 39.52it/s]

{'loss': 4.6148, 'grad_norm': 8.949975967407227, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2007/33198 [00:52<13:24, 38.76it/s]

{'loss': 4.5556, 'grad_norm': 6.674206733703613, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2505/33198 [01:20<12:48, 39.93it/s]  

{'loss': 4.5107, 'grad_norm': 7.264606952667236, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3008/33198 [01:32<11:15, 44.70it/s]

{'loss': 4.4665, 'grad_norm': 8.772937774658203, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3505/33198 [01:45<16:24, 30.16it/s]

{'loss': 4.4636, 'grad_norm': 7.439962863922119, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4007/33198 [02:04<11:12, 43.41it/s]  

{'loss': 4.4164, 'grad_norm': 8.324767112731934, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4503/33198 [02:20<13:42, 34.89it/s]

{'loss': 4.3921, 'grad_norm': 7.875037670135498, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5005/33198 [02:35<10:51, 43.29it/s]

{'loss': 4.346, 'grad_norm': 10.167865753173828, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5504/33198 [02:47<10:58, 42.06it/s]

{'loss': 4.3558, 'grad_norm': 7.721701145172119, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6003/33198 [03:02<10:39, 42.56it/s]

{'loss': 4.3256, 'grad_norm': 7.515357971191406, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6505/33198 [03:19<13:54, 32.00it/s]

{'loss': 4.3132, 'grad_norm': 8.361444473266602, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7005/33198 [03:35<13:28, 32.39it/s]

{'loss': 4.2997, 'grad_norm': 7.924822807312012, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7504/33198 [03:49<10:24, 41.16it/s]

{'loss': 4.2847, 'grad_norm': 8.876337051391602, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8010/33198 [04:01<09:19, 45.03it/s]

{'loss': 4.2726, 'grad_norm': 8.007807731628418, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8506/33198 [04:13<09:38, 42.67it/s]

{'loss': 4.2554, 'grad_norm': 9.899553298950195, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9007/33198 [04:25<09:04, 44.45it/s]

{'loss': 4.2484, 'grad_norm': 8.519908905029297, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9506/33198 [04:39<08:17, 47.61it/s]

{'loss': 4.2282, 'grad_norm': 7.797108173370361, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10007/33198 [04:51<08:36, 44.91it/s]

{'loss': 4.2203, 'grad_norm': 8.756908416748047, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10506/33198 [05:03<09:10, 41.18it/s]

{'loss': 4.2249, 'grad_norm': 9.240748405456543, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11007/33198 [05:19<09:03, 40.80it/s]

{'loss': 4.214, 'grad_norm': 8.871993064880371, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11507/33198 [05:31<08:26, 42.80it/s]

{'loss': 4.1125, 'grad_norm': 8.41720962524414, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12006/33198 [05:43<08:26, 41.82it/s]

{'loss': 4.0922, 'grad_norm': 12.446883201599121, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12505/33198 [05:56<07:53, 43.67it/s]

{'loss': 4.0741, 'grad_norm': 10.822868347167969, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13007/33198 [06:08<07:58, 42.16it/s]

{'loss': 4.0771, 'grad_norm': 9.328271865844727, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13507/33198 [06:20<07:41, 42.65it/s]

{'loss': 4.0844, 'grad_norm': 9.745953559875488, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14006/33198 [06:32<08:36, 37.13it/s]

{'loss': 4.0773, 'grad_norm': 7.878422737121582, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14504/33198 [06:44<07:09, 43.48it/s]

{'loss': 4.1026, 'grad_norm': 7.899666786193848, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15006/33198 [06:58<09:29, 31.94it/s]

{'loss': 4.0945, 'grad_norm': 9.905467987060547, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15506/33198 [07:12<07:46, 37.94it/s]

{'loss': 4.1049, 'grad_norm': 8.29261302947998, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16005/33198 [07:26<06:40, 42.97it/s]

{'loss': 4.1187, 'grad_norm': 8.9486665725708, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16506/33198 [07:39<07:32, 36.87it/s]

{'loss': 4.109, 'grad_norm': 10.757320404052734, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17005/33198 [07:52<07:44, 34.85it/s]

{'loss': 4.0872, 'grad_norm': 8.89176082611084, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17504/33198 [08:05<06:02, 43.26it/s]

{'loss': 4.0978, 'grad_norm': 10.769739151000977, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18006/33198 [08:19<08:06, 31.21it/s]

{'loss': 4.1257, 'grad_norm': 9.38382339477539, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18504/33198 [08:33<05:55, 41.29it/s]

{'loss': 4.0947, 'grad_norm': 8.142663955688477, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19004/33198 [08:46<07:28, 31.61it/s]

{'loss': 4.077, 'grad_norm': 8.833924293518066, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19507/33198 [08:59<05:17, 43.09it/s]

{'loss': 4.1017, 'grad_norm': 7.303516864776611, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20008/33198 [09:12<05:12, 42.16it/s]

{'loss': 4.0606, 'grad_norm': 9.228391647338867, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20504/33198 [09:29<05:26, 38.94it/s]

{'loss': 4.086, 'grad_norm': 9.113387107849121, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21007/33198 [09:42<04:44, 42.80it/s]

{'loss': 4.0962, 'grad_norm': 7.469878196716309, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21510/33198 [09:55<04:50, 40.29it/s]

{'loss': 4.073, 'grad_norm': 9.034071922302246, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22003/33198 [10:08<05:25, 34.44it/s]

{'loss': 4.0744, 'grad_norm': 7.641732692718506, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22506/33198 [10:22<04:43, 37.69it/s]

{'loss': 4.0448, 'grad_norm': 8.371180534362793, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23007/33198 [10:35<03:41, 46.06it/s]

{'loss': 4.0403, 'grad_norm': 10.019253730773926, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23505/33198 [10:48<03:23, 47.58it/s]

{'loss': 4.0158, 'grad_norm': 8.560962677001953, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24003/33198 [11:02<04:22, 34.97it/s]

{'loss': 4.0154, 'grad_norm': 11.331335067749023, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24503/33198 [11:14<04:10, 34.74it/s]

{'loss': 4.0068, 'grad_norm': 8.806781768798828, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25006/33198 [11:29<04:09, 32.88it/s]

{'loss': 4.0263, 'grad_norm': 7.061975479125977, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25503/33198 [11:43<04:21, 29.44it/s]

{'loss': 4.0031, 'grad_norm': 7.5069732666015625, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26006/33198 [11:59<03:01, 39.58it/s]

{'loss': 4.0083, 'grad_norm': 8.841962814331055, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26504/33198 [12:13<03:32, 31.50it/s]

{'loss': 4.0252, 'grad_norm': 10.563920974731445, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27008/33198 [12:27<02:46, 37.18it/s]

{'loss': 4.0036, 'grad_norm': 9.242264747619629, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27502/33198 [12:40<02:46, 34.19it/s]

{'loss': 4.0275, 'grad_norm': 8.527857780456543, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28002/33198 [12:54<02:03, 42.02it/s]

{'loss': 4.0334, 'grad_norm': 10.713822364807129, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28504/33198 [13:07<02:24, 32.39it/s]

{'loss': 4.0246, 'grad_norm': 10.206165313720703, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29003/33198 [13:20<02:27, 28.47it/s]

{'loss': 4.0172, 'grad_norm': 9.00136661529541, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29509/33198 [13:34<01:32, 39.93it/s]

{'loss': 4.0223, 'grad_norm': 9.250991821289062, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30007/33198 [13:48<01:10, 45.14it/s]

{'loss': 4.0539, 'grad_norm': 8.755867004394531, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30507/33198 [14:09<01:09, 38.53it/s]

{'loss': 4.0235, 'grad_norm': 9.770101547241211, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31005/33198 [14:21<00:53, 41.12it/s]

{'loss': 4.0337, 'grad_norm': 8.105023384094238, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31504/33198 [14:36<00:41, 41.16it/s]

{'loss': 4.0292, 'grad_norm': 9.075948715209961, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32000/33198 [14:59<00:54, 21.91it/s]

{'loss': 4.0204, 'grad_norm': 8.669865608215332, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32501/33198 [15:35<01:06, 10.56it/s]

{'loss': 4.056, 'grad_norm': 7.954188823699951, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33007/33198 [16:13<00:04, 39.17it/s]

{'loss': 4.0537, 'grad_norm': 8.72110366821289, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [16:19<00:00, 33.90it/s]


{'train_runtime': 979.2677, 'train_samples_per_second': 271.194, 'train_steps_per_second': 33.901, 'train_loss': 4.173695216617554, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:16<00:00, 79.49it/s]
100%|██████████| 10570/10570 [00:24<00:00, 431.75it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 504/33198 [00:13<12:09, 44.82it/s]

{'loss': 5.217, 'grad_norm': 6.099491596221924, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1003/33198 [00:31<33:09, 16.18it/s] 

{'loss': 4.7044, 'grad_norm': 8.236241340637207, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1507/33198 [01:00<13:28, 39.20it/s]  

{'loss': 4.6055, 'grad_norm': 7.926356792449951, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2006/33198 [01:31<12:58, 40.05it/s]  

{'loss': 4.5468, 'grad_norm': 6.85430383682251, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2507/33198 [01:56<21:47, 23.47it/s]  

{'loss': 4.5053, 'grad_norm': 7.698514461517334, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3006/33198 [02:13<14:12, 35.43it/s]

{'loss': 4.4597, 'grad_norm': 8.498244285583496, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3503/33198 [02:26<11:41, 42.33it/s]

{'loss': 4.4602, 'grad_norm': 7.34637975692749, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4003/33198 [02:45<15:00, 32.43it/s]

{'loss': 4.4183, 'grad_norm': 8.288337707519531, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4504/33198 [02:59<12:04, 39.59it/s]

{'loss': 4.3896, 'grad_norm': 7.944289207458496, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5007/33198 [03:15<14:52, 31.60it/s]

{'loss': 4.3524, 'grad_norm': 10.08199405670166, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5507/33198 [03:29<13:57, 33.07it/s]

{'loss': 4.3541, 'grad_norm': 7.995953559875488, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6008/33198 [03:42<09:57, 45.51it/s]

{'loss': 4.3199, 'grad_norm': 7.322112083435059, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6509/33198 [03:59<09:26, 47.13it/s]

{'loss': 4.3068, 'grad_norm': 8.150160789489746, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7007/33198 [04:16<10:19, 42.27it/s]

{'loss': 4.2989, 'grad_norm': 7.836596488952637, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7505/33198 [04:27<09:21, 45.79it/s]

{'loss': 4.285, 'grad_norm': 8.538581848144531, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8004/33198 [04:38<09:08, 45.89it/s]

{'loss': 4.269, 'grad_norm': 8.964563369750977, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8503/33198 [04:57<13:22, 30.77it/s]

{'loss': 4.2533, 'grad_norm': 9.361740112304688, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9003/33198 [05:20<12:11, 33.08it/s]

{'loss': 4.2368, 'grad_norm': 8.852286338806152, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9501/33198 [05:49<29:03, 13.59it/s]

{'loss': 4.23, 'grad_norm': 8.148773193359375, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10007/33198 [06:27<11:16, 34.31it/s]

{'loss': 4.2262, 'grad_norm': 9.081343650817871, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10505/33198 [06:51<08:14, 45.90it/s]

{'loss': 4.2228, 'grad_norm': 9.50588321685791, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11004/33198 [07:17<16:35, 22.31it/s]

{'loss': 4.2143, 'grad_norm': 8.903220176696777, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11500/33198 [07:53<34:18, 10.54it/s]

{'loss': 4.1139, 'grad_norm': 8.513594627380371, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12001/33198 [08:33<38:13,  9.24it/s]

{'loss': 4.0895, 'grad_norm': 12.217830657958984, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12500/33198 [09:09<13:34, 25.42it/s]

{'loss': 4.0802, 'grad_norm': 11.103292465209961, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13000/33198 [09:36<06:58, 48.21it/s]

{'loss': 4.075, 'grad_norm': 9.206415176391602, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13500/33198 [10:08<13:10, 24.91it/s]

{'loss': 4.0914, 'grad_norm': 9.818493843078613, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14004/33198 [10:24<19:43, 16.21it/s]

{'loss': 4.0767, 'grad_norm': 7.897668361663818, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14508/33198 [10:43<06:59, 44.55it/s]

{'loss': 4.0997, 'grad_norm': 8.186816215515137, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15004/33198 [11:09<07:08, 42.43it/s]

{'loss': 4.0946, 'grad_norm': 10.038397789001465, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15502/33198 [11:26<23:03, 12.79it/s]

{'loss': 4.0979, 'grad_norm': 8.496382713317871, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16001/33198 [11:55<30:06,  9.52it/s]

{'loss': 4.1162, 'grad_norm': 8.933939933776855, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16504/33198 [12:20<05:44, 48.46it/s]

{'loss': 4.1091, 'grad_norm': 9.959688186645508, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17007/33198 [12:42<05:50, 46.23it/s]

{'loss': 4.0813, 'grad_norm': 9.275493621826172, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17505/33198 [13:09<08:51, 29.51it/s]

{'loss': 4.0974, 'grad_norm': 9.73459243774414, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18000/33198 [13:36<23:03, 10.99it/s]

{'loss': 4.1155, 'grad_norm': 9.31815242767334, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18507/33198 [14:01<07:40, 31.87it/s]

{'loss': 4.0957, 'grad_norm': 7.844474792480469, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19001/33198 [14:28<21:19, 11.10it/s]

{'loss': 4.0721, 'grad_norm': 9.098621368408203, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19501/33198 [14:56<24:21,  9.37it/s]

{'loss': 4.096, 'grad_norm': 7.1940741539001465, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20000/33198 [15:13<17:21, 12.68it/s]

{'loss': 4.0613, 'grad_norm': 9.117616653442383, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20509/33198 [15:34<04:31, 46.79it/s]

{'loss': 4.0807, 'grad_norm': 9.63761043548584, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21006/33198 [15:45<06:06, 33.25it/s]

{'loss': 4.097, 'grad_norm': 7.165627956390381, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21506/33198 [15:59<04:02, 48.15it/s]

{'loss': 4.0707, 'grad_norm': 8.89576244354248, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22006/33198 [16:11<04:51, 38.39it/s]

{'loss': 4.0743, 'grad_norm': 8.638818740844727, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22503/33198 [16:23<05:02, 35.41it/s]

{'loss': 4.0474, 'grad_norm': 8.223994255065918, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23006/33198 [16:34<03:58, 42.67it/s]

{'loss': 4.0346, 'grad_norm': 9.08349323272705, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23508/33198 [16:46<03:34, 45.08it/s]

{'loss': 4.0197, 'grad_norm': 9.126009941101074, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24004/33198 [16:57<03:27, 44.33it/s]

{'loss': 4.0078, 'grad_norm': 10.384025573730469, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24503/33198 [17:09<03:32, 41.01it/s]

{'loss': 4.0027, 'grad_norm': 9.951871871948242, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25007/33198 [17:23<05:09, 26.44it/s]

{'loss': 4.0279, 'grad_norm': 7.0226969718933105, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25506/33198 [17:41<03:02, 42.22it/s]

{'loss': 4.0037, 'grad_norm': 7.506753444671631, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26003/33198 [17:54<03:00, 39.81it/s]

{'loss': 4.0142, 'grad_norm': 9.113380432128906, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26507/33198 [18:08<02:29, 44.74it/s]

{'loss': 4.0263, 'grad_norm': 11.76244068145752, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27006/33198 [18:21<02:12, 46.74it/s]

{'loss': 4.0077, 'grad_norm': 9.543758392333984, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27507/33198 [18:34<02:02, 46.43it/s]

{'loss': 4.0283, 'grad_norm': 7.587759971618652, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28003/33198 [18:48<03:00, 28.70it/s]

{'loss': 4.0305, 'grad_norm': 9.76353645324707, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28506/33198 [19:11<01:40, 46.59it/s]

{'loss': 4.0147, 'grad_norm': 9.267171859741211, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29004/33198 [19:24<01:45, 39.85it/s]

{'loss': 4.0196, 'grad_norm': 8.76077651977539, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29504/33198 [19:40<01:57, 31.40it/s]

{'loss': 4.0207, 'grad_norm': 8.497650146484375, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30005/33198 [19:54<01:18, 40.87it/s]

{'loss': 4.0506, 'grad_norm': 9.258277893066406, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30507/33198 [20:09<01:05, 41.22it/s]

{'loss': 4.0242, 'grad_norm': 9.40515422821045, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31005/33198 [20:25<00:55, 39.83it/s]

{'loss': 4.0368, 'grad_norm': 8.737187385559082, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31506/33198 [20:39<00:39, 43.19it/s]

{'loss': 4.0183, 'grad_norm': 9.782020568847656, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32003/33198 [20:50<00:29, 40.37it/s]

{'loss': 4.0202, 'grad_norm': 8.812300682067871, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32500/33198 [21:02<00:23, 30.00it/s]

{'loss': 4.0572, 'grad_norm': 8.545795440673828, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33003/33198 [21:20<00:04, 41.16it/s]

{'loss': 4.0596, 'grad_norm': 9.000713348388672, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [21:28<00:00, 25.76it/s]


{'train_runtime': 1288.7128, 'train_samples_per_second': 206.075, 'train_steps_per_second': 25.761, 'train_loss': 4.17244749093919, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:15<00:00, 86.10it/s]
100%|██████████| 10570/10570 [00:18<00:00, 581.85it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 504/33198 [00:13<13:41, 39.77it/s]

{'loss': 5.2706, 'grad_norm': 5.368374347686768, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1007/33198 [00:26<13:22, 40.10it/s]

{'loss': 4.7152, 'grad_norm': 7.780563831329346, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1505/33198 [00:38<12:14, 43.15it/s]

{'loss': 4.6189, 'grad_norm': 7.318145275115967, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2005/33198 [00:50<12:15, 42.41it/s]

{'loss': 4.5473, 'grad_norm': 6.445275783538818, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2504/33198 [01:02<12:12, 41.92it/s]

{'loss': 4.5031, 'grad_norm': 7.091943740844727, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3004/33198 [01:14<12:28, 40.34it/s]

{'loss': 4.4573, 'grad_norm': 8.431456565856934, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3503/33198 [01:26<12:21, 40.03it/s]

{'loss': 4.4596, 'grad_norm': 6.9903459548950195, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4007/33198 [01:38<11:29, 42.36it/s]

{'loss': 4.4174, 'grad_norm': 8.135255813598633, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4506/33198 [01:50<11:32, 41.42it/s]

{'loss': 4.3886, 'grad_norm': 7.750444412231445, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5005/33198 [02:02<12:03, 38.95it/s]

{'loss': 4.3487, 'grad_norm': 9.338872909545898, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5505/33198 [02:14<11:52, 38.84it/s]

{'loss': 4.3521, 'grad_norm': 7.788793563842773, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6007/33198 [02:26<11:55, 37.98it/s]

{'loss': 4.3275, 'grad_norm': 7.447761535644531, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6503/33198 [02:43<14:05, 31.58it/s]

{'loss': 4.3068, 'grad_norm': 7.901050090789795, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7004/33198 [02:59<11:38, 37.48it/s]

{'loss': 4.2941, 'grad_norm': 7.688968181610107, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7504/33198 [03:12<10:42, 39.98it/s]

{'loss': 4.2862, 'grad_norm': 8.246253967285156, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8005/33198 [03:25<10:29, 39.99it/s]

{'loss': 4.2711, 'grad_norm': 8.7944917678833, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8503/33198 [03:38<11:30, 35.77it/s]

{'loss': 4.2516, 'grad_norm': 9.871116638183594, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9001/33198 [03:52<12:04, 33.39it/s]

{'loss': 4.2419, 'grad_norm': 7.849687099456787, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9503/33198 [04:08<11:32, 34.19it/s]

{'loss': 4.2217, 'grad_norm': 7.857113361358643, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10005/33198 [04:23<11:43, 32.99it/s]

{'loss': 4.219, 'grad_norm': 8.973508834838867, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10505/33198 [04:39<10:16, 36.79it/s]

{'loss': 4.2118, 'grad_norm': 8.951766014099121, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11001/33198 [05:05<26:05, 14.18it/s]

{'loss': 4.2119, 'grad_norm': 8.288154602050781, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11504/33198 [05:23<09:44, 37.13it/s]

{'loss': 4.107, 'grad_norm': 8.58151912689209, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12000/33198 [05:43<15:39, 22.55it/s]

{'loss': 4.0869, 'grad_norm': 11.697239875793457, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12506/33198 [05:57<08:24, 41.04it/s]

{'loss': 4.0693, 'grad_norm': 10.06063461303711, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13005/33198 [06:09<08:13, 40.89it/s]

{'loss': 4.0668, 'grad_norm': 8.644493103027344, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13508/33198 [06:31<07:57, 41.21it/s]

{'loss': 4.0869, 'grad_norm': 10.02542781829834, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14000/33198 [06:50<32:17,  9.91it/s]

{'loss': 4.0744, 'grad_norm': 7.551969051361084, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14504/33198 [07:15<15:50, 19.67it/s]

{'loss': 4.1031, 'grad_norm': 7.728491306304932, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15004/33198 [07:28<07:16, 41.67it/s]

{'loss': 4.0945, 'grad_norm': 9.380645751953125, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15502/33198 [07:43<08:49, 33.41it/s]

{'loss': 4.0915, 'grad_norm': 8.888790130615234, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16002/33198 [08:00<07:59, 35.83it/s]

{'loss': 4.1139, 'grad_norm': 8.526517868041992, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16507/33198 [08:18<06:55, 40.17it/s]

{'loss': 4.1073, 'grad_norm': 10.719942092895508, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17000/33198 [08:32<12:26, 21.71it/s]

{'loss': 4.0774, 'grad_norm': 9.644766807556152, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17503/33198 [08:48<06:37, 39.47it/s]

{'loss': 4.0986, 'grad_norm': 9.716432571411133, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18005/33198 [09:00<06:13, 40.68it/s]

{'loss': 4.1221, 'grad_norm': 8.815722465515137, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18505/33198 [09:12<05:49, 42.10it/s]

{'loss': 4.0884, 'grad_norm': 8.525140762329102, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19004/33198 [09:25<08:22, 28.25it/s]

{'loss': 4.0697, 'grad_norm': 8.747166633605957, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▉    | 19506/33198 [09:39<05:37, 40.61it/s]

{'loss': 4.0939, 'grad_norm': 7.294510364532471, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20006/33198 [09:51<05:26, 40.38it/s]

{'loss': 4.055, 'grad_norm': 8.919330596923828, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20503/33198 [10:04<05:30, 38.44it/s]

{'loss': 4.0865, 'grad_norm': 8.50809383392334, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21007/33198 [10:18<04:46, 42.49it/s]

{'loss': 4.0987, 'grad_norm': 7.384470462799072, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21504/33198 [10:32<05:04, 38.42it/s]

{'loss': 4.0669, 'grad_norm': 8.966097831726074, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22000/33198 [10:48<08:11, 22.77it/s]

{'loss': 4.0741, 'grad_norm': 8.624255180358887, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22506/33198 [11:05<04:38, 38.44it/s]

{'loss': 4.0472, 'grad_norm': 7.410487651824951, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23007/33198 [11:19<04:10, 40.71it/s]

{'loss': 4.0382, 'grad_norm': 9.602871894836426, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23506/33198 [11:35<04:01, 40.17it/s]

{'loss': 4.0204, 'grad_norm': 9.407121658325195, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24005/33198 [11:49<04:36, 33.30it/s]

{'loss': 4.0177, 'grad_norm': 10.579171180725098, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24505/33198 [12:05<03:45, 38.51it/s]

{'loss': 4.0048, 'grad_norm': 8.650748252868652, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25005/33198 [12:20<03:15, 41.96it/s]

{'loss': 4.0295, 'grad_norm': 6.545716285705566, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25504/33198 [12:33<03:12, 39.88it/s]

{'loss': 4.0057, 'grad_norm': 7.320461273193359, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26005/33198 [12:46<03:01, 39.57it/s]

{'loss': 4.014, 'grad_norm': 9.713167190551758, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26507/33198 [12:58<02:47, 39.86it/s]

{'loss': 4.0231, 'grad_norm': 11.531954765319824, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27005/33198 [13:11<02:32, 40.53it/s]

{'loss': 4.0098, 'grad_norm': 8.332605361938477, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27507/33198 [13:23<02:28, 38.39it/s]

{'loss': 4.0315, 'grad_norm': 7.897073745727539, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28007/33198 [13:36<02:11, 39.42it/s]

{'loss': 4.0309, 'grad_norm': 9.898113250732422, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28506/33198 [13:49<02:00, 38.94it/s]

{'loss': 4.0146, 'grad_norm': 8.681078910827637, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29005/33198 [14:09<02:44, 25.44it/s]

{'loss': 4.0164, 'grad_norm': 8.960919380187988, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29503/33198 [14:26<01:40, 36.83it/s]

{'loss': 4.0204, 'grad_norm': 8.696645736694336, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30005/33198 [14:40<01:25, 37.55it/s]

{'loss': 4.0541, 'grad_norm': 9.216174125671387, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30500/33198 [14:54<01:14, 36.07it/s]

{'loss': 4.0232, 'grad_norm': 8.696391105651855, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31003/33198 [15:26<01:18, 28.04it/s]

{'loss': 4.0363, 'grad_norm': 7.56039571762085, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31506/33198 [16:02<01:29, 18.81it/s]

{'loss': 4.0251, 'grad_norm': 9.686541557312012, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32001/33198 [16:30<02:14,  8.89it/s]

{'loss': 4.0189, 'grad_norm': 8.666413307189941, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32505/33198 [16:42<00:16, 41.83it/s]

{'loss': 4.0578, 'grad_norm': 7.78354024887085, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33009/33198 [16:59<00:04, 39.55it/s]

{'loss': 4.0663, 'grad_norm': 8.967992782592773, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [17:04<00:00, 32.41it/s]


{'train_runtime': 1024.4163, 'train_samples_per_second': 259.242, 'train_steps_per_second': 32.407, 'train_loss': 4.172847074674938, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:17<00:00, 77.87it/s]
100%|██████████| 10570/10570 [00:18<00:00, 581.78it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 504/33198 [00:11<12:17, 44.34it/s]

{'loss': 5.1859, 'grad_norm': 5.726118564605713, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1003/33198 [00:22<17:09, 31.27it/s]

{'loss': 4.716, 'grad_norm': 7.0972723960876465, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1503/33198 [00:37<17:46, 29.72it/s]

{'loss': 4.6095, 'grad_norm': 13.187422752380371, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2005/33198 [00:53<17:20, 29.98it/s]

{'loss': 4.5441, 'grad_norm': 5.930774688720703, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2509/33198 [01:05<10:13, 50.00it/s]

{'loss': 4.5047, 'grad_norm': 7.053867816925049, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3010/33198 [01:15<09:42, 51.78it/s]

{'loss': 4.4582, 'grad_norm': 8.948625564575195, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3508/33198 [01:25<10:05, 49.05it/s]

{'loss': 4.4515, 'grad_norm': 8.110678672790527, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4009/33198 [01:35<09:28, 51.32it/s]

{'loss': 4.4071, 'grad_norm': 8.523616790771484, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4504/33198 [01:45<09:27, 50.59it/s]

{'loss': 4.3824, 'grad_norm': 8.074662208557129, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5009/33198 [01:55<09:11, 51.09it/s]

{'loss': 4.3374, 'grad_norm': 10.193090438842773, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5506/33198 [02:05<09:37, 47.98it/s]

{'loss': 4.3472, 'grad_norm': 7.834934711456299, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6005/33198 [02:15<09:05, 49.88it/s]

{'loss': 4.3163, 'grad_norm': 7.51124382019043, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6506/33198 [02:25<08:40, 51.29it/s]

{'loss': 4.2969, 'grad_norm': 8.349343299865723, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7009/33198 [02:36<09:20, 46.75it/s]

{'loss': 4.2842, 'grad_norm': 7.691927909851074, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7505/33198 [02:45<08:37, 49.67it/s]

{'loss': 4.277, 'grad_norm': 8.319233894348145, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8009/33198 [02:56<08:25, 49.87it/s]

{'loss': 4.2533, 'grad_norm': 8.288060188293457, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8504/33198 [03:06<08:32, 48.22it/s]

{'loss': 4.2476, 'grad_norm': 11.107903480529785, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9008/33198 [03:16<08:04, 49.93it/s]

{'loss': 4.2301, 'grad_norm': 8.558582305908203, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9508/33198 [03:26<08:00, 49.27it/s]

{'loss': 4.2169, 'grad_norm': 8.018875122070312, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10008/33198 [03:36<08:02, 48.06it/s]

{'loss': 4.2104, 'grad_norm': 11.144255638122559, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10504/33198 [03:47<07:58, 47.40it/s]

{'loss': 4.198, 'grad_norm': 9.397102355957031, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11007/33198 [03:57<07:22, 50.20it/s]

{'loss': 4.1834, 'grad_norm': 8.782341003417969, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11504/33198 [04:07<07:33, 47.81it/s]

{'loss': 4.0834, 'grad_norm': 8.25867748260498, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12006/33198 [04:17<07:25, 47.52it/s]

{'loss': 4.0649, 'grad_norm': 13.425212860107422, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12506/33198 [04:28<07:12, 47.79it/s]

{'loss': 4.0341, 'grad_norm': 10.939164161682129, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13005/33198 [04:38<07:13, 46.63it/s]

{'loss': 4.0163, 'grad_norm': 9.302475929260254, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13508/33198 [04:48<06:44, 48.71it/s]

{'loss': 4.024, 'grad_norm': 11.092337608337402, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14008/33198 [05:01<07:23, 43.28it/s]

{'loss': 4.0073, 'grad_norm': 8.460589408874512, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14509/33198 [05:12<06:16, 49.59it/s]

{'loss': 4.0409, 'grad_norm': 14.329657554626465, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15006/33198 [05:22<06:38, 45.67it/s]

{'loss': 4.0187, 'grad_norm': 11.513931274414062, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15507/33198 [05:32<05:52, 50.21it/s]

{'loss': 4.0178, 'grad_norm': 8.716485977172852, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16005/33198 [05:43<05:58, 47.93it/s]

{'loss': 4.04, 'grad_norm': 9.269510269165039, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16500/33198 [05:54<05:34, 49.97it/s]

{'loss': 4.0168, 'grad_norm': 16.871644973754883, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17008/33198 [06:15<05:26, 49.57it/s]

{'loss': 3.9901, 'grad_norm': 17.62178611755371, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17505/33198 [06:34<05:40, 46.15it/s]

{'loss': 3.9977, 'grad_norm': 13.905438423156738, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18009/33198 [06:44<05:33, 45.61it/s]

{'loss': 4.0188, 'grad_norm': 8.827104568481445, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18502/33198 [07:02<21:41, 11.29it/s]

{'loss': 3.9846, 'grad_norm': 7.857240676879883, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19006/33198 [07:20<04:38, 50.93it/s]

{'loss': 3.9648, 'grad_norm': 12.830521583557129, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19501/33198 [07:39<16:15, 14.05it/s]

{'loss': 3.9952, 'grad_norm': 9.387167930603027, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20010/33198 [08:00<04:26, 49.49it/s]

{'loss': 3.9504, 'grad_norm': 12.965628623962402, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20507/33198 [08:24<04:36, 45.87it/s]

{'loss': 3.9718, 'grad_norm': 9.302750587463379, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21010/33198 [08:42<04:20, 46.70it/s]

{'loss': 3.9776, 'grad_norm': 7.562581539154053, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21507/33198 [08:56<03:49, 50.92it/s]

{'loss': 3.945, 'grad_norm': 10.434149742126465, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22007/33198 [09:15<04:04, 45.81it/s]

{'loss': 3.9367, 'grad_norm': 11.146270751953125, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22503/33198 [09:34<08:28, 21.02it/s]

{'loss': 3.9304, 'grad_norm': 14.871341705322266, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23004/33198 [09:50<03:33, 47.69it/s]

{'loss': 3.9219, 'grad_norm': 11.19936752319336, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23507/33198 [10:00<03:09, 51.25it/s]

{'loss': 3.8876, 'grad_norm': 9.987467765808105, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24001/33198 [10:18<13:30, 11.35it/s]

{'loss': 3.8802, 'grad_norm': 11.827798843383789, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24501/33198 [10:40<06:39, 21.77it/s]

{'loss': 3.8592, 'grad_norm': 9.623414039611816, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25008/33198 [10:59<04:46, 28.55it/s]

{'loss': 3.8942, 'grad_norm': 7.6078948974609375, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25509/33198 [11:21<02:34, 49.70it/s]

{'loss': 3.8657, 'grad_norm': 11.444421768188477, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26005/33198 [11:48<04:13, 28.39it/s]

{'loss': 3.8805, 'grad_norm': 9.198184967041016, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26506/33198 [12:01<02:10, 51.29it/s]

{'loss': 3.8823, 'grad_norm': 16.474964141845703, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27002/33198 [12:20<09:20, 11.05it/s]

{'loss': 3.8737, 'grad_norm': 9.695333480834961, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27501/33198 [12:37<03:27, 27.47it/s]

{'loss': 3.9004, 'grad_norm': 10.363393783569336, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28000/33198 [12:57<04:20, 19.93it/s]

{'loss': 3.8983, 'grad_norm': 21.62123680114746, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28506/33198 [13:11<01:29, 52.20it/s]

{'loss': 3.8701, 'grad_norm': 10.028822898864746, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29002/33198 [13:26<05:57, 11.75it/s]

{'loss': 3.8719, 'grad_norm': 9.515891075134277, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29506/33198 [13:44<01:23, 44.25it/s]

{'loss': 3.8881, 'grad_norm': 11.395096778869629, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30008/33198 [13:59<01:03, 50.43it/s]

{'loss': 3.9195, 'grad_norm': 9.173688888549805, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30504/33198 [14:18<00:56, 47.97it/s]

{'loss': 3.8881, 'grad_norm': 9.900225639343262, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31009/33198 [14:35<00:46, 47.35it/s]

{'loss': 3.8884, 'grad_norm': 11.242599487304688, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31501/33198 [14:57<01:56, 14.62it/s]

{'loss': 3.8881, 'grad_norm': 11.116098403930664, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32003/33198 [15:16<00:34, 34.58it/s]

{'loss': 3.8801, 'grad_norm': 10.67619800567627, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32501/33198 [15:41<01:04, 10.86it/s]

{'loss': 3.9301, 'grad_norm': 11.348482131958008, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33006/33198 [15:56<00:03, 49.13it/s]

{'loss': 3.9141, 'grad_norm': 9.164691925048828, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [16:00<00:00, 34.57it/s]


{'train_runtime': 960.3681, 'train_samples_per_second': 276.531, 'train_steps_per_second': 34.568, 'train_loss': 4.09527921765689, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:16<00:00, 82.19it/s]
100%|██████████| 10570/10570 [00:24<00:00, 425.81it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 505/33198 [00:11<11:01, 49.46it/s]

{'loss': 5.1583, 'grad_norm': 6.544259548187256, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1001/33198 [00:31<42:45, 12.55it/s] 

{'loss': 4.7311, 'grad_norm': 6.5698676109313965, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1501/33198 [00:53<46:25, 11.38it/s]

{'loss': 4.6258, 'grad_norm': 12.783973693847656, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2003/33198 [01:16<27:56, 18.61it/s]  

{'loss': 4.5592, 'grad_norm': 7.086500644683838, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2503/33198 [01:35<17:14, 29.69it/s]  

{'loss': 4.513, 'grad_norm': 6.795592308044434, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3002/33198 [01:53<13:32, 37.15it/s]

{'loss': 4.4671, 'grad_norm': 8.10591983795166, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3504/33198 [02:11<15:19, 32.29it/s]

{'loss': 4.4526, 'grad_norm': 7.604242324829102, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4000/33198 [02:25<11:43, 41.51it/s]

{'loss': 4.4152, 'grad_norm': 8.294495582580566, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4500/33198 [02:41<34:37, 13.81it/s]

{'loss': 4.3853, 'grad_norm': 7.54940128326416, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5000/33198 [03:09<16:18, 28.83it/s]

{'loss': 4.346, 'grad_norm': 9.94802474975586, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5505/33198 [03:31<10:09, 45.46it/s]

{'loss': 4.3524, 'grad_norm': 8.504803657531738, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6004/33198 [03:57<12:54, 35.12it/s]

{'loss': 4.3201, 'grad_norm': 7.513419151306152, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6504/33198 [04:14<12:40, 35.10it/s]

{'loss': 4.3073, 'grad_norm': 8.126593589782715, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7006/33198 [04:27<09:29, 45.96it/s]

{'loss': 4.2905, 'grad_norm': 7.557707786560059, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7506/33198 [04:40<10:27, 40.93it/s]

{'loss': 4.2771, 'grad_norm': 8.074536323547363, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8005/33198 [04:53<09:47, 42.91it/s]

{'loss': 4.2675, 'grad_norm': 8.313448905944824, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8500/33198 [05:10<09:41, 42.47it/s]

{'loss': 4.2547, 'grad_norm': 9.84346866607666, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9000/33198 [05:42<21:04, 19.13it/s]

{'loss': 4.2411, 'grad_norm': 8.149506568908691, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9507/33198 [05:56<08:50, 44.64it/s]

{'loss': 4.224, 'grad_norm': 7.779999256134033, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10002/33198 [06:18<16:13, 23.84it/s]

{'loss': 4.2204, 'grad_norm': 9.395424842834473, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10509/33198 [06:30<08:53, 42.54it/s]

{'loss': 4.2105, 'grad_norm': 9.6415433883667, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11000/33198 [06:43<30:34, 12.10it/s]

{'loss': 4.2152, 'grad_norm': 8.485074996948242, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11502/33198 [07:06<07:42, 46.90it/s]

{'loss': 4.1024, 'grad_norm': 8.556230545043945, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12008/33198 [07:19<07:42, 45.81it/s]

{'loss': 4.0925, 'grad_norm': 11.130523681640625, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12504/33198 [07:40<14:27, 23.85it/s]

{'loss': 4.0696, 'grad_norm': 11.338006019592285, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13006/33198 [07:55<14:54, 22.58it/s]

{'loss': 4.073, 'grad_norm': 9.256583213806152, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13507/33198 [08:08<07:41, 42.67it/s]

{'loss': 4.0793, 'grad_norm': 10.196284294128418, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14000/33198 [08:23<13:21, 23.95it/s]

{'loss': 4.0682, 'grad_norm': 8.163761138916016, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14505/33198 [08:43<07:00, 44.47it/s]

{'loss': 4.1051, 'grad_norm': 8.161203384399414, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15003/33198 [08:56<08:16, 36.66it/s]

{'loss': 4.0975, 'grad_norm': 9.158161163330078, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15504/33198 [09:09<07:48, 37.75it/s]

{'loss': 4.0952, 'grad_norm': 9.164891242980957, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16000/33198 [09:24<09:45, 29.36it/s]

{'loss': 4.1067, 'grad_norm': 8.752315521240234, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16510/33198 [09:36<05:50, 47.57it/s]

{'loss': 4.1055, 'grad_norm': 10.383354187011719, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17007/33198 [10:00<07:19, 36.86it/s]

{'loss': 4.0796, 'grad_norm': 9.79237174987793, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17505/33198 [10:12<05:46, 45.31it/s]

{'loss': 4.0968, 'grad_norm': 10.247963905334473, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18002/33198 [10:42<11:48, 21.44it/s]

{'loss': 4.1179, 'grad_norm': 9.19451904296875, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18506/33198 [11:16<06:25, 38.12it/s]

{'loss': 4.0847, 'grad_norm': 8.705649375915527, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19002/33198 [11:37<24:56,  9.48it/s]

{'loss': 4.0713, 'grad_norm': 8.558287620544434, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19502/33198 [12:10<06:44, 33.82it/s]

{'loss': 4.0935, 'grad_norm': 8.019159317016602, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20008/33198 [12:23<04:45, 46.16it/s]

{'loss': 4.0552, 'grad_norm': 9.556673049926758, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20506/33198 [12:36<04:41, 45.14it/s]

{'loss': 4.0795, 'grad_norm': 9.856563568115234, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21003/33198 [12:56<08:12, 24.76it/s]

{'loss': 4.0941, 'grad_norm': 7.428546905517578, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21505/33198 [13:17<04:55, 39.51it/s]

{'loss': 4.0642, 'grad_norm': 9.581589698791504, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22007/33198 [13:37<04:27, 41.86it/s]

{'loss': 4.0728, 'grad_norm': 8.810335159301758, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22500/33198 [14:02<06:34, 27.09it/s]

{'loss': 4.0419, 'grad_norm': 9.428764343261719, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23008/33198 [14:14<03:40, 46.12it/s]

{'loss': 4.0405, 'grad_norm': 9.340229034423828, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23506/33198 [14:25<04:12, 38.41it/s]

{'loss': 4.0113, 'grad_norm': 9.975574493408203, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24009/33198 [14:37<03:08, 48.83it/s]

{'loss': 4.0106, 'grad_norm': 10.5370512008667, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24503/33198 [14:49<03:54, 37.09it/s]

{'loss': 3.9971, 'grad_norm': 8.968664169311523, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25003/33198 [15:03<04:14, 32.15it/s]

{'loss': 4.02, 'grad_norm': 7.605747699737549, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25503/33198 [15:17<02:59, 42.90it/s]

{'loss': 4.007, 'grad_norm': 7.678272724151611, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26004/33198 [15:31<02:56, 40.78it/s]

{'loss': 4.0107, 'grad_norm': 9.010610580444336, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26504/33198 [15:45<03:35, 31.01it/s]

{'loss': 4.0198, 'grad_norm': 11.803719520568848, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27008/33198 [15:58<02:13, 46.47it/s]

{'loss': 4.0032, 'grad_norm': 9.153314590454102, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27505/33198 [16:10<03:15, 29.08it/s]

{'loss': 4.0366, 'grad_norm': 7.775594234466553, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28004/33198 [16:25<03:31, 24.52it/s]

{'loss': 4.0267, 'grad_norm': 10.3668794631958, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28504/33198 [16:37<01:58, 39.53it/s]

{'loss': 4.0166, 'grad_norm': 10.164084434509277, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29007/33198 [16:50<01:41, 41.48it/s]

{'loss': 4.0107, 'grad_norm': 8.750020027160645, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29508/33198 [17:13<01:23, 44.32it/s]

{'loss': 4.0186, 'grad_norm': 9.88864803314209, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30004/33198 [17:28<01:26, 36.75it/s]

{'loss': 4.0508, 'grad_norm': 8.862512588500977, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30505/33198 [17:50<01:48, 24.85it/s]

{'loss': 4.0226, 'grad_norm': 9.0410795211792, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31002/33198 [18:07<01:14, 29.36it/s]

{'loss': 4.0364, 'grad_norm': 9.511988639831543, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31507/33198 [18:20<00:38, 44.37it/s]

{'loss': 4.0304, 'grad_norm': 9.850687026977539, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32006/33198 [18:31<00:29, 40.58it/s]

{'loss': 4.0214, 'grad_norm': 8.560542106628418, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32507/33198 [18:52<00:14, 47.43it/s]

{'loss': 4.0526, 'grad_norm': 7.8927693367004395, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33008/33198 [19:06<00:03, 49.68it/s]

{'loss': 4.0566, 'grad_norm': 8.636191368103027, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [19:10<00:00, 28.85it/s]


{'train_runtime': 1150.7659, 'train_samples_per_second': 230.778, 'train_steps_per_second': 28.849, 'train_loss': 4.17007298288277, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:17<00:00, 79.11it/s]
100%|██████████| 10570/10570 [00:26<00:00, 400.48it/s]
Some weights of BertForQuestionAnswering were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['qa_outputs.bias', 'qa_outputs.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  2%|▏         | 506/33198 [00:15<13:28, 40.45it/s]

{'loss': 5.1758, 'grad_norm': 5.5125579833984375, 'learning_rate': 1.9698777034761133e-05, 'epoch': 0.05}


  3%|▎         | 1001/33198 [00:30<38:36, 13.90it/s]

{'loss': 4.7071, 'grad_norm': 6.63002347946167, 'learning_rate': 1.939755406952226e-05, 'epoch': 0.09}


  5%|▍         | 1502/33198 [00:47<20:19, 25.98it/s]

{'loss': 4.6081, 'grad_norm': 11.04152774810791, 'learning_rate': 1.9096331104283393e-05, 'epoch': 0.14}


  6%|▌         | 2003/33198 [01:07<14:41, 35.37it/s]

{'loss': 4.5518, 'grad_norm': 6.294684886932373, 'learning_rate': 1.879510813904452e-05, 'epoch': 0.18}


  8%|▊         | 2503/33198 [01:22<13:21, 38.30it/s]

{'loss': 4.5094, 'grad_norm': 7.173090934753418, 'learning_rate': 1.8493885173805653e-05, 'epoch': 0.23}


  9%|▉         | 3007/33198 [01:35<10:12, 49.26it/s]

{'loss': 4.4654, 'grad_norm': 8.332958221435547, 'learning_rate': 1.819266220856678e-05, 'epoch': 0.27}


 11%|█         | 3500/33198 [01:53<15:26, 32.04it/s]

{'loss': 4.458, 'grad_norm': 7.909667491912842, 'learning_rate': 1.7891439243327913e-05, 'epoch': 0.32}


 12%|█▏        | 4010/33198 [02:19<13:11, 36.90it/s]  

{'loss': 4.4195, 'grad_norm': 8.394230842590332, 'learning_rate': 1.759021627808904e-05, 'epoch': 0.36}


 14%|█▎        | 4500/33198 [02:44<11:30, 41.58it/s]

{'loss': 4.3887, 'grad_norm': 7.515993595123291, 'learning_rate': 1.7288993312850173e-05, 'epoch': 0.41}


 15%|█▌        | 5003/33198 [03:08<24:30, 19.17it/s]

{'loss': 4.3453, 'grad_norm': 10.207367897033691, 'learning_rate': 1.6987770347611305e-05, 'epoch': 0.45}


 17%|█▋        | 5500/33198 [03:26<10:40, 43.24it/s]

{'loss': 4.3517, 'grad_norm': 8.25554370880127, 'learning_rate': 1.6686547382372433e-05, 'epoch': 0.5}


 18%|█▊        | 6000/33198 [03:51<28:07, 16.12it/s]

{'loss': 4.3206, 'grad_norm': 7.186047077178955, 'learning_rate': 1.6385324417133565e-05, 'epoch': 0.54}


 20%|█▉        | 6503/33198 [04:08<09:40, 45.97it/s]

{'loss': 4.3032, 'grad_norm': 7.878291130065918, 'learning_rate': 1.6084101451894693e-05, 'epoch': 0.59}


 21%|██        | 7004/33198 [04:26<12:14, 35.65it/s]

{'loss': 4.2922, 'grad_norm': 7.570741653442383, 'learning_rate': 1.5782878486655825e-05, 'epoch': 0.63}


 23%|██▎       | 7503/33198 [04:46<17:47, 24.08it/s]

{'loss': 4.28, 'grad_norm': 8.650850296020508, 'learning_rate': 1.5481655521416953e-05, 'epoch': 0.68}


 24%|██▍       | 8006/33198 [05:06<14:21, 29.25it/s]

{'loss': 4.2731, 'grad_norm': 8.624935150146484, 'learning_rate': 1.5180432556178085e-05, 'epoch': 0.72}


 26%|██▌       | 8500/33198 [05:25<40:09, 10.25it/s]

{'loss': 4.2496, 'grad_norm': 9.063841819763184, 'learning_rate': 1.4879209590939215e-05, 'epoch': 0.77}


 27%|██▋       | 9001/33198 [05:46<08:40, 46.46it/s]

{'loss': 4.2363, 'grad_norm': 8.024989128112793, 'learning_rate': 1.4577986625700345e-05, 'epoch': 0.81}


 29%|██▊       | 9501/33198 [06:09<58:23,  6.76it/s]

{'loss': 4.2262, 'grad_norm': 7.842356204986572, 'learning_rate': 1.4276763660461475e-05, 'epoch': 0.86}


 30%|███       | 10007/33198 [06:27<08:24, 45.99it/s]

{'loss': 4.2235, 'grad_norm': 8.722813606262207, 'learning_rate': 1.3975540695222605e-05, 'epoch': 0.9}


 32%|███▏      | 10503/33198 [06:50<09:43, 38.87it/s]

{'loss': 4.2106, 'grad_norm': 8.646865844726562, 'learning_rate': 1.3674317729983735e-05, 'epoch': 0.95}


 33%|███▎      | 11007/33198 [07:18<09:36, 38.50it/s]

{'loss': 4.2075, 'grad_norm': 8.524059295654297, 'learning_rate': 1.3373094764744865e-05, 'epoch': 0.99}


 35%|███▍      | 11510/33198 [07:32<08:27, 42.70it/s]

{'loss': 4.1018, 'grad_norm': 8.107661247253418, 'learning_rate': 1.3071871799505995e-05, 'epoch': 1.04}


 36%|███▌      | 12007/33198 [07:47<08:24, 41.97it/s]

{'loss': 4.0895, 'grad_norm': 11.561569213867188, 'learning_rate': 1.2770648834267125e-05, 'epoch': 1.08}


 38%|███▊      | 12506/33198 [08:11<08:35, 40.15it/s]

{'loss': 4.0721, 'grad_norm': 10.343844413757324, 'learning_rate': 1.2469425869028256e-05, 'epoch': 1.13}


 39%|███▉      | 13001/33198 [08:33<39:30,  8.52it/s]

{'loss': 4.0726, 'grad_norm': 9.49120044708252, 'learning_rate': 1.2168202903789386e-05, 'epoch': 1.17}


 41%|████      | 13504/33198 [08:56<08:16, 39.67it/s]

{'loss': 4.0827, 'grad_norm': 9.725366592407227, 'learning_rate': 1.1866979938550516e-05, 'epoch': 1.22}


 42%|████▏     | 14004/33198 [09:19<08:37, 37.06it/s]

{'loss': 4.0641, 'grad_norm': 7.842074871063232, 'learning_rate': 1.1565756973311646e-05, 'epoch': 1.27}


 44%|████▎     | 14503/33198 [09:34<08:59, 34.65it/s]

{'loss': 4.0981, 'grad_norm': 8.050152778625488, 'learning_rate': 1.1264534008072776e-05, 'epoch': 1.31}


 45%|████▌     | 15004/33198 [09:46<07:03, 42.96it/s]

{'loss': 4.101, 'grad_norm': 9.106961250305176, 'learning_rate': 1.0963311042833906e-05, 'epoch': 1.36}


 47%|████▋     | 15507/33198 [09:58<05:59, 49.17it/s]

{'loss': 4.0913, 'grad_norm': 8.097149848937988, 'learning_rate': 1.0662088077595036e-05, 'epoch': 1.4}


 48%|████▊     | 16004/33198 [10:09<06:53, 41.60it/s]

{'loss': 4.1089, 'grad_norm': 9.241143226623535, 'learning_rate': 1.0360865112356166e-05, 'epoch': 1.45}


 50%|████▉     | 16509/33198 [10:30<06:06, 45.59it/s]

{'loss': 4.1055, 'grad_norm': 10.787370681762695, 'learning_rate': 1.0059642147117296e-05, 'epoch': 1.49}


 51%|█████     | 17008/33198 [10:50<06:00, 44.96it/s]

{'loss': 4.0761, 'grad_norm': 10.075386047363281, 'learning_rate': 9.758419181878426e-06, 'epoch': 1.54}


 53%|█████▎    | 17506/33198 [11:06<05:30, 47.43it/s]

{'loss': 4.0916, 'grad_norm': 10.574637413024902, 'learning_rate': 9.457196216639558e-06, 'epoch': 1.58}


 54%|█████▍    | 18008/33198 [11:31<05:33, 45.48it/s]

{'loss': 4.1162, 'grad_norm': 9.505813598632812, 'learning_rate': 9.155973251400688e-06, 'epoch': 1.63}


 56%|█████▌    | 18506/33198 [11:42<04:58, 49.25it/s]

{'loss': 4.0827, 'grad_norm': 8.407498359680176, 'learning_rate': 8.854750286161818e-06, 'epoch': 1.67}


 57%|█████▋    | 19009/33198 [12:00<05:07, 46.20it/s]

{'loss': 4.0705, 'grad_norm': 8.155019760131836, 'learning_rate': 8.553527320922948e-06, 'epoch': 1.72}


 59%|█████▊    | 19502/33198 [12:25<04:46, 47.88it/s]

{'loss': 4.0953, 'grad_norm': 7.073212623596191, 'learning_rate': 8.252304355684078e-06, 'epoch': 1.76}


 60%|██████    | 20007/33198 [12:43<04:56, 44.53it/s]

{'loss': 4.059, 'grad_norm': 9.24216365814209, 'learning_rate': 7.951081390445208e-06, 'epoch': 1.81}


 62%|██████▏   | 20508/33198 [13:02<04:34, 46.22it/s]

{'loss': 4.0831, 'grad_norm': 8.869059562683105, 'learning_rate': 7.649858425206338e-06, 'epoch': 1.85}


 63%|██████▎   | 21005/33198 [13:16<06:16, 32.38it/s]

{'loss': 4.0945, 'grad_norm': 7.321341514587402, 'learning_rate': 7.348635459967468e-06, 'epoch': 1.9}


 65%|██████▍   | 21503/33198 [13:28<04:09, 46.81it/s]

{'loss': 4.0607, 'grad_norm': 9.526588439941406, 'learning_rate': 7.0474124947285985e-06, 'epoch': 1.94}


 66%|██████▋   | 22003/33198 [13:44<06:12, 30.02it/s]

{'loss': 4.075, 'grad_norm': 8.057596206665039, 'learning_rate': 6.7461895294897285e-06, 'epoch': 1.99}


 68%|██████▊   | 22503/33198 [13:56<03:39, 48.81it/s]

{'loss': 4.0456, 'grad_norm': 8.89096450805664, 'learning_rate': 6.4449665642508585e-06, 'epoch': 2.03}


 69%|██████▉   | 23003/33198 [14:16<04:20, 39.07it/s]

{'loss': 4.0447, 'grad_norm': 9.401945114135742, 'learning_rate': 6.1437435990119885e-06, 'epoch': 2.08}


 71%|███████   | 23508/33198 [14:30<03:23, 47.55it/s]

{'loss': 4.0139, 'grad_norm': 9.035786628723145, 'learning_rate': 5.8425206337731185e-06, 'epoch': 2.12}


 72%|███████▏  | 24001/33198 [14:58<17:44,  8.64it/s]

{'loss': 4.0095, 'grad_norm': 10.816130638122559, 'learning_rate': 5.541297668534249e-06, 'epoch': 2.17}


 74%|███████▍  | 24507/33198 [15:12<03:06, 46.67it/s]

{'loss': 4.0032, 'grad_norm': 8.94147777557373, 'learning_rate': 5.240074703295379e-06, 'epoch': 2.21}


 75%|███████▌  | 25005/33198 [15:24<03:39, 37.37it/s]

{'loss': 4.0191, 'grad_norm': 7.287137031555176, 'learning_rate': 4.93885173805651e-06, 'epoch': 2.26}


 77%|███████▋  | 25506/33198 [15:36<02:56, 43.62it/s]

{'loss': 3.9989, 'grad_norm': 7.463711261749268, 'learning_rate': 4.63762877281764e-06, 'epoch': 2.3}


 78%|███████▊  | 26006/33198 [15:54<03:01, 39.66it/s]

{'loss': 4.0034, 'grad_norm': 9.124420166015625, 'learning_rate': 4.33640580757877e-06, 'epoch': 2.35}


 80%|███████▉  | 26503/33198 [16:13<03:14, 34.49it/s]

{'loss': 4.0212, 'grad_norm': 12.119963645935059, 'learning_rate': 4.0351828423399e-06, 'epoch': 2.39}


 81%|████████▏ | 27008/33198 [16:32<03:10, 32.55it/s]

{'loss': 4.0015, 'grad_norm': 9.110553741455078, 'learning_rate': 3.7339598771010305e-06, 'epoch': 2.44}


 83%|████████▎ | 27506/33198 [16:51<03:01, 31.42it/s]

{'loss': 4.0319, 'grad_norm': 7.574251651763916, 'learning_rate': 3.4327369118621605e-06, 'epoch': 2.49}


 84%|████████▍ | 28006/33198 [17:10<02:17, 37.87it/s]

{'loss': 4.0286, 'grad_norm': 9.909834861755371, 'learning_rate': 3.131513946623291e-06, 'epoch': 2.53}


 86%|████████▌ | 28507/33198 [17:32<01:39, 47.26it/s]

{'loss': 4.0203, 'grad_norm': 10.3059720993042, 'learning_rate': 2.830290981384421e-06, 'epoch': 2.58}


 87%|████████▋ | 29002/33198 [17:49<01:30, 46.46it/s]

{'loss': 4.0133, 'grad_norm': 8.450298309326172, 'learning_rate': 2.5290680161455513e-06, 'epoch': 2.62}


 89%|████████▉ | 29506/33198 [18:08<01:34, 38.94it/s]

{'loss': 4.0192, 'grad_norm': 9.6996431350708, 'learning_rate': 2.2278450509066812e-06, 'epoch': 2.67}


 90%|█████████ | 30008/33198 [18:24<01:09, 45.78it/s]

{'loss': 4.0561, 'grad_norm': 8.549276351928711, 'learning_rate': 1.9266220856678112e-06, 'epoch': 2.71}


 92%|█████████▏| 30508/33198 [18:37<01:35, 28.21it/s]

{'loss': 4.0254, 'grad_norm': 9.294112205505371, 'learning_rate': 1.6253991204289416e-06, 'epoch': 2.76}


 93%|█████████▎| 31000/33198 [19:00<02:40, 13.73it/s]

{'loss': 4.0324, 'grad_norm': 8.739338874816895, 'learning_rate': 1.3241761551900716e-06, 'epoch': 2.8}


 95%|█████████▍| 31508/33198 [19:22<00:36, 46.11it/s]

{'loss': 4.0307, 'grad_norm': 9.754219055175781, 'learning_rate': 1.022953189951202e-06, 'epoch': 2.85}


 96%|█████████▋| 32009/33198 [19:44<00:29, 39.79it/s]

{'loss': 4.0141, 'grad_norm': 8.608089447021484, 'learning_rate': 7.217302247123321e-07, 'epoch': 2.89}


 98%|█████████▊| 32507/33198 [20:14<00:15, 43.50it/s]

{'loss': 4.0575, 'grad_norm': 8.105183601379395, 'learning_rate': 4.205072594734623e-07, 'epoch': 2.94}


 99%|█████████▉| 33001/33198 [20:34<00:11, 16.47it/s]

{'loss': 4.0602, 'grad_norm': 9.111811637878418, 'learning_rate': 1.1928429423459245e-07, 'epoch': 2.98}


100%|██████████| 33198/33198 [20:44<00:00, 26.69it/s]


{'train_runtime': 1244.0204, 'train_samples_per_second': 213.479, 'train_steps_per_second': 26.686, 'train_loss': 4.169612178300345, 'epoch': 3.0}


100%|██████████| 1348/1348 [00:24<00:00, 54.21it/s] 
100%|██████████| 10570/10570 [00:21<00:00, 496.32it/s]


# End of Code

In [None]:
# Reformer	LSH + Reversible layers	O(n log n)	Moderate
# Linformer	Low-rank projection	O(n)	Slight drop
# Performer	Kernel-based attention (FAVOR+)	O(n)	Very good
# BigBird	Structured sparse attention	O(n)	Excellent
# Nyströmformer	Nyström matrix approximation	O(n)	Good
# FlashAttention	CUDA-optimized full attention	O(n²) but fast	No tradeoff

# Logs

Using the 4 layers BERT and 64 batch size:
    the original attention training took 8:17 min and gave an accuracy of 79.82 %
    the proposed attention training took 9:08 and gave an accuracy of 80.96%


It takes 80 min to fine tune Bert-base-uncased on sst2 (batch_size = 128)



QA

Num Context Tokens = 0      Random Seed = 2025
{'exact_match': 28.94039735099338, 'f1': 40.779770669530336}

Num Context Tokens = 1      Random Seed = 2025
{'exact_match': 9.725638599810786, 'f1': 19.227043974258013}

Num Context Tokens = 4      Random Seed = 2025
{'exact_match': 8.684957426679281, 'f1': 18.15225450570832}

Num Context Tokens = 16      Random Seed = 2025
{'exact_match': 9.01608325449385, 'f1': 17.624647379035757}

Num Context Tokens = 32      Random Seed = 2025
{'exact_match': 4.711447492904447, 'f1': 10.245614370758615}

Num Context Tokens = 64      Random Seed = 2025
{'exact_match': 4.654683065279092, 'f1': 10.208962959376224}

Num Context Tokens = 128      Random Seed = 2025
{'exact_match': 3.424787133396405, 'f1': 7.686542942813391}

Num Context Tokens = 0      Random Seed = 17
{'exact_match': 32.02459791863765, 'f1': 44.28220365441659}

Num Context Tokens = 1      Random Seed = 17
{'exact_match': 9.7918637653737, 'f1': 19.70289598515182}

Num Context Tokens = 4      Random Seed = 17
{'exact_match': 7.994323557237465, 'f1': 16.63525060675849}

Num Context Tokens = 16      Random Seed = 17
{'exact_match': 10.189214758751183, 'f1': 19.176374788338425}

Num Context Tokens = 32      Random Seed = 17
{'exact_match': 9.754020813623463, 'f1': 18.620783689403815}

Num Context Tokens = 64      Random Seed = 17
{'exact_match': 7.994323557237465, 'f1': 15.505388032599994}


# Attention (original + mine v1)

In [None]:
class CustomAttention(nn.Module):
    def __init__(self, config, num_context_tokens):
        super().__init__()
        # The following equality must be possible num_attention_heads * attention_head_size = config.hidden_size.
        if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
            raise ValueError(f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention " f"heads ({config.num_attention_heads})")
        
        # assign attention_head size such that num_attention_heads * attention_head_size = config.hidden_size
        self.num_attention_heads = config.num_attention_heads
        self.attention_head_size = int(config.hidden_size / config.num_attention_heads)    #    
        self.all_head_size = self.num_attention_heads * self.attention_head_size

        # Initialize projections matrices and dropout layer
        self.query = nn.Linear(config.hidden_size, self.all_head_size)
        self.key = nn.Linear(config.hidden_size, self.all_head_size)
        self.value = nn.Linear(config.hidden_size, self.all_head_size)
        self.dropout_prob = config.attention_probs_dropout_prob
        self.dropout = nn.Dropout(config.attention_probs_dropout_prob)        

        # Central Attention complexity is O(num_context_tokens * sequence_length * hidden_size)
        self.num_context_tokens = num_context_tokens
    
    def transpose_for_scores(self, x: torch.Tensor) -> torch.Tensor:
        new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size) # (batch_size, num_tokens, num_attention_heads, attention_head_size)
        x = x.view(new_x_shape)
        return x.permute(0, 2, 1, 3) # (batch_size, num_attention_heads, num_tokens, attention_head_size)

    # Adapted from BertSelfAttention
    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        encoder_attention_mask: Optional[torch.FloatTensor] = None,
        past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        output_attentions: Optional[bool] = False,
    ) -> Tuple[torch.Tensor]:
        
        # Project input into query, key, value matrices respectively
        query_layer = self.transpose_for_scores(self.query(hidden_states))
        key_layer = self.transpose_for_scores(self.key(hidden_states))
        value_layer = self.transpose_for_scores(self.value(hidden_states))

        # Split q,k, and v for the central attention.
        # Tokens ordering: [CLS] + [CONTEXT_1] + [CONTEXT_2] ... [CONTEXT_N] + ALL SENTENCE TOKENS + [SEP] + [PAD] + [PAD] + [PAD] ....
        # first token is [CLS], then the context tokens, then the sentence, then the [SEP], and finally the padding.
        c = self.num_context_tokens + 1  # [CLS] + context tokens are treated as the central context tokens.
        query_layer_context     = query_layer[:,:,:c,:]
        query_layer_sentence    = query_layer[:,:,c:,:]
        key_layer_context       = key_layer[:,:,:c,:]
        key_layer_sentence      = key_layer[:,:,c:,:]
        value_layer_context     = value_layer[:,:,:c,:]
        value_layer_sentence    = value_layer[:,:,c:,:]

        if attention_mask is not None:
            attention_mask_context  = attention_mask[:,:,:c,c:]
            attention_mask_sentence = attention_mask[:,:,c:,:c]  # should always be zeros, since the central context tokens will never contain padding
          
        if self.num_context_tokens == 0: # Original Attention
            # Take the dot product between "query" and "key" to get the raw attention scores.
            attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) / math.sqrt(self.attention_head_size)

            # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
            if attention_mask is not None:
                attention_scores = attention_scores + attention_mask

            # Normalize the attention scores to probabilities.
            attention_probs = torch.nn.functional.softmax(attention_scores, dim=-1)

            # This is actually dropping out entire tokens to attend to, which might
            # seem a bit unusual, but is taken from the original Transformer paper.
            attention_probs = self.dropout(attention_probs)
            attn_output = torch.matmul(attention_probs, value_layer)

        else: # Central Attention 

            attention_scores_context_to_sentence = torch.matmul(query_layer_context, key_layer_sentence.transpose(-1, -2)) / math.sqrt(self.attention_head_size)
            attention_scores_sentence_to_context = torch.matmul(query_layer_sentence, key_layer_context.transpose(-1, -2)) / math.sqrt(self.attention_head_size)

            if attention_mask is not None:
                attention_scores_context_to_sentence = attention_scores_context_to_sentence + attention_mask_context
                attention_scores_sentence_to_context = attention_scores_sentence_to_context + attention_mask_sentence

            attention_probs_context_to_sentence = torch.nn.functional.softmax(attention_scores_context_to_sentence, dim=-1)
            attention_probs_sentence_to_context = torch.nn.functional.softmax(attention_scores_sentence_to_context, dim=-1)

            attention_probs_context_to_sentence = self.dropout(attention_probs_context_to_sentence)
            attention_probs_sentence_to_context = self.dropout(attention_probs_sentence_to_context)

            attn_output_context = torch.matmul(attention_probs_context_to_sentence, value_layer_sentence)
            attn_output_sentence= torch.matmul(attention_probs_sentence_to_context, value_layer_context)

            attn_output = torch.cat((attn_output_context, attn_output_sentence), dim = 2)

        # reformatting attention output to (batch_size, num_tokens, hidden_size), which is the contextual embedding for each token.
        attn_output = attn_output.permute(0, 2, 1, 3).contiguous()
        new_attn_output_shape = attn_output.size()[:-2] + (self.all_head_size,)
        attn_output = attn_output.view(new_attn_output_shape)

        outputs = (attn_output,)

        return outputs

# Debugging and Testing