<a href="https://colab.research.google.com/github/Typing784/tutorials/blob/master/HF_Note8.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
!pip install datasets evaluate transformers[sentencepiece]

Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
Collecting responses<0.19 (from evaluate)
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Installing collected packages: responses, evaluate
Successfully installed evaluate-0.4.1 responses-0.18.0


In [77]:
from datasets import load_dataset
from transformers import AutoTokenizer, DataCollatorWithPadding

raw_datasets = load_dataset("glue", "mrpc")
checkpoint = "bert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

def tokenize_function(example):
    result = tokenizer(example["sentence1"], example["sentence2"], truncation=True, max_length=32,
        return_overflowing_tokens=True)
    sample_maping = result.pop("overflow_to_sample_mapping")

    for k, v in example.items():
      result[k] = [v[i] for i in sample_maping]

    return result

tokenized_datasets = raw_datasets.map(tokenize_function, batched=True)
# 注意这里需要保留labels
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)


Map:   0%|          | 0/3668 [00:00<?, ? examples/s]

Map:   0%|          | 0/408 [00:00<?, ? examples/s]

Map:   0%|          | 0/1725 [00:00<?, ? examples/s]

在token化的过程中，截断与否都不会自动Padding。

In [78]:
# tokenized_datasets['train']
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 408
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx'],
        num_rows: 1725
    })
})

In [79]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 18525
    })
    validation: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 2051
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'idx', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 8702
    })
})

In [58]:
len(tokenized_datasets['train'][1]['input_ids'])

27

In [80]:
# 前面已经remove，这里不再需要

tokenized_datasets = tokenized_datasets.remove_columns(['sentence1', 'sentence2','idx'])
# tokenized_datasets = tokenized_datasets.remove_columns(['overflow_to_sample_mapping'])
# tokenized_datasets = tokenized_datasets.rename_column('label','labels')  # 实践证明，这一行是不需要的
tokenized_datasets.set_format('torch')
print(tokenized_datasets['train'].column_names)


['label', 'input_ids', 'token_type_ids', 'attention_mask']


In [81]:
from torch.utils.data import DataLoader, Dataset
train_dataloader = DataLoader(tokenized_datasets['train'], shuffle=True, batch_size=8, collate_fn=data_collator)  # 通过这里的dataloader，每个batch的seq_len可能不同
eval_dataloader = DataLoader(tokenized_datasets['validation'], batch_size=8, collate_fn=data_collator)

# 查看一下train_dataloader的元素长啥样
for batch in train_dataloader:
    break
{k: v.shape for k, v in batch.items()}
# batch



{'input_ids': torch.Size([8, 32]),
 'token_type_ids': torch.Size([8, 32]),
 'attention_mask': torch.Size([8, 32]),
 'labels': torch.Size([8])}

In [12]:
tokenized_datasets  # 经过上面的处理，它就可以直接丢进pytorch的Dataloader中了，跟pytorch中的Dataset格式已经一样了

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 3668
    })
    validation: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 408
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 1725
    })
})

In [82]:
from transformers import AutoModelForSequenceClassification
model = AutoModelForSequenceClassification.from_pretrained(checkpoint, num_labels=2)


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [84]:
outputs = model(**batch)
print(outputs.loss, outputs.logits.shape)
# batch

tensor(0.6202, grad_fn=<NllLossBackward0>) torch.Size([8, 2])


In [85]:
from transformers import AdamW, get_scheduler

optimizer = AdamW(model.parameters(), lr=5e-5)

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)  # num of batches * num of epochs
lr_scheduler = get_scheduler(
    'linear',
    optimizer=optimizer,  # scheduler是针对optimizer的lr的
    num_warmup_steps=0,
    num_training_steps=num_training_steps)
print(num_training_steps)


6948


In [86]:
import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12,

In [89]:
from tqdm import tqdm
import torch.nn as nn

for epoch in range(num_epochs):
  running_loss = 0.0
  for i, batch in enumerate(tqdm(train_dataloader)):
      # 要在GPU上训练，需要把数据集都移动到GPU上：
      batch = {k:v.to(device) for k,v in batch.items()}
      loss = model(**batch).loss
      loss.backward()
      optimizer.step()
      lr_scheduler.step()
      optimizer.zero_grad()

      running_loss += loss.item()

      if i % 10 == 9:
        print('[%d, %5d] loss: %.3f' % (epoch + 1, i + 1, running_loss / 10))
        running_loss = 0.0




  0%|          | 10/2316 [00:29<1:53:03,  2.94s/it]

[1,    10] loss: 0.617


  1%|          | 20/2316 [00:59<1:50:23,  2.88s/it]

[1,    20] loss: 0.605


  1%|▏         | 30/2316 [01:30<1:56:24,  3.06s/it]

[1,    30] loss: 0.658


  2%|▏         | 40/2316 [02:00<1:51:23,  2.94s/it]

[1,    40] loss: 0.673


  2%|▏         | 50/2316 [02:30<1:59:10,  3.16s/it]

[1,    50] loss: 0.601


  3%|▎         | 60/2316 [03:00<1:53:12,  3.01s/it]

[1,    60] loss: 0.608


  3%|▎         | 70/2316 [03:30<1:49:14,  2.92s/it]

[1,    70] loss: 0.478


  3%|▎         | 80/2316 [04:01<1:54:17,  3.07s/it]

[1,    80] loss: 0.561


  4%|▍         | 90/2316 [04:30<1:48:13,  2.92s/it]

[1,    90] loss: 0.536


  4%|▍         | 100/2316 [05:01<1:56:49,  3.16s/it]

[1,   100] loss: 0.575


  5%|▍         | 110/2316 [05:31<1:49:42,  2.98s/it]

[1,   110] loss: 0.588


  5%|▌         | 120/2316 [06:01<1:47:21,  2.93s/it]

[1,   120] loss: 0.608


  6%|▌         | 130/2316 [06:31<1:50:47,  3.04s/it]

[1,   130] loss: 0.593


  6%|▌         | 140/2316 [07:01<1:47:00,  2.95s/it]

[1,   140] loss: 0.585


  6%|▋         | 150/2316 [07:31<1:48:00,  2.99s/it]

[1,   150] loss: 0.593


  7%|▋         | 160/2316 [08:01<1:50:02,  3.06s/it]

[1,   160] loss: 0.589


  7%|▋         | 170/2316 [08:32<1:50:36,  3.09s/it]

[1,   170] loss: 0.570


  8%|▊         | 180/2316 [09:02<1:47:34,  3.02s/it]

[1,   180] loss: 0.578


  8%|▊         | 190/2316 [09:32<1:44:43,  2.96s/it]

[1,   190] loss: 0.585


  9%|▊         | 200/2316 [10:01<1:44:41,  2.97s/it]

[1,   200] loss: 0.631


  9%|▉         | 210/2316 [10:32<1:46:30,  3.03s/it]

[1,   210] loss: 0.624


  9%|▉         | 220/2316 [11:01<1:41:35,  2.91s/it]

[1,   220] loss: 0.548


 10%|▉         | 230/2316 [11:31<1:48:44,  3.13s/it]

[1,   230] loss: 0.634


 10%|█         | 240/2316 [12:01<1:41:35,  2.94s/it]

[1,   240] loss: 0.460


 11%|█         | 250/2316 [12:30<1:39:13,  2.88s/it]

[1,   250] loss: 0.638


 11%|█         | 260/2316 [13:00<1:44:02,  3.04s/it]

[1,   260] loss: 0.642


 12%|█▏        | 270/2316 [13:30<1:40:11,  2.94s/it]

[1,   270] loss: 0.573


 12%|█▏        | 280/2316 [13:59<1:36:02,  2.83s/it]

[1,   280] loss: 0.585


 13%|█▎        | 290/2316 [14:29<1:41:26,  3.00s/it]

[1,   290] loss: 0.548


 13%|█▎        | 300/2316 [14:59<1:37:36,  2.90s/it]

[1,   300] loss: 0.556


 13%|█▎        | 310/2316 [15:29<1:43:39,  3.10s/it]

[1,   310] loss: 0.470


 14%|█▍        | 320/2316 [15:58<1:38:18,  2.96s/it]

[1,   320] loss: 0.593


 14%|█▍        | 330/2316 [16:28<1:37:15,  2.94s/it]

[1,   330] loss: 0.572


 15%|█▍        | 340/2316 [16:58<1:38:52,  3.00s/it]

[1,   340] loss: 0.617


 15%|█▌        | 350/2316 [17:27<1:35:01,  2.90s/it]

[1,   350] loss: 0.604


 16%|█▌        | 360/2316 [17:57<1:38:04,  3.01s/it]

[1,   360] loss: 0.624


 16%|█▌        | 370/2316 [18:26<1:35:14,  2.94s/it]

[1,   370] loss: 0.540


 16%|█▋        | 380/2316 [18:55<1:33:52,  2.91s/it]

[1,   380] loss: 0.492


 17%|█▋        | 390/2316 [19:25<1:36:29,  3.01s/it]

[1,   390] loss: 0.594


 17%|█▋        | 400/2316 [19:55<1:32:56,  2.91s/it]

[1,   400] loss: 0.600


 18%|█▊        | 410/2316 [20:25<1:35:34,  3.01s/it]

[1,   410] loss: 0.492


 18%|█▊        | 420/2316 [20:54<1:32:14,  2.92s/it]

[1,   420] loss: 0.670


 19%|█▊        | 430/2316 [21:25<1:37:56,  3.12s/it]

[1,   430] loss: 0.620


 19%|█▉        | 440/2316 [21:54<1:31:07,  2.91s/it]

[1,   440] loss: 0.540


 19%|█▉        | 450/2316 [22:23<1:29:32,  2.88s/it]

[1,   450] loss: 0.592


 20%|█▉        | 460/2316 [22:53<1:32:42,  3.00s/it]

[1,   460] loss: 0.548


 20%|██        | 470/2316 [23:23<1:28:37,  2.88s/it]

[1,   470] loss: 0.555


 21%|██        | 480/2316 [23:53<1:35:00,  3.10s/it]

[1,   480] loss: 0.550


 21%|██        | 490/2316 [24:22<1:29:19,  2.94s/it]

[1,   490] loss: 0.662


 22%|██▏       | 500/2316 [24:52<1:30:29,  2.99s/it]

[1,   500] loss: 0.594


 22%|██▏       | 510/2316 [25:22<1:30:39,  3.01s/it]

[1,   510] loss: 0.542


 22%|██▏       | 520/2316 [25:51<1:26:49,  2.90s/it]

[1,   520] loss: 0.642


 23%|██▎       | 530/2316 [26:22<1:29:11,  3.00s/it]

[1,   530] loss: 0.608


 23%|██▎       | 540/2316 [26:51<1:25:47,  2.90s/it]

[1,   540] loss: 0.569


 24%|██▎       | 550/2316 [27:21<1:31:36,  3.11s/it]

[1,   550] loss: 0.624


 24%|██▍       | 560/2316 [27:50<1:25:33,  2.92s/it]

[1,   560] loss: 0.598


 25%|██▍       | 570/2316 [28:20<1:23:33,  2.87s/it]

[1,   570] loss: 0.602


 25%|██▌       | 580/2316 [28:50<1:27:01,  3.01s/it]

[1,   580] loss: 0.556


 25%|██▌       | 590/2316 [29:19<1:23:07,  2.89s/it]

[1,   590] loss: 0.507


 26%|██▌       | 600/2316 [29:49<1:28:39,  3.10s/it]

[1,   600] loss: 0.519


 26%|██▋       | 610/2316 [30:19<1:23:46,  2.95s/it]

[1,   610] loss: 0.481


 27%|██▋       | 620/2316 [30:49<1:26:25,  3.06s/it]

[1,   620] loss: 0.697


 27%|██▋       | 630/2316 [31:20<1:28:21,  3.14s/it]

[1,   630] loss: 0.615


 28%|██▊       | 640/2316 [31:53<1:33:51,  3.36s/it]

[1,   640] loss: 0.498


 28%|██▊       | 650/2316 [32:23<1:21:46,  2.95s/it]

[1,   650] loss: 0.493


 28%|██▊       | 660/2316 [32:53<1:23:12,  3.01s/it]

[1,   660] loss: 0.637


 29%|██▉       | 670/2316 [33:22<1:21:28,  2.97s/it]

[1,   670] loss: 0.452


 29%|██▉       | 680/2316 [33:51<1:18:25,  2.88s/it]

[1,   680] loss: 0.692


 30%|██▉       | 690/2316 [34:21<1:20:40,  2.98s/it]

[1,   690] loss: 0.554


 30%|██▉       | 690/2316 [34:22<1:21:00,  2.99s/it]


KeyboardInterrupt: 

In [90]:
import evaluate

metric = evaluate.load("glue", "mrpc")
model.eval()
for batch in eval_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model(**batch)

    logits = outputs.logits
    predictions = torch.argmax(logits, dim=-1)
    metric.add_batch(predictions=predictions, references=batch["labels"])

metric.compute()

Downloading builder script:   0%|          | 0.00/5.75k [00:00<?, ?B/s]

{'accuracy': 0.6962457337883959, 'f1': 0.8000000000000002}