1. 通过调节参数threshold(阈值)（0%~100%）来控制识别的标签是否合理，默认设置为60%，即0.6
2. 药品输入长度最大为64，超过部分会被截断
3. 仅支持识别单个药品名
4. 如果错别字为药品名关键字时或者有多个错别字的时候正确率不高，字符顺序调换影响不大
5. 当输出很接近正确分类但不正确时，可以调整k（1，2或3）的值来获取更多可能

In [1]:
# %pip install -r requirements.txt

In [2]:
import numpy as np
import pandas as pd
from sklearn.preprocessing import LabelEncoder
import torch
from transformers import AutoTokenizer
from utils.model import *

k = 2
threshold = 0.6
# source_max_length = 64

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
device = 'cuda:0'
# device = 'cpu'

In [3]:
def predict(texts, model, tokenizer, threshold=0.6, k=1, guess=False):
    model.eval()
    assert k in (1,2,3)
    atc_label = pd.read_csv('datasets/atc_label.csv')
    le = LabelEncoder()
    le.fit(atc_label.label)
    id_to_label = dict(zip(le.transform(le.classes_), le.classes_))
    texts_, res1, res2, res3 = [text.strip().upper() for text in texts], [], [], []
    input_ids = tokenizer(texts_, padding='max_length', truncation=True, max_length=64)['input_ids']
    logits = model({'input_ids':torch.tensor(input_ids).to(device), 'mode':'predict'}).detach().cpu().numpy()
    for logit in logits:
        if np.max(logit) < threshold:
            res1.append('其它')
            res2.append(f'{id_to_label[np.argmax(logit)]}' if guess else '')
            res3.append('')
        else:
            pred1, pred2, pred3 = logit.argsort()[-3:]
            label1, label2, label3 = id_to_label[pred1], id_to_label[pred2], id_to_label[pred3]
            res1.append(f'{label3} ({min(1., logit[pred3])*100:.01f}%)')
            res2.append(f'{label2} ({min(1., logit[pred2])*100:.01f}%)' if logit[pred2] >= threshold else '')
            res3.append(f'{label1} ({min(1., logit[pred1])*100:.01f}%)' if logit[pred1] >= threshold else '')
    if k == 1:
        return pd.DataFrame({'text':texts, 'predicted':res1})
    elif k == 2:
        return pd.DataFrame({'text':texts, 'first':res1, 'second': res2})
    else:
        return pd.DataFrame({'text':texts, 'first':res1, 'second': res2, 'third':res3})

In [4]:
tokenizer = AutoTokenizer.from_pretrained('tokenizer')
model = torch.load('models/atc_model_20.pt').to(device)

In [5]:
texts = open('input.txt', encoding='utf8').read().split('\n')
pred_info = predict(texts, model, tokenizer, threshold=threshold, k=k, guess=False)
pred_info.to_csv('output.csv', index=None, header=None)
pred_info

Unnamed: 0,text,first,second
0,八益母胶囊（甲类）,八珍益母胶囊 (87.1%),
1,95%乙纯,乙醇 (60.4%),
2,6味地黄丸(ID=2002738663),六味地黄丸(水丸) (67.3%),
3,VC翘银片(基药),维C银翘片 (77.8%),银翘解毒丸 (69.5%)
4,醋钙酸片（的灵）,其它各类治疗用药品 (82.0%),
5,Y维生圣#素AD滴剂(乙)(限儿童)一岁以下,维生素A和维生素D的复方 (100.0%),
6,丁柜儿齐贴（丙),丁桂儿脐贴 (66.3%),
7,丁酸氢化可的松乳膏△[门特],丁酸氢化可的松 (97.7%),
8,(基)盐酸四环素醋酸可的松眼膏,四环素的复方 (74.2%),四环素 (73.9%)
9,复方颗粒,其它,


In [6]:
atc_data = pd.read_csv('datasets/atc_clean.csv')
sample = atc_data.sample(n=15)
texts = sample.source.tolist()
expected = sample.target.tolist()
pred_info = predict(texts, model, tokenizer, threshold=threshold, k=k)
pred_info['expected'] = expected
pred_info[''] = [l.split()[0] for l in pred_info['first']] == pred_info['expected']
pred_info

Unnamed: 0,text,first,second,expected,Unnamed: 5
0,硫酸镁注射液术中断脐后用,硫酸镁 (100.0%),,硫酸镁,True
1,四制艾叶,四制香附丸 (86.3%),,四制香附丸,True
2,坤泰胶囊乙类,坤泰胶囊 (97.2%),,坤泰胶囊,True
3,肝素钠乳膏,肝素 (100.0%),,肝素,True
4,多柔比星基本,多柔比星（阿霉素） (87.0%),,多柔比星（阿霉素）,True
5,人免疫球蛋白丙,标准的人免疫球蛋白 (87.0%),静注乙型肝炎人免疫球蛋白(pH4) (61.8%),标准的人免疫球蛋白,True
6,停银杏叶滴丸国,银杏叶丸 (88.0%),,银杏叶丸,True
7,卡铂注射液50mg支,卡铂 (100.0%),,卡铂,True
8,润肠胶囊,润肠胶囊 (83.4%),,润肠胶囊,True
9,泛昔洛韦片丽珠风片盒,泛昔洛韦 (97.2%),,泛昔洛韦,True
