In [5]:
import torch
import torch.nn as nn
from crf import CRF
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import DistilBertModel

import constant as config
class BERT_CRF_Joint(nn.Module):
    def __init__(self, config=config, bert=None, distill=False):
        super(BERT_CRF_Joint, self).__init__()
        
        
        #별도의 BERT모델을 지정하지 않으면 SKT KoBERT를 Default로 지정한다. 
        self.bert = bert
        self.distill=distill
        if bert is None:
            if self.distill == True:
                self.bert = DistilBertModel.from_pretrained('monologg/distilkobert')
            else:
                self.bert, self.vocab  = get_pytorch_kobert_model()
                
            for param in self.bert.parameters():
                param.requires_grad = True
            
        
        self.dropout = nn.Dropout(config.dropout)
        self.crf_linear = nn.Linear(config.hidden_size, config.num_entity)
        self.intent_classifier = nn.Linear(config.hidden_size, config.num_intent)
        self.bilstm  = nn.LSTM(config.hidden_size, config.hidden_size //2, 
                               batch_first=True, bidirectional=True )
        self.crf = CRF(num_tags=config.num_entity, batch_first=True)
    
    
    #Sentence의 길이만큼만 Attention을 취하기 위해 Mask를 생성한다.
    def get_attention_mask(self, input_ids, valid_length):
        attention_mask = torch.zeros_like(input_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()
    
    def forward(self, input_ids, valid_length, token_type_ids, entity=None, intent=None):
        attention_mask = self.get_attention_mask(input_ids, valid_length)
        
        #all_encoder_layers는 BERT의 output representation 전부이고
        #poold_output은 CLS token의 representation 값이다.
        #기본 kobert와 distill kobert의 output형태가 다르기 때문에 분기처리하였다.
        if self.distill==True:
            outputs = self.bert(input_ids=input_ids.long(), 
                                attention_mask=attention_mask) 
            
            all_encoder_layers, pooled_output = outputs[0], outputs[0][:,0,:]
            
        else:
            all_encoder_layers, pooled_output = self.bert(input_ids=input_ids.long(),
                                                      token_type_ids=token_type_ids,
                                                      attention_mask=attention_mask)

        cls_out = pooled_output
        cls_out_drop = self.dropout(cls_out)
        logits = self.intent_classifier(cls_out_drop)
        
        # Entity on CRF
        last_encoder_layer = all_encoder_layers
        drop = self.dropout(last_encoder_layer)
        output, hc = self.bilstm(drop)
        linear = self.crf_linear(output)
        tag_seq = self.crf.decode(linear)

        # For training
        if entity is not None:
            log_likelihood = self.crf(linear, entity)       
            return log_likelihood, tag_seq, logits
        
        # For inference
        else: 
            confidence = self.crf.compute_confidence(linear, tag_seq)
            return tag_seq, confidence, logits
               

# TEST CODE

In [8]:
from tokenization_kobert import KoBertTokenizer
tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert') # monologg/distilkobert도 동일
batch=[]
valid_length=[]
batch.append(tokenizer.encode("안녕하세요. 반갑습니다. 테스트코드를 작성합니다."))
valid_length.append(len(tokenizer.encode("안녕하세요. 반갑습니다. 테스트코드를 작성합니다.")))
batch.append(tokenizer.encode("오늘은 버트를 이용하여 테스트를 해보겠습니다."))
valid_length.append(len(tokenizer.encode("오늘은 버트를 이용하여 테스트를 해보겠습니다.")))
batch.append(tokenizer.encode("조인트 모델이 잘 돌아가는지 궁금하군요."))
valid_length.append(len(tokenizer.encode("조인트 모델이 잘 돌아가는지 궁금하군요.")))


maxlen=max(valid_length)
input_ids = [sen+[0]*(maxlen-len(sen)) for sen in batch]
input_ids= torch.tensor(input_ids)

In [9]:
joint = BERT_CRF_Joint(distill=True)
joint(input_ids, valid_length, token_type_ids=None)

([[11, 17, 1, 212, 212, 212, 110, 219, 92, 1, 212, 118, 215, 53, 118, 212],
  [212, 212, 118, 93, 92, 1, 212, 212, 212, 118, 194, 1, 212, 212, 212, 212],
  [230, 73, 35, 212, 93, 241, 82, 93, 11, 1, 194, 56, 92, 191, 92, 183]],
 0.00465765967965126,
 tensor([[-0.3264,  0.1196,  0.4184, -0.1342, -0.4970, -0.0875,  0.0306,  0.2577,
           0.1859,  0.2623,  0.1183,  0.0730, -0.4708,  0.1127, -0.1290,  0.1809,
           0.3338,  0.6026,  0.3228,  0.5081, -0.1881, -0.2677, -0.1861, -0.1454,
           0.7420,  0.0959,  0.1700, -0.0903, -0.0798,  0.1567,  0.0990,  0.5753,
           0.1216,  0.1859,  0.2001, -0.5072, -0.2218,  0.3213, -0.1307, -0.0891,
           0.2631,  0.3012, -0.6117,  0.3088,  0.2558, -0.8699,  0.3474,  0.3200,
           0.0498, -0.0896, -0.0058,  0.1389,  0.4176,  0.0896,  0.1912, -0.1358,
           0.0468,  0.3852,  0.2406,  0.1795,  0.5275, -0.1891,  0.6130,  0.4632,
          -0.1768,  0.1506,  0.0659, -0.0788, -0.2129,  0.0477, -0.3145, -0.0031],
         [-