In [1]:
from tqdm import tqdm
#第7章/加载编码工具
from transformers import AutoTokenizer, AutoModel

tokenizer = AutoTokenizer.from_pretrained('microsoft/deberta-v3-large')
model = AutoModel.from_pretrained('microsoft/deberta-v3-large')

  return self.fget.__get__(instance, owner)()


In [2]:
#第7章/试编码句子
out = tokenizer.batch_encode_plus(
    batch_text_or_text_pairs=['从明天起，做一个幸福的人。', '喂马，劈柴，周游世界。'],
    truncation=True,
    padding='max_length',
    max_length=17,
    return_tensors='pt',
    return_length=True)

#查看编码输出
for k, v in out.items():
    print(k, v.shape)

#把编码还原为句子
print(tokenizer.decode(out['input_ids'][0]))

input_ids torch.Size([2, 17])
token_type_ids torch.Size([2, 17])
attention_mask torch.Size([2, 17])
length torch.Size([2])
[CLS] 从明天起,做一个幸福的人。[SEP][PAD][PAD]


In [3]:
#第7章/定义数据集
import torch
from datasets import load_from_disk


class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        self.dataset = load_from_disk('./data/ChnSentiCorp')[split]

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        text = self.dataset[i]['text']
        label = self.dataset[i]['label']

        return text, label


dataset = Dataset('train')

len(dataset), dataset[20]

(9600, ('非常不错，服务很好，位于市中心区，交通方便，不过价格也高！', 1))

In [4]:
#第7章/定义计算设备
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'

device

'cpu'

In [5]:
#第7章/数据整理函数
def collate_fn(data):
    sents = [i[0] for i in data]
    labels = [i[1] for i in data]

    #编码
    data = tokenizer.batch_encode_plus(batch_text_or_text_pairs=sents,
                                   truncation=True,
                                   padding='max_length',
                                   max_length=500,
                                   return_tensors='pt',
                                   return_length=True)

    #input_ids:编码之后的数字
    #attention_mask:是补零的位置是0,其他位置是1
    input_ids = data['input_ids']
    attention_mask = data['attention_mask']
    token_type_ids = data['token_type_ids']
    labels = torch.LongTensor(labels)

    #把数据移动到计算设备上
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    token_type_ids = token_type_ids.to(device)
    labels = labels.to(device)

    return input_ids, attention_mask, token_type_ids, labels

In [6]:
#第7章/数据整理函数试算
#模拟一批数据
data = [
    ('你站在桥上看风景', 1),
    ('看风景的人在楼上看你', 0),
    ('明月装饰了你的窗子', 1),
    ('你装饰了别人的梦', 0),
]

#试算
input_ids, attention_mask, token_type_ids, labels = collate_fn(data)

input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

(torch.Size([4, 500]),
 torch.Size([4, 500]),
 torch.Size([4, 500]),
 tensor([1, 0, 1, 0]))

In [7]:
#第7章/数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

len(loader)

600

In [8]:
#第7章/查看数据样例
for i, (input_ids, attention_mask, token_type_ids,
        labels) in enumerate(loader):
    break

input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

(torch.Size([16, 500]),
 torch.Size([16, 500]),
 torch.Size([16, 500]),
 tensor([0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1]))

In [9]:
#统计参数量
sum(i.numel() for i in model.parameters()) / 10000

43401.216

In [10]:
#第7章/不训练预训练模型,不需要计算梯度
for param in model.parameters():
    param.requires_grad_(True)

In [11]:
#第7章/预训练模型试算
#设定计算设备
model.to(device)

#模型试算
out = model(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)

out.last_hidden_state.shape

torch.Size([16, 500, 1024])

In [12]:
#第7章/定义下游任务模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(in_features=1024, out_features=2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        #使用预训练模型抽取数据特征
        
        out = model(input_ids=input_ids,
                 attention_mask=attention_mask,
                 token_type_ids=token_type_ids)

        #对抽取的特征只取第一个token <CLS>的结果做分类即可,结果[16, 1024]
        out = self.fc(out.last_hidden_state[:, 0])

        out = out.softmax(dim=1)

        return out


model2 = Model()

#设定计算设备
model2.to(device)

#试算
model2(input_ids=input_ids,
      attention_mask=attention_mask,
      token_type_ids=token_type_ids).shape

torch.Size([16, 2])

In [13]:
#第7章/训练
from transformers import AdamW
from transformers.optimization import get_scheduler


def train():
    #定义优化器
    optimizer = AdamW(model2.parameters(), lr=5e-4)

    #定义loss函数
    criterion = torch.nn.CrossEntropyLoss()

    #定义学习率调节器
    scheduler = get_scheduler(name='linear',
                              num_warmup_steps=0,
                              num_training_steps=len(loader),
                              optimizer=optimizer)

    #模型切换到训练模式
    model2.train()

    #按批次遍历训练集中的数据
    for i, (input_ids, attention_mask, token_type_ids,
            labels) in tqdm(enumerate(loader)):

        #模型计算
        out = model2(input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_type_ids=token_type_ids)

        #计算loss并使用梯度下降法优化模型参数
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        #输出各项数据的情况，便于观察
        if i % 20 == 0:
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)
            lr = optimizer.state_dict()['param_groups'][0]['lr']
            print(i, loss.item(), lr, accuracy)


train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
1it [00:43, 43.64s/it]

0 0.6900021433830261 0.0004991666666666666 0.5625


21it [15:15, 43.69s/it]

20 0.6911894083023071 0.0004825 0.5


41it [29:50, 42.87s/it]

40 0.6758537888526917 0.00046583333333333334 0.625


61it [43:45, 41.35s/it]

60 0.7096973657608032 0.00044916666666666667 0.5


81it [57:35, 41.61s/it]

80 0.7274434566497803 0.0004325 0.3125


101it [1:11:29, 41.66s/it]

100 0.6676782369613647 0.0004158333333333333 0.75


121it [1:25:28, 42.22s/it]

120 0.7407042384147644 0.0003991666666666667 0.5


141it [1:39:20, 41.63s/it]

140 0.7311270236968994 0.00038250000000000003 0.5625


161it [1:53:11, 41.51s/it]

160 0.627440333366394 0.00036583333333333335 0.875


181it [2:07:02, 41.53s/it]

180 0.694422721862793 0.0003491666666666667 0.4375


201it [2:20:50, 41.46s/it]

200 0.7326825857162476 0.0003325 0.4375


221it [2:34:46, 41.60s/it]

220 0.6928690671920776 0.0003158333333333334 0.5625


241it [2:48:38, 42.08s/it]

240 0.7145330905914307 0.0002991666666666667 0.5625


261it [3:02:31, 41.81s/it]

260 0.6566442847251892 0.0002825 0.75


281it [3:16:26, 41.77s/it]

280 0.7089576721191406 0.0002658333333333333 0.4375


301it [3:30:20, 41.48s/it]

300 0.7007237672805786 0.0002491666666666667 0.5


321it [3:44:17, 41.86s/it]

320 0.739691436290741 0.0002325 0.375


341it [3:58:09, 41.40s/it]

340 0.6811448931694031 0.00021583333333333334 0.4375


361it [4:12:06, 41.44s/it]

360 0.7289862036705017 0.00019916666666666667 0.5


381it [4:25:58, 41.42s/it]

380 0.6506887674331665 0.0001825 0.6875


401it [4:39:53, 41.22s/it]

400 0.7367125749588013 0.00016583333333333334 0.4375


421it [4:53:47, 41.67s/it]

420 0.654661238193512 0.00014916666666666667 0.6875


441it [5:07:39, 41.60s/it]

440 0.7555987238883972 0.00013250000000000002 0.375


461it [5:21:33, 41.93s/it]

460 0.7035426497459412 0.00011583333333333333 0.4375


481it [5:35:22, 41.52s/it]

480 0.669472873210907 9.916666666666667e-05 0.6875


501it [5:49:13, 41.55s/it]

500 0.7772378921508789 8.25e-05 0.25


521it [6:03:03, 41.53s/it]

520 0.6892099380493164 6.583333333333333e-05 0.5625


541it [6:16:57, 41.75s/it]

540 0.6918518543243408 4.9166666666666665e-05 0.4375


561it [6:30:43, 41.25s/it]

560 0.6738533973693848 3.2500000000000004e-05 0.5625


581it [6:44:28, 40.89s/it]

580 0.6709483861923218 1.5833333333333336e-05 0.5625


600it [6:57:29, 41.75s/it]


In [14]:
#第7章/测试
def test():
    #定义测试数据集加载器
    loader_test = torch.utils.data.DataLoader(dataset=Dataset('test'),
                                              batch_size=30,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    
    print('length of test set: ', len(loader_test))

    #下游任务模型切换到运行模式
    model2.eval()
    correct = 0
    total = 0

    #按批次遍历测试集中的数据
    for i, (input_ids, attention_mask, token_type_ids,
            labels) in tqdm(enumerate(loader_test)):

        print(i)

        #计算
        with torch.no_grad():
            out = model2(input_ids=input_ids,
                        attention_mask=attention_mask,
                        token_type_ids=token_type_ids)

        #统计正确率
        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    print(correct / total)


test()

length of test set:  40


0it [00:00, ?it/s]

0


1it [00:17, 17.34s/it]

1


2it [00:34, 17.15s/it]

2


3it [00:51, 17.10s/it]

3


4it [01:08, 17.06s/it]

4


5it [01:25, 17.05s/it]

5


6it [01:42, 17.05s/it]

6


7it [01:59, 17.08s/it]

7


8it [02:16, 17.04s/it]

8


9it [02:33, 17.04s/it]

9


10it [02:50, 17.11s/it]

10


11it [03:07, 17.10s/it]

11


12it [03:24, 17.07s/it]

12


13it [03:41, 17.03s/it]

13


14it [03:58, 17.03s/it]

14


15it [04:16, 17.07s/it]

15


16it [04:33, 17.07s/it]

16


17it [04:50, 17.10s/it]

17


18it [05:07, 17.11s/it]

18


19it [05:24, 17.08s/it]

19


20it [05:41, 17.08s/it]

20


21it [05:58, 17.07s/it]

21


22it [06:15, 17.07s/it]

22


23it [06:32, 17.05s/it]

23


24it [06:49, 17.04s/it]

24


25it [07:06, 17.05s/it]

25


26it [07:23, 17.07s/it]

26


27it [07:40, 17.06s/it]

27


28it [07:58, 17.08s/it]

28


29it [08:14, 17.04s/it]

29


30it [08:32, 17.06s/it]

30


31it [08:49, 17.08s/it]

31


32it [09:06, 17.11s/it]

32


33it [09:23, 17.08s/it]

33


34it [09:40, 17.06s/it]

34


35it [09:57, 17.05s/it]

35


36it [10:14, 17.08s/it]

36


37it [10:31, 17.07s/it]

37


38it [10:48, 17.06s/it]

38


39it [11:05, 17.06s/it]

39


40it [11:22, 17.07s/it]

0.5891666666666666



