In [1]:
!pip install transformers

Collecting transformers
  Downloading transformers-4.18.0-py3-none-any.whl (4.0 MB)
[K     |████████████████████████████████| 4.0 MB 31.3 MB/s 
Collecting tokenizers!=0.11.3,<0.13,>=0.11.1
  Downloading tokenizers-0.12.1-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.6 MB)
[K     |████████████████████████████████| 6.6 MB 65.5 MB/s 
Collecting pyyaml>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 61.9 MB/s 
Collecting sacremoses
  Downloading sacremoses-0.0.49-py3-none-any.whl (895 kB)
[K     |████████████████████████████████| 895 kB 56.4 MB/s 
[?25hCollecting huggingface-hub<1.0,>=0.1.0
  Downloading huggingface_hub-0.5.1-py3-none-any.whl (77 kB)
[K     |████████████████████████████████| 77 kB 4.8 MB/s 
Installing collected packages: pyyaml, tokenizers, sacremoses, huggingface-hub, transformers
  Attempting uninstall: pyyaml
    F

In [2]:
import json
import os
from collections import Counter

import numpy as np

from transformers import BertModel, BertTokenizer,BertTokenizerFast
import torch
import torch
from torch.utils.data import TensorDataset, DataLoader
from tqdm import tqdm


In [3]:
"""
Date: 2021-06-01 17:18:25
LastEditors: GodK
"""
import time

common = {
    "exp_name": "cluener",
    "encoder": "BERT",
    "data_home": "data",
    "bert_path": "./pretrained_models/bert-base-chinese",  # bert-base-cased， bert-base-chinese
    "run_type": "train",  # train,eval
    "f1_2_save": 0.5,  # 存模型的最低f1值
    "logger": "default"  # wandb or default，default意味着只输出日志到控制台
}

# wandb的配置，只有在logger=wandb时生效。用于可视化训练过程
wandb_config = {
    "run_name": time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime()),
    "log_interval": 10
}

train_config = {
    "train_data": "train.json",
    "valid_data": "dev.json",
    "ent2id": "ent2id.json",
    "path_to_save_model": "./outputs",  # 在logger不是wandb时生效
    "hyper_parameters": {
        "lr": 5e-5,
        "batch_size": 1,
        "epochs": 50,
        "seed": 2333,
        "max_seq_len": 128,
        "scheduler": "CAWR"
    }
}

eval_config = {
    "model_state_dir": "./outputs/cluener/",  # 预测时注意填写模型路径（时间tag文件夹）
    "run_id": "",
    "last_k_model": 1,  # 取倒数第几个model_state
    "test_data": "test.json",
    "ent2id": "ent2id.json",
    "save_res_dir": "./results",
    "hyper_parameters": {
        "batch_size": 16,
        "max_seq_len": 512,
    }

}

cawr_scheduler = {
    # CosineAnnealingWarmRestarts
    "T_mult": 1,
    "rewarm_epoch_num": 2,
}
step_scheduler = {
    # StepLR
    "decay_rate": 0.999,
    "decay_steps": 100,
}

# ---------------------------------------------
train_config["hyper_parameters"].update(**cawr_scheduler, **step_scheduler)
train_config = {**train_config, **common, **wandb_config}
eval_config = {**eval_config, **common}


In [5]:
model_name='bert-base-chinese'
seq_length=128
tokenizer = BertTokenizerFast.from_pretrained(model_name)

Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/107k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/263k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/624 [00:00<?, ?B/s]

In [6]:
def load_data(data_path, data_type="train"):
    """读取数据集

    Args:
        data_path (str): 数据存放路径
        data_type (str, optional): 数据类型. Defaults to "train".

    Returns:
        (json): train和valid中一条数据格式：{"text":"","entity_list":[(start, end, label), (start, end, label)...]}
    """
    if data_type == "train" or data_type == "valid":
        # datas = []
        # with open(data_path, encoding="utf-8") as f:
        #     for line in f:
        #         line = json.loads(line)
        #         item = {}
        #         item["text"] = line["text"]
        #         item["entity_list"] = []
        #         for k, v in line['label'].items():
        #             for spans in v.values():
        #                 for start, end in spans:
        #                     item["entity_list"].append((start, end, k))
        #
        #         datas.append(item)
        datas = []
        with open(data_path, encoding="utf-8") as f:
            for line in f.readlines():
                line = json.loads(line)
                item = {}
                item["text"] = line["text"]
                item["entity_list"] = []
                for k, v in line['label'].items():
                    for k_, span in v.items():
                        for spans in span:
                            item["entity_list"].append((spans[0], spans[1], k))

                datas.append(item)
        return datas
    else:
        return json.load(open(data_path, encoding="utf-8"))

def find_head_idx(source, target):
    target_len = len(target)
    for i in range(len(source)):
        if source[i: i + target_len] == target:
            return i,i + target_len-1
    return -1,-1


entity2id={
    "O":0,
"B-address": 1, "I-address": 2,
"B-book": 3,"I-book": 4,
"B-company": 5,"I-company": 6,
"B-game": 7, "I-game":8,
"B-government": 9, "I-government": 10,
    "B-movie": 11, "I-movie": 12,
    "B-name": 13, "I-name": 14,
    "B-organization": 15,"I-organization": 16,
    "B-position": 17, "I-position": 18,
    "B-scene": 19,"I-scene": 20
}

BERT_MAX_LEN = 128

## data generator

In [31]:
class Preprocessor(object):
    def __init__(self, tokenizer):
        super(Preprocessor, self).__init__()
        self.tokenizer = tokenizer
        self.add_special_tokens = True

    def get_ent2token_spans(self, text, entity_list):
        """实体列表转为token_spans

        Args:
            text (str): 原始文本
            entity_list (list): [(start, end, ent_type),(start, end, ent_type)...]
        """
        ent2token_spans = []

        inputs = self.tokenizer(text, add_special_tokens=True, return_offsets_mapping=True)
        token2char_span_mapping = inputs["offset_mapping"]
        text2tokens = self.tokenizer.tokenize(text, add_special_tokens=self.add_special_tokens)

        for ent_span in entity_list:
            ent = text[ent_span[0]:ent_span[1] + 1]
            ent2token = self.tokenizer.tokenize(ent, add_special_tokens=False)

            # 寻找ent的token_span
            token_start_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[0]]

            token_end_indexs = [i for i, v in enumerate(text2tokens) if v == ent2token[-1]]

            token_start_index = list(filter(lambda x: token2char_span_mapping[x][0] == ent_span[0], token_start_indexs))
            token_end_index = list(filter(lambda x: token2char_span_mapping[x][-1] - 1 == ent_span[1], token_end_indexs))  # token2char_span_mapping[x][-1]-1 减1是因为原始的char_span是闭区间，而token2char_span是开区间

            if len(token_start_index) == 0 or len(token_end_index) == 0:
                # print(f'[{ent}] 无法对应到 [{text}] 的token_span，已丢弃')
                continue
            token_span = (token_start_index[0], token_end_index[0], ent_span[2])
            ent2token_spans.append(token_span)

        return ent2token_spans

class data_generator:
    def __init__(self, data, tokenizer, maxlen):
        self.data = data
        self.batch_size = len(self.data)
        self.tokenizer = tokenizer
        self.maxlen = maxlen
        self.preprocessor = Preprocessor(tokenizer)

    def __len__(self):
        return self.batch_size

    def generator(self):
        # while True:
        # idxs = list(range(len(self.data)))
        sent_lengths = []
        tokens_batch, segments_batch, token_type_batch, labels_ids = [], [], [], []
        start_indexs = []
        end_indexs = []
        # tokens_batch--------tokens ids;segments_batch-------attention mask; token_type_ids--------；labels_ids------ 每个实体的label标签
        for sample in self.data:

            text_len = BERT_MAX_LEN
            inputs = self.tokenizer(sample['text'], return_tensors='pt', add_special_tokens=True, truncation=True, padding=True,
                                    max_length=BERT_MAX_LEN)
            token_ids, segment_ids, token_type_ids = inputs['input_ids'], inputs['attention_mask'], inputs[
                'token_type_ids']

            sent_lengths.append(token_ids.shape[1])
            pad_len = BERT_MAX_LEN - token_ids.shape[1]
            pad_seq = torch.zeros(1, pad_len)
            token_ids = torch.cat((token_ids, pad_seq), dim=1)
            segment_ids = torch.cat((segment_ids, pad_seq), dim=1)
            token_type_ids = torch.cat((token_type_ids, pad_seq), dim=1)
            #句子的label标签
            labels=torch.zeros(self.maxlen)
            ent2token_spans = self.preprocessor.get_ent2token_spans(
                sample["text"], sample["entity_list"]
            )
            for start, end, label in ent2token_spans:
                labels[start]=entity2id['B-'+label]
                if start+1!=end+1:
                    labels[start+1:end+1]=entity2id['I-'+label]

            labels_ids.append(labels)
            token_type_batch.append(token_type_ids)
            tokens_batch.append(token_ids)
            segments_batch.append(segment_ids)

            # print(sub_heads,sub_tails,torch.where(obj_heads==1),torch.where(obj_tails)==1)
        return tokens_batch, segments_batch, token_type_batch, labels_ids,sent_lengths

In [33]:
if __name__ == '__main__':
    train_path=os.path.join(common["data_home"], common["exp_name"], train_config["train_data"])
    train_data = load_data(train_path)
    print('train_data_length',len(train_data))
    train_config=train_config
    maxlen = train_config["hyper_parameters"]['max_seq_len']
    train_tokens_batch, train_segments_batch, train_token_type_batch, train_labels_ids,sent_lengths = \
        data_generator(train_data,tokenizer,maxlen).generator()

    vaild_path=os.path.join(common["data_home"], common["exp_name"], train_config["valid_data"])
    vaild_data=load_data(vaild_path)[:-500]
    print('valid_data_length',len(vaild_data))
    test_data=load_data(vaild_path)[-500:]
    vaild_tokens_batch, vaild_segments_batch, vaild_token_type_batch, valid_labels_ids,sent_lengths = \
        data_generator(vaild_data, tokenizer, maxlen).generator()

    test_tokens_batch, test_segments_batch, test_token_type_batch, test_labels_ids,test_sent_lengths = \
        data_generator(test_data, tokenizer, maxlen).generator()
sign=['train_','valid_','test_']
for di in sign:
    if di == 'train_':
        tokens_batch, segments_batch, token_type_batch, labels_ids= train_tokens_batch, train_segments_batch, train_token_type_batch, train_labels_ids
        tokens_batch = torch.cat([l for l in tokens_batch]).int()
        print(tokens_batch.shape)
        segments_batch = torch.cat([l for l in segments_batch]).int()  # sents_length*128
        # 命名实体识别

        token_type_batch = torch.cat([l for l in token_type_batch]).int()
        labels_ids_batch = torch.cat([l for l in labels_ids]).reshape(tokens_batch.shape[0], -1).int()
        train_dataset=TensorDataset(tokens_batch,segments_batch,token_type_batch,labels_ids_batch)

    elif di=='valid_':
        tokens_batch, segments_batch, token_type_batch, labels_ids =  vaild_tokens_batch, vaild_segments_batch, vaild_token_type_batch, valid_labels_ids
        tokens_batch = torch.cat([l for l in tokens_batch]).int()
        print(tokens_batch.shape)
        segments_batch = torch.cat([l for l in segments_batch]).int()  # sents_length*128
        # 命名实体识别

        token_type_batch = torch.cat([l for l in token_type_batch]).int()
        labels_ids_batch = torch.cat([l for l in labels_ids]).reshape(tokens_batch.shape[0], -1).int()
        valid_dataset=TensorDataset(tokens_batch,segments_batch,token_type_batch,labels_ids_batch)

    elif di=='test_':
        tokens_batch, segments_batch, token_type_batch, labels_ids,sent_lengths = test_tokens_batch, test_segments_batch, test_token_type_batch, test_labels_ids,test_sent_lengths
        tokens_batch = torch.cat([l for l in tokens_batch]).int()
        print(tokens_batch.shape)
        segments_batch = torch.cat([l for l in segments_batch]).int()  # sents_length*128
        # 命名实体识别
        sent_lengths_batch = torch.tensor(sent_lengths).reshape(tokens_batch.shape[0], -1).int()
        token_type_batch = torch.cat([l for l in token_type_batch]).int()
        labels_ids_batch = torch.cat([l for l in labels_ids]).reshape(tokens_batch.shape[0], -1).int()
        test_dataset = TensorDataset(tokens_batch, segments_batch, token_type_batch, labels_ids_batch,sent_lengths_batch)

print('train_dataset_length',len(train_dataset))
print('vaild_dataset_length',len(valid_dataset))
print('test_dataset_length',len(test_dataset))

train_data_length 10748
valid_data_length 843
torch.Size([10748, 128])
torch.Size([843, 128])
torch.Size([500, 128])
train_dataset_length 10748
vaild_dataset_length 843
test_dataset_length 500


In [13]:
!pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [14]:
from transformers import AdamW
from transformers import BertModel
from torch.nn import functional as F
from torchcrf import CRF

In [15]:
import torch.nn as nn
class E2EModel(nn.Module):
    def __init__(self):
        super(E2EModel, self).__init__()
        self.tagset_size=len(entity2id.keys())
        self.hidden_dim=768
        self.encode=BertModel.from_pretrained(model_name)
        # self.lstm = nn.LSTM(embedding_dim, hidden_dim // 2,
        #                     num_layers=1, bidirectional=True)
        self.classifier1 = nn.Linear(768, 128)
        # Maps the output of the LSTM into tag space.
        self.classifier2 = nn.Linear(128, self.tagset_size)

        self.crf=CRF(self.tagset_size,batch_first=True)

    def forward(self,inputs_id,att_mask,token_type_ids,labels_ids):
        bo=self.encode(inputs_id,att_mask,token_type_ids)[0]#B*L*768
        # x=x.view(128, -1, 768)
        x=self.classifier1(bo)
        # x=x.view(128, hidden_dim)
        sub=self.classifier2(x)
        slogits=sub
        soutput = (slogits,)

        sloss = self.crf(emissions=slogits, tags=labels_ids.long(), mask=att_mask.byte())
        soutputs = (-1 * sloss,) + soutput
        return bo,soutputs  # (loss), scores

In [16]:
def same_seeds(seed):
    # torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    # np.random.seed(seed)
    # random.seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
same_seeds(0)

In [17]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
from torch.optim.lr_scheduler import LambdaLR
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """ Create a schedule with a learning rate that decreases linearly after
    linearly increasing during a warmup period.
    """
    def lr_lambda(current_step):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)


In [22]:
def train(eenet, train_dataset, val_set, num_epochs, learning_rate, batch_size):
    i = 1
    model_path = 'model'
    train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=2)
    print('执行次数为：{}'.format(i))
    t_total = len(train_loader) * num_epochs
    # warmup_steps = int(t_total * warmup_proportion)
    warmup_steps = 50
    optimizer = AdamW(eenet.parameters(), lr=learning_rate)
    scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps,
                                                num_training_steps=t_total)

    eenet = eenet.to(device)

    # tokens_batch[:rc_train_size], segments_batch[:rc_train_size], token_type_batch[:rc_train_size], labels_ids_batch[:rc_train_size], sub_batch_tags[:rc_train_size], obj_batch_tags[:rc_train_size]
    for epoch in range(num_epochs):
        eenet.train()
        train_loss = 0.
        step = 0
        for batch_idx, data in enumerate(tqdm(train_loader)):
            torch.autograd.set_detect_anomaly(True)

            inputs_id, att_mask, token_type_ids, labels_ids = data
            text, mask, token_type_ids, labels_ids = inputs_id.to(device), att_mask.to(device), token_type_ids.to(
                device), labels_ids.to(device)
            optimizer.zero_grad()
            # ee _model识别实体----------------------------------------------
            x, soutputs = eenet(text, mask, token_type_ids, labels_ids)
            s_loss, slogits = soutputs[:2]
            # 准备一个batch(batch_size=1)的训练数据
            spre = eenet.crf.decode(slogits, mask.byte())
            # ee_type model 识别实体类型-----------------------------------------------
            t_loss = s_loss
            t_loss.backward()
            # all_loss.backward()
            train_loss += t_loss.item()
            optimizer.step()
            scheduler.step()
            step += 1

            if step % 100 == 0:
                print('\n')
                print('*' * 10)
                print(
                    'train_epoch|{},average_loss={}'.format(epoch + 1,  train_loss /(batch_size*step)))

        torch.save(eenet.state_dict(), 'ee2' + model_path)

        eenet.eval()
        min_loss = 1
        with torch.no_grad():
            val_loss = 0.
            val_step = 0
            vaild_loader = DataLoader(val_set, batch_size=batch_size)
            for batch_idx, data in enumerate(tqdm(vaild_loader)):

                inputs_id, att_mask, token_type_ids, labels_ids = data
                text, mask, token_type_ids, labels_ids = inputs_id.to(device), att_mask.to(device), token_type_ids.to(
                    device), labels_ids.to(device)
                # ee _model识别实体----------------------------------------------
                x, soutputs = eenet(text, mask, token_type_ids, labels_ids)
                s_loss, slogits = soutputs[:2]
                # all_loss.backward()
                val_loss += s_loss.item()

                val_step += 1

                if val_step % 100 == 0:
                    print('\n')
                    print('-' * 10)
                    print(
                        'val_epoch|{},val_averge_loss={}'.format(epoch + 1, val_loss / (batch_size*val_step)))

            if val_loss < min_loss:
                min_loss = val_loss
                print('save model')
                torch.save(eenet.state_dict(), 'ee_best' + model_path)

eenet=E2EModel()
train(eenet,train_dataset, valid_dataset,20, 1e-5, 32)

Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


执行次数为：1


 30%|██▉       | 100/336 [02:48<06:35,  1.68s/it]



**********
train_epoch|1,average_loss=58.130461349487305


 60%|█████▉    | 200/336 [05:37<03:49,  1.68s/it]



**********
train_epoch|1,average_loss=39.80197557449341


 89%|████████▉ | 300/336 [08:25<01:00,  1.67s/it]



**********
train_epoch|1,average_loss=30.95521881421407


100%|██████████| 336/336 [09:26<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.21it/s]
 30%|██▉       | 100/336 [02:49<06:42,  1.70s/it]



**********
train_epoch|2,average_loss=9.69733693599701


 60%|█████▉    | 200/336 [05:38<03:49,  1.68s/it]



**********
train_epoch|2,average_loss=9.164098937511444


 89%|████████▉ | 300/336 [08:27<00:59,  1.66s/it]



**********
train_epoch|2,average_loss=8.583388728300731


100%|██████████| 336/336 [09:28<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.23it/s]
 30%|██▉       | 100/336 [02:49<06:36,  1.68s/it]



**********
train_epoch|3,average_loss=6.843464889526367


 60%|█████▉    | 200/336 [05:38<03:56,  1.74s/it]



**********
train_epoch|3,average_loss=6.687037572860718


 89%|████████▉ | 300/336 [08:26<01:00,  1.67s/it]



**********
train_epoch|3,average_loss=6.416750960350036


100%|██████████| 336/336 [09:26<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.20it/s]
 30%|██▉       | 100/336 [02:48<06:38,  1.69s/it]



**********
train_epoch|4,average_loss=5.572550456523896


 60%|█████▉    | 200/336 [05:37<03:48,  1.68s/it]



**********
train_epoch|4,average_loss=5.492569981813431


 89%|████████▉ | 300/336 [08:26<01:00,  1.67s/it]



**********
train_epoch|4,average_loss=5.306122080485026


100%|██████████| 336/336 [09:27<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  3.99it/s]
 30%|██▉       | 100/336 [02:49<06:50,  1.74s/it]



**********
train_epoch|5,average_loss=4.754248151779175


 60%|█████▉    | 200/336 [05:38<03:52,  1.71s/it]



**********
train_epoch|5,average_loss=4.682039009332657


 89%|████████▉ | 300/336 [08:28<01:00,  1.68s/it]



**********
train_epoch|5,average_loss=4.518847116629282


100%|██████████| 336/336 [09:29<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  3.99it/s]
 30%|██▉       | 100/336 [02:49<06:39,  1.69s/it]



**********
train_epoch|6,average_loss=4.207011156082153


 60%|█████▉    | 200/336 [05:39<03:50,  1.69s/it]



**********
train_epoch|6,average_loss=4.10266646027565


 89%|████████▉ | 300/336 [08:28<01:01,  1.71s/it]



**********
train_epoch|6,average_loss=3.9433957159519197


100%|██████████| 336/336 [09:30<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  3.99it/s]
 30%|██▉       | 100/336 [02:49<06:37,  1.68s/it]



**********
train_epoch|7,average_loss=3.693643810749054


 60%|█████▉    | 200/336 [05:39<03:49,  1.69s/it]



**********
train_epoch|7,average_loss=3.5966575217247008


 89%|████████▉ | 300/336 [08:28<01:02,  1.74s/it]



**********
train_epoch|7,average_loss=3.4570091744263967


100%|██████████| 336/336 [09:29<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  4.00it/s]
 30%|██▉       | 100/336 [02:49<06:38,  1.69s/it]



**********
train_epoch|8,average_loss=3.2743906712532045


 60%|█████▉    | 200/336 [05:38<03:54,  1.72s/it]



**********
train_epoch|8,average_loss=3.206157926917076


 89%|████████▉ | 300/336 [08:28<01:00,  1.68s/it]



**********
train_epoch|8,average_loss=3.061591362953186


100%|██████████| 336/336 [09:28<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.01it/s]
 30%|██▉       | 100/336 [02:49<06:36,  1.68s/it]



**********
train_epoch|9,average_loss=2.978748859167099


 60%|█████▉    | 200/336 [05:39<03:48,  1.68s/it]



**********
train_epoch|9,average_loss=2.8603312373161316


 89%|████████▉ | 300/336 [08:28<01:00,  1.69s/it]



**********
train_epoch|9,average_loss=2.731859248081843


100%|██████████| 336/336 [09:29<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  4.00it/s]
 30%|██▉       | 100/336 [02:49<06:39,  1.69s/it]



**********
train_epoch|10,average_loss=2.6229579532146454


 60%|█████▉    | 200/336 [05:38<03:48,  1.68s/it]



**********
train_epoch|10,average_loss=2.5364594638347624


 89%|████████▉ | 300/336 [08:27<01:01,  1.70s/it]



**********
train_epoch|10,average_loss=2.4291844590504965


100%|██████████| 336/336 [09:28<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  3.99it/s]
 30%|██▉       | 100/336 [02:50<06:46,  1.72s/it]



**********
train_epoch|11,average_loss=2.3757664799690246


 60%|█████▉    | 200/336 [05:40<03:53,  1.72s/it]



**********
train_epoch|11,average_loss=2.327839229106903


 89%|████████▉ | 300/336 [08:31<01:01,  1.71s/it]



**********
train_epoch|11,average_loss=2.1994695778687796


100%|██████████| 336/336 [09:33<00:00,  1.71s/it]
100%|██████████| 27/27 [00:06<00:00,  3.96it/s]
 30%|██▉       | 100/336 [02:49<06:35,  1.68s/it]



**********
train_epoch|12,average_loss=2.1650555777549743


 60%|█████▉    | 200/336 [05:39<03:49,  1.69s/it]



**********
train_epoch|12,average_loss=2.074055982232094


 89%|████████▉ | 300/336 [08:29<01:02,  1.73s/it]



**********
train_epoch|12,average_loss=1.9762180825074513


100%|██████████| 336/336 [09:31<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  3.99it/s]
 30%|██▉       | 100/336 [02:50<06:37,  1.68s/it]



**********
train_epoch|13,average_loss=2.0072878193855286


 60%|█████▉    | 200/336 [05:40<03:50,  1.69s/it]



**********
train_epoch|13,average_loss=1.9491668635606765


 89%|████████▉ | 300/336 [08:30<01:00,  1.68s/it]



**********
train_epoch|13,average_loss=1.8174493972460428


100%|██████████| 336/336 [09:32<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  4.04it/s]
 30%|██▉       | 100/336 [02:50<06:37,  1.68s/it]



**********
train_epoch|14,average_loss=1.7974305868148803


 60%|█████▉    | 200/336 [05:40<03:53,  1.72s/it]



**********
train_epoch|14,average_loss=1.7723235648870468


 89%|████████▉ | 300/336 [08:30<01:00,  1.68s/it]



**********
train_epoch|14,average_loss=1.6715702950954436


100%|██████████| 336/336 [09:31<00:00,  1.70s/it]
100%|██████████| 27/27 [00:06<00:00,  4.20it/s]
 30%|██▉       | 100/336 [02:49<06:39,  1.69s/it]



**********
train_epoch|15,average_loss=1.6449739694595338


 60%|█████▉    | 200/336 [05:38<03:55,  1.73s/it]



**********
train_epoch|15,average_loss=1.6178591203689576


 89%|████████▉ | 300/336 [08:27<00:59,  1.66s/it]



**********
train_epoch|15,average_loss=1.5289241806666056


100%|██████████| 336/336 [09:28<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.13it/s]
 30%|██▉       | 100/336 [02:49<06:36,  1.68s/it]



**********
train_epoch|16,average_loss=1.5116298770904542


 60%|█████▉    | 200/336 [05:38<03:46,  1.67s/it]



**********
train_epoch|16,average_loss=1.5201425313949586


 89%|████████▉ | 300/336 [08:26<01:00,  1.68s/it]



**********
train_epoch|16,average_loss=1.4350100107987722


100%|██████████| 336/336 [09:27<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  4.20it/s]
 30%|██▉       | 100/336 [02:48<06:44,  1.71s/it]



**********
train_epoch|17,average_loss=1.4627880716323853


 60%|█████▉    | 200/336 [05:36<03:47,  1.67s/it]



**********
train_epoch|17,average_loss=1.4519072216749191


 89%|████████▉ | 300/336 [08:25<01:00,  1.69s/it]



**********
train_epoch|17,average_loss=1.3558174693584442


100%|██████████| 336/336 [09:26<00:00,  1.69s/it]
100%|██████████| 27/27 [00:06<00:00,  3.97it/s]
 30%|██▉       | 100/336 [02:50<06:40,  1.70s/it]



**********
train_epoch|18,average_loss=1.3877946770191192


 57%|█████▋    | 191/336 [05:26<04:07,  1.71s/it]


KeyboardInterrupt: ignored

In [23]:
#test-------------------------------
eenet=E2EModel()
model_path='model'
eenet.load_state_dict(torch.load('ee2' + model_path))
id2label={k:v for v,k in entity2id.items()}


Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [24]:
def extrac_triple(text, pre_):  # 抽取一个triple

    subjects = []
    labels=[]
    for i in range(len(text)):
        sj = text[i]

        if  '##' in sj:
            sj = subjects[-1]+sj.lstrip("##")
            subjects[-1]=sj
        else:
            lb = pre_[i]
            labels.append(id2label[lb])
            subjects.append(sj)

    # print(subjects,labels)
    # print(len(subjects))
    # print(len(labels))
    return subjects,labels

def get_entity_bio(seq,id2label):
    """Gets entities from sequence.
    note: BIO
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
        get_entity_bio(seq)
        #output
        [['PER', 0,1], ['LOC', 3, 3]]
    """
    chunks = []
    chunk = [-1, -1, -1]
    for indx, tag in enumerate(seq):
        if not isinstance(tag, str):
            tag = id2label[tag]
        if tag.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
            chunk[1] = indx
            chunk[0] = tag.split('-')[1]
            chunk[2] = indx
            if indx == len(seq) - 1:
                chunks.append(chunk)
        elif tag.startswith('I-') and chunk[1] != -1:
            _type = tag.split('-')[1]
            if _type == chunk[0]:
                chunk[2] = indx

            if indx == len(seq) - 1:
                chunks.append(chunk)
        else:
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
    return chunks

In [25]:
f = open('实体标签输出.json', mode='w+', encoding='utf_8')
results = []
preds_total = []
true_total = []
def get_entity_bios(seq,id2label):
    """Gets entities from sequence.
    note: BIOS
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        # >>> seq = ['B-PER', 'I-PER', 'O', 'S-LOC']
        # >>> get_entity_bios(seq)
        [['PER', 0,1], ['LOC', 3, 3]]
    """
    chunks = []
    chunk = [-1, -1, -1]
    for indx, tag in enumerate(seq):
        if not isinstance(tag, str):
            tag = id2label[tag]
        if tag.startswith("S-"):
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
            chunk[1] = indx
            chunk[2] = indx
            chunk[0] = tag.split('-')[1]
            chunks.append(chunk)
            chunk = (-1, -1, -1)
        if tag.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
            chunk[1] = indx
            chunk[0] = tag.split('-')[1]
        elif tag.startswith('I-') and chunk[1] != -1:
            _type = tag.split('-')[1]
            if _type == chunk[0]:
                chunk[2] = indx
            if indx == len(seq) - 1:
                chunks.append(chunk)
        else:
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
    return chunks

def get_entity_bio(seq,id2label):
    """Gets entities from sequence.
    note: BIO
    Args:
        seq (list): sequence of labels.
    Returns:
        list: list of (chunk_type, chunk_start, chunk_end).
    Example:
        seq = ['B-PER', 'I-PER', 'O', 'B-LOC']
        get_entity_bio(seq)
        #output
        [['PER', 0,1], ['LOC', 3, 3]]
    """
    chunks = []
    chunk = [-1, -1, -1]
    for indx, tag in enumerate(seq):
        if not isinstance(tag, str):
            tag = id2label[tag]
        if tag.startswith("B-"):
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
            chunk[1] = indx
            chunk[0] = tag.split('-')[1]
            chunk[2] = indx
            if indx == len(seq) - 1:
                chunks.append(chunk)
        elif tag.startswith('I-') and chunk[1] != -1:
            _type = tag.split('-')[1]
            if _type == chunk[0]:
                chunk[2] = indx

            if indx == len(seq) - 1:
                chunks.append(chunk)
        else:
            if chunk[2] != -1:
                chunks.append(chunk)
            chunk = [-1, -1, -1]
    return chunks

def get_entities(seq,id2label,markup='bios'):
    '''
    :param seq:
    :param id2label:
    :param markup:
    :return:
    '''
    assert markup in ['bio','bios']
    if markup =='bio':
        return get_entity_bio(seq,id2label)
    else:
        return get_entity_bios(seq,id2label)
class SeqEntityScore(object):
    def __init__(self, id2label,markup='bios'):
        self.id2label = id2label
        self.markup = markup
        self.reset()

    def reset(self):
        self.origins = []
        self.founds = []
        self.rights = []

    def compute(self, origin, found, right):
        recall = 0 if origin == 0 else (right / origin)
        precision = 0 if found == 0 else (right / found)
        f1 = 0. if recall + precision == 0 else (2 * precision * recall) / (precision + recall)
        return recall, precision, f1

    def result(self):
        class_info = {}
        origin_counter = Counter([x[0] for x in self.origins])
        found_counter = Counter([x[0] for x in self.founds])
        right_counter = Counter([x[0] for x in self.rights])
        for type_, count in origin_counter.items():
            origin = count
            found = found_counter.get(type_, 0)
            right = right_counter.get(type_, 0)
            recall, precision, f1 = self.compute(origin, found, right)
            class_info[type_] = {"acc": round(precision, 4), 'recall': round(recall, 4), 'f1': round(f1, 4)}
        origin = len(self.origins)
        found = len(self.founds)
        right = len(self.rights)
        recall, precision, f1 = self.compute(origin, found, right)
        return {'acc': precision, 'recall': recall, 'f1': f1}, class_info

    def update(self, label_paths, pred_paths):
        '''
        labels_paths: [[],[],[],....]
        pred_paths: [[],[],[],.....]

        :param label_paths:
        :param pred_paths:
        :return:
        Example:
            >>> labels_paths = [['O', 'O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
            >>> pred_paths = [['O', 'O', 'B-MISC', 'I-MISC', 'I-MISC', 'I-MISC', 'O'], ['B-PER', 'I-PER', 'O']]
        '''
        for label_path, pre_path in zip(label_paths, pred_paths):
            label_entities = get_entities(label_path, self.id2label,self.markup)
            pre_entities = get_entities(pre_path, self.id2label,self.markup)
            self.origins.extend(label_entities)
            self.founds.extend(pre_entities)
            self.rights.extend([pre_entity for pre_entity in pre_entities if pre_entity in label_entities])

metric = SeqEntityScore(id2label, markup='bio')

In [35]:
with torch.no_grad():
    eenet.eval()
    eenet.cpu()
    eval_loss = 0.0
    epoch = 0
    nb_eval_steps = 0
    dict_t = {}
    test_loader = DataLoader(test_dataset, batch_size=1)
    step = 0
    acc = 0
    total = 0

    for batch_idx, data in enumerate(tqdm(test_loader)):
        json_d = {}
        json_d['id'] = batch_idx
        inputs_id, att_mask, token_type_ids, labels_ids ,dlength = data

        inputs = {"input_ids": data[0], "attention_mask": data[1], "labels": data[3]}
        # ee _model识别实体----------------------------------------------
        x, soutputs = eenet(inputs_id, att_mask, token_type_ids, labels_ids)
        tmp_eval_loss, logits = soutputs[:2]
        # obj_htar, obj_ttar = obj_htar.unsqueeze(-1), obj_ttar.unsqueeze(-1)
        # total_loss.backward()
        tags = eenet.crf.decode(logits, att_mask.byte())
        eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1
        out_label_ids = inputs['labels'].cpu().numpy().tolist()

        input_lens = [len(i) for i in tags]
        # tags = tags.squeeze(0).cpu().numpy().tolist()
        i = 0
        label = labels_ids[i]
        text = tokenizer.convert_ids_to_tokens(inputs_id[[i][0]])
        text = text[1:dlength[i] - 1]
        pre_ = torch.tensor(tags[i][1:dlength[i] - 1])
        true_ = label[1:len(pre_) + 1]
        acc += (pre_ == true_).sum().item()
        pre_ = tags[i][1:dlength[i] - 1]
        true_ = true_.numpy().tolist()
        json_d['tag_seq'] = " ".join([id2label[x] for x in pre_])
        pre_label_entities = get_entity_bio(pre_, id2label)
        true_label_entites = get_entity_bio(true_, id2label)
        json_d['pre_entities'] = pre_label_entities
        json_d['true_label_entites'] = true_label_entites
        results.append(json_d)
        preds_total.append(pre_)
        true_total.append(true_)

        total += len(pre_)
        subjects, labels = extrac_triple(text, pre_)
        for w, t in zip(subjects, labels):
            f.write(f'{w} {t}\n')
        f.write('\n')

se1=SeqEntityScore(id2label)
se1.update(preds_total,true_total)
s=se1.result()
print(s)

100%|██████████| 500/500 [03:35<00:00,  2.32it/s]

({'acc': 0.7954545454545454, 'recall': 0.7954545454545454, 'f1': 0.7954545454545455}, {'government': {'acc': 0.91, 'recall': 0.8349, 'f1': 0.8708}, 'company': {'acc': 0.8134, 'recall': 0.7676, 'f1': 0.7899}, 'organization': {'acc': 0.7311, 'recall': 0.8447, 'f1': 0.7838}, 'scene': {'acc': 0.725, 'recall': 0.7945, 'f1': 0.7582}, 'name': {'acc': 0.8889, 'recall': 0.8696, 'f1': 0.8791}, 'address': {'acc': 0.6693, 'recall': 0.6391, 'f1': 0.6538}, 'position': {'acc': 0.7747, 'recall': 0.7663, 'f1': 0.7705}, 'book': {'acc': 0.7377, 'recall': 0.8491, 'f1': 0.7895}, 'game': {'acc': 0.8429, 'recall': 0.7468, 'f1': 0.7919}, 'movie': {'acc': 0.8511, 'recall': 1.0, 'f1': 0.9195}})



