In [1]:
import torch
import torch.nn as nn
from crf import CRF
from kobert.pytorch_kobert import get_pytorch_kobert_model
from transformers import DistilBertModel
import config

class BERT_CRF_Joint(nn.Module):
    def __init__(self, config=config, bert=None, distill=False):
        super(BERT_CRF_Joint, self).__init__()
        
        
        #별도의 BERT모델을 지정하지 않으면 SKT KoBERT를 Default로 지정한다. 
        self.bert = bert
        self.distill=distill
        if bert is None:
            if self.distill == True:
                self.bert = DistilBertModel.from_pretrained('monologg/distilkobert')
            else:
                self.bert, self.vocab  = get_pytorch_kobert_model()
                
            for param in self.bert.parameters():
                param.requires_grad = True
            
        
        self.dropout = nn.Dropout(config.dropout)
        self.crf_linear = nn.Linear(config.hidden_size, config.num_entity)
        self.intent_classifier = nn.Linear(config.hidden_size, config.num_intent)
        self.bilstm  = nn.LSTM(config.hidden_size, config.hidden_size //2, 
                               batch_first=True, bidirectional=True )
        self.crf = CRF(num_tags=config.num_entity, batch_first=True)
    
    
    #Sentence의 길이만큼만 Attention을 취하기 위해 Mask를 생성한다.
    def get_attention_mask(self, input_ids, valid_length):
        attention_mask = torch.zeros_like(input_ids)
        for i, v in enumerate(valid_length):
            attention_mask[i][:v] = 1
        return attention_mask.float()
    
    def forward(self, input_ids, valid_length, token_type_ids, entity=None, intent=None):
        attention_mask = self.get_attention_mask(input_ids, valid_length)
        
        #all_encoder_layers는 BERT의 output representation 전부이고
        #poold_output은 CLS token의 representation 값이다.
        #기본 kobert와 distill kobert의 output형태가 다르기 때문에 분기처리하였다.
        if self.distill==True:
            outputs = self.bert(input_ids=input_ids.long(), 
                                attention_mask=attention_mask) 
            
            all_encoder_layers, pooled_output = outputs[0], outputs[0][:,0,:]
            
        else:
            all_encoder_layers, pooled_output = self.bert(input_ids=input_ids.long(),
                                                      token_type_ids=token_type_ids,
                                                      attention_mask=attention_mask)

        cls_out = pooled_output
        cls_out_drop = self.dropout(cls_out)
        logits = self.intent_classifier(cls_out_drop)
        
        # Entity on CRF
        last_encoder_layer = all_encoder_layers
        drop = self.dropout(last_encoder_layer)
        output, hc = self.bilstm(drop)
        linear = self.crf_linear(output)
        tag_seq = self.crf.decode(linear)

        # For training
        if entity is not None:
            log_likelihood = self.crf(linear, entity)       
            return log_likelihood, tag_seq, logits
        
        # For inference
        else: 
            confidence = self.crf.compute_confidence(linear, tag_seq)
            return tag_seq, confidence, logits
               

# TEST CODE

In [2]:
from tokenization_kobert import KoBertTokenizer
tokenizer = KoBertTokenizer.from_pretrained('monologg/kobert') # monologg/distilkobert도 동일
batch=[]
valid_length=[]
batch.append(tokenizer.encode("안녕하세요. 반갑습니다. 테스트코드를 작성합니다."))
valid_length.append(len(tokenizer.encode("안녕하세요. 반갑습니다. 테스트코드를 작성합니다.")))
batch.append(tokenizer.encode("오늘은 버트를 이용하여 테스트를 해보겠습니다."))
valid_length.append(len(tokenizer.encode("오늘은 버트를 이용하여 테스트를 해보겠습니다.")))
batch.append(tokenizer.encode("조인트 모델이 잘 돌아가는지 궁금하군요."))
valid_length.append(len(tokenizer.encode("조인트 모델이 잘 돌아가는지 궁금하군요.")))


maxlen=max(valid_length)
input_ids = [sen+[0]*(maxlen-len(sen)) for sen in batch]
input_ids= torch.tensor(input_ids)

In [3]:
joint = BERT_CRF_Joint(distill=True)
joint(input_ids, valid_length, token_type_ids=None)

([[53,
   26,
   174,
   244,
   128,
   134,
   130,
   128,
   165,
   161,
   233,
   103,
   245,
   128,
   165,
   124],
  [233,
   128,
   161,
   256,
   248,
   128,
   165,
   234,
   66,
   66,
   214,
   150,
   190,
   211,
   173,
   124],
  [53, 101, 128, 198, 103, 128, 66, 16, 103, 244, 128, 173, 55, 244, 146, 87]],
 0.004704668186604977,
 tensor([[ 0.0967,  0.0642,  0.3189, -0.1460,  0.2150,  0.2407, -0.2924,  0.0133,
          -0.0249, -0.0041, -0.1298, -0.3748,  0.0901, -0.1452,  0.0431,  0.0045,
          -0.3683, -0.4025,  0.2652, -0.1089, -0.0740,  0.3377,  0.4195,  0.2688,
          -0.1537, -0.6224, -0.5605,  0.2107, -0.1776,  0.3735, -0.5674,  0.3486,
           0.1657, -0.4137,  0.0564,  0.1058,  0.3253,  0.5036,  0.6185, -0.0107,
           0.1889, -0.6330,  0.2323, -0.0898,  0.0166,  0.1886, -0.3348, -0.0582,
           0.3292, -0.3958, -0.3806,  0.1372, -0.1704, -0.1735, -0.0093, -0.1659,
          -0.0498, -0.1903,  0.1411, -0.0853, -0.0279,  0.0663, -0.27