In [None]:
import torch
from transformers import AutoModel, AutoTokenizer
from datasets import load_dataset, load_from_disk

In [None]:
device = 'cuda' if(torch.cuda.is_available()) else 'cpu'
print(device)

### 加载编码工具

In [None]:
# 加载分词器
tokenizer = AutoTokenizer.from_pretrained('hfl/rbt6')

print(tokenizer)

# 分词测试
tokenizer.batch_encode_plus(
    [[
        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间'', 的', '海', '域', '。'
    ],
    [
        '这', '座', '依', '山'', 傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一', '流', '的', '设', '计', '师', '主', '持', '设', '计', '，', '整', '个', '建', '筑', '群', '精', '美', '而', '恢', '宏', '。'
    ]],
    truncation=True,        
    padding=True,
    return_tensors='pt',
    is_split_into_words=True,   # 告诉分词器已经被分好词了
)

### 定义数据集

In [None]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        # names = ['0', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']   # labels[单词, 人名开始, 人名中间, 组织开始, 组织中间, 地名开始, 地名中间]

        # 在线加载数据集
        dataset = load_dataset(path='peoples_daily_ner', split=split)

        # 过滤掉太长的句子
        def f(data):
            return len(data['tokens']) <= 512 -2
        
        self.dataset = dataset.filter(f)
    
    def __len__(self):
        return len(self.dataset)
    
    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']
        
        return tokens, labels

In [None]:
dataset = Dataset('train')

tokens, labels = dataset[0]
len(dataset), tokens, labels

### 数据加载器

In [None]:
# 数据整理函数
def collate_fn(data):
    tokens = [i[0] for i in data]
    print(tokens)
    labels = [i[1] for i in data]

    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)
    lens = inputs['input_ids'].shape[1]

    # 由于每个label长度不等, 将其填充到最长长度(lens)
    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]
    
    return inputs, torch.LongTensor(labels)


# 数据加载器
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=4,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

In [None]:
# 查看数据样例
for i, (inputs, labels) in enumerate(loader):
    break

print(len(loader))
print(tokenizer.decode(inputs['input_ids'][0]))
print(labels[0])

for k, v in inputs.items():
    print(k, v.shape)

### 加载预训练模型

In [None]:
# 加载预训练模型
pretrained = AutoModel.from_pretrained('hfl/rbt6')

# 统计参数量
print(sum(i.numel() for i in pretrained.parameters()) / 10000)

# 模型试算
#[b, lens] -> [b, lens, 768]
pretrained(**inputs).last_hidden_state.shape

### 定义下游模型

In [None]:
# 定义下游模型
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = None

        self.rnn = torch.nn.GRU(768, 768)
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, inputs):
        if(self.tuneing):
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)
        out = self.fc(out).softmax(dim=2)

        return out

    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        if(tuneing):
            for i in pretrained.parameters():
                i.requires_grad = True
            
            pretrained.train()
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad = False
            
            pretrained.eval()
            self.pretrained = None

In [None]:
model = Model()

model(inputs).shape

### 工具函数

In [None]:
# 对计算结果和label变形, 并移除pad
def reshape_and_remove_pad(outs, labels, attention_mask):
    # 变形, 便于计算loss
    # [b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    # [b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    # 忽略对pad的计算结果
    # [b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels

reshape_and_remove_pad(torch.randn(2,3,8), torch.ones(2,3), torch.ones(2,3))

In [None]:
# 获取正确数量和总数
def get_correct_and_total_count(labels, outs):
    # [b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    # 计算除了0以外元素的正确率, 因为0太多了, 包括的话, 正确率很容易虚高
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content

get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

### 定义训练函数

In [None]:
# 训练
def train(model, pretrained, epoches):
    lr = 2e-5 if(model.tuneing) else 5e-4

    # 训练
    optimizer = torch.optim.AdamW(model.parameters(), lr = lr)
    criterion = torch.nn.CrossEntropyLoss().to(device)

    model.train()
    model = model.to(device)
    model.rnn = model.rnn.to(device)
    model.fc = model.fc.to(device)
    pretrained = pretrained.to(device)
    
    for epoch in range(epoches):
        for step, (inputs, labels) in enumerate(loader):
            # 模型计算
            # [b, lens] -> [b, lens, 8]
            for i in inputs:
                inputs[i] = inputs[i].to(device)
            labels = labels.to(device)
            outs = model(inputs)

            # 对outs和label变形, 并且移除pad
            # outs -> [b, lens, 8] -> [c, 8]
            # labels -> [b, lens] -> [c]
            outs, labels = reshape_and_remove_pad(outs, labels, inputs['attention_mask'])

            # 梯度下降
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if(step % 50 == 0):
                counts = get_correct_and_total_count(labels, outs)

                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]

                print(epoch, step, loss.item(), accuracy, accuracy_content)

    # torch.save(model, '命名体识别_中文.model')
    model = model.cpu()
    model.rnn = model.rnn.cpu()
    model.fc = model.fc.cpu()
    pretrained = pretrained.cpu()
    return model


In [None]:
# 模型总参数
model.fine_tuneing(False)
print(sum(i.numel() for i in model.parameters()) / 10000)
train(model, pretrained, 1)

model.fine_tuneing(True)
print(sum(i.numel() for i in model.parameters()) / 10000)
train(model, pretrained, 2)

### 测试

In [15]:
def test():
    model.eval()

    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),
                                              batch_size=128,
                                              collate_fn=collate_fn,
                                              shuffle=True,
                                              drop_last=True)
    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        if(step == 5):
            break
        print(step)

        with torch.no_grad():
            # [b, lens] -> [b, lens, 8] -> [b, lens]
            outs = model(inputs)
        
        # 对outs和label变形, 并且移除pad
        # outs -> [b, lens, 8] -> [c, 8]
        # labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels, inputs['attention_mask'])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)

test()

1
2
3
4
0.9886694829223565 0.953558926487748


### 使用

In [42]:
sentence1 = '在泸州的小狗就是一个大傻逼'
sentence2 = '在泸州的刘天一就是一个大傻逼'
names = ['0', '姓', '名', '组织开始', '组织中间', '地名开始', '地名中间', 'NONE']

inputs1 = tokenizer.encode_plus(sentence1, return_tensors='pt')
inputs2 = tokenizer.encode_plus(sentence2, return_tensors='pt')

out1 = model(inputs1)
out2 = model(inputs2)

print(sentence1, ': ')
for n, i in enumerate(out1.argmax(dim=2).tolist()[0][1:-1]):
    print(sentence1[n], ': ' , names[i])
print()
print(sentence2, ': ')
for n, i in enumerate(out2.argmax(dim=2).tolist()[0][1:-1]):
    print(sentence2[n], ': ' , names[i])

在泸州的小狗就是一个大傻逼 : 
在 :  0
泸 :  地名开始
州 :  地名中间
的 :  0
小 :  0
狗 :  0
就 :  0
是 :  0
一 :  0
个 :  0
大 :  0
傻 :  0
逼 :  0

在泸州的刘天一就是一个大傻逼 : 
在 :  0
泸 :  地名开始
州 :  地名中间
的 :  0
刘 :  姓
天 :  名
一 :  名
就 :  0
是 :  0
一 :  0
个 :  0
大 :  0
傻 :  0
逼 :  0
