# 模型构建与损失函数


**目录：**
1. BERT分类模型

2. 损失函数计算

---


In [1]:
import torch
import torch.nn as nn

from torch.utils.data import TensorDataset, RandomSampler, DataLoader

# 以BERT为预训练模型进行讲解
from transformers import BertPreTrainedModel, BertModel, BertConfig

%cd ../

C:\Users\威威的小荔枝\Desktop\第五课_代码


### 意图分类任务的MLP层

In [2]:
# intent分类的MLP全连接层
class IntentClassifier(nn.Module):
    def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
        super(IntentClassifier, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)
        self.linear = nn.Linear(input_dim, num_intent_labels)

    def forward(self, x):
        # x: [batch_size, input_dim]
        x = self.dropout(x)
        return self.linear(x)

### 主要的模型框架

In [3]:
class ClsBERT(BertPreTrainedModel):
    def __init__(self, config, args, intent_label_lst):
        super(ClsBERT, self).__init__(config)
        self.args = args
        self.num_intent_labels = len(intent_label_lst)
        self.bert = BertModel(config=config)  # Load pretrained bert

        self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)


    def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids):
        outputs = self.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
        sequence_output = outputs[0]
        
        pooled_output = outputs[1]  # [CLS]

        intent_logits = self.intent_classifier(pooled_output)

        outputs = ((intent_logits),) + outputs[2:]  # add hidden states and attention if they are here

        # 1. Intent Softmax
        if intent_label_ids is not None:
            if self.num_intent_labels == 1:
                intent_loss_fct = nn.MSELoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
            else:
                intent_loss_fct = nn.CrossEntropyLoss()
                intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))

            outputs = (intent_loss,) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)

### 损失函数 CrossEntropyLoss
Pytorch中CrossEntropyLoss()函数的主要是将softmax -> log -> NLLLoss合并到一块得到的结果， 所以我们自己不需要求softmax。
$$L=- \sum_{i=1}^{N}y_i* \log \hat{y_i}$$
$y_i$是真正类别的one-hot分布，只有真实类别的概率为1，其他都是0，$\hat{y_i}$是经由softmax后的分布

- softmax将输出数据规范化为一个概率分布。

- 然后将Softmax之后的结果取log

- 输入负对数损失函数

### 举例查看

In [4]:
from transformers import BertConfig, DistilBertConfig, AlbertConfig
from transformers import BertTokenizer, DistilBertTokenizer, AlbertTokenizer

from bert_finetune_cls.model import ClsBERT
from bert_finetune_cls.utils import init_logger, load_tokenizer, get_intent_labels
from bert_finetune_cls.data_loader import load_and_cache_examples

MODEL_CLASSES = {
    'bert': (BertConfig, ClsBERT, BertTokenizer),
}

MODEL_PATH_MAP = {
    'bert': './bert_finetune_cls/resources/uncased_L-2_H-128_A-2',
}

In [5]:
# 先构建参数
class Args():
    task =  None
    data_dir =  None
    intent_label_file =  None


args = Args()
args.task = "atis"
args.data_dir = "./bert_finetune_cls/data"
args.intent_label_file = "intent_label.txt"
args.max_seq_len = 50
args.model_type = "bert"
args.model_dir = "bert_finetune_cls/experiments/outputs/clsbert_0"
args.model_name_or_path = MODEL_PATH_MAP[args.model_type]

args.train_batch_size = 4
args.dropout_rate = 0.1



In [6]:
tokenizer = load_tokenizer(args)
config = MODEL_CLASSES[args.model_type][0].from_pretrained(args.model_name_or_path)

intent_label_lst = get_intent_labels(args)

model = ClsBERT(config, args, intent_label_lst)

In [7]:
# load dataset 
train_dataset = load_and_cache_examples(args, tokenizer, mode="train")

# torch自带的sampler类，功能是每次返回一个随机的样本索引
train_sampler = RandomSampler(train_dataset)
# 使用dataloader输出batch
train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

device = "cpu"
for step, batch in enumerate(train_dataloader):
    
    if step > 1:
        continue
    
    batch = tuple(t.to(device) for t in batch) # 将batch上传到显卡
    inputs = {"input_ids": batch[0],
              "attention_mask": batch[1],
              "token_type_ids": batch[2],
              "intent_label_ids": batch[3],}
    
    input_ids = inputs["input_ids"]
    print("input_ids: ", input_ids)
    
    attention_mask = inputs["attention_mask"]
    token_type_ids = inputs["token_type_ids"]
    intent_label_ids = inputs["intent_label_ids"]
    
    
    
    outputs = model.bert(input_ids, attention_mask=attention_mask,
                            token_type_ids=token_type_ids)  # sequence_output, pooled_output, (hidden_states), (attentions)
    
    pooled_output = outputs[1]  # [CLS]  [4 * 128]
    intent_logits = model.intent_classifier(pooled_output)
    print("intent_logits: ", intent_logits)   # [4 * 22]
    print("intent_logits: ", intent_logits.shape)
    
    intent_loss_fct = nn.CrossEntropyLoss()
    intent_loss = intent_loss_fct(intent_logits.view(-1, model.num_intent_labels), intent_label_ids.view(-1))
    print("intent_loss: ", intent_loss)
    
    
    

input_ids:  tensor([[ 101, 1045, 2342, 1037, 3462, 2006, 2250, 2710, 2013, 4361, 2000, 2624,
         5277, 2007, 1037, 3913, 7840, 1999, 5887,  102,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0],
        [ 101, 2425, 2033, 2055, 1996, 2598, 5193, 1999, 3190,  102,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0],
        [ 101, 1045, 2215, 2035, 7599, 2013, 5865, 2000, 2899, 5887, 2006, 9432,
          102,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,   