In [11]:
import json
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchtext
import pytorch_lightning as pl
from transformers import (
    DataProcessor,
    InputExample,
    BertTokenizer,
    BertForSequenceClassification,
    BertConfig,
    glue_convert_examples_to_features,
)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(2020)
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

In [2]:
data_path = Path('/home/bnu/projects/CCKS2020-Entity-Linking/data/csv')
train_data = pd.read_csv(data_path/'train_type.csv', sep='\t')
train_data.head()

Unnamed: 0,entity,offset,rawtext,type
0,小品,0,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,Other
1,战狼故事,3,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,Work
2,吴京,10,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,Person
3,障碍,16,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,Other
4,爱人,20,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,Other


In [3]:
lengths = [len(t) for t in train_data['rawtext']]
print(max(lengths))

53


In [4]:
pickle_path = Path('/home/bnu/projects/CCKS2020-Entity-Linking/data/pickle')
idx_to_type = pd.read_pickle(pickle_path/'idx_to_type.pkl')
idx_to_type

['Person',
 'Work',
 'Medicine',
 'Game',
 'Other',
 'Organization',
 'Location',
 'Culture',
 'Biological',
 'VirtualThings',
 'Natural&Geography',
 'Website',
 'Event',
 'Brand',
 'Food',
 'Awards',
 'Time&Calendar',
 'Disease&Symptom',
 'Software',
 'Vehicle',
 'Education',
 'Constellation',
 'Diagnosis&Treatment',
 'Law&Regulation']

In [5]:
class ETProcessor(DataProcessor):

    def get_train_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(data_dir / 'train_type.csv'),
            set_type='train',
        )

    def get_dev_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(data_dir / 'valid_type.csv'),
            set_type='valid',
        )

    def get_labels(self):
        return idx_to_type

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            try:
                text_a = line[0]
                text_b = line[2]
                label = line[3]  
                examples.append(InputExample(
                    guid=guid,
                    text_a=text_a,
                    text_b=text_b,
                    label=label,
                ))
            except:
                print(i)
                print(line)
        return examples

In [6]:
data_path = Path('/home/bnu/projects/CCKS2020-Entity-Linking/data/csv')
processor = ETProcessor()
examples = processor.get_train_examples(data_path)
print(examples[10])
print('Train:', len(examples))
examples = processor.get_dev_examples(data_path)
print('Valid:', len(examples))

InputExample(guid='train-11', text_a='动作', text_b='甄嬛传：安陵容怀孕时，雍正经常摸她的肚子，原来这动作大有深意', label='Other')
Train: 266740
Valid: 33074


In [7]:
def generate_dataloaders(tokenizer, data_path):
    def generate_dataloader_inner(examples):
        features = glue_convert_examples_to_features(
            examples,
            tokenizer,
            label_list=idx_to_type,
            max_length=64,
            output_mode='classification',
            pad_on_left=False,
            pad_token=tokenizer.pad_token_id,
            pad_token_segment_id=0)

        dataset = torch.utils.data.TensorDataset(
            torch.LongTensor([f.input_ids for f in features]),
            torch.LongTensor([f.attention_mask for f in features]),
            torch.LongTensor([f.token_type_ids for f in features]),
            torch.LongTensor([f.label for f in features])
        )

        sampler = torch.utils.data.RandomSampler(dataset)
        dataloader = torch.utils.data.DataLoader(
            dataset, sampler=sampler, batch_size=64
        )
        return dataloader

    # 训练数据
    train_examples = processor.get_train_examples(data_path)
    print('Load Example Finish')
    train_loader = generate_dataloader_inner(train_examples)
    print('Generate DataLoader Finish')
    
    # 验证数据
    valid_examples = processor.get_dev_examples(data_path)
    print('Load Example Finish')
    valid_loader = generate_dataloader_inner(valid_examples)
    print('Generate DataLoader Finish')
    
    return train_loader, valid_loader

In [8]:
pretrained_path = '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch'
tokenizer = BertTokenizer.from_pretrained(pretrained_path)
train_loader, valid_loader = generate_dataloaders(tokenizer, data_path)

INFO:transformers.tokenization_utils:Model name '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_

Load Example Finish


INFO:transformers.data.processors.glue:Writing example 10000/266740
INFO:transformers.data.processors.glue:Writing example 20000/266740
INFO:transformers.data.processors.glue:Writing example 30000/266740
INFO:transformers.data.processors.glue:Writing example 40000/266740
INFO:transformers.data.processors.glue:Writing example 50000/266740
INFO:transformers.data.processors.glue:Writing example 60000/266740
INFO:transformers.data.processors.glue:Writing example 70000/266740
INFO:transformers.data.processors.glue:Writing example 80000/266740
INFO:transformers.data.processors.glue:Writing example 90000/266740
INFO:transformers.data.processors.glue:Writing example 100000/266740
INFO:transformers.data.processors.glue:Writing example 110000/266740
INFO:transformers.data.processors.glue:Writing example 120000/266740
INFO:transformers.data.processors.glue:Writing example 130000/266740
INFO:transformers.data.processors.glue:Writing example 140000/266740
INFO:transformers.data.processors.glue:Writ

Generate DataLoader Finish
Load Example Finish


INFO:transformers.data.processors.glue:Writing example 10000/33074
INFO:transformers.data.processors.glue:Writing example 20000/33074
INFO:transformers.data.processors.glue:Writing example 30000/33074


Generate DataLoader Finish


In [19]:
class ETRoBERTaModel(pl.LightningModule):

    def __init__(self, 
                 pretrained_path, 
                 train_loader, 
                 valid_loader):
        super(ETRoBERTaModel, self).__init__()
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        config = BertConfig.from_json_file(pretrained_path+'/bert_config.json')
        config.num_labels = len(idx_to_type)
        print(config)
        
        # 预训练模型
        self.ptm = BertForSequenceClassification.from_pretrained(
            pretrained_path+'/pytorch_model.bin',
            config=config,
        )

        # 损失函数
        self.criterion = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask, token_type_ids):
        return self.ptm(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        )[0]

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, label = batch
        out = self(input_ids, attention_mask, token_type_ids)

        loss = self.criterion(out, label)

        _, pred = torch.max(out, dim=1)
        acc = (pred == label).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, label = batch
        out = self(input_ids, attention_mask, token_type_ids)

        loss = self.criterion(out, label)

        _, pred = torch.max(out, dim=1)
        acc = (pred == label).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return self.train_loader

    def val_dataloader(self):
        return self.valid_loader


In [20]:
model = ETRoBERTaModel(pretrained_path, train_loader, valid_loader)
batch = next(iter(train_loader))
model(batch[0], batch[1], batch[2])

INFO:transformers.modeling_utils:loading weights file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/pytorch_model.bin


BertConfig {
  "_num_labels": 24,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1",
    "2": "LABEL_2",
    "3": "LABEL_3",
    "4": "LABEL_4",
    "5": "LABEL_5",
    "6": "LABEL_6",
    "7": "LABEL_7",
    "8": "LABEL_8",
    "9": "LABEL_9",
    "10": "LABEL_10",
    "11": "LABEL_11",
    "12": "LABEL_12",
    "13": "LABEL_13",
    "14": "LABEL_14",
    "15": "LABEL_15",
    "16": "LABEL_16",
    "17": "LABEL_17",
    "18": "LABEL_18",
    "19": "LABEL_19",
    "20": "LABEL_20",
    "21": "LABEL_21",
    "22": "LABEL_22",
    "23": "LABEL_23"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_

INFO:transformers.modeling_utils:Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
INFO:transformers.modeling_utils:Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']


tensor([[ 0.5866,  0.0903,  0.2801,  ..., -0.0650, -1.0424,  0.5130],
        [ 0.5835,  0.2605,  0.1885,  ...,  0.0015, -0.8778,  0.4723],
        [ 0.2553,  0.1412,  0.1660,  ..., -0.0530, -0.8117,  0.3629],
        ...,
        [ 0.4258, -0.1628,  0.3814,  ..., -0.1921, -0.8992,  0.4901],
        [ 0.1804,  0.2798,  0.2134,  ..., -0.0246, -0.7629,  0.4898],
        [ 0.2131, -0.0873,  0.1173,  ...,  0.1244, -0.7905,  0.4977]],
       grad_fn=<AddmmBackward>)

In [21]:
trainer = pl.Trainer(
    max_epochs=1,
    val_check_interval=0.1,
    gpus=2,
    distributed_backend='dp',
    default_save_path='/media/bnu/data/pytorch-lightning-checkpoints/ETRoBERTaModel',
)
trainer.fit(model)

INFO:lightning:GPU available: True, used: True
INFO:lightning:VISIBLE GPUS: 0,1
INFO:lightning:
    | Name                                                 | Type                          | Params
---------------------------------------------------------------------------------------------------
0   | ptm                                                  | BertForSequenceClassification | 102 M 
1   | ptm.bert                                             | BertModel                     | 102 M 
2   | ptm.bert.embeddings                                  | BertEmbeddings                | 16 M  
3   | ptm.bert.embeddings.word_embeddings                  | Embedding                     | 16 M  
4   | ptm.bert.embeddings.position_embeddings              | Embedding                     | 393 K 
5   | ptm.bert.embeddings.token_type_embeddings            | Embedding                     | 1 K   
6   | ptm.bert.embeddings.LayerNorm                        | LayerNorm                     | 1 K   
7   



HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …



HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…

HBox(children=(FloatProgress(value=0.0, description='Validating', layout=Layout(flex='2'), max=517.0, style=Pr…




1

In [22]:
ckpt_path = Path('/home/bnu/projects/CCKS2020-Entity-Linking/ckpt/')
trainer.save_checkpoint(ckpt_path/'ET-RoBERTa-64-0420.ckpt')