# bertEncode模型分解
* 语言：pytorch

In [2]:
import torch
from torch import nn
import math

## 1、embedding 层
#### nn.Embedding
    
    ** 参数：vocab_size,向量长度,padding_idx（填充id）
    
#### nn.LayerNorm(对-1维进行正则化)

    ** 参数：hidden_size（正则化的长度）、eps=平滑参数
    
#### nn.Dropout
    ** 参数：丢弃率

#### self.register_buffer("a",aa)
    ** 参数：变量名、aa(内容)
    ** 作用：注册一个变量a,值为aa,赋给self对象
 
#### torch.expand((2,2))
    ** 参数：shape（元组）
    ** 功能：将源数据，广播到指定维度
    
#### tensor.permute(1,0,2)
    ** 参数：维度的顺序
    ** 功能：转置
 
#### tensor.contiguous()
    ** 参数：无
    ** 功能：保证tensor在内存中的存储是连续的
    ** 备注：view()是建立在contiguous的变量之上，在使用transpose、permute之后，必须加一个contiguous,才能使用view

In [3]:
class BertEmbedding(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.token_embedding=nn.Embedding(config.vocab_size,config.hidden_size,padding_idx=config.pad_token_id)  #初始化向量矩阵
        self.position_embdedding=nn.Embedding(config.max_position_embedding,config.hidden_size)  #单个字映射
        self.segment_embedding=nn.Embedding(3,config.hidden_size)
        
        self.LayerNorm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.dropout=nn.Dropout(config.dropout_prob)
        
        self.register_buffer("position_ids",torch.arange(config.max_position_embedding).expand((1,-1)))   #[1,max_position_embedding]
    
    def forward(self,seq,labels):
        x=self.token_embedding(seq)+self.position_embdedding(self.position_ids[:,:seq.size(1)])+self.segment_embedding(labels)
        return self.dropout(self.LayerNorm(x))

## 2、Encode 层
#### nn.ModuleList
    * 参数：list（与list无异）
#### nn.Linear
    * 参数：上一层size、下一层size
    * 功能：矩阵乘

In [4]:
class BertEncode(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.config=config
        self.layers=nn.ModuleList([BertLayer(config) for _ in range(config.layer_nums)])    #n个模块
    
    def forward(self,x):
        return self.layers(x)

In [5]:
class BertLayer(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.chunk_size_feed_forward=config.chunk_size_feed_forward
        self.seq_len_dim=1
        self.is_decoder=config.is_decoder
        self.attention=BertAttention(config)   #attetion+ add&norm
        self.add_cross_attention = config.add_cross_attention
        if self.add_cross_attention:  #decoder在加一个attention
            assert self.is_decoder, f"{self} should be used as a decoder model if cross attention is added"
            self.crossattention = BertAttention(config)
        self.intermediate=BertIntermediate(config)   #feed forward
        self.output=BertOutput(config)    #add & norm
    
    def forward(self,hidden_states,attention_mask=None,head_mask=None,encoder_hidden_states=None,output_attention=False):
        self_attention_outputs=self.attention(hidden_states,attention_mask,head_mask,output_attention=output_attention)
        attention_output=self_attention_outputs[0]
        outputs=self_attention_outputs[1:]
        
        """
            注释掉的这段代码在BERT里是不用的，在一些用于生成式任务的预训练模型会使用，其实这个地方是能体现出BERT和GPT的不同。
            if self.is_decoder and encoder_hidden_states is not None:
                assert hasattr(
                    self, "crossattention"
                ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
                cross_attention_outputs = self.crossattention(
                    attention_output,
                    attention_mask,
                    head_mask,
                    encoder_hidden_states,
                    encoder_attention_mask,
                    output_attentions,
                )
                attention_output = cross_attention_outputs[0]
                outputs = outputs + cross_attention_outputs[1:]  # add cross attentions if we output attention weights
        """
        # 这段代码里的chunking部分不是给BERT用的，但是源码把BertIntermediate、BertOutput都封装到里面了，所以我们直接看feed_forward_chunk这个函数就可以了
        #  layer_output = apply_chunking_to_forward(
        #  self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
        #  )
        layer_output=self.feed_forward_chunk(attention_output)
        outputs = (layer_output,) + outputs
        return outputs

    def feed_forward_chunk(self, attention_output):
        # BertIntermediate，结构见图
        intermediate_output = self.intermediate(attention_output)
        # BertOutput，结构见图
        layer_output = self.output(intermediate_output, attention_output)
        return layer_output

In [6]:
#attention 内部
class BertAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.self_attention=BertSelfAttention(config)
        self.output=BertSelfOutput(config)
        self.pruned_heads=set()
    
    def forward(self,hidden_states,attention_mask=None,head_mask=None,
                encoder_hidden_states=None,encoder_attention_mask=None,output_attention=False):
        attention_output,attention_prob=self.self_attention(
            hidden_states,
            attention_mask,
            head_mask,
            encoder_hidden_states,
            encoder_attention_mask,
            output_attention
        )
        return self.output(attention_output,hidden_states),attention_prob

#attention部分
class BertSelfAttention(nn.Module):
    def __init__(self,config):
        super().__init__()
        if config.hidden_size%config.num_attention_heads==0:
            raise Exception("不能整除")
        self.num_attention_heads=config.num_attention_heads
        self.attention_head_size=config.hidden_size//config.num_attention_heads
        self.all_head_size=self.num_attention_heads*self.attention_head_size
        
        #输入线性转换
        self.query=nn.Linear(config.hidden_size,self.all_head_size)
        self.value=nn.Linear(config.hidden_size,self.all_head_size)
        self.key=nn.Linear(config.hidden_size,self.all_head_size)
        
        self.normalize=nn.Softmax(dim=-1)
        
        self.dropout=nn.Dropout(config.dropout_prob)
    
    #切头
    def transpose_for_scores(self,x):
        x_shape=x.size()[:1]+(self.num_attention_heads,self.attention_head_size)
        x=x.view(*x_shape)
        return x.permute(0,2,1,3)     #[bs,num_haeds,seq_len,attn_size]
    
    def forward(self,hidden_states,attention_mask=None,head_mask=None,
                encoder_hidden_states=None,encoder_attention_mask=None,output_attention=False):
        all_query=self.query(hidden_states)
        if encoder_hidden_states is not None:   #decode
            all_key=self.key(encoder_hidden_states)
            all_value=self.value(encoder_hidden_states)
            attention_mask=encoder_attention_mask
        else:
            all_key=self.key(hidden_states)
            all_value=self.value(hidden_states)
        #切头
        query_layer=self.transpose_for_scores(all_query)
        key_layer=self.transpose_for_scores(all_key)
        value_layer=self.transpose_for_scores(all_value)
        
        #自注意力
        key_layer=key_layer.permute(0,1,3,2)
        d_k=query_layer.size(-1)
        attetion_scores=torch.matmul(query_layer,key_layer)/math.sqrt(d_k)
        
        #mask
        if attention_mask is not None:
            attetion_scores+=attention_mask
        
        #归一化
        attention_probs=self.normalize(attetion_scores)
        
        #head_mask:可以有选择的mask多个头
        if head_mask is not None:
            attention_probs=attention_probs*head_mask
        
        context_layer=torch.matmul(attention_probs,value_layer)   #下一次输入
        #合并头
        context_layer=context_layer.permute(0,2,1,3).contiguous()
        context_shape=context_layer.size()[:-2]+(self.all_head_size,)
        context_layer=context_layer.view(*context_shape)
        
        outputs=(context_layer,attention_probs) if output_attention else (context_layer,)
        return outputs
        
#add & norm
class BertSelfOutput(nn.Module):
    def __init__(self,config):
        self.dense=nn.Linear(config.hidden_size,config.hidden_size)
        self.dropout=nn.Dropout(config.dropout_prob)
        self.layer_norm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
    
    def forward(self,attention_outputs,hidden_state):
        attention_outputs=self.dense(attention_outputs)
        attention_outputs=self.dropout(attention_outputs)
        return self.layer_norm(attention_outputs+hidden_state)

In [7]:
#全连接+激活函数
ACT2FN={"relu":nn.ReLU}

class BertIntermediate(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.dense=nn.Linear(config.hidden_size,config.intermediate_size)
        if isinstance(config.hidden_act,str):
            self.intermediate_act_fn=ACT2FN[config.hidden_act]
        else:
            self.intermediate_act_fn=config.hidden_act
    
    def forward(self,hidden_states):
        hidden_states=self.dense(hidden_states)
        hidden_states=self.intermediate_act_fn(hidden_states)
        return hidden_states

In [8]:
# add & norm
class BertOutput(nn.Module):
    def __init__(self,config):
        self.dense=nn.Linear(config.intermediate_size,config.hidden_size)
        self.layer_norm=nn.LayerNorm(config.hidden_size,eps=config.layer_norm_eps)
        self.dropout=nn.dropout(config.dropout_prob)
    
    def forward(self,hidden_states,input_tensor):
        hidden_states=self.dense(hidden_states)
        hidden_states=self.dropout(hidden_states)
        return self.layer_norm(input_tensor+hidden_states)

## 3、输出层

In [9]:
#分类层 
class BertPooler(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.dense=nn.Linear(config.hidden_size,config.hidden_size)
        self.activation=nn.Tanh()
        
    def forward(self,hidden_states):
        cls_hidden=hidden_states[:,0]
        pooled_output=self.dense(cls_hidden)
        pooled_output=self.activation(pooled_output)
        return pooled_output