In [21]:
import torch
import torch.nn as nn
import math
from torch import Tensor

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
        """
        Args:
            d_model:      嵌入的维度
            n_heads:      自注意力头的数量
            dropout:      丢弃概率
        """
        super().__init__()
        assert d_model % n_heads == 0            # 确保头的数量可以整除嵌入的维度
        self.d_model = d_model                   # 512 维度
        self.n_heads = n_heads                   # 8 个头
        self.d_key = d_model // n_heads          # 假设 d_value 等于 d_key | 512/8=64

        self.Wq = nn.Linear(d_model, d_model)    # 查询权重
        self.Wk = nn.Linear(d_model, d_model)    # 键权重
        self.Wv = nn.Linear(d_model, d_model)    # 值权重
        self.Wo = nn.Linear(d_model, d_model)    # 输出权重

        self.dropout = nn.Dropout(p=dropout)     # 初始化 dropout 层

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
        """
        Args:
            query:         查询向量         (batch_size, q_length, d_model)
            key:           键向量           (batch_size, k_length, d_model)
            value:         值向量           (batch_size, s_length, d_model)
            mask:          解码器的掩码     

        Returns:
            output:        注意力值         (batch_size, q_length, d_model)
            attn_probs:    softmax 分数     (batch_size, n_heads, q_length, k_length)
        """
        batch_size = key.size(0)                  

        # 计算查询、键和值张量
        Q = self.Wq(query)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)
        K = self.Wk(key)                         # (32, 10, 512) x (512, 512) = (32, 10, 512)
        V = self.Wv(value)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)

        # 将每个张量分割为 n 个头以计算注意力
        # 查询张量
        Q = Q.view(batch_size,                   
                   -1,                           
                   self.n_heads,              
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        # 键张量
        K = K.view(batch_size,                   
                   -1,                           
                   self.n_heads,              
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        # 值张量
        V = V.view(batch_size,                   
                   -1,                           
                   self.n_heads, 
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        
        # 计算注意力
        # 缩放点积 -> QK^{T}
        scaled_dot_prod = torch.matmul(Q,        
                                       K.permute(0, 1, 3, 2)
                                       ) / math.sqrt(self.d_key)      

        # 将掩码位置为 0 的位置填充为 (-1e10)
        if mask is not None:
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)

        # 应用 softmax 
        attn_probs = torch.softmax(scaled_dot_prod, dim=-1)
        
        # 乘以值以获得注意力
        A = torch.matmul(self.dropout(attn_probs), V)       
                                                       
        # 将注意力重塑回 (32, 10, 512)
        A = A.permute(0, 2, 1, 3).contiguous()              
        A = A.view(batch_size, -1, self.n_heads*self.d_key) 
        
        # 通过最终权重层
        output = self.Wo(A)                                 

        return output, attn_probs                           


In [22]:
class Embeddings(nn.Module):
  def __init__(self, vocab_size: int, d_model: int):
    """
    Args:
      vocab_size:     size of vocabulary
      d_model:        dimension of embeddings
    """
    # inherit from nn.Module
    super().__init__()   
     
    # embedding look-up table (lut)                          
    self.lut = nn.Embedding(vocab_size, d_model)   

    # dimension of embeddings 
    self.d_model = d_model                          

  def forward(self, x: Tensor):
    """
    Args:
      x: input Tensor (batch_size, seq_length)
      
    Returns:
        embedding vector
    """
    # embeddings by constant sqrt(d_model)
    # return self.lut(x) * math.sqrt(self.d_model)  
    return self.lut(x)

In [23]:
class PositionalEncoding(nn.Module):
  def __init__(self, d_model: int, dropout: float = 0.1, max_length: int = 5000):
    """
    Args:
      d_model:      dimension of embeddings
      dropout:      randomly zeroes-out some of the input
      max_length:   max sequence length
    """
    # inherit from Module
    super().__init__()     

    # initialize dropout                  
    self.dropout = nn.Dropout(p=dropout)      

    # create tensor of 0s
    pe = torch.zeros(max_length, d_model)    

    # create position column   
    k = torch.arange(0, max_length).unsqueeze(1)  

    # calc divisor for positional encoding 
    div_term = torch.exp(                                 
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
    )

    # calc sine on even indices
    pe[:, 0::2] = torch.sin(k * div_term)    

    # calc cosine on odd indices   
    pe[:, 1::2] = torch.cos(k * div_term)  

    # add dimension     
    pe = pe.unsqueeze(0)          

    # buffers are saved in state_dict but not trained by the optimizer                        
    self.register_buffer("pe", pe)                        

  def forward(self, x: Tensor):
    """
    Args:
      x:        embeddings (batch_size, seq_length, d_model)
    
    Returns:
                embeddings + positional encodings (batch_size, seq_length, d_model)
    """
    # add positional encoding to the embeddings
    x = x + self.pe[:, : x.size(1)].requires_grad_(False) 

    # perform dropout
    return self.dropout(x)

In [24]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model: int, d_ffn: int, dropout: float = 0.1):
    """
    Args:
        d_model:      dimension of embeddings
        d_ffn:        dimension of feed-forward network
        dropout:      probability of dropout occurring
    """
    super().__init__()

    self.w_1 = nn.Linear(d_model, d_ffn)
    self.w_2 = nn.Linear(d_ffn, d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    """
    Args:
        x:            output from attention (batch_size, seq_length, d_model)
       
    Returns:
        expanded-and-contracted representation (batch_size, seq_length, d_model)
    """
    # w_1(x).relu(): (batch_size, seq_length, d_model) x (d_model,d_ffn) -> (batch_size, seq_length, d_ffn)
    # w_2(w_1(x).relu()): (batch_size, seq_length, d_ffn) x (d_ffn, d_model) -> (batch_size, seq_length, d_model) 
    return self.w_2(self.dropout(self.w_1(x).relu()))

In [25]:
example = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"


def tokenize(sequence):
  # remove punctuation
  for punc in ["!", ".", "?"]:
    sequence = sequence.replace(punc, "")
  
  # split the sequence on spaces and lowercase each token
  return [token.lower() for token in sequence.split(" ")]

def build_vocab(data):
  # tokenize the data and remove duplicates
  vocab = list(set(tokenize(data)))

  # sort the vocabulary
  vocab.sort()

  # assign an integer to each word
  stoi = {word:i for i, word in enumerate(vocab)}

  return stoi

# build the vocab
stoi = build_vocab(example)



In [26]:
torch.set_printoptions(precision=2, sci_mode=False)

# convert the sequences to integers
sequences = ["I wonder what will come next!",
             "This is a basic example paragraph.",
             "Hello what is a basic split?"]

# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]

# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]

# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()

# vocab size
vocab_size = len(stoi)

# embedding dimensions
d_model = 8

# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)

# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)

# embed the sequence
embeddings = lut(tensor_sequences)

# positionally encode the sequences
X = pe(embeddings)

# set the n_heads
n_heads = 4

# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)

# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)

output

tensor([[[-0.22,  0.12, -0.19,  0.37, -0.54,  0.33,  0.01, -0.38],
         [-0.36,  0.29, -0.25,  0.54, -0.41,  0.44,  0.09, -0.45],
         [-0.27,  0.37, -0.35,  0.51, -0.57,  0.69,  0.16, -0.44],
         [-0.19,  0.29, -0.15,  0.61, -0.41,  0.76,  0.03, -0.62],
         [-0.31,  0.43, -0.43,  0.70, -0.86,  0.53,  0.26, -0.72],
         [-0.29,  0.51, -0.39,  0.56, -0.43,  0.39,  0.03, -0.61]],

        [[-0.44,  0.28, -0.27,  0.49, -0.53,  0.95,  0.19, -0.35],
         [-0.60,  0.25, -0.34,  0.32, -0.43,  0.78, -0.02, -0.32],
         [-0.42,  0.06, -0.18,  0.35, -0.61,  1.09,  0.13, -0.28],
         [-0.55,  0.51, -0.53,  0.44, -0.63,  0.70,  0.08, -0.41],
         [-0.51,  0.25, -0.36,  0.37, -0.47,  0.95,  0.08, -0.39],
         [-0.53,  0.36, -0.44,  0.46, -0.55,  1.04,  0.18, -0.43]],

        [[-0.54,  0.33, -0.31,  0.18, -0.33,  0.58, -0.05, -0.18],
         [-0.65,  0.33, -0.31,  0.45, -0.54,  0.72,  0.16, -0.36],
         [-0.79,  0.30, -0.25,  0.21, -0.34,  0.36, -0.12,

In [27]:
class EncoderLayer(nn.Module):  
  def __init__(self, d_model: int, n_heads: int, d_ffn: int, dropout: float):
    """
    Args:
        d_model:      dimension of embeddings
        n_heads:      number of heads
        d_ffn:        dimension of feed-forward network
        dropout:      probability of dropout occurring
    """
    super().__init__()
    # multi-head attention sublayer
    self.attention = MultiHeadAttention(d_model, n_heads, dropout)
    # layer norm for multi-head attention
    self.attn_layer_norm = nn.LayerNorm(d_model)

    # position-wise feed-forward network
    self.positionwise_ffn = PositionwiseFeedForward(d_model, d_ffn, dropout)
    # layer norm for position-wise ffn
    self.ffn_layer_norm = nn.LayerNorm(d_model)

    self.dropout = nn.Dropout(dropout)

  def forward(self, src: Tensor, src_mask: Tensor):
    """
    Args:
        src:          positionally embedded sequences   (batch_size, seq_length, d_model)
        src_mask:     mask for the sequences            (batch_size, 1, 1, seq_length)
    Returns:
        src:          sequences after self-attention    (batch_size, seq_length, d_model)
    """
    # pass embeddings through multi-head attention
    _src, attn_probs = self.attention(src, src, src, src_mask)

    # residual add and norm
    src = self.attn_layer_norm(src + self.dropout(_src))
    
    # position-wise feed-forward network
    _src = self.positionwise_ffn(src)

    # residual add and norm
    src = self.ffn_layer_norm(src + self.dropout(_src)) 

    return src, attn_probs

In [28]:
class Encoder(nn.Module):
  def __init__(self, d_model: int, n_layers: int, 
               n_heads: int, d_ffn: int, dropout: float = 0.1):
    """
    Args:
        d_model:      dimension of embeddings
        n_layers:     number of encoder layers
        n_heads:      number of heads
        d_ffn:        dimension of feed-forward network
        dropout:      probability of dropout occurring
    """
    super().__init__()
    
    # create n_layers encoders 
    self.layers = nn.ModuleList([EncoderLayer(d_model, n_heads, d_ffn, dropout)
                                 for layer in range(n_layers)])

    self.dropout = nn.Dropout(dropout)
    
  def forward(self, src: Tensor, src_mask: Tensor):
    """
    Args:
        src:          embedded sequences                (batch_size, seq_length, d_model)
        src_mask:     mask for the sequences            (batch_size, 1, 1, seq_length)

    Returns:
        src:          sequences after self-attention    (batch_size, seq_length, d_model)
    """

    # pass the sequences through each encoder
    for layer in self.layers:
      src, attn_probs = layer(src, src_mask)

    self.attn_probs = attn_probs

    return src


In [29]:
torch.set_printoptions(precision=2, sci_mode=False)

# convert the sequences to integers
sequences = ["I wonder what will come next!",
             "This is a basic example paragraph.",
             "Hello what is a basic split?"]

# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]

# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]

# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()

# parameters
vocab_size = len(stoi)
d_model = 8
d_ffn = d_model*4 # 32
n_heads = 4
n_layers = 4
dropout = 0.1

# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)

# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)

# embed the sequence
embeddings = lut(tensor_sequences)

# positionally encode the sequences
X = pe(embeddings)

# initialize encoder
encoder = Encoder(d_model, n_layers, n_heads,
                  d_ffn, dropout)

# pass through encoder
encoder(src=X, src_mask=None)

tensor([[[-1.42,  0.82, -0.81,  1.10, -0.97,  1.50, -0.45,  0.23],
         [ 0.03,  1.52, -0.30,  1.75, -1.09, -0.39, -0.68, -0.84],
         [ 0.28,  0.09,  1.82,  0.64, -1.12,  0.43, -0.57, -1.56],
         [-0.15, -1.15,  1.07,  1.81, -1.20,  0.52, -0.72, -0.18],
         [-0.36, -0.07,  1.23,  1.02, -1.28,  0.48,  0.69, -1.72],
         [-1.16,  0.01,  1.55,  1.55, -1.22, -0.13, -0.02, -0.59]],

        [[ 1.54,  0.73, -0.90,  0.16, -0.18, -0.03, -1.96,  0.62],
         [ 1.14,  0.61,  0.81,  0.02, -0.26, -0.29,  0.30, -2.33],
         [ 1.54,  1.22, -0.83, -0.30,  0.51,  0.12, -1.64, -0.63],
         [ 2.17, -0.58,  0.44, -0.22,  0.72, -1.02, -0.70, -0.81],
         [ 0.62, -1.29,  0.30,  0.15,  0.93,  0.13, -1.94,  1.10],
         [-0.17,  1.08, -0.84,  1.64,  0.38,  0.26, -0.72, -1.63]],

        [[-1.18,  0.43,  0.39,  1.32, -1.75,  1.08,  0.20, -0.48],
         [ 0.44,  0.58,  1.83, -0.20, -0.47,  0.50, -1.46, -1.22],
         [ 0.94, -0.33,  1.51,  0.44, -1.04,  0.12,  0.22,