In [None]:
# -*- encoding:utf-8 -*-
"""
  This script provides an k-BERT exmaple for classification.
"""
import os
import collections
import codecs
import sys
import json
import random
from typing import Dict, List
from multiprocessing import Process, Pool

import argparse
import numpy as np
import torch
import torch.nn as nn
from transformers import (BertConfig, BertTokenizer, 
                          BertModel, BertPreTrainedModel,
                          DistilBertConfig, DistilBertTokenizer,
                          DistilBertModel, DistilBertPreTrainedModel)

from uer.optimizers import BertAdam
from brain import KnowledgeGraph
from utils import set_seed, load_hyperparam, save_model
from constants import * 

# task model
class BertClassifier(nn.Module):
    def __init__(self, args, model):
        super(BertClassifier, self).__init__()
        self.labels_num = args.labels_num
        self.bert = model
        self.output_layer_2 = nn.Linear(args.hidden_size, args.labels_num)
        self.softmax = nn.LogSoftmax(dim=-1)
        self.criterion = nn.NLLLoss()
        self.use_vm = False if args.no_vm else True
        print("[BertClassifier] use visible_matrix: {}".format(self.use_vm))

    def forward(self, src, label, vm=None, seg=None, pos=None):
        """
        Args:
            src: [batch_size x seq_length]
            label: [batch_size]
            mask: [batch_size x seq_length]
        """
        # Encoder.
        if not self.use_vm:
            vm = None
        output = self.bert(input_ids=src, attention_mask=vm, token_type_ids=seg, position_ids=pos, return_dict=True)
        output = output['pooler_output']
        logits = self.output_layer_2(output)
        loss = self.criterion(self.softmax(logits.view(-1, self.labels_num)), label.view(-1))
        return loss, logits

def add_knowledge_worker(params):
    p_id, sentences, columns, kg, bert_token, max_length = params # modified
    sentences_num = len(sentences)
    dataset = []
    for line_id, line in enumerate(sentences):
        if line_id % 10000 == 0:
            print("Progress of process {}: {}/{}".format(p_id, line_id, sentences_num))
            sys.stdout.flush()
        line = line.strip().split('\t')
        try:
            label = int(line[columns["label"]])
            if len(line) == 2:
                text = CLS_TOKEN + line[columns["text_a"]]
   
                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=max_length) # modified
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0]
                token_ids = bert_token.convert_tokens_to_ids(tokens)

                seg = [0 for t in tokens]
                
                dataset.append((token_ids, label, seg, pos, vm))
            
            elif len(line)==3:                  
                text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
                text = CLS_TOKEN + text_a + SEP_TOKEN + text_b + SEP_TOKEN

                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=max_length) # modified
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0]
                token_ids = bert_token.convert_tokens_to_ids(tokens)

                seg = []
                seg_tag = 0
                for t in tokens:
                    seg.append(seg_tag)
                    if t == SEP_TOKEN:
                        seg_tag += 1
                
                dataset.append((token_ids, label, seg, pos, vm))
            
            elif len(line) == 4:  # for dbqa
                qid=int(line[columns["qid"]])
                text_a, text_b = line[columns["text_a"]], line[columns["text_b"]]
                text = CLS_TOKEN + text_a + SEP_TOKEN + text_b + SEP_TOKEN

                tokens, pos, vm, _ = kg.add_knowledge_with_vm([text], add_pad=True, max_length=max_length) # modified
                tokens = tokens[0]
                pos = pos[0]
                vm = vm[0]
                token_ids = bert_token.convert_tokens_to_ids(tokens)

                seg = []
                seg_tag = 0
                for t in tokens:
                    seg.append(seg_tag)
                    if t == SEP_TOKEN:
                        seg_tag += 1
                
                dataset.append((token_ids, label, seg, pos, vm, qid))
            else:
                pass
            
        except:
            print("Error line: ", line)
    return dataset

def set_seed(seed=7):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    
def load_hyperparam(args):
    with codecs.open(args.config_path, "r", "utf-8") as f:
        param = json.load(f)
    args.emb_size = param.get("emb_size", 768)
    args.hidden_size = param.get("hidden_size", 768)
    args.kernel_size = param.get("kernel_size", 3)
    args.block_size = param.get("block_size", 2)
    args.feedforward_size = param.get("feedforward_size", None)
    args.heads_num = param.get("heads_num", None)
    args.layers_num = param.get("layers_num", 12)
    args.dropout = param.get("dropout", 0.1)
    
    return args

def save_model(model, model_path):
    # We dont't need prefix "module".
    if hasattr(model, "module"):
        torch.save(model.module.state_dict(), model_path)
    else:
        torch.save(model.state_dict(), model_path)

# 执行代码

In [None]:
args = {    
    "model_path": "./models/bert_origin/",
    "config_path": "./models/bert_origin/config.json",
    
#     "train_path": "/input/datasets_K-BERT/nlpcc-dbqa/train.tsv",
#     "dev_path":  "/input/datasets_K-BERT/nlpcc-dbqa/dev.tsv",
#     "test_path":  "/input/datasets_K-BERT/nlpcc-dbqa/test.tsv",
#     "output_model_path": "./outputs/kbert_nlpcc-dbqa_CnDbpedia.bin",
    "train_path": "/input/datasets_K-BERT/book_review/train.tsv",
    "dev_path":  "/input/datasets_K-BERT/book_review/dev.tsv",
    "test_path":  "/input/datasets_K-BERT/book_review/test.tsv",
    "output_model_path": "./outputs/kbert_book_review_CnDbpedia.bin",
    
    "kg_name": "CnDbpedia",

    "batch_size": 128, # 32, 64, 128
    "seq_length": 256,
    "learning_rate":2e-5 ,
    "warmup": 0.1,
    "dropout": 0.5,
    "epochs_num": 10, # 5, 10, 20
    "report_steps": 100,
    "seed": 7,
    "mean_reciprocal_rank": False, # True for DBQA dataset
    "workers_num": 1, # number of process for loading dataset，取决于cpu数量和线程数量
    "no_vm": False, # Disable the visible_matrix
}

class Args(dict):  #字典转对象，递归版,既可以作为对象、也可以作为属性
    __setattr__ = dict.__setitem__
    __getattr__ = dict.__getitem__
args = Args(args)

# Load the hyperparameters from the config file.
args = load_hyperparam(args)

set_seed(args.seed)

# Count the number of labels.
labels_set = set()
columns = {}
with open(args.train_path, mode="r", encoding="utf-8") as f:
    for line_id, line in enumerate(f):
        try:
            line = line.strip().split("\t")
            if line_id == 0:
                for i, column_name in enumerate(line):
                    columns[column_name] = i
                continue
            label = int(line[columns["label"]])
            labels_set.add(label)
        except:
            pass
args.labels_num = len(labels_set) 

# Build knowledge graph.
if args.kg_name == 'none':
    spo_files = []
else:
    spo_files = [args.kg_name]
kg = KnowledgeGraph(spo_files=spo_files, predicate=True)

In [None]:
# model
from transformers import (BertConfig, BertTokenizer, BertModel, 
                          DistilBertConfig, DistilBertTokenizer, DistilBertModel)

bert_config = BertConfig.from_pretrained(args.model_path)
bert_token = BertTokenizer.from_pretrained(args.model_path)
bert_model = BertModel(config=bert_config)
bert_model.config.max_position_embeddings = args.seq_length #句子最大长度256
model = BertClassifier(args, bert_model)
# model.from_pretrained(args.model_path)
# model.load_state_dict(torch.load(args.model_path+'pytorch_model.bin'), strict=False)
model.load_state_dict(torch.load('./outputs/kbert_book_review_CnDbpedia.bin'), strict=False)

# For simplicity, we use DataParallel wrapper to use multiple GPUs.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.device_count() > 1:
    print("{} GPUs are available. Let's use them.".format(torch.cuda.device_count()))
    model = nn.DataParallel(model)

model = model.to(device)

In [None]:
# Datset loader.
def batch_loader(batch_size, input_ids, label_ids, seg_ids, pos_ids, vms):
    instances_num = input_ids.size()[0]
    for i in range(instances_num // batch_size):
        input_ids_batch = input_ids[i*batch_size: (i+1)*batch_size, :]
        label_ids_batch = label_ids[i*batch_size: (i+1)*batch_size]
        seg_ids_batch = seg_ids[i*batch_size: (i+1)*batch_size, :]
        pos_ids_batch = pos_ids[i*batch_size: (i+1)*batch_size, :]
        vms_batch = vms[i*batch_size: (i+1)*batch_size]
        yield input_ids_batch, label_ids_batch, seg_ids_batch, pos_ids_batch, vms_batch

    if instances_num > instances_num // batch_size * batch_size:
        input_ids_batch = input_ids[instances_num//batch_size*batch_size:, :]
        label_ids_batch = label_ids[instances_num//batch_size*batch_size:]
        seg_ids_batch = seg_ids[instances_num//batch_size*batch_size:, :]
        pos_ids_batch = pos_ids[instances_num//batch_size*batch_size:, :]
        vms_batch = vms[instances_num//batch_size*batch_size:]
        yield input_ids_batch, label_ids_batch, seg_ids_batch, pos_ids_batch, vms_batch

def read_dataset(path, workers_num=1):

    print("Loading sentences from {}".format(path))
    sentences = []
    with open(path, mode='r', encoding="utf-8") as f:
        for line_id, line in enumerate(f):
            if line_id == 0:
                continue
            sentences.append(line)
    sentence_num = len(sentences)

    print("There are {} sentence in total. We use {} processes to inject knowledge into sentences.".format(sentence_num, workers_num))
    if workers_num > 1:
        params = []
        sentence_per_block = int(sentence_num / workers_num) + 1
        for i in range(workers_num):
            params.append((i, sentences[i*sentence_per_block: (i+1)*sentence_per_block], columns, kg, bert_token, args.seq_length)) # modified
        pool = Pool(workers_num)
        res = pool.map(add_knowledge_worker, params)
        pool.close()
        pool.join()
        dataset = [sample for block in res for sample in block]
    else:
        params = (0, sentences, columns, kg, bert_token, args.seq_length) # modified
        dataset = add_knowledge_worker(params)

    return dataset

# Evaluation function.
def evaluate(args, is_test, metrics='Acc'):
    if is_test:
        dataset = read_dataset(args.test_path, workers_num=args.workers_num)
    else:
        dataset = read_dataset(args.dev_path, workers_num=args.workers_num)

    input_ids = torch.LongTensor([sample[0] for sample in dataset])
    label_ids = torch.LongTensor([sample[1] for sample in dataset])
    seg_ids = torch.LongTensor([sample[2] for sample in dataset])
    pos_ids = torch.LongTensor([example[3] for example in dataset])
    vms = [example[4] for example in dataset] # list of 2-dim tensor

    batch_size = args.batch_size
    instances_num = input_ids.size()[0]
    if is_test:
        print("The number of evaluation instances: ", instances_num)

    correct = 0
    # Confusion matrix.
    confusion = torch.zeros(args.labels_num, args.labels_num, dtype=torch.long)

    model.eval()

    if not args.mean_reciprocal_rank:
        for i, (input_ids_batch, label_ids_batch,  seg_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, seg_ids, pos_ids, vms)):

            # vms_batch = vms_batch.long()
            vms_batch = torch.LongTensor(vms_batch)

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            seg_ids_batch = seg_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            with torch.no_grad():
                try:
                    loss, logits = model(input_ids_batch, label_ids_batch, vms_batch, seg_ids_batch, pos_ids_batch)
                except:
                    print(input_ids_batch)
                    print(input_ids_batch.size())
                    print(vms_batch)
                    print(vms_batch.size())

            logits = nn.Softmax(dim=1)(logits)
            pred = torch.argmax(logits, dim=1)
            gold = label_ids_batch
            for j in range(pred.size()[0]):
                confusion[pred[j], gold[j]] += 1
            correct += torch.sum(pred == gold).item()

        if is_test:
            print("Confusion matrix:")
            print(confusion)
            print("Report precision, recall, and f1:")

        for i in range(confusion.size()[0]):
            p = confusion[i,i].item()/confusion[i,:].sum().item()
            r = confusion[i,i].item()/confusion[:,i].sum().item()
            f1 = 2*p*r / (p+r)
            if i == 1:
                label_1_f1 = f1
            print("Label {}: {:.3f}, {:.3f}, {:.3f}".format(i,p,r,f1))
        print("Acc. (Correct/Total): {:.4f} ({}/{}) ".format(correct/len(dataset), correct, len(dataset)))
        if metrics == 'Acc':
            return correct/len(dataset)
        elif metrics == 'f1':
            return label_1_f1
        else:
            return correct/len(dataset)
    else:
        for i, (input_ids_batch, label_ids_batch,  seg_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, seg_ids, pos_ids, vms)):

            vms_batch = torch.LongTensor(vms_batch)

            input_ids_batch = input_ids_batch.to(device)
            label_ids_batch = label_ids_batch.to(device)
            seg_ids_batch = seg_ids_batch.to(device)
            pos_ids_batch = pos_ids_batch.to(device)
            vms_batch = vms_batch.to(device)

            with torch.no_grad():
                loss, logits = model(input_ids_batch, label_ids_batch, vms_batch, seg_ids_batch, pos_ids_batch)
            logits = nn.Softmax(dim=1)(logits)
            if i == 0:
                logits_all=logits
            if i >= 1:
                logits_all=torch.cat((logits_all,logits),0)

        order = -1
        gold = []
        for i in range(len(dataset)):
            qid = dataset[i][-1]
            label = dataset[i][1]
            if qid == order:
                j += 1
                if label == 1:
                    gold.append((qid,j))
            else:
                order = qid
                j = 0
                if label == 1:
                    gold.append((qid,j))

        label_order = []
        order = -1
        for i in range(len(gold)):
            if gold[i][0] == order:
                templist.append(gold[i][1])
            elif gold[i][0] != order:
                order=gold[i][0]
                if i > 0:
                    label_order.append(templist)
                templist = []
                templist.append(gold[i][1])
        label_order.append(templist)

        order = -1
        score_list = []
        for i in range(len(logits_all)):
            score = float(logits_all[i][1])
            qid=int(dataset[i][-1])
            if qid == order:
                templist.append(score)
            else:
                order = qid
                if i > 0:
                    score_list.append(templist)
                templist = []
                templist.append(score)
        score_list.append(templist)

        rank = []
        pred = []
        print(len(score_list))
        print(len(label_order))
        for i in range(len(score_list)):
            if len(label_order[i])==1:
                if label_order[i][0] < len(score_list[i]):
                    true_score = score_list[i][label_order[i][0]]
                    score_list[i].sort(reverse=True)
                    for j in range(len(score_list[i])):
                        if score_list[i][j] == true_score:
                            rank.append(1 / (j + 1))
                else:
                    rank.append(0)

            else:
                true_rank = len(score_list[i])
                for k in range(len(label_order[i])):
                    if label_order[i][k] < len(score_list[i]):
                        true_score = score_list[i][label_order[i][k]]
                        temp = sorted(score_list[i],reverse=True)
                        for j in range(len(temp)):
                            if temp[j] == true_score:
                                if j < true_rank:
                                    true_rank = j
                if true_rank < len(score_list[i]):
                    rank.append(1 / (true_rank + 1))
                else:
                    rank.append(0)
        MRR = sum(rank) / len(rank)
        print("MRR", MRR)
        return MRR

In [None]:
# Training phase.
print("Start training.")
trainset = read_dataset(args.train_path, workers_num=args.workers_num)
print("Shuffling dataset")
random.shuffle(trainset)
instances_num = len(trainset)
batch_size = args.batch_size

print("Trans data to tensor.")
print("input_ids")
input_ids = torch.LongTensor([example[0] for example in trainset])
print("label_ids")
label_ids = torch.LongTensor([example[1] for example in trainset])
print("seg_ids")
seg_ids = torch.LongTensor([example[2] for example in trainset])
print("pos_ids")
pos_ids = torch.LongTensor([example[3] for example in trainset])
print("vms")
vms = [example[4] for example in trainset]

train_steps = int(instances_num * args.epochs_num / batch_size) + 1

print("Batch size: ", batch_size)
print("The number of training instances:", instances_num)

param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay_rate': 0.0}
]
optimizer = BertAdam(optimizer_grouped_parameters, lr=args.learning_rate, warmup=args.warmup, t_total=train_steps)

total_loss = 0.
result = 0.0
# best_result = 0.0
best_result = 0.8143

In [None]:
# Training phase 2
for epoch in range(1, args.epochs_num+1):
    model.train()
    for i, (input_ids_batch, label_ids_batch, seg_ids_batch, pos_ids_batch, vms_batch) in enumerate(batch_loader(batch_size, input_ids, label_ids, seg_ids, pos_ids, vms)):
        model.zero_grad()

        vms_batch = torch.LongTensor(vms_batch)

        input_ids_batch = input_ids_batch.to(device)
        label_ids_batch = label_ids_batch.to(device)
        seg_ids_batch = seg_ids_batch.to(device)
        pos_ids_batch = pos_ids_batch.to(device)
        vms_batch = vms_batch.to(device)

        loss, _ = model(input_ids_batch, label_ids_batch, vms_batch, seg_ids_batch, pos=pos_ids_batch)
        if torch.cuda.device_count() > 1:
            loss = torch.mean(loss)
        total_loss += loss.item()
        if (i + 1) % args.report_steps == 0:
            print("Epoch id: {}, Training steps: {}, Avg loss: {:.3f}".format(epoch, i+1, total_loss / args.report_steps))
            sys.stdout.flush()
            total_loss = 0.
        loss.backward()
        optimizer.step()

    print("Start evaluation on dev dataset.")
    result = evaluate(args, False)
    if result > best_result:
        best_result = result
        save_model(model, args.output_model_path)
    else:
        continue

    print("Start evaluation on test dataset.")
    evaluate(args, True)

In [None]:
# Evaluation phase.
print("Final evaluation on the test dataset.")

if torch.cuda.device_count() > 1:
    model.module.load_state_dict(torch.load(args.output_model_path))
else:
    model.load_state_dict(torch.load(args.output_model_path))
evaluate(args, True)