# Fine-tuning a deep-learning model with HuggingFace `🤗Datasets` library and PyTorch

Based on its [Quick tour example](https://huggingface.co/docs/datasets/quicktour.html)

Colab author: Manuel Romero / [@mrm8488](https://twitter.com/mrm8488)

In [1]:
!nvidia-smi

Sat Nov 14 13:42:22 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 455.32.00    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   58C    P8    10W /  70W |      0MiB / 15079MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install -q datasets
!pip install -q transformers

[K     |████████████████████████████████| 153kB 7.7MB/s 
[K     |████████████████████████████████| 17.7MB 211kB/s 
[K     |████████████████████████████████| 245kB 51.9MB/s 
[K     |████████████████████████████████| 1.3MB 8.8MB/s 
[K     |████████████████████████████████| 890kB 57.1MB/s 
[K     |████████████████████████████████| 2.9MB 61.7MB/s 
[K     |████████████████████████████████| 1.1MB 53.7MB/s 
[?25h  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone


In [3]:
# Make sure that we have a recent version of pyarrow in the session before we continue - otherwise reboot Colab to activate it
import pyarrow
if int(pyarrow.__version__.split('.')[1]) < 16 and int(pyarrow.__version__.split('.')[0]) == 0:
    import os
    os.kill(os.getpid(), 9)

### Load the model and the tokenizer

In [4]:
from transformers import AutoModelForSequenceClassification, AutoTokenizer
model = AutoModelForSequenceClassification.from_pretrained('bert-base-cased')
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=433.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=435779157.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-cased were not used when initializing BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at b

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=213450.0, style=ProgressStyle(descripti…




### Load and process the dataset

In [5]:
from datasets import load_dataset
dataset = load_dataset('glue', 'mrpc', split='train')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=7826.0, style=ProgressStyle(description…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=4473.0, style=ProgressStyle(description…


Downloading and preparing dataset glue/mrpc (download: 1.43 MiB, generated: 1.43 MiB, post-processed: Unknown size, total: 2.85 MiB) to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4...


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Downloading', max=1.0, style=ProgressSt…




HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))



HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/mrpc/1.0.0/7c99657241149a24692c402a5c3f34d4c9f1df5ac2e4c3759fadea38f6cb29c4. Subsequent calls will reuse this data.


In [6]:
len(dataset)

3668

In [7]:
dataset[0]

{'idx': 0,
 'label': 1,
 'sentence1': 'Amrozi accused his brother , whom he called " the witness " , of deliberately distorting his evidence .',
 'sentence2': 'Referring to him as only " the witness " , Amrozi accused his brother of deliberately distorting his evidence .'}

In [8]:
dataset.features

{'idx': Value(dtype='int32', id=None),
 'label': ClassLabel(num_classes=2, names=['not_equivalent', 'equivalent'], names_file=None, id=None),
 'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None)}

In [9]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('equivalent'))[0]

In [10]:
# dataset.filter(lambda example: example['label'] == dataset.features['label'].str2int('not_equivalent'))[0]

#### Tokenizing the dataset

In [11]:
def encode(examples):
    return tokenizer(examples['sentence1'], examples['sentence2'], truncation=True, padding='max_length')
dataset = dataset.map(encode, batched=True)
dataset[0]

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




{'attention_mask': [1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  1,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  

#### Formatting the dataset

In [12]:
dataset = dataset.map(lambda examples: {'labels': examples['label']}, batched=True)

HBox(children=(FloatProgress(value=0.0, max=4.0), HTML(value='')))




In [13]:
import torch
dataset.set_format(type='torch', columns=['input_ids', 'token_type_ids', 'attention_mask', 'labels'])
dataloader = torch.utils.data.DataLoader(dataset, batch_size=8)
next(iter(dataloader))

  return torch.tensor(x, **format_kwargs)


{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         ...,
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0],
         [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[  101,  7277,  2180,  ...,     0,     0,     0],
         [  101, 10684,  2599,  ...,     0,     0,     0],
         [  101,  1220,  1125,  ...,     0,     0,     0],
         ...,
         [  101, 16944,  1107,  ...,     0,     0,     0],
         [  101,  1109, 11896,  ...,     0,     0,     0],
         [  101,  1109,  4173,  ...,     0,     0,     0]]),
 'labels': tensor([1, 0, 1, 0, 1, 1, 0, 1]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         ...,
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0],
         [0, 0, 0,  ..., 0, 0, 0]])}

#### Training

In [14]:
from tqdm import tqdm
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.train().to(device)
optimizer = torch.optim.AdamW(params=model.parameters(), lr=1e-5)

In [15]:
for epoch in range(3):
    for i, batch in enumerate(tqdm(dataloader)):
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs[0]
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        if i % 10 == 0:
            print(f"loss: {loss}")

  0%|          | 1/459 [00:00<06:50,  1.12it/s]

loss: 0.793382465839386


  2%|▏         | 11/459 [00:08<05:32,  1.35it/s]

loss: 0.7165786027908325


  5%|▍         | 21/459 [00:15<05:25,  1.35it/s]

loss: 0.7045443654060364


  7%|▋         | 31/459 [00:23<05:20,  1.34it/s]

loss: 0.566281795501709


  9%|▉         | 41/459 [00:30<05:15,  1.33it/s]

loss: 0.5868521928787231


 11%|█         | 51/459 [00:38<05:09,  1.32it/s]

loss: 0.49931037425994873


 13%|█▎        | 61/459 [00:45<05:05,  1.30it/s]

loss: 0.7353698015213013


 15%|█▌        | 71/459 [00:53<04:59,  1.30it/s]

loss: 0.6399345397949219


 18%|█▊        | 81/459 [01:01<04:53,  1.29it/s]

loss: 0.694754421710968


 20%|█▉        | 91/459 [01:08<04:47,  1.28it/s]

loss: 0.6593285202980042


 22%|██▏       | 101/459 [01:16<04:42,  1.27it/s]

loss: 0.570824146270752


 24%|██▍       | 111/459 [01:24<04:36,  1.26it/s]

loss: 0.5997903347015381


 26%|██▋       | 121/459 [01:32<04:28,  1.26it/s]

loss: 0.42902088165283203


 29%|██▊       | 131/459 [01:40<04:20,  1.26it/s]

loss: 0.5499044060707092


 31%|███       | 141/459 [01:48<04:09,  1.28it/s]

loss: 0.4047330915927887


 33%|███▎      | 151/459 [01:56<03:59,  1.29it/s]

loss: 0.49817344546318054


 35%|███▌      | 161/459 [02:03<03:51,  1.29it/s]

loss: 0.7351539134979248


 37%|███▋      | 171/459 [02:11<03:43,  1.29it/s]

loss: 0.6151454448699951


 39%|███▉      | 181/459 [02:19<03:36,  1.28it/s]

loss: 0.6475334167480469


 42%|████▏     | 191/459 [02:27<03:30,  1.27it/s]

loss: 0.6804612874984741


 44%|████▍     | 201/459 [02:34<03:23,  1.27it/s]

loss: 0.2948799729347229


 46%|████▌     | 211/459 [02:42<03:15,  1.27it/s]

loss: 0.2861334979534149


 48%|████▊     | 221/459 [02:50<03:07,  1.27it/s]

loss: 0.6884828805923462


 50%|█████     | 231/459 [02:58<02:59,  1.27it/s]

loss: 0.3266623914241791


 53%|█████▎    | 241/459 [03:06<02:51,  1.27it/s]

loss: 0.4848356544971466


 55%|█████▍    | 251/459 [03:14<02:42,  1.28it/s]

loss: 0.6599732041358948


 57%|█████▋    | 261/459 [03:21<02:34,  1.28it/s]

loss: 0.5659237504005432


 59%|█████▉    | 271/459 [03:29<02:27,  1.27it/s]

loss: 0.4902201294898987


 61%|██████    | 281/459 [03:37<02:19,  1.27it/s]

loss: 0.38434478640556335


 63%|██████▎   | 291/459 [03:45<02:12,  1.27it/s]

loss: 0.38033080101013184


 66%|██████▌   | 301/459 [03:53<02:04,  1.27it/s]

loss: 0.7902589440345764


 68%|██████▊   | 311/459 [04:01<01:56,  1.27it/s]

loss: 0.4547716975212097


 70%|██████▉   | 321/459 [04:08<01:48,  1.27it/s]

loss: 0.3094363212585449


 72%|███████▏  | 331/459 [04:16<01:40,  1.27it/s]

loss: 0.19041335582733154


 74%|███████▍  | 341/459 [04:24<01:32,  1.27it/s]

loss: 0.3507658839225769


 76%|███████▋  | 351/459 [04:32<01:24,  1.27it/s]

loss: 0.551235556602478


 79%|███████▊  | 361/459 [04:40<01:16,  1.28it/s]

loss: 0.20859478414058685


 81%|████████  | 371/459 [04:47<01:09,  1.28it/s]

loss: 0.4408032298088074


 83%|████████▎ | 381/459 [04:55<01:01,  1.28it/s]

loss: 0.1755213886499405


 85%|████████▌ | 391/459 [05:03<00:53,  1.27it/s]

loss: 0.4821820855140686


 87%|████████▋ | 401/459 [05:11<00:45,  1.28it/s]

loss: 0.8104358911514282


 90%|████████▉ | 411/459 [05:19<00:37,  1.27it/s]

loss: 0.4309285879135132


 92%|█████████▏| 421/459 [05:27<00:29,  1.27it/s]

loss: 0.5295751690864563


 94%|█████████▍| 431/459 [05:34<00:22,  1.27it/s]

loss: 0.3362852931022644


 96%|█████████▌| 441/459 [05:42<00:14,  1.27it/s]

loss: 0.8957725763320923


 98%|█████████▊| 451/459 [05:50<00:06,  1.28it/s]

loss: 0.7903518676757812


100%|██████████| 459/459 [05:56<00:00,  1.29it/s]
  0%|          | 1/459 [00:00<06:02,  1.26it/s]

loss: 0.7254610061645508


  2%|▏         | 11/459 [00:08<05:52,  1.27it/s]

loss: 0.9955068230628967


  5%|▍         | 21/459 [00:16<05:44,  1.27it/s]

loss: 0.33960676193237305


  7%|▋         | 31/459 [00:24<05:36,  1.27it/s]

loss: 0.2201566845178604


  9%|▉         | 41/459 [00:32<05:28,  1.27it/s]

loss: 0.390663743019104


 11%|█         | 51/459 [00:39<05:20,  1.27it/s]

loss: 0.25429534912109375


 13%|█▎        | 61/459 [00:47<05:11,  1.28it/s]

loss: 0.948885440826416


 15%|█▌        | 71/459 [00:55<05:04,  1.28it/s]

loss: 0.42604541778564453


 18%|█▊        | 81/459 [01:03<04:56,  1.28it/s]

loss: 0.4235728681087494


 20%|█▉        | 91/459 [01:11<04:48,  1.28it/s]

loss: 0.46968773007392883


 22%|██▏       | 101/459 [01:18<04:41,  1.27it/s]

loss: 0.3030402362346649


 24%|██▍       | 111/459 [01:26<04:33,  1.27it/s]

loss: 0.23066949844360352


 26%|██▋       | 121/459 [01:34<04:25,  1.27it/s]

loss: 0.4219714403152466


 29%|██▊       | 131/459 [01:42<04:17,  1.27it/s]

loss: 0.1519598662853241


 31%|███       | 141/459 [01:50<04:10,  1.27it/s]

loss: 0.31331920623779297


 33%|███▎      | 151/459 [01:58<04:01,  1.27it/s]

loss: 0.23015624284744263


 35%|███▌      | 161/459 [02:06<03:54,  1.27it/s]

loss: 0.31204837560653687


 37%|███▋      | 171/459 [02:13<03:46,  1.27it/s]

loss: 0.18736374378204346


 39%|███▉      | 181/459 [02:21<03:38,  1.27it/s]

loss: 0.4696507751941681


 42%|████▏     | 191/459 [02:29<03:30,  1.27it/s]

loss: 0.5062205195426941


 44%|████▍     | 201/459 [02:37<03:22,  1.27it/s]

loss: 0.31908324360847473


 46%|████▌     | 211/459 [02:45<03:14,  1.27it/s]

loss: 0.06602268666028976


 48%|████▊     | 221/459 [02:53<03:07,  1.27it/s]

loss: 0.3697710335254669


 50%|█████     | 231/459 [03:00<02:58,  1.27it/s]

loss: 0.07028135657310486


 53%|█████▎    | 241/459 [03:08<02:51,  1.27it/s]

loss: 0.5077078938484192


 55%|█████▍    | 251/459 [03:16<02:44,  1.26it/s]

loss: 0.9763592481613159


 57%|█████▋    | 261/459 [03:24<02:35,  1.27it/s]

loss: 0.16136033833026886


 59%|█████▉    | 271/459 [03:32<02:27,  1.27it/s]

loss: 0.3476882874965668


 61%|██████    | 281/459 [03:40<02:19,  1.27it/s]

loss: 0.10727621614933014


 63%|██████▎   | 291/459 [03:47<02:11,  1.27it/s]

loss: 0.09166322648525238


 66%|██████▌   | 301/459 [03:55<02:04,  1.27it/s]

loss: 0.5973555445671082


 68%|██████▊   | 311/459 [04:03<01:56,  1.27it/s]

loss: 0.39631080627441406


 70%|██████▉   | 321/459 [04:11<01:48,  1.27it/s]

loss: 0.3525766134262085


 72%|███████▏  | 331/459 [04:19<01:40,  1.27it/s]

loss: 0.07093723118305206


 74%|███████▍  | 341/459 [04:27<01:32,  1.27it/s]

loss: 0.3173861801624298


 76%|███████▋  | 351/459 [04:34<01:25,  1.27it/s]

loss: 0.5595075488090515


 79%|███████▊  | 361/459 [04:42<01:17,  1.27it/s]

loss: 0.07867054641246796


 81%|████████  | 371/459 [04:50<01:09,  1.27it/s]

loss: 0.22574225068092346


 83%|████████▎ | 381/459 [04:58<01:01,  1.27it/s]

loss: 0.12201197445392609


 85%|████████▌ | 391/459 [05:06<00:53,  1.27it/s]

loss: 0.5256179571151733


 87%|████████▋ | 401/459 [05:14<00:45,  1.27it/s]

loss: 0.44406235218048096


 90%|████████▉ | 411/459 [05:22<00:37,  1.28it/s]

loss: 0.5397162437438965


 92%|█████████▏| 421/459 [05:29<00:29,  1.27it/s]

loss: 0.15145479142665863


 94%|█████████▍| 431/459 [05:37<00:21,  1.27it/s]

loss: 0.10703088343143463


 96%|█████████▌| 441/459 [05:45<00:14,  1.27it/s]

loss: 0.5912159085273743


 98%|█████████▊| 451/459 [05:53<00:06,  1.27it/s]

loss: 0.3390589952468872


100%|██████████| 459/459 [05:59<00:00,  1.28it/s]
  0%|          | 1/459 [00:00<06:02,  1.26it/s]

loss: 0.2634665071964264


  2%|▏         | 11/459 [00:08<05:51,  1.27it/s]

loss: 0.6131705641746521


  5%|▍         | 21/459 [00:16<05:43,  1.27it/s]

loss: 0.2913869619369507


  7%|▋         | 31/459 [00:24<05:35,  1.27it/s]

loss: 0.06220335140824318


  9%|▉         | 41/459 [00:32<05:27,  1.27it/s]

loss: 0.0500989705324173


 11%|█         | 51/459 [00:39<05:19,  1.28it/s]

loss: 0.12938112020492554


 13%|█▎        | 61/459 [00:47<05:12,  1.27it/s]

loss: 0.34653350710868835


 15%|█▌        | 71/459 [00:55<05:04,  1.27it/s]

loss: 0.09276345372200012


 18%|█▊        | 81/459 [01:03<04:56,  1.27it/s]

loss: 0.1475953459739685


 20%|█▉        | 91/459 [01:11<04:49,  1.27it/s]

loss: 0.3668742775917053


 22%|██▏       | 101/459 [01:19<04:40,  1.27it/s]

loss: 0.08190273493528366


 24%|██▍       | 111/459 [01:26<04:33,  1.27it/s]

loss: 0.3522956073284149


 26%|██▋       | 121/459 [01:34<04:25,  1.27it/s]

loss: 0.42303070425987244


 29%|██▊       | 131/459 [01:42<04:17,  1.27it/s]

loss: 0.27230963110923767


 31%|███       | 141/459 [01:50<04:09,  1.27it/s]

loss: 0.1597849428653717


 33%|███▎      | 151/459 [01:58<04:01,  1.27it/s]

loss: 0.11209230124950409


 35%|███▌      | 161/459 [02:06<03:54,  1.27it/s]

loss: 0.10532563924789429


 37%|███▋      | 171/459 [02:13<03:46,  1.27it/s]

loss: 0.04381905868649483


 39%|███▉      | 181/459 [02:21<03:38,  1.27it/s]

loss: 0.10122339427471161


 42%|████▏     | 191/459 [02:29<03:31,  1.27it/s]

loss: 0.3767867386341095


 44%|████▍     | 201/459 [02:37<03:22,  1.27it/s]

loss: 0.11810891330242157


 46%|████▌     | 211/459 [02:45<03:14,  1.27it/s]

loss: 0.010220718570053577


 48%|████▊     | 221/459 [02:52<03:06,  1.28it/s]

loss: 0.12831054627895355


 50%|█████     | 231/459 [03:00<02:58,  1.28it/s]

loss: 0.1644396185874939


 53%|█████▎    | 241/459 [03:08<02:50,  1.28it/s]

loss: 0.6394615769386292


 55%|█████▍    | 251/459 [03:16<02:42,  1.28it/s]

loss: 0.37773266434669495


 57%|█████▋    | 261/459 [03:24<02:35,  1.27it/s]

loss: 0.052201732993125916


 59%|█████▉    | 271/459 [03:32<02:27,  1.27it/s]

loss: 0.08773839473724365


 61%|██████    | 281/459 [03:39<02:19,  1.27it/s]

loss: 0.08117695897817612


 63%|██████▎   | 291/459 [03:47<02:12,  1.27it/s]

loss: 0.12272791564464569


 66%|██████▌   | 301/459 [03:55<02:04,  1.27it/s]

loss: 0.34258291125297546


 68%|██████▊   | 311/459 [04:03<01:56,  1.28it/s]

loss: 0.12695850431919098


 70%|██████▉   | 321/459 [04:11<01:48,  1.27it/s]

loss: 0.13969255983829498


 72%|███████▏  | 331/459 [04:19<01:40,  1.27it/s]

loss: 0.020897645503282547


 74%|███████▍  | 341/459 [04:26<01:33,  1.27it/s]

loss: 0.09689651429653168


 76%|███████▋  | 351/459 [04:34<01:25,  1.27it/s]

loss: 0.9152028560638428


 79%|███████▊  | 361/459 [04:42<01:16,  1.28it/s]

loss: 0.005689491983503103


 81%|████████  | 371/459 [04:50<01:09,  1.27it/s]

loss: 0.011376610025763512


 83%|████████▎ | 381/459 [04:58<01:01,  1.27it/s]

loss: 0.4315567910671234


 85%|████████▌ | 391/459 [05:06<00:53,  1.27it/s]

loss: 0.2688358426094055


 87%|████████▋ | 401/459 [05:13<00:45,  1.27it/s]

loss: 0.24602092802524567


 90%|████████▉ | 411/459 [05:21<00:37,  1.27it/s]

loss: 0.1759956181049347


 92%|█████████▏| 421/459 [05:29<00:29,  1.28it/s]

loss: 0.05238579213619232


 94%|█████████▍| 431/459 [05:37<00:21,  1.27it/s]

loss: 0.022939883172512054


 96%|█████████▌| 441/459 [05:45<00:14,  1.28it/s]

loss: 0.5062085390090942


 98%|█████████▊| 451/459 [05:53<00:06,  1.28it/s]

loss: 0.2902717888355255


100%|██████████| 459/459 [05:58<00:00,  1.28it/s]
