## CausalDecoder -1 (不是transformer的decoder)


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F


In [49]:
class CausalDecoder(nn.Module):
    def __init__(self, hidden_num, head_num, dropout=0.1):
        super(CausalDecoder, self).__init__()
        self.hidden_num = hidden_num
        self.head_num = head_num
        self.dropout = dropout
        self.head_dim = hidden_num // head_num  
        
        # for ffn
        self.layernorm_ffn = nn.LayerNorm(hidden_num, eps=1e-6)
        self.up_prj = nn.Linear(hidden_num, hidden_num*4)
        self.down_prj = nn.Linear(hidden_num*4, hidden_num)
        self.act_fn = nn.ReLU()
        self.drop_ffn = nn.Dropout(dropout)

        # for mutiheadattention
        self.q_prj = nn.Linear(hidden_num, hidden_num)
        self.k_prj = nn.Linear(hidden_num, hidden_num)
        self.v_prj = nn.Linear(hidden_num, hidden_num)
        self.o_prj = nn.Linear(hidden_num, hidden_num)
        self.drop_attn = nn.Dropout(dropout)
        self.layernorm_attn = nn.LayerNorm(hidden_num, eps=1e-6)

    def atten_output(self, Q, K, V, mask):
        #  [batch, head_num, seq_len, head_dim]

        # attention
        K_ = K.transpose(2, 3)
        # K_.shape = [batch, head_num, head_dim, seq_len]
        # Q.shape = [batch, head_num, seq_len, head_dim]
        atten_weight = torch.matmul(Q, K_) / (self.head_dim ** 0.5)
        # atten_weight.shape = [batch, head_num, seq_len, seq_len]
        print(f"K after projection: {K.shape}")
        print(f"K after view: {K.shape}")
        print(f"K after first permute: {K.shape}")
        print(f"K after final permute: {K_.shape}")
        
        #mask
        if mask is not None:
            atten_mask = mask.tril()
            atten_weight = atten_weight.masked_fill(atten_mask == 0, -1e9)
        else:
            atten_mask = torch.ones_like(atten_weight).tril()
            # [batch, head_num, seq_len, seq_len]
            atten_weight = atten_weight.masked_fill(atten_mask == 0, -1e9)
        atten_score = F.softmax(atten_weight, dim=-1)
        atten_score = self.drop_attn(atten_score)

        # output
        mid_output = torch.matmul(atten_score, V)
        # mid_output.shape = [batch, head_num, seq_len, head_dim]

        mid_output = mid_output.permute(0, 2, 1, 3).contiguous()
        # mid_output.shape = [batch, seq_len, head_num, head_dim]

        batch_size, seq_len, _, _ = mid_output.size()
        print('mid_output-已经转化：\n', mid_output.shape)

        mid_output = mid_output.view(batch_size, seq_len, -1)
        # mid_output.shape = [batch, seq_len, head_num * head_dim]


        print('mid_output.shape:\n', mid_output.shape)
        output = self.o_prj(mid_output)
        # output.shape = [batch, seq_len, hidden_num]

        return output

    def atten_block(self, x, mask=None):
        batch_size, seq_len, _ = x.size()
        # [batch, seq_len, hidden_num]

        Q = self.q_prj(x).view(batch_size, seq_len, self.head_num, self.head_dim).permute(0, 2, 1, 3)
        K = self.k_prj(x).view(batch_size, seq_len, self.head_num, self.head_dim).permute(0, 2, 1, 3)
        V = self.v_prj(x).view(batch_size, seq_len, self.head_num, self.head_dim).permute(0, 2, 1, 3)
        # Q.shape = [batch, head_num, seq_len, head_dim]
        print('Q.shape:\n' ,Q.shape)

        output_atten = self.atten_output(Q, K, V, mask)
        print('output_atten.shape:\n', output_atten.shape)
        output_atten = self.layernorm_attn(output_atten + x)


        return output_atten
    
    def ffn(self, x):
        up = self.up_prj(self.act_fn(x))
        down = self.down_prj(up)
        output = self.drop_ffn(down)
        output_ffn = self.layernorm_ffn(output + x)
        return output_ffn
    
    def forward(self, x, mask=None):
        output_atten = self.atten_block(x, mask)
        output_ffn = self.ffn(output_atten)
        return output_ffn

# 测试代码
x = torch.randn(3, 4, 128)  # [batch_size, seq_len, dim]
# mask.shape = [batch, head_num, seq_len, seq_len]
b = torch.ones(3, 4, 4)
mask = b.unsqueeze(1).repeat(1, 8, 1, 1)
# mask.shape = [3, 8, 4, 4]
decoder = CausalDecoder(128, 8)
output = decoder(x, mask)
print( '最后的输出：\n',output.shape)

Q.shape:
 torch.Size([3, 8, 4, 16])
K after projection: torch.Size([3, 8, 4, 16])
K after view: torch.Size([3, 8, 4, 16])
K after first permute: torch.Size([3, 8, 4, 16])
K after final permute: torch.Size([3, 8, 16, 4])
mid_output-已经转化：
 torch.Size([3, 4, 8, 16])
mid_output.shape:
 torch.Size([3, 4, 128])
output_atten.shape:
 torch.Size([3, 4, 128])
最后的输出：
 torch.Size([3, 4, 128])
