# Causal LM的decoder：主要是MHA和FFN部分。

In [1]:
import math
import torch
import torch.nn as nn

In [18]:
class SimpleDecoderLayer(nn.Module):
    def __init__(self, hidden_size, num_heads, attention_dropout_rate=0.1):
        super().__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_size = hidden_size // num_heads

        # MHA part
        self.q_proj = nn.Linear(hidden_size, hidden_size)
        self.k_proj = nn.Linear(hidden_size, hidden_size)
        self.v_proj = nn.Linear(hidden_size, hidden_size)
        self.o_proj = nn.Linear(hidden_size, hidden_size)

        self.attention_dropout = nn.Dropout(attention_dropout_rate)
        self.attention_layer_norm = nn.LayerNorm(hidden_size, eps=1e-6)

        # FFN part ： 升维 -> 降维 -> layernorm 
        self.up_proj = nn.Linear(hidden_size, hidden_size * 4)  # 这里是4倍，swishGLU是8/3
        self.down_proj = nn.Linear(hidden_size * 4, hidden_size)
        self.act_fn = nn.GELU()

        self.ffn_dropout = nn.Dropout(attention_dropout_rate)
        self.ffn_layer_norm = nn.LayerNorm(hidden_size, eps=1e-6)

    def attention_layer(self, q, k, v, attention_mask=None):
        attention_value = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(self.head_size)

        if attention_mask is not None:
            attention_value = attention_value.tril()  # 保留左下三角，其余置0
            attention_value = attention_value.masked_fill(attention_mask == 0, float("-inf"))   
        else:
            attention_mask = torch.ones_like(attention_value).tril()
            attention_value = attention_value.masked_fill(attention_mask == 0, float("-inf"))

        attention_weight = torch.softmax(attention_value, dim=-1)
        attention_weight = self.attention_dropout(attention_weight)

        attention_result = torch.matmul(attention_weight, v)
        # concat
        batch_size, seq_len = attention_result.shape[0], attention_result.shape[2]
        attention_result = attention_result.transpose(1, 2).contiguous().view(batch_size, seq_len, self.hidden_size)

        output = self.o_proj(attention_result)

        return output
            

    def MHA(self, x, attention_mask=None):
        batch_size, seq_len, hidden_size = x.shape

        Q = self.q_proj(x)
        K = self.k_proj(x)
        V = self.v_proj(x)

        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_size).transpose(1, 2)

        output = self.attention_layer(Q, K, V, attention_mask)

        # residual + post-layernorm
        output = self.attention_layer_norm(x + output)

        return output
    
    def FFN(self, x):
        up = self.up_proj(x)
        up = self.act_fn(up)
        down = self.down_proj(up)
        down = self.ffn_dropout(down)
        # residual + post-layernorm
        output = self.ffn_layer_norm(x + down)
        return output

    def forward(self, x, attention_mask=None):
        x = self.MHA(x, attention_mask)
        x = self.FFN(x)
        return x

In [25]:
class Decoder(nn.Module):
    def __init__(self, vocab_size, hidden_size, num_heads, attention_dropout_rate=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.layer_list = nn.ModuleList([
            SimpleDecoderLayer(hidden_size, num_heads, attention_dropout_rate)
        ])
        self.output_layer = nn.Linear(hidden_size, vocab_size)
    
    def forward(self, x, attention_mask=None):
        x = self.embedding(x)
        for layer in self.layer_list:
            x = layer(x, attention_mask)
        output = self.output_layer(x)
        logits = torch.softmax(output, dim=-1)  
        return output

In [38]:
x = torch.randint(low=0, high=12, size=(3, 4))  # 最大为12，因为下面vocab_size=12
print(x)
# x = torch.rand(3, 4, 64)
# net = SimpleDecoderLayer(64, 8)
net = Decoder(12, 64, 8)
mask = (
    torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0], [1, 1, 1, 0]])
    .unsqueeze(1)
    .unsqueeze(2)
    .repeat(1, 8, 4, 1)
)

net(x, mask).shape

tensor([[ 8,  3,  0,  9],
        [ 5,  5,  0,  3],
        [ 9,  8,  0, 10]])


torch.Size([3, 4, 12])

In [13]:
attention_mask = torch.ones(5, 5, 5)
test = attention_mask.tril() - torch.ones(5, 5, 5)
test.tril_()

tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]])

In [11]:
attention_mask

tensor([[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]])

In [17]:
batch_size, seq_len = attention_mask.shape[0], attention_mask.shape[2]
print(batch_size, seq_len)

(5, 5)