<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc"><ul class="toc-item"><li><span><a href="#Fine-tuning-a-deep-learning-model-with-HuggingFace-🤗Datasets-library-and-PyTorch" data-toc-modified-id="Fine-tuning-a-deep-learning-model-with-HuggingFace-🤗Datasets-library-and-PyTorch-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Fine-tuning a deep-learning model with HuggingFace <code>🤗Datasets</code> library and PyTorch</a></span><ul class="toc-item"><li><ul class="toc-item"><li><span><a href="#Load-the-model-and-the-tokenizer" data-toc-modified-id="Load-the-model-and-the-tokenizer-1.0.1"><span class="toc-item-num">1.0.1&nbsp;&nbsp;</span>Load the model and the tokenizer</a></span></li><li><span><a href="#Load-and-process-the-dataset" data-toc-modified-id="Load-and-process-the-dataset-1.0.2"><span class="toc-item-num">1.0.2&nbsp;&nbsp;</span>Load and process the dataset</a></span><ul class="toc-item"><li><span><a href="#Tokenizing-the-dataset" data-toc-modified-id="Tokenizing-the-dataset-1.0.2.1"><span class="toc-item-num">1.0.2.1&nbsp;&nbsp;</span>Tokenizing the dataset</a></span></li><li><span><a href="#Formatting-the-dataset" data-toc-modified-id="Formatting-the-dataset-1.0.2.2"><span class="toc-item-num">1.0.2.2&nbsp;&nbsp;</span>Formatting the dataset</a></span></li><li><span><a href="#Training" data-toc-modified-id="Training-1.0.2.3"><span class="toc-item-num">1.0.2.3&nbsp;&nbsp;</span>Training</a></span></li></ul></li></ul></li></ul></li></ul></div>

# Fine-tuning a deep-learning model with HuggingFace `🤗Datasets` library and PyTorch

Based on its [Quick tour example](https://huggingface.co/docs/datasets/quicktour.html)

Colab author: Manuel Romero / [@mrm8488](https://twitter.com/mrm8488)

In [3]:
# Make sure that we have a recent version of pyarrow in the session before we continue - otherwise reboot Colab to activate it
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16 and int(pyarrow.__version__.split('.')[0]) == 0:
    import os
    os.kill(os.getpid(), 9)

### Load the model and the tokenizer

In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

### Load and process the dataset

In [5]:
from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')

Reusing dataset glue (/home/quan/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4)


In [7]:
len(dataset)

3668

In [9]:
dataset[0]

{'idx': 0,
 'label': 1,
 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'}

In [10]:
dataset.features

{'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None),
 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),
 'idx': Value(dtype='int32', id=None)}

In [None]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('equivalent'))[0]

In [None]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('not_equivalent'))[0]

#### Tokenizing the dataset

In [11]:
def encode(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length')
dataset = dataset.map(encode, batched=True)


Loading cached processed dataset at /home/quan/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4/cache-ecb3eae944990a41.arrow


In [12]:
list(dataset[0].keys())

['attention_mask',
 'idx',
 'input_ids',
 'label',
 'sentence1',
 'sentence2',
 'token_type_ids']

In [26]:
len(dataset[0]['input_ids']),len(dataset[0]['attention_mask']),len(dataset[0]['token_type_ids'])

(512, 512, 512)

In [34]:
print(dataset[0]['input_ids'])
print(dataset[0]['input_ids'].index(0))

[101, 7277, 2180, 5303, 4806, 1117, 1711, 117, 2292, 1119, 1270, 107, 1103, 7737, 107, 117, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 11336, 6732, 3384, 1106, 1140, 1112, 1178, 107, 1103, 7737, 107, 117, 7277, 2180, 5303, 4806, 1117, 1711, 1104, 9938, 4267, 12223, 21811, 1117, 2554, 119, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

In [32]:
print([tokenizer.decode(i) for i in dataset[0]['input_ids']])

['[CLS]', 'Am', '##ro', '##zi', 'accused', 'his', 'brother', ',', 'whom', 'he', 'called', '"', 'the', 'witness', '"', ',', 'of', 'deliberately', 'di', '##sto', '##rting', 'his', 'evidence', '.', '[SEP]', 'Re', '##fer', '##ring', 'to', 'him', 'as', 'only', '"', 'the', 'witness', '"', ',', 'Am', '##ro', '##zi', 'accused', 'his', 'brother', 'of', 'deliberately', 'di', '##sto', '##rting', 'his', 'evidence', '.', '[SEP]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[PAD]', '[P

In [38]:
print(dataset[0]['token_type_ids']) 
print(dataset[0]['token_type_ids'].index(1)) # to signal the 2nd sentence

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [35]:
print(dataset[0]['attention_mask'])
print(dataset[0]['attention_mask'].index(0))

[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 

In [21]:
print(dataset[0]['sentence1'])

Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .


In [22]:
print(dataset[0]['sentence2'])

Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .


512


#### Formatting the dataset

In [42]:
dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [43]:
dataset[0]['label'],dataset[0]['labels']

(1, 1)

In [44]:
import torch
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
next(iter(dataloader))

  return torch.tensor(x, **format_kwargs)


{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,  7277,  2180,  ...,     0,     0,     0],
         [  101, 10684,  2599,  ...,     0,     0,     0],
         [  101,  1220,  1125,  ...,     0,     0,     0],
         ...,
         [  101, 16944,  1107,  ...,     0,     0,     0],
         [  101,  1109, 11896,  ...,     0,     0,     0],
         [  101,  1109,  4173,  ...,     0,     0,     0]]),
 'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]])}

In [47]:
tmp = next(iter(dataloader))

In [50]:
len(tmp) # 4 keys

4

In [51]:
tmp['attention_mask'].shape # bs = 8

torch.Size([8, 512])

#### Training

In [45]:
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

cuda


In [46]:
model.train().to(device)

# note that we let Adam optimizer to access all the parameters in the model for Adam GD
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

In [52]:
type(model)

transformers.models.bert.modeling_bert.BertForSequenceClassification

In [55]:
tmp_cuda =  {k: v.to(device) for k, v in tmp.items()}

tmp_outp = model(**tmp_cuda)

In [56]:
len(tmp_outp) # output of huggingface model is always a tuple.


2

In [59]:
# for this case, the first element is the loss (negative log loss, as we want to see if 2 sentences are equivalent)
# (since the model we are using is 'SequenceClassification?')
tmp_outp[0]

tensor(0.7886, device='cuda:0', grad_fn=<NllLossBackward>)

In [63]:
tmp_outp[1].shape

torch.Size([8, 2])

In [64]:
tmp_outp[1] # the 2nd element is just the head output (presoftmax)

tensor([[ 0.3832,  0.0596],
        [ 0.2036, -0.1933],
        [ 0.1324, -0.3716],
        [ 0.4472,  0.1715],
        [ 0.2803, -0.1362],
        [ 0.4707,  0.0265],
        [ 0.4794, -0.0466],
        [ 0.3093, -0.3229]], device='cuda:0', grad_fn=<AddmmBackward>)

In [65]:
for epoch in range(3):
    for i, batch in enumerate(tqdm(dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {loss}")

  0%|          | 2/459 [00:00<01:28,  5.17it/s]

loss: 0.7385174632072449


  3%|▎         | 12/459 [00:01<01:09,  6.41it/s]

loss: 0.6396863460540771


  5%|▍         | 22/459 [00:03<01:08,  6.36it/s]

loss: 0.5599258542060852


  7%|▋         | 32/459 [00:05<01:06,  6.40it/s]

loss: 0.580287516117096


  9%|▉         | 42/459 [00:06<01:04,  6.42it/s]

loss: 0.5933288335800171


 11%|█▏        | 52/459 [00:08<01:04,  6.27it/s]

loss: 0.5439509153366089


 14%|█▎        | 62/459 [00:09<01:02,  6.40it/s]

loss: 0.784173846244812


 16%|█▌        | 72/459 [00:11<01:02,  6.19it/s]

loss: 0.5773892998695374


 18%|█▊        | 82/459 [00:12<01:00,  6.19it/s]

loss: 0.6961305141448975


 20%|██        | 92/459 [00:14<00:58,  6.29it/s]

loss: 0.7113283276557922


 22%|██▏       | 102/459 [00:16<00:57,  6.23it/s]

loss: 0.5707622170448303


 24%|██▍       | 112/459 [00:17<00:56,  6.14it/s]

loss: 0.7540099024772644


 27%|██▋       | 122/459 [00:19<00:54,  6.15it/s]

loss: 0.48102545738220215


 29%|██▉       | 132/459 [00:21<00:52,  6.18it/s]

loss: 0.7389267086982727


 31%|███       | 142/459 [00:22<00:51,  6.17it/s]

loss: 0.5252507328987122


 33%|███▎      | 152/459 [00:24<00:50,  6.12it/s]

loss: 0.481222540140152


 35%|███▌      | 162/459 [00:25<00:48,  6.12it/s]

loss: 0.6592655181884766


 37%|███▋      | 172/459 [00:27<00:45,  6.33it/s]

loss: 0.4697287380695343


 40%|███▉      | 182/459 [00:29<00:43,  6.33it/s]

loss: 0.6142392158508301


 42%|████▏     | 192/459 [00:30<00:42,  6.31it/s]

loss: 0.6753271818161011


 44%|████▍     | 202/459 [00:32<00:41,  6.13it/s]

loss: 0.3276720643043518


 46%|████▌     | 212/459 [00:33<00:39,  6.28it/s]

loss: 0.2992999255657196


 48%|████▊     | 222/459 [00:35<00:38,  6.19it/s]

loss: 0.7070814371109009


 51%|█████     | 232/459 [00:37<00:36,  6.20it/s]

loss: 0.35821616649627686


 53%|█████▎    | 242/459 [00:38<00:34,  6.27it/s]

loss: 0.5617762804031372


 55%|█████▍    | 252/459 [00:40<00:33,  6.14it/s]

loss: 0.7065333127975464


 57%|█████▋    | 262/459 [00:41<00:31,  6.17it/s]

loss: 0.6581903100013733


 59%|█████▉    | 272/459 [00:43<00:30,  6.14it/s]

loss: 0.5468242764472961


 61%|██████▏   | 282/459 [00:45<00:28,  6.23it/s]

loss: 0.43561458587646484


 64%|██████▎   | 292/459 [00:46<00:27,  6.15it/s]

loss: 0.2961084842681885


 66%|██████▌   | 302/459 [00:48<00:25,  6.24it/s]

loss: 0.723567008972168


 68%|██████▊   | 312/459 [00:50<00:23,  6.16it/s]

loss: 0.4911409616470337


 70%|███████   | 322/459 [00:51<00:22,  6.21it/s]

loss: 0.2670542597770691


 72%|███████▏  | 332/459 [00:53<00:20,  6.17it/s]

loss: 0.2044954001903534


 75%|███████▍  | 342/459 [00:54<00:18,  6.27it/s]

loss: 0.31846243143081665


 77%|███████▋  | 352/459 [00:56<00:17,  6.26it/s]

loss: 0.466289758682251


 79%|███████▉  | 362/459 [00:58<00:15,  6.26it/s]

loss: 0.15207155048847198


 81%|████████  | 372/459 [00:59<00:15,  5.66it/s]

loss: 0.32111233472824097


 83%|████████▎ | 382/459 [01:01<00:12,  6.19it/s]

loss: 0.16522280871868134


 85%|████████▌ | 392/459 [01:03<00:10,  6.25it/s]

loss: 0.4492148160934448


 88%|████████▊ | 402/459 [01:04<00:09,  6.18it/s]

loss: 1.1196686029434204


 90%|████████▉ | 412/459 [01:06<00:07,  6.23it/s]

loss: 0.3662930727005005


 92%|█████████▏| 422/459 [01:07<00:05,  6.22it/s]

loss: 0.3457149565219879


 94%|█████████▍| 432/459 [01:09<00:04,  6.21it/s]

loss: 0.2896239161491394


 96%|█████████▋| 442/459 [01:11<00:02,  6.13it/s]

loss: 0.7615089416503906


 98%|█████████▊| 452/459 [01:12<00:01,  6.05it/s]

loss: 0.6900796294212341


100%|██████████| 459/459 [01:13<00:00,  6.21it/s]
  0%|          | 2/459 [00:00<01:15,  6.04it/s]

loss: 0.6395913362503052


  3%|▎         | 12/459 [00:01<01:13,  6.08it/s]

loss: 1.0151634216308594


  5%|▍         | 22/459 [00:03<01:12,  6.03it/s]

loss: 0.30757787823677063


  7%|▋         | 32/459 [00:05<01:09,  6.10it/s]

loss: 0.1513495147228241


  9%|▉         | 42/459 [00:06<01:09,  6.02it/s]

loss: 0.37375688552856445


 11%|█▏        | 52/459 [00:08<01:06,  6.11it/s]

loss: 0.2504567801952362


 14%|█▎        | 62/459 [00:10<01:05,  6.09it/s]

loss: 0.843349814414978


 16%|█▌        | 72/459 [00:11<01:03,  6.05it/s]

loss: 0.663969874382019


 18%|█▊        | 82/459 [00:13<01:02,  6.07it/s]

loss: 0.5129697322845459


 20%|██        | 92/459 [00:15<01:00,  6.04it/s]

loss: 0.44292324781417847


 22%|██▏       | 102/459 [00:16<00:59,  6.04it/s]

loss: 0.3749794363975525


 24%|██▍       | 112/459 [00:18<00:56,  6.15it/s]

loss: 0.507390022277832


 27%|██▋       | 122/459 [00:20<00:54,  6.13it/s]

loss: 0.3866831064224243


 29%|██▉       | 132/459 [00:21<00:54,  5.97it/s]

loss: 0.33217400312423706


 31%|███       | 142/459 [00:23<00:51,  6.11it/s]

loss: 0.20392589271068573


 33%|███▎      | 152/459 [00:24<00:50,  6.09it/s]

loss: 0.21191030740737915


 35%|███▌      | 162/459 [00:26<00:48,  6.08it/s]

loss: 0.35859110951423645


 37%|███▋      | 172/459 [00:28<00:47,  6.08it/s]

loss: 0.26240360736846924


 40%|███▉      | 182/459 [00:29<00:44,  6.20it/s]

loss: 0.46699756383895874


 42%|████▏     | 192/459 [00:31<00:44,  6.05it/s]

loss: 0.4052182734012604


 44%|████▍     | 202/459 [00:33<00:42,  6.11it/s]

loss: 0.17324300110340118


 46%|████▌     | 212/459 [00:34<00:39,  6.18it/s]

loss: 0.06379885226488113


 48%|████▊     | 222/459 [00:36<00:38,  6.20it/s]

loss: 0.5905288457870483


 51%|█████     | 232/459 [00:38<00:37,  6.10it/s]

loss: 0.1922043263912201


 53%|█████▎    | 242/459 [00:39<00:35,  6.13it/s]

loss: 0.40016743540763855


 55%|█████▍    | 252/459 [00:41<00:34,  6.04it/s]

loss: 0.5182885527610779


 57%|█████▋    | 262/459 [00:43<00:32,  6.03it/s]

loss: 0.24538220465183258


 59%|█████▉    | 272/459 [00:44<00:30,  6.14it/s]

loss: 0.3315081000328064


 61%|██████▏   | 282/459 [00:46<00:28,  6.12it/s]

loss: 0.20186308026313782


 64%|██████▎   | 292/459 [00:47<00:27,  5.98it/s]

loss: 0.2977391481399536


 66%|██████▌   | 302/459 [00:49<00:25,  6.13it/s]

loss: 0.43343889713287354


 68%|██████▊   | 312/459 [00:51<00:24,  6.10it/s]

loss: 0.2447441965341568


 70%|███████   | 322/459 [00:52<00:22,  6.06it/s]

loss: 0.14581812918186188


 72%|███████▏  | 332/459 [00:54<00:20,  6.10it/s]

loss: 0.042760152369737625


 75%|███████▍  | 342/459 [00:56<00:19,  6.10it/s]

loss: 0.21027135848999023


 77%|███████▋  | 352/459 [00:57<00:17,  6.03it/s]

loss: 0.30647265911102295


 79%|███████▉  | 362/459 [00:59<00:16,  6.03it/s]

loss: 0.05156349763274193


 81%|████████  | 372/459 [01:01<00:14,  6.03it/s]

loss: 0.1352405548095703


 83%|████████▎ | 382/459 [01:02<00:12,  6.12it/s]

loss: 0.18216073513031006


 85%|████████▌ | 392/459 [01:04<00:11,  6.06it/s]

loss: 0.18528293073177338


 88%|████████▊ | 402/459 [01:06<00:09,  6.07it/s]

loss: 0.48762351274490356


 90%|████████▉ | 412/459 [01:07<00:07,  6.02it/s]

loss: 0.34017160534858704


 92%|█████████▏| 422/459 [01:09<00:06,  6.02it/s]

loss: 0.07095564901828766


 94%|█████████▍| 432/459 [01:11<00:04,  5.97it/s]

loss: 0.05436824634671211


 96%|█████████▋| 442/459 [01:12<00:02,  6.16it/s]

loss: 0.7057621479034424


 98%|█████████▊| 452/459 [01:14<00:01,  6.03it/s]

loss: 0.32769590616226196


100%|██████████| 459/459 [01:15<00:00,  6.08it/s]
  0%|          | 2/459 [00:00<01:15,  6.07it/s]

loss: 0.15865696966648102


  3%|▎         | 12/459 [00:01<01:13,  6.10it/s]

loss: 0.9185944199562073


  5%|▍         | 22/459 [00:03<01:10,  6.19it/s]

loss: 0.15081919729709625


  7%|▋         | 32/459 [00:05<01:10,  6.08it/s]

loss: 0.0896751657128334


  9%|▉         | 42/459 [00:06<01:08,  6.06it/s]

loss: 0.06325718015432358


 11%|█▏        | 52/459 [00:08<01:06,  6.08it/s]

loss: 0.04721106216311455


 14%|█▎        | 62/459 [00:10<01:04,  6.16it/s]

loss: 0.5421358942985535


 16%|█▌        | 72/459 [00:11<01:02,  6.16it/s]

loss: 0.6455751061439514


 18%|█▊        | 82/459 [00:13<01:01,  6.12it/s]

loss: 0.1967630237340927


 20%|██        | 92/459 [00:15<00:59,  6.13it/s]

loss: 0.3068765103816986


 22%|██▏       | 102/459 [00:16<00:59,  5.97it/s]

loss: 0.367424875497818


 24%|██▍       | 112/459 [00:18<00:57,  6.08it/s]

loss: 0.11551333963871002


 27%|██▋       | 122/459 [00:19<00:55,  6.10it/s]

loss: 0.17283883690834045


 29%|██▉       | 132/459 [00:21<00:53,  6.09it/s]

loss: 0.4919939935207367


 31%|███       | 142/459 [00:23<00:51,  6.14it/s]

loss: 0.24960841238498688


 33%|███▎      | 152/459 [00:24<00:50,  6.12it/s]

loss: 0.05049392580986023


 35%|███▌      | 162/459 [00:26<00:49,  5.95it/s]

loss: 0.39103004336357117


 37%|███▋      | 172/459 [00:28<00:47,  6.00it/s]

loss: 0.08831337094306946


 40%|███▉      | 182/459 [00:29<00:46,  6.01it/s]

loss: 0.058148086071014404


 42%|████▏     | 192/459 [00:31<00:43,  6.14it/s]

loss: 0.5763173699378967


 44%|████▍     | 202/459 [00:33<00:42,  6.09it/s]

loss: 0.01973964087665081


 46%|████▌     | 212/459 [00:34<00:40,  6.12it/s]

loss: 0.04940488934516907


 48%|████▊     | 222/459 [00:36<00:38,  6.09it/s]

loss: 0.045024577528238297


 51%|█████     | 232/459 [00:38<00:37,  6.04it/s]

loss: 0.2407701313495636


 53%|█████▎    | 242/459 [00:39<00:35,  6.12it/s]

loss: 0.4587419927120209


 55%|█████▍    | 252/459 [00:41<00:33,  6.14it/s]

loss: 0.43034496903419495


 57%|█████▋    | 262/459 [00:42<00:32,  6.07it/s]

loss: 0.04627221077680588


 59%|█████▉    | 272/459 [00:44<00:31,  5.98it/s]

loss: 0.11038823425769806


 61%|██████▏   | 282/459 [00:46<00:29,  5.98it/s]

loss: 0.026460517197847366


 64%|██████▎   | 292/459 [00:47<00:27,  5.98it/s]

loss: 0.031553205102682114


 66%|██████▌   | 302/459 [00:49<00:26,  5.98it/s]

loss: 0.14787940680980682


 68%|██████▊   | 312/459 [00:51<00:24,  5.98it/s]

loss: 0.154608353972435


 70%|███████   | 322/459 [00:52<00:22,  5.99it/s]

loss: 0.04358350485563278


 72%|███████▏  | 332/459 [00:54<00:20,  6.05it/s]

loss: 0.018410921096801758


 75%|███████▍  | 342/459 [00:56<00:19,  6.11it/s]

loss: 0.05729895085096359


 77%|███████▋  | 352/459 [00:57<00:17,  6.03it/s]

loss: 0.28072720766067505


 79%|███████▉  | 362/459 [00:59<00:15,  6.19it/s]

loss: 0.06360605359077454


 81%|████████  | 372/459 [01:01<00:14,  6.13it/s]

loss: 0.050791747868061066


 83%|████████▎ | 382/459 [01:02<00:12,  6.04it/s]

loss: 0.18050871789455414


 85%|████████▌ | 392/459 [01:04<00:11,  6.06it/s]

loss: 0.038048695772886276


 88%|████████▊ | 402/459 [01:06<00:09,  6.10it/s]

loss: 0.824906587600708


 90%|████████▉ | 412/459 [01:07<00:07,  6.06it/s]

loss: 0.0655549168586731


 92%|█████████▏| 422/459 [01:09<00:06,  6.08it/s]

loss: 0.045639004558324814


 94%|█████████▍| 432/459 [01:11<00:04,  6.02it/s]

loss: 0.02948756143450737


 96%|█████████▋| 442/459 [01:12<00:02,  6.14it/s]

loss: 0.4335087537765503


 98%|█████████▊| 452/459 [01:14<00:01,  6.04it/s]

loss: 0.10622315853834152


100%|██████████| 459/459 [01:15<00:00,  6.08it/s]
