In [None]:
#### !/usr/bin/env python
# coding: utf-8

import os
from os.path import join, exists
import collections
import codecs
import sys
import json
import random
from typing import Dict, List
from multiprocessing import Process, Pool
from collections import OrderedDict
import logging
from datetime import datetime
from tqdm import tqdm, trange
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter

import argparse
import numpy as np
import torch
import torch.nn as nn
import transformers
from transformers import (BertConfig, BertTokenizer, 
                          BertModel, BertPreTrainedModel,
                          DistilBertConfig, DistilBertTokenizer,
                          GPT2Config, GPT2LMHeadModel,
                          AdamW, get_linear_schedule_with_warmup)

from modeling_distilbert import (DistilBertModel, 
                                 DistilBertPreTrainedModel)

from uer.optimizers import BertAdam
from brain import KnowledgeGraph
from utils import (set_seed, create_logger, save_model, 
                    calculate_loss_and_accuracy)
from constants import * 

# arguments setting
def load_hyperparam(args):
    with codecs.open(args.enc_config_path, "r", "utf-8") as f:
        param = json.load(f)
    args.emb_size = param.get("emb_size", 768)
    args.hidden_size = param.get("hidden_size", 768)
    args.kernel_size = param.get("kernel_size", 3)
    args.block_size = param.get("block_size", 2)
    args.feedforward_size = param.get("feedforward_size", None)
    args.heads_num = param.get("heads_num", None)
    args.layers_num = param.get("layers_num", 12)
    args.dropout = param.get("dropout", 0.1)
    
    return args

args = {
    "enc_model_path": "./models/distilbert-chinese/",
    "enc_config_path": "./models/distilbert-chinese/config.json",
    "dec_model_path": "./models/gpt2_dialogue_model/",
    "dec_config_path": "./models/gpt2_dialogue_model/config.json",
    
    "train_path": "/input/datasets_K-BERT/STC-corpus/STC.json",
#     "dev_path":  "/input/datasets_K-BERT/book_review/dev.tsv",
    "test_path":  "/input/datasets_K-BERT/STC-corpus/STC_test.json",

    "model_output_path": "./outputs/encdec_STC_CnDbpedia.bin",    
    "kg_name": "CnDbpedia",
    "log_path": "/output/enc_dec_log.txt",
    "tb_writer_dir": "/output",

    "batch_size": 32, # 32, 64, 128
    "seq_length": 256,
    "learning_rate":2e-5 , # 2e-5, 5e-5
    "warmup": 0.1,
    "dropout": 0.5,
    "epochs_num": 2, # 5, 10, 20
    "log_step": 10, #多少步汇报一次loss
    "max_grad_norm": 1.0, #梯度裁剪
    "gradient_accumulation": 2,  #每n次反向传播/batch，进行一次梯度下降
    "seed": 7,
    "mean_reciprocal_rank": False, # True for DBQA dataset
    "workers_num": 1, # number of process for loading dataset，取决于cpu数量和线程数量
    "no_vm": False, # Disable the visible_matrix
}

class Args(dict):  #字典转对象，递归版,既可以作为对象、也可以作为属性
    __setattr__ = dict.__setitem__
    __getattr__ = dict.__getitem__
args = Args(args)
args = load_hyperparam(args) # Load the hyperparameters from the config file.

# basic setting
logger = create_logger(args)
set_seed(args.seed)

In [None]:
# Build knowledge graph.
if args.kg_name == 'none':
    spo_files = []
else:
    spo_files = [args.kg_name]
kg = KnowledgeGraph(spo_files=spo_files, predicate=True)

In [None]:
# build model
## whole model
class EncDecModel(transformers.PreTrainedModel):
    def __init__(self,config, args, enc_model, dec_model):
        super().__init__(config)
        self.enc = enc_model
        self.dec = dec_model
    def forward(
        self, 
        input_ids=None,
        attention_mask=None,
        position_ids=None, 
        token_type_ids=None, 
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=True,
        label_ids = None
    ):
        
        enc_output = self.enc(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids, 
            token_type_ids=token_type_ids, 
            head_mask=None,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict
        )
        enc_hidden_states = enc_output[0]
        label_embeds = self.dec.transformer.wte(label_ids) # id2embedding
        label_embeds[:,0] = enc_hidden_states[:,0] # 只替换[CLS]
        outputs = self.dec(inputs_embeds=label_embeds)

        return outputs

In [None]:
## encoder distilbert origin
enc_model_config = DistilBertConfig.from_pretrained(args.enc_model_path)
enc_model_token = BertTokenizer.from_pretrained(args.enc_model_path)
enc_model = DistilBertModel(config=enc_model_config)
enc_model.config.max_position_embeddings = args.seq_length #句子最大长度256

# module_dict = torch.load(args.model_path+'pytorch_model.bin')
module_dict = torch.load("./outputs/distilbert_book_review_CnDbpedia.bin")
new_module_dict = OrderedDict()
# 处理 导入module_dict名称不匹配问题
for layer in module_dict.keys():
    if layer[:6] == 'origin':
        new_module_dict[layer[7:]] = module_dict[layer]
    else:
        new_module_dict[layer] = module_dict[layer]
enc_model.load_state_dict(new_module_dict, strict=False)

## decoder model
dec_model_token = BertTokenizer.from_pretrained(args.dec_model_path)
dec_model = GPT2LMHeadModel.from_pretrained(args.dec_model_path)
model = EncDecModel(config=enc_model_config, args=args,enc_model=enc_model, dec_model=dec_model)

In [None]:
def add_knowledge_worker(params):
    p_id, sentences, kg, enc_tokenizer, dec_tokenizer, max_length = params # modified
    sentences_num = len(sentences)
    dataset = []
    assert len(sentences[0]) == 2
    for line_id, line in enumerate(sentences):
        if line_id % 10000 == 0:
            print("Progress of process {}: {}/{}".format(p_id, line_id, sentences_num))
            sys.stdout.flush()
        try: 
            text_a, text_b = line[0], line[1]
            text = CLS_TOKEN + text_a
            tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=max_length) # modified
            tokens = tokens[0]
            pos = pos[0]
            vm = vm[0]
            token_ids = enc_tokenizer.convert_tokens_to_ids(tokens)

            text_dec = CLS_TOKEN + text_b
            text_dec = dec_tokenizer.tokenize(text_dec) # str2list, 自带数据清洗和特殊标志识别
            text_dec = dec_tokenizer.convert_tokens_to_ids(text_dec) # list of str -> list of id
            pad_num = max_length - len(text_dec)
            if pad_num > 0:
                text_dec.extend([PAD_ID] * pad_num)
            else:
                text_dec = text_dec[:max_length]

            seg = []
            seg_tag = 0
            for t in tokens:
                seg.append(seg_tag)
                if t == SEP_TOKEN:
                    seg_tag += 1

            dataset.append((token_ids, text_dec, seg, pos, vm))
        except:
            print(f"Error line: {line_id}, {line}")
    return dataset

# Datset loader.
def batch_loader(batch_size, input_ids, text_dec, seg_ids, pos_ids, vms):
    instances_num = input_ids.size()[0]
    for i in range(instances_num // batch_size):
        input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
        text_dec_batch = text_dec[i*batch_size: (i+1)*batch_size]
        seg_ids_batch = seg_ids[i*batch_size: (i+1)*batch_size, :]
        pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
        vms_batch = vms[i*batch_size: (i+1)*batch_size]
        yield input_ids_batch, text_dec_batch, seg_ids_batch, pos_ids_batch, vms_batch

    if instances_num > instances_num // batch_size * batch_size:
        input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
        text_dec_batch = text_dec[instances_num//batch_size*batch_size:]
        seg_ids_batch = seg_ids[instances_num//batch_size*batch_size:, :]
        pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
        vms_batch = vms[instances_num//batch_size*batch_size:]
        yield input_ids_batch, text_dec_batch, seg_ids_batch, pos_ids_batch, vms_batch

## read dataset
sentences = []
with open(args.test_path, mode='r', encoding="utf-8") as f:
    data = json.load(f)
    sentences = data['test']

print("There are {} sentence in total. We use {} processes to inject knowledge into sentences.".format(len(sentences), 1))
dataset = add_knowledge_worker((0, sentences, kg, enc_model_token, dec_model_token, args.seq_length))

## 处理dataset
random.shuffle(dataset)

logger.info("Trans data to tensor.")
logger.info("input_ids")
input_ids = torch.LongTensor([example[0] for example in dataset])
logger.info("text_dec")
text_dec = torch.LongTensor([example[1] for example in dataset])
logger.info("seg_ids")
seg_ids = torch.LongTensor([example[2] for example in dataset])
logger.info("pos_ids")
pos_ids = torch.LongTensor([example[3] for example in dataset])
logger.info("vms")
vms = [example[4] for example in dataset]

instances_num = len(dataset)
train_steps = int(instances_num * args.epochs_num / args.batch_size) + 1

In [None]:
## 模型并行化
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    logging.info("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
    model = nn.DataParallel(model) #并行化

model = model.to(device)

## 记录模型参数数量
num_parameters = 0

parameters = model.parameters()
for parameter in parameters:
    num_parameters += parameter.numel()
logger.info('number of model parameters: {}'.format(num_parameters))

##建立优化器和学习率调度器
t_total = len(dataset) // args.gradient_accumulation * args.epochs_num # t_total = batch_num * epochs，代表实际梯度下降总次数
warmup_steps = t_total // 2 #和学习率调整有关

optimizer = AdamW(model.parameters(), lr=args.learning_rate) #梯度优化器
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=warmup_steps, num_training_steps=t_total) #倾斜三角式，学习率调度器

## 创建对话模型的输出目录
if not os.path.exists(args.model_output_path):
    os.mkdir(args.model_output_path)

##开始训练
logger.info("***** Running training *****")
logger.info(f"  Num Epochs = {args.epochs_num}")
logger.info(f"  Total train batch size (w. parallel, distributed & accumulation) = {args.batch_size * args.gradient_accumulation}")
logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation}")
logger.info(f"  Total optimization steps = {t_total}")
logger.info(f"  Batch size: {args.batch_size}")
logger.info(f"  The number of training instances: {instances_num}")
logger.info('starting training')
# 用于统计每次梯度累计的loss
running_loss = 0
# 统计一共训练了多少个step
overall_step = 0
# 记录tensorboardX
tb_writer = SummaryWriter(log_dir=args.tb_writer_dir)
# 记录 out of memory的次数
oom_time = 0

In [None]:
# 开始训练
for epoch in range(args.epochs_num):
    epoch_start_time = datetime.now()
    model.train()
    for batch_idx, (input_ids_batch, text_dec_batch, seg_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(args.batch_size, input_ids, text_dec, seg_ids, pos_ids, vms)):
        model.zero_grad()
        # 注意：GPT2模型的forward()函数，是对于给定的context，生成一个token，而不是生成一串token
        # GPT2Model的输入为n个token_id时，输出也是n个hidden_state，使用第n个hidden_state预测第n+1个token
        vms_batch = torch.LongTensor(vms_batch)

        input_ids_batch = input_ids_batch.to(device)
        text_dec_batch = text_dec_batch.to(device)
        seg_ids_batch = seg_ids_batch.to(device)
        pos_ids_batch = pos_ids_batch.to(device)
        vms_batch = vms_batch.to(device)
        
        outputs = model(
            input_ids=input_ids_batch, 
            attention_mask=vms_batch, 
            position_ids=pos_ids_batch,
            token_type_ids=seg_ids_batch, 
            label_ids=text_dec_batch
        )
        loss, accuracy = calculate_loss_and_accuracy(outputs, labels=text_dec_batch, device=device)
        loss = loss.mean()
        accuracy = accuracy.mean()
        if args.gradient_accumulation > 1:
            loss = loss / args.gradient_accumulation
            accuracy = accuracy / args.gradient_accumulation
        loss.backward()
        # 梯度裁剪解决的是梯度消失或爆炸的问题，即设定阈值
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
        # 进行一定step的梯度累计之后，更新参数
        if (batch_idx + 1) % args.gradient_accumulation == 0:
            running_loss += loss.item()
            optimizer.step() # 更新参数，进行梯度下降
            optimizer.zero_grad() # 清空梯度信息
            scheduler.step() # 进行warm up，调整学习率
            overall_step += 1
            # 更新日志与tnesorboardX信息
            if (overall_step + 1) % args.log_step == 0:
                logger.info(
                    "batch {} of epoch {}, loss {}, accuracy {}".format(batch_idx + 1, epoch + 1, loss,
                                                                        accuracy))
                tb_writer.add_scalar('loss', loss.item(), overall_step)
    logger.info('saving model for epoch {}'.format(epoch + 1))

    model_path = join(args.model_output_path, 'model_epoch{}'.format(epoch + 1))
    if not os.path.exists(model_path):
        os.mkdir(model_path)
    model_to_save = model.module if hasattr(model, 'module') else model
    model_to_save.save_pretrained(model_path)
    logger.info('epoch {} finished'.format(epoch + 1))
    epoch_finish_time = datetime.now()
    logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))
logger.info('training finished')