In [4]:
%%writefile logger.py
import time
import sys
import logging
import numpy as np


def get_logger(filename):
    """Return a logger instance that writes in filename
    Args:
        filename: (string) path to log.txt
    Returns:
        logger: (instance of logger)
    """
    # 实例化一 logger 对象, 若为 .getLogger(), 则返回 rootLogger, 为所有 logger 实例的父 logger
    logger = logging.getLogger('logger')
    # 为 logger 实例设定输出级别
    logger.setLevel(logging.DEBUG)
    # 将 logger 内容输出到文件
    file_handler = logging.FileHandler(filename)
    stream_hander = logging.StreamHandler()
    file_handler.setLevel(logging.DEBUG)
    file_handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s'))
    logger.addHandler(file_handler)
    logger.addHandler(stream_hander)
    return logger


class Progbar(object):
    """Progbar class copied from keras (https://github.com/fchollet/keras/)
    Displays a progress bar.
    Small edit : added strict arg to update
    Arguments
        target: Total number of steps expected.
        interval: Minimum visual progress update interval (in seconds).
    """
    def __init__(self, target, width=30, verbose=1):
        self.width = width
        self.target = target
        self.sum_values = {}
        self.unique_values = []
        self.start = time.time()
        self.total_width = 0
        self.seen_so_far = 0
        self.verbose = verbose

    def update(self, current, values=None, exact=None, strict=None):
        """Updates the progress bar.
        Arguments
            current: Index of current step.
            values: List of tuples (name, value_for_last_step).
                The progress bar will display averages for these values.
            exact: List of tuples (name, value_for_last_step).
                The progress bar will display these values directly.
        """
        if strict is None:
            strict = []
        if exact is None:
            exact = []
        if values is None:
            values = []
        for k, v in values:
            if k not in self.sum_values:
                self.sum_values[k] = [v * (current - self.seen_so_far), current - self.seen_so_far]
                self.unique_values.append(k)
            else:
                self.sum_values[k][0] += v * (current - self.seen_so_far)
                self.sum_values[k][1] += (current - self.seen_so_far)
        for k, v in exact:
            if k not in self.sum_values:
                self.unique_values.append(k)
            self.sum_values[k] = [v, 1]

        for k, v in strict:
            if k not in self.sum_values:
                self.unique_values.append(k)
            self.sum_values[k] = v

        self.seen_so_far = current

        now = time.time()
        if self.verbose == 1:
            prev_total_width = self.total_width
            sys.stdout.write("\b" * prev_total_width)
            sys.stdout.write("\r")
            numdigits = int(np.floor(np.log10(self.target))) + 1
            barstr = '%%%dd/%%%dd [' % (numdigits, numdigits)
            bar = barstr % (current, self.target)
            prog = float(current)/self.target
            prog_width = int(self.width*prog)
            if prog_width > 0:
                bar += ('='*(prog_width-1))
                if current < self.target:
                    bar += '>'
                else:
                    bar += '='
            bar += ('.'*(self.width-prog_width))
            bar += ']'
            sys.stdout.write(bar)
            self.total_width = len(bar)
            if current:
                time_per_unit = (now - self.start) / current
            else:
                time_per_unit = 0
            eta = time_per_unit*(self.target - current)
            info = ''
            if current < self.target:
                info += ' - ETA: %ds' % eta
            else:
                info += ' - %ds' % (now - self.start)
            for k in self.unique_values:
                if type(self.sum_values[k]) is list:
                    info += ' - %s: %.5f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
                else:
                    info += ' - %s: %.5f' % (k, self.sum_values[k])
            self.total_width += len(info)
            if prev_total_width > self.total_width:
                info += ((prev_total_width-self.total_width) * ' ')
            sys.stdout.write(info)
            sys.stdout.flush()
            if current >= self.target:
                sys.stdout.write("\n")
        if self.verbose == 2:
            if current >= self.target:
                info = '%ds' % (now - self.start)
                for k in self.unique_values:
                    info += ' - %s: %.4f' % (k, self.sum_values[k][0] / max(1, self.sum_values[k][1]))
                sys.stdout.write(info + "\n")

    def add(self, n, values=None):
        if values is None:
            values = []
        self.update(self.seen_so_far+n, values)

Writing logger.py


In [5]:
import json
import os
import torch
import transformers as tfs
import random
from torch import nn
from torch import optim
from logger import Progbar

In [15]:
if BASE_MODEL_PATH == "":
    curdir = os.path.dirname(os.path.abspath(__file__))
    model_path = os.path.join(curdir, "model")
    if not os.path.exists(model_path):
        os.mkdir(model_path)
else:
    model_path = BASE_MODEL_PATH

In [16]:
model_path

'model'

In [24]:
BASE_DATASET_PATH = 'dataset'#数据文件
BASE_MODEL_PATH = 'model'#微调文件
PRETRAINED_BERT_ENCODER_PATH = 'pretrain'#预训练文件
dataset_path = BASE_DATASET_PATH
mdoel_path = 'BASE_MODEL_PATH'
#bert微调模型
FINETUNED_BERT_ENCODER_PATH = os.path.join(model_path,"finetuned_bert.bin")
POSITIVE_TRAIN_FILE_PATH = os.path.join(dataset_path, "postive_train.json")
POSITIVE_TRAIN_INFO_PATH = os.path.join(dataset_path, "positive_info.json")
UNLABELED_TRAIN_FILE_PATH = os.path.join(dataset_path,"unlabeled_train.json")
BERT_MODEL_SAVE_PATH = model_path
BATCH_SIZE=2
EPOCH=1
#获取数据集的标签及其大小
def get_label_set_and_sample_num(config_path, sample_num=False):
    with open(config_path, "r", encoding="UTF-8") as input_file:
        json_data = json.loads(input_file.readline())
        if sample_num:
            return json_data["label_list"], json_data["total_num"]
        else:
            return json_data["label_list"]

# 获取一个epoch需要的batch数
def get_steps_per_epoch(line_count, batch_size):
    return line_count // batch_size if line_count % batch_size == 0 else line_count // batch_size + 1


# 定义输入到Bert中的文本的格式,即标题,正文的组织形式
def prepare_sequence(title: str, body: str):
    return (title, body[:256] + "|" + body[-256:])


# 迭代器: 逐条读取数据并输出文本和标签
def get_text_and_label_index_iterator(input_path):
    with open(input_path, 'r', encoding="gbk") as input_file:
        for line in input_file:
            json_data = json.loads(line)
            text = prepare_sequence(json_data["title"], json_data["body"])
            label = json_data['label']

            yield text, label


# 迭代器: 生成一个batch的数据
def get_bert_iterator_batch(data_path, batch_size=32):
    keras_bert_iter = get_text_and_label_index_iterator(data_path)
    continue_iterator = True
    while continue_iterator:
        data_list = []
        for _ in range(batch_size):
            try:
                data = next(keras_bert_iter)
                data_list.append(data)
            except StopIteration:
                continue_iterator = False
        random.shuffle(data_list)

        text_list = []
        label_list = []

        for data in data_list:
            text, label = data
            text_list.append(text)
            label_list.append(label)

        yield text_list, label_list

    return False


class BertClassificationModel(nn.Module):
    """Bert分类器模型"""
    def __init__(self, model_path, predicted_size, hidden_size=768):
        super(BertClassificationModel, self).__init__()
        model_class, tokenizer_class = tfs.BertModel, tfs.BertTokenizer
        self.tokenizer = tokenizer_class.from_pretrained(model_path)
        self.bert = model_class.from_pretrained(model_path)
        self.linear = nn.Linear(hidden_size, predicted_size)
        self.dropout = nn.Dropout(p=0.2)

    def forward(self, batch_sentences):
        batch_tokenized = self.tokenizer.batch_encode_plus(batch_sentences, add_special_tokens=True,
                                                           max_length=512,pad_to_max_length=True)

        input_ids = torch.tensor(batch_tokenized['input_ids'])#.cuda()
        token_type_ids = torch.tensor(batch_tokenized['token_type_ids'])#.cuda()
        attention_mask = torch.tensor(batch_tokenized['attention_mask'])#.cuda()

        bert_output = self.bert(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask)
        bert_cls_hidden_state = bert_output[0][:, 0, :]
        linear_output = self.dropout(self.linear(bert_cls_hidden_state))#.cuda()
        return linear_output

def train_bert():
    if os.path.exists(POSITIVE_TRAIN_INFO_PATH):
        labels_set, total_num = get_label_set_and_sample_num(POSITIVE_TRAIN_INFO_PATH, True)
    else:
        print("Found no positive_info.json, please rerun the Preprocess.py.")
        exit()

    #torch.cuda.set_device(0)

    print("Start training model...")
    # train the model
    steps = get_steps_per_epoch(total_num, BATCH_SIZE)

    bert_classifier_model = BertClassificationModel(PRETRAINED_BERT_ENCODER_PATH, len(labels_set))
    #bert_classifier_model = bert_classifier_model.cuda()

    # 不同子网络设定不同的学习率
    Bert_model_param = []
    Bert_downstream_param = []
    number = 0
    for items, _ in bert_classifier_model.named_parameters():
        if "bert" in items:
            Bert_model_param.append(_)
        else:
            Bert_downstream_param.append(_)
        number += _.numel()
    param_groups = [{"params": Bert_model_param, "lr": 1e-5},
                    {"params": Bert_downstream_param, "lr": 1e-4}]
    optimizer = optim.Adam(param_groups, eps=1e-7, weight_decay=0.001)
    StepLR = torch.optim.lr_scheduler.StepLR(optimizer, step_size=steps, gamma=0.6)
    criterion = nn.CrossEntropyLoss()
    bert_classifier_model.train()
    progbar = Progbar(target=steps)

    for epoch in range(EPOCH):
        model_save_path = os.path.join(BERT_MODEL_SAVE_PATH, "model_epoch{}.pkl".format(epoch))

        dataset_iterator = get_bert_iterator_batch(POSITIVE_TRAIN_FILE_PATH, BATCH_SIZE)

        for i, iteration in enumerate(dataset_iterator):
            # 清空梯度
            bert_classifier_model.zero_grad()
            text = iteration[0]
            labels = torch.tensor(iteration[1])#.cuda()
            optimizer.zero_grad()
            output = bert_classifier_model(text)
            loss = criterion(output, labels)#.cuda()
            loss.backward()

            # 更新模型参数
            optimizer.step()
            # 学习率优化器计数
            StepLR.step()
            progbar.update(i + 1, None, None, [("train loss", loss.item()), ("bert_lr", optimizer.state_dict()["param_groups"][0]["lr"]), ("fc_lr", optimizer.state_dict()["param_groups"][1]["lr"])])

            if i == steps - 1:
                break

        # 保存完整的 BERT 分类器模型
        torch.save(bert_classifier_model, model_save_path)
        # 单独保存经 fune tune 的 BertEncoder模型
        torch.save(bert_classifier_model.bert, FINETUNED_BERT_ENCODER_PATH)
        print("epoch {} is over!\n".format(epoch))

    print("\nTraining is over!\n")  

In [25]:
train_bert()

Start training model...


Some weights of the model checkpoint at pretrain were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Truncation was not explicitly activated but `max_length` is provided a specific value, please use `truncation=True` to explicitly truncate examples to max length. D

 8242/38227 [=====>........................] - ETA: 338935s - train loss: 0.97607 - bert_lr: 0.00001 - fc_lr: 0.00010

KeyboardInterrupt: 