## config

In [None]:
import os
import pickle
import logging
import threading


def set_logger(config):
    if not os.path.exists(config.log_path):
        os.mkdir(config.log_path)
    logging.basicConfig(
        format='%(asctime)s %(levelname)-8s %(message)s',
        level=logging.INFO,
        datefmt='%Y-%m-%d %H:%M:%S',
        filename=os.path.join(config.log_path, '{}.log'.format(config.model)),
        filemode='a'
    )
    console = logging.StreamHandler()
    console.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s %(levelname)-8s %(message)s')
    console.setFormatter(formatter)
    logging.getLogger('').addHandler(console)

def load_file(fp: str, sep: str = None):
    """
    读取文件；
    若sep为None，按行读取，返回文件内容列表，格式为:[xxx,xxx,xxx,...]
    若不为None，按行读取分隔，返回文件内容列表，格式为: [[xxx,xxx],[xxx,xxx],...]
    """
    with open(fp, "r", encoding="utf-8") as f:
        lines = f.readlines()
        if sep:
            return [line.strip().split(sep) for line in lines]
        else:
            return lines

def get_labels(config):
    """读取训练数据获取标签"""
    labels = ['I-PER', 'I-ORG', 'I-LOC', 'B-LOC', 'B-PER', 'O', 'B-ORG']
    labels.extend(['<START>', '<END>'])

    return labels


class Config(object):
    _instance_lock = threading.Lock()
    _init_flag = False

    def __init__(self):
        if not Config._init_flag:
            Config._init_flag = True
            self.base_path = os.path.abspath('/content/drive/MyDrive/Colab Notebooks')
            self._init_train_config()


    def _init_train_config(self):
        self.label_list = []
        self.use_gpu = True
        self.device = "cuda"
        self.checkpoints = False  # 使用预训练模型时设置为False
        self.model = 'bert_bilstm_crf'  # 可选['bert_bilstm_crf','bilstm_crf','bilstm','crf','hmm']

        # 输入数据集、日志、输出目录
        self.train_file = os.path.join(self.base_path, 'data/train.txt')
        self.test_file = os.path.join(self.base_path, 'data/test.txt')
        self.log_path = os.path.join(self.base_path, 'logs')
        # self.output_path = os.path.join(self.base_path, 'output', datetime.datetime.now().strftime('%Y%m%d%H%M%S'))
        self.output_path = os.path.join(self.base_path, 'output', self.model)
        self.trained_model_path = os.path.join(self.base_path, 'ckpts', self.model)
        self.model_name_or_path = os.path.join(self.base_path, 'ckpts', 'bert-base-chinese') if not self.checkpoints \
            else self.trained_model_path

        # 以下是模型训练参数
        self.do_train = True
        self.do_eval = False
        self.need_birnn = True
        self.do_lower_case = True
        self.rnn_dim = 128
        self.max_seq_length = 128
        self.batch_size = 16
        self.num_train_epochs = 2
        self.ckpts_epoch = 1
        self.gradient_accumulation_steps = 1
        self.learning_rate = 3e-5
        self.adam_epsilon = 1e-8
        self.warmup_steps = 0
        self.logging_steps = 50
        self.remove_O = False


In [None]:
!pip install pytorch-crf

Collecting pytorch-crf
  Downloading pytorch_crf-0.7.2-py3-none-any.whl (9.5 kB)
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


# BERT_BiLSTM_CRF 模型

In [None]:
import torch.nn as nn
from transformers import BertPreTrainedModel, BertModel
from torchcrf import CRF


class BERT_BiLSTM_CRF(BertPreTrainedModel):
    def __init__(self, config, need_birnn=False, rnn_dim=128):
        super(BERT_BiLSTM_CRF, self).__init__(config)
        self.bert = BertModel(config)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        out_dim = config.hidden_size  # 768

        if need_birnn:
            self.need_birnn = need_birnn
            self.birnn = nn.LSTM(input_size=config.hidden_size, hidden_size=rnn_dim, num_layers=1, bidirectional=True,
                                 batch_first=True)
            out_dim = rnn_dim * 2

        self.hidden2tag = nn.Linear(in_features=out_dim, out_features=config.num_labels)

        self.crf = CRF(num_tags=config.num_labels, batch_first=True)

    def forward(self, input_ids, tags, token_type_ids=None, attention_mask=None):
        """
        :param input_ids:      torch.Size([batch_size,seq_len]), 代表输入实例的tensor张量
        :param tags:      torch.Size([batch_size,seq_len]), 真实标签
        :param token_type_ids: torch.Size([batch_size,seq_len]), 一个实例可以含有两个句子,相当于标记
        :param attention_mask:     torch.Size([batch_size,seq_len]), 指定对哪些词进行self-Attention操作
        :return: loss
        """
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]  # torch.Size([batch_size,seq_len,hidden_size])
        if self.need_birnn:
            sequence_output, _ = self.birnn(sequence_output)  # (seq_length,batch_size,num_directions*hidden_size)
        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)  # [seq_length, batch_size, num_labels]
        loss = -1 * self.crf(emissions, tags, mask=attention_mask.byte())
        return loss

    def predict(self, input_ids, token_type_ids=None, attention_mask=None):
        """
        :param input_ids:      torch.Size([batch_size,seq_len]), 代表输入实例的tensor张量
        :param token_type_ids:   torch.Size([batch_size,seq_len]), 一个实例可以含有两个句子,相当于标记
        :param attention_mask:   torch.Size([batch_size,seq_len]), 指定对哪些词进行self-Attention操作

        :return: pred_tags:     batch_size的list
        """
        outputs = self.bert(input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        if self.need_birnn:
            sequence_output, _ = self.birnn(sequence_output)
        sequence_output = self.dropout(sequence_output)
        emissions = self.hidden2tag(sequence_output)
        return self.crf.decode(emissions, attention_mask.byte())


#train and test

In [None]:
import os
import torch
import logging

from tqdm import tqdm, trange
from torch.utils.data import DataLoader, SequentialSampler, Dataset
from transformers import AdamW, get_linear_schedule_with_warmup, BertTokenizer, BertConfig

# from utils import *
# from dataloader import NERDataset
# from models import BERT_BiLSTM_CRF

from collections import Counter

class Metrics(object):
    """用于评价模型，计算每个标签的精确率，召回率，F1分数"""
    def __init__(self, golden_tags, predict_tags, remove_O=False):
        # 所有句子tags的拼接[[t1, t2], [t3, t4]...] --> [t1, t2, t3, t4...]
        self.golden_tags = self.flatten_lists(golden_tags)
        self.predict_tags = self.flatten_lists(predict_tags)
        if remove_O:  # 将O标记移除，只关心实体标记
            self._remove_Otags()

        # 辅助计算的变量
        self.tagset = set(self.golden_tags)
        self.correct_tags_number = self.count_correct_tags()
        self.predict_tags_counter = Counter(self.predict_tags)
        self.golden_tags_counter = Counter(self.golden_tags)

        self.precision_scores = self.cal_precision()
        self.recall_scores = self.cal_recall()
        self.f1_scores = self.cal_f1()

    def flatten_lists(self, lists):
        flatten_list = []
        for l in lists:
            if type(l) == list:
                flatten_list += l
            else:
                flatten_list.append(l)
        return flatten_list

    def cal_precision(self):
        precision_scores = {}
        for tag in self.tagset:
            precision_scores[tag] = self.correct_tags_number.get(tag, 0) / \
                self.predict_tags_counter[tag]

        return precision_scores

    def cal_recall(self):
        recall_scores = {}
        for tag in self.tagset:
            recall_scores[tag] = self.correct_tags_number.get(tag, 0) / \
                self.golden_tags_counter[tag]
        return recall_scores

    def cal_f1(self):
        f1_scores = {}
        for tag in self.tagset:
            p, r = self.precision_scores[tag], self.recall_scores[tag]
            f1_scores[tag] = 2*p*r / (p+r+1e-10)  # 加上一个特别小的数，防止分母为0
        return f1_scores

    def report_scores(self):
        """将结果用表格的形式打印出来，像这个样子：
                      precision    recall  f1-score   support
              B-LOC      0.775     0.757     0.766      1084
              I-LOC      0.601     0.631     0.616       325
             B-MISC      0.698     0.499     0.582       339
             I-MISC      0.644     0.567     0.603       557
              B-ORG      0.795     0.801     0.798      1400
              I-ORG      0.831     0.773     0.801      1104
              B-PER      0.812     0.876     0.843       735
              I-PER      0.873     0.931     0.901       634
          avg/total      0.779     0.764     0.770      6178
        """
        # 打印表头
        header_format = '{:>9s}  {:>9} {:>9} {:>9} {:>9}'
        header = ['precision', 'recall', 'f1-score', 'support']
        logging.info(header_format.format('', *header))

        # 打印每个标签的 精确率、召回率、f1分数
        row_format = '{:>9s}  {:>9.4f} {:>9.4f} {:>9.4f} {:>9}'
        for tag in self.tagset:
            logging.info(row_format.format(
                tag,
                self.precision_scores[tag],
                self.recall_scores[tag],
                self.f1_scores[tag],
                self.golden_tags_counter[tag]
            ))

        # 计算并打印平均值
        avg_metrics = self.cal_avg_metrics()
        logging.info(row_format.format(
            'avg/total',
            avg_metrics['precision'],
            avg_metrics['recall'],
            avg_metrics['f1_score'],
            len(self.golden_tags)
        ))


    def count_correct_tags(self):
        """计算每种标签预测正确的个数(对应精确率、召回率计算公式上的tp)，用于后面精确率以及召回率的计算"""
        correct_dict = {}
        for gold_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
            if gold_tag == predict_tag:
                if gold_tag not in correct_dict:
                    correct_dict[gold_tag] = 1
                else:
                    correct_dict[gold_tag] += 1

        return correct_dict

    def cal_avg_metrics(self):
        avg_metrics = {}
        total = len(self.golden_tags)

        avg_metrics['precision'] = 0.
        avg_metrics['recall'] = 0.
        avg_metrics['f1_score'] = 0.
        for tag in self.tagset:
            size = self.golden_tags_counter[tag]
            avg_metrics['precision'] += self.precision_scores[tag] * size
            avg_metrics['recall'] += self.recall_scores[tag] * size
            avg_metrics['f1_score'] += self.f1_scores[tag] * size

        for metric in avg_metrics.keys():
            avg_metrics[metric] /= total

        return avg_metrics

    def _remove_Otags(self):
        length = len(self.golden_tags)
        O_tag_indices = [i for i in range(length) if self.golden_tags[i] == 'O']
        self.golden_tags = [tag for i, tag in enumerate(self.golden_tags) if i not in O_tag_indices]
        self.predict_tags = [tag for i, tag in enumerate(self.predict_tags) if i not in O_tag_indices]
        logging.info("原总标记数为{}，移除了{}个O标记，占比{:.2f}%".format(
            length,
            len(O_tag_indices),
            len(O_tag_indices) / length * 100
        ))

    def report_confusion_matrix(self):
        """计算混淆矩阵"""
        logging.info("Confusion Matrix:")
        tag_list = list(self.tagset)
        # 初始化混淆矩阵 matrix[i][j]表示第i个tag被模型预测成第j个tag的次数
        tags_size = len(tag_list)
        matrix = []
        for i in range(tags_size):
            matrix.append([0] * tags_size)

        for golden_tag, predict_tag in zip(self.golden_tags, self.predict_tags):
            try:
                row = tag_list.index(golden_tag)
                col = tag_list.index(predict_tag)
                matrix[row][col] += 1
            except ValueError:  # 有极少数标记没有出现在golden_tags，但出现在predict_tags，跳过这些标记
                continue

        row_format_ = '{:>7} ' * (tags_size+1)
        logging.info(row_format_.format("", *tag_list))
        for i, row in enumerate(matrix):
            logging.info(row_format_.format(tag_list[i], *row))


class Bert_Bilstm_Crf():
    def __init__(self, config, device, use_gpu, n_gpu, writer, id2label):
        self.config = config
        self.device = device
        self.use_gpu = use_gpu
        self.n_gpu = n_gpu
        self.writer = writer
        self.id2label = id2label
        self.tokenizer = BertTokenizer.from_pretrained(config.model_name_or_path,
                                                  do_lower_case=True)
        bert_config = BertConfig.from_pretrained(config.model_name_or_path, num_labels=len(config.label_list))
        self.model = BERT_BiLSTM_CRF.from_pretrained(config.model_name_or_path, config=bert_config,
                                                need_birnn=True, rnn_dim=config.rnn_dim)
        self.model.to(device)
        logging.info("loading tokenizer、bert_config and bert_bilstm_crf model successful!")

    def train(self):
        if self.use_gpu and self.n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        logging.info("starting load train data and data_loader...")
        dataset = NERDataset(self.config, self.tokenizer, mode='train')
        dataloader = DataLoader(dataset, self.config.batch_size, shuffle=True)

        logging.info("loading train data_set and data_loader successful!")

        # 初始化模型参数优化器
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=self.config.learning_rate, eps=self.config.adam_epsilon)

        # 初始化学习率优化器
        t_total = len(dataloader) // self.config.gradient_accumulation_steps * self.config.num_train_epochs
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=self.config.warmup_steps,
                                                    num_training_steps=t_total)
        logging.info("loading AdamW optimizer、Warmup LinearSchedule and calculate optimizer parameter successful!")

        logging.info("====================== Running training ======================")
        logging.info(
            f"Num Examples:  {len(dataset)}, Num Batch Step: {len(dataloader)}, "
            f"Num Epochs: {self.config.num_train_epochs}, Num scheduler steps：{t_total}")

        # 启用 BatchNormalization 和 Dropout
        self.model.train()
        global_step, tr_loss, logging_loss, best_f1 = 0, 0.0, 0.0, 0.0
        for epoch in range(int(self.config.num_train_epochs)):
            # model.train()
            for batch, batch_data in enumerate(tqdm(dataloader, desc="Train_DataLoader")):
                # input_ids = torch.tensor(batch_data['input_ids'], dtype=torch.long)
                # token_type_ids = torch.tensor(batch_data['token_type_ids'], dtype=torch.long)
                # attention_mask = torch.tensor(batch_data['attention_mask'], dtype=torch.long)
                # label_ids = torch.tensor(batch_data['label_ids'], dtype=torch.long)

                batch_data = tuple(torch.stack(batch_data[k]).T.to(self.device) for k in batch_data.keys())
                input_ids, token_type_ids, attention_mask, label_ids = batch_data
                outputs = self.model(input_ids, label_ids, token_type_ids, attention_mask)
                loss = outputs

                if self.use_gpu and self.n_gpu > 1:
                    loss = loss.mean()

                if self.config.gradient_accumulation_steps > 1:
                    loss = loss / self.config.gradient_accumulation_steps

                logging.info(f"Epoch: {epoch}/{int(self.config.num_train_epochs)}\tBatch: {batch}/{len(dataloader)}\tLoss:{loss}")
                # 反向传播
                loss.backward()
                tr_loss += loss.item()

                # 优化器_模型参数的总更新次数，和上面的t_total对应
                if (batch + 1) % self.config.gradient_accumulation_steps == 0:
                    # 更新参数
                    optimizer.step()
                    scheduler.step()
                    # 梯度清零
                    self.model.zero_grad()
                    global_step += 1

                    if self.config.logging_steps > 0 and global_step % self.config.logging_steps == 0:
                        tr_loss_avg = (tr_loss - logging_loss) / self.config.logging_steps
                        self.writer.add_scalar("Train/loss", tr_loss_avg, global_step)
                        logging_loss = tr_loss

            if self.config.do_eval:
                logging.info("====================== Running Eval ======================")
                eval_data = NERDataset(self.config, self.tokenizer, mode="eval")

                avg_metrics, cal_indicators, eval_sens = self.evaluate(
                    self.config, self.tokenizer, eval_data, self.model, self.id2label, self.device, tqdm_desc="Eval_DataLoader")
                f1_score = avg_metrics['f1_score']
                self.writer.add_scalar("Eval/precision", avg_metrics['precision'], epoch)
                self.writer.add_scalar("Eval/recall", avg_metrics['recall'], epoch)
                self.writer.add_scalar("Eval/f1_score", avg_metrics['f1_score'], epoch)

                # save the best performs model
                if f1_score > best_f1:
                    logging.info(f"******** the best f1 is {f1_score}, save model !!! ********")
                    best_f1 = f1_score
                    # Take care of distributed/parallel training
                    model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
                    model_to_save.save_pretrained(self.config.trained_model_path)
                    self.tokenizer.save_pretrained(self.config.trained_model_path)
                    model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
                    model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
                    self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))

            # # （如果config.do_eval=False，注释以下模型断点保存步骤）
            # # 数据集过大，需要分阶段、分时训练时每隔一段时间保存checkpoints
            # if (epoch + 1) % self.config.ckpts_epoch == 0:
            #     model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
            #     model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
            #     self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))

        model_to_save = self.model.module if hasattr(self.model, 'module') else self.model
        model_to_save.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))
        self.tokenizer.save_pretrained(os.path.join(self.config.trained_model_path, 'checkpoints'))

        # torch.save(self.config, os.path.join(self.config.trained_model_path, 'training_config.bin'))
        # torch.save(self.model, os.path.join(self.config.trained_model_path, 'ner_model.ckpt'))
        # logging.info("training_args.bin and ner_model.ckpt save successful!")

        self.writer.close()
        logging.info("NER model training successful!!!")

    @staticmethod
    def evaluate(config, tokenizer, dataset, model, id2label, device, tqdm_desc):
        sampler = SequentialSampler(dataset)
        data_loader = DataLoader(dataset, sampler=sampler, batch_size=config.batch_size)
        if isinstance(model, torch.nn.DataParallel):
            model = model.module
        model.eval()

        id2label[-1] = 'NULL'  # 解码临时添加
        ori_tokens = [tokenizer.decode(tdt['input_ids']).split(" ") for tdt in dataset]
        ori_labels = [[id2label[idx] for idx in tdt['label_ids']] for tdt in dataset]
        pred_labels = []

        for b_i, batch_data in enumerate(tqdm(data_loader, desc=tqdm_desc)):
            batch_data = tuple(torch.stack(batch_data[k]).T.to(device) for k in batch_data.keys())
            input_ids, token_type_ids, attention_mask, label_ids = batch_data

            with torch.no_grad():
                logits = model.predict(input_ids, token_type_ids, attention_mask)

            for logit in logits:
                pred_labels.append([id2label[idx] for idx in logit])

        assert len(pred_labels) == len(ori_tokens) == len(ori_labels)
        eval_sens = []
        for ori_token, ori_label, pred_label in zip(ori_tokens, ori_labels, pred_labels):
            sen_tll = []
            for ot, ol, pl in zip(ori_token, ori_label, pred_label):
                if ot in ["[CLS]", "[SEP]", "[PAD]"]:
                    continue
                sen_tll.append((ot, ol, pl))
            eval_sens.append(sen_tll)

        golden_tags = [[ttl[1] for ttl in sen] for sen in eval_sens]
        predict_tags = [[ttl[2] for ttl in sen] for sen in eval_sens]
        cal_indicators = Metrics(golden_tags, predict_tags, remove_O=config.remove_O)
        avg_metrics = cal_indicators.cal_avg_metrics()  # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']

        return avg_metrics, cal_indicators, eval_sens


    def test(self):
        logging.info("====================== Running test ======================")
        dataset = NERDataset(self.config, self.tokenizer, mode='test')
        avg_metrics, cal_indicators, eval_sens = self.evaluate(
            self.config, self.tokenizer, dataset, self.model, self.id2label, self.device, tqdm_desc="Test_DataLoader")

        cal_indicators.report_scores()  # avg_metrics['precision'], avg_metrics['recall'], avg_metrics['f1_score']
        cal_indicators.report_confusion_matrix()
        # 将测试结果写入本地
        with open(os.path.join(self.config.output_path, "token_labels_test.txt"), "w", encoding="utf-8") as f:
            for sen in eval_sens:
                for ttl in sen:
                    f.write(f"{ttl[0]}\t{ttl[1]}\t{ttl[2]}\n")
                f.write("\n")


# data loader

In [None]:
import os
import logging
import torch
from torch.utils.data import Dataset, TensorDataset

class InputData(object):
    """A single training/test example for simple sequence classification."""
    def __init__(self, guid, text, label=None):
        self.guid = guid
        self.text = text
        self.label = label


class InputFeatures(object):
    """A single set of features of data."""
    def __init__(self, input_ids, token_type_ids, attention_mask, label_id):
        """
        :param input_ids:       单词在词典中的编码
        :param attention_mask:  指定 对哪些词 进行self-Attention操作
        :param token_type_ids:  区分两个句子的编码（上句全为0，下句全为1）
        :param label_id:        标签的id
        """
        self.input_ids = input_ids
        self.token_type_ids = token_type_ids
        self.attention_mask = attention_mask
        self.label_id = label_id


class NERDataset(Dataset):
    def __init__(self, config, tokenizer, mode="train"):
        # text: a list of words, all text from the training dataset
        super(NERDataset, self).__init__()
        self.config = config
        self.tokenizer = tokenizer
        if mode == "train":
            self.file_path = config.train_file
        elif mode == "test":
            self.file_path = config.test_file
        elif mode == "eval":
            self.file_path = config.dev_file
        else:
            raise ValueError("mode must be one of train, or test")

        self.tdt_data = self.get_data()
        self.len = len(self.tdt_data)

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        """
        对指定数据集进行预处理，进一步封装数据，包括:
        tdt_data：[InputData(guid=index, text=text, label=label)]
        feature：BatchEncoding( input_ids=input_ids,
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask,
                                label_id=label_ids)
        data_f： 处理完成的数据集, TensorDataset(all_input_ids, all_token_type_ids, all_attention_mask, all_label_ids)
        """
        label_map = {label: i for i, label in enumerate(self.config.label_list)}
        max_seq_length = self.config.max_seq_length

        data = self.tdt_data[idx]
        data_text_list = data.text.split(" ")
        data_label_list = data.label.split(" ")
        assert len(data_text_list) == len(data_label_list)

        features = self.tokenizer(''.join(data_text_list), padding='max_length', max_length=max_seq_length, truncation=True)
        label_ids = [label_map[label] for label in data_label_list]
        label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]
        while len(label_ids) < max_seq_length:
            label_ids.append(-1)
        features.data['label_ids'] = label_ids

        return features


    def read_file(self):
        with open(self.file_path, "r", encoding="utf-8") as f:
            lines, words, labels = [], [], []
            for line in f.readlines():
                contends = line.strip()
                tokens = line.strip().split()
                if len(tokens) == 2:
                    words.append(tokens[0])
                    labels.append(tokens[1])
                else:
                    if len(contends) == 0 and len(words) > 0:
                        label, word = [], []
                        for l, w in zip(labels, words):
                            if len(l) > 0 and len(w) > 0:
                                label.append(l)
                                word.append(w)
                        lines.append([' '.join(label), ' '.join(word)])
                        words, labels = [], []
        return lines


    def get_data(self):
        '''数据预处理并返回相关数据'''
        lines = self.read_file()
        tdt_data = []
        for i, line in enumerate(lines):
            guid = str(i)
            text = line[1]
            # word_piece = self.word_piece_bool(text)
            # if word_piece:
            #     continue
            label = line[0]
            tdt_data.append(InputData(guid=guid, text=text, label=label))

        return tdt_data


    def word_piece_bool(self, text):
        word_piece = False
        data_text_list = text.split(' ')
        for i, word in enumerate(data_text_list):
            # 防止wordPiece情况出现，不过貌似不会
            token = self.tokenizer.tokenize(word)
            # 单个字符表示不会出现wordPiece
            if len(token) != 1:
                word_piece = True

        return word_piece


    @staticmethod
    def convert_data_to_features(self, tdt_data):
        """
        对输入数据进行特征转换
        例如:
            guid: 0
            tokens: [CLS] 王 辉 生 前 驾 驶 机 械 洒 药 消 毒 9 0 后 王 辉 ， 2 0 1 0 年 1 2 月 参 军 ， 2 0 1 5 年 1 2 月 退 伍 后 ， 先 是 应 聘 当 辅 警 ， 后 来 在 父 亲 成 立 的 扶 风 恒 盛 科 [SEP]
            input_ids: 101 4374 6778 4495 1184 7730 7724 3322 3462 3818 5790 3867 3681 130 121 1400 4374 6778 8024 123 121 122 121 2399 122 123 3299 1346 1092 8024 123 121 122 126 2399 122 123 3299 6842 824 1400 8024 1044 3221 2418 5470 2496 6774 6356 8024 1400 3341 1762 4266 779 2768 4989 4638 2820 7599 2608 4670 4906 102
            token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
            attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1
            label_ids: 2 5 3 2 2 2 2 2 2 2 2 2 2 4 11 11 5 3 2 4 11 11 11 11 11 11 11 2 2 2 4 11 11 11 11 11 11 11 2 2 2 2 2 2 2 2 2 0 14 2 2 2 2 2 2 2 2 2 12 7 7 7 7 2
        """
        label_map = {label: i for i, label in enumerate(self.config.label_list)}
        max_seq_length = self.config.max_seq_length

        features = []
        for data in tdt_data:
            data_text_list = data.text.split(" ")
            data_label_list = data.label.split(" ")
            assert len(data_text_list) == len(data_label_list)

            tokens, labels, ori_tokens = [], [], []
            word_piece = False
            for i, word in enumerate(data_text_list):
                # 防止wordPiece情况出现，不过貌似不会
                token = self.tokenizer.tokenize(word)
                tokens.extend(token)
                label = data_label_list[i]
                ori_tokens.append(word)
                # 单个字符不会出现wordPiece
                if len(token) == 1:
                    labels.append(label)
                else:
                    word_piece = True

            if word_piece:
                logging.info("Error tokens!!! skip this lines, the content is: %s" % " ".join(data_text_list))
                continue

            assert len(tokens) == len(ori_tokens)

            # feature = self.tokenizer(''.join(tokens), padding='max_length', max_length=max_seq_length, truncation=True)
            # label_ids = [label_map[label] for label in labels]
            # label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]
            # while len(label_ids) < max_seq_length:
            #     label_ids.append(-1)
            # feature.data['label_ids'] = label_ids
            # features.append(feature)

            if len(tokens) >= max_seq_length - 1:
                # -2的原因是因为序列需要加一个句首和句尾标志
                tokens = tokens[0:(max_seq_length - 2)]
                labels = labels[0:(max_seq_length - 2)]

            label_ids = [label_map[label] for label in labels]
            new_tokens = ["[CLS]"] + tokens + ["[SEP]"]
            input_ids = self.tokenizer.convert_tokens_to_ids(new_tokens)
            token_type_ids = [0] * len(input_ids)
            attention_mask = [1] * len(input_ids)
            label_ids = [label_map["<START>"]] + label_ids + [label_map["<END>"]]

            while len(input_ids) < max_seq_length:
                input_ids.append(0)
                attention_mask.append(0)
                token_type_ids.append(0)
                label_ids.append(0)

            features.append(InputFeatures(input_ids=input_ids,
                                          token_type_ids=token_type_ids,
                                          attention_mask=attention_mask,
                                          label_id=label_ids))
        return features


# main

In [None]:
import torch

from torch.utils.tensorboard import SummaryWriter

# from utils import *
# from trainer import Bert_Bilstm_Crf


def main():
    config = Config()
    set_logger(config)
    writer = SummaryWriter(log_dir=os.path.join(config.output_path, "visual"), comment="ner")

    if config.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(config.gradient_accumulation_steps))

    use_gpu = torch.cuda.is_available() and config.use_gpu
    device = torch.device('cuda' if use_gpu else 'cpu')
    config.device = device
    n_gpu = torch.cuda.device_count()
    logging.info(f"available device: {device}，count_gpu: {n_gpu}")

    config.label_list = get_labels(config)
    label2id = {label: i for i, label in enumerate(config.label_list)}
    id2label = {i: label for label, i in label2id.items()}
    logging.info("loading label2id and id2label dictionary successful!")

    # Bert_Bilstm_Crf模型的训练与测试
    trainer_bbc = Bert_Bilstm_Crf(config, device, use_gpu, n_gpu, writer, id2label)
    # trainer_bbc.train() # 训练
    trainer_bbc.test()  # 测试

if __name__ == '__main__':
    main()


Some weights of BERT_BiLSTM_CRF were not initialized from the model checkpoint at /content/drive/MyDrive/Colab Notebooks/ckpts/bert-base-chinese and are newly initialized: ['birnn.bias_hh_l0_reverse', 'hidden2tag.weight', 'birnn.bias_hh_l0', 'crf.end_transitions', 'birnn.weight_hh_l0_reverse', 'birnn.weight_ih_l0_reverse', 'birnn.bias_ih_l0_reverse', 'crf.start_transitions', 'crf.transitions', 'birnn.weight_ih_l0', 'hidden2tag.bias', 'birnn.bias_ih_l0', 'birnn.weight_hh_l0']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  score = torch.where(mask[i].unsqueeze(1), next_score, score)
Train_DataLoader: 100%|██████████| 3167/3167 [24:42<00:00,  2.14it/s]
Train_DataLoader: 100%|██████████| 3167/3167 [24:55<00:00,  2.12it/s]
Test_DataLoader: 100%|██████████| 290/290 [00:43<00:00,  6.65it/s]


# 预测

In [1]:
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report

txt_path = '/content/drive/MyDrive/Colab Notebooks/output/bert_bilstm_crf/token_labels_test.txt'
preds = []
labels = []
texts = []
error = 0
with open(txt_path, "r", encoding="utf-8") as f:
    lines = f.readlines()
    for line in lines:
        t = line.split()
        if len(t) == 3:
            if t[2] not in ['I-PER', 'I-ORG', 'I-LOC', 'B-LOC', 'B-PER', 'O', 'B-ORG']:
                error += 1
                continue
            texts.append(t[0])
            labels.append(t[1])
            preds.append(t[2])
precision = precision_score(labels, preds, average='macro')
recall = recall_score(labels, preds, average='macro')
f1 = f1_score(labels, preds, average='macro')
report = classification_report(labels, preds)
print()
print(report)
print()
print('precision: ', precision)
print('recall: ', recall)
print('f1_score: ', f1)
print('error: ', error)
print()
print('原文：', texts[:16])
print('标签：', labels[:16])
print('预测：', preds[:16])


              precision    recall  f1-score   support

       B-LOC       0.97      0.94      0.95      2871
       B-ORG       0.92      0.93      0.92      1327
       B-PER       0.97      0.97      0.97      1972
       I-LOC       0.95      0.93      0.94      4370
       I-ORG       0.93      0.96      0.95      5640
       I-PER       0.98      0.98      0.98      3844
           O       1.00      1.00      1.00    150935

    accuracy                           0.99    170959
   macro avg       0.96      0.96      0.96    170959
weighted avg       0.99      0.99      0.99    170959


precision:  0.9584684531096469
recall:  0.9576453856701352
f1_score:  0.9579837850878423
error:  1

原文： ['中', '共', '中', '央', '致', '中', '国', '致', '公', '党', '十', '一', '大', '的', '贺', '词']
标签： ['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'O', 'O']
预测： ['B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', 'O', 'B-ORG', 'I-ORG', 'I-ORG', 'I-ORG', '