# Causal LM: self-attention + ffn

![decoder](./../image/Decoder.png)

In [1]:
import math
import torch
import torch.nn as nn 

class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_dim : int, head_num : int, attention_dropout_rate: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num # 整除关系
        
        # layer (mha, ffn)
        # mha
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        self.o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout_att = nn.Dropout(attention_dropout_rate)
        self.att_ln = nn.LayerNorm(hidden_dim, eps=1e-7)
        
        # ffn
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4) # (swishGLU, ) 8 / 3
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)
        self.act_fn = nn.GELU()
        self.dropout_ffn = nn.Dropout(1e-1)
        self.ffn_ln = nn.LayerNorm(hidden_dim, eps=1e-7)
        
    def attention_layer(self, query, key, value, attention_mask = None):
        # output (b, s, h)
        key = key.transpose(2, 3) # (b, head_num, head_dim, seq)
        attention_weight = query @ key / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attention_mask = attention_mask.tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        else:
            attention_mask = torch.ones_like(attention_weight).tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        
        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.dropout_att(attention_weight)
        mid_out = attention_weight @ value
        
        
        mid_out = mid_out.transpose(1, 2).contiguous()
        b, s, _, _ = mid_out.size()
        mid_out = mid_out.view(b, s, -1) # concat
        
        output = self.o(mid_out)
        return output
    
    def mha(self, x, mask = None):
        # (b, s, h) -> (b, head_num, s, head_dim)
        batch, seq_len, _ = x.size()
        
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x) # (b, s, h)
        
        # (b, s, h) -> (b, head_num, s, head_dim)
        q_state = Q.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k_state = K.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v_state = V.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        
        output = self.attention_layer(q_state, k_state, v_state, mask)
        
        # post norm (b, s, h)
        return output
    
    def ffn(self, x):
        up = self.up_proj(x)
        up = self.act_fn(up)
        down = self.down_proj(up)
        # # dropout
        # down = self.dropout_ffn(down)
        # post layernorm
        return self.ffn_ln(x + down)
    
    def forward(self, x, attention_mask = None):
        # (s, h)
        x = self.mha(x, attention_mask)
        x = self.ffn(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.layer_list = nn.ModuleList([
            SimpleDecoderLayer(64, 8) for i in range(5)
        ])
        
        self.emb = nn.Embedding(12, 64)
        self.out = nn.Linear(64, 12)
    
    def forward(self, x, mask = None):
        # (3, 4)
        x = self.emb(x)
        for i, l in enumerate(self.layer_list):
            x = l(x, mask)
            print(f"{i}th x shape", x.shape)
            
        print(x.shape)
        output = self.out(x)
        return torch.softmax(output, dim = -1)
        
# (3, 4)
x = torch.randint(low = 0, high = 12, size = (3, 4))
print("x shape:", x.shape)
net = Decoder()
# (3, 4) -unsqueeze-> (3, 1, 4) -unsqueeze-> (3, 1, 1, 4) -repeat-> (3, 8, 4, 4)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 0, 0, 0]]).unsqueeze(1).unsqueeze(2).repeat(1, 8, 4, 1)
)
print("mask:\r\n", mask)
net(x, mask)


x shape: torch.Size([3, 4])
mask:
 tensor([[[[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]],

         [[1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1],
          [1, 1, 1, 1]]],


        [[[1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0]],

         [[1, 1, 0, 0],
          [1, 1, 0, 0],
          [1, 1, 0, 0],
         

tensor([[[0.0603, 0.1562, 0.0590, 0.0412, 0.1855, 0.0615, 0.1033, 0.0349,
          0.1118, 0.0742, 0.0667, 0.0453],
         [0.0366, 0.1617, 0.0741, 0.0388, 0.2317, 0.0572, 0.0811, 0.0302,
          0.0665, 0.0848, 0.1008, 0.0365],
         [0.0418, 0.1505, 0.0955, 0.0324, 0.2099, 0.0725, 0.0871, 0.0310,
          0.0683, 0.0802, 0.0934, 0.0373],
         [0.0404, 0.1444, 0.0796, 0.0364, 0.2727, 0.0623, 0.0735, 0.0293,
          0.0661, 0.0742, 0.0886, 0.0325]],

        [[0.0464, 0.0750, 0.1038, 0.0929, 0.2257, 0.0752, 0.0397, 0.0793,
          0.0719, 0.0613, 0.0437, 0.0852],
         [0.0468, 0.0684, 0.1046, 0.1067, 0.2086, 0.0758, 0.0434, 0.0797,
          0.0793, 0.0527, 0.0456, 0.0885],
         [0.0445, 0.0625, 0.0895, 0.0616, 0.2889, 0.0768, 0.0429, 0.0754,
          0.0620, 0.0664, 0.0350, 0.0945],
         [0.0498, 0.0833, 0.0996, 0.1033, 0.2040, 0.0784, 0.0426, 0.0722,
          0.0733, 0.0513, 0.0512, 0.0910]],

        [[0.0940, 0.0516, 0.0455, 0.1461, 0.0885, 0.0735, 0.