In [1]:
import torch

def create_causal_mask(seq_len, device=None):
    mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()
    if device is not None:
        mask = mask.to(device)
    return mask

def create_padding_mask(seq_lengths, max_len):
    batch_size = len(seq_lengths)
    mask = torch.zeros(batch_size, max_len, dtype=torch.bool)
    for i, length in enumerate(seq_lengths):
        if length < max_len:
            mask[i, length:] = True
    return mask

In [2]:
from transformer_layers import MultiHeadAttention, FeedForward
import torch.nn as nn


class Decoder(nn.Module):
    def __init__(self, d_model, nhead, d_ff, num_layers, dropout=0.1, norm=None):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                'self_attn': MultiHeadAttention(d_model, nhead, dropout),
                'cross_attn': MultiHeadAttention(d_model, nhead, dropout),
                'feed_forward': FeedForward(d_model, d_ff, dropout),
                'norm1': nn.LayerNorm(d_model),
                'norm2': nn.LayerNorm(d_model),
                'norm3': nn.LayerNorm(d_model),
                'dropout': nn.Dropout(dropout)
            }) for _ in range(num_layers)
        ])
        self.num_layers = num_layers
        self.norm = norm
    
    def forward(self, tgt, memory, tgt_mask=None, memory_mask=None,
                tgt_key_padding_mask=None, memory_key_padding_mask=None):
        output = tgt
        for layer in self.layers:
            # 自注意力機制（含遮罩）+ 殘差連接 + LayerNorm
            tgt2 = layer['self_attn'](output, output, output, tgt_mask, tgt_key_padding_mask)
            output = layer['norm1'](output + layer['dropout'](tgt2))
            
            # Cross attention（如果有 memory）
            if memory is not None:
                tgt2 = layer['cross_attn'](output, memory, memory, memory_mask, memory_key_padding_mask)
                output = layer['norm2'](output + layer['dropout'](tgt2))
            
            # 前饋神經網路 + 殘差連接 + LayerNorm
            tgt2 = layer['feed_forward'](output)
            output = layer['norm3'](output + layer['dropout'](tgt2))
        
        if self.norm is not None:
            output = self.norm(output)
        return output


# 測試用模型建立與輸入資料
d_model = 512
nhead = 8
d_ff = 2048
num_layers = 6
seq_len = 10
batch_size = 2

decoder = Decoder(d_model, nhead, d_ff, num_layers)
tgt = torch.randn(batch_size, seq_len, d_model)
memory = torch.randn(batch_size, seq_len, d_model)

# 建立遮罩
tgt_mask = create_causal_mask(seq_len).unsqueeze(0).expand(batch_size, -1, -1)
tgt_key_padding_mask = create_padding_mask([10, 8], seq_len)
memory_key_padding_mask = create_padding_mask([10, 7], seq_len)

output = decoder(
    tgt, 
    memory, 
    tgt_mask=tgt_mask, 
    memory_mask=None, 
    tgt_key_padding_mask=tgt_key_padding_mask, 
    memory_key_padding_mask=memory_key_padding_mask
)
print(output.shape)  # 預期輸出: (batch_size, seq_len, d_model)


torch.Size([2, 10, 512])
