In [2]:
import torch
import random
from transformers import AdamW
from datasets import load_dataset
from transformers import BertModel
from transformers import BertTokenizer

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device = 'cpu'

### 定义数据集

In [4]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        super().__init__()
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 40
        
        self.dataset = dataset.filter(f)    # 过滤
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']

        # 切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0

        # 有一半的概率把后半句替换为一句无关的话
        if(random.randint(0,1) == 0):
            j = random.randint(0, len(self.dataset)-1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label

In [6]:
dataset = Dataset('train')

sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

Found cached dataset chn_senti_corp (C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)
Loading cached processed dataset at C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85\cache-8cbabdef7cb17c1f.arrow


(8001, '选择珠江花园的原因就是方便，有电动扶梯直', '点，哪里开线了。而且都已经做了相应的处理', 1)

### 加载tokenizer

In [7]:
# 加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})

### 定义批处理函数

In [8]:
def collate_fn(data):
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]

    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents, 
                                    truncation=True,
                                    padding='max_length',
                                    max_length=45,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids: 编码之后的数字
    # attention: 是补零的位置是0, 其他位置是1(padding部分不参与注意力机制计算)
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)

    return input_ids, attention_mask, token_type_ids, labels

### 定义数据加载器

In [18]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                    batch_size = 16, 
                                    collate_fn = collate_fn,
                                    shuffle = True,
                                    drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

500
[CLS] 分 量 够 轻 了 ， 13 寸 的 屏 幕 刚 刚 好 ， 再 小 看 得 [SEP] 就 不 舒 服 了 ， 外 观 不 错 ， 性 能 过 得 去 ， 加 [UNK] [SEP] [PAD] [PAD] [PAD] [PAD]


(torch.Size([16, 45]),
 torch.Size([16, 45]),
 torch.Size([16, 45]),
 tensor([0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1]))

### 加载bert中文模型

In [19]:
# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)

# 不训练, 不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad = False

# 模型试算
out = pretrained(input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids)

out.last_hidden_state.shape

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


torch.Size([16, 45, 768])

### 定义下游任务模型

In [20]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])

        out = out.softmax(dim=1)

        return out

model = Model().to(device)

model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

torch.Size([16, 2])

### 训练下游任务模型

In [21]:
# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss().to(device)

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if(i % 5 == 0):
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if(i == 100):
        break



0 0.716956615447998 0.4375
5 0.6302874088287354 0.75
10 0.5722857713699341 0.6875
15 0.5094899535179138 0.8125
20 0.5728960037231445 0.625
25 0.4202996790409088 1.0
30 0.4266606271266937 0.9375
35 0.4131791293621063 0.9375
40 0.5453267097473145 0.75
45 0.35341116786003113 1.0
50 0.4447466731071472 0.875
55 0.4176824986934662 0.875
60 0.40036770701408386 0.9375
65 0.42593082785606384 0.9375
70 0.49632930755615234 0.8125
75 0.46445268392562866 0.875
80 0.42487233877182007 0.875
85 0.38113802671432495 0.9375
90 0.522413969039917 0.8125
95 0.4350370168685913 0.9375
100 0.46008482575416565 0.8125


### 测试

In [22]:
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset = Dataset('validation'),
                                            batch_size = 32,
                                            collate_fn=collate_fn,
                                            shuffle = True,
                                            drop_last = True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if(i == 5):
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    return correct / total

test()


Found cached dataset chn_senti_corp (C:/Users/BeatsLeo/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)


  0%|          | 0/2 [00:00<?, ?ba/s]

0
1
2
3
4


0.84375

### 使用

In [None]:
sentence = ['不错，很有意思']

data = token.batch_encode_plus(batch_text_or_text_pairs=sentence, 
                                truncation=True,
                                padding='max_length',
                                max_length=45,
                                return_tensors='pt',
                                return_length=True)

input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']

out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
res = out.argmax(dim=1)

out, '是一对' if res.item() == 0 else '不是一对'