# Causal LM: self-attention + ffn

![decoder](./../image/Decoder.png)

In [10]:
import math
import torch
import torch.nn as nn 

class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_dim : int, head_num : int, attention_dropout_rate: float = 0.1):
        super().__init__()
        self.hidden_dim = hidden_dim
        self.head_num = head_num
        self.head_dim = hidden_dim // head_num # 整除关系
        
        # layer (mha, ffn)
        # mha
        self.q = nn.Linear(hidden_dim, hidden_dim)
        self.k = nn.Linear(hidden_dim, hidden_dim)
        self.v = nn.Linear(hidden_dim, hidden_dim)
        self.o = nn.Linear(hidden_dim, hidden_dim)
        self.dropout_att = nn.Dropout(attention_dropout_rate)
        self.att_ln = nn.LayerNorm(hidden_dim, eps=1e-7)
        
        # ffn
        self.up_proj = nn.Linear(hidden_dim, hidden_dim * 4) # (swishGLU, ) 8 / 3
        self.down_proj = nn.Linear(hidden_dim * 4, hidden_dim)
        self.act_fn = nn.GELU()
        self.dropout_ffn = nn.Dropout(1e-1)
        self.ffn_ln = nn.LayerNorm(hidden_dim, eps=1e-7)
        
    def attention_layer(self, query, key, value, attention_mask = None):
        # output (b, s, h)
        key = key.transpose(2, 3) # (b, head_num, head_dim, seq)
        attention_weight = query @ key / math.sqrt(self.head_dim)
        
        if attention_mask is not None:
            attention_mask = attention_mask.tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        else:
            attention_mask = torch.ones_like(attention_weight).tril()
            attention_weight = attention_weight.masked_fill(
                attention_mask == 0,
                float("-inf")
            )
        
        attention_weight = torch.softmax(attention_weight, dim = -1)
        attention_weight = self.dropout_att(attention_weight)
        mid_out = attention_weight @ value
        
        
        mid_out = mid_out.transpose(1, 2).contiguous()
        b, s, _, _ = mid_out.size()
        mid_out = mid_out.view(b, s, -1) # concat
        
        output = self.o(mid_out)
        return output
    
    def mha(self, x, mask = None):
        # (b, s, h) -> (b, head_num, s, head_dim)
        batch, seq_len, _ = x.size()
        
        Q = self.q(x)
        K = self.k(x)
        V = self.v(x) # (b, s, h)
        
        # (b, s, h) -> (b, head_num, s, head_dim)
        q_state = Q.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        k_state = K.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        v_state = V.view(batch, seq_len, self.head_num, self.head_dim).transpose(1, 2)
        
        output = self.attention_layer(q_state, k_state, v_state, mask)
        
        # post norm (b, s, h)
        return output
    
    def ffn(self, x):
        up = self.up_proj(x)
        up = self.act_fn(up)
        down = self.down_proj(up)
        # # dropout
        # down = self.dropout_ffn(down)
        # post layernorm
        return self.ffn_ln(x + down)
    
    def forward(self, x, attention_mask = None):
        # (s, h)
        x = self.mha(x, attention_mask)
        x = self.ffn(x)
        return x
    
class Decoder(nn.Module):
    def __init__(self, ):
        super().__init__()
        self.layer_list = nn.ModuleList([
            SimpleDecoderLayer(64, 8) for i in range(5)
        ])
        
        self.emb = nn.Embedding(12, 64)
        self.out = nn.Linear(64, 12)
    
    def forward(self, x, mask = None):
        # (3, 4)
        print("before decoder", x.shape)
        x = self.emb(x)
        print("after decoder", x.shape)
        for i, l in enumerate(self.layer_list):
            x = l(x, mask)
            print(f"{i}th x shape", x.shape)
            
        print(x.shape)
        output = self.out(x)
        return torch.softmax(output, dim = -1)
        
# (3, 4)
x = torch.randint(low = 0, high = 12, size = (3, 4))
print("x shape:", x.shape)
net = Decoder()
# (3, 4) -unsqueeze-> (3, 1, 4) -unsqueeze-> (3, 1, 1, 4) -repeat-> (3, 8, 4, 4)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 0, 0, 0]]).unsqueeze(1).unsqueeze(2).repeat(1, 8, 4, 1)
)
net(x, mask)


x shape: torch.Size([3, 4])
before decoder torch.Size([3, 4])
after decoder torch.Size([3, 4, 64])
0th x shape torch.Size([3, 4, 64])
1th x shape torch.Size([3, 4, 64])
2th x shape torch.Size([3, 4, 64])
3th x shape torch.Size([3, 4, 64])
4th x shape torch.Size([3, 4, 64])
torch.Size([3, 4, 64])


tensor([[[0.0317, 0.1421, 0.0564, 0.1285, 0.0365, 0.1562, 0.0980, 0.0621,
          0.0408, 0.0440, 0.1040, 0.0997],
         [0.0479, 0.0925, 0.0254, 0.1728, 0.0583, 0.1789, 0.1241, 0.0221,
          0.0248, 0.0808, 0.0968, 0.0756],
         [0.0431, 0.0758, 0.0206, 0.1860, 0.0403, 0.1420, 0.1442, 0.0248,
          0.0288, 0.0702, 0.1299, 0.0944],
         [0.0370, 0.0825, 0.0203, 0.2120, 0.0398, 0.1552, 0.1327, 0.0206,
          0.0260, 0.0679, 0.1268, 0.0792]],

        [[0.0929, 0.0690, 0.1036, 0.0735, 0.0419, 0.1938, 0.0473, 0.0289,
          0.1132, 0.0743, 0.0304, 0.1311],
         [0.0756, 0.0650, 0.1395, 0.0889, 0.0414, 0.1910, 0.0485, 0.0263,
          0.1188, 0.0743, 0.0263, 0.1044],
         [0.0643, 0.0724, 0.1088, 0.0915, 0.0437, 0.2451, 0.0373, 0.0203,
          0.1137, 0.0747, 0.0276, 0.1006],
         [0.0730, 0.0809, 0.1183, 0.0893, 0.0377, 0.2379, 0.0497, 0.0213,
          0.0756, 0.0707, 0.0351, 0.1104]],

        [[0.0669, 0.1487, 0.0986, 0.1311, 0.0246, 0.1026, 0.