**Pytorch implementation of Transformer from scratch Demo**

Note: Many codes and figures are borrowed from:
- https://github.com/BoXiaolei/MyTransformer_pytorch/blob/main/MyTransformer.ipynb
- https://nn.labml.ai/index.html 
- https://github.com/hyunwoongko/transformer 
- https://jalammar.github.io/illustrated-transformer/
- https://lilianweng.github.io/posts/2023-01-27-the-transformer-family-v2/ 

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import math 
import torch.utils.data as data
import numpy as np

<div>
<img src="images/transformer.png" width="800"/>
</div>

# 1. Positional Encoding

The positional encodings have the same dimension d_model as the embeddings, so that the two can be summed

## Sinusoidal Positional Encoding

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model = 512, dropout = 0.1, max_len = 5000):
        '''
        d_model: dimension of the word embedding, 512
        max_len: the maximum number of tokens in a sentence, 5000
        '''
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        # Generate a max_len * d_model matrix, that is, 5000 * 512
        # 5000 is the maximum number of tokens in a sentence, 
        # and 512 is the length of a token represented by a vector.
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) # pos：[max_len,1]
        # First calculate the fraction in the brackets, pos is [5000,1], the denominator is [256],
        # and the result of multiplication by broadcasting is [5000,256]
        div_term = pos / pow(10000.0,torch.arange(0, d_model, 2).float() / d_model)
        pe[:, 0::2] = torch.sin(div_term)
        pe[:, 1::2] = torch.cos(div_term)
        # A sentence needs to do pe once, and there will be multiple sentences in a batch,
        # so add one dimension to do broadcast when adding to a batch of input data
        pe = pe.unsqueeze(0) # [5000,512] -> [1,5000,512] 
        # English: register_buffer is used to save the parameters that will not be updated, and the parameters are saved in the buffer
        self.register_buffer('pe', pe)
        
        
    def forward(self, x):
        '''x: [batch_size, seq_len, d_model]'''
        # 5000 is the maximum seq_len we have defined, that is to say, we have calculated the pe for the most situation
        x = x + self.pe[:, :x.size(1), :]  # Note: residual connection
        return self.dropout(x) 
        # return: [batch_size, seq_len, d_model]
        

## Learned Positional Encoding

Learned positional encoding assigns each element with a learned column vector which encodes its absolute position (Gehring, et al. 2017) and furthermroe this encoding can be learned differently per layer (Al-Rfou et al. 2018).

# 2. Mask

## Pading Mask

In the Transformer architecture, a **padding mask** is used to handle **sequences of different lengths**. Here's why it's important:

- **Variable Sequence Lengths**: In many natural language processing tasks, input sequences (like sentences) vary in length. However, neural networks, including Transformers, typically require inputs of a fixed size. To manage this, shorter sequences are often padded with special tokens (like [PAD]) to match the length of the longest sequence in a batch.

- Ignoring Padding Tokens during Training: The padding tokens are not actual data; they're just placeholders. It's crucial that the model doesn't treat these padding tokens as meaningful input. The padding mask is a mechanism to ensure that the model ignores these tokens during training and inference. It does this by zeroing out (masking) the padding tokens' impact on the model's output.

- Attention Mechanism Efficiency: Transformers use an attention mechanism to weigh the importance of different parts of the input sequence. Without a padding mask, the attention mechanism might incorrectly assign significance to the padding tokens, leading to less accurate or meaningful outputs.

- Preventing Data Leakage: In certain cases, especially in language modeling tasks, padding at the beginning or end of sequences might inadvertently reveal information about the sequence. A mask helps ensure that the model's predictions are based solely on actual data, not on these artificial padding tokens.

#### In summary, the pad mask in Transformer architectures is a critical component for handling variable-length input sequences effectively and ensuring that the model's attention mechanism focuses on the meaningful parts of the input, thereby improving the quality and relevance of the model's outputs.

- When is this calculated mask used?

After the multiplication of query and key's transpose, resulting in the attention score matrix of size (len_q,len_k), the mask obtained from this function is used to cover the results of the matrix multiplication. In the original multiplication result matrix (len_q,len_k), the meaning of the element in the ith row and jth column is "the attention score of the ith word in the q sequence to the jth word in the k sequence". The entire ith row represents the attention of this word in q to all words in k, and the entire jth column represents the attention of all words in q to the jth word in k. As padding, none of the words in q should pay attention to it, hence the corresponding columns should be set to True.

- Why is only the padding position of k masked, and not that of q? (i.e., why is it that only the last few columns of the return matrix of this function are True, and not the last few rows as well?)

Logically, it should be like this: as padding, it should neither attract attention nor pay attention to others. The attention that the padding calculates towards other words is also meaningless. Here, we are actually cutting corners, but this is because: the attention of the padding in q to the words in k is not going to be used, as we won't use a padding character to predict the next word. Moreover, its vector representation, no matter how it's updated, will not affect the calculations of other words in q, so we let it be. However, the padding in k is different. If it's not managed, it will meaninglessly absorb a lot of attention from the words in q, leading to biases in the model's learning.

In [3]:
# The sentences we input into the model vary in length, and we use a placeholder 'P' to pad them to the length of the longest sentence. These placeholders are meaningless and we set these positions to True. The following function returns a Boolean tensor indicating whether the position is a placeholder.
# Return: tensor [batch_size, len_q, len_k]，True means the position is a placeholder

def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, len_q]
    seq_k: [batch_size, len_k]
    '''
    batch_size, len_q = seq_q.size()
    _,          len_k = seq_k.size()
    # seq_k.data.eq(0):，element in seq_k will be True (if ele == 0), False otherwise.
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1) # pad_attn_mask: [batch_size,1,len_k]

    # To provide a k for each q, so the second dimension is expanded q times.
    # Expand is not really doubling the memory, it just repeats the reference, and any modification to any reference will modify the original value.
    # Here we use it to save memory because we won't modify this mask.
    return pad_attn_mask.expand(batch_size, len_q, len_k) # return: [batch_size, len_q, len_k]
    # Return batch_size len_q * len_k matrix, content is True and False, True means the position is a placeholder.
    # The i-th row and the j-th column indicate whether the attention of the i-th word of the query to the j-th word of the key is meaningless. If it is meaningless, it is True. If it is meaningful, it is False (that is, the position of the padding is True)

In [4]:
# test for get_attn_pad_mask
seq_q = torch.tensor([[1,2,3,4,0,0],[3,4,5,6,7,0],[2,3,4,0,0,0]])
seq_k = torch.tensor([[1,2,3,4,5],[1,2,0,0,0],[1,2,3,0,0]])
print(get_attn_pad_mask(seq_q, seq_k))

tensor([[[False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False],
         [False, False, False, False, False]],

        [[False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False,  True,  True,  True]],

        [[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]]])


## Subsequence Mask

This mask is used in the first "Masked Multi-Head self Attention" module in Decoder of Transformer. The goal is to prevent the model seeing the future input.

Take the example below:

Assuming one sentence with 5 tokens as the decoder input. The i-th row and j-th column denotes the attention of i-th token to j-th token.

For the i-th token (row), it can only see itself and tokens before it, and the tokens after it will be filtered. So 1 means filtering, and 0 means keeping.

<div>
<img src="images/subsequenceMask.png" width="500"/>
</div>

In [5]:
# To prevent positions from attending to subsequent positions
def get_attn_subsequence_mask(seq):
    """
    seq: [batch_size, tgt_len]
    This is only used in decoder, so the length of seq is the length of target sentence.
    """
    batch_size, tgt_len = seq.shape
    attn_shape = [batch_size, tgt_len, tgt_len]
    # np.triu: Return a copy of a matrix with the elements below the k-th diagonal zeroed.
    # np.triu is to generate an upper triangular matrix, k is the offset relative to the main diagonal
    # k = 1 means not including the main diagonal (starting from the main diagonal offset 1
    # subsequence_mask = np.triu(np.ones(attn_shape), k=1)
    # subsequence_mask = torch.from_numpy(subsequence_mask).byte() 
    subsequence_mask = torch.triu(torch.ones(attn_shape), diagonal=1).byte() #.byte() is equivalent to .to(torch.uint8)
    # Because there are only 0 and 1, byte is used to save memory.
    return subsequence_mask  # return: [batch_size, tgt_len, tgt_len]

### test for get_attn_subsequence_mask
seq = torch.tensor([[1,2,3,0,0],[1,2,3,4,0]])
print(get_attn_subsequence_mask(seq))

tensor([[[0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1],
         [0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0]]], dtype=torch.uint8)


# 3. Scaled Dot Product Attention

<div>
<img src="images/ScaledDotProductAttention.png" width="500"/>
</div>

In [6]:

class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super(ScaledDotProductAttention, self).__init__()

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, n_heads, len_q, d_k]  
        K: [batch_size, n_heads, len_k, d_k]
        V: [batch_size, n_heads, len_v(=len_k), d_v] 
        Two types of attention:
        1) self attention
        2) cross attention: K and V are encoder's output, so the shape of K and V are the same
        attn_mask: [batch_size, n_heads, seq_len, seq_len]
        
        # len_q is not necessary to be euqal to len_k
        because len_q is the length of the query sentence (decoder input, or predicted sentence in the 2) attention operation),
        and len_k is the length of the key sentence (encoder input, or source sentence).
        '''
        batch_size, n_heads, len_q, d_k = Q.shape 
        # 1) computer attention score QK^T/sqrt(d_k)
        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores: [batch_size, n_heads, len_q, len_k]
        # 2) mask operation (option), only used in the decoder
        if attn_mask is not None:
            scores.masked_fill_(attn_mask, float('-inf')) # or float('-inf'), -1e9: negative infinity
            # Fills elements of self tensor with value where mask is True.
            # The masked elements in the scores are replaced by -1e9, 
            # so that the softmax operation will make the value of the masked position close to 0.

        # 3) softmax to get attention weights
        attn = nn.Softmax(dim=-1)(scores)  # attn: [batch_size, n_heads, len_q, len_k]
        # 4) use attention weights to weigh the value V
        context = torch.matmul(attn, V)  # context: [batch_size, n_heads, len_q, d_v]
        '''
        返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子，
        只不过每个句子中词向量维度512被分成了8个部分，分别由8个头各自看一部分，每个头算的是整个句子(一列)的512/8=64个维度，最后按列拼接起来
        '''
        return context

# 4. MultiHeadAttention

<div>
<img src="images/MultiHeadAttention_v2.png" width="600"/>
</div>

In [7]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model = 512, n_heads = 8, dropout_rate = 0.0):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.n_heads = n_heads
        self.dropout_rate = dropout_rate
        self.head_dim = d_model // n_heads
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.scaled_dot_product_attention = ScaledDotProductAttention()
        self.W_O = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout_rate)

    def forward(self, Q, K, V, attn_mask):
        '''
        Q: [batch_size, len_q, d_model]
        K: [batch_size, len_k, d_model]
        V: [batch_size, len_v, d_model] 
        attn_mask: [batch_size, seq_len, seq_len]
        '''
        batch_size, len_q, d_model = Q.shape
        batch_size, len_k, d_model = K.shape
        batch_size, len_v, d_model = V.shape
        
        # 1) linear projection
        Q = self.W_Q(Q)
        K = self.W_K(K)
        V = self.W_V(V)
        
        # 2) split by heads
        # [batch_size, len_q, d_model] -> [batch_size, len_q, n_heads, head_dim]
        Q = Q.reshape(batch_size, len_q, self.n_heads, self.head_dim)
        K = K.reshape(batch_size, len_k, self.n_heads, self.head_dim)
        V = V.reshape(batch_size, len_v, self.n_heads, self.head_dim)
        
        # 3) transpose for attention dot product
        # [batch_size, len_q, n_heads, head_dim] -> [batch_size, n_heads, len_q, head_dim]
        Q = Q.transpose(1, 2)
        K = K.transpose(1, 2)
        V = V.transpose(1, 2)
        
        # 4) attention
        # attn_mask: [batch_size, seq_len, seq_len] -> [batch_size, n_heads, len_q, len_k]
        attn_mask = attn_mask.unsqueeze(1).repeat(1, self.n_heads, 1, 1)
        # attn_mask = attn_mask.unsqueeze(1).expand(-1, self.n_heads, -1, -1)
        context = self.scaled_dot_product_attention(Q, K, V, attn_mask)
        # context: [batch_size, n_heads, len_q, head_dim]
        
        # 5) concat heads
        # method 1:
        output = context.transpose(1, 2).reshape(batch_size, len_q, self.d_model)
        # output: [batch_size, len_q, d_model]
        
        # method 2:
        # output = torch.cat([context[:,i,:,:] for i in range(self.n_heads)], dim=-1)
        # output: [batch_size, len_q, d_model]
        
        # 6) linear projection (concat heads)
        output = self.W_O(output)
        return output # output: [batch_size, len_q, d_model]

# 5. Feed-Forward Networks

In [8]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, 
                 d_model = 512, 
                 d_ff = 2048, 
                 dropout_rate = 0.0):
        super(PositionwiseFeedForward, self).__init__()
        self.d_model = d_model
        self.d_ff = d_ff # dimension of latent layer for feed forward neural network
        self.W_1 = nn.Linear(d_model, d_ff)
        self.W_2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout_rate)
        self.relu = nn.ReLU()
    def forward(self, x):
        '''
        x: [batch_size, seq_len, d_model]
        '''
        output = self.relu(self.W_1(x))
        output = self.W_2(output)
        
        return output

# 5. LayerNorm

For each sample, normalize across features.

<div>
<img src="images/LayerNorm.png" width="700"/>
</div>

In [9]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        """_summary_

        Args:
            x (_type_): _description_

        Returns:
            _type_: _description_
        """
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # '-1' means last dimension. 

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out


# 7. Encoder

## Encoder Layer

<div>
<img src="images/encoder.png" width="600"/>
</div>

In [10]:

class EncoderLayer(nn.Module):
    def __init__(self, 
                 d_model = 512, 
                 d_ff = 2048, 
                 n_heads = 8, 
                 dropout_rate = 0.0):
        super(EncoderLayer, self).__init__()
        self.enc_self_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_ff, dropout_rate)
        self.layer_norm1 = LayerNorm(d_model)
        self.layer_norm2 = LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
        
    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs:         [batch_size, src_len, d_model]
        enc_self_attn_mask: [batch_size, src_len, src_len]
        '''
        # Sublayer 1: self attention
        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask)
        # enc_outputs: [batch_size, src_len, d_model]
        
        # add & norm
        enc_outputs = self.layer_norm1(enc_inputs + enc_outputs)
        # enc_outputs: [batch_size, src_len, d_model]
        
        # Sublayer 2: position-wise feed forward network
        enc_ff_outputs = self.pos_ffn(enc_outputs)
        # enc_ff_outputs: [batch_size, src_len, d_model]
        
        # add & norm
        enc_outputs = self.layer_norm2(enc_outputs + enc_ff_outputs)
        # enc_outputs: [batch_size, src_len, d_model]
        
        return enc_outputs

## Encoder

<div>
<img src="images/EncoderDecoder_2.png" width="600"/>
</div>

In [11]:

class Encoder(nn.Module):
    def __init__(
        self,
        d_model = 512, 
        d_ff = 2048, 
        n_heads = 8, 
        n_layers = 6,
        dropout_rate = 0.0, 
        ):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([EncoderLayer(d_model, d_ff, n_heads, dropout_rate) for _ in range(n_layers)])
    
    def forward(self, enc_inputs, enc_self_attn_mask):
        '''
        enc_inputs: [batch_size, src_len, d_model]
        '''
        enc_outputs = enc_inputs
        # encoding
        for layer in self.layers:
            enc_outputs = layer(enc_outputs, enc_self_attn_mask)
        # enc_outputs: [batch_size, src_len, d_model]
        
        return enc_outputs

# 8. Decoder

## Decoder Layer

<div>
<img src="images/decoder.png" width="600"/>
</div>

In [12]:

class DecoderLayer(nn.Module):
    def __init__(self, 
                 d_model = 512, 
                 d_ff = 2048, 
                 n_heads = 8, 
                 dropout_rate = 0.1):
        super(DecoderLayer, self).__init__()
        self.dec_self_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)
        self.dec_cros_attn = MultiHeadAttention(d_model, n_heads, dropout_rate)
        self.pos_ffn = PositionwiseFeedForward(d_model, d_ff, dropout_rate)
        self.layer_norm_1 = LayerNorm(d_model)
        self.layer_norm_2 = LayerNorm(d_model)
        self.layer_norm_3 = LayerNorm(d_model)
        
    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask):
        '''
        dec_inputs:         [batch_size, tgt_len, d_model]
        enc_outputs:        [batch_size, src_len, d_model]
        dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        dec_enc_attn_mask:  [batch_size, tgt_len, src_len]
        '''
        # Sublayer 1: self attention
        dec_outputs = self.dec_self_attn(dec_inputs, dec_inputs, dec_inputs, dec_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model]
        
        # add & norm
        dec_outputs = self.layer_norm_1(dec_inputs + dec_outputs)
        # dec_outputs: [batch_size, tgt_len, d_model]
        
        # Sublayer 2: encoder-decoder cross attention, 
        # Q (decoder), K (encoder output), V (encoder output)
        dec_outputs_2 = self.dec_cros_attn(dec_outputs, enc_outputs, enc_outputs, dec_enc_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model]
        
        # add & norm
        dec_outputs_2 = self.layer_norm_2(dec_outputs + dec_outputs_2)
        # dec_outputs_2: [batch_size, tgt_len, d_model]
        
        # Sublayer 3: position-wise feed forward network
        dec_outputs_3 = self.pos_ffn(dec_outputs_2)
        # dec_ff_outputs: [batch_size, tgt_len, d_model]
        
        # add & norm
        dec_outputs_3 = self.layer_norm_3(dec_outputs_2 + dec_outputs_3)
        # dec_outputs_3: [batch_size, tgt_len, d_model]
        
        return dec_outputs_3 

## Decoder

- nn.Embedding
- PositionalEncoding
- 6 DecoderLayer

<div>
<img src="images/EncoderDecoder_2.png" width="600"/>
</div>

In [13]:

class Decoder(nn.Module):
    def __init__(self, 
                 d_model = 512, 
                 d_ff = 2048, 
                 n_heads = 8, 
                 n_layers = 6,
                 dropout_rate = 0.1):
        super(Decoder, self).__init__()
        self.layers = nn.ModuleList([DecoderLayer(d_model, d_ff, n_heads, dropout_rate) for _ in range(n_layers)])
    
    def forward(self, dec_inputs, enc_outputs, dec_self_attn_mask, dec_self_attn_subsequence_mask, memory_self_attn_mask):
        '''
        dec_inputs:         [batch_size, tgt_len]
        enc_outputs:        [batch_size, src_len, d_model]
        sometimes, enc_outputs is also called memory in the paper
        '''
        # combine two masks in decoder self attention
        dec_self_attn_mask = torch.gt((dec_self_attn_mask + dec_self_attn_subsequence_mask), 0)
        # dec_self_attn_mask: [batch_size, tgt_len, tgt_len]
        
        output = dec_inputs
        # decoding
        for layer in self.layers:
            output = layer(output, enc_outputs, dec_self_attn_mask, memory_self_attn_mask)
        # dec_outputs: [batch_size, tgt_len, d_model]
        
        return output

# 9. Transformer

- Encoder
- Decoder
- Dense

<div>
<img src="images/EncoderDecoder_1.png" width="600"/>
</div>

In [14]:

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens): # tokens: [batch_size, seq_len]
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)


# The Transformer model
class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 tgt_vocab_size, 
                 d_model = 512, 
                 d_ff = 2048, 
                 n_heads = 8, 
                 n_layers = 6,
                 dropout_rate = 0.1):
        super(Transformer, self).__init__()
        self.src_tok_emb = TokenEmbedding(src_vocab_size, d_model) #[batch_size, src_len, d_model]
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, d_model) #[batch_size, tgt_len, d_model]
        self.pos_embedding = PositionalEncoding(d_model, dropout_rate)

        self.Encoder = Encoder(d_model, d_ff, n_heads, n_layers, dropout_rate)
        self.Decoder = Decoder(d_model, d_ff, n_heads, n_layers, dropout_rate)
        self.generator = nn.Linear(d_model, tgt_vocab_size)
        
    def create_masks(self, src, tgt):
        '''
        src: [batch_size, src_len]
        tgt: [batch_size, tgt_len]
        
        ''' 
        # padding mask for encoder self attention: enc_self_attn_pad_mask
        src_key_padding_mask = get_attn_pad_mask(src, src) #[batch_size, src_len, src_len]
        
        # padding mask for decoder self attention: dec_self_attn_pad_mask
        tgt_key_padding_mask = get_attn_pad_mask(tgt, tgt) # [batch_size, tgt_len, tgt_len]
        
        # encoder-decoder cross attention mask: cross_attn_mask
        memory_key_padding_mask = get_attn_pad_mask(tgt, src) #[batch_size, tgt_len, src_len]
        
        # sequence mask (only exists in decoder)
        tgt_mask = get_attn_subsequence_mask(tgt) # [batch_size, tgt_len, tgt_len]
        
        if tgt_key_padding_mask.is_cuda:
            tgt_mask = tgt_mask.cuda()
        
        return  tgt_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask


    def forward(self, enc_inputs, dec_inputs):
        '''
        enc_inputs: [batch_size, src_len]
        dec_inputs: [batch_size, tgt_len]
        '''
        # embedding + position encoding
        src_emb = self.pos_embedding(self.src_tok_emb(enc_inputs))
        tgt_emb = self.pos_embedding(self.tgt_tok_emb(dec_inputs))
        
        # prepare masks
        tgt_mask, src_key_padding_mask, tgt_key_padding_mask, memory_key_padding_mask = self.create_masks(enc_inputs, dec_inputs)
        
        # encoding
        enc_outputs = self.Encoder(src_emb, src_key_padding_mask)
        # enc_outputs: [batch_size, src_len, d_model]
        
        # decoding
        dec_outputs = self.Decoder(tgt_emb, enc_outputs, tgt_key_padding_mask, tgt_mask, memory_key_padding_mask)
        # dec_outputs: [batch_size, tgt_len, d_model]
        
        # projection
        dec_logits = self.generator(dec_outputs)
        # dec_logits: [batch_size, tgt_len, tgt_vocab_size]
        
        return dec_logits.view(-1, dec_logits.size(-1))  #  [batch_size * tgt_len, tgt_vocab_size]
        # there are batch_size sentences in one batch
        # first tgt_len words are the prediction probability of the first sentence,
        # then the next tgt_len words are the prediction probability of the second sentence, and so on.

# 10. Demo of Training

## Dataset Preparation 

In [15]:
# S: Start of sentence
# E: End of sentence
# P: padding，make sure the length of the sentence is the same

sentence = [
    # enc_input   dec_input    dec_output
    ['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]
# source vocab
src_vocab = {'P':0, 'ich':1,'mochte':2,'ein':3,'bier':4,'cola':5}
src_vocab_size = len(src_vocab) # 6

# target vocab (including special symbols)
tgt_vocab = {'P':0,'i':1,'want':2,'a':3,'beer':4,'coke':5,'S':6,'E':7,'.':8}

# reverse mapping dictionary, idx ——> word
idx2word = {v:k for k,v in tgt_vocab.items()}
tgt_vocab_size = len(tgt_vocab) # 9

src_len = 5
# the length of the longest sentence in the input sequence, 
# which is actually the number of tokens in the longest sentence
tgt_len = 6
# the length of the longest sentence in the dec_input/dec_output sequence

In [16]:
# 这个函数把原始输入序列转换成token表示
def make_data(sentence):
    enc_inputs, dec_inputs, dec_outputs = [],[],[]
    for i in range(len(sentence)):
        enc_input = [src_vocab[word] for word in sentence[i][0].split()]
        dec_input = [tgt_vocab[word] for word in sentence[i][1].split()]
        dec_output = [tgt_vocab[word] for word in sentence[i][2].split()]
        
        enc_inputs.append(enc_input)
        dec_inputs.append(dec_input)
        dec_outputs.append(dec_output)
        
    # LongTensor是专用于存储整型的，Tensor则可以存浮点、整数、bool等多种类型
    return torch.LongTensor(enc_inputs),torch.LongTensor(dec_inputs),torch.LongTensor(dec_outputs)

enc_inputs, dec_inputs, dec_outputs = make_data(sentence)

print(' enc_inputs: \n', enc_inputs)  # enc_inputs: [2,5]
print(' dec_inputs: \n', dec_inputs)  # dec_inputs: [2,6]
print(' dec_outputs: \n', dec_outputs) # dec_outputs: [2,6]

 enc_inputs: 
 tensor([[1, 2, 3, 4, 0],
        [1, 2, 3, 5, 0]])
 dec_inputs: 
 tensor([[6, 1, 2, 3, 4, 8],
        [6, 1, 2, 3, 5, 8]])
 dec_outputs: 
 tensor([[1, 2, 3, 4, 8, 7],
        [1, 2, 3, 5, 8, 7]])


In [17]:
class MyDataSet(data.Dataset):
    def __init__(self,enc_inputs, dec_inputs, dec_outputs):
        super(MyDataSet,self).__init__()
        self.enc_inputs = enc_inputs
        self.dec_inputs = dec_inputs
        self.dec_outputs = dec_outputs
        
    def __len__(self):
        # in the example above, enc_inputs.shape = [2,5], so return 2
        return self.enc_inputs.shape[0] 
    
    # return a set of enc_input, dec_input, dec_output, according to the index
    def __getitem__(self, idx):
        return self.enc_inputs[idx], self.dec_inputs[idx], self.dec_outputs[idx]

# DataLoader
loader = data.DataLoader(dataset=MyDataSet(enc_inputs,dec_inputs, dec_outputs),batch_size=2,shuffle=True)
len(loader)

1

## Model Training

In [18]:

model = Transformer(src_vocab_size, 
                    tgt_vocab_size, 
                    d_model = 512, 
                    d_ff = 2048, 
                    n_heads = 8, 
                    n_layers = 6,
                    dropout_rate = 0.1).cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss（因为是padding无意义）
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)

# training
num_epochs = 10
for epoch in range(num_epochs):
    for enc_inputs, dec_inputs, dec_outputs_true in loader:
        '''
        enc_inputs: [batch_size, src_len] [2,5]
        dec_inputs: [batch_size, tgt_len] [2,6]
        dec_outputs_true: [batch_size, tgt_len] [2,6]
        '''
        enc_inputs, dec_inputs, dec_outputs_true = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs_true.cuda()
        outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]

        # print("pred: ", outputs.shape, outputs)
        # print("true: ", dec_outputs_true.view(-1).shape, dec_outputs_true.view(-1))
        loss = criterion(outputs, dec_outputs_true.view(-1))  # 将flatten dec_outputs_true作为target

        # weight updates
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        print(f'Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item()}')
        # break

torch.save(model, 'MyTransformer.pth')

Epoch [1/10], Loss: 2.3747642040252686
Epoch [2/10], Loss: 2.3079771995544434
Epoch [3/10], Loss: 1.9671508073806763
Epoch [4/10], Loss: 1.7235007286071777
Epoch [5/10], Loss: 1.407813549041748
Epoch [6/10], Loss: 1.1282033920288086
Epoch [7/10], Loss: 0.8996966481208801
Epoch [8/10], Loss: 0.6844269633293152
Epoch [9/10], Loss: 0.5365076661109924
Epoch [10/10], Loss: 0.369119793176651


## Model Inference

In [20]:
# 原文使用的是大小为4的beam search，这里为简单起见使用更简单的greedy贪心策略生成预测，不考虑候选，每一步选择概率最大的作为输出
# 如果不使用greedy_decoder，那么我们之前实现的model只会进行一次预测得到['i']，并不会自回归，所以我们利用编写好的Encoder-Decoder来手动实现自回归（把上一次Decoder的输出作为下一次的输入，直到预测出终止符）
def greedy_decoder(model, enc_input, start_symbol):
    """enc_input: [1, seq_len] 对应一句话"""
    enc_outputs = model.Encoder(enc_input) # enc_outputs: [1, seq_len, 512]
    # 生成一个1行0列的，和enc_inputs.data类型相同的空张量，待后续填充
    dec_input = torch.zeros(1, 0).type_as(enc_input.data) # .data避免影响梯度信息
    next_symbol = start_symbol
    flag = True
    while flag:
        # dec_input.detach() 创建 dec_input 的一个分离副本
        # 生成了一个 只含有next_symbol的（1,1）的张量
        # -1 表示在最后一个维度上进行拼接cat
        # 这行代码的作用是将next_symbol拼接到dec_input中，作为新一轮decoder的输入
        dec_input = torch.cat([dec_input.detach(), torch.tensor([[next_symbol]], dtype=enc_input.dtype).cuda()], -1) # dec_input: [1,当前词数]
        dec_outputs = model.Decoder(dec_input, enc_input, enc_outputs) # dec_outputs: [1, tgt_len, d_model]
        projected = model.generator(dec_outputs) # projected: [1, 当前生成的tgt_len, tgt_vocab_size]
        # max返回的是一个元组（最大值，最大值对应的索引），所以用[1]取到最大值对应的索引, 索引就是类别，即预测出的下一个词
        # keepdim为False会导致减少一维
        prob = projected.squeeze(0).max(dim=-1, keepdim=False)[1] # prob: [1],
        # prob是一个一维的列表，包含目前为止依次生成的词的索引，最后一个是新生成的（即下一个词的类别）
        # 因为注意力是依照前面的词算出来的，所以后生成的不会改变之前生成的
        next_symbol = prob.data[-1]
        if next_symbol == tgt_vocab['.']:
            flag = False
        print(next_symbol)
    return dec_input  # dec_input: [1,tgt_len]


# 测试
model = torch.load('MyTransformer.pth')
model.eval()
with torch.no_grad():
    # 手动从loader中取一个batch的数据
    enc_inputs, _, _ = next(iter(loader))
    enc_inputs = enc_inputs.cuda()
    for i in range(len(enc_inputs)):
        greedy_dec_input = greedy_decoder(model, enc_inputs[i].view(1, -1), start_symbol=tgt_vocab['S'])
        predict  = model(enc_inputs[i].view(1, -1), greedy_dec_input) # predict: [batch_size * tgt_len, tgt_vocab_size]
        predict = predict.data.max(dim=-1, keepdim=False)[1]
        '''greedy_dec_input是基于贪婪策略生成的，而贪婪解码的输出是基于当前时间步生成的假设的输出。这意味着它可能不是最优的输出，因为它仅考虑了每个时间步的最有可能的单词，而没有考虑全局上下文。
        因此，为了获得更好的性能评估，通常会将整个输入序列和之前的假设输出序列传递给模型，以考虑全局上下文并允许模型更准确地生成输出
        '''
        print(enc_inputs[i], '->', [idx2word[n.item()] for n in predict])


TypeError: forward() missing 1 required positional argument: 'enc_self_attn_mask'

In [None]:
# 探究一下多头注意力从(batch_size, seq_len, d_model) 到 (batch_size,n_heads, seq_len, d_k/v)的意义

# 1、这是初始的q
q = torch.arange(120).reshape(2,5,12)
print(q)
print('------------------')
batch_size = 2
seq_len = 5
d_model = 12
n_heads = 3
d_k = 4

# 2、分成n_heads个头
new_q = q.view(batch_size, -1, n_heads, d_k).transpose(1,2)
# 上面一行代码的形状变化：(2,5,12) -> (2,5,3,4) -> (2,3,5,4)
# 意义变化：最初是batch_size为2，一个batch中有2个句子，一个句子包含5个词，每个词由长度为12的向量表示
# 最后仍然是batch_size为2，但一个batch中有3个头，每个头包含一个句子，每个句子包含5个词，但每个词由长度为4的向量表示

print(new_q)
print(new_q.shape) # torch.Size([2, 3, 5, 4])
print('------------------')

# 3、将n_heads个头合并
final_q = q.transpose(1,2).contiguous().view(batch_size, -1, d_model)
print(final_q)
print(final_q.shape)
print('------------------')

# 按原来的concat实现拼回去元素顺序和最初不同了，因此改成下面这种实现
final_q2 = torch.cat([new_q[:,i,:,:] for i in range(new_q.size(1))], dim=-1)
print(final_q2)
print(final_q2.shape)

# 
