In [9]:
import torch
import torch.nn as nn
import math
from torch import Tensor

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int = 512, n_heads: int = 8, dropout: float = 0.1):
        """
        Args:
            d_model:      嵌入的维度
            n_heads:      自注意力头的数量
            dropout:      丢弃概率
        """
        super().__init__()
        assert d_model % n_heads == 0            # 确保头的数量可以整除嵌入的维度
        self.d_model = d_model                   # 512 维度
        self.n_heads = n_heads                   # 8 个头
        self.d_key = d_model // n_heads          # 假设 d_value 等于 d_key | 512/8=64

        self.Wq = nn.Linear(d_model, d_model)    # 查询权重
        self.Wk = nn.Linear(d_model, d_model)    # 键权重
        self.Wv = nn.Linear(d_model, d_model)    # 值权重
        self.Wo = nn.Linear(d_model, d_model)    # 输出权重

        self.dropout = nn.Dropout(p=dropout)     # 初始化 dropout 层

    def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: torch.Tensor = None):
        """
        Args:
            query:         查询向量         (batch_size, q_length, d_model)
            key:           键向量           (batch_size, k_length, d_model)
            value:         值向量           (batch_size, s_length, d_model)
            mask:          解码器的掩码     

        Returns:
            output:        注意力值         (batch_size, q_length, d_model)
            attn_probs:    softmax 分数     (batch_size, n_heads, q_length, k_length)
        """
        batch_size = key.size(0)                  

        # 计算查询、键和值张量
        Q = self.Wq(query)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)
        K = self.Wk(key)                         # (32, 10, 512) x (512, 512) = (32, 10, 512)
        V = self.Wv(value)                       # (32, 10, 512) x (512, 512) = (32, 10, 512)

        # 将每个张量分割为 n 个头以计算注意力
        # 查询张量
        Q = Q.view(batch_size,                   
                   -1,                           
                   self.n_heads,              
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        # 键张量
        K = K.view(batch_size,                   
                   -1,                           
                   self.n_heads,              
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        # 值张量
        V = V.view(batch_size,                   
                   -1,                           
                   self.n_heads, 
                   self.d_key
                   ).permute(0, 2, 1, 3)         
        
        # 计算注意力
        # 缩放点积 -> QK^{T}
        scaled_dot_prod = torch.matmul(Q,        
                                       K.permute(0, 1, 3, 2)
                                       ) / math.sqrt(self.d_key)      

        # 将掩码位置为 0 的位置填充为 (-1e10)
        if mask is not None:
            scaled_dot_prod = scaled_dot_prod.masked_fill(mask == 0, -1e10)

        # 应用 softmax 
        attn_probs = torch.softmax(scaled_dot_prod, dim=-1)
        
        # 乘以值以获得注意力
        A = torch.matmul(self.dropout(attn_probs), V)       
                                                       
        # 将注意力重塑回 (32, 10, 512)
        A = A.permute(0, 2, 1, 3).contiguous()              
        A = A.view(batch_size, -1, self.n_heads*self.d_key) 
        
        # 通过最终权重层
        output = self.Wo(A)                                 

        return output, attn_probs                           


In [10]:
class Embeddings(nn.Module):
  def __init__(self, vocab_size: int, d_model: int):
    """
    Args:
      vocab_size:     size of vocabulary
      d_model:        dimension of embeddings
    """
    # inherit from nn.Module
    super().__init__()   
     
    # embedding look-up table (lut)                          
    self.lut = nn.Embedding(vocab_size, d_model)   

    # dimension of embeddings 
    self.d_model = d_model                          

  def forward(self, x: Tensor):
    """
    Args:
      x: input Tensor (batch_size, seq_length)
      
    Returns:
        embedding vector
    """
    # embeddings by constant sqrt(d_model)
    # return self.lut(x) * math.sqrt(self.d_model)  
    return self.lut(x)

NameError: name 'Tensor' is not defined

In [5]:
example = "Hello! This is an example of a paragraph that has been split into its basic components. I wonder what will come next! Any guesses?"


def tokenize(sequence):
  # remove punctuation
  for punc in ["!", ".", "?"]:
    sequence = sequence.replace(punc, "")
  
  # split the sequence on spaces and lowercase each token
  return [token.lower() for token in sequence.split(" ")]

def build_vocab(data):
  # tokenize the data and remove duplicates
  vocab = list(set(tokenize(data)))

  # sort the vocabulary
  vocab.sort()

  # assign an integer to each word
  stoi = {word:i for i, word in enumerate(vocab)}

  return stoi

# build the vocab
stoi = build_vocab(example)



In [7]:
torch.set_printoptions(precision=2, sci_mode=False)

# convert the sequences to integers
sequences = ["I wonder what will come next!",
             "This is a basic example paragraph.",
             "Hello what is a basic split?"]

# tokenize the sequences
tokenized_sequences = [tokenize(seq) for seq in sequences]

# index the sequences 
indexed_sequences = [[stoi[word] for word in seq] for seq in tokenized_sequences]

# convert the sequences to a tensor
tensor_sequences = torch.tensor(indexed_sequences).long()

# vocab size
vocab_size = len(stoi)

# embedding dimensions
d_model = 8

# create the embeddings
lut = Embeddings(vocab_size, d_model) # look-up table (lut)

# create the positional encodings
pe = PositionalEncoding(d_model=d_model, dropout=0.1, max_length=10)

# embed the sequence
embeddings = lut(tensor_sequences)

# positionally encode the sequences
X = pe(embeddings)

# set the n_heads
n_heads = 4

# create the attention layer
attention = MultiHeadAttention(d_model, n_heads, dropout=0.1)

# pass X through the attention layer three times to create Q, K, and V
output, attn_probs = attention(X, X, X, mask=None)

output

NameError: name 'Embeddings' is not defined