In [51]:
import csv
import re
from transformers import BertTokenizer
import torch
import torch.nn as nn

In [52]:
class model_crf(nn.Module):
    def __init__(self, n_tags, hidden_dim=768, batchsize= 32):
        super(model_crf, self).__init__()
        self.n_tags = n_tags
        self.lstm =  nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=hidden_dim//2, dropout= 0.3, batch_first=True)		
        self.hidden_dim = hidden_dim
        self.fc = nn.Linear(hidden_dim, self.n_tags)
        self.bert = BertModel.from_pretrained('bert-base-chinese')

        # for param in self.bert.parameters():
        # 	param.requires_grad = False
        # self.bert.eval()  # 知用来取bert embedding

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.CRF = CRF(n_tags, batch_first= True)
        self.hidden = self.init_hidden(batchsize)

    def init_hidden(self, batch_size):
        return (torch.randn(2*2, batch_size, self.hidden_dim // 2).to(self.device),
                torch.randn(2*2, batch_size, self.hidden_dim // 2).to(self.device))

    def forward(self, input_ids, attention_mask, tags):

        batch_size = input_ids.size(0)
        max_seq_len = input_ids.size(1)
        bert_output, _  = self.bert(input_ids.long(), attention_mask)

        seq_len = torch.sum(attention_mask, dim= 1).cpu().int()
        # print(seq_len)
        pack_input = pack_padded_sequence(input= bert_output, lengths= seq_len, batch_first= True, enforce_sorted= False)
        packed_lstm_out, _ = self.lstm(pack_input,self.init_hidden(batch_size= batch_size))
        lstm_enc, _=  pad_packed_sequence(packed_lstm_out, batch_first=True, padding_value=0)
        # print(lstm_enc.size())
        lstm_feats = self.fc(lstm_enc)

        lstm_max_seq_len = lstm_feats.size(1)
        pad = torch.zeros(size=(batch_size, max_seq_len-lstm_max_seq_len, self.n_tags), dtype= torch.float).to(self.device)
        lstm_feats= torch.cat((lstm_feats, pad), dim= 1)

        lstm_feats[:,:,1:-1] = lstm_feats[:,:,1:-1]*100
        loss = -self.CRF(lstm_feats, tags, attention_mask.bool(), reduction= 'token_mean')
        pred_seqs = self.CRF.decode(emissions= lstm_feats, mask= attention_mask.bool())

        return loss, pred_seqs

In [2]:
with open('test_input.data', 'r', encoding= 'utf-8') as f:
    data = f.readlines()

print(data[:10])
    

['醫\n', '師\n', '：\n', '最\n', '近\n', '人\n', '有\n', '沒\n', '有\n', '什\n']


In [28]:
class test_output():
    def __init__(self, data):
        
        self.data_list = []
        self.word_id = []
        self.word_article_id = [] 
        article_id = 0
        word_id = 0
        data_list_tmp = []
        article_id_tmp = []
        word_id_tmp = []
        
        for row in data:
            
            data_tuple = tuple()
            if row == '\n':
                
                article_id += 1 
                word_id = 0
                self.word_id.append(word_id)
                self.word_article_id.append(article_id_tmp)
                self.data_list.append(data_list_tmp)
                data_list_tmp = []
                article_id_tmp = []
                word_id_tmp = []

            else:
                
                row = row.strip('\n').split(' ')

                if row[0] in ['。', '？','！','，','～','：',',','‧']:
                    
                    self.word_id.append(word_id_tmp)
                    self.word_article_id.append(article_id_tmp)
                    self.data_list.append(data_list_tmp)
                    data_list_tmp = []
                    article_id_tmp = []
                    word_id_tmp = []
                    
                elif row[0] in ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']:
                      
                    data_tuple = (row[0].lower(), article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)

                elif row[0] not in ['摁','嗯','啦','喔','欸','啊','齁','嘿','…','...']:
                    
                    data_tuple = (row[0], article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)
                    
                word_id += 1
                
        if len(data_list_tmp) != 0:
            self.data_list.append(data_list_tmp)
            self.word_id.append(word_id_tmp)
            self.word_article_id.append(article_id_tmp)
            
    def raw_output(self):
        return self.data_list, self.word_id, self.word_article_id
    
    def get_stcs(self):
        
        all_stcs = list()
        all_article_ids = list()
        all_word_ids = list()

        for stc_list in self.data_list:

            txt_len = len(stc_list) #(文章數，每個文章對應的總字數) (word, label)
            stc = str() #存字數= max_stc_len的字串
            article_ids = []
            word_ids = []
            

            for idx, (word,article_id, word_id) in enumerate(stc_list):

                stc += word
                article_ids.append(article_id)
                word_ids.append(word_id)

            all_stcs.append(stc)
            all_article_ids.append(article_ids)
            all_word_ids.append(word_ids)
        
        assert len(all_stcs) > 0, 'all stcs len = 0' 
        
        all_stcs_clean = []
        all_article_ids_clean = []
        all_word_ids_clean = []
        idx = 0
        
        for stc, article_id, word_id in zip(all_stcs, all_article_ids, all_word_ids):
            stc_clean = re.sub(r'(醫師)|(個管師)|(民眾)|(家屬)|(護理師)', '', stc)
            # print(stc, stc_clean, label)
            if len(stc_clean) > 1:  
            # print(stc_clean, stc)
                all_stcs_clean.append(stc)
                all_article_ids_clean.append(article_id)
                all_word_ids_clean.append(word_id)

            # 這一步就先把label 做 0 padding
            
        max_length = len(max(all_stcs_clean, key=len))
        assert max_length > 0, 'max length less than 1'

        print('sentences總數: {}'.format(len(all_stcs_clean)))
            
        return all_stcs_clean, all_article_ids_clean, all_word_ids_clean

In [29]:
stcs, article_ids, word_ids = test_output(data= data).get_stcs()

sentences總數: 24302


In [32]:
clean_stcs, clean_article_id, clean_word_id = [], [] ,[]

for stc, article_id, word_id in zip(stcs, article_ids, word_ids):
#     print(stc, article_id, word_id)
    if stc not in ['沒有','也沒有','哪個','那個','算了','不用','有','有有有','有嗎','一點點', '謝謝','不會','不好意思','對不對','好不好','要嗎','還好']:
        clean_stcs.append(stc)
        clean_article_id.append(article_id)
        clean_word_id.append(word_id)
print(len(clean_stcs))

23673


In [21]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [18]:
max_len = max(len(txt) for txt in clean_stcs)
print(max_len)
print(len(clean_stcs))

51


In [22]:
encoding = tokenizer.batch_encode_plus(clean_stcs, 
                                 padding=True,
                                 add_special_tokens=False,
                                 return_attention_mask= True,
                                 return_token_type_ids= False,
                                 is_split_into_words=True,
                                 return_tensors='pt')

In [23]:
encoding['input_ids'].size()

torch.Size([23935, 51])

In [24]:
encoding['attention_mask'].size()

torch.Size([23935, 51])

In [25]:
encoding['attention_mask'][:2]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0]])

In [None]:
# loss, preds = model(input_ids, 

-----

# 輸出成csv
1. 需要id to tag dictionary
2. 對每個預測句子做iteration
3. pred : [句子數, 句子長度]

In [None]:
model = torch.load('ner_model_batch32_wup4000_lstmhd256_lr5e-5_40epoch_adam_wde-3.pt')

In [26]:
entity_text = []
for stc, labels, article_id, word_id in zip(clean_stcs, pred_labels, clean_article_id, clean_word_id):
    for label in labels:
        if bool(re.match(r'B-', label)):
            

['最近人有沒有什麼不舒服',
 '我們本來說要月',
 '年底才打對不對',
 '你是怕',
 '因為我要治療牙',
 '開始要治療牙齒了',
 '我又不常跑醫院',
 '我想說那先來打',
 '我今天來領藥嘛',
 '那順便過來打',
 '是是是',
 '你是要做',
 '植牙',
 '植牙的',
 '對對對',
 '你的免疫力是夠',
 '是沒有問題的',
 '在我們這裡做的嗎',
 '在新樓做的',
 '那這樣離你住的地方也有點距離',
 '你是住麻豆嗎',
 '住麻豆',
 '那辛苦了',
 '那我們就去',
 '等下就去打針這樣子',
 '沒什麼不舒服',
 '最近',
 '都沒有',
 '最近',
 '今年都是打四價的',
 '各個廠牌都有',
 '我覺得都沒差',
 '反正我也都沒在挑的',
 '沒關係',
 '我不知道',
 '這個要醫師',
 '應該都ok',
 '都還好',
 '都很安全',
 '這個你比較專業的',
 '血壓這樣ok',
 '不過可能打針要等一下',
 '因為最近打的人很多',
 '很踴躍',
 '就聽說年底可能會打不到',
 '會這樣',
 '有可能',
 '假如拖到年底可能會打不完',
 '那這樣我們這邊就結束了',
 '要去上課嗎',
 '這學期已經開始不上課了',
 '今天的',
 '應該是昨天的抽血',
 '肝、腎功能都很棒',
 '肌酐酸',
 '0.79',
 '換算成腎絲球過濾率預測值是88',
 '這個通常60以上就夠了',
 '60',
 '這個都很讚',
 '謝謝',
 'ast/alt就是我們常講的got還有gpt',
 '分別是20跟19',
 '那也是很理想',
 'crp就是那個',
 '我們那個發炎',
 '體內發炎的時候',
 '肝臟會分泌這個c反應蛋白',
 '這個也都2.6',
 '是大概也是ok',
 'ok',
 '那我們接下來就是等',
 '我們就等那個民眾',
 '超音波',
 '超音波做完',
 '再來停藥這樣子',
 '那個糖尿病',
 '那個血糖那個醫師',
 '血糖',
 '黃醫師說現在6.6',
 '那這樣不錯',
 '6.6不錯',
 '那時候8.3',
 '我記得那時候8.3',
 '四個月前8.3',
 '所以這樣吃應該可以',
 '可以',
 '應該可以',
 '

In [45]:
labels = ['O','O','B-','I-','I-','O']

In [50]:
entity = []
temp= []
last_temp = []

for x in labels:
    if x == 'O':
        if last_temp[-1] == 'I':
            entity.append(temp)
            last_temp.append(x)
        else:
            temp = []
            last_temp.append(x)
        
    elif bool(rematch(r'(B-)', x)):
        temp.append(x)
        last_temp.append(x)
    elif (bool(rematch(r'(I-)',x)))&(last_temp[-1] != 'O'):
        temp.append(x)
        last_temp.append(x)

IndexError: list index out of range

In [44]:
bool(re.match(r'(B-)|(I-)', 'B-orga')

True