In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
print(torch.__version__)
"""
pytorch源码目录
/usr/local/lib/python3.8/dist-packages/torch
"""

1.12.0+cu113


'\npytorch源码目录\n/usr/local/lib/python3.8/dist-packages/torch\n'

In [2]:
## 超参数
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
num_heads = 8
embed_len = 512
batch_size = 8
stack_len = 6
drop_out = 0.1

input_vocab_size = 7000
output_vocab_size = 7000
print(device)

cpu


In [3]:
## embedding block
class InputEmbedding(nn.Module):
    """
    输入embedding word_embedding+position_embedding
    """
    def __init__(self, input_vocab_size=7000, embed_len=512, dropout=0.1, device=device):
        super(InputEmbedding, self).__init__()
        self.input_vocab_size = input_vocab_size
        self.embed_len = embed_len
        self.dropout = dropout
        self.device = device
        # WordEmbeddingLayer
        self.word_embedding_layer = nn.Embedding(self.input_vocab_size, self.embed_len)  
        # PositionEmbeddingLayer
        self.position_embedding_layer = nn.Embedding(self.input_vocab_size, self.embed_len) 
        # DropoutLayer
        self.dropout_layer = nn.Dropout(p=self.dropout)
    
    def forward(self, input):
        # input.shape: [batch_size, seq_len], dtype: torch.int64
        
        # word_embedding
        word_embedding = self.word_embedding_layer(input)
        # position_embedding
        batch_size, seq_len = input.shape
        positions_vector = torch.arange(0, seq_len).expand(batch_size, seq_len).to(self.device)
        positional_encoding = self.position_embedding_layer(positions_vector)
        # output.shape: [bs, sl, embed_dim], dtype: torch.float32
        return self.dropout_layer(word_embedding + positional_encoding)

In [4]:
seq_len = 20
input = torch.randint(input_vocab_size, (batch_size, seq_len)).to(device)
print(input.shape)
embedding = InputEmbedding().to(device)
output = embedding(input)
print(output.shape, output.dtype)
# print(output)

torch.Size([8, 20])
torch.Size([8, 20, 512]) torch.float32


In [5]:
## Scaled dot-product Attention
class ScaledDotProduct(nn.Module):
    def __init__(self, embed_len=512, mask=None):
        super(ScaledDotProduct, self).__init__()
        self.embed_len = embed_len
        self.mask = mask
        self.d_k = embed_len
        self.softmax = nn.Softmax(dim=-1)
        
    def forward(self, queries, keys, values):
        """
        queries.shape: [batch_size, num_heads, seq_len, head_length]
        keys.shape: [batch_size, num_heads, seq_len, head_length]
        values.shape: [batch_size, seq_len, num_heads, head_length]
        """
        compatibility = torch.matmul(queries, torch.transpose(keys, 2, 3)) # [batch_size, num_heads, seq_len, seq_len]
        compatibility = compatibility / math.sqrt(self.d_k)
        
        compatibility = self.softmax(compatibility)
        if self.mask is not None:
            compatibility = torch.tril(compatibility)
        
        return torch.matmul(compatibility, torch.transpose(values, 1, 2)) # [batch_size, num_heads, seq_len, head_length]
        
    

In [6]:
## multihead attn block
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads=8, embed_len=512, mask=None):
        super(MultiHeadAttention, self).__init__()
        self.embed_len = embed_len
        self.num_heads = num_heads
        # self.batch_size = batch_size
        self.mask = mask
        self.head_length = int(self.embed_len/self.num_heads)
        
        self.q_in = self.k_in = self.v_in = self.embed_len
        # 线性变换层
        self.q_linear = nn.Linear(self.q_in, self.q_in)
        self.k_linear = nn.Linear(self.k_in, self.k_in)
        self.v_linear = nn.Linear(self.v_in, self.v_in)
        
        if self.mask is not None:
            self.attn = ScaledDotProduct(mask=True)
        else:
            self.attn = ScaledDotProduct()
        
        self.output_linear = nn.Linear(self.q_in, self.q_in)
    
    def forward(self, queries, keys, values):
        """
        queries.shape: [batch_size, seq_len, embed_dim]
        keys.shape: [batch_size, seq_len, embed_dim]
        values.shape: [batch_size, seq_len, embed_dim]
        """
        # 我们需要把QKV分拆为num_heads个头， [batch_size, seq_len, num_heads, head_length]
        # 然后transpose(1, 2) -> [batch_size, num_heads, seq_len, head_length] 传入ScaledDotProduct
        batch_size, seq_len = queries.shape[0], queries.shape[1]
        # 线性变换并切分
        queries = self.q_linear(queries).reshape(batch_size, seq_len, self.num_heads, self.head_length) # [batch_size, seq_len, num_heads, head_length]
        # 转置seq_len, num_heads维度
        queries = queries.transpose(1, 2) # [batch_size, num_heads, seq_len head_length]
        keys = self.k_linear(keys).reshape(batch_size, seq_len, self.num_heads, self.head_length)
        keys = keys.transpose(1, 2)
        values = self.v_linear(values).reshape(batch_size, seq_len, self.num_heads, self.head_length)
        
        # 传入attn
        sdp_output = self.attn(queries, keys, values) # [batch_size, num_heads, seq_len, head_length]
        # 转置回seq_len, num_heads
        sdp_output = sdp_output.transpose(1, 2)       # [batch_size, seq_len, num_heads, head_length]
        # concat多头
        sdp_output = sdp_output.reshape(batch_size, seq_len, self.embed_len) 
        # output: [batch_size, seq_len, embed_len]
        return self.output_linear(sdp_output)
        
    

In [7]:
## encoder block
class EncoderBlock(nn.Module):
    def __init__(self, embed_len=512, dropout=0.1):
        super(EncoderBlock, self).__init__()
        self.embed_len = embed_len
        self.dropout = dropout
        self.multihead_attn = MultiHeadAttention()
        self.firstnorm = nn.LayerNorm(self.embed_len)
        self.secondnorm = nn.LayerNorm(self.embed_len)
        self.dropout_layer = nn.Dropout(p=self.dropout)
        
        self.feed_forward = nn.Sequential(  nn.Linear(self.embed_len, self.embed_len*4),
                                            nn.ReLU(),
                                            nn.Linear(self.embed_len*4, self.embed_len))
        
        
    def forward(self, queries, keys, values):
        """
        queries.shape: [batch_size, seq_len, embed_dim]
        keys.shape: [batch_size, seq_len, embed_dim]
        values.shape: [batch_size, seq_len, embed_dim]
        """
        attn_output = self.multihead_attn(queries, keys, values) # [batch_size, seq_len, embed_len]
        # Add & Norm
        first_sublayer_output = self.firstnorm(attn_output + queries)
        # FeedForward
        ff_output = self.feed_forward(first_sublayer_output)
        ff_output = self.dropout_layer(ff_output)
        # Add & Norm
        return self.secondnorm(ff_output + first_sublayer_output) # output: [batch_size, seq_len, embed_len]
        

In [8]:
## decoder block
class DecoderBlock(nn.Module):
    def __init__(self, embed_len=embed_len, dropout=drop_out):
        super(DecoderBlock, self).__init__()
        self.embed_len = embed_len
        self.dropout = dropout
        
        self.maskedMultihead_attn = MultiHeadAttention(mask=True)
        self.firstnorm = nn.LayerNorm(self.embed_len)
        self.dropout_layer = nn.Dropout(p=dropout)
        
        self.encoder_block = EncoderBlock()
        
    def forward(self, queries, keys, values):
        masked_multihead_attn_output = self.maskedMultihead_attn(queries, keys, values)
        masked_multihead_attn_output = self.dropout_layer(masked_multihead_attn_output)
        first_sublayer_output = self.firstnorm(masked_multihead_attn_output)
        
        return self.encoder_block(first_sublayer_output, keys, values)

In [9]:
## Transformer
class Transformer(nn.Module):
    def __init__(self, embed_len=embed_len, stack_len=stack_len, device=device, output_vocab_size=output_vocab_size):
        super(Transformer, self).__init__()
        self.embed_len = embed_len
        self.stack_len = stack_len
        self.device = device
        self.output_vocab_size = output_vocab_size
        
        self.embedding = InputEmbedding().to(self.device)
        self.enc_stack = nn.ModuleList(EncoderBlock() for _ in range(stack_len))
        self.dec_stack = nn.ModuleList(DecoderBlock() for _ in range(stack_len))
        
        self.final_linear = nn.Linear(self.embed_len, self.output_vocab_size).to(self.device)
        self.softmax = nn.Softmax()
        
    def forward(self, test_input, test_output):
        enc_output = self.embedding(test_input)
        
        for enc_layer in self.enc_stack:
            enc_output = enc_layer(enc_output, enc_output, enc_output)
        
        dec_output = self.embedding(test_output)
        for dec_layer in self.dec_stack:
            dec_output = dec_layer(dec_output, enc_output, enc_output)
        
        final_output = self.final_linear(dec_output)
        return self.softmax(final_output)
        

In [10]:
input_tokens = torch.randint(input_vocab_size, (batch_size, 30)).to(device)
tgt_tokens = torch.randint(output_vocab_size, (batch_size, 30)).to(device)

transformer = Transformer().to(device)
transformer_output = transformer(input_tokens, tgt_tokens)


  return self.softmax(final_output)


In [11]:
print(transformer_output.size())

torch.Size([8, 30, 7000])
