In [None]:
from transformers import (
    BertTokenizerFast,
    AutoModelForTokenClassification,
    AutoModelForMaskedLM,
    AutoModel
)
from transformers import pipeline

tokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese')
model = AutoModelForMaskedLM.from_pretrained('bert-base-chinese')
print(model)
print("-" * 72)
model = AutoModelForTokenClassification.from_pretrained('bert-base-chinese')
print(model)
print("-" * 72)
model = AutoModel.from_pretrained('bert-base-chinese')
print(model)

In [None]:
tokenizer.batch_encode_plus(
    [[
        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',
        '的', '海', '域', '。'
    ],
     [
         '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',
         '流', '的', '设', '计', '师', '主', '持', '设', '计', '，', '整', '个', '建', '筑',
         '群', '精', '美', '而', '恢', '宏', '。'
     ]],
    truncation=True,
    padding=True,
    return_tensors='pt',
    is_split_into_words=True)

In [None]:
import torch
from datasets import load_dataset, load_from_disk


class Dataset(torch.utils.data.Dataset):
    def __init__(self, split):
        #names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']

        #在线加载数据集
        #dataset = load_dataset(path='peoples_daily_ner', split=split)

        #离线加载数据集
        dataset = load_from_disk(dataset_path='./data/ner/')[split]

        #过滤掉太长的句子
        def f(data):
            return len(data['tokens']) <= 512 - 2

        dataset = dataset.filter(f)

        self.dataset = dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, i):
        tokens = self.dataset[i]['tokens']
        labels = self.dataset[i]['ner_tags']

        return tokens, labels


dataset = Dataset('train')

tokens, labels = dataset[0]

len(dataset), tokens, labels

In [None]:
def collate_fn(data):
    tokens = [i[0] for i in data]
    labels = [i[1] for i in data]

    inputs = tokenizer.batch_encode_plus(tokens,
                                         truncation=True,
                                         padding=True,
                                         return_tensors='pt',
                                         is_split_into_words=True)

    lens = inputs['input_ids'].shape[1]

    for i in range(len(labels)):
        labels[i] = [7] + labels[i]
        labels[i] += [7] * lens
        labels[i] = labels[i][:lens]

    return inputs, torch.LongTensor(labels)

In [None]:
loader = torch.utils.data.DataLoader(dataset=dataset,
                                     batch_size=16,
                                     collate_fn=collate_fn,
                                     shuffle=True,
                                     drop_last=True)

# 查看数据样例
for i, (inputs, labels) in enumerate(loader):
    print(len(loader))
    print(tokenizer.decode(inputs['input_ids'][0]))
    print(labels[0])
    break


for k, v in inputs.items():
    print(k, v.shape)

In [None]:
pretrained = AutoModel.from_pretrained('bert-base-chinese')

#统计参数量
print(sum(i.numel() for i in pretrained.parameters()) / 10000)

#模型试算
#[b, lens] -> [b, lens, 768]
pretrained(**inputs).last_hidden_state.shape

In [None]:
class Model(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.tuneing = False
        self.pretrained = None

        self.rnn = torch.nn.GRU(768, 768,batch_first=True)
        self.fc = torch.nn.Linear(768, 8)

    def forward(self, inputs):
        if self.tuneing:
            out = self.pretrained(**inputs).last_hidden_state
        else:
            with torch.no_grad():
                out = pretrained(**inputs).last_hidden_state

        out, _ = self.rnn(out)

        out = self.fc(out).softmax(dim=2)

        return out

    def fine_tuneing(self, tuneing):
        self.tuneing = tuneing
        if tuneing:
            for i in pretrained.parameters():
                i.requires_grad = True

            pretrained.train()
            self.pretrained = pretrained
        else:
            for i in pretrained.parameters():
                i.requires_grad_(False)

            pretrained.eval()
            self.pretrained = None


model = Model()

model(inputs).shape

In [None]:
def reshape_and_remove_pad(outs, labels, attention_mask):
    # 变形,便于计算loss
    # [b, lens, 8] -> [b*lens, 8]
    outs = outs.reshape(-1, 8)
    # [b, lens] -> [b*lens]
    labels = labels.reshape(-1)

    # 忽略对pad的计算结果
    # [b, lens] -> [b*lens - pad]
    select = attention_mask.reshape(-1) == 1
    outs = outs[select]
    labels = labels[select]

    return outs, labels


reshape_and_remove_pad(
    torch.randn(2, 3, 8),
    torch.ones(2, 3),
    torch.ones(2, 3)
)

In [None]:
def get_correct_and_total_count(labels, outs):
    # [b*lens, 8] -> [b*lens]
    outs = outs.argmax(dim=1)
    correct = (outs == labels).sum().item()
    total = len(labels)

    # 计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高
    select = labels != 0
    outs = outs[select]
    labels = labels[select]
    correct_content = (outs == labels).sum().item()
    total_content = len(labels)

    return correct, total, correct_content, total_content


get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))

In [None]:
from transformers import AdamW


# 训练
def train(epochs):
    lr = 2e-5 if model.tuneing else 5e-4

    # 训练
    optimizer = AdamW(model.parameters(), lr=lr)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(epochs):
        for step, (inputs, labels) in enumerate(loader):
            # 模型计算
            # [b, lens] -> [b, lens, 8]
            outs = model(inputs)

            # 对outs和label变形,并且移除pad
            # outs -> [b, lens, 8] -> [c, 8]
            # labels -> [b, lens] -> [c]
            outs, labels = reshape_and_remove_pad(outs, labels,
                                                  inputs['attention_mask'])

            # 梯度下降
            loss = criterion(outs, labels)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            if step % 50 == 0:
                counts = get_correct_and_total_count(labels, outs)

                accuracy = counts[0] / counts[1]
                accuracy_content = counts[2] / counts[3]

                print(epoch, step, loss.item(), accuracy, accuracy_content)

            if step == 100:
                break

        torch.save(model, './logs/checkpoint/ner.bin')

In [None]:
model.fine_tuneing(False)
print(sum(p.numel() for p in model.parameters()) / 10000)
# model.fine_tuneing(True)
# print(sum(p.numel() for p in model.parameters()) / 10000)
train(1)

In [None]:
def test():
    model_load = torch.load('./logs/checkpoint/ner.bin')
    model_load.eval()

    loader_test = torch.utils.data.DataLoader(
        dataset=Dataset('validation'),
        batch_size=128,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )

    correct = 0
    total = 0

    correct_content = 0
    total_content = 0

    for step, (inputs, labels) in enumerate(loader_test):
        if step == 5:
            break
        print(step)

        with torch.no_grad():
            # [b, lens] -> [b, lens, 8] -> [b, lens]
            outs = model_load(inputs)

        # 对outs和label变形,并且移除pad
        # outs -> [b, lens, 8] -> [c, 8]
        # labels -> [b, lens] -> [c]
        outs, labels = reshape_and_remove_pad(outs, labels,
                                              inputs['attention_mask'])

        counts = get_correct_and_total_count(labels, outs)
        correct += counts[0]
        total += counts[1]
        correct_content += counts[2]
        total_content += counts[3]

    print(correct / total, correct_content / total_content)


test()

In [None]:
def predict():
    model_load = torch.load('./logs/checkpoint/ner.bin')
    model_load.eval()

    loader_test = torch.utils.data.DataLoader(
        dataset=Dataset('validation'),
        batch_size=8,
        collate_fn=collate_fn,
        shuffle=True,
        drop_last=True
    )

    for i, (inputs, labels) in enumerate(loader_test):
        break

    with torch.no_grad():
        # [b, lens] -> [b, lens, 8] -> [b, lens]
        outs = model_load(inputs).argmax(dim=2)

    for i in range(8):
        # 移除pad
        select = inputs['attention_mask'][i] == 1
        input_id = inputs['input_ids'][i, select]
        out = outs[i, select]
        label = labels[i, select]

        # 输出原句子
        print(tokenizer.decode(input_id).replace(' ', ''))

        # 输出tag
        for tag in [label, out]:
            s = ''
            for j in range(len(tag)):
                if tag[j] == 0:
                    s += '·'
                    continue
                s += tokenizer.decode(input_id[j])
                s += str(tag[j].item())

            print(s)
        print('==========================')


predict()