In [1]:
from tqdm import tqdm

#第7章/加载编码工具
from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

tokenizer

BertTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_length=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

In [2]:
#第7章/试编码句子
out = tokenizer.batch_encode_plus(
    batch_text_or_text_pairs=['从明天起，做一个幸福的人。', '喂马，劈柴，周游世界。'],
    truncation=True,
    padding='max_length',
    max_length=17,
    return_tensors='pt',
    return_length=True)

#查看编码输出
for k, v in out.items():
    print(k, v.shape)

#把编码还原为句子
print(tokenizer.decode(out['input_ids'][0]))

input_ids torch.Size([2, 17])
token_type_ids torch.Size([2, 17])
length torch.Size([2])
attention_mask torch.Size([2, 17])
[CLS] 从 明 天 起 ， 做 一 个 幸 福 的 人 。 [SEP] [PAD] [PAD]


In [3]:
#第7章/定义数据集
import torch
from datasets import load_from_disk


class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_from_disk('./data/ChnSentiCorp')[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label


dataset = Dataset('train')

len(dataset), dataset[20]

(9600, ('非常不错，服务很好，位于市中心区，交通方便，不过价格也高！', 1))

In [4]:
#第7章/定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

device

'cpu'

In [5]:
#第7章/数据整理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #把数据移动到计算设备上
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)

    return input_ids, attention_mask, token_type_ids, labels

In [6]:
#第7章/数据整理函数试算
#模拟一批数据
data = [
    ('你站在桥上看风景', 1),
    ('看风景的人在楼上看你', 0),
    ('明月装饰了你的窗子', 1),
    ('你装饰了别人的梦', 0),
]

#试算
input_ids, attention_mask, token_type_ids, labels = collate_fn(data)

input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

(torch.Size([4, 500]),
 torch.Size([4, 500]),
 torch.Size([4, 500]),
 tensor([1, 0, 1, 0]))

In [7]:
#第7章/数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

len(loader)

600

In [8]:
#第7章/查看数据样例
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

(torch.Size([16, 500]),
 torch.Size([16, 500]),
 torch.Size([16, 500]),
 tensor([1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1]))

In [9]:
#第7章/加载预训练模型
from transformers import BertModel

pretrained = BertModel.from_pretrained('bert-base-chinese')

#统计参数量
sum(i.numel() for i in pretrained.parameters()) / 10000

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


10226.7648

In [10]:
#第7章/不训练预训练模型,不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad_(False)

In [11]:
#第7章/预训练模型试算
#设定计算设备
pretrained.to(device)

#模型试算
out = pretrained(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)

out.last_hidden_state.shape

torch.Size([16, 500, 768])

In [12]:
#第7章/定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(in_features=768, out_features=2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        #使用预训练模型抽取数据特征
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                             attention_mask=attention_mask,
                             token_type_ids=token_type_ids)

        #对抽取的特征只取第一个token <CLS>的结果做分类即可,结果[16, 768]
        out = self.fc(out.last_hidden_state[:, 0])

        out = out.softmax(dim=1)

        return out


model = Model()

#设定计算设备
model.to(device)

#试算
model(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([16, 2])

In [13]:
#第7章/训练
from transformers import AdamW
from transformers.optimization import get_scheduler


def train():
    #定义优化器
    optimizer = AdamW(model.parameters(), lr=5e-4)

    #定义loss函数
    criterion = torch.nn.CrossEntropyLoss()

    #定义学习率调节器
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    #模型切换到训练模式
    model.train()

    #按批次遍历训练集中的数据
    for i, (input_ids, attention_mask, token_type_ids,
            labels) in tqdm(enumerate(loader)):

        #模型计算
        out = model(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)

        #计算loss并使用梯度下降法优化模型参数
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        #输出各项数据的情况，便于观察
        if i % 20 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)


train()

1it [00:02,  2.32s/it]

0 0.6699261665344238 0.0004991666666666666 0.5625


21it [00:47,  2.30s/it]

20 0.5833996534347534 0.0004825 0.8125


41it [01:34,  2.31s/it]

40 0.5714215636253357 0.00046583333333333334 0.75


61it [02:20,  2.28s/it]

60 0.6317876577377319 0.00044916666666666667 0.625


81it [03:05,  2.29s/it]

80 0.5275751352310181 0.0004325 0.875


101it [03:51,  2.32s/it]

100 0.5517035722732544 0.0004158333333333333 0.6875


121it [04:38,  2.32s/it]

120 0.5092716813087463 0.0003991666666666667 0.75


141it [05:24,  2.30s/it]

140 0.4396957755088806 0.00038250000000000003 0.875


161it [06:09,  2.27s/it]

160 0.4306226372718811 0.00036583333333333335 0.9375


181it [06:55,  2.27s/it]

180 0.4795974791049957 0.0003491666666666667 0.9375


201it [07:41,  2.29s/it]

200 0.5424973964691162 0.0003325 0.8125


221it [08:26,  2.29s/it]

220 0.4641924500465393 0.0003158333333333334 0.9375


241it [09:12,  2.27s/it]

240 0.43923714756965637 0.0002991666666666667 0.9375


261it [09:58,  2.29s/it]

260 0.5213017463684082 0.0002825 0.8125


281it [10:44,  2.29s/it]

280 0.49129173159599304 0.0002658333333333333 0.8125


301it [11:30,  2.28s/it]

300 0.4349232017993927 0.0002491666666666667 0.875


321it [12:16,  2.30s/it]

320 0.4721528887748718 0.0002325 0.8125


341it [13:02,  2.36s/it]

340 0.4591294527053833 0.00021583333333333334 0.8125


361it [13:49,  2.33s/it]

360 0.46062272787094116 0.00019916666666666667 0.875


381it [14:35,  2.30s/it]

380 0.3842809498310089 0.0001825 1.0


401it [15:21,  2.30s/it]

400 0.49416786432266235 0.00016583333333333334 0.8125


421it [16:08,  2.30s/it]

420 0.5344399809837341 0.00014916666666666667 0.8125


441it [16:54,  2.29s/it]

440 0.46505314111709595 0.00013250000000000002 0.8125


461it [17:40,  2.32s/it]

460 0.41973111033439636 0.00011583333333333333 0.875


481it [18:26,  2.30s/it]

480 0.46818891167640686 9.916666666666667e-05 0.875


501it [19:12,  2.33s/it]

500 0.41528356075286865 8.25e-05 1.0


521it [19:59,  2.30s/it]

520 0.3984796106815338 6.583333333333333e-05 0.9375


541it [20:45,  2.33s/it]

540 0.41384023427963257 4.9166666666666665e-05 0.9375


561it [21:31,  2.30s/it]

560 0.44433972239494324 3.2500000000000004e-05 0.875


581it [22:17,  2.28s/it]

580 0.37153586745262146 1.5833333333333336e-05 0.9375


600it [23:01,  2.30s/it]


In [14]:
#第7章/测试
def test():
    #定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=30,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    
    print('length of test set: ', len(loader_test))

    #下游任务模型切换到运行模式
    model.eval()
    correct = 0
    total = 0

    #按批次遍历测试集中的数据
    for i, (input_ids, attention_mask, token_type_ids,
            labels) in tqdm(enumerate(loader_test)):

        print(i)

        #计算
        with torch.no_grad():
            out = model(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        #统计正确率
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

length of test set:  40


0it [00:00, ?it/s]

0


1it [00:04,  4.09s/it]

1


2it [00:08,  4.04s/it]

2


3it [00:12,  4.01s/it]

3


4it [00:16,  4.00s/it]

4


5it [00:20,  4.00s/it]

5


6it [00:24,  4.02s/it]

6


7it [00:28,  4.03s/it]

7


8it [00:32,  4.03s/it]

8


9it [00:36,  4.02s/it]

9


10it [00:40,  4.02s/it]

10


11it [00:44,  4.02s/it]

11


12it [00:48,  4.02s/it]

12


13it [00:52,  4.02s/it]

13


14it [00:56,  4.01s/it]

14


15it [01:00,  4.03s/it]

15


16it [01:04,  4.05s/it]

16


17it [01:08,  4.03s/it]

17


18it [01:12,  4.03s/it]

18


19it [01:16,  4.03s/it]

19


20it [01:20,  4.03s/it]

20


21it [01:24,  4.03s/it]

21


22it [01:28,  4.03s/it]

22


23it [01:32,  4.04s/it]

23


24it [01:36,  4.03s/it]

24


25it [01:40,  4.02s/it]

25


26it [01:44,  4.01s/it]

26


27it [01:48,  4.02s/it]

27


28it [01:52,  4.01s/it]

28


29it [01:56,  4.01s/it]

29


30it [02:00,  4.01s/it]

30


31it [02:04,  4.01s/it]

31


32it [02:08,  4.00s/it]

32


33it [02:12,  4.00s/it]

33


34it [02:16,  4.00s/it]

34


35it [02:20,  3.99s/it]

35


36it [02:24,  3.99s/it]

36


37it [02:28,  4.00s/it]

37


38it [02:32,  3.99s/it]

38


39it [02:36,  4.00s/it]

39


40it [02:40,  4.01s/it]

0.8775



