In [39]:
import torch
from transformers import AdamW
from datasets import load_dataset
from transformers import BertModel
from transformers import BertTokenizer

In [40]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 定义数据集

In [41]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        super().__init__()
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 30
        
        self.dataset = dataset.filter(f)    # 过滤
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        
        return text

In [42]:
dataset = Dataset('train')

len(dataset), dataset[1]

Found cached dataset chn_senti_corp (C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85\cache-9bcfb9e05326d42e.arrow


(9192, '15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错')

### 加载tokenizer

In [43]:
# 加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### 定义批处理函数

In [44]:
def collate_fn(data):
    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=data, 
                                    truncation=True,
                                    padding='max_length',
                                    max_length=30,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids: 编码之后的数字
    # attention: 是补零的位置是0, 其他位置是1(padding部分不参与注意力机制计算)
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)

    # 把第15个词固定替换为mask
    labels = input_ids[:, 15].reshape(-1).clone().to(device)
    input_ids[:, 15] = token.get_vocab()[token.mask_token]

    return input_ids, attention_mask, token_type_ids, labels

### 定义数据加载器

In [45]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                    batch_size = 16, 
                                    collate_fn = collate_fn,
                                    shuffle = True,
                                    drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
print(token.decode(labels[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape

574
[CLS] 每 次 开 机 时 音 箱 都 有 杂 音 ， 必 须 [MASK] 机 重 启 才 可 以 ， 现 在 又 不 超 过 [SEP]
关


(torch.Size([16, 30]),
 torch.Size([16, 30]),
 torch.Size([16, 30]),
 torch.Size([16]))

### 加载bert中文模型

In [46]:
# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)

# 不训练, 不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad = False

# 模型试算
out = pretrained(input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids)

out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 30, 768])

### 定义下游任务模型

In [47]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = torch.nn.Linear(768, token.vocab_size, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, input_ids, attention_mask, token_type_ids, mask_loc = 15):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)

        out = self.decoder(out.last_hidden_state[:, mask_loc])

        return out

model = Model().to(device)

model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

torch.Size([16, 21128])

### 训练下游任务模型

In [48]:
# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss().to(device)

model.train()
for epoch in range(5):
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if(i % 50 == 0):
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)

            print(epoch, i, loss.item(), accuracy)



0 0 10.13753890991211 0.0
0 50 7.677145004272461 0.375
0 100 6.194839954376221 0.125
0 150 5.391766548156738 0.1875
0 200 5.27683162689209 0.3125
0 250 3.697930097579956 0.5
0 300 4.379644870758057 0.5
0 350 5.089016437530518 0.25
0 400 2.495898723602295 0.625
0 450 4.2066450119018555 0.3125
0 500 4.9397053718566895 0.25
0 550 2.8925065994262695 0.5625
1 0 2.7506492137908936 0.625
1 50 2.4521594047546387 0.625
1 100 3.1646573543548584 0.4375
1 150 1.888915777206421 0.75
1 200 1.1031101942062378 0.875
1 250 2.2484664916992188 0.6875
1 300 2.2036807537078857 0.6875
1 350 3.0364530086517334 0.375
1 400 1.3976575136184692 0.6875
1 450 2.1122806072235107 0.625
1 500 1.6212650537490845 0.75
1 550 1.2603724002838135 0.75
2 0 0.9670196175575256 0.9375
2 50 1.4742933511734009 0.75
2 100 1.244105577468872 0.6875
2 150 0.8302542567253113 0.875
2 200 0.7655669450759888 0.875
2 250 0.8564367294311523 0.8125
2 300 0.8216100335121155 0.75
2 350 1.0247125625610352 0.6875
2 400 0.8160069584846497 0.812

### 测试

In [49]:
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset = Dataset('test'),
                                            batch_size = 32,
                                            collate_fn=collate_fn,
                                            shuffle = True,
                                            drop_last = True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if(i == 5):
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

        print(token.decode(input_ids[0]))
        print('标准: ', token.decode(labels[0]))
        print('预测: ', token.decode(out[0]))

    return correct / total

test()


Found cached dataset chn_senti_corp (C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85\cache-7a7e6f6e90083519.arrow


0
[CLS] 集 成 显 卡, 玩 使 命 召 唤 5 等 要 求 [MASK] 高 的 游 戏 比 较 卡, 没 有 高 清 输 [SEP]
标准:  很
预测:  很
1
[CLS] 卡 梅 拉 是 之 勇 敢 的 小 鸡 简 单 而 执 [MASK] 我 喜 欢 哦 ~ 同 事 的 小 宝 宝 也 很 [SEP]
标准:  着
预测:  着
2
[CLS] 渡 假 村 周 围 景 色 不 错, 但 较 落 乡 [MASK] 硬 件 太 差, 服 务 水 准 有 待 提 高 [SEP]
标准:  .
预测:  ,
3
[CLS] 外 观 很 美 观 大 方 ， 配 置 均 衡 合 理 [MASK] 做 工 不 错 ， 屏 幕 效 果 很 棒 ， 镜 [SEP]
标准:  ，
预测:  ，
4
[CLS] 我 在 京 东 网 买 的 华 硕 笔 记 本 ， 第 [MASK] 天 就 发 现 严 重 的 问 题 ， 无 法 使 [SEP]
标准:  二
预测:  二


0.66875

### 使用

In [80]:
sentence = ['太没[MASK]理']
data = token.batch_encode_plus(batch_text_or_text_pairs=sentence, 
                                truncation=True,
                                padding='max_length',
                                max_length=30,
                                return_tensors='pt',
                                return_length=True)

input_ids = data['input_ids'].to(device)
attention_mask = data['attention_mask'].to(device)
token_type_ids = data['token_type_ids'].to(device)

out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, mask_loc = 3)
res = out.argmax(dim=1)

token.decode(res)

'道'