In [1]:
import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import SequentialSampler, DataLoader
from transformers import BertTokenizer
from model import GlobalPointer, GlobalPointerNERPredictor

In [2]:
data_path = '../dataset/test.csv'
df = pd.read_csv(data_path, delimiter="\t")
df.info()
df.head(5)



<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2657 entries, 0 to 2656
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    2657 non-null   object
dtypes: object(1)
memory usage: 20.9+ KB


Unnamed: 0,text
0,比如片子一开始说到的暴雪娱乐的作品有……CS……。其实我看到这个时也喷饭了，这
1,辽宁14家城商行开始执行央行房贷政策
2,而本场面对西布朗这样一支弱队，相信维冈不会放过这样的机会。目前，风扬数据下，由于平值从中阻隔，
3,《辐射》设计师Anderson转投inXile
4,中信晒卡是一张个性十足的借记卡，中信银行网站为其提供了一个特色操作平台，客户通过简单的操作，


In [3]:
from config import parse_args

args = parse_args()
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
setup_seed(args.seed)


args.tag2idx = {'O':0, 'B-0':1, 'I-0':2}
args.idx2tag = {0: 'O', 1: 'B-0', 2:'I-0'}

In [4]:
# test_data


from ark_nlp.model.ner.global_pointer_bert import Tokenizer
tokenizer = BertTokenizer.from_pretrained(args.bert_dir)
ark_tokenizer = Tokenizer(vocab=tokenizer, max_seq_len=128)
import os
from ark_nlp.factory.utils.conlleval import get_entity_bio



Ent2id, id2Ent = {'0': 0, 'O': 1}, {0: '0', 1: 'O'}



In [5]:
if torch.cuda.is_available():
    args.device = 'cuda:0'
    print('使用：', args.device,' ing........')

model = GlobalPointer(args, len(Ent2id), 64).to(args.device)  #
path = f'./save_model/best_model.pth'
model.load_state_dict(torch.load(path, map_location='cpu'))
model=model.to(args.device)

使用： cuda:0  ing........


Some weights of the model checkpoint at ../hfl/chinese-roberta-wwm-ext were not used when initializing BertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
model.eval()
ner_predictor_instance = GlobalPointerNERPredictor(model, ark_tokenizer, Ent2id, tokenizer)

from tqdm import tqdm

predict_results = []

for i in tqdm(range(len(df))):
    _line = df['text'][i]
    label = len(_line) * ['O']
    for _preditc in ner_predictor_instance.predict_one_sample(_line):
        if 'I' in label[_preditc['start_idx']]:
            continue
        if 'B' in label[_preditc['start_idx']] and 'O' not in label[_preditc['end_idx']]:
            continue
        if 'O' in label[_preditc['start_idx']] and 'B' in label[_preditc['end_idx']]:
            continue

        label[_preditc['start_idx']] = 'B-' + _preditc['type']
        label[_preditc['start_idx'] + 1: _preditc['end_idx'] + 1] = (_preditc['end_idx'] - _preditc[
            'start_idx']) * [('I-' + _preditc['type'])]

    predict_results.append(label)

100%|██████████| 2657/2657 [00:27<00:00, 96.36it/s] 


In [7]:
print(len(predict_results))
#%

2657


In [8]:
from ark_nlp.factory.utils.conlleval import get_entity_bio
def extract_entity(label, text):
    entity_labels = []
    for _type, _start_idx, _end_idx in get_entity_bio(label, id2label=None):
        entity_labels.append({
            'start_idx': _start_idx,
            'end_idx': _end_idx,
            'type': _type,
            'entity': text[_start_idx: _end_idx + 1]
        })
    entity_list = []
    for info in entity_labels:
        entity_list.append(info['entity'])
    return entity_list

In [9]:
tag_list = []
for idx, label in enumerate(predict_results):
    tag_list.append(extract_entity(label=label, text=df['text'][idx]))

In [10]:
tag_list

[['暴雪娱乐'],
 ['辽宁'],
 ['西布朗', '维冈'],
 ['inXile'],
 ['中信', '中信银行'],
 ['欧冠', '联盟杯'],
 ['星展银行香港分行'],
 ['民生蓝筹混合型基金', '民生加银'],
 ['北京澳际教育咨询有限公司', '澳新亚留学中心'],
 ['老特拉福德', '切尔西', '曼联'],
 ['柳州'],
 ['布莱克本'],
 ['拉科', '马拉加'],
 ['挪威', '荷兰'],
 ['华纳兄弟公司'],
 ['网龙'],
 ['上海正大广场'],
 ['盛大'],
 ['博洛尼亚'],
 ['光大'],
 ['西南证券办公室'],
 ['欧足联'],
 ['阿森纳', '基辅迪纳摩'],
 ['波尔多小镇一期'],
 ['科隆', 'IEM5世界总决赛'],
 ['宜昌高新区港窑路25号南苑二期14栋南都御景1号楼12楼46-1-148'],
 ['罗布泊西北岸'],
 ['苏格兰', '意大利'],
 ['塞维利亚', '皇马', '马竞'],
 ['切尔西', '罗马'],
 ['邙山', '长安', '陈仓'],
 ['辉煌云上', '辉煌集团'],
 ['北三环马甸桥', 'cbd'],
 ['摩根大通'],
 ['宝钢股份'],
 ['意甲'],
 ['798工厂'],
 ['荷兰', 'nec尼美根'],
 ['雷曼兄弟'],
 ['英国伦敦', 'Candella'],
 ['印尼', '奥地利'],
 ['纽卡斯尔', '热刺'],
 [],
 ['南昌市兴业银行'],
 ['华夏'],
 ['梅塔利斯特'],
 ['意甲球队', '联盟杯'],
 ['民生银行'],
 ['广东东莞市万江区东莞市万江街道生益电子厂'],
 ['曜越太阳神跑跑战队'],
 ['那不勒斯', '卡利亚里'],
 ['德国复兴信贷银行'],
 ['中国航天科技集团'],
 ['萨尔瓦多'],
 ['ac米兰', '拉齐奥'],
 ['曼城', '英超'],
 ['imf'],
 ['吉隆滩'],
 [],
 ['太阳神', '橘子熊'],
 ['波斯湾', '霍尔木兹海峡水域'],
 ['芬兰'],
 ['广发行'],
 ['中国国际数码互动娱乐展览会', '上海国际博览中心'],
 ['波鸿', '

In [11]:
new_df = pd.DataFrame({'tag': tag_list})
new_df.to_csv('suubmit.csv', index=False)
