## Data


In [1]:
IGNORE_INDEX = -100
MAX_LENGTH = 128

In [2]:
from datasets import load_dataset

wnut_dataset = load_dataset("wnut_17", trust_remote_code=True)

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
wnut_dataset

DatasetDict({
    train: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 3394
    })
    validation: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1009
    })
    test: Dataset({
        features: ['id', 'tokens', 'ner_tags'],
        num_rows: 1287
    })
})

In [4]:
tokens = wnut_dataset["train"][0]["tokens"]
ner_tags = wnut_dataset["train"][0]["ner_tags"]
assert len(tokens) == len(ner_tags)
tokens[:5], ner_tags[:5]

(['@paulwalk', 'It', "'s", 'the', 'view'], [0, 0, 0, 0, 0])

In [5]:
label_list = wnut_dataset["train"].features["ner_tags"].feature.names
label_list, len(label_list)

(['O',
  'B-corporation',
  'I-corporation',
  'B-creative-work',
  'I-creative-work',
  'B-group',
  'I-group',
  'B-location',
  'I-location',
  'B-person',
  'I-person',
  'B-product',
  'I-product'],
 13)

In [6]:
from transformers import AutoTokenizer

model_name = "distilbert/distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [7]:
example = wnut_dataset["train"][0]
tokenized_input = tokenizer(
    example["tokens"], max_length=MAX_LENGTH, truncation=True, is_split_into_words=True
)
tokens = tokenizer.convert_ids_to_tokens(tokenized_input["input_ids"])
tokens[:5], example["tokens"][:5], example["ner_tags"][:5]

(['[CLS]', '@', 'paul', '##walk', 'it'],
 ['@paulwalk', 'It', "'s", 'the', 'view'],
 [0, 0, 0, 0, 0])

In [8]:
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"],
        truncation=True,
        max_length=MAX_LENGTH,  #'max_length',
        is_split_into_words=True,
    )

    labels = []
    for i, label in enumerate(examples[f"ner_tags"]):
        # Map tokens to their respective word.
        word_ids = tokenized_inputs.word_ids(batch_index=i)

        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx and word_ids != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(IGNORE_INDEX)

            # if word_idx is None:
            #     label_ids.append(IGNORE_INDEX)
            # elif (
            #     word_idx != previous_word_idx
            # ):  # Only label the first token of a given word.
            #     label_ids.append(label[word_idx])
            # else:
            #     label_ids.append(IGNORE_INDEX)
            previous_word_idx = word_idx
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

In [9]:
tokenized_wnut = wnut_dataset.map(tokenize_and_align_labels, batched=True)

In [10]:
row = tokenized_wnut["train"][0]

print(len(row["tokens"]), len(row["ner_tags"]))
print(len(row["input_ids"]), len(row["attention_mask"]), len(row["labels"]))
row.keys()

27 27
34 34 34


dict_keys(['id', 'tokens', 'ner_tags', 'input_ids', 'attention_mask', 'labels'])

In [11]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(
    tokenizer=tokenizer, padding="max_length", max_length=MAX_LENGTH
)

## Metrics


In [12]:
import evaluate

# https://huggingface.co/spaces/evaluate-metric/seqeval
seqeval = evaluate.load("seqeval")

In [13]:
# labels = [label_list[i] for i in example[f"ner_tags"]]
# labels

In [14]:
import numpy as np

# labels = [label_list[i] for i in example[f"ner_tags"]]


def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)  # 2
    # print(predictions)

    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != IGNORE_INDEX]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != IGNORE_INDEX]
        for prediction, label in zip(predictions, labels)
    ]
    # print(true_predictions)
    # print(true_labels)

    results = seqeval.compute(
        predictions=true_predictions, references=true_labels, zero_division=0
    )
    # print(results)

    return {
        "precision": results["overall_precision"],
        "recall": results["overall_recall"],
        "f1": results["overall_f1"],
        "accuracy": results["overall_accuracy"],
    }

In [15]:
# p = ([[[0, 0, 1], [0, 1, 0]]], [2, 1])
# compute_metrics(p)

## Model


In [16]:
label2id = {
    "O": 0,
    "B-corporation": 1,
    "I-corporation": 2,
    "B-creative-work": 3,
    "I-creative-work": 4,
    "B-group": 5,
    "I-group": 6,
    "B-location": 7,
    "I-location": 8,
    "B-person": 9,
    "I-person": 10,
    "B-product": 11,
    "I-product": 12,
}
id2label = {v: k for k, v in label2id.items()}

In [17]:
from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer

model = AutoModelForTokenClassification.from_pretrained(
    model_name,
    num_labels=13,
    id2label=id2label,
    label2id=label2id,
)

Some weights of DistilBertForTokenClassification were not initialized from the model checkpoint at distilbert/distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [18]:
from torchinfo import summary

summary(model, verbose=1)
None

Layer (type:depth-idx)                                  Param #
DistilBertForTokenClassification                        --
├─DistilBertModel: 1-1                                  --
│    └─Embeddings: 2-1                                  --
│    │    └─Embedding: 3-1                              23,440,896
│    │    └─Embedding: 3-2                              393,216
│    │    └─LayerNorm: 3-3                              1,536
│    │    └─Dropout: 3-4                                --
│    └─Transformer: 2-2                                 --
│    │    └─ModuleList: 3-5                             42,527,232
├─Dropout: 1-2                                          --
├─Linear: 1-3                                           9,997
Total params: 66,372,877
Trainable params: 66,372,877
Non-trainable params: 0


## LoRA


In [19]:
from peft import LoraConfig, get_peft_model

target_modules = ["q_lin", "k_lin", "v_lin"]

# Load config
# Write your code here
config = LoraConfig(
    r=64,
    lora_alpha=128,
    lora_dropout=0.1,
    bias="none",
    target_modules=target_modules,
    modules_to_save=["classifier"],
)
lora_model = get_peft_model(model, config)
summary(lora_model, verbose=1)
None

Layer (type:depth-idx)                                                 Param #
PeftModel                                                              --
├─LoraModel: 1-1                                                       --
│    └─DistilBertForTokenClassification: 2-1                           --
│    │    └─DistilBertModel: 3-1                                       68,132,352
│    │    └─Dropout: 3-2                                               --
│    │    └─ModulesToSaveWrapper: 3-3                                  19,994
Total params: 68,152,346
Trainable params: 1,779,469
Non-trainable params: 66,372,877


## ClearML


In [20]:
import os
from dotenv import load_dotenv
from IPython.display import clear_output

load_dotenv(".env")

CLEARML_API_ACCESS_KEY = os.getenv("CLEARML_API_ACCESS_KEY")
CLEARML_API_SECRET_KEY = os.getenv("CLEARML_API_SECRET_KEY")

%env CLEARML_WEB_HOST=https://app.clear.ml/
%env CLEARML_API_HOST=https://api.clear.ml/
%env CLEARML_FILES_HOST=https://files.clear.ml/
%env CLEARML_API_ACCESS_KEY=$CLEARML_API_ACCESS_KEY
%env CLEARML_API_SECRET_KEY=$CLEARML_API_SECRET_KEY
clear_output()

In [21]:
from clearml import Task

# default is ViT if input is empty
if not (task_name := input()):
    task_name = "BERT"

task = Task.init(
    task_name=task_name,
    project_name="PMLDL",
)

ClearML Task: created new task id=ba411791de9f4632a43f00fabd5ff593
2024-10-28 16:02:12,126 - clearml.Task - INFO - Storing jupyter notebook directly as code
ClearML results page: https://app.clear.ml/projects/3fdff4c44fe4469db0bf496ddae269dc/experiments/ba411791de9f4632a43f00fabd5ff593/output/log


huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)


## Training


In [24]:
batch_size = 32
epochs = 20

training_args = TrainingArguments(
    # output and logging
    output_dir="bert-output",
    # overwrite_output_dir=True,
    logging_steps=10,
    # remove_unused_columns=False,
    push_to_hub=False,
    report_to="clearml",
    # data
    label_names=["labels"],  # https://github.com/huggingface/transformers/issues/28530
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    # hyperparameters
    learning_rate=2e-5,
    num_train_epochs=epochs,
    gradient_accumulation_steps=2,
    weight_decay=0.001,
    # strategies
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    # other
    seed=42,
)

In [25]:
from transformers import EarlyStoppingCallback

trainer = Trainer(
    model=lora_model,
    args=training_args,
    train_dataset=tokenized_wnut["train"],
    eval_dataset=tokenized_wnut["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=10)],
)

In [26]:
trainer.train()



  1%|          | 10/1060 [00:04<07:40,  2.28it/s]

{'loss': 4.8553, 'grad_norm': 13.343891143798828, 'learning_rate': 1.9811320754716984e-05, 'epoch': 0.19}


  2%|▏         | 20/1060 [00:09<07:25,  2.33it/s]

{'loss': 4.6007, 'grad_norm': 13.306522369384766, 'learning_rate': 1.9622641509433963e-05, 'epoch': 0.37}


  3%|▎         | 30/1060 [00:13<07:14,  2.37it/s]

{'loss': 4.3055, 'grad_norm': 13.72619342803955, 'learning_rate': 1.9433962264150945e-05, 'epoch': 0.56}


  4%|▍         | 40/1060 [00:17<07:24,  2.30it/s]

{'loss': 3.9578, 'grad_norm': 13.937024116516113, 'learning_rate': 1.9245283018867927e-05, 'epoch': 0.75}


  5%|▍         | 50/1060 [00:21<07:10,  2.34it/s]

{'loss': 3.515, 'grad_norm': 14.048351287841797, 'learning_rate': 1.905660377358491e-05, 'epoch': 0.93}


                                                 
  5%|▌         | 53/1060 [00:27<07:08,  2.35it/s]

{'eval_loss': 1.6174941062927246, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9132316035422736, 'eval_runtime': 4.5297, 'eval_samples_per_second': 284.126, 'eval_steps_per_second': 9.051, 'epoch': 0.99}


  6%|▌         | 60/1060 [00:48<20:19,  1.22s/it]  

{'loss': 3.0129, 'grad_norm': 14.340547561645508, 'learning_rate': 1.8867924528301888e-05, 'epoch': 1.12}


  7%|▋         | 70/1060 [00:52<07:34,  2.18it/s]

{'loss': 2.2954, 'grad_norm': 14.052679061889648, 'learning_rate': 1.867924528301887e-05, 'epoch': 1.31}


  8%|▊         | 80/1060 [00:57<06:55,  2.36it/s]

{'loss': 1.6378, 'grad_norm': 11.231733322143555, 'learning_rate': 1.8490566037735852e-05, 'epoch': 1.5}


  8%|▊         | 90/1060 [01:01<06:52,  2.35it/s]

{'loss': 1.0864, 'grad_norm': 7.511904716491699, 'learning_rate': 1.830188679245283e-05, 'epoch': 1.68}


  9%|▉         | 100/1060 [01:05<06:47,  2.36it/s]

{'loss': 0.7861, 'grad_norm': 3.4796345233917236, 'learning_rate': 1.8113207547169813e-05, 'epoch': 1.87}


                                                  
 10%|█         | 107/1060 [01:12<05:48,  2.73it/s]

{'eval_loss': 0.4801444709300995, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.3617, 'eval_samples_per_second': 295.068, 'eval_steps_per_second': 9.4, 'epoch': 2.0}


 10%|█         | 110/1060 [01:31<57:14,  3.62s/it]  

{'loss': 0.6114, 'grad_norm': 1.6160500049591064, 'learning_rate': 1.7924528301886795e-05, 'epoch': 2.06}


 11%|█▏        | 120/1060 [01:35<08:02,  1.95it/s]

{'loss': 0.6619, 'grad_norm': 0.5249636173248291, 'learning_rate': 1.7735849056603774e-05, 'epoch': 2.24}


 12%|█▏        | 130/1060 [01:39<06:35,  2.35it/s]

{'loss': 0.5742, 'grad_norm': 0.8887442350387573, 'learning_rate': 1.7547169811320756e-05, 'epoch': 2.43}


 13%|█▎        | 140/1060 [01:44<06:30,  2.36it/s]

{'loss': 0.5735, 'grad_norm': 0.44479459524154663, 'learning_rate': 1.735849056603774e-05, 'epoch': 2.62}


 14%|█▍        | 150/1060 [01:48<06:26,  2.36it/s]

{'loss': 0.5454, 'grad_norm': 0.538837730884552, 'learning_rate': 1.716981132075472e-05, 'epoch': 2.8}


 15%|█▌        | 160/1060 [01:52<06:21,  2.36it/s]

{'loss': 0.5771, 'grad_norm': 0.4966653287410736, 'learning_rate': 1.69811320754717e-05, 'epoch': 2.99}


                                                  
 15%|█▌        | 160/1060 [01:57<06:21,  2.36it/s]

{'eval_loss': 0.46249377727508545, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.3928, 'eval_samples_per_second': 292.981, 'eval_steps_per_second': 9.334, 'epoch': 2.99}


 16%|█▌        | 170/1060 [02:18<10:06,  1.47it/s]  

{'loss': 0.5844, 'grad_norm': 0.733722984790802, 'learning_rate': 1.679245283018868e-05, 'epoch': 3.18}


 17%|█▋        | 180/1060 [02:22<06:19,  2.32it/s]

{'loss': 0.4825, 'grad_norm': 0.3899109363555908, 'learning_rate': 1.6603773584905664e-05, 'epoch': 3.36}


 18%|█▊        | 190/1060 [02:27<06:09,  2.36it/s]

{'loss': 0.5398, 'grad_norm': 0.7129817605018616, 'learning_rate': 1.6415094339622643e-05, 'epoch': 3.55}


 19%|█▉        | 200/1060 [02:31<06:06,  2.35it/s]

{'loss': 0.598, 'grad_norm': 0.6791282296180725, 'learning_rate': 1.6226415094339625e-05, 'epoch': 3.74}


 20%|█▉        | 210/1060 [02:35<06:01,  2.35it/s]

{'loss': 0.5625, 'grad_norm': 0.4288803040981293, 'learning_rate': 1.6037735849056607e-05, 'epoch': 3.93}


                                                  
 20%|██        | 214/1060 [02:41<05:10,  2.73it/s]

{'eval_loss': 0.4457629323005676, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.4047, 'eval_samples_per_second': 292.187, 'eval_steps_per_second': 9.308, 'epoch': 4.0}


 21%|██        | 220/1060 [03:01<21:18,  1.52s/it]  

{'loss': 0.5593, 'grad_norm': 0.5247312188148499, 'learning_rate': 1.5849056603773586e-05, 'epoch': 4.11}


 22%|██▏       | 230/1060 [03:05<06:19,  2.19it/s]

{'loss': 0.4917, 'grad_norm': 0.8038358688354492, 'learning_rate': 1.5660377358490568e-05, 'epoch': 4.3}


 23%|██▎       | 240/1060 [03:10<05:49,  2.35it/s]

{'loss': 0.5056, 'grad_norm': 0.5636480450630188, 'learning_rate': 1.547169811320755e-05, 'epoch': 4.49}


 24%|██▎       | 250/1060 [03:14<05:45,  2.35it/s]

{'loss': 0.5423, 'grad_norm': 0.4851149022579193, 'learning_rate': 1.5283018867924532e-05, 'epoch': 4.67}


 25%|██▍       | 260/1060 [03:18<05:40,  2.35it/s]

{'loss': 0.5169, 'grad_norm': 0.5321747064590454, 'learning_rate': 1.5094339622641511e-05, 'epoch': 4.86}


                                                  
 25%|██▌       | 267/1060 [03:25<05:37,  2.35it/s]

{'eval_loss': 0.42166849970817566, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.4003, 'eval_samples_per_second': 292.477, 'eval_steps_per_second': 9.317, 'epoch': 4.99}


 25%|██▌       | 270/1060 [03:44<47:43,  3.62s/it]  

{'loss': 0.5573, 'grad_norm': 0.6594676375389099, 'learning_rate': 1.4905660377358491e-05, 'epoch': 5.05}


 26%|██▋       | 280/1060 [03:48<06:40,  1.95it/s]

{'loss': 0.5314, 'grad_norm': 0.5341221690177917, 'learning_rate': 1.4716981132075472e-05, 'epoch': 5.23}


 27%|██▋       | 290/1060 [03:53<05:29,  2.34it/s]

{'loss': 0.4266, 'grad_norm': 1.1582375764846802, 'learning_rate': 1.4528301886792452e-05, 'epoch': 5.42}


 28%|██▊       | 300/1060 [03:57<05:23,  2.35it/s]

{'loss': 0.4747, 'grad_norm': 0.6910334229469299, 'learning_rate': 1.4339622641509435e-05, 'epoch': 5.61}


 29%|██▉       | 310/1060 [04:01<05:18,  2.36it/s]

{'loss': 0.4514, 'grad_norm': 0.7015814781188965, 'learning_rate': 1.4150943396226415e-05, 'epoch': 5.79}


 30%|███       | 320/1060 [04:05<05:14,  2.35it/s]

{'loss': 0.4602, 'grad_norm': 1.0783578157424927, 'learning_rate': 1.3962264150943397e-05, 'epoch': 5.98}


                                                  
 30%|███       | 321/1060 [04:10<04:32,  2.71it/s]

{'eval_loss': 0.3825850486755371, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.413, 'eval_samples_per_second': 291.635, 'eval_steps_per_second': 9.291, 'epoch': 6.0}


 31%|███       | 330/1060 [04:31<09:44,  1.25it/s]  

{'loss': 0.5171, 'grad_norm': 1.2437933683395386, 'learning_rate': 1.3773584905660378e-05, 'epoch': 6.17}


 32%|███▏      | 340/1060 [04:36<05:11,  2.31it/s]

{'loss': 0.4157, 'grad_norm': 0.5337054133415222, 'learning_rate': 1.3584905660377358e-05, 'epoch': 6.36}


 33%|███▎      | 350/1060 [04:40<05:02,  2.35it/s]

{'loss': 0.4231, 'grad_norm': 1.2700152397155762, 'learning_rate': 1.339622641509434e-05, 'epoch': 6.54}


 34%|███▍      | 360/1060 [04:44<04:58,  2.35it/s]

{'loss': 0.4136, 'grad_norm': 0.8863049745559692, 'learning_rate': 1.320754716981132e-05, 'epoch': 6.73}


 35%|███▍      | 370/1060 [04:49<04:53,  2.35it/s]

{'loss': 0.3993, 'grad_norm': 1.1528431177139282, 'learning_rate': 1.3018867924528303e-05, 'epoch': 6.92}


                                                  
 35%|███▌      | 374/1060 [04:55<04:58,  2.30it/s]

{'eval_loss': 0.3781232237815857, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.3864, 'eval_samples_per_second': 293.404, 'eval_steps_per_second': 9.347, 'epoch': 6.99}


 36%|███▌      | 380/1060 [05:15<17:21,  1.53s/it]  

{'loss': 0.4132, 'grad_norm': 0.7648394107818604, 'learning_rate': 1.2830188679245283e-05, 'epoch': 7.1}


 37%|███▋      | 390/1060 [05:19<05:04,  2.20it/s]

{'loss': 0.404, 'grad_norm': 0.6905752420425415, 'learning_rate': 1.2641509433962264e-05, 'epoch': 7.29}


 38%|███▊      | 400/1060 [05:23<04:40,  2.35it/s]

{'loss': 0.4267, 'grad_norm': 0.9779990911483765, 'learning_rate': 1.2452830188679246e-05, 'epoch': 7.48}


 39%|███▊      | 410/1060 [05:27<04:37,  2.35it/s]

{'loss': 0.3835, 'grad_norm': 1.3765205144882202, 'learning_rate': 1.2264150943396227e-05, 'epoch': 7.66}


 40%|███▉      | 420/1060 [05:32<04:31,  2.35it/s]

{'loss': 0.4156, 'grad_norm': 0.7510159611701965, 'learning_rate': 1.2075471698113209e-05, 'epoch': 7.85}


                                                  
 40%|████      | 428/1060 [05:39<03:51,  2.73it/s]

{'eval_loss': 0.38876578211784363, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.4084, 'eval_samples_per_second': 291.946, 'eval_steps_per_second': 9.301, 'epoch': 8.0}


 41%|████      | 430/1060 [05:58<52:15,  4.98s/it]  

{'loss': 0.392, 'grad_norm': 0.4238787591457367, 'learning_rate': 1.188679245283019e-05, 'epoch': 8.04}


 42%|████▏     | 440/1060 [06:02<05:41,  1.82it/s]

{'loss': 0.4233, 'grad_norm': 0.5094231963157654, 'learning_rate': 1.169811320754717e-05, 'epoch': 8.22}


 42%|████▏     | 450/1060 [06:06<04:21,  2.34it/s]

{'loss': 0.4288, 'grad_norm': 1.126540184020996, 'learning_rate': 1.1509433962264152e-05, 'epoch': 8.41}


 43%|████▎     | 460/1060 [06:10<04:15,  2.35it/s]

{'loss': 0.4021, 'grad_norm': 0.5422381162643433, 'learning_rate': 1.1320754716981132e-05, 'epoch': 8.6}


 44%|████▍     | 470/1060 [06:15<04:14,  2.32it/s]

{'loss': 0.3702, 'grad_norm': 0.8516051173210144, 'learning_rate': 1.1132075471698115e-05, 'epoch': 8.79}


 45%|████▌     | 480/1060 [06:19<04:07,  2.34it/s]

{'loss': 0.3631, 'grad_norm': 0.8657976388931274, 'learning_rate': 1.0943396226415095e-05, 'epoch': 8.97}


                                                  
 45%|████▌     | 481/1060 [06:24<04:07,  2.34it/s]

{'eval_loss': 0.38644179701805115, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.4001, 'eval_samples_per_second': 292.491, 'eval_steps_per_second': 9.318, 'epoch': 8.99}


 46%|████▌     | 490/1060 [06:45<07:33,  1.26it/s]  

{'loss': 0.3976, 'grad_norm': 0.6678516864776611, 'learning_rate': 1.0754716981132076e-05, 'epoch': 9.16}


 47%|████▋     | 500/1060 [06:49<04:02,  2.31it/s]

{'loss': 0.4255, 'grad_norm': 0.44450706243515015, 'learning_rate': 1.0566037735849058e-05, 'epoch': 9.35}


 48%|████▊     | 510/1060 [06:53<03:54,  2.35it/s]

{'loss': 0.3699, 'grad_norm': 0.5086137652397156, 'learning_rate': 1.0377358490566038e-05, 'epoch': 9.53}


 49%|████▉     | 520/1060 [06:58<03:49,  2.35it/s]

{'loss': 0.382, 'grad_norm': 0.47330331802368164, 'learning_rate': 1.018867924528302e-05, 'epoch': 9.72}


 50%|█████     | 530/1060 [07:02<03:44,  2.36it/s]

{'loss': 0.3713, 'grad_norm': 0.4638141095638275, 'learning_rate': 1e-05, 'epoch': 9.91}


                                                  
 50%|█████     | 535/1060 [07:08<03:18,  2.65it/s]

{'eval_loss': 0.38724201917648315, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.5723, 'eval_samples_per_second': 281.477, 'eval_steps_per_second': 8.967, 'epoch': 10.0}


 51%|█████     | 540/1060 [07:28<17:30,  2.02s/it]  

{'loss': 0.3811, 'grad_norm': 0.5202161073684692, 'learning_rate': 9.811320754716981e-06, 'epoch': 10.09}


 52%|█████▏    | 550/1060 [07:32<03:59,  2.13it/s]

{'loss': 0.3784, 'grad_norm': 0.47577646374702454, 'learning_rate': 9.622641509433963e-06, 'epoch': 10.28}


 53%|█████▎    | 560/1060 [07:37<03:34,  2.34it/s]

{'loss': 0.3734, 'grad_norm': 1.0538872480392456, 'learning_rate': 9.433962264150944e-06, 'epoch': 10.47}


 54%|█████▍    | 570/1060 [07:41<03:28,  2.35it/s]

{'loss': 0.3822, 'grad_norm': 1.021440863609314, 'learning_rate': 9.245283018867926e-06, 'epoch': 10.65}


 55%|█████▍    | 580/1060 [07:45<03:24,  2.35it/s]

{'loss': 0.4077, 'grad_norm': 0.8767058849334717, 'learning_rate': 9.056603773584907e-06, 'epoch': 10.84}


                                                  
 55%|█████▌    | 588/1060 [07:53<03:20,  2.35it/s]

{'eval_loss': 0.38151252269744873, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.3988, 'eval_samples_per_second': 292.579, 'eval_steps_per_second': 9.321, 'epoch': 10.99}


 56%|█████▌    | 590/1060 [08:11<39:08,  5.00s/it]

{'loss': 0.3583, 'grad_norm': 0.5539172887802124, 'learning_rate': 8.867924528301887e-06, 'epoch': 11.03}


 57%|█████▋    | 600/1060 [08:16<04:14,  1.81it/s]

{'loss': 0.3543, 'grad_norm': 1.3451056480407715, 'learning_rate': 8.67924528301887e-06, 'epoch': 11.21}


 58%|█████▊    | 610/1060 [08:20<03:12,  2.34it/s]

{'loss': 0.4169, 'grad_norm': 0.866301417350769, 'learning_rate': 8.49056603773585e-06, 'epoch': 11.4}


 58%|█████▊    | 620/1060 [08:24<03:07,  2.35it/s]

{'loss': 0.395, 'grad_norm': 0.9138256311416626, 'learning_rate': 8.301886792452832e-06, 'epoch': 11.59}


 59%|█████▉    | 630/1060 [08:28<03:04,  2.33it/s]

{'loss': 0.3853, 'grad_norm': 0.38531455397605896, 'learning_rate': 8.113207547169812e-06, 'epoch': 11.78}


 60%|██████    | 640/1060 [08:33<03:00,  2.33it/s]

{'loss': 0.3616, 'grad_norm': 0.8884817957878113, 'learning_rate': 7.924528301886793e-06, 'epoch': 11.96}


                                                  
 61%|██████    | 642/1060 [08:38<02:34,  2.71it/s]

{'eval_loss': 0.38028204441070557, 'eval_precision': 0.0, 'eval_recall': 0.0, 'eval_f1': 0.0, 'eval_accuracy': 0.9206161479216546, 'eval_runtime': 4.4491, 'eval_samples_per_second': 289.271, 'eval_steps_per_second': 9.215, 'epoch': 12.0}


 61%|██████    | 642/1060 [08:55<05:48,  1.20it/s]

{'train_runtime': 541.3428, 'train_samples_per_second': 125.392, 'train_steps_per_second': 1.958, 'train_loss': 0.8518582899993825, 'epoch': 12.0}





TrainOutput(global_step=642, training_loss=0.8518582899993825, metrics={'train_runtime': 541.3428, 'train_samples_per_second': 125.392, 'train_steps_per_second': 1.958, 'total_flos': 1386234650382336.0, 'train_loss': 0.8518582899993825, 'epoch': 12.0})