# 使用RoBERTa模型完成entity linking

In [1]:
import random
from pathlib import Path
import os

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from tqdm.notebook import tqdm
from sklearn.metrics import accuracy_score
from transformers import (
    DataProcessor,
    InputExample,
    BertTokenizer,
    BertConfig,
    BertForSequenceClassification,
    glue_convert_examples_to_features,
)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(2020)
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [2]:
data_path = "../../data/ccks2020_el_data_v1"
roberta_path = "../../../../../../../research/pretrained_models/chinese_roberta_wwm_large_ext_pytorch"
model_path = "pytorch-lightning-checkpoints/ELRoBERTaModel"

按照惯例，我们来写一个data processor用于预处理数据。

In [6]:
class EntityLinkingProcessor(DataProcessor):
    """实体链接数据处理"""

    def get_train_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='train',
        )

    def get_dev_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='valid',
        )

    def get_test_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='test',
        )

    def get_labels(self):
        return ['0', '1']

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            if len(line) != 5:
                print(line)
            text_a = line[0] + ' ' + line[2]
            text_b = line[3]
            label = line[-1]
            examples.append(InputExample(
                guid=guid,
                text_a=text_a,
                text_b=text_b,
                label=label,
            ))
        return examples

    def create_dataloader(self, examples, tokenizer, max_length=384,
                          shuffle=False, batch_size=32, use_pickle=False):
        pickle_name = 'EL_FEATURE_' + examples[0].guid.split('-')[0].upper() + '.pkl'
        if use_pickle and os.path.isfile(os.path.join(data_path, pickle_name)):
            features = pd.read_pickle(os.path.join(data_path, pickle_name))
        else:
            features = glue_convert_examples_to_features(
                examples,
                tokenizer,
                label_list=self.get_labels(),
                max_length=max_length,
                output_mode='classification',
            )
            pd.to_pickle(features, os.path.join(data_path, pickle_name))

        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features]),
        )

        dataloader = torch.utils.data.DataLoader(
            dataset,
            shuffle=shuffle,
            batch_size=batch_size,
        )
        return dataloader

In [7]:
# processor = EntityLinkingProcessor()
# tokenizer = BertTokenizer.from_pretrained(roberta_path)
# valid_examples = processor.get_dev_examples(os.path.join(data_path, 'dev_link.tsv'))
# valid_loader = processor.create_dataloader(
#     examples=valid_examples,
#     tokenizer=tokenizer,
#     max_length=384,
#     shuffle=False,
#     batch_size=8,
#     use_pickle=True,
# )
# for batch in valid_loader:
#     print(batch)
#     break

# train_examples = processor.get_train_examples(os.path.join(data_path, 'train_link.tsv'))
# train_loader = processor.create_dataloader(
#     examples=train_examples,
#     tokenizer=tokenizer,
#     max_length=384,
#     shuffle=False,
#     batch_size=8,
#     use_pickle=True,
# )
# for batch in train_loader:
#     print(batch)
#     break
    


In [8]:
class EntityLinkingModel(pl.LightningModule):
    """实体链接模型"""

    def __init__(self, max_length=384, batch_size=32, use_pickle=True):
        super(EntityLinkingModel, self).__init__()
        # 输入最大长度
        self.max_length = max_length
        self.batch_size = batch_size
        self.use_pickle = use_pickle

        self.tokenizer = BertTokenizer.from_pretrained(roberta_path)

        # 预训练模型配置信息
        self.config = BertConfig.from_json_file(os.path.join(roberta_path, 'config.json'))
        self.config.num_labels = 1

        # 预训练模型
        self.bert = BertForSequenceClassification.from_pretrained(
            os.path.join(roberta_path, 'pytorch_model.bin'),
            config=self.config,
        )

        # 二分类损失函数
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, token_type_ids):
        logits = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[0]
        return logits.squeeze()

    def prepare_data(self):
        self.processor = EntityLinkingProcessor()
        self.train_examples = self.processor.get_train_examples(os.path.join(data_path, 'train_link.tsv'))
        self.valid_examples = self.processor.get_dev_examples(os.path.join(data_path , 'dev_link.tsv'))
#         self.test_examples = self.processor.get_test_examples(os.path.join(data_path, 'test_link.tsv'))

        self.train_loader = self.processor.create_dataloader(
            examples=self.train_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=True,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
        self.valid_loader = self.processor.create_dataloader(
            examples=self.valid_examples,
            tokenizer=self.tokenizer,
            max_length=self.max_length,
            shuffle=False,
            batch_size=self.batch_size,
            use_pickle=self.use_pickle,
        )
#         self.test_loader = self.processor.create_dataloader(
#             examples=self.test_examples,
#             tokenizer=self.tokenizer,
#             max_length=self.max_length,
#             shuffle=False,
#             batch_size=32,
# #             use_pickle=self.use_pickle,
#         )

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int()
        acc = (preds == labels).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader

In [9]:
model = EntityLinkingModel(batch_size=16)
# model.prepare_data()

In [10]:
trainer = pl.Trainer(
    max_steps=5000,
    val_check_interval=0.1,
    gpus=1,
    weights_save_path=model_path,
)
trainer.fit(model)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type                          | Params
------------------------------------------------------------
0 | bert      | BertForSequenceClassification | 102 M 
1 | criterion | BCEWithLogitsLoss             | 0     


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…




1

In [8]:
# model = EntityLinkingModel.load_from_checkpoint(
# #     checkpoint_path=model_path + 'EL_BASE_EPOCH0.ckpt',
# )
# model.to(device)
# model.eval()

# result = []
# for batch in tqdm(valid_loader):
#     for i in range(len(batch)):
#         batch[i] = batch[i].to(device)
#     input_ids, attention_mask, token_type_ids, labels = batch
#     logits = model(input_ids, attention_mask, token_type_ids)
#     preds = (logits > 0).int()
#     result.extend(preds.tolist())

In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.eval()
model = model.to(device)
result = []
with torch.no_grad():
    for batch in tqdm(model.val_dataloader()):
        batch = [t.to(device) for t in batch]
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = model(input_ids, attention_mask, token_type_ids)
        preds = (logits > 0).int()
        result.extend(preds.tolist())

HBox(children=(FloatProgress(value=0.0, max=63.0), HTML(value='')))




In [10]:
result

[1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 0,
 0,
 1,
 0,
 1,
 1,
 1,
 0,
 0,
 1,
 0,
 1,
 0,
 0,
 0,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 1,
 1,
 0,
 0,
 0,
 0,
 1,
 1,
 0,
 1,
 1,
 0,


In [20]:
answers = []
for batch in model.val_dataloader():
    answers.extend(batch[-1].data.cpu().numpy().tolist())
    

In [23]:
import numpy as np
np.sum(np.array(answers) == np.array(result)) / len(result)

0.8868868868868869