In [47]:
import torch
from transformers import AdamW
from datasets import load_dataset
from transformers import BertModel
from transformers import BertTokenizer

In [48]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

### 定义数据集

In [49]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        super().__init__()
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 30
        
        self.dataset = dataset.filter(f)    # 过滤
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']
        
        return text

In [50]:
dataset = Dataset('train')

len(dataset), dataset[1]

Found cached dataset chn_senti_corp (C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85\cache-9bcfb9e05326d42e.arrow


(9192, '15.4寸笔记本的键盘确实爽，基本跟台式机差不多了，蛮喜欢数字小键盘，输数字特方便，样子也很美观，做工也相当不错')

### 加载tokenizer

In [51]:
# 加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### 定义批处理函数

In [52]:
def collate_fn(data):
    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=data, 
                                    truncation=True,
                                    padding='max_length',
                                    max_length=30,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids: 编码之后的数字
    # attention: 是补零的位置是0, 其他位置是1(padding部分不参与注意力机制计算)
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)

    # 把第15个词固定替换为mask
    labels = input_ids[:, 15].reshape(-1).clone().to(device)
    input_ids[:, 15] = token.get_vocab()[token.mask_token]

    return input_ids, attention_mask, token_type_ids, labels

### 定义数据加载器

In [53]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                    batch_size = 16, 
                                    collate_fn = collate_fn,
                                    shuffle = True,
                                    drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
print(token.decode(labels[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels.shape

574
[CLS] 这 本 书 的 作 者 自 己 有 多 年 食 品 添 [MASK] 剂 行 业 工 作 的 经 验 ， 而 且 自 己 [SEP]
加


(torch.Size([16, 30]),
 torch.Size([16, 30]),
 torch.Size([16, 30]),
 torch.Size([16]))

### 加载bert中文模型

In [54]:
# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)

# 不训练, 不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad = False

# 模型试算
out = pretrained(input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids)

out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 30, 768])

### 定义下游任务模型

In [55]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.decoder = torch.nn.Linear(768, token.vocab_size, bias=False)
        self.bias = torch.nn.Parameter(torch.zeros(token.vocab_size))
        self.decoder.bias = self.bias

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)

        out = self.decoder(out.last_hidden_state[:, 15])

        return out

model = Model().to(device)

model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

torch.Size([16, 21128])

### 训练下游任务模型

In [56]:
# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss().to(device)

model.train()
for epoch in range(5):
    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
        out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if(i % 50 == 0):
            out = out.argmax(dim=1)
            accuracy = (out == labels).sum().item() / len(labels)

            print(epoch, i, loss.item(), accuracy)

0 0 10.064632415771484 0.0


KeyboardInterrupt: 

### 测试

In [None]:
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset = Dataset('test'),
                                            batch_size = 32,
                                            collate_fn=collate_fn,
                                            shuffle = True,
                                            drop_last = True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if(i == 5):
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

        print(token.decode(input_ids[0]))
        print(token.decode(labels[0]), print(token.decode(out[0])))

    return correct / total

test()
