In [None]:
pip install pytorch-crf

Collecting pytorch-crf
  Downloading https://files.pythonhosted.org/packages/96/7d/4c4688e26ea015fc118a0327e5726e6596836abce9182d3738be8ec2e32a/pytorch_crf-0.7.2-py3-none-any.whl
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [None]:
pip install transformers==3

Collecting transformers==3
[?25l  Downloading https://files.pythonhosted.org/packages/9c/35/1c3f6e62d81f5f0daff1384e6d5e6c5758682a8357ebc765ece2b9def62b/transformers-3.0.0-py3-none-any.whl (754kB)
[K     |████████████████████████████████| 757kB 10.8MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 22.9MB/s 
Collecting tokenizers==0.8.0-rc4
[?25l  Downloading https://files.pythonhosted.org/packages/e8/bd/e5abec46af977c8a1375c1dca7cb1e5b3ec392ef279067af7f6bc50491a0/tokenizers-0.8.0rc4-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)
[K     |████████████████████████████████| 3.0MB 43.2MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import sys
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup/run")
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup")

In [None]:
import csv
import re
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torchcrf import CRF
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

# Functions

In [None]:
class model_crf(nn.Module):
	def __init__(self, n_tags, hidden_dim=256, batchsize= 64, num_layers= 1, lstm_dropout= 0, fc_dropout= 0.2):
		super(model_crf, self).__init__()
		self.num_layers = num_layers
		self.n_tags = n_tags
		self.lstm =  nn.LSTM(bidirectional=True, num_layers=num_layers, input_size=768, hidden_size=hidden_dim//2, dropout= lstm_dropout, batch_first=True)		
		self.hidden_dim = hidden_dim
		self.fc = nn.Linear(hidden_dim, self.n_tags)
		self.bert = BertModel.from_pretrained('bert-base-chinese')

		# for param in self.bert.parameters():
		# 	param.requires_grad = False
		# self.bert.eval()  # 知用来取bert embedding

		self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
		self.CRF = CRF(n_tags, batch_first= True)
		self.dropout = nn.Dropout(p= fc_dropout)
		self.hidden = self.init_hidden(batchsize)

	def init_hidden(self, batch_size):
		return (torch.randn(2*self.num_layers, batch_size, self.hidden_dim // 2).to(self.device),
				torch.randn(2*self.num_layers, batch_size, self.hidden_dim // 2).to(self.device))

	def forward(self, input_ids, attention_mask, tags):

		batch_size = input_ids.size(0)
		max_seq_len = input_ids.size(1)
		bert_output, _  = self.bert(input_ids.long(), attention_mask)
		seq_len = torch.sum(attention_mask, dim= 1).cpu().int()
		# print(seq_len)
		pack_input = pack_padded_sequence(input= bert_output, lengths= seq_len, batch_first= True, enforce_sorted= False)
		packed_lstm_out, _ = self.lstm(pack_input,self.init_hidden(batch_size= batch_size))
		lstm_enc, _=  pad_packed_sequence(packed_lstm_out, batch_first=True)
		# print(lstm_enc.size())
		lstm_enc = self.dropout(lstm_enc)
		lstm_feats = self.fc(lstm_enc)

		lstm_max_seq_len = lstm_feats.size(1)
		pad = torch.zeros(size=(batch_size, max_seq_len-lstm_max_seq_len, self.n_tags), dtype= torch.float).to(self.device)
		lstm_feats= torch.cat((lstm_feats, pad), dim= 1)
  
		# lstm_feats[:,:,:4] = lstm_feats[:,:,:5]*100
		# lstm_feats[:,:,5:9] = lstm_feats[:,:,5:9]*10
		# lstm_feats[:,:,9:11] = lstm_feats[:,:,9:11]*100
		# lstm_feats[:,:,11] = lstm_feats[:,:,11]*100
		# lstm_feats[:,:,12:17] = lstm_feats[:,:,12:17]*100
		# lstm_feats[:,:,17:21] = lstm_feats[:,:,17:21]*10
		# lstm_feats[:,:,21:23] = lstm_feats[:,:,21:23]*100
		# lstm_feats[:,:,23] = lstm_feats[:,:,23]*1

		lstm_feats[:,:,:23] = lstm_feats[:,:,:23]*100
		lstm_feats[:,:,23] = lstm_feats[:,:,23]*10

		loss = -self.CRF(lstm_feats, tags, attention_mask.bool(), reduction= 'token_mean')
		pred_seqs = self.CRF.decode(emissions= lstm_feats, mask= attention_mask.bool())
  
		return loss, pred_seqs

In [None]:
class test_output():
    def __init__(self, data):
        
        self.data_list = []
        self.word_id = []
        self.word_article_id = [] 
        article_id = 0
        word_id = 0
        data_list_tmp = []
        article_id_tmp = []
        word_id_tmp = []
        
        for row in data:
            
            data_tuple = tuple()
            if row == '\n':
                
                article_id += 1 
                word_id = 0
                self.word_id.append(word_id)
                self.word_article_id.append(article_id_tmp)
                self.data_list.append(data_list_tmp)
                data_list_tmp = []
                article_id_tmp = []
                word_id_tmp = []

            else:
                
                row = row.strip('\n').split(' ')

                if row[0] in ['。', '？','！','，','～','：']:
                    
                    self.word_id.append(word_id_tmp)
                    self.word_article_id.append(article_id_tmp)
                    self.data_list.append(data_list_tmp)
                    data_list_tmp = []
                    article_id_tmp = []
                    word_id_tmp = []
                    
                elif row[0] in ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']:
                      
                    data_tuple = (row[0].lower(), article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)

                elif row[0] not in ['摁','嗯','啦','喔','欸','啊','齁','嘿','…','...'，'，']:
                    
                    data_tuple = (row[0], article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)
                    
                word_id += 1
                
        if len(data_list_tmp) != 0:
            self.data_list.append(data_list_tmp)
            self.word_id.append(word_id_tmp)
            self.word_article_id.append(article_id_tmp)
            
    def raw_output(self):
        return self.data_list, self.word_id, self.word_article_id
    
    def get_stcs(self):
        
        all_stcs = list()
        all_article_ids = list()
        all_word_ids = list()

        for stc_list in self.data_list:

            txt_len = len(stc_list) #(文章數，每個文章對應的總字數) (word, label)
            stc = str() #存字數= max_stc_len的字串
            article_ids = []
            word_ids = []
            

            for idx, (word,article_id, word_id) in enumerate(stc_list):

                stc += word
                article_ids.append(article_id)
                word_ids.append(word_id)

            all_stcs.append(stc)
            all_article_ids.append(article_ids)
            all_word_ids.append(word_ids)
        
        assert len(all_stcs) > 0, 'all stcs len = 0' 
        
        all_stcs_clean = []
        all_article_ids_clean = []
        all_word_ids_clean = []
        idx = 0
        
        for stc, article_id, word_id in zip(all_stcs, all_article_ids, all_word_ids):
            stc_clean = re.sub(r'(醫師)|(個管師)|(民眾)|(家屬)|(護理師)', '', stc)
            # print(stc, stc_clean, label)
            if len(stc_clean) > 1:  
            # print(stc_clean, stc)
                all_stcs_clean.append(stc)
                all_article_ids_clean.append(article_id)
                all_word_ids_clean.append(word_id)

            # 這一步就先把label 做 0 padding
            
        max_length = len(max(all_stcs_clean, key=len))
        assert max_length > 0, 'max length less than 1'

        print('sentences總數: {}'.format(len(all_stcs_clean)))
            
        return all_stcs_clean, all_article_ids_clean, all_word_ids_clean

# loading data

In [None]:
with open('/content/drive/My Drive/python檔/aicup/test_input.data', 'r', encoding= 'utf-8') as f:
    data = f.readlines()

print(data[:10])
    

['醫\n', '師\n', '：\n', '最\n', '近\n', '人\n', '有\n', '沒\n', '有\n', '什\n']


In [None]:
stcs, article_ids, word_ids = test_output(data= data).get_stcs()

sentences總數: 24302


In [None]:
idx = 0
for stc, word_id in zip(stcs, word_ids):
    if idx == 150:
        break
    print(stc, word_id)
    idx += 1 


最近人有沒有什麼不舒服 [3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
沒有 [18, 19]
沒有 [24, 25]
我們本來說要月 [28, 29, 30, 31, 32, 33, 34]
年底才打對不對 [36, 37, 38, 39, 40, 41, 42]
你是怕 [55, 56, 57]
因為我要治療牙 [62, 63, 64, 65, 66, 67, 68]
開始要治療牙齒了 [70, 71, 72, 73, 74, 75, 76, 77]
我又不常跑醫院 [88, 89, 90, 91, 92, 93, 94]
我想說那先來打 [104, 105, 106, 107, 108, 109, 110]
我今天來領藥嘛 [112, 113, 114, 115, 116, 117, 118]
那順便過來打 [120, 121, 122, 123, 124, 125]
是是是 [132, 133, 134]
你是要做 [137, 138, 139, 140]
植牙 [145, 146]
植牙的 [151, 152, 153]
對對對 [158, 159, 160]
你的免疫力是夠 [168, 169, 170, 171, 172, 173, 174]
是沒有問題的 [176, 177, 178, 179, 180, 181]
在我們這裡做的嗎 [183, 184, 185, 186, 187, 188, 189, 190]
在新樓做的 [197, 198, 199, 200, 201]
那這樣離你住的地方也有點距離 [208, 209, 210, 211, 212, 213, 214, 215, 216, 217, 218, 219, 220, 221]
你是住麻豆嗎 [224, 225, 226, 227, 228, 229]
住麻豆 [236, 237, 238]
那辛苦了 [243, 244, 245, 246]
那我們就去 [249, 250, 251, 252, 253]
等下就去打針這樣子 [255, 256, 257, 258, 259, 260, 261, 262, 263]
沒什麼不舒服 [266, 267, 268, 269, 270, 271]
最近 [275, 276]
沒有 [281, 282]
都沒有 [

In [None]:
clean_stcs, clean_article_id, clean_word_id = [], [] ,[]

for stc, article_id, word_id in zip(stcs, article_ids, word_ids):
#     print(stc, article_id, word_id)
    if stc not in ['沒有','也沒有','哪個','那個','算了','不用','有','有有有','有嗎','一點點', '謝謝','不會','不好意思','對不對','好不好','要嗎','還好']:
        clean_stcs.append(stc)
        clean_article_id.append(article_id)
        clean_word_id.append(word_id)
print(len(clean_stcs))

23673


In [None]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

In [None]:
max_len = max(len(txt) for txt in clean_stcs)
print(max_len)
print(len(clean_stcs))

51
23673


In [None]:
encoding = tokenizer.batch_encode_plus(clean_stcs, 
                                 padding=True,
                                 add_special_tokens=False,
                                 return_attention_mask= True,
                                 return_token_type_ids= False,
                                #  is_split_into_words=True,
                                 return_tensors='pt')

In [None]:
encoding['input_ids'].size()

torch.Size([23673, 51])

In [None]:
encoding['attention_mask'].size()

torch.Size([23673, 51])

In [None]:
encoding['attention_mask'][:2]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0]])

In [None]:
# loss, preds = model(input_ids, 

-----

# 輸出成csv
1. 需要id to tag dictionary
2. 對每個預測句子做iteration
3. pred : [句子數, 句子長度]

In [None]:
model = model_crf(n_tags= 25).to('cuda')
device = 'cuda'

In [None]:
model = torch.load('/content/drive/My Drive/python檔/aicup/64_wp0.4_256_lr1e-4_30epoch_1e-4_layer1_0.75.pt')
torch.save(model.state_dict(), 'model_dict.pt')

In [None]:
model.load_state_dict(torch.load('model_dict.pt'), strict=False)

<All keys matched successfully>

In [None]:
batch_size= 128
pred_labels = []
print(len(clean_stcs)/batch_size)

for idx in range(int((len(clean_stcs)/batch_size))):
  input= encoding['input_ids'][idx*batch_size:(idx+1)*batch_size].to(device)
  mask = encoding['attention_mask'][idx*batch_size:(idx+1)*batch_size].to(device)
  tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
  _, preds = model(input, mask, tags)
  for pred in preds:
    pred_labels.append(pred)

idx = int((len(clean_stcs)/batch_size))
input= encoding['input_ids'][idx*batch_size:].to(device)
mask = encoding['attention_mask'][idx*batch_size:].to(device)
tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
_, preds = model(input, mask, tags)
for pred in preds:
  pred_labels.append(pred)

In [None]:
tag2id = {'[PAD]': 0, 'B-ID': 1, 'B-clinical_event': 2, 'B-contact': 3, 'B-education': 4, 'B-family': 5, 'B-location': 6, 'B-med_exam': 7, 'B-money': 8, 'B-name': 9, 'B-organization': 10, 'B-profession': 11, 'B-time': 12, 'I-ID': 13, 'I-clinical_event': 14, 'I-contact': 15, 'I-education': 16, 'I-family': 17, 'I-location': 18, 'I-med_exam': 19, 'I-money': 20, 'I-name': 21, 'I-organization': 22, 'I-profession': 23, 'I-time': 24, 'O': 25}

In [None]:
id2tag ={v:k for k, v in tag2id.items()}

In [None]:
pred_labels_tag = []
for label in pred_labels:
  stc_label = [id2tag[id] for id in label]
  pred_labels_tag.append(stc_label)

In [None]:
pred_labels_tag[:3]

[['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O'],
 ['O', 'O', 'O', 'O', 'O', 'O', 'B-time'],
 ['O', 'O', 'O', 'O', 'O', 'O', 'O']]

In [None]:
# _, preds = model(encoding['input_ids'].to('cuda'), encoding['attention_mask'].to('cuda'), tags= torch.zeros((encoding['input_ids'].size(0),encoding['input_ids'].size(1))).to('cuda'))

In [None]:
entity_text = []

for stc, labels, article_id, word_id in zip(clean_stcs, pred_labels_tag, clean_article_id, clean_word_id):
    
    entity = str()
    
    start_pos = 0
    end_pos = 0
    article = 0
    
    entity_type = str()
    
    
    for idx, label in enumerate(labels):
        if bool(re.match(r'B-', label)):
            entity += list(stc)[idx]
            start_pos = word_id[idx]
            article = article_id[idx]
            entity_type = label.split('B-')[1]
            
        elif bool(re.match(r'I-', label)):
            entity += list(stc)[idx]
            end_pos= word_id[idx]+1
            try:
              if (labels[idx+1] == 'O') & (entity_type!=''):
                  entity_text.append((article, start_pos, end_pos, entity, entity_type))
                  
                  entity = str()
                  start_pos = 0
                  end_pos = 0
                  article = 0
                  entity_type = str()
            except:
              pass

In [None]:
entity_text

[(0, 113, 115, '今天', 'time'),
 (0, 198, 200, '新樓', 'location'),
 (0, 227, 229, '麻豆', 'location'),
 (1, 15, 18, '這學期', 'time'),
 (1, 31, 33, '今天', 'time'),
 (1, 40, 42, '昨天', 'time'),
 (1, 360, 363, '黃醫師', 'name'),
 (1, 385, 388, '6.6', 'med_exam'),
 (1, 469, 472, '三個月', 'time'),
 (1, 626, 629, '黃醫師', 'name'),
 (1, 742, 744, '中午', 'time'),
 (2, 402, 405, '三個月', 'time'),
 (2, 792, 795, '兩個月', 'time'),
 (2, 1090, 1092, '7.', 'med_exam'),
 (2, 1394, 1396, '45', 'money'),
 (2, 1406, 1408, '今天', 'time'),
 (3, 6, 8, '昨天', 'time'),
 (3, 11, 13, '一天', 'time'),
 (3, 221, 223, '昨天', 'time'),
 (3, 225, 227, '早上', 'time'),
 (3, 263, 267, '今天早上', 'time'),
 (3, 538, 540, '前天', 'time'),
 (3, 602, 604, '每天', 'time'),
 (3, 643, 645, '昨天', 'time'),
 (3, 662, 664, '昨天', 'time'),
 (3, 673, 675, '今天', 'time'),
 (3, 678, 681, '吳醫師', 'name'),
 (3, 748, 750, '今天', 'time'),
 (3, 817, 819, '今天', 'time'),
 (3, 821, 823, '今天', 'time'),
 (3, 867, 870, '禮拜二', 'time'),
 (3, 871, 874, '禮拜三', 'time'),
 (4, 6, 8, '嘉明', 

In [None]:
with open('test_output.tsv', 'w', encoding='utf-8',newline='\n') as f:
  writer = csv.writer(f, delimiter='\t')
  writer.writerow(['article_id','start_position', 'end_position', 'entity_text', 'entity_type'])
  for (article, start_pos, end_pos, entity, entity_type) in entity_text:
    writer.writerow([str(article), str(start_pos), str(end_pos), str(entity), str(entity_type)])