In [1]:
pip install pytorch-crf



In [2]:
pip install transformers==3



In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
import sys
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup/run")
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup")

In [5]:
import csv
import re
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torchcrf import CRF
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Functions

In [6]:
class model_crf(nn.Module):
	def __init__(self, n_tags, hidden_dim=256, batchsize= 64, num_layers= 1, lstm_dropout= 0, fc_dropout= 0.2):
		super(model_crf, self).__init__()
		self.num_layers = num_layers
		self.n_tags = n_tags
		self.lstm =  nn.LSTM(bidirectional=True, num_layers=num_layers, input_size=768, hidden_size=hidden_dim//2, dropout= lstm_dropout, batch_first=True)		
		self.hidden_dim = hidden_dim
		self.fc = nn.Linear(hidden_dim, self.n_tags)
		self.bert = BertModel.from_pretrained('bert-base-chinese')

		# for param in self.bert.parameters():
		# 	param.requires_grad = False
		# self.bert.eval()  # 知用来取bert embedding

		self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
		self.CRF = CRF(n_tags, batch_first= True)
		self.dropout = nn.Dropout(p= fc_dropout)
		self.hidden = self.init_hidden(batchsize)

	def init_hidden(self, batch_size):
		return (torch.randn(2*self.num_layers, batch_size, self.hidden_dim // 2).to(self.device),
				torch.randn(2*self.num_layers, batch_size, self.hidden_dim // 2).to(self.device))

	def forward(self, input_ids, attention_mask, tags):

		batch_size = input_ids.size(0)
		max_seq_len = input_ids.size(1)
		bert_output, _  = self.bert(input_ids.long(), attention_mask)
		seq_len = torch.sum(attention_mask, dim= 1).cpu().int()
		# print(seq_len)
		pack_input = pack_padded_sequence(input= bert_output, lengths= seq_len, batch_first= True, enforce_sorted= False)
		packed_lstm_out, _ = self.lstm(pack_input,self.init_hidden(batch_size= batch_size))
		lstm_enc, _=  pad_packed_sequence(packed_lstm_out, batch_first=True)
		# print(lstm_enc.size())
		lstm_enc = self.dropout(lstm_enc)
		lstm_feats = self.fc(lstm_enc)

		lstm_max_seq_len = lstm_feats.size(1)
		pad = torch.zeros(size=(batch_size, max_seq_len-lstm_max_seq_len, self.n_tags), dtype= torch.float).to(self.device)
		lstm_feats= torch.cat((lstm_feats, pad), dim= 1)
  
		# lstm_feats[:,:,:4] = lstm_feats[:,:,:5]*100
		# lstm_feats[:,:,5:9] = lstm_feats[:,:,5:9]*10
		# lstm_feats[:,:,9:11] = lstm_feats[:,:,9:11]*100
		# lstm_feats[:,:,11] = lstm_feats[:,:,11]*100
		# lstm_feats[:,:,12:17] = lstm_feats[:,:,12:17]*100
		# lstm_feats[:,:,17:21] = lstm_feats[:,:,17:21]*10
		# lstm_feats[:,:,21:23] = lstm_feats[:,:,21:23]*100
		# lstm_feats[:,:,23] = lstm_feats[:,:,23]*1

		lstm_feats[:,:,:23] = lstm_feats[:,:,:23]*100
		lstm_feats[:,:,23] = lstm_feats[:,:,23]*10

		loss = -self.CRF(lstm_feats, tags, attention_mask.bool(), reduction= 'token_mean')
		pred_seqs = self.CRF.decode(emissions= lstm_feats, mask= attention_mask.bool())
  
		return loss, pred_seqs

In [7]:
class test_output():
    def __init__(self, data):
        
        self.data_list = []
        self.word_id = []
        self.word_article_id = [] 
        article_id = 0
        word_id = 0
        data_list_tmp = []
        article_id_tmp = []
        word_id_tmp = []
        
        for row in data:
            
            data_tuple = tuple()
            if row == '\n':
                
                article_id += 1 
                word_id = 0
                self.word_id.append(word_id)
                self.word_article_id.append(article_id_tmp)
                self.data_list.append(data_list_tmp)
                data_list_tmp = []
                article_id_tmp = []
                word_id_tmp = []
            
            if (row[0] == '，') & (len(data_list_tmp) >= 64):

                article_id += 1 
                word_id = 0
                self.word_id.append(word_id)
                self.word_article_id.append(article_id_tmp)
                self.data_list.append(data_list_tmp)
                data_list_tmp = []
                article_id_tmp = []
                word_id_tmp = []

            else:
                
                row = row.strip('\n').split(' ')

                if row[0] in ['。', '？','！','～','：']:
                    
                    self.word_id.append(word_id_tmp)
                    self.word_article_id.append(article_id_tmp)
                    self.data_list.append(data_list_tmp)
                    data_list_tmp = []
                    article_id_tmp = []
                    word_id_tmp = []
                    
                elif row[0] in ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']:
                      
                    data_tuple = (row[0].lower(), article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)

                elif row[0] not in ['摁','嗯','啦','喔','欸','啊','齁','嘿','…','...','，']:
                    
                    data_tuple = (row[0], article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)
                    
                word_id += 1
                
        if len(data_list_tmp) != 0:
            self.data_list.append(data_list_tmp)
            self.word_id.append(word_id_tmp)
            self.word_article_id.append(article_id_tmp)
            
    def raw_output(self):
        return self.data_list, self.word_id, self.word_article_id
    
    def get_stcs(self):
        
        all_stcs = list()
        all_article_ids = list()
        all_word_ids = list()

        for stc_list in self.data_list:

            txt_len = len(stc_list) #(文章數，每個文章對應的總字數) (word, label)
            stc = str() #存字數= max_stc_len的字串
            article_ids = []
            word_ids = []
            

            for idx, (word,article_id, word_id) in enumerate(stc_list):

                stc += word
                article_ids.append(article_id)
                word_ids.append(word_id)

            all_stcs.append(stc)
            all_article_ids.append(article_ids)
            all_word_ids.append(word_ids)
        
        assert len(all_stcs) > 0, 'all stcs len = 0' 
        
        all_stcs_clean = []
        all_article_ids_clean = []
        all_word_ids_clean = []
        idx = 0
        
        for stc, article_id, word_id in zip(all_stcs, all_article_ids, all_word_ids):
            stc_clean = re.sub(r'(醫師)|(個管師)|(民眾)|(家屬)|(護理師)', '', stc)
            # print(stc, stc_clean, label)
            if len(stc_clean) > 1:  
            # print(stc_clean, stc)
                all_stcs_clean.append(stc)
                all_article_ids_clean.append(article_id)
                all_word_ids_clean.append(word_id)

            # 這一步就先把label 做 0 padding
            
        max_length = len(max(all_stcs_clean, key=len))
        assert max_length > 0, 'max length less than 1'

        print('sentences總數: {}'.format(len(all_stcs_clean)))
            
        return all_stcs_clean, all_article_ids_clean, all_word_ids_clean

# loading data

In [8]:
with open('/content/drive/My Drive/python檔/aicup/test_input.data', 'r', encoding= 'utf-8') as f:
    data = f.readlines()

print(data[:10])
    

['醫\n', '師\n', '：\n', '最\n', '近\n', '人\n', '有\n', '沒\n', '有\n', '什\n']


In [9]:
stcs, article_ids, word_ids = test_output(data= data).get_stcs()

sentences總數: 15391


In [10]:
idx = 0
for stc, word_id in zip(stcs, word_ids):
    if idx == 150:
        break
    print(stc, word_id)
    idx += 1 


最近人有沒有什麼不舒服 [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
沒有 [18, 19]
沒有 [24, 25]
我們本來說要月年底才打對不對 [28, 29, 30, 31, 32, 33, 34, 36, 37, 38, 39, 40, 41, 42]
對但醫師 [47, 49, 52, 53]
你是怕 [55, 56, 57]
因為我要治療牙開始要治療牙齒了 [62, 63, 64, 65, 66, 67, 68, 70, 71, 72, 73, 74, 75, 76, 77]
我又不常跑醫院 [88, 89, 90, 91, 92, 93, 94]
我想說那先來打我今天來領藥嘛那順便過來打 [104, 105, 106, 107, 108, 109, 110, 112, 113, 114, 115, 116, 117, 118, 120, 121, 122, 123, 124, 125]
是是是 [132, 133, 134]
你是要做 [137, 138, 139, 140]
植牙 [145, 146]
植牙的 [151, 152, 153]
對對對 [158, 159, 160]
你的免疫力是夠是沒有問題的 [168, 169, 170, 171, 172, 173, 174, 176, 177, 178, 179, 180, 181]
在我們這裡做的嗎 [183, 184, 185, 186, 187, 188, 189, 190]
對在新樓做的 [195, 197, 198, 199, 200, 201]
那這樣離你住的地方也有點距離你是住麻豆嗎 [208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221, 224, 225, 226, 227, 228, 229]
對住麻豆 [234, 236, 237, 238]
那辛苦了那我們就去等下就去打針這樣子 [243, 244, 245, 246, 249, 250, 251, 252, 253, 255, 256, 257, 258, 259, 260, 261, 262, 263]
沒什麼不舒服 [266, 267, 268, 269, 270, 271]
最近 [275, 276]
沒

In [11]:
clean_stcs, clean_article_id, clean_word_id = [], [] ,[]

for stc, article_id, word_id in zip(stcs, article_ids, word_ids):
#     print(stc, article_id, word_id)
    if stc not in ['沒有','也沒有','哪個','那個','算了','不用','有','有有有','有嗎','一點點', '謝謝','不會','不好意思','對不對','好不好','要嗎','還好']:
        clean_stcs.append(stc)
        clean_article_id.append(article_id)
        clean_word_id.append(word_id)
print(len(clean_stcs))

15117


In [12]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [13]:
max_len = max(len(txt) for txt in clean_stcs)
print(max_len)
print(len(clean_stcs))

85
15117


In [14]:
clean_stcs_space = [' '.join(list(stc)) for stc in clean_stcs]

In [15]:
clean_stcs_space[:50]

['最 近 人 有 沒 有 什 麼 不 舒 服',
 '我 們 本 來 說 要 月 年 底 才 打 對 不 對',
 '對 但 醫 師',
 '你 是 怕',
 '因 為 我 要 治 療 牙 開 始 要 治 療 牙 齒 了',
 '我 又 不 常 跑 醫 院',
 '我 想 說 那 先 來 打 我 今 天 來 領 藥 嘛 那 順 便 過 來 打',
 '是 是 是',
 '你 是 要 做',
 '植 牙',
 '植 牙 的',
 '對 對 對',
 '你 的 免 疫 力 是 夠 是 沒 有 問 題 的',
 '在 我 們 這 裡 做 的 嗎',
 '對 在 新 樓 做 的',
 '那 這 樣 離 你 住 的 地 方 也 有 點 距 離 你 是 住 麻 豆 嗎',
 '對 住 麻 豆',
 '那 辛 苦 了 那 我 們 就 去 等 下 就 去 打 針 這 樣 子',
 '沒 什 麼 不 舒 服',
 '最 近',
 '沒 有 都 沒 有',
 '最 近 今 年 都 是 打 四 價 的',
 '各 個 廠 牌 都 有 我 覺 得 都 沒 差',
 '反 正 我 也 都 沒 在 挑 的',
 '沒 關 係',
 '我 不 知 道 這 個 要 醫 師',
 '應 該 都 o k 都 還 好 都 很 安 全',
 '這 個 你 比 較 專 業 的',
 '血 壓 這 樣 o k',
 '不 過 可 能 打 針 要 等 一 下 因 為 最 近 打 的 人 很 多',
 '很 踴 躍',
 '對 就 聽 說 年 底 可 能 會 打 不 到',
 '會 這 樣',
 '好 有 可 能 假 如 拖 到 年 底 可 能 會 打 不 完',
 '好 那 這 樣 我 們 這 邊 就 結 束 了',
 '要 去 上 課 嗎',
 '這 學 期 已 經 開 始 不 上 課 了',
 '今 天 的 呃 應 該 是 昨 天 的 抽 血',
 '肝 、 腎 功 能 都 很 棒',
 '肌 酐 酸 0 . 7 9 換 算 成 腎 絲 球 過 濾 率 預 測 值 是 8 8',
 '這 個 通 常 6 0 以 上 就 夠 了',
 '6 0',
 '這 個 都 很 讚',
 'a s t / a l t 就 是 我 們 常 講 的 g o t 還 有 g p t',
 '分 別 是 2 0 跟 1 9'

In [16]:
encoding = tokenizer.batch_encode_plus(clean_stcs_space, 
                                 padding=True,
                                 add_special_tokens=False,
                                 return_attention_mask= True,
                                 return_token_type_ids= False,
                                #  is_split_into_words=True,
                                 return_tensors='pt')

In [17]:
encoding['input_ids'].size()

torch.Size([15117, 85])

In [18]:
encoding['attention_mask'].size()

torch.Size([15117, 85])

In [19]:
encoding['attention_mask'][:2]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])

In [20]:
# loss, preds = model(input_ids, 

-----

# 輸出成csv
1. 需要id to tag dictionary
2. 對每個預測句子做iteration
3. pred : [句子數, 句子長度]

In [21]:
model = model_crf(n_tags= 25).to('cuda')
device = 'cuda'

In [22]:
model = torch.load('/content/drive/My Drive/python檔/aicup/64_wp0.4_256_lr1e-4_30epoch_1e-4_layer1_0.75.pt')
torch.save(model.state_dict(), 'model_dict.pt')

In [23]:
model.load_state_dict(torch.load('model_dict.pt'), strict=False)

<All keys matched successfully>

In [None]:
batch_size= 128
pred_labels = []
print(len(clean_stcs)/batch_size)
print((len(clean_stcs) % batch_size))

model.eval()

for idx in range(int((len(clean_stcs)/batch_size))):
  input= encoding['input_ids'][idx*batch_size:(idx+1)*batch_size].to(device)
  mask = encoding['attention_mask'][idx*batch_size:(idx+1)*batch_size].to(device)
  tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
  with torch.no_grad():
    _, preds = model(input, mask, tags)
  for pred in preds:
    pred_labels.append(pred)

print('ok')

if (len(clean_stcs) % batch_size) != 0:
  idx = int((len(clean_stcs)/batch_size))
  input= encoding['input_ids'][idx*batch_size:].to(device)
  mask = encoding['attention_mask'][idx*batch_size:].to(device)
  tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
  with torch.no_grad():
    _, preds = model(input, mask, tags)
  for pred in preds:
    pred_labels.append(pred)

118.1015625
13


In [None]:
tag2id = {'B-ID': 0, 'B-clinical_event': 1, 'B-contact': 2, 'B-education': 3, 'B-family': 4, 'B-location': 5, 'B-med_exam': 6, 'B-money': 7, 'B-name': 8, 'B-organization': 9, 'B-profession': 10, 'B-time': 11, 'I-ID': 12, 'I-clinical_event': 13, 'I-contact': 14, 'I-education': 15, 'I-family': 16, 'I-location': 17, 'I-med_exam': 18, 'I-money': 19, 'I-name': 20, 'I-organization': 21, 'I-profession': 22, 'I-time': 23, 'O': 24}

In [None]:
id2tag ={v:k for k, v in tag2id.items()}

In [None]:
pred_labels_tag = []
for label in pred_labels:
  stc_label = [id2tag[id] for id in label]
  pred_labels_tag.append(stc_label)

In [None]:
pred_labels_tag[:3]

In [None]:
# _, preds = model(encoding['input_ids'].to('cuda'), encoding['attention_mask'].to('cuda'), tags= torch.zeros((encoding['input_ids'].size(0),encoding['input_ids'].size(1))).to('cuda'))

In [None]:
entity_text = []

for stc, labels, article_id, word_id in zip(clean_stcs, pred_labels_tag, clean_article_id, clean_word_id):
    if len(set(labels)) >1:
        print(stc)
        print(labels)
    entity = str()
    
    start_pos = 0
    end_pos = 0
    article = 0
    
    entity_type = str()
    
    
    for idx, label in enumerate(labels):
        if bool(re.match(r'B-', label)):
            entity += list(stc)[idx]
            start_pos = word_id[idx]
            article = article_id[idx]
            entity_type = label.split('B-')[1]
            
        elif bool(re.match(r'I-', label)):
            entity += list(stc)[idx]
            end_pos= word_id[idx]+1
            try:
              if (labels[idx+1] == 'O') & (entity_type!=''):
                  entity_text.append((article, start_pos, end_pos, entity, entity_type))
                  
                  entity = str()
                  start_pos = 0
                  end_pos = 0
                  article = 0
                  entity_type = str()
            except:
              pass

In [None]:
entity_text

In [None]:
with open('test_output.tsv', 'w', encoding='utf-8',newline='\n') as f:
  writer = csv.writer(f, delimiter='\t')
  writer.writerow(['article_id','start_position', 'end_position', 'entity_text', 'entity_type'])
  for (article, start_pos, end_pos, entity, entity_type) in entity_text:
    writer.writerow([str(article), str(start_pos), str(end_pos), str(entity), str(entity_type)])