In [None]:
import warnings
warnings.simplefilter('ignore')

import logging

import re

import pandas as pd
pd.set_option('max_rows', 500)
pd.set_option('max_colwidth', 100)

from tqdm import tqdm
tqdm.pandas()

from simpletransformers.ner import NERModel, NERArgs

In [None]:
train = pd.read_csv('raw_data/train.csv')
train.head(5)

In [None]:
clean_medicine = {
    '甲霜锰锌': "锰锌",
    '烯酰锰锌': "锰锌",
    '霜脲锰锌': "锰锌",
    '恶霜锰锌': "锰锌",
    '春雷王铜': "王铜",
    '阿维哒螨灵': "哒螨灵",
    '苯甲丙环唑': "丙环唑",
}

In [None]:
train_data = list()

for i, row in tqdm(train.iterrows()):
    id_ = i
    text = re.sub("[-\d\.%·\+，。％一＋]", "", row['text'])
    text = text.replace('多   /w', '多/w')
    for d in clean_medicine:
        text = text.replace(d, clean_medicine[d])

    for item in text.split():
        item = item.replace(' ', '')
        try:
            w, lbl = item.split('/')
            if lbl not in ['n_crop', 'n_disease', 'n_medicine']:
                for c in w:
                    train_data.append([id_, c, 'O'])
            elif lbl == 'n_crop':
                if len(w) < 2:
                    print(f"word len < 2: {w}")
                else:
                    train_data.append([id_, w[0], 'B_crop'])
                    if len(w) > 2:
                        for c in w[1:-1]:
                            train_data.append([id_, c, 'I_crop'])
                    train_data.append([id_, w[-1], 'E_crop'])
            elif lbl == 'n_disease':
                if len(w) < 2:
                    print(f"word len < 2: {w}")
                else:
                    train_data.append([id_, w[0], 'B_disease'])
                    if len(w) > 2:
                        for c in w[1:-1]:
                            train_data.append([id_, c, 'I_disease'])
                    train_data.append([id_, w[-1], 'E_disease'])
            elif lbl == 'n_medicine':
                if len(w) < 2:
                    print(f"word len < 2: {w}")
                else:
                    train_data.append([id_, w[0], 'B_medicine'])
                    if len(w) > 2:
                        for c in w[1:-1]:
                            train_data.append([id_, c, 'I_medicine'])
                    train_data.append([id_, w[-1], 'E_medicine'])
        except:
            item = re.sub(r'/[a-z]+', '', item)
            for c in item:
                train_data.append([id_, c, 'O'])

In [None]:
train_data = pd.DataFrame(
    train_data, columns=["sentence_id", "words", "labels"]
)

train_data.head()

In [None]:
train_data.labels.value_counts()

In [None]:
labels = [
    'B_crop',
    'I_crop',
    'E_crop',
    'B_disease',
    'I_disease',
    'E_disease',
    'B_medicine',
    'I_medicine',
    'E_medicine',
    'O'
]

In [None]:
eval_data = train_data[train_data['sentence_id'] >= len(train)-300]
eval_data.head()

In [None]:
train_data = train_data[train_data['sentence_id'] < len(train)-300]

train_data.shape, eval_data.shape

In [None]:
model_args = NERArgs()
model_args.train_batch_size = 8
model_args.num_train_epochs = 5
model_args.fp16 = False
model_args.evaluate_during_training = True

In [None]:
model = NERModel("bert", 
                 "hfl/chinese-bert-wwm-ext",
                 labels=labels,
                 args=model_args)

In [None]:
model.train_model(train_data, eval_data=eval_data)

In [None]:
result, model_outputs, preds_list = model.eval_model(eval_data)
result

In [None]:
test = pd.read_csv('raw_data/test.csv')
test.head()

In [None]:
test.shape

In [None]:
test_data = list()

for i, row in tqdm(test.iterrows()):
    id_ = i
    text = re.sub("[-\d\.%·\+，。％一＋ ]", "", row['text'])
    text = text.replace('多   /w', '多/w')
    for d in clean_medicine:
        text = text.replace(d, clean_medicine[d])
 
    preds, _ = model.predict([text], split_on_space=False)

    n_crop = list()
    n_disease = list()
    n_medicine = list()
    
    new_li = list()
    for i in preds[0]:
        for ch, lb in i.items():
            new_li.append([ch, lb])
            
    max_ = len(new_li)
    for i in range(max_):
        w = list()
        ch1, lb1 = new_li[i]
        if lb1 == 'B_crop':
            w.append(ch1)
            for j in range(i+1, max_):
                ch2, lb2 = new_li[j]
                if lb2 == 'I_crop' or lb2 == 'O':
                    w.append(ch2)
                elif lb2 == 'E_crop':
                    w.append(ch2)
                    n_crop.append("".join(w))
                    break
        elif lb1 == 'B_disease':
            w.append(ch1)
            for j in range(i+1, max_):
                ch2, lb2 = new_li[j]
                if lb2 == 'I_disease' or lb2 == 'O':
                    w.append(ch2)
                elif lb2 == 'E_disease':
                    w.append(ch2)
                    n_disease.append("".join(w))
                    break
        elif lb1 == 'B_medicine':
            w.append(ch1)
            for j in range(i+1, max_):
                ch2, lb2 = new_li[j]
                if lb2 == 'I_medicine' or lb2 == 'O':
                    w.append(ch2)
                elif lb2 == 'E_medicine':
                    w.append(ch2)
                    n_medicine.append("".join(w))
                    break
                    
    test_data.append([id_, n_crop, n_disease, n_medicine])

In [None]:
test_data = pd.DataFrame(
    test_data, columns=['id', 'n_crop', 'n_disease', 'n_medicine']
)

test_data.head(10)

In [None]:
test_data.to_csv('submission.csv', index=False)