In [1]:
pip install pytorch-crf

Collecting pytorch-crf
  Downloading https://files.pythonhosted.org/packages/96/7d/4c4688e26ea015fc118a0327e5726e6596836abce9182d3738be8ec2e32a/pytorch_crf-0.7.2-py3-none-any.whl
Installing collected packages: pytorch-crf
Successfully installed pytorch-crf-0.7.2


In [2]:
pip install transformers==3

Collecting transformers==3
[?25l  Downloading https://files.pythonhosted.org/packages/9c/35/1c3f6e62d81f5f0daff1384e6d5e6c5758682a8357ebc765ece2b9def62b/transformers-3.0.0-py3-none-any.whl (754kB)
[K     |████████████████████████████████| 757kB 9.4MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/e5/2d/6d4ca4bef9a67070fa1cac508606328329152b1df10bdf31fb6e4e727894/sentencepiece-0.1.94-cp36-cp36m-manylinux2014_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 12.3MB/s 
Collecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 29.3MB/s 
Collecting tokenizers==0.8.0-rc4
[?25l  Downloading https://files.pythonhosted.org/packages/e8/bd/e5abec46af977c8a1375c1dca7cb1e5b3ec392ef279067af7f6bc50491a0/tokenizers-0.8.0rc4-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
import sys
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup/run")
sys.path.insert(0,"/content/drive/My Drive/python檔/aicup")

In [5]:
import csv
import re
from transformers import BertTokenizer, BertModel
import torch
import torch.nn as nn
from torchcrf import CRF
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [6]:
class model_crf(nn.Module):
    def __init__(self, n_tags, hidden_dim=256, batchsize= 32):
        super(model_crf, self).__init__()
        self.n_tags = n_tags
        self.lstm =  nn.LSTM(bidirectional=True, num_layers=2, input_size=768, hidden_size=hidden_dim//2, dropout= 0.3, batch_first=True)		
        self.hidden_dim = hidden_dim
        self.fc = nn.Linear(hidden_dim, self.n_tags)
        self.bert = BertModel.from_pretrained('bert-base-chinese')

        # for param in self.bert.parameters():
        # 	param.requires_grad = False
        # self.bert.eval()  # 知用来取bert embedding

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        self.CRF = CRF(n_tags, batch_first= True)
        self.hidden = self.init_hidden(batchsize)

    def init_hidden(self, batch_size):
        return (torch.randn(2*2, batch_size, self.hidden_dim // 2).to(self.device),
                torch.randn(2*2, batch_size, self.hidden_dim // 2).to(self.device))

    def forward(self, input_ids, attention_mask, tags):

        batch_size = input_ids.size(0)
        max_seq_len = input_ids.size(1)
        bert_output, _  = self.bert(input_ids.long(), attention_mask)

        seq_len = torch.sum(attention_mask, dim= 1).cpu().int()
        # print(seq_len)
        pack_input = pack_padded_sequence(input= bert_output, lengths= seq_len, batch_first= True, enforce_sorted= False)
        packed_lstm_out, _ = self.lstm(pack_input,self.init_hidden(batch_size= batch_size))
        lstm_enc, _=  pad_packed_sequence(packed_lstm_out, batch_first=True, padding_value=0)
        # print(lstm_enc.size())
        lstm_feats = self.fc(lstm_enc)

        lstm_max_seq_len = lstm_feats.size(1)
        pad = torch.zeros(size=(batch_size, max_seq_len-lstm_max_seq_len, self.n_tags), dtype= torch.float).to(self.device)
        lstm_feats= torch.cat((lstm_feats, pad), dim= 1)

        lstm_feats[:,:,1:-1] = lstm_feats[:,:,1:-1]*100
        loss = -self.CRF(lstm_feats, tags, attention_mask.bool(), reduction= 'token_mean')
        pred_seqs = self.CRF.decode(emissions= lstm_feats, mask= attention_mask.bool())

        return loss, pred_seqs

In [7]:
with open('/content/drive/My Drive/python檔/aicup/test_input.data', 'r', encoding= 'utf-8') as f:
    data = f.readlines()

print(data[:10])
    

['醫\n', '師\n', '：\n', '最\n', '近\n', '人\n', '有\n', '沒\n', '有\n', '什\n']


In [8]:
class test_output():
    def __init__(self, data):
        
        self.data_list = []
        self.word_id = []
        self.word_article_id = [] 
        article_id = 0
        word_id = 0
        data_list_tmp = []
        article_id_tmp = []
        word_id_tmp = []
        
        for row in data:
            
            data_tuple = tuple()
            if row == '\n':
                
                article_id += 1 
                word_id = 0
                self.word_id.append(word_id)
                self.word_article_id.append(article_id_tmp)
                self.data_list.append(data_list_tmp)
                data_list_tmp = []
                article_id_tmp = []
                word_id_tmp = []

            else:
                
                row = row.strip('\n').split(' ')

                if row[0] in ['。', '？','！','，','～','：',',','‧']:
                    
                    self.word_id.append(word_id_tmp)
                    self.word_article_id.append(article_id_tmp)
                    self.data_list.append(data_list_tmp)
                    data_list_tmp = []
                    article_id_tmp = []
                    word_id_tmp = []
                    
                elif row[0] in ['A','B','C','D','E','F','G','H','I','J','K','L','M','N','O','P','Q','R','S','T','U','V','W','X','Y','Z']:
                      
                    data_tuple = (row[0].lower(), article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)

                elif row[0] not in ['摁','嗯','啦','喔','欸','啊','齁','嘿','…','...']:
                    
                    data_tuple = (row[0], article_id, word_id)
                    data_list_tmp.append(data_tuple)
                    article_id_tmp.append(article_id)
                    word_id_tmp.append(word_id)
                    
                word_id += 1
                
        if len(data_list_tmp) != 0:
            self.data_list.append(data_list_tmp)
            self.word_id.append(word_id_tmp)
            self.word_article_id.append(article_id_tmp)
            
    def raw_output(self):
        return self.data_list, self.word_id, self.word_article_id
    
    def get_stcs(self):
        
        all_stcs = list()
        all_article_ids = list()
        all_word_ids = list()

        for stc_list in self.data_list:

            txt_len = len(stc_list) #(文章數，每個文章對應的總字數) (word, label)
            stc = str() #存字數= max_stc_len的字串
            article_ids = []
            word_ids = []
            

            for idx, (word,article_id, word_id) in enumerate(stc_list):

                stc += word
                article_ids.append(article_id)
                word_ids.append(word_id)

            all_stcs.append(stc)
            all_article_ids.append(article_ids)
            all_word_ids.append(word_ids)
        
        assert len(all_stcs) > 0, 'all stcs len = 0' 
        
        all_stcs_clean = []
        all_article_ids_clean = []
        all_word_ids_clean = []
        idx = 0
        
        for stc, article_id, word_id in zip(all_stcs, all_article_ids, all_word_ids):
            stc_clean = re.sub(r'(醫師)|(個管師)|(民眾)|(家屬)|(護理師)', '', stc)
            # print(stc, stc_clean, label)
            if len(stc_clean) > 1:  
            # print(stc_clean, stc)
                all_stcs_clean.append(stc)
                all_article_ids_clean.append(article_id)
                all_word_ids_clean.append(word_id)

            # 這一步就先把label 做 0 padding
            
        max_length = len(max(all_stcs_clean, key=len))
        assert max_length > 0, 'max length less than 1'

        print('sentences總數: {}'.format(len(all_stcs_clean)))
            
        return all_stcs_clean, all_article_ids_clean, all_word_ids_clean

In [9]:
stcs, article_ids, word_ids = test_output(data= data).get_stcs()

sentences總數: 24302


In [10]:
clean_stcs, clean_article_id, clean_word_id = [], [] ,[]

for stc, article_id, word_id in zip(stcs, article_ids, word_ids):
#     print(stc, article_id, word_id)
    if stc not in ['沒有','也沒有','哪個','那個','算了','不用','有','有有有','有嗎','一點點', '謝謝','不會','不好意思','對不對','好不好','要嗎','還好']:
        clean_stcs.append(stc)
        clean_article_id.append(article_id)
        clean_word_id.append(word_id)
print(len(clean_stcs))

23673


In [11]:
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=109540.0, style=ProgressStyle(descripti…




In [12]:
max_len = max(len(txt) for txt in clean_stcs)
print(max_len)
print(len(clean_stcs))

51
23673


In [13]:
encoding = tokenizer.batch_encode_plus(clean_stcs, 
                                 padding=True,
                                 add_special_tokens=False,
                                 return_attention_mask= True,
                                 return_token_type_ids= False,
                                #  is_split_into_words=True,
                                 return_tensors='pt')

In [14]:
encoding['input_ids'].size()

torch.Size([23673, 51])

In [15]:
encoding['attention_mask'].size()

torch.Size([23673, 51])

In [16]:
encoding['attention_mask'][:2]

tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
         0, 0, 0]])

In [17]:
# loss, preds = model(input_ids, 

-----

# 輸出成csv
1. 需要id to tag dictionary
2. 對每個預測句子做iteration
3. pred : [句子數, 句子長度]

In [18]:
model = model_crf(n_tags= 26).to('cuda')
device = 'cuda'

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=624.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=411577189.0, style=ProgressStyle(descri…




In [21]:
model = torch.load('/content/drive/My Drive/python檔/aicup/ner_model_batch32_wup4000_lstmhd256_lr5e-5_40epoch_adam_wde-3.pt')
torch.save(model.state_dict(), 'model_dict.pt')

In [23]:
model.load_state_dict(torch.load('model_dict.pt'), strict=False)

<All keys matched successfully>

In [None]:
batch_size= 32
pred_labels = []

for idx in range(int((len(clean_stcs)/batch_size))):
  input= encoding['input_ids'][idx*batch_size:(idx+1)*batch_size].to(device)
  mask = encoding['attention_mask'][idx*batch_size:(idx+1)*batch_size].to(device)
  tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
  _, preds = model(input, mask, tags)
  for pred in preds:
    pred_labels.append(pred)

idx = int((len(clean_stcs)/batch_size))
input= encoding['input_ids'][idx*batch_size:].to(device)
mask = encoding['attention_mask'][idx*batch_size:].to(device)
tags= torch.zeros((input.size(0),input.size(1)), dtype=torch.long).to(device)
_, preds = model(input, mask, tags)
for pred in preds:
  pred_labels.append(pred)

In [None]:
tag2id = {'[PAD]': 0, 'B-ID': 1, 'B-clinical_event': 2, 'B-contact': 3, 'B-education': 4, 'B-family': 5, 'B-location': 6, 'B-med_exam': 7, 'B-money': 8, 'B-name': 9, 'B-organization': 10, 'B-profession': 11, 'B-time': 12, 'I-ID': 13, 'I-clinical_event': 14, 'I-contact': 15, 'I-education': 16, 'I-family': 17, 'I-location': 18, 'I-med_exam': 19, 'I-money': 20, 'I-name': 21, 'I-organization': 22, 'I-profession': 23, 'I-time': 24, 'O': 25}

In [None]:
id2tag ={v:k for k, v in tag2id.items()}

In [None]:
pred_labels_tag = []
for label in pred_labels:
  stc_label = [id2tag[id] for id in label]
  pred_labels_tag.append(stc_label)

In [None]:
pred_labels_tag

In [None]:
# _, preds = model(encoding['input_ids'].to('cuda'), encoding['attention_mask'].to('cuda'), tags= torch.zeros((encoding['input_ids'].size(0),encoding['input_ids'].size(1))).to('cuda'))

In [None]:
entity_text = []

for stc, labels, article_id, word_id in zip(clean_stcs, pred_labels_tag, clean_article_id, clean_word_id):
    
    entity = str()
    
    start_pos = 0
    end_pos = 0
    article = 0
    
    entity_type = str()
    
    
    for idx, label in enumerate(labels):
        if bool(re.match(r'B-', label)):
            entity += list(stc)[idx]
            start_pos = word_id[idx]
            article = article_id[idx]
            entity_type = label.split('B-')[1]
            
        elif bool(re.match(r'I-', label)):
            entity += list(stc)[idx]
            end_pos= word_id[idx]
            try:
              if (labels[idx+1] == 'O') & (entity_type!=''):
                  entity_text.append((article, start_pos, end_pos, entity, entity_type))
                  
                  entity = str()
                  start_pos = 0
                  end_pos = 0
                  article = 0
                  entity_type = str()
            except:
              pass

In [None]:
entity_text

In [None]:
with open('test_output.tsv', 'w', encoding='utf-16') as f:
  writer = csv.writer(f, delimiter='\t')
  for (article, start_pos, end_pos, entity, entity_type) in entity_text:
    writer.writerow([str(article), str(start_pos), str(end_pos), str(entity), str(entity_type)])