In [1]:
import warnings
warnings.simplefilter('ignore')

import re

import numpy as np
import pandas as pd
pd.set_option('max_rows', 100)

from tqdm import tqdm
tqdm.pandas()

from simpletransformers.ner import NERModel, NERArgs



# 文本清洗

In [2]:
!cp -v 'raw_data/bmes_train.json' 'raw_data/bmes_train_corrected.json' 

# 这四个是多了中括号
!sed -i 's/二河闸\[/二河闸/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/克州\[/克州/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/流域\[/流域/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/徐水\[-/徐水-/' 'raw_data/bmes_train_corrected.json'

# 这些不规范 :(
!sed -i 's/#河流\*\]//' 'raw_data/bmes_train_corrected.json'
!sed -i 's/岳西县#地区\*，经\[@/岳西县#地区*[@]，经/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/\[@姜明/[@]姜明/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/义和村#地区\*，全\[@/义和村#地区*[@]，全[@]/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/，全\[@\]/，全/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/幸福乡苇子沟#地区\*，经动力区\[@黎明乡、/幸福乡苇子沟，经动力区黎明乡、/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/苇子沟#地区\*，经动力区\[@黎明乡-LOC/苇子沟-LOC", "黎明乡-LOC/' 'raw_data/bmes_train_corrected.json'
!sed -i 's/遂城镇#地区\*\]/遂城镇/' 'raw_data/bmes_train_corrected.json'

'raw_data/bmes_train.json' -> 'raw_data/bmes_train_corrected.json'


In [3]:
label_map = {
    '河流': 'RIV',
    '地区': 'LOC',
    '水库': 'RES',
    '水利术语': 'TER',
}


with open('raw_data/bmes_train_corrected.json') as f:
    lines = f.read().split('\n')
    

correct_lines = list()

for line in lines:
    if re.search('"text":', line):  # 文本
        line = re.sub(r'#[^\*]+\*\[@\]', '', line)
        
    if re.search(r'^\s+"[^a-z]+.*\*\[@\]', line):   # 标签
        line = line.replace('*[@]', '","')
        line = line.replace('#', '-')
        for label in label_map.keys():
            line = line.replace(label, label_map[label]) 
    
    correct_lines.append(line)
    
    
with open('raw_data/bmes_train_corrected.json', 'w') as f:
    f.write("\n".join(correct_lines))

# 文本预处理

In [4]:
train = pd.read_json('raw_data/bmes_train_corrected.json')
train.head()

Unnamed: 0,id,text,entities
0,train_0,1976年冬，配合兴建涡河闸水利工程，开挖了涡河引河，全长1.2公里。,[涡河-RIV]
1,train_1,宋代，甓社湖曾现珠光（河蚌珠光），并为在甓社湖居室临窗夜读的著名学者孙觉亲眼所见，被其好友沈...,"[珠湖-LAK, 高邮湖-LAK, 甓社湖-LAK]"
2,train_2,2008年5月10日上午，山东省第五座跨黄河大桥，黄河首座———济南建邦黄河大桥在济南西外环...,"[黄河-RIV, 济南-LOC, 山东省-LOC]"
3,train_3,学校曾获国家及省市级“职业教育先进单位”称号，是全国计算机应用技术考试培训基地（NIT），安...,"[安徽省-LOC, 太湖县-LOC]"
4,train_4,东西两源汇合后，进入平原区，北流经过石埠嘴、船涨埠，至白洋淀后进入瓦埠湖。,"[白洋淀-LAK, 瓦埠湖-LAK]"


In [5]:
train_data = list()

for i, row in tqdm(train.iterrows()):
    entities = row['entities']
    text = row['text']
    id_list = [i] * len(text)
    label_list = ["O"] * len(text)
    finished = list()
    for entity in entities:
        name, label = entity.split('-')
        for m in re.finditer(name, text):     # 找到字符串的起始位置和它们所属的标签
            start = m.start()
            end = m.end()
            if start not in finished:
                label_list[start] = f'B_{label}'
                finished.append(start)
                for i in range(start+1, end):
                    if i not in finished:
                        label_list[i] = f'I_{label}'
                        finished.append(i)
    train_data.extend(zip(id_list, text, label_list))

4919it [00:01, 3157.75it/s]


In [6]:
train_data = pd.DataFrame(
    train_data, columns=["sentence_id", "words", "labels"]
)

train_data.head(10)

Unnamed: 0,sentence_id,words,labels
0,0,1,O
1,0,9,O
2,0,7,O
3,0,6,O
4,0,年,O
5,0,冬,O
6,0,，,O
7,0,配,O
8,0,合,O
9,0,兴,O


In [8]:
# train_data[train_data.sentence_id==2354]

In [9]:
print(train_data.shape, train_data.sentence_id.nunique())

fold = 0
eval_data = train_data[train_data['sentence_id'] % 10 == fold]
train_data = train_data[train_data['sentence_id'] % 10 != fold]

print(train_data.shape, eval_data.shape)

(263624, 3) 4919
(237503, 3) (26121, 3)


# 模型训练

In [10]:
model_args = NERArgs()
model_args.train_batch_size = 8
model_args.num_train_epochs = 8
model_args.fp16 = False
model_args.evaluate_during_training = True
model_args.overwrite_output_dir = True
model_args.save_steps = -1
model_args.save_eval_checkpoints = False

In [11]:
model = NERModel("bert", 
                 "hfl/chinese-roberta-wwm-ext",
                 labels=train_data.labels.unique().tolist(),
                 args=model_args)

Some weights of the model checkpoint at hfl/chinese-roberta-wwm-ext were not used when initializing BertForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPretraining model).
- This IS NOT expected if you are initializing BertForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForTokenClassification were not initialized from the model checkpoint at h

In [12]:
model.train_model(train_data, eval_data=eval_data)

HBox(children=(FloatProgress(value=0.0, max=4427.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Epoch', max=8.0, style=ProgressStyle(description_width='i…

HBox(children=(FloatProgress(value=0.0, description='Running Epoch 0 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 1 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 2 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 3 of 8', max=554.0, style=ProgressStyle(des…

HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…





HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 4 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 5 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 6 of 8', max=554.0, style=ProgressStyle(des…




HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




HBox(children=(FloatProgress(value=0.0, description='Running Epoch 7 of 8', max=554.0, style=ProgressStyle(des…

HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…





HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…





(4432,
 {'global_step': [554, 1108, 1662, 2000, 2216, 2770, 3324, 3878, 4000, 4432],
  'precision': [0.7824257425742575,
   0.8135056713268267,
   0.7879464285714286,
   0.7716972034715526,
   0.7897943640517898,
   0.7923910304862686,
   0.7997442455242967,
   0.7842364532019704,
   0.7814261210487625,
   0.8137513171759747],
  'recall': [0.8375728669846317,
   0.8171701112877583,
   0.8418124006359301,
   0.8481717011128775,
   0.8243243243243243,
   0.8333333333333334,
   0.8285638579756227,
   0.8436671966083731,
   0.8449920508744038,
   0.8184949655537891],
  'f1_score': [0.8090606603532121,
   0.8153337739590218,
   0.813989239046887,
   0.8081292602878062,
   0.806690003889537,
   0.8123466356709286,
   0.8138990109318064,
   0.8128669900434006,
   0.8119669000636537,
   0.8161162483487451],
  'train_loss': [0.41048386693000793,
   0.28162282705307007,
   0.06361792981624603,
   0.12904062867164612,
   0.022538699209690094,
   0.008450004272162914,
   0.01704234629869461,
   0.

In [13]:
result, _, _ = model.eval_model(eval_data)
result

HBox(children=(FloatProgress(value=0.0, max=492.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Evaluation', max=62.0, style=ProgressStyle(descri…




{'eval_loss': 0.33824595208129576,
 'precision': 0.8137513171759747,
 'recall': 0.8184949655537891,
 'f1_score': 0.8161162483487451}

# 模型预测及生成提交文件

In [14]:
test = pd.read_json('raw_data/bmes_test.json')
test.head()

Unnamed: 0,id,text
0,test_0,2013年有1个项目获得国家创新基金，2个项目获国家科技型中小企业创新基金支持，1家民营企业...
1,test_1,汉江遥堤堤防防洪标准为防御汉江1964年型洪水，堤顶高程按设计洪水位加超高2米确定，面宽10...
2,test_2,纳支流七里川河后称红岩河，主流经嘴头镇、白云乡、王家凌乡，于王家凌乡之擂鼓滩出本县境入留坝县。
3,test_3,云山水库，位于七虎林河上游，1958年由复转官兵承建，水库的蓄水量为4750万立方米，控制了...
4,test_4,武穴正处武山湖深水口古北江遗址上，原称青林湖，古长江九大要穴之一。


In [15]:
test_data = list()

for i, row in tqdm(test.iterrows()):
    text = row['text']
    test_data.append(text)
    
preds, raw_outputs = model.predict(test_data, split_on_space=False)

628it [00:00, 9060.22it/s]


HBox(children=(FloatProgress(value=0.0, max=628.0), HTML(value='')))




HBox(children=(FloatProgress(value=0.0, description='Running Prediction', max=79.0, style=ProgressStyle(descri…




In [16]:
sub = pd.read_json('raw_data/bmes_sample.json')
sub.head()

Unnamed: 0,id,entities
0,test_0,"[20-LOC, 称号-TER]"
1,test_1,"[汉江-LOC, :3-TER]"
2,test_2,"[纳支-LOC, 坝县-TER]"
3,test_3,"[云山-LOC, 水量-TER]"
4,test_4,"[武穴-LOC, 之一-TER]"


In [17]:
final = list()

for pred in tqdm(preds):
    res_list = list() 
    for char in pred:
        if list(char.values())[0] != 'O':
            res_list.extend(zip(char.keys(), char.values()))
            chars = [i[0] for i in res_list]
            labels = [i[1] for i in res_list]
            words = list()
            word = ''
            for i, label in enumerate(labels):
                if label.startswith('B_'):
                    if i > 0:
                        lbl = labels[i-1].replace('I_', '')
                        words.append(f'{word}-{lbl}')
                    word = chars[i]
                else:
                    word += chars[i]
                    if i == len(labels) - 1:
                        lbl = labels[i].replace('I_', '')
                        words.append(f'{word}-{lbl}')
            ret = list()
            for item in words:
                if item not in ret:
                    ret.append(item)
    final.append(ret)

100%|██████████| 628/628 [00:00<00:00, 6863.85it/s]


In [18]:
sub['entities'] = final
sub.head()

Unnamed: 0,id,entities
0,test_0,"[临河区-LOC, 巴彦淖尔市-LOC]"
1,test_1,[汉江-RIV]
2,test_2,"[七里川河-RIV, 红岩河-RIV, 嘴头镇-LOC, 白云乡-LOC, 王家凌乡-LOC..."
3,test_3,"[云山水库-RES, 七虎林河-RIV, 上游-TER]"
4,test_4,"[武山湖-LAK, 古-B_RIV, 北江-RIV, 青林湖-LAK, 长江-RIV]"


In [19]:
sub.to_json('baseline.json', orient='records', force_ascii=False)