In [1]:
import pandas as pd
import numpy as np
import random
import torch
import torch.nn as nn
from tqdm import tqdm
from torch.utils.data import SequentialSampler, DataLoader

In [2]:
data_path = '../dataset/test.csv'
df = pd.read_csv(data_path, delimiter="\t")
df.info()
df.head(5)



<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2657 entries, 0 to 2656
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   text    2657 non-null   object
dtypes: object(1)
memory usage: 20.9+ KB


Unnamed: 0,text
0,比如片子一开始说到的暴雪娱乐的作品有……CS……。其实我看到这个时也喷饭了，这
1,辽宁14家城商行开始执行央行房贷政策
2,而本场面对西布朗这样一支弱队，相信维冈不会放过这样的机会。目前，风扬数据下，由于平值从中阻隔，
3,《辐射》设计师Anderson转投inXile
4,中信晒卡是一张个性十足的借记卡，中信银行网站为其提供了一个特色操作平台，客户通过简单的操作，


In [3]:
from config import parse_args

args = parse_args()
def setup_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
setup_seed(args.seed)


args.tag2idx = {'O':0, 'B-0':1, 'I-0':2}
args.idx2tag = {0: 'O', 1: 'B-0', 2:'I-0'}

In [4]:
args.word2idx = pd.read_pickle('./word2idx.pkl')
args.vocab_len = len(args.word2idx)

print('vocab_len: ', args.vocab_len)

vocab_len:  3040


In [5]:
test_data = [[list(i)] for i in df['text']]

In [6]:
# test_data

In [7]:
from data_helper import NER_Dataset
test_data_datsset = NER_Dataset(test_data, args, test_mode=True)
sampler = SequentialSampler(test_data_datsset)
dataloader = DataLoader(test_data_datsset,
                        batch_size=args.test_batch_size,
                        sampler=sampler)

In [8]:
if torch.cuda.is_available():
    args.device = 'cuda:0'
    print('使用：', args.device,' ing........')
from model import BiLSTM_CRF
model = BiLSTM_CRF(args)
path = f'./save_model/best_model.pth'
model.load_state_dict(torch.load(path, map_location='cpu'))
model.to(args.device)

使用： cuda:0  ing........


BiLSTM_CRF(
  (embedding): Embedding(3040, 300, padding_idx=3039)
  (rnn): LSTM(300, 300, batch_first=True, bidirectional=True)
  (hidden2tag): Linear(in_features=600, out_features=3, bias=True)
  (crf): CRF(num_tags=3)
)

In [9]:
# 保存有所样本的预测结果
predict_tag = []

model = model.eval()
with torch.no_grad():
    for sample in tqdm(dataloader, 'val'):
        sentence_tensor = sample['sentence_tensor'].to(args.device)
        mask_tensor = sample['mask_tensor'].to(args.device)
        # label_tensor = sample['label_tensor'].to(configs.device)
        out = model(sentence_tensor=sentence_tensor,
                    label_tensor=None,
                    mask_tensor=mask_tensor)

        for l in out:
            temp = []
            for i in l:
                temp.append(args.idx2tag[i])
            predict_tag.append(temp)

val: 100%|██████████| 84/84 [00:04<00:00, 20.07it/s]


In [10]:
from ark_nlp.factory.utils.conlleval import get_entity_bio
def extract_entity(label, text):
    entity_labels = []
    for _type, _start_idx, _end_idx in get_entity_bio(label, id2label=None):
            entity_labels.append({
                'start_idx': _start_idx,
                'end_idx': _end_idx,
                'type': _type,
                'entity': text[_start_idx: _end_idx + 1]
            })
    entity_list = []
    for info in entity_labels:
        entity_list.append(info['entity'])
    return entity_list

In [11]:
extract_entity(label=predict_tag[0], text=df['text'][0])

['暴雪娱乐']

In [12]:
tag_list = []
for idx, label in enumerate(predict_tag):
    tag_list.append(extract_entity(label=label, text=df['text'][idx]))

In [13]:
tag_list

[['暴雪娱乐'],
 ['辽宁14家城商行'],
 ['西布朗', '维冈'],
 ['inXile'],
 ['中信银行'],
 ['欧冠', '联盟杯'],
 ['星展银行香港分行'],
 ['民生蓝筹混合型', '民生加银'],
 ['北京澳际教育咨询有限公司', '澳新亚留学中心'],
 ['老特拉福德', '切尔西', '曼联'],
 [],
 ['布莱克本'],
 ['拉科', '马拉加'],
 ['挪威', '荷兰'],
 ['波特', '华纳兄弟公司'],
 ['网龙', '网龙'],
 ['StarsWar6Killer', '上海正大广场'],
 [],
 ['博洛尼亚'],
 ['光大三家银行'],
 ['西南证券办公室'],
 ['欧足联'],
 ['阿森纳', '基辅迪纳摩'],
 ['波尔多小镇一期'],
 ['科隆', 'IEM5世界总决赛'],
 ['宜昌高新区港窑路25号南苑', '南都御景1号楼12楼4'],
 ['楼兰', '罗布泊西北岸'],
 ['苏格兰', '意大利'],
 ['塞维利亚', '皇马', '马竞'],
 ['切尔西', '意大利', '罗马'],
 ['长安'],
 ['辉煌云上', '辉煌集团'],
 ['北三环马甸桥'],
 ['摩根大', '亚太区', '子行'],
 ['宝钢'],
 ['意甲'],
 ['长篇新行', '老三工作室'],
 ['荷兰', 'nec尼美根'],
 ['雷曼兄弟'],
 ['英国伦敦', 'Seas公司'],
 ['印尼', '奥地利'],
 ['纽卡斯尔', '热刺'],
 ['质监局', '旧换新工作小组'],
 ['南昌市兴业银行'],
 ['华夏'],
 ['梅塔利斯特'],
 ['意甲', '联盟杯'],
 ['民生银行'],
 ['万江街道', '益电子厂现货'],
 ['曜越太阳神'],
 ['那不勒斯', '卡利亚里'],
 ['德国复兴信贷银行', '美国国际集团（aig）'],
 ['中国航天科技集团'],
 [],
 ['ac米兰', '拉齐奥'],
 ['曼城', '英超'],
 ['imf'],
 ['吉隆滩'],
 ['昆明市', '昆明市'],
 ['太阳神', '橘子熊'],
 ['伊朗革命卫队', '波斯湾', '霍尔木兹海'],
 

In [14]:
new_df = pd.DataFrame({'tag': tag_list})
new_df.to_csv('suubmit.csv', index=False)
