In [1]:
import random
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
from transformers import (
    DataProcessor,
    InputExample,
    BertTokenizer,
    BertConfig,
    BertForSequenceClassification,
    glue_convert_examples_to_features,
)

def set_random_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

set_random_seed(2020)
device = torch.device('cuda:0' if torch.cuda.is_available else 'cpu')

INFO:transformers.file_utils:PyTorch version 1.4.0 available.
INFO:transformers.file_utils:TensorFlow version 2.1.0 available.


In [2]:
tsv_path = '../../data/tsv/'
train_data = pd.read_csv(tsv_path+'EL_TRAIN.tsv', sep='\t')
train_data.head()

Unnamed: 0,text_id,entity,offset,short_text,kb_id,kb_text,kb_predicate_num,predict
0,1,小品,0,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,275897,义项描述:小品 代表:喜剧小品 中文名:小品 特点:短小精悍，情节简单等 基本要求:语言清晰...,8,1
1,1,吴京,10,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,218699,摘要:吴京，男，1963年4月出生，河北石家庄人，1986年2月加入中国共产党，1980年6...,9,0
2,1,吴京,10,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,200103,性别:男 毕业院校:楚雄州人民警察学校 出生日期:1980年10月 职务:公安局科技信息化和...,10,0
3,1,吴京,10,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,159056,职业:演员、导演 义项描述:中国内地男演员、导演 配偶:谢楠 国籍:中国 代表作品:流浪地球...,17,1
4,1,障碍,16,小品《战狼故事》中，吴京突破重重障碍解救爱人，深情告白太感人,249920,性质:体育竞技术语 日本語:障害　しょうがい 义项描述:体育竞技术语 距起跑点:270米 外...,8,0


In [3]:
class EntityLinkingProcessor(DataProcessor):

    def get_train_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='train',
        )

    def get_dev_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='valid',
        )
    
    def get_test_examples(self, file_path):
        return self._create_examples(
            self._read_tsv(file_path),
            set_type='test',
        )

    def get_labels(self):
        return ['0', '1']

    def _create_examples(self, lines, set_type):
        examples = []
        for i, line in enumerate(lines):
            if i == 0:
                continue
            guid = f'{set_type}-{i}'
            text_a = line[1] + ' ' + line[3]
            text_b = line[5]
            label = line[-1]  
            examples.append(InputExample(
                guid=guid,
                text_a=text_a,
                text_b=text_b,
                label=label,
            ))
        return examples

In [4]:
processor = EntityLinkingProcessor()

train_examples = processor.get_train_examples(tsv_path + 'EL_TRAIN.tsv')
print(train_examples[10])
print('Train:', len(train_examples))
print()

valid_examples = processor.get_dev_examples(tsv_path + 'EL_VALID.tsv')
print(valid_examples[20])
print('Valid:', len(valid_examples))
print()

test_examples = processor.get_test_examples(tsv_path + 'EL_TEST.tsv')
print(test_examples[30])
print('Test:', len(test_examples))
print()

InputExample(guid='train-11', text_a='甄嬛传 甄嬛传：安陵容怀孕时，雍正经常摸她的肚子，原来这动作大有深意', text_b='摘要:《甄嬛传》是经小说原作者流潋紫改编的一个新话剧，首度用戏剧的形式展现一个史诗类题材。 义项描述:话剧版《甄嬛传》', label='0')
Train: 535333

InputExample(guid='valid-21', text_a='中国作家 《中国作家》2013年度排行榜', text_b='中文名:中国作家 摘要:中国作家主要指中国现、当代作家。 起始:五四时期 标签:文化 包括:中国现当代作家 义项描述:作家分类 代表:鲁迅，朱自清', label='1')
Valid: 142939

InputExample(guid='test-31', text_a='妹妹 思追原来是个超级妹控，不愿妹妹嫁人，然而妹妹却喜欢一博老师', text_b='片长:8分4秒 外文名:Sister 主演:刘炳阳 编剧:宋思琪 imdb编码:tt9032798 色彩:黑白 义项描述:中国/美国2018年宋思琪执导影片 类型:动画电影 导演:宋思琪 出品公司:The Animation Showcase 发行公司:The Animation Showcase 制片地区:中国、美国 对白语言:普通话 摘要:《妹妹》是由宋思琪执导，刘炳阳参与配音的动画电影，该片于2018年6月15日在法国安纳西国际动画电影节上映。 上映时间:2018年6月15日（安纳西国际动画电影节） 中文名:妹妹', label='0')
Test: 309442



In [5]:
def create_dataloader(examples, 
                      tokenizer, 
                      max_length=384, 
                      shuffle=False,
                      batch_size=32,
                      num_workers=6):
    features = glue_convert_examples_to_features(
        examples,
        tokenizer,
        label_list=['0', '1'],
        max_length=max_length,
        output_mode='classification',
        pad_on_left=False,
        pad_token=tokenizer.pad_token_id,
        pad_token_segment_id=0,
    )
    
    dataset = torch.utils.data.TensorDataset(
        torch.LongTensor([f.input_ids for f in features]),
        torch.LongTensor([f.attention_mask for f in features]),
        torch.LongTensor([f.token_type_ids for f in features]),
        torch.LongTensor([f.label for f in features])
    )
    
    dataloader = torch.utils.data.DataLoader(
        dataset, 
        shuffle=shuffle, 
        batch_size=32, 
        num_workers=num_workers
    )
    return dataloader

In [6]:
pretrained_path = '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/'
tokenizer = BertTokenizer.from_pretrained(pretrained_path)

train_loader = create_dataloader(
    examples=train_examples,
    tokenizer=tokenizer,
    max_length=384,
    shuffle=True,
    batch_size=32,
    num_workers=6
)

valid_loader = create_dataloader(
    examples=valid_examples,
    tokenizer=tokenizer,
    max_length=384,
    shuffle=False,
    batch_size=32,
    num_workers=6
)

test_loader = create_dataloader(
    examples=test_examples,
    tokenizer=tokenizer,
    max_length=384,
    shuffle=False,
    batch_size=32,
    num_workers=6
)

INFO:transformers.tokenization_utils:Model name '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' not found in model shortcut name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese, bert-base-german-cased, bert-large-uncased-whole-word-masking, bert-large-cased-whole-word-masking, bert-large-uncased-whole-word-masking-finetuned-squad, bert-large-cased-whole-word-masking-finetuned-squad, bert-base-cased-finetuned-mrpc, bert-base-german-dbmdz-cased, bert-base-german-dbmdz-uncased, bert-base-finnish-cased-v1, bert-base-finnish-uncased-v1, bert-base-dutch-cased). Assuming '/media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/' is a path, a model identifier, or url to a directory containing tokenizer files.
INFO:transformers.tokenization_utils:Didn't find file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ex

INFO:transformers.data.processors.glue:attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:transformers.data.processors.glue:token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 

INFO:transformers.data.processors.glue:label: 0 (id = 0)
INFO:transformers.data.processors.glue:Writing example 10000/535333
INFO:transformers.data.processors.glue:Writing example 20000/535333
INFO:transformers.data.processors.glue:Writing example 30000/535333
INFO:transformers.data.processors.glue:Writing example 40000/535333
INFO:transformers.data.processors.glue:Writing example 50000/535333
INFO:transformers.data.processors.glue:Writing example 60000/535333
INFO:transformers.data.processors.glue:Writing example 70000/535333
INFO:transformers.data.processors.glue:Writing example 80000/535333
INFO:transformers.data.processors.glue:Writing example 90000/535333
INFO:transformers.data.processors.glue:Writing example 100000/535333
INFO:transformers.data.processors.glue:Writing example 110000/535333
INFO:transformers.data.processors.glue:Writing example 120000/535333
INFO:transformers.data.processors.glue:Writing example 130000/535333
INFO:transformers.data.processors.glue:Writing example 

INFO:transformers.data.processors.glue:token_type_ids: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:transformers.data.processors.glue:label: 0 (id = 0)
INFO:transformers.data.processors.glue:*** Example ***
INFO:transformers.data.processors.glue:guid: valid-3
INFO:transfo

INFO:transformers.data.processors.glue:attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:transformers.data.processors.glue:token_type_ids: 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 

INFO:transformers.data.processors.glue:token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:transformers.data.processors.glue:label: 0 (id = 0)
INFO:transformers.data.processors.glue:*** Example ***
INFO:transformers.data.processors.glue:guid: test-3
INFO:transfor

INFO:transformers.data.processors.glue:attention_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
INFO:transformers.data.processors.glue:token_type_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 

In [7]:
class EntityLinkingModel(pl.LightningModule):

    def __init__(self):
        super(EntityLinkingModel, self).__init__()

        config = BertConfig.from_json_file(pretrained_path+'/bert_config.json')
        config.num_labels = 1
        
        # 预训练模型
        self.bert = BertForSequenceClassification.from_pretrained(
            pretrained_path+'/pytorch_model.bin',
            config=config,
        )

        # 二分类损失函数
        self.criterion = nn.BCEWithLogitsLoss()

    def forward(self, input_ids, attention_mask, token_type_ids):
        logits = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
        )[0]
        return logits.squeeze()
    
    def training_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int() 
        acc = (preds == labels).float().mean()

        tensorboard_logs = {'train_loss': loss, 'train_acc': acc}
        return {'loss': loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        input_ids, attention_mask, token_type_ids, labels = batch
        logits = self(input_ids, attention_mask, token_type_ids)
        loss = self.criterion(logits, labels.float())

        preds = (logits > 0).int() 
        acc = (preds == labels).float().mean()

        return {'val_loss': loss, 'val_acc': acc}

    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        val_acc = torch.stack([x['val_acc'] for x in outputs]).mean()

        tensorboard_logs = {'val_loss': val_loss, 'val_acc': val_acc}
        return {'val_loss': val_loss, 'log': tensorboard_logs, 'progress_bar': tensorboard_logs}
    
    def configure_optimizers(self):
        return torch.optim.Adam([p for p in self.parameters() if p.requires_grad], lr=2e-5, eps=1e-8)

    def train_dataloader(self):
        return train_loader

    def val_dataloader(self):
        return valid_loader


In [9]:
torch.cuda.empty_cache()
model = EntityLinkingModel()
save_path = '/media/bnu/data/pytorch-lightning-checkpoints/EntityLinking/'

checkpoint_callback = pl.callbacks.ModelCheckpoint(
    filepath=save_path,
    save_top_k=True,
    verbose=True,
    monitor='val_loss',
    mode='min',
    prefix='EL_',
)

trainer = pl.Trainer(
    max_epochs=1,
#     val_check_interval=0.1,
    checkpoint_callback=checkpoint_callback,
#     gpus=1,
    gpus=2,
    distributed_backend='dp',
    default_save_path=save_path,
    profiler=True,
)

trainer.fit(model)

INFO:transformers.modeling_utils:loading weights file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch//pytorch_model.bin
INFO:transformers.modeling_utils:Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
INFO:transformers.modeling_utils:Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
INFO:lightning:GPU available: True, used: True
INFO:lightning:VISIBLE GPUS: 0,1
INFO:lightning:
    | Name                                                  | Type                          | Params
----------------------------------------------------------------------------------------------------


HBox(children=(FloatProgress(value=0.0, description='Validation sanity check', layout=Layout(flex='2'), max=5.…



HBox(children=(FloatProgress(value=1.0, bar_style='info', layout=Layout(flex='2'), max=1.0), HTML(value='')), …

RuntimeError: CUDA out of memory. Tried to allocate 18.00 MiB (GPU 0; 10.76 GiB total capacity; 9.35 GiB already allocated; 21.69 MiB free; 9.52 GiB reserved in total by PyTorch)

In [33]:
model = EntityLinkingModel()
for batch in valid_loader:
    input_ids, attention_mask, token_type_ids, labels = batch
    logits = model(input_ids, attention_mask, token_type_ids)
    print(logits)
    break

INFO:transformers.modeling_utils:loading weights file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch//pytorch_model.bin
INFO:transformers.modeling_utils:Weights of BertForSequenceClassification not initialized from pretrained model: ['classifier.weight', 'classifier.bias']
INFO:transformers.modeling_utils:Weights from pretrained model not used in BertForSequenceClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']


tensor([0.7366, 0.8043, 0.5222, 0.5221, 0.8672, 0.6929, 0.5325, 0.6781, 0.3220,
        0.4888, 0.5010, 0.2940, 0.4860, 0.5321, 0.5085, 0.5989, 0.5111, 0.5920,
        0.5531, 0.6092, 0.6262, 0.6959, 0.4237, 0.7809, 0.4985, 0.7876, 0.7559,
        0.5531, 0.5516, 0.8327, 0.6327, 0.6183])


In [35]:
ckpt_path = Path('/home/bnu/projects/CCKS2020-Entity-Linking/ckpt/')
trainer.save_checkpoint(ckpt_path/'EL-RoBERTa-128-0419.ckpt')

In [59]:
model = ELRoBERTaModel.load_from_checkpoint(
    ckpt_path/'EL-RoBERTa-128-0419.ckpt',
    pretrained_path=pretrained_path,
    train_loader=None,
    valid_loader=None,
)
model.eval()
batch = next(iter(valid_loader))
outputs = model(batch[0], batch[1], batch[2])
F.softmax(outputs, dim=-1)[:, 1]

INFO:transformers.configuration_utils:loading configuration file /media/bnu/data/transformers-pretrained-model/chinese_roberta_wwm_ext_pytorch/bert_config.json
INFO:transformers.configuration_utils:Model config BertConfig {
  "_num_labels": 2,
  "architectures": null,
  "attention_probs_dropout_prob": 0.1,
  "bad_words_ids": null,
  "bos_token_id": null,
  "decoder_start_token_id": null,
  "directionality": "bidi",
  "do_sample": false,
  "early_stopping": false,
  "eos_token_id": null,
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "id2label": {
    "0": "LABEL_0",
    "1": "LABEL_1"
  },
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "is_encoder_decoder": false,
  "label2id": {
    "LABEL_0": 0,
    "LABEL_1": 1
  },
  "layer_norm_eps": 1e-12,
  "length_penalty": 1.0,
  "max_length": 20,
  "max_position_embeddings": 512,
  "min_length": 0,
  "model_type": "bert",
  "no_repeat_ngram_size": 0,

tensor([6.3753e-03, 5.0309e-04, 4.3646e-04, 9.7265e-01, 9.9619e-01, 3.3942e-02,
        9.9538e-01, 7.3173e-01, 9.9759e-01, 9.9883e-01, 2.1060e-03, 9.8735e-01,
        3.8978e-04, 2.0384e-01, 9.5879e-01, 2.9894e-03, 1.2333e-02, 7.7380e-01,
        4.1989e-04, 1.2595e-01, 4.8184e-04, 9.7496e-01, 8.5529e-01, 1.4261e-03,
        4.5359e-04, 1.0327e-01, 4.1051e-04, 2.2023e-01, 1.5070e-02, 9.6544e-01,
        4.5117e-04, 4.6139e-03])