In [None]:
import torch
import random
from transformers import AdamW
from datasets import load_dataset
from transformers import BertModel
from transformers import BertTokenizer

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### 定义数据集

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        super().__init__()
        dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)

        def f(data):
            return len(data['text']) > 40
        
        self.dataset = dataset.filter(f)    # 过滤
    
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        text = self.dataset[idx]['text']

        # 切分一句话为前半句和后半句
        sentence1 = text[:20]
        sentence2 = text[20:40]
        label = 0

        # 有一半的概率把后半句替换为一句无关的话
        if(random.randint(0,1) == 0):
            j = random.randint(0, len(self.dataset)-1)
            sentence2 = self.dataset[j]['text'][20:40]
            label = 1

        return sentence1, sentence2, label

In [None]:
dataset = Dataset('train')

sentence1, sentence2, label = dataset[0]

len(dataset), sentence1, sentence2, label

### 加载tokenizer

In [None]:
# 加载字典和分词工具
token = BertTokenizer.from_pretrained('bert-base-chinese')

token

### 定义批处理函数

In [None]:
def collate_fn(data):
    sents = [i[:2] for i in data]
    labels = [i[2] for i in data]

    # 编码
    data = token.batch_encode_plus(batch_text_or_text_pairs=sents, 
                                    truncation=True,
                                    padding='max_length',
                                    max_length=45,
                                    return_tensors='pt',
                                    return_length=True)

    # input_ids: 编码之后的数字
    # attention: 是补零的位置是0, 其他位置是1(padding部分不参与注意力机制计算)
    input_ids = data['input_ids'].to(device)
    attention_mask = data['attention_mask'].to(device)
    token_type_ids = data['token_type_ids'].to(device)
    labels = torch.LongTensor(labels).to(device)

    return input_ids, attention_mask, token_type_ids, labels

### 定义数据加载器

In [None]:
loader = torch.utils.data.DataLoader(dataset = dataset,
                                    batch_size = 16, 
                                    collate_fn = collate_fn,
                                    shuffle = True,
                                    drop_last=True)

for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    break

print(len(loader))
print(token.decode(input_ids[0]))
input_ids.shape, attention_mask.shape, token_type_ids.shape, labels

### 加载bert中文模型

In [None]:
# 加载预训练模型
pretrained = BertModel.from_pretrained('bert-base-chinese').to(device)

# 不训练, 不需要计算梯度
for param in pretrained.parameters():
    param.requires_grad = False

# 模型试算
out = pretrained(input_ids = input_ids,
                attention_mask = attention_mask,
                token_type_ids = token_type_ids)

out.last_hidden_state.shape

### 定义下游任务模型

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.fc = torch.nn.Linear(768, 2)

    def forward(self, input_ids, attention_mask, token_type_ids):
        with torch.no_grad():
            out = pretrained(input_ids=input_ids,
                            attention_mask=attention_mask,
                            token_type_ids=token_type_ids)
        out = self.fc(out.last_hidden_state[:, 0])

        out = out.softmax(dim=1)

        return out

model = Model().to(device)

model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids).shape

### 训练下游任务模型

In [None]:
# 训练
optimizer = AdamW(model.parameters(), lr=5e-4)
criterion = torch.nn.CrossEntropyLoss().to(device)

model.train()
for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader):
    out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

    loss = criterion(out, labels)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if(i % 5 == 0):
        out = out.argmax(dim=1)
        accuracy = (out == labels).sum().item() / len(labels)

        print(i, loss.item(), accuracy)

    if(i == 100):
        break

### 测试

In [None]:
def test():
    model.eval()
    correct = 0
    total = 0

    loader_test = torch.utils.data.DataLoader(dataset = Dataset('validation'),
                                            batch_size = 32,
                                            collate_fn=collate_fn,
                                            shuffle = True,
                                            drop_last = True)

    for i, (input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_test):
        if(i == 5):
            break

        print(i)

        with torch.no_grad():
            out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        out = out.argmax(dim=1)
        correct += (out == labels).sum().item()
        total += len(labels)

    return correct / total

test()


### 使用

In [None]:
sentence = ['不错，很有意思']

data = token.batch_encode_plus(batch_text_or_text_pairs=sentence, 
                                truncation=True,
                                padding='max_length',
                                max_length=45,
                                return_tensors='pt',
                                return_length=True)

input_ids = data['input_ids']
attention_mask = data['attention_mask']
token_type_ids = data['token_type_ids']

out = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
res = out.argmax(dim=1)

out, '是一对' if res.item() == 0 else '不是一对'