<table style="width:100%">
<tr>
<td style="vertical-align:middle; text-align:left;">
<font size="2">
Supplementary code for the <a href="http://mng.bz/orYv">Build a Large Language Model From Scratch</a> book by <a href="https://sebastianraschka.com">Sebastian Raschka</a><br>
<br>Code repository: <a href="https://github.com/rasbt/LLMs-from-scratch">https://github.com/rasbt/LLMs-from-scratch</a>
</font>
</td>
<td style="vertical-align:middle; text-align:left;">
<a href="http://mng.bz/orYv"><img src="https://sebastianraschka.com/images/LLMs-from-scratch-images/cover-small.webp" width="100px"></a>
</td>
</tr>
</table>


 # Multi-head Attention Plus Data Loading
 # 多头注意力机制与数据加载

In [1]:
# NBVAL_IGNORE_OUTPUT
from importlib.metadata import version

print("torch version:", version("torch"))

torch version: 2.2.2


The complete chapter code is located in [ch03.ipynb](./ch03.ipynb).
完整的章节代码位于 [ch03.ipynb](./ch03.ipynb)。

This notebook contains the main takeaway, multihead-attention implementation (plus the data loading pipeline from chapter 2)
本笔记本包含主要要点：多头注意力机制的实现（以及第2章的数据加载流程）

## Data Loader from Chapter 2
## 第二章的类加载器

In [2]:
# 导入所需的库
import tiktoken  # OpenAI的分词器库
import torch  # PyTorch深度学习框架
import torch.nn as nn  # PyTorch神经网络模块
from torch.utils.data import Dataset, DataLoader  # PyTorch数据加载工具


class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        # 初始化输入和目标ID列表
        self.input_ids = []  
        self.target_ids = []

        # 使用tokenizer对整个文本进行编码
        token_ids = tokenizer.encode(txt, allowed_special={'<|endoftext|>'})

        # 使用滑动窗口将文本分成重叠的序列
        for i in range(0, len(token_ids) - max_length, stride):
            # 获取输入序列片段
            input_chunk = token_ids[i:i + max_length]
            # 获取目标序列片段(比输入序列向后移动一位)
            target_chunk = token_ids[i + 1: i + max_length + 1]
            # 将序列转换为tensor并存储
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))

    def __len__(self):
        # 返回数据集长度
        return len(self.input_ids)

    def __getitem__(self, idx):
        # 返回指定索引的输入和目标序列对
        return self.input_ids[idx], self.target_ids[idx]


def create_dataloader(txt, batch_size=4, max_length=256, stride=128, shuffle=True):
    # 初始化GPT-2分词器
    tokenizer = tiktoken.get_encoding("gpt2")

    # 创建数据集实例
    dataset = GPTDatasetV1(txt, tokenizer, max_length, stride)

    # 创建并返回数据加载器
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=shuffle)

    return dataloader


# 读取示例文本文件
with open("small-text-sample.txt", "r", encoding="utf-8") as f:
    raw_text = f.read()

# 初始化分词器并编码文本
tokenizer = tiktoken.get_encoding("gpt2")
encoded_text = tokenizer.encode(raw_text)

# 设置模型参数
vocab_size = 50257  # GPT-2词表大小
output_dim = 256    # 嵌入维度
max_len = 1024      # 最大序列长度
context_length = max_len  # 上下文长度

# 创建词元嵌入层
token_embedding_layer = nn.Embedding(vocab_size, output_dim)
# 创建位置嵌入层
pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)

# 设置实际使用的序列长度并创建数据加载器
max_length = 4
dataloader = create_dataloader(raw_text, batch_size=8, max_length=max_length, stride=max_length)

In [3]:
# 遍历数据加载器
for batch in dataloader:
    # 获取一个批次的输入和目标
    x, y = batch

    # 获取词元嵌入
    token_embeddings = token_embedding_layer(x)
    # 获取位置嵌入
    pos_embeddings = pos_embedding_layer(torch.arange(max_length))

    # 将词元嵌入和位置嵌入相加得到输入嵌入
    input_embeddings = token_embeddings + pos_embeddings

    # 只处理第一个批次后就退出循环
    break

In [4]:
# 打印输入嵌入的形状
# 形状为 [batch_size, sequence_length, embedding_dim]
print(input_embeddings.shape)

torch.Size([8, 4, 256])


# Multi-head Attention from Chapter 3
# 第三章的多头注意力机制

## Variant A: Simple implementation
## 变体 A: 简单实现

In [5]:
class CausalSelfAttention(nn.Module):
    """因果自注意力层,用于确保当前位置只能看到之前位置的信息"""

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        """初始化函数
        Args:
            d_in: 输入维度
            d_out: 输出维度 
            context_length: 上下文长度
            dropout: dropout比率
            qkv_bias: 是否使用偏置项
        """
        super().__init__()
        self.d_out = d_out  # 保存输出维度
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # Query变换矩阵
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)  # Key变换矩阵
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # Value变换矩阵
        self.dropout = nn.Dropout(dropout)  # Dropout层防止过拟合
        # 创建上三角掩码矩阵,用于实现因果注意力
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        """前向传播函数
        Args:
            x: 输入张量, 形状为[batch_size, seq_len, d_in]
        Returns:
            context_vec: 注意力输出, 形状为[batch_size, seq_len, d_out]
        """
        b, n_tokens, d_in = x.shape  # 获取输入张量的形状
        keys = self.W_key(x)  # 计算key
        queries = self.W_query(x)  # 计算query
        values = self.W_value(x)  # 计算value

        # 计算注意力分数
        attn_scores = queries @ keys.transpose(1, 2)  # 矩阵乘法计算注意力分数
        # 将未来信息掩码设置为负无穷
        attn_scores.masked_fill_(
            self.mask.bool()[:n_tokens, :n_tokens], -torch.inf)
        # 计算注意力权重并进行缩放
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)  # 对注意力权重应用dropout

        # 计算最终的上下文向量
        context_vec = attn_weights @ values
        return context_vec


class MultiHeadAttentionWrapper(nn.Module):
    """多头注意力包装器,将多个注意力头的输出组合在一起"""

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        """初始化函数
        Args:
            d_in: 输入维度
            d_out: 每个头的输出维度
            context_length: 上下文长度
            dropout: dropout比率
            num_heads: 注意力头数量
            qkv_bias: 是否使用偏置项
        """
        super().__init__()
        # 创建多个注意力头
        self.heads = nn.ModuleList(
            [CausalSelfAttention(d_in, d_out, context_length, dropout, qkv_bias) 
             for _ in range(num_heads)]
        )
        # 输出投影层,将所有头的输出组合起来
        self.out_proj = nn.Linear(d_out*num_heads, d_out*num_heads)

    def forward(self, x):
        """前向传播函数
        Args:
            x: 输入张量
        Returns:
            输出张量,包含所有注意力头的组合结果
        """
        # 连接所有注意力头的输出
        context_vec = torch.cat([head(x) for head in self.heads], dim=-1)
        return self.out_proj(context_vec)

In [6]:
# 设置随机种子以确保可重复性
torch.manual_seed(123)

# 设置上下文长度和输入维度
context_length = max_length
d_in = output_dim

# 设置注意力头数量和每个头的输出维度
num_heads=2  # 使用2个注意力头
d_out = d_in // num_heads  # 每个头的输出维度是输入维度除以头数

# 创建多头注意力模型实例
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads)

# 使用输入嵌入进行前向传播
batch = input_embeddings
context_vecs = mha(batch)

# 打印输出张量的形状
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])


## Variant B: Alternative implementation
## 变体 B: 选择实现

In [7]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        """多头注意力层的初始化函数
        Args:
            d_in: 输入维度
            d_out: 输出维度
            context_length: 上下文长度
            dropout: dropout比率
            num_heads: 注意力头数量
            qkv_bias: 是否使用QKV变换的偏置项
        """
        super().__init__()
        # 确保输出维度可以被头数整除
        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        # 保存配置参数
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads  # 每个头的维度大小

        # 定义Q、K、V的线性变换层
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)  # Query变换
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)    # Key变换
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)  # Value变换
        
        # 输出投影层,用于组合所有头的输出
        self.out_proj = nn.Linear(d_out, d_out)
        
        # Dropout层
        self.dropout = nn.Dropout(dropout)
        
        # 注册因果掩码(上三角矩阵)作为缓冲区
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        """前向传播函数
        Args:
            x: 输入张量,形状为(batch_size, num_tokens, d_in)
        Returns:
            context_vec: 输出张量,形状为(batch_size, num_tokens, d_out)
        """
        # 获取输入张量的维度信息
        b, num_tokens, d_in = x.shape

        # 通过线性层计算K、Q、V
        keys = self.W_key(x)     # 形状: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # 重塑张量,添加头维度
        # 将最后一维拆分为(num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # 调整维度顺序,将头维度提前
        # (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # 计算注意力分数(Q和K的点积)
        attn_scores = queries @ keys.transpose(2, 3)  # 对每个头计算注意力分数
        
        # 将原始掩码截断到实际token数量并转换为布尔值
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # 使用掩码填充注意力分数
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        # 计算注意力权重(使用softmax归一化,并应用缩放因子)
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # 计算注意力加权和并调整维度顺序
        context_vec = (attn_weights @ values).transpose(1, 2)
        
        # 合并所有注意力头的输出
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        # 通过输出投影层
        context_vec = self.out_proj(context_vec)

        return context_vec

In [8]:
# 设置随机种子以确保结果可复现
torch.manual_seed(123)

# 设置上下文长度和维度参数
context_length = max_length  # 使用最大序列长度作为上下文长度
d_in = output_dim  # 输入维度等于输出维度
d_out = d_in  # 输出维度与输入维度相同

# 初始化多头注意力模块
# 参数说明:
# - d_in: 输入维度
# - d_out: 输出维度  
# - context_length: 上下文长度
# - dropout: dropout率(这里设为0)
# - num_heads: 注意力头数(这里设为2)
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

# 将输入嵌入传入多头注意力层
batch = input_embeddings
context_vecs = mha(batch)

# 打印输出张量的形状
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([8, 4, 256])
