In [1]:
import pandas as pd

def read_conll(filename):
    df = pd.read_csv(filename,
                     sep='\t', header=None, keep_default_na=False,
                     names=['words', 'labels'], skip_blank_lines=False)
    df['sentence_id'] = (df.words == '').cumsum()
    return df[df.words != '']

In [2]:
train_data = pd.read_csv('../Datasets/final_version_dataset/train_data.csv')
test_data = read_conll('../Datasets/final_version_dataset/test_data.txt')

In [3]:
def ner_id2tag(id):
    id_map = {
        0: 'O',
        1: 'B-PER',
        2: 'I-PER',
        3: 'B-ORG',
        4: 'I-ORG',
        5: 'B-LOC',
        6: 'I-LOC',
    }

    return id_map[id]

In [4]:
def convert_to_conll_format(csv_file):
    dataframe = pd.read_csv(csv_file)
    id = dataframe['id']
    tokens = dataframe['tokens']
    ner_tags = dataframe['ner_tags']

    for i in range(17544, len(id)):
        for index in range(len(eval(tokens[i])[0])):
            output = pd.DataFrame({'words': [eval(tokens[i])[0][index]],
                                   'labels': [ner_id2tag(eval(ner_tags[i].replace(' ', ','))[index])],
                                   'sentence_id': [i+46364]})

            output.to_csv('../Datasets/msra_dataset/conll_msra_train.csv', mode='a', header=False,
                          columns=output.keys(), index=False, encoding='utf-8', index_label=False)
        if not (i % 100):
            print(f'current at index {i}')

In [None]:
data = [[train_data['sentence_id'].nunique(), test_data['sentence_id'].nunique()]]
pd.DataFrame(data, columns=["Train", "Test"])

In [6]:
train_args = {
    'reprocess_input_data': True,
    'overwrite_output_dir': True,
    'sliding_window': True,
    'max_seq_length': 64,
    'num_train_epochs': 15,
    'train_batch_size': 32,
    'fp16': True,
    'output_dir': '/outputs/',
}

In [7]:
custom_label = ['B-GPE','I-GPE','B-PER','I-PER','B-DATE','I-DATE','B-ORG','I-ORG','B-CARDINAL','I-CARDINAL','B-NORP','I-NORP','B-LOC','I-LOC','B-TIME','I-TIME','B-FAC','I-FAC','B-MONEY','I-MONEY','B-ORDINAL','I-ORDINAL','B-EVENT','I-EVENT','B-WFA','I-WFA','B-QUANTITY','I-QUANTITY','B-PERCENT','I-PERCENT','B-LANGUAGE','I-LANGUAGE','B-PRODUCT','I-PRODUCT','B-LAW','I-LAW','O']

In [None]:
from simpletransformers.ner import NERModel
from transformers import AutoTokenizer
import pandas as pd
import logging

logging.basicConfig(level=logging.DEBUG)
transformers_logger = logging.getLogger('transformers')
transformers_logger.setLevel(logging.WARNING)

# We use the bert base cased pre-trained model.
tokenizer = AutoTokenizer.from_pretrained('bert-base-chinese')
model = NERModel('bert', '../model_output/checkpoint-35835-epoch-15', args=train_args)

# Train the model, there is no development or validation set for this dataset
# https://simpletransformers.ai/docs/tips-and-tricks/#using-early-stopping

#model.train_model(train_data, output_dir='../model_output')

# Evaluate the model in terms of accuracy score
result, model_outputs, preds_list = model.eval_model(test_data)

In [None]:
strs = """曾有兩次酒駕紀錄的台北市66歲張姓計程車司機開車時沒繫安全帶，被員警發現後上前攔停，但張男不但沒有配合受檢，反而腳踩油門加速逃逸，瘋狂闖紅燈逃逸，最後又棄車徒步逃逸。員警通報線上警網攔截圍捕，在捷運善導寺站4號出口處將張姓運將攔下，經酒測0.7毫克嚴重超標，全案依公共危險罪嫌移送台北地檢署偵辦。另外，依據《道路交通管理處罰條例》新規定，張男遭撤銷執業登記，要等12年才能再次開車謀生。轄區中正一分局介壽派出所員警上月20日凌晨12點巡邏時，見計程車駕駛駕車未繫安全帶，便鳴笛示意駕駛靠邊停車，但張男完全不理會警方攔查，從衡陽路右轉中華路，之後連闖3個紅燈，在大街小巷亂竄，企圖擺脫員警追緝。員警尾隨在後，擔心強硬追車恐造成其他用路人危險，先停車通報線上警網實施攔截圍捕。此時張男逆向闖入忠孝西路巷弄內行駛約5公尺並趁機關閉大燈，最後將計程車停在青島東路路邊停車格後，與同車友人裝沒事下車步行逃跑。"""
samples = strs.split('。')
predictions, _ = model.predict(samples, split_on_space=False)
for idx, sample in enumerate(samples):
    print('{}: '.format(idx))
    for word in predictions[idx]:
        print('{}'.format(word))