In [1]:
import json
import argparse
from itertools import chain
from functools import partial

import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification
import evaluate
from datasets import Dataset, features
import numpy as np
from seqeval.metrics import recall_score, precision_score
from seqeval.metrics import classification_report
from seqeval.metrics import f1_score

TRAINING_MODEL_PATH = "microsoft/deberta-v3-base"
TRAINING_MAX_LENGTH = 1024
OUTPUT_DIR = "output"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = json.load(open("train.json"))

# downsampling of negative examples
pos=[] # positive samples (contain relevant labels)
neg=[] # negative samples (presumably contain entities that are possibly wrongly classified as entity)
for d in data:
    if any(np.array(d["labels"]) != "O"): pos.append(d)
    else: neg.append(d)
print("original datapoints: ", len(data))

data = pos+neg
print("combined: ", len(data))

original datapoints:  6807
combined:  6807


In [3]:
all_labels = sorted(list(set(chain(*[x["labels"] for x in data]))))
label2id = {l: i for i,l in enumerate(all_labels)}
id2label = {v:k for k,v in label2id.items()}

target = [
    'B-EMAIL', 'B-ID_NUM', 'B-NAME_STUDENT', 'B-PHONE_NUM', 
    'B-STREET_ADDRESS', 'B-URL_PERSONAL', 'B-USERNAME', 'I-ID_NUM', 
    'I-NAME_STUDENT', 'I-PHONE_NUM', 'I-STREET_ADDRESS', 'I-URL_PERSONAL'
]

print(id2label)

{0: 'B-EMAIL', 1: 'B-ID_NUM', 2: 'B-NAME_STUDENT', 3: 'B-PHONE_NUM', 4: 'B-STREET_ADDRESS', 5: 'B-URL_PERSONAL', 6: 'B-USERNAME', 7: 'I-ID_NUM', 8: 'I-NAME_STUDENT', 9: 'I-PHONE_NUM', 10: 'I-STREET_ADDRESS', 11: 'I-URL_PERSONAL', 12: 'O'}


In [4]:
def tokenize(example, tokenizer, label2id, max_length):

    # rebuild text from tokens
    text = []
    labels = []

    for t, l, ws in zip(example["tokens"], example["provided_labels"], example["trailing_whitespace"]):
        text.append(t)
        labels.extend([l] * len(t))
        if ws:
            text.append(" ")
            labels.append("O")
    tokenized = tokenizer("".join(text), return_offsets_mapping=True, max_length=max_length)
    text = "".join(text)
    token_labels = []
    for start_idx, end_idx in tokenized.offset_mapping:
        # CLS token
        if start_idx == 0 and end_idx == 0:
            token_labels.append(label2id["O"])
            continue
        # case when token starts with whitespace
        if text[start_idx].isspace():
            start_idx += 1
        token_labels.append(label2id[labels[start_idx]])
    length = len(tokenized.input_ids)
    return {**tokenized, "labels": token_labels, "length": length}

In [5]:
tokenizer = AutoTokenizer.from_pretrained(TRAINING_MODEL_PATH)

ds = Dataset.from_dict({
    "full_text": [x["full_text"] for x in data],
    "document": [str(x["document"]) for x in data],
    "tokens": [x["tokens"] for x in data],
    "trailing_whitespace": [x["trailing_whitespace"] for x in data],
    "provided_labels": [x["labels"] for x in data],
})
ds = ds.map(tokenize, fn_kwargs={"tokenizer": tokenizer, "label2id": label2id, "max_length": TRAINING_MAX_LENGTH}, num_proc=3)

Map (num_proc=3): 100%|██████████| 6807/6807 [00:17<00:00, 397.44 examples/s]


In [6]:
x = ds[0]

for t,l in zip(x["tokens"], x["provided_labels"]):
    if l != "O":
        print((t,l))

print("*"*100)

for t, l in zip(tokenizer.convert_ids_to_tokens(x["input_ids"]), x["labels"]):
    if id2label[l] != "O":
        print((t,id2label[l]))

('Nathalie', 'B-NAME_STUDENT')
('Sylla', 'I-NAME_STUDENT')
('Nathalie', 'B-NAME_STUDENT')
('Sylla', 'I-NAME_STUDENT')
('Nathalie', 'B-NAME_STUDENT')
('Sylla', 'I-NAME_STUDENT')
****************************************************************************************************
('N', 'B-NAME_STUDENT')
('atha', 'B-NAME_STUDENT')
('lie', 'B-NAME_STUDENT')
('▁S', 'I-NAME_STUDENT')
('ylla', 'I-NAME_STUDENT')
('N', 'B-NAME_STUDENT')
('atha', 'B-NAME_STUDENT')
('lie', 'B-NAME_STUDENT')
('▁S', 'I-NAME_STUDENT')
('ylla', 'I-NAME_STUDENT')
('N', 'B-NAME_STUDENT')
('atha', 'B-NAME_STUDENT')
('lie', 'B-NAME_STUDENT')
('▁S', 'I-NAME_STUDENT')
('ylla', 'I-NAME_STUDENT')


In [7]:
def compute_metrics(p, all_labels):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [all_labels[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [all_labels[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    recall = recall_score(true_labels, true_predictions)
    precision = precision_score(true_labels, true_predictions)
    f1_score = (1 + 5*5) * recall * precision / (5*5*precision + recall)
    
    results = {
        'recall': recall,
        'precision': precision,
        'f1': f1_score
    }
    return results

In [8]:
model = AutoModelForTokenClassification.from_pretrained(
    TRAINING_MODEL_PATH,
    num_labels=len(all_labels),
    id2label=id2label,
    label2id=label2id,
    ignore_mismatched_sizes=True
)
collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=16)

  return self.fget.__get__(instance, owner)()
Some weights of DebertaV2ForTokenClassification were not initialized from the model checkpoint at microsoft/deberta-v3-base and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
args = TrainingArguments(
    output_dir=OUTPUT_DIR, 
    fp16=True,
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    report_to="none",
    evaluation_strategy="no",
    do_eval=False,
    save_total_limit=1,
    logging_steps=20,
    lr_scheduler_type='cosine',
    metric_for_best_model="f1",
    greater_is_better=True,
    warmup_ratio=0.1,
    weight_decay=0.01
)

trainer = Trainer(
    model=model, 
    args=args, 
    train_dataset=ds,
    data_collator=collator, 
    tokenizer=tokenizer,
    compute_metrics=partial(compute_metrics, all_labels=all_labels),
)

In [10]:
trainer.train()

  1%|          | 20/2553 [01:25<2:43:51,  3.88s/it]

{'loss': 1.8419, 'grad_norm': 34.67120361328125, 'learning_rate': 1.328125e-06, 'epoch': 0.02}


  2%|▏         | 40/2553 [02:54<2:15:53,  3.24s/it]

{'loss': 1.4023, 'grad_norm': 24.580671310424805, 'learning_rate': 2.8906250000000004e-06, 'epoch': 0.05}


  2%|▏         | 60/2553 [04:09<2:59:20,  4.32s/it]

{'loss': 0.4206, 'grad_norm': 3.322894811630249, 'learning_rate': 4.453125000000001e-06, 'epoch': 0.07}


  3%|▎         | 80/2553 [05:49<2:54:44,  4.24s/it]

{'loss': 0.0219, 'grad_norm': 0.10044004768133163, 'learning_rate': 6.0156250000000005e-06, 'epoch': 0.09}


  4%|▍         | 100/2553 [07:24<3:26:40,  5.06s/it]

{'loss': 0.0086, 'grad_norm': 0.04590647295117378, 'learning_rate': 7.578125e-06, 'epoch': 0.12}


  5%|▍         | 120/2553 [08:37<3:07:12,  4.62s/it]

{'loss': 0.0151, 'grad_norm': 0.10894951969385147, 'learning_rate': 9.140625e-06, 'epoch': 0.14}


  5%|▌         | 140/2553 [10:09<2:38:07,  3.93s/it]

{'loss': 0.009, 'grad_norm': 0.1373862475156784, 'learning_rate': 1.0703125000000001e-05, 'epoch': 0.16}


  6%|▋         | 160/2553 [11:27<2:47:41,  4.20s/it]

{'loss': 0.0194, 'grad_norm': 0.12569521367549896, 'learning_rate': 1.2265625000000002e-05, 'epoch': 0.19}


  7%|▋         | 180/2553 [12:57<3:40:43,  5.58s/it]

{'loss': 0.0042, 'grad_norm': 0.025305725634098053, 'learning_rate': 1.3828125e-05, 'epoch': 0.21}


  8%|▊         | 200/2553 [14:14<3:01:51,  4.64s/it]

{'loss': 0.0113, 'grad_norm': 0.08829409629106522, 'learning_rate': 1.5390625e-05, 'epoch': 0.24}


  9%|▊         | 220/2553 [15:31<2:57:16,  4.56s/it]

{'loss': 0.007, 'grad_norm': 0.04148557782173157, 'learning_rate': 1.6953125e-05, 'epoch': 0.26}


  9%|▉         | 240/2553 [16:39<1:53:33,  2.95s/it]

{'loss': 0.008, 'grad_norm': 0.5220126509666443, 'learning_rate': 1.8515625e-05, 'epoch': 0.28}


 10%|█         | 260/2553 [17:49<2:00:03,  3.14s/it]

{'loss': 0.0042, 'grad_norm': 0.07057224959135056, 'learning_rate': 1.9999990647069838e-05, 'epoch': 0.31}


 11%|█         | 280/2553 [18:57<1:49:01,  2.88s/it]

{'loss': 0.0029, 'grad_norm': 0.0602918416261673, 'learning_rate': 1.999587564069112e-05, 'epoch': 0.33}


 12%|█▏        | 300/2553 [20:14<1:48:04,  2.88s/it]

{'loss': 0.002, 'grad_norm': 0.08650904148817062, 'learning_rate': 1.99842818413438e-05, 'epoch': 0.35}


 13%|█▎        | 320/2553 [21:40<2:55:50,  4.72s/it]

{'loss': 0.0022, 'grad_norm': 0.055620379745960236, 'learning_rate': 1.9965217923367973e-05, 'epoch': 0.38}


 13%|█▎        | 340/2553 [23:11<3:06:36,  5.06s/it]

{'loss': 0.0102, 'grad_norm': 0.36773842573165894, 'learning_rate': 1.993869815015596e-05, 'epoch': 0.4}


 14%|█▍        | 360/2553 [24:38<2:38:29,  4.34s/it]

{'loss': 0.0037, 'grad_norm': 0.053256094455718994, 'learning_rate': 1.9904742363480555e-05, 'epoch': 0.42}


 15%|█▍        | 380/2553 [25:36<1:33:57,  2.59s/it]

{'loss': 0.002, 'grad_norm': 0.02086927928030491, 'learning_rate': 1.986337596864969e-05, 'epoch': 0.45}


 16%|█▌        | 400/2553 [26:52<2:18:53,  3.87s/it]

{'loss': 0.0064, 'grad_norm': 0.02077837847173214, 'learning_rate': 1.981462991549846e-05, 'epoch': 0.47}


 16%|█▋        | 420/2553 [28:07<2:38:43,  4.46s/it]

{'loss': 0.0013, 'grad_norm': 0.029213573783636093, 'learning_rate': 1.9758540675232866e-05, 'epoch': 0.49}


 17%|█▋        | 440/2553 [29:18<2:01:29,  3.45s/it]

{'loss': 0.0031, 'grad_norm': 0.006565848831087351, 'learning_rate': 1.9695150213142465e-05, 'epoch': 0.52}


 18%|█▊        | 460/2553 [30:44<2:22:14,  4.08s/it]

{'loss': 0.0044, 'grad_norm': 0.02498180791735649, 'learning_rate': 1.962450595720248e-05, 'epoch': 0.54}


 19%|█▉        | 480/2553 [32:04<2:14:19,  3.89s/it]

{'loss': 0.0021, 'grad_norm': 0.28366103768348694, 'learning_rate': 1.9546660762588737e-05, 'epoch': 0.56}


 20%|█▉        | 500/2553 [33:07<2:14:33,  3.93s/it]

{'loss': 0.0012, 'grad_norm': 0.005088325589895248, 'learning_rate': 1.9461672872132092e-05, 'epoch': 0.59}


 20%|██        | 520/2553 [34:29<2:03:44,  3.65s/it]

{'loss': 0.0027, 'grad_norm': 0.1837281435728073, 'learning_rate': 1.9369605872741822e-05, 'epoch': 0.61}


 21%|██        | 540/2553 [35:33<1:56:31,  3.47s/it]

{'loss': 0.0011, 'grad_norm': 0.0314670130610466, 'learning_rate': 1.927052864783069e-05, 'epoch': 0.63}


 22%|██▏       | 560/2553 [36:48<1:33:57,  2.83s/it]

{'loss': 0.0019, 'grad_norm': 0.009120455011725426, 'learning_rate': 1.916451532577721e-05, 'epoch': 0.66}


 23%|██▎       | 580/2553 [38:06<2:32:41,  4.64s/it]

{'loss': 0.0027, 'grad_norm': 0.0036727278493344784, 'learning_rate': 1.9051645224463672e-05, 'epoch': 0.68}


 24%|██▎       | 600/2553 [39:24<2:09:56,  3.99s/it]

{'loss': 0.001, 'grad_norm': 0.00715648103505373, 'learning_rate': 1.8932002791931502e-05, 'epoch': 0.71}


 24%|██▍       | 620/2553 [40:43<2:13:52,  4.16s/it]

{'loss': 0.0034, 'grad_norm': 0.007983054965734482, 'learning_rate': 1.8805677543198213e-05, 'epoch': 0.73}


 25%|██▌       | 640/2553 [41:56<1:43:50,  3.26s/it]

{'loss': 0.0014, 'grad_norm': 0.027531830593943596, 'learning_rate': 1.86727639932834e-05, 'epoch': 0.75}


 26%|██▌       | 660/2553 [43:15<2:00:41,  3.83s/it]

{'loss': 0.0013, 'grad_norm': 0.023379599675536156, 'learning_rate': 1.853336158649373e-05, 'epoch': 0.78}


 27%|██▋       | 680/2553 [44:34<1:59:47,  3.84s/it]

{'loss': 0.0019, 'grad_norm': 0.08342902362346649, 'learning_rate': 1.838757462201989e-05, 'epoch': 0.8}


 27%|██▋       | 700/2553 [45:56<1:53:13,  3.67s/it]

{'loss': 0.0025, 'grad_norm': 0.02872801385819912, 'learning_rate': 1.8235512175901253e-05, 'epoch': 0.82}


 28%|██▊       | 720/2553 [47:30<2:17:52,  4.51s/it]

{'loss': 0.0014, 'grad_norm': 0.0038250323850661516, 'learning_rate': 1.8077288019416463e-05, 'epoch': 0.85}


 29%|██▉       | 740/2553 [49:05<2:36:37,  5.18s/it]

{'loss': 0.0028, 'grad_norm': 0.01003308966755867, 'learning_rate': 1.7913020533961155e-05, 'epoch': 0.87}


 30%|██▉       | 760/2553 [50:55<2:37:56,  5.29s/it]

{'loss': 0.0024, 'grad_norm': 0.011119973845779896, 'learning_rate': 1.774283262247644e-05, 'epoch': 0.89}


 31%|███       | 780/2553 [52:35<2:12:55,  4.50s/it]

{'loss': 0.0013, 'grad_norm': 0.006767550483345985, 'learning_rate': 1.7566851617494403e-05, 'epoch': 0.92}


 31%|███▏      | 800/2553 [54:14<1:50:49,  3.79s/it]

{'loss': 0.0023, 'grad_norm': 0.018628723919391632, 'learning_rate': 1.7385209185869456e-05, 'epoch': 0.94}


 32%|███▏      | 820/2553 [55:24<1:37:29,  3.38s/it]

{'loss': 0.0018, 'grad_norm': 0.02739383652806282, 'learning_rate': 1.7198041230266783e-05, 'epoch': 0.96}


 33%|███▎      | 840/2553 [56:54<2:35:07,  5.43s/it]

{'loss': 0.0019, 'grad_norm': 0.008370666764676571, 'learning_rate': 1.700548778748162e-05, 'epoch': 0.99}


 34%|███▎      | 860/2553 [58:02<1:43:16,  3.66s/it]

{'loss': 0.0008, 'grad_norm': 0.010614277794957161, 'learning_rate': 1.6807692923665424e-05, 'epoch': 1.01}


 34%|███▍      | 880/2553 [59:19<1:39:50,  3.58s/it]

{'loss': 0.0013, 'grad_norm': 0.03094007447361946, 'learning_rate': 1.660480462653732e-05, 'epoch': 1.03}


 35%|███▌      | 900/2553 [1:00:27<1:19:42,  2.89s/it]

{'loss': 0.0025, 'grad_norm': 0.12118067592382431, 'learning_rate': 1.6396974694661493e-05, 'epoch': 1.06}


 36%|███▌      | 920/2553 [1:01:35<1:46:16,  3.91s/it]

{'loss': 0.0003, 'grad_norm': 0.0025812131352722645, 'learning_rate': 1.618435862387332e-05, 'epoch': 1.08}


 37%|███▋      | 940/2553 [1:03:11<1:57:52,  4.38s/it]

{'loss': 0.0017, 'grad_norm': 0.09015757590532303, 'learning_rate': 1.596711549093931e-05, 'epoch': 1.1}


 38%|███▊      | 960/2553 [1:05:08<2:37:44,  5.94s/it]

{'loss': 0.0017, 'grad_norm': 0.009390764869749546, 'learning_rate': 1.574540783453775e-05, 'epoch': 1.13}


 38%|███▊      | 980/2553 [1:06:27<1:09:47,  2.66s/it]

{'loss': 0.0007, 'grad_norm': 0.004280119203031063, 'learning_rate': 1.5519401533649275e-05, 'epoch': 1.15}


 39%|███▉      | 1000/2553 [1:07:49<1:13:17,  2.83s/it]

{'loss': 0.0008, 'grad_norm': 0.0014742824714630842, 'learning_rate': 1.528926568344821e-05, 'epoch': 1.18}


 40%|███▉      | 1020/2553 [1:09:20<2:01:58,  4.77s/it]

{'loss': 0.0006, 'grad_norm': 0.05127005651593208, 'learning_rate': 1.5055172468787597e-05, 'epoch': 1.2}


 41%|████      | 1040/2553 [1:10:44<2:11:15,  5.21s/it]

{'loss': 0.0011, 'grad_norm': 0.0043853651732206345, 'learning_rate': 1.4817297035372602e-05, 'epoch': 1.22}


 42%|████▏     | 1060/2553 [1:12:19<2:00:00,  4.82s/it]

{'loss': 0.0007, 'grad_norm': 0.017531612887978554, 'learning_rate': 1.4575817358718609e-05, 'epoch': 1.25}


 42%|████▏     | 1080/2553 [1:13:57<1:51:28,  4.54s/it]

{'loss': 0.0006, 'grad_norm': 0.059839878231287, 'learning_rate': 1.433091411099209e-05, 'epoch': 1.27}


 43%|████▎     | 1100/2553 [1:15:25<1:48:26,  4.48s/it]

{'loss': 0.0003, 'grad_norm': 0.05952679365873337, 'learning_rate': 1.408277052583389e-05, 'epoch': 1.29}


 44%|████▍     | 1120/2553 [1:16:40<1:28:44,  3.72s/it]

{'loss': 0.0017, 'grad_norm': 0.1452808529138565, 'learning_rate': 1.3831572261266035e-05, 'epoch': 1.32}


 45%|████▍     | 1140/2553 [1:17:53<1:41:11,  4.30s/it]

{'loss': 0.0006, 'grad_norm': 0.0027944978792220354, 'learning_rate': 1.3577507260784662e-05, 'epoch': 1.34}


 45%|████▌     | 1160/2553 [1:19:03<1:33:32,  4.03s/it]

{'loss': 0.0007, 'grad_norm': 0.003534047631546855, 'learning_rate': 1.3320765612742945e-05, 'epoch': 1.36}


 46%|████▌     | 1180/2553 [1:20:19<1:18:35,  3.43s/it]

{'loss': 0.0003, 'grad_norm': 0.002823177957907319, 'learning_rate': 1.3061539408129331e-05, 'epoch': 1.39}


 47%|████▋     | 1200/2553 [1:21:44<1:39:33,  4.41s/it]

{'loss': 0.0006, 'grad_norm': 0.006681990809738636, 'learning_rate': 1.2800022596847356e-05, 'epoch': 1.41}


 48%|████▊     | 1220/2553 [1:23:01<1:42:32,  4.62s/it]

{'loss': 0.0012, 'grad_norm': 0.3547970652580261, 'learning_rate': 1.2536410842604703e-05, 'epoch': 1.43}


 49%|████▊     | 1240/2553 [1:24:18<1:30:05,  4.12s/it]

{'loss': 0.0017, 'grad_norm': 0.002558843931183219, 'learning_rate': 1.2270901376519988e-05, 'epoch': 1.46}


 49%|████▉     | 1260/2553 [1:25:34<1:17:33,  3.60s/it]

{'loss': 0.0004, 'grad_norm': 0.012622611597180367, 'learning_rate': 1.20036928495568e-05, 'epoch': 1.48}


 50%|█████     | 1280/2553 [1:26:52<1:26:07,  4.06s/it]

{'loss': 0.0007, 'grad_norm': 0.1525263637304306, 'learning_rate': 1.1734985183895496e-05, 'epoch': 1.5}


 51%|█████     | 1300/2553 [1:28:08<1:07:57,  3.25s/it]

{'loss': 0.0021, 'grad_norm': 0.002412284491583705, 'learning_rate': 1.1464979423353828e-05, 'epoch': 1.53}


 52%|█████▏    | 1320/2553 [1:29:23<1:18:33,  3.82s/it]

{'loss': 0.0007, 'grad_norm': 0.0021953044924885035, 'learning_rate': 1.1193877582968393e-05, 'epoch': 1.55}


 52%|█████▏    | 1340/2553 [1:30:45<1:31:25,  4.52s/it]

{'loss': 0.0024, 'grad_norm': 0.02655080519616604, 'learning_rate': 1.0921882497849448e-05, 'epoch': 1.57}


 53%|█████▎    | 1360/2553 [1:32:17<1:38:22,  4.95s/it]

{'loss': 0.0014, 'grad_norm': 0.03975345939397812, 'learning_rate': 1.0649197671422122e-05, 'epoch': 1.6}


 54%|█████▍    | 1380/2553 [1:33:44<1:29:04,  4.56s/it]

{'loss': 0.0006, 'grad_norm': 0.018507815897464752, 'learning_rate': 1.0376027123167636e-05, 'epoch': 1.62}


 55%|█████▍    | 1400/2553 [1:34:55<1:10:15,  3.66s/it]

{'loss': 0.0006, 'grad_norm': 0.0012540126917883754, 'learning_rate': 1.0102575235978389e-05, 'epoch': 1.65}


 56%|█████▌    | 1420/2553 [1:36:17<1:12:15,  3.83s/it]

{'loss': 0.0003, 'grad_norm': 0.0012486637569963932, 'learning_rate': 9.82904660324116e-06, 'epoch': 1.67}


 56%|█████▋    | 1440/2553 [1:37:37<1:12:45,  3.92s/it]

{'loss': 0.0019, 'grad_norm': 0.07697788625955582, 'learning_rate': 9.555645875762823e-06, 'epoch': 1.69}


 57%|█████▋    | 1460/2553 [1:38:55<1:37:56,  5.38s/it]

{'loss': 0.0031, 'grad_norm': 0.004302347544580698, 'learning_rate': 9.282577608653077e-06, 'epoch': 1.72}


 58%|█████▊    | 1480/2553 [1:40:27<1:06:56,  3.74s/it]

{'loss': 0.0012, 'grad_norm': 0.007997998967766762, 'learning_rate': 9.010046108278813e-06, 'epoch': 1.74}


 59%|█████▉    | 1500/2553 [1:41:56<1:10:04,  3.99s/it]

{'loss': 0.0009, 'grad_norm': 0.01588195562362671, 'learning_rate': 8.738255279404556e-06, 'epoch': 1.76}


 60%|█████▉    | 1520/2553 [1:43:17<1:04:45,  3.76s/it]

{'loss': 0.0007, 'grad_norm': 0.0030946696642786264, 'learning_rate': 8.467408472633387e-06, 'epoch': 1.79}


 60%|██████    | 1540/2553 [1:44:36<59:59,  3.55s/it]  

{'loss': 0.0004, 'grad_norm': 0.0017630235524848104, 'learning_rate': 8.197708332262507e-06, 'epoch': 1.81}


 61%|██████    | 1560/2553 [1:46:10<1:20:46,  4.88s/it]

{'loss': 0.0012, 'grad_norm': 0.001720304018817842, 'learning_rate': 7.929356644667212e-06, 'epoch': 1.83}


 62%|██████▏   | 1580/2553 [1:47:44<1:13:53,  4.56s/it]

{'loss': 0.001, 'grad_norm': 0.047721799463033676, 'learning_rate': 7.66255418732677e-06, 'epoch': 1.86}


 63%|██████▎   | 1600/2553 [1:49:26<1:19:19,  4.99s/it]

{'loss': 0.0006, 'grad_norm': 0.05706571415066719, 'learning_rate': 7.39750057860517e-06, 'epoch': 1.88}


 63%|██████▎   | 1620/2553 [1:51:07<1:01:07,  3.93s/it]

{'loss': 0.0018, 'grad_norm': 0.009366340935230255, 'learning_rate': 7.134394128399061e-06, 'epoch': 1.9}


 64%|██████▍   | 1640/2553 [1:52:32<52:08,  3.43s/it]  

{'loss': 0.0003, 'grad_norm': 0.0013648731401190162, 'learning_rate': 6.873431689764696e-06, 'epoch': 1.93}


 65%|██████▌   | 1660/2553 [1:53:55<46:48,  3.14s/it]  

{'loss': 0.0016, 'grad_norm': 0.0016624890267848969, 'learning_rate': 6.6148085116348595e-06, 'epoch': 1.95}


 66%|██████▌   | 1680/2553 [1:55:29<59:35,  4.10s/it]  

{'loss': 0.0006, 'grad_norm': 0.08644425123929977, 'learning_rate': 6.358718092735992e-06, 'epoch': 1.97}


 67%|██████▋   | 1700/2553 [1:56:45<49:26,  3.48s/it]  

{'loss': 0.0003, 'grad_norm': 0.02995803765952587, 'learning_rate': 6.1053520368147865e-06, 'epoch': 2.0}


 67%|██████▋   | 1720/2553 [1:58:16<58:06,  4.19s/it]  

{'loss': 0.0005, 'grad_norm': 0.027580013498663902, 'learning_rate': 5.8548999092825745e-06, 'epoch': 2.02}


 68%|██████▊   | 1740/2553 [1:59:36<58:19,  4.30s/it]  

{'loss': 0.0004, 'grad_norm': 0.0015521232271566987, 'learning_rate': 5.607549095384805e-06, 'epoch': 2.04}


 69%|██████▉   | 1760/2553 [2:00:46<54:34,  4.13s/it]  

{'loss': 0.0007, 'grad_norm': 0.058435048907995224, 'learning_rate': 5.363484660001659e-06, 'epoch': 2.07}


 70%|██████▉   | 1780/2553 [2:01:58<39:11,  3.04s/it]

{'loss': 0.0009, 'grad_norm': 0.031119346618652344, 'learning_rate': 5.12288920918475e-06, 'epoch': 2.09}


 71%|███████   | 1800/2553 [2:03:17<50:10,  4.00s/it]  

{'loss': 0.0008, 'grad_norm': 0.1845477819442749, 'learning_rate': 4.885942753533516e-06, 'epoch': 2.12}


 71%|███████▏  | 1820/2553 [2:04:33<52:02,  4.26s/it]

{'loss': 0.0002, 'grad_norm': 0.0020683316979557276, 'learning_rate': 4.652822573513422e-06, 'epoch': 2.14}


 72%|███████▏  | 1840/2553 [2:05:41<48:09,  4.05s/it]

{'loss': 0.0005, 'grad_norm': 0.00274756271392107, 'learning_rate': 4.4237030868169185e-06, 'epoch': 2.16}


 73%|███████▎  | 1860/2553 [2:06:59<51:04,  4.42s/it]

{'loss': 0.0002, 'grad_norm': 0.004180385731160641, 'learning_rate': 4.198755717866205e-06, 'epoch': 2.19}


 74%|███████▎  | 1880/2553 [2:08:12<49:10,  4.38s/it]

{'loss': 0.0003, 'grad_norm': 0.004487969446927309, 'learning_rate': 3.978148769555553e-06, 'epoch': 2.21}


 74%|███████▍  | 1900/2553 [2:09:30<31:43,  2.92s/it]

{'loss': 0.0003, 'grad_norm': 0.015574700199067593, 'learning_rate': 3.7620472973291567e-06, 'epoch': 2.23}


 75%|███████▌  | 1920/2553 [2:10:40<31:56,  3.03s/it]

{'loss': 0.0008, 'grad_norm': 0.0010355972917750478, 'learning_rate': 3.5506129856886064e-06, 'epoch': 2.26}


 76%|███████▌  | 1940/2553 [2:11:57<43:25,  4.25s/it]

{'loss': 0.0009, 'grad_norm': 0.004691922105848789, 'learning_rate': 3.344004027222533e-06, 'epoch': 2.28}


 77%|███████▋  | 1960/2553 [2:13:06<22:10,  2.24s/it]

{'loss': 0.0005, 'grad_norm': 0.12228234112262726, 'learning_rate': 3.142375004248822e-06, 'epoch': 2.3}


 78%|███████▊  | 1980/2553 [2:14:19<27:25,  2.87s/it]

{'loss': 0.0009, 'grad_norm': 0.017235729843378067, 'learning_rate': 2.9458767731579974e-06, 'epoch': 2.33}


 78%|███████▊  | 2000/2553 [2:15:35<36:05,  3.92s/it]

{'loss': 0.0002, 'grad_norm': 0.023902123793959618, 'learning_rate': 2.7546563515442994e-06, 'epoch': 2.35}


 79%|███████▉  | 2020/2553 [2:16:55<27:40,  3.11s/it]

{'loss': 0.0001, 'grad_norm': 0.0009022831218317151, 'learning_rate': 2.568856808208905e-06, 'epoch': 2.37}


 80%|███████▉  | 2040/2553 [2:18:13<36:37,  4.28s/it]

{'loss': 0.0004, 'grad_norm': 0.025068266317248344, 'learning_rate': 2.388617156117583e-06, 'epoch': 2.4}


 81%|████████  | 2060/2553 [2:19:32<37:10,  4.52s/it]

{'loss': 0.0001, 'grad_norm': 0.003175614634528756, 'learning_rate': 2.214072248392879e-06, 'epoch': 2.42}


 81%|████████▏ | 2080/2553 [2:20:52<34:40,  4.40s/it]

{'loss': 0.0005, 'grad_norm': 0.004155599046498537, 'learning_rate': 2.0453526774186415e-06, 'epoch': 2.44}


 82%|████████▏ | 2100/2553 [2:22:01<23:45,  3.15s/it]

{'loss': 0.0004, 'grad_norm': 0.05457166209816933, 'learning_rate': 1.8825846771323963e-06, 'epoch': 2.47}


 83%|████████▎ | 2120/2553 [2:23:05<24:36,  3.41s/it]

{'loss': 0.0004, 'grad_norm': 0.0011616392293944955, 'learning_rate': 1.7258900285786206e-06, 'epoch': 2.49}


 84%|████████▍ | 2140/2553 [2:24:20<23:34,  3.42s/it]

{'loss': 0.0005, 'grad_norm': 0.0009765063878148794, 'learning_rate': 1.5753859687936624e-06, 'epoch': 2.51}


 85%|████████▍ | 2160/2553 [2:25:40<24:45,  3.78s/it]

{'loss': 0.0003, 'grad_norm': 0.009198080748319626, 'learning_rate': 1.4311851030904123e-06, 'epoch': 2.54}


 85%|████████▌ | 2180/2553 [2:27:12<28:39,  4.61s/it]

{'loss': 0.0002, 'grad_norm': 0.0020256510470062494, 'learning_rate': 1.2933953208083527e-06, 'epoch': 2.56}


 86%|████████▌ | 2200/2553 [2:28:35<20:37,  3.51s/it]

{'loss': 0.0009, 'grad_norm': 0.0009522838518023491, 'learning_rate': 1.1621197145920826e-06, 'epoch': 2.59}


 87%|████████▋ | 2220/2553 [2:30:05<27:50,  5.02s/it]

{'loss': 0.0006, 'grad_norm': 0.010744785889983177, 'learning_rate': 1.0374565032586503e-06, 'epoch': 2.61}


 88%|████████▊ | 2240/2553 [2:31:34<21:16,  4.08s/it]

{'loss': 0.0007, 'grad_norm': 0.0009587712702341378, 'learning_rate': 9.194989583114012e-07, 'epoch': 2.63}


 89%|████████▊ | 2260/2553 [2:32:47<22:36,  4.63s/it]

{'loss': 0.0002, 'grad_norm': 0.0008816990884952247, 'learning_rate': 8.083353341554057e-07, 'epoch': 2.66}


 89%|████████▉ | 2280/2553 [2:34:02<18:52,  4.15s/it]

{'loss': 0.0003, 'grad_norm': 0.008161123842000961, 'learning_rate': 7.040488020665559e-07, 'epoch': 2.68}


 90%|█████████ | 2300/2553 [2:35:25<18:20,  4.35s/it]

{'loss': 0.0002, 'grad_norm': 0.0009013846865855157, 'learning_rate': 6.067173879638299e-07, 'epoch': 2.7}


 91%|█████████ | 2320/2553 [2:36:46<14:41,  3.78s/it]

{'loss': 0.0002, 'grad_norm': 0.0009755913633853197, 'learning_rate': 5.164139140312597e-07, 'epoch': 2.73}


 92%|█████████▏| 2340/2553 [2:38:06<16:11,  4.56s/it]

{'loss': 0.0014, 'grad_norm': 0.012593158520758152, 'learning_rate': 4.3320594423323815e-07, 'epoch': 2.75}


 92%|█████████▏| 2360/2553 [2:39:23<09:59,  3.11s/it]

{'loss': 0.0007, 'grad_norm': 0.0009416328393854201, 'learning_rate': 3.5715573376399194e-07, 'epoch': 2.77}


 93%|█████████▎| 2380/2553 [2:40:37<13:29,  4.68s/it]

{'loss': 0.0017, 'grad_norm': 0.10463856905698776, 'learning_rate': 2.8832018246899076e-07, 'epoch': 2.8}


 94%|█████████▍| 2400/2553 [2:41:53<11:00,  4.32s/it]

{'loss': 0.0003, 'grad_norm': 0.0008544187876395881, 'learning_rate': 2.2675079227318375e-07, 'epoch': 2.82}


 95%|█████████▍| 2420/2553 [2:43:10<08:06,  3.66s/it]

{'loss': 0.0003, 'grad_norm': 0.0023097118828445673, 'learning_rate': 1.7249362864787177e-07, 'epoch': 2.84}


 96%|█████████▌| 2440/2553 [2:44:34<09:19,  4.95s/it]

{'loss': 0.0004, 'grad_norm': 0.0750170350074768, 'learning_rate': 1.2558928614508915e-07, 'epoch': 2.87}


 96%|█████████▋| 2460/2553 [2:45:53<05:02,  3.25s/it]

{'loss': 0.0003, 'grad_norm': 0.015982935205101967, 'learning_rate': 8.6072858025249e-08, 'epoch': 2.89}


 97%|█████████▋| 2480/2553 [2:47:19<05:11,  4.27s/it]

{'loss': 0.0007, 'grad_norm': 0.0008800806826911867, 'learning_rate': 5.397391000078966e-08, 'epoch': 2.91}


 98%|█████████▊| 2500/2553 [2:48:23<02:37,  2.98s/it]

{'loss': 0.0004, 'grad_norm': 0.0036601191386580467, 'learning_rate': 2.931645811546924e-08, 'epoch': 2.94}


 99%|█████████▊| 2520/2553 [2:49:48<01:55,  3.51s/it]

{'loss': 0.0005, 'grad_norm': 0.0008605211041867733, 'learning_rate': 1.2118950775851235e-08, 'epoch': 2.96}


 99%|█████████▉| 2540/2553 [2:51:12<01:01,  4.76s/it]

{'loss': 0.0006, 'grad_norm': 0.0012916140258312225, 'learning_rate': 2.3942549484312627e-09, 'epoch': 2.98}


100%|██████████| 2553/2553 [2:51:53<00:00,  4.04s/it]

{'train_runtime': 10313.0764, 'train_samples_per_second': 1.98, 'train_steps_per_second': 0.248, 'train_loss': 0.030670080014204994, 'epoch': 3.0}





TrainOutput(global_step=2553, training_loss=0.030670080014204994, metrics={'train_runtime': 10313.0764, 'train_samples_per_second': 1.98, 'train_steps_per_second': 0.248, 'train_loss': 0.030670080014204994, 'epoch': 3.0})