In [13]:
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
from transformers import (
    DataProcessor,
    InputExample,
    BertTokenizer,
    BertConfig,
    BertForSequenceClassification,
    glue_convert_examples_to_features,
)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(2020)
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [11]:
CKPT_PATH = '../../ckpt/'
TSV_PATH = '../../data/tsv/'
PICKLE_PATH = '../../data/pickle/'
RESULT_PATH = '../../data/result/'
PRETRAINED_PATH = '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/'

In [3]:
class EntityLinkingProcessor(DataProcessor):
    """实体链接数据处理"""

    def get_train_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='train',
        )

    def get_dev_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='valid',
        )

    def get_test_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='test',
        )

    def get_labels(self):
        return ['0', '1']

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            text_a = line[1] + ' ' + line[3]
            text_b = line[5]
            label = line[-1]
            examples.append(InputExample(
                guid=guid,
                text_a=text_a,
                text_b=text_b,
                label=label,
            ))
        return examples

    def create_dataloader(self, examples, tokenizer, max_length=384,
                          shuffle=False, batch_size=32, use_pickle=False):
        pickle_name = 'EL_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'
        if use_pickle:
            features = pd.read_pickle(PICKLE_PATH + pickle_name)
        else:
            features = glue_convert_examples_to_features(
                examples,
                tokenizer,
                label_list=self.get_labels(),
                max_length=max_length,
                output_mode='classification',
                pad_on_left=False,
                pad_token=tokenizer.pad_token_id,
                pad_token_segment_id=0,
            )
            pd.to_pickle(features, PICKLE_PATH + pickle_name)

        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features]),
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            shuffle=shuffle,
            batch_size=batch_size,
        )
        return dataloader

In [4]:
processor = EntityLinkingProcessor()
tokenizer = BertTokenizer.from_pretrained(PRETRAINED_PATH)
valid_examples = processor.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')
valid_loader = processor.create_dataloader(
    examples=valid_examples,
    tokenizer=tokenizer,
    max_length=384,
    shuffle=False,
    batch_size=8,
    use_pickle=True,
)

for batch in valid_loader:
    print(batch)
    break

INFO:transformers.tokenization_utils:Model name '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ex

[tensor([[ 101, 1921,  678,  ...,    0,    0,    0],
        [ 101, 3719, 1649,  ...,    0,    0,    0],
        [ 101, 3719, 1649,  ...,    0,    0,    0],
        ...,
        [ 101, 1322, 2791,  ...,    0,    0,    0],
        [ 101, 1139, 4909,  ...,    0,    0,    0],
        [ 101, 1139, 4909,  ...,    0,    0,    0]]), tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        ...,
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]]), tensor([1, 0, 0, 1, 1, 0, 0, 1])]


In [5]:
class EntityLinkingModel(pl.LightningModule):
    """实体链接模型"""

    def __init__(self, max_length=384, batch_size=32, use_pickle=True):
        super(EntityLinkingModel, self).__init__()
        # 输入最大长度
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_pickle = use_pickle

        self.tokenizer = BertTokenizer.from_pretrained(PRETRAINED_PATH)

        # 预训练模型配置信息
        self.config = BertConfig.from_json_file(PRETRAINED_PATH + 'bert_config.json')
        self.config.num_labels = 1

        # 预训练模型
        self.bert = BertForSequenceClassification.from_pretrained(
            PRETRAINED_PATH + 'pytorch_model.bin',
            config=self.config,
        )

        # 二分类损失函数
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, token_type_ids):
        logits = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[0]
        return logits.squeeze()

    def prepare_data(self):
        self.processor = EntityLinkingProcessor()
        self.train_examples = self.processor.get_train_examples(TSV_PATH + 'EL_TRAIN.tsv')
        self.valid_examples = self.processor.get_dev_examples(TSV_PATH + 'EL_VALID.tsv')
        self.test_examples = self.processor.get_test_examples(TSV_PATH + 'EL_TEST.tsv')

        self.train_loader = self.processor.create_dataloader(
            examples=self.train_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=True,
            batch_size=32,
            use_pickle=self.use_pickle,
        )
        self.valid_loader = self.processor.create_dataloader(
            examples=self.valid_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=self.use_pickle,
        )
        self.test_loader = self.processor.create_dataloader(
            examples=self.test_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=32,
            use_pickle=self.use_pickle,
        )

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader


In [6]:
model = EntityLinkingModel.load_from_checkpoint(
    checkpoint_path=CKPT_PATH+'EL_BASE_EPOCH0.ckpt',
)
model.to(device)
model.eval()

result = []
for batch in tqdm(valid_loader):
    for i in range(len(batch)):
        batch[i] = batch[i].to(device)
    input_ids, attention_mask, token_type_ids, labels = batch
    logits = model(input_ids, attention_mask, token_type_ids)
    preds = (logits > 0).int()
    result.extend(preds.tolist())

INFO:transformers.tokenization_utils:Model name '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ex

HBox(children=(FloatProgress(value=0.0, max=17868.0), HTML(value='')))




In [12]:
valid_data = pd.read_csv(TSV_PATH+'EL_VALID.tsv', sep='\t')
valid_data['result'] = result
valid_data.to_csv(RESULT_PATH+'EL_VALID_RESULT.tsv', index=False, sep='\t')

In [14]:
print(accuracy_score(valid_data['predict'], valid_data['result']))

0.9549318240648108
