### 1. Embedding

In [8]:
import torch 
import torch.nn as nn  
import torch.nn.functional as F
import math 

In [9]:
class Embedding(nn.Module):
    def __init__(self,vocab_size,d_model):
        super(Embedding,self).__init__()

        # input: [batch_size,seq_len]
        # output:[batch_size,seq_len,d_model]

        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size,d_model)

    def forward(self,x):
        x = self.embedding(x) * math.sqrt(self.d_model)
        return x 

### 2. PositionalEncoding

In [10]:
class PositionalEncoding(nn.Module):
    def __init__(self,d_model,max_len = 5000,dropout = 0.1):
        super(PositionalEncoding,self).__init__()

        # Input: [batch_size,seq_len,d_model]
        # Output:[batch_size,seq_len,d_model]

        # pe: [max_len,d_model]    -->  [1,max_len,d_model] for boardcasting
        # pos:[max_len,1]
        # div_term:[1,d_model]

        self.d_model = d_model
        self.dropout = nn.Dropout(dropout)

        pe = torch.zeros(max_len,d_model)

        pos = torch.arange(0,max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0,d_model,2) * (-math.log(10000) / d_model)
        )

        pe[:,0::2] = torch.sin(pos * div_term)
        pe[:,1::2] = torch.cos(pos * div_term)

        pe = pe.unsqueeze(0)
        self.register_buffer("pe",pe)

    def forward(self,x):
        x = x + self.pe[:,:x.size(1),:]
        x = self.dropout(x)
        return x

### 3. MultiHeadAttention

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self,d_model,heads,dropout = 0.1):
        super(MultiHeadAttention,self).__init__()

        # Input: [batch_size,seq_len,d_model]
        # Q: [N,q_len,d_model]
        # K: [N,k_len,d_model]
        # V: [N,v_len,d_model]


        self.d_model = d_model
        self.heads = heads
        self.d_k = d_model // heads
        self.dropout = nn.Dropout(dropout)

        self.wq = nn.Linear(d_model,d_model)
        self.wk = nn.Linear(d_model,d_model)
        self.wv = nn.Linear(d_model,d_model)
        self.wo = nn.Linear(d_model,d_model)

    def forward(self,query,key,value,mask):
        N = query.size(0)
        q_len,k_len,v_len = query.size(1), key.size(1), value.size(1)

        # [N,q_len,heads,d_k]
        # [N,k_len,heads,d_k]
        # [N,v_len,heads,d_k]

        Q = self.wq(query).view(N,q_len,self.heads,self.d_k).transpose(1,2)
        K = self.wk(key).view(N,k_len,self.heads,self.d_k).transpose(1,2)
        V = self.wv(value).view(N,v_len,self.heads,self.d_k).transpose(1,2)

        # QK.T [N,h,q_len,k_len]
        scores = torch.matmul(Q,K.transpose(-2,-1)) / math.sqrt(self.d_k)

        if mask is not None:
            scores = scores.masked_fill(mask == 0, -1e9)

        attention_weights = F.softmax(scores,dim = -1)

        attention_weights = self.dropout(attention_weights)

        # [N,h,q_len,d_k]
        context = torch.matmul(attention_weights,V)

        context = context.transpose(1,2).contiguous().view(N,q_len,self.d_model)

        output = self.wo(context)

        return output, attention_weights
        




### 4. PositionWiseFeedForward

In [12]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self,d_model,d_ff,dropout=0.1):
        super(PositionWiseFeedForward,self).__init__()

        # Input: [N,seq_len,d_model]
        # Output:[N,seq_len,d_model]

        # self.d_model = d_model
        # self.dropout = nn.Dropout(dropout)   在forward中没有出现，已经放在sequential容器里了

        self.ffn = nn.Sequential(
            nn.Linear(d_model,d_ff),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(d_ff,d_model)
        )

    def forward(self,x):

        x = self.ffn(x)
        return x 

### 5. AddNorm

In [13]:
class AddNorm(nn.Module):
    def __init__(self,d_model,dropout = 0.1):
        super(AddNorm,self).__init__()

        # LayerNorm(x + dropout(sublayer(x)))

        # x : [N,seq_len,d_model]
        
        self.norm = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self,x,sublayer):
        x = self.norm(x + self.dropout(sublayer))
        return x 

        

### 6. EncoderLayer

In [14]:
class EncoderLayer(nn.Module):
    def __init__(self,d_model,d_ff,heads,dropout = 0.1):
        super(EncoderLayer,self).__init__()

        self.self_attn = MultiHeadAttention(d_model,heads,dropout)
        self.addnorm1 = AddNorm(d_model,dropout)
        self.ffn = PositionWiseFeedForward(d_model,d_ff,dropout)
        self.addnorm2 = AddNorm(d_model,dropout)

    def forward(self,x,mask):

        self_attn,_ = self.self_attn(x,x,x,mask)
        x = self.addnorm1(x,self_attn)
        ffn = self.ffn(x)
        x = self.addnorm2(x,ffn)
        return x

### 7. Encoder

In [15]:
class Encoder(nn.Module):
    def __init__(self,vocab_size,d_model,heads,d_ff,num_layers,max_len = 5000,dropout = 0.1):
        super(Encoder,self).__init__()

        self.embedding = Embedding(vocab_size,d_model)
        self.positionencoding = PositionalEncoding(d_model,max_len,dropout)

        self.layers = nn.ModuleList([
            EncoderLayer(d_model,d_ff,heads,dropout) for _ in range(num_layers)
        ])


    def forward(self,src,src_mask):
        x = self.embedding(src)
        x = self.positionencoding(x)
        
        for layer in self.layers:
            x = layer(x,src_mask)
        return x


### 8. DecoderLayer

In [16]:
class DecoderLayer(nn.Module):
    def __init__(self,d_model,d_ff,heads,dropout = 0.1):
        super(DecoderLayer,self).__init__()
        
        self.self_attn = MultiHeadAttention(d_model,heads,dropout)
        self.addnorm1 = AddNorm(d_model,dropout)
        self.enc_dec_attn = MultiHeadAttention(d_model,heads,dropout)
        self.addnorm2 = AddNorm(d_model,dropout)
        self.ffn = PositionWiseFeedForward(d_model,d_ff,dropout)
        self.addnorm3 = AddNorm(d_model,dropout)

    def forward(self,x,encoder_output,src_mask,tgt_mask):
        # Input: embedding + pe

        attn_output,_ =  self.self_attn(x,x,x,tgt_mask)
        x = self.addnorm1(x,attn_output)

        enc_dec_attn_output,_ = self.enc_dec_attn(x,encoder_output,encoder_output,src_mask)
        x = self.addnorm2(x,enc_dec_attn_output)

        ffn_output = self.ffn(x)
        x = self.addnorm3(x,ffn_output)
        return x 





### 9. Decoder

In [27]:
class Decoder(nn.Module):
    def __init__(self,vocab_size,d_model,heads,d_ff,num_layers,max_len = 5000,dropout = 0.1):
        super(Decoder,self).__init__()

        # tgt [batch_size,tgt_seq_len]  --> [batch_size,tgt_seq_len,d_model]
        self.embedding = Embedding(vocab_size,d_model)

        
        self.positonalencoding = PositionalEncoding(d_model,max_len,dropout)

        self.layers = nn.ModuleList([
            DecoderLayer(d_model,d_ff,heads,dropout) for _ in range(num_layers)
        ])
        self.output_linear = nn.Linear(d_model,vocab_size)

    def forward(self,tgt,encoder_output,tgt_mask=None,src_mask=None):
        x = self.embedding(tgt)
        x = self.positonalencoding(x)
        
        for layer in self.layers:
            x = layer(x,encoder_output,src_mask,tgt_mask)
        
        x = self.output_linear(x)
        return x


### 10. Transformer

In [30]:
class Transformer(nn.Module):
    def __init__(self,src_vocab_size,tgt_vocab_size,d_model=512,heads = 8,d_ff= 2048,
                 num_encoder_layers= 6,num_decoder_layers = 6,max_len = 5000,dropout = 0.1):
        super().__init__()

        self.encoder = Encoder(src_vocab_size,d_model,heads,d_ff,num_encoder_layers,max_len,dropout)
        self.decoder = Decoder(tgt_vocab_size,d_model,heads,d_ff,num_decoder_layers,max_len,dropout)

    def forward(self,src,tgt,src_mask = None,tgt_mask = None):
        encoder_output = self.encoder(src,src_mask)
        decoder_output = self.decoder(tgt,encoder_output,src_mask = None,tgt_mask = None)
        
        return decoder_output

In [31]:
# 测试代码
if __name__ == "__main__":
    # 设置随机种子以便重现
    torch.manual_seed(42)
        
    # 简化的测试函数
    def quick_test():
        src_vocab_size = 100
        tgt_vocab_size = 100
        d_model = 64
        num_encoder_layers = 6
        num_decoder_layers = 6
        heads = 4
        d_ff = 128
        max_len = 50
        
        print("创建模型...")
        # 传入Transformer参数
        model = Transformer(
            src_vocab_size=src_vocab_size,
            tgt_vocab_size=tgt_vocab_size,
            d_model=d_model,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            heads=heads,
            d_ff=d_ff,
            max_len=max_len
        )
        
        print("创建测试数据...")
        batch_size = 2
        src_len = 10
        tgt_len = 8
        
        src = torch.randint(1, src_vocab_size, (batch_size, src_len))
        tgt = torch.randint(1, tgt_vocab_size, (batch_size, tgt_len))
        
        print(f"源序列形状: {src.shape}")   # [batch_size, src_len]
        print(f"目标序列形状: {tgt.shape}") # [batch_size, tgt_len]
        
        # 前向传播测试
        print("运行前向传播...")
        try:
            output = model(src, tgt)
            print(f"前向传播成功!")
            print(f"输出形状: {output.shape}")
            print(f"预期形状: [{batch_size}, {tgt_len}, {tgt_vocab_size}]")
            
            # 检查形状是否正确
            if output.shape == (batch_size, tgt_len, tgt_vocab_size):
                print("形状匹配！基本功能测试通过！")
                return True
            else:
                print(f"形状不匹配！期望: {(batch_size, tgt_len, tgt_vocab_size)}，实际: {output.shape}")
                return False
                
        except Exception as e:
            print(f"前向传播失败: {e}")
            return False
    
    # 运行测试
    success = quick_test()
    
    if success:
        print("Transformer代码基本功能正常! ")
    else:
        print("请检查代码实现")

创建模型...
创建测试数据...
源序列形状: torch.Size([2, 10])
目标序列形状: torch.Size([2, 8])
运行前向传播...
前向传播成功!
输出形状: torch.Size([2, 8, 100])
预期形状: [2, 8, 100]
形状匹配！基本功能测试通过！
Transformer代码基本功能正常! 
