## unicode1

ord()函数用于获取字符的Unicode码点（整数表示）
chr()函数用于将Unicode码点转换为对应的字符

chr(0) 是 空字符（Null character），在Python中表示为 '\x00'。

In [3]:
print(chr(0))

 


In [4]:
s = '\n'
print(s)
print(repr(s))



'\n'


chr(0)虽然表示空字符(\0)，但并不会导致字符串结束

In [5]:
chr(0)
print(chr(0))
"this is a test" + chr(0) + "string"
print("this is a test" + chr(0) + "string")

 
this is a test string


## unicode2

1. 效率与字节级表示
现代Tokenizer（如BERT使用的WordPiece、GPT使用的BPE）的核心思想是在子词级别进行分割，而一个越来越流行的趋势是直接在字节级别进行子词学习。

UTF-8是变长编码（1到4个字节），但它有一个关键优势：单字节编码的部分与ASCII完全兼容。这意味着所有英文字母、数字和常用符号在UTF-8中都被表示为单个字节。

Tokenizer的工作方式：当我们在UTF-8上训练BPE时，我们本质上是在对字节序列进行合并操作。模型最初看到的是一个个字节，然后学习将频繁共现的字节组合成更大的token。

使用UTF-8：

一个英文字符（如 a）是1个字节。

一个欧洲字符（如 é）是2个字节。

一个常见的中文汉字（如 中）是3个字节。

BPE算法可以自由地将单个字节或字节组合成token。例如，它可能会学到 "the" 是一个token（3个字节合并），也可能会学到 "中" 是一个token（3个特定字节的合并）。这非常灵活和高效。

使用UTF-16：

绝大多数常用字符（包括整个BMP基本多文种平面）都固定用2个字节表示。

这意味着即使是英文字母 a，也会被表示为 0061（2个字节）。这会立即将你的初始词汇表大小（在BPE之前）翻倍，并且其中一半的字节（前导的00）几乎是冗余的，因为英文文本中大部分时间高字节都是0。

Tokenizer需要处理这些大量的、无信息的00字节，学习效率低下。

使用UTF-32：

每个字符都固定用4个字节表示。

这将是效率的灾难。一个简单的英文文本文件，其原始大小会膨胀为UTF-8版本的4倍。

初始词汇表（在BPE之前）是巨大的4字节单元，其中对于拉丁语系文本，每个字符的前3个字节基本都是0。这对Tokenizer的学习和模型的计算都是巨大的浪费。

结论1：UTF-8为Tokenizer提供了最紧凑、信息密度最高的起点（字节序列），避免了UTF-16/32中大量的空格浪费。

2. 词汇表大小与模型复杂度
Tokenizer训练的目标之一是生成一个大小固定的、高效的词汇表。

UTF-8的基础原子只有256种可能（0-255的字节值）。这是一个非常小且易于管理的起点。BPE可以从这256个字节开始，稳健地构建出包含数万到数十万个token的词汇

表，这些token可以有效地表示任何语言中的任何单词。

3. 兼容性与实践性
Web和数据的现实：互联网上绝大部分文本数据默认都是以UTF-8编码的。操作系统、文件处理和网络传输也广泛支持UTF-8。从数据源到模型训练，使用UTF-8意味着更少的编码解码转换，减少了出错的可能。

统一处理：使用基于UTF-8的字节级Tokenizer（如SentencePiece），可以设计出一个真正的单一、统一的多语言模型。这个模型可以处理任何混合语言的文本，而无需预先知道文本的语言信息。因为所有语言最终都被分解成了相同的256个字节的序列。

In [6]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
 return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("hello".encode("utf-8"))

'hello'

In [None]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
 return "".join([bytes([b]).decode("utf-8") for b in bytestring])
decode_utf8_bytes_to_str_wrong("中文".encode("utf-8"))

此处报错的原因是decode_utf8_bytes_to_str_wrong（）函数的功能是遍历字节串中的每个字节，将每个整数字节转为单字节bytes，然后单独对每个字节进行UTF-8解码，再将结果拼接为字符串。‘hello’能正常运行的原因是其中的每个字符都为有效的ASCII字符，在单字节解码时不会出错。而‘中文’中的每个字符都是三字节字符，转换后的字节本身不是有效的UTF-8字符，所以无法正常进行UFT-8解码

In [None]:
bytestring = "hello".encode("utf-8")
print(bytestring)
bytestring = "中文".encode("utf-8")
print(bytestring)

有效的 UTF-8 双字节序列格式为：110xxxxx 10xxxxxx
无效的序列示例：0xC1 0xBF（无效起始字节）

## train_bpe

In [None]:
import collections
from typing import Dict, List, Tuple

def train_bpe(input_path: str, vocab_size: int, special_tokens: List[str], max_merges: int = None) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    训练字节级BPE分词器
    
    Args:
        input_path: 训练数据文本文件的路径
        vocab_size: 最终词汇表大小（包括初始字节词汇、合并产生的词汇和特殊令牌）
        special_tokens: 要添加到词汇表中的特殊令牌列表
        max_merges: 最大合并次数，如果为None则不限制
    
    Returns:
        vocab: 词汇表，映射从token ID到bytes
        merges: BPE合并操作列表，按创建顺序排列
    """
    
    # 1. 初始化词汇表（256个字节 + 特殊令牌）
    vocab = {}
    next_id = 0
    
    # 添加基础字节词汇 (0-255)
    for i in range(256):
        vocab[next_id] = bytes([i])
        next_id += 1
    
    # 添加特殊令牌
    special_token_bytes = []
    for token in special_tokens:
        token_bytes = token.encode('utf-8')
        special_token_bytes.append(token_bytes)
        vocab[next_id] = token_bytes
        next_id += 1
    
    # 2. 读取训练数据并统计词频
    with open(input_path, 'r', encoding='utf-8') as f:
        text = f.read()
    
    # 将文本转换为UTF-8字节序列
    byte_data = text.encode('utf-8')
    
    # 3. 初始化词汇：将文本拆分为单词，每个单词拆分为字节
    words = text.split()
    word_freqs = collections.Counter(words)
    
    # 构建初始词汇统计（字节序列）
    vocab_stats = {}
    for word, freq in word_freqs.items():
        byte_word = word.encode('utf-8')
        # 将单词表示为字节列表
        tokens = [bytes([b]) for b in byte_word]
        vocab_stats[tuple(tokens)] = freq
    
    # 4. BPE训练循环
    merges = []
    merge_count = 0  # 添加合并次数计数器
    
    while len(vocab) < vocab_size:
        # 检查是否达到最大合并次数限制
        if max_merges is not None and merge_count >= max_merges:
            break
            
        # 统计所有相邻字节对的出现频率
        pair_freqs = collections.Counter()
        
        for tokens, freq in vocab_stats.items():
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                pair_freqs[pair] += freq
        
        if not pair_freqs:
            break  # 没有更多可以合并的对
        
        # 找到频率最高的字节对（如果有多个相同频率的，选择最后一个）
        max_freq = max(pair_freqs.values())
        # 找到所有频率等于最大频率的字节对
        best_pairs = [pair for pair, freq in pair_freqs.items() if freq == max_freq]
        # 选择最后一个字节对
        best_pair = best_pairs[-1] if best_pairs else None
        
        # 检查是否达到词汇表大小限制
        if len(vocab) >= vocab_size:
            break
        
        # 创建新的合并token
        new_token = best_pair[0] + best_pair[1]
        
        # 添加到词汇表
        vocab[next_id] = new_token
        next_id += 1
        
        # 记录合并操作
        merges.append(best_pair)
        merge_count += 1  # 增加合并计数器
        
        # 更新词汇统计：合并所有出现的best_pair
        new_vocab_stats = {}
        for tokens, freq in vocab_stats.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and 
                    tokens[i] == best_pair[0] and 
                    tokens[i + 1] == best_pair[1]):
                    new_tokens.append(new_token)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            new_vocab_stats[tuple(new_tokens)] = freq
        
        vocab_stats = new_vocab_stats
    
    return vocab, merges

# 测试函数
def test_bpe():
    """测试BPE训练函数"""
    
    user_dir = os.path.expanduser('~')
    file_path = os.path.join(user_dir, 'Downloads', 'tinystories_sample.txt')
    
    # 训练BPE分词器，限制最多合并6次
    vocab, merges = train_bpe(
        input_path='file_path',
        vocab_size=,
        special_tokens=['<pad>', '<unk>', '<s>', '</s>'],
        max_merges=6  # 添加合并次数限制
    )
    
    print("词汇表大小:", len(vocab))
    print("实际合并次数:", len(merges))
    print("\n前20个词汇项:")
    for i, (token_id, token_bytes) in enumerate(list(vocab.items())[:20]):
        print(f"ID {token_id}: {token_bytes} -> {token_bytes.decode('utf-8', errors='replace')}")
    
    print(f"\n特殊令牌:")
    for token in ['<pad>', '<unk>', '<s>', '</s>']:
        token_bytes = token.encode('utf-8')
        for token_id, bytes_val in vocab.items():
            if bytes_val == token_bytes:
                print(f"ID {token_id}: {token_bytes} -> {token}")
                break
    
    print(f"\n合并操作数量: {len(merges)}")
    if merges:
        print("合并操作列表:")
        for i, merge in enumerate(merges):
            print(f"  {i+1}: {merge[0]} + {merge[1]} -> {merge[0] + merge[1]}")
    print(vocab)
    print('/n')
    print('/n')
    print('/n')
    print(merges)
    # 清理测试文件
    import os
    if os.path.exists('test_bpe.txt'):
        os.remove('test_bpe.txt')

if __name__ == "__main__":
    test_bpe()

组合：

1.创建主函数和导入依赖

In [None]:
import os
import collections
import multiprocessing as mp
from typing import Dict, List, Tuple, BinaryIO
import heapq

def parallel_bpe_training(
    input_path: str,
    vocab_size: int,
    special_tokens: List[str],
    max_merges: int = None,
    num_processes: int = 4,
    chunk_token: bytes = b"<|endoftext|>"
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    并行BPE训练主函数
    
    Args:
        input_path: 输入文件路径
        vocab_size: 目标词汇表大小
        special_tokens: 特殊标记列表
        max_merges: 最大合并次数
        num_processes: 进程数
        chunk_token: 分块标记
    """
    
    # 1. 分块处理
    chunk_results = process_chunks_parallel(
        input_path, num_processes, chunk_token, special_tokens
    )
    
    # 2. 合并统计结果
    global_word_freqs = merge_chunk_statistics(chunk_results)
    
    # 3. 全局BPE训练
    vocab, merges = train_bpe_from_statistics(
        global_word_freqs, vocab_size, special_tokens, max_merges
    )
    
    return vocab, merges

2.并行处理分块

In [None]:
def process_chunk(args):
    """处理单个分块的worker函数"""
    start, end, input_path, chunk_token, special_tokens = args
    
    word_freqs = collections.Counter()
    
    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_data = f.read(end - start)
        
        try:
            # 解码并处理分块
            text = chunk_data.decode('utf-8', errors='ignore')
            words = text.split()
            
            # 统计词频（转换为字节序列）
            for word in words:
                byte_word = word.encode('utf-8')
                word_freqs[byte_word] += 1
                
        except Exception as e:
            print(f"处理分块 [{start}-{end}] 时出错: {e}")
    
    return dict(word_freqs)

def process_chunks_parallel(input_path, num_processes, chunk_token, special_tokens):
    """并行处理所有分块"""
    
    # 获取分块边界
    with open(input_path, 'rb') as f:
        boundaries = find_chunk_boundaries(f, num_processes, chunk_token)
    
    # 准备参数
    chunk_args = []
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        chunk_args.append((start, end, input_path, chunk_token, special_tokens))
    
    # 并行处理
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(process_chunk, chunk_args)
    
    return results

3. 合并统计结果

In [None]:
def merge_chunk_statistics(chunk_results):
    """合并所有分块的统计结果"""
    global_word_freqs = collections.Counter()
    
    for chunk_freqs in chunk_results:
        for word_bytes, freq in chunk_freqs.items():
            global_word_freqs[word_bytes] += freq
    
    return global_word_freqs

4. 基于统计的BPE训练

In [None]:
def train_bpe_from_statistics(
    word_freqs: Dict[bytes, int],
    vocab_size: int,
    special_tokens: List[str],
    max_merges: int = None
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    基于词频统计训练BPE
    """
    
    # 初始化词汇表
    vocab = {}
    next_id = 0
    
    # 添加基础字节词汇
    for i in range(256):
        vocab[next_id] = bytes([i])
        next_id += 1
    
    # 添加特殊令牌
    special_token_bytes = []
    for token in special_tokens:
        token_bytes = token.encode('utf-8')
        special_token_bytes.append(token_bytes)
        vocab[next_id] = token_bytes
        next_id += 1
    
    # 构建初始词汇统计
    vocab_stats = {}
    for word_bytes, freq in word_freqs.items():
        tokens = [bytes([b]) for b in word_bytes]
        vocab_stats[tuple(tokens)] = freq
    
    # BPE训练循环
    merges = []
    merge_count = 0
    
    while len(vocab) < vocab_size:
        if max_merges is not None and merge_count >= max_merges:
            break
            
        # 统计字节对频率
        pair_freqs = collections.Counter()
        
        for tokens, freq in vocab_stats.items():
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                pair_freqs[pair] += freq
        
        if not pair_freqs:
            break
        
        # 找到最频繁的字节对
        max_freq = max(pair_freqs.values())
        best_pairs = [pair for pair, freq in pair_freqs.items() if freq == max_freq]
        best_pair = best_pairs[-1] if best_pairs else None
        
        if len(vocab) >= vocab_size:
            break
        
        # 创建新token
        new_token = best_pair[0] + best_pair[1]
        vocab[next_id] = new_token
        next_id += 1
        
        # 记录合并
        merges.append(best_pair)
        merge_count += 1
        
        # 更新词汇统计
        new_vocab_stats = {}
        for tokens, freq in vocab_stats.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and 
                    tokens[i] == best_pair[0] and 
                    tokens[i + 1] == best_pair[1]):
                    new_tokens.append(new_token)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            new_vocab_stats[tuple(new_tokens)] = freq
        
        vocab_stats = new_vocab_stats
    
    return vocab, merges

5. 完整的组合代码

In [None]:
import os
import collections
import multiprocessing as mp
from typing import Dict, List, Tuple, BinaryIO


def find_chunk_boundaries(
        file: BinaryIO,  # 二进制文件对象
        desired_num_chunks: int,  # 期望的块数
        split_special_token: bytes,  # 分割标记（字节格式）
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    file.seek(0, os.SEEK_END)
    file_size = file.tell()  # 获取文件总大小
    file.seek(0)  # 重置文件指针

    chunk_size = file_size // desired_num_chunks  # 计算理论块大小

    # 创建初始的均匀分块边界
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size  # 确保最后一个边界是文件末尾

    mini_chunk_size = 4096  # 每次读取4KB进行搜索

    for bi in range(1, len(chunk_boundaries) - 1):  # 调整中间边界
        initial_position = chunk_boundaries[bi]  # 初始边界位置
        file.seek(initial_position)

        while True:
            mini_chunk = file.read(mini_chunk_size)  # 读取小块

            if mini_chunk == b"":  # 到达文件末尾
                chunk_boundaries[bi] = file_size
                break

            found_at = mini_chunk.find(split_special_token)  # 查找分割标记
            if found_at != -1:  # 找到标记
                chunk_boundaries[bi] = initial_position + found_at + len(split_special_token)
                break

            initial_position += mini_chunk_size  # 继续向前搜索

    return sorted(set(chunk_boundaries))  # 去重并排序


def process_chunk(args):
    """处理单个分块的worker函数"""
    start, end, input_path, chunk_token, special_tokens = args

    word_freqs = collections.Counter()

    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_data = f.read(end - start)

        try:
            # 解码并处理分块
            text = chunk_data.decode('utf-8', errors='ignore')
            words = text.split()

            # 统计词频（转换为字节序列）
            for word in words:
                byte_word = word.encode('utf-8')
                word_freqs[byte_word] += 1

        except Exception as e:
            print(f"处理分块 [{start}-{end}] 时出错: {e}")

    return dict(word_freqs)


def process_chunks_parallel(input_path, num_processes, chunk_token, special_tokens):
    """并行处理所有分块"""

    # 获取分块边界
    with open(input_path, 'rb') as f:
        boundaries = find_chunk_boundaries(f, num_processes, chunk_token)

    print(f"分块边界: {boundaries}")

    # 准备参数
    chunk_args = []
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        chunk_args.append((start, end, input_path, chunk_token, special_tokens))

    print(f"创建了 {len(chunk_args)} 个分块任务")

    # 并行处理
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(process_chunk, chunk_args)

    return results


def merge_chunk_statistics(chunk_results):
    """合并所有分块的统计结果"""
    global_word_freqs = collections.Counter()

    for chunk_freqs in chunk_results:
        for word_bytes, freq in chunk_freqs.items():
            global_word_freqs[word_bytes] += freq

    return global_word_freqs


def train_bpe_from_statistics(
        word_freqs: Dict[bytes, int],
        vocab_size: int,
        special_tokens: List[str],
        max_merges: int = None
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    基于词频统计训练BPE
    """

    # 初始化词汇表
    vocab = {}
    next_id = 0

    # 添加基础字节词汇
    for i in range(256):
        vocab[next_id] = bytes([i])
        next_id += 1

    # 添加特殊令牌
    special_token_bytes = []
    for token in special_tokens:
        token_bytes = token.encode('utf-8')
        special_token_bytes.append(token_bytes)
        vocab[next_id] = token_bytes
        next_id += 1

    # 构建初始词汇统计
    vocab_stats = {}
    for word_bytes, freq in word_freqs.items():
        tokens = [bytes([b]) for b in word_bytes]
        vocab_stats[tuple(tokens)] = freq

    # BPE训练循环
    merges = []
    merge_count = 0

    while len(vocab) < vocab_size:
        if max_merges is not None and merge_count >= max_merges:
            break

        # 统计字节对频率
        pair_freqs = collections.Counter()

        for tokens, freq in vocab_stats.items():
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                pair_freqs[pair] += freq

        if not pair_freqs:
            break

        # 找到最频繁的字节对
        max_freq = max(pair_freqs.values())
        best_pairs = [pair for pair, freq in pair_freqs.items() if freq == max_freq]
        best_pair = best_pairs[-1] if best_pairs else None

        if len(vocab) >= vocab_size:
            break

        # 创建新token
        new_token = best_pair[0] + best_pair[1]
        vocab[next_id] = new_token
        next_id += 1

        # 记录合并
        merges.append(best_pair)
        merge_count += 1

        # 更新词汇统计
        new_vocab_stats = {}
        for tokens, freq in vocab_stats.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and
                        tokens[i] == best_pair[0] and
                        tokens[i + 1] == best_pair[1]):
                    new_tokens.append(new_token)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            new_vocab_stats[tuple(new_tokens)] = freq

        vocab_stats = new_vocab_stats

    return vocab, merges


def parallel_bpe_training(
        input_path: str,
        vocab_size: int,
        special_tokens: List[str],
        max_merges: int = None,
        num_processes: int = 4,
        chunk_token: bytes = b"<|endoftext|>"
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    完整的并行BPE训练流程
    """
    print("开始并行BPE训练...")
    print(f"文件: {input_path}")
    print(f"进程数: {num_processes}")

    # 1. 并行分块处理
    print("阶段1: 并行分块处理...")
    chunk_results = process_chunks_parallel(
        input_path, num_processes, chunk_token, special_tokens
    )

    # 2. 合并统计
    print("阶段2: 合并统计结果...")
    global_word_freqs = merge_chunk_statistics(chunk_results)
    print(f"统计了 {len(global_word_freqs)} 个唯一单词")

    # 3. BPE训练
    print("阶段3: BPE训练...")
    vocab, merges = train_bpe_from_statistics(
        global_word_freqs, vocab_size, special_tokens, max_merges
    )

    print(f"训练完成! 词汇表大小: {len(vocab)}, 合并操作: {len(merges)}")
    return vocab, merges


def test_parallel_bpe():
    """测试并行BPE训练"""
    user_dir = os.path.expanduser('~')
    file_path = os.path.join(user_dir, 'Downloads', 'tinystories_sample.txt')

    # 检查文件大小
    file_size = os.path.getsize(file_path)
    print(f"测试文件大小: {file_size} 字节")

    # 并行训练
    vocab, merges = parallel_bpe_training(
        input_path=file_path,
        vocab_size=10000,
        special_tokens=['<pad>', '<unk>', '<s>', '</s>'],
        max_merges= None,
        num_processes=2
    )

    print("\n训练结果:")
    print(f"词汇表大小: {len(vocab)}")
    print(f"合并操作数: {len(merges)}")
    print(merges)
    print(vocab)

if __name__ == "__main__":
    # 注意：在Windows上使用multiprocessing时需要这个保护
    test_parallel_bpe()

## implemented tokenizer

In [None]:
import collections
from typing import Dict, List, Tuple, Iterable, Iterator, Optional,BinaryIO
import os
import multiprocessing as mp
import time

def find_chunk_boundaries(
        file: BinaryIO,  # 二进制文件对象
        desired_num_chunks: int,  # 期望的块数
        split_special_token: bytes,  # 分割标记（字节格式）
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    file.seek(0, os.SEEK_END)
    file_size = file.tell()  # 获取文件总大小
    file.seek(0)  # 重置文件指针

    chunk_size = file_size // desired_num_chunks  # 计算理论块大小

    # 创建初始的均匀分块边界
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size  # 确保最后一个边界是文件末尾

    mini_chunk_size = 4096  # 每次读取4KB进行搜索

    for bi in range(1, len(chunk_boundaries) - 1):  # 调整中间边界
        initial_position = chunk_boundaries[bi]  # 初始边界位置
        file.seek(initial_position)

        while True:
            mini_chunk = file.read(mini_chunk_size)  # 读取小块

            if mini_chunk == b"":  # 到达文件末尾
                chunk_boundaries[bi] = file_size
                break

            found_at = mini_chunk.find(split_special_token)  # 查找分割标记
            if found_at != -1:  # 找到标记
                chunk_boundaries[bi] = initial_position + found_at + len(split_special_token)
                break

            initial_position += mini_chunk_size  # 继续向前搜索

    return sorted(set(chunk_boundaries))  # 去重并排序


def process_chunk(args):
    """处理单个分块的worker函数"""
    start, end, input_path, chunk_token, special_tokens = args

    word_freqs = collections.Counter()

    with open(input_path, 'rb') as f:
        f.seek(start)
        chunk_data = f.read(end - start)

        try:
            # 解码并处理分块
            text = chunk_data.decode('utf-8', errors='ignore')
            words = text.split()

            # 统计词频（转换为字节序列）
            for word in words:
                byte_word = word.encode('utf-8')
                word_freqs[byte_word] += 1

        except Exception as e:
            print(f"处理分块 [{start}-{end}] 时出错: {e}")

    return dict(word_freqs)


def process_chunks_parallel(input_path, num_processes, chunk_token, special_tokens):
    """并行处理所有分块"""

    # 获取分块边界
    with open(input_path, 'rb') as f:
        boundaries = find_chunk_boundaries(f, num_processes, chunk_token)

    print(f"分块边界: {boundaries}")

    # 准备参数
    chunk_args = []
    for start, end in zip(boundaries[:-1], boundaries[1:]):
        chunk_args.append((start, end, input_path, chunk_token, special_tokens))

    print(f"创建了 {len(chunk_args)} 个分块任务")

    # 并行处理
    with mp.Pool(processes=num_processes) as pool:
        results = pool.map(process_chunk, chunk_args)

    return results


def merge_chunk_statistics(chunk_results):
    """合并所有分块的统计结果"""
    global_word_freqs = collections.Counter()

    for chunk_freqs in chunk_results:
        for word_bytes, freq in chunk_freqs.items():
            global_word_freqs[word_bytes] += freq

    return global_word_freqs


def train_bpe_from_statistics(
        word_freqs: Dict[bytes, int],
        vocab_size: int,
        special_tokens: List[str],
        max_merges: int = None
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    基于词频统计训练BPE
    """

    # 初始化词汇表
    vocab = {}
    next_id = 0

    # 添加基础字节词汇
    for i in range(256):
        vocab[next_id] = bytes([i])
        next_id += 1

    # 添加特殊令牌
    special_token_bytes = []
    for token in special_tokens:
        token_bytes = token.encode('utf-8')
        special_token_bytes.append(token_bytes)
        vocab[next_id] = token_bytes
        next_id += 1

    # 构建初始词汇统计
    vocab_stats = {}
    for word_bytes, freq in word_freqs.items():
        tokens = [bytes([b]) for b in word_bytes]
        vocab_stats[tuple(tokens)] = freq

    # BPE训练循环
    merges = []
    merge_count = 0

    while len(vocab) < vocab_size:
        if max_merges is not None and merge_count >= max_merges:
            break

        # 统计字节对频率
        pair_freqs = collections.Counter()

        for tokens, freq in vocab_stats.items():
            for i in range(len(tokens) - 1):
                pair = (tokens[i], tokens[i + 1])
                pair_freqs[pair] += freq

        if not pair_freqs:
            break

        # 找到最频繁的字节对
        max_freq = max(pair_freqs.values())
        best_pairs = [pair for pair, freq in pair_freqs.items() if freq == max_freq]
        best_pair = best_pairs[-1] if best_pairs else None

        if len(vocab) >= vocab_size:
            break

        # 创建新token
        new_token = best_pair[0] + best_pair[1]
        vocab[next_id] = new_token
        next_id += 1

        # 记录合并
        merges.append(best_pair)
        merge_count += 1

        # 更新词汇统计
        new_vocab_stats = {}
        for tokens, freq in vocab_stats.items():
            new_tokens = []
            i = 0
            while i < len(tokens):
                if (i < len(tokens) - 1 and
                        tokens[i] == best_pair[0] and
                        tokens[i + 1] == best_pair[1]):
                    new_tokens.append(new_token)
                    i += 2
                else:
                    new_tokens.append(tokens[i])
                    i += 1
            new_vocab_stats[tuple(new_tokens)] = freq

        vocab_stats = new_vocab_stats

    return vocab, merges


def parallel_bpe_training(
        input_path: str,
        vocab_size: int,
        special_tokens: List[str],
        max_merges: int = None,
        num_processes: int = 4,
        chunk_token: bytes = b"<|endoftext|>"
) -> Tuple[Dict[int, bytes], List[Tuple[bytes, bytes]]]:
    """
    完整的并行BPE训练流程
    """
    print("开始并行BPE训练...")
    print(f"文件: {input_path}")
    print(f"进程数: {num_processes}")

    # 1. 并行分块处理
    print("阶段1: 并行分块处理...")
    chunk_results = process_chunks_parallel(
        input_path, num_processes, chunk_token, special_tokens
    )

    # 2. 合并统计
    print("阶段2: 合并统计结果...")
    global_word_freqs = merge_chunk_statistics(chunk_results)
    print(f"统计了 {len(global_word_freqs)} 个唯一单词")

    # 3. BPE训练
    print("阶段3: BPE训练...")
    vocab, merges = train_bpe_from_statistics(
        global_word_freqs, vocab_size, special_tokens, max_merges
    )

    print(f"训练完成! 词汇表大小: {len(vocab)}, 合并操作: {len(merges)}")
    return vocab, merges


class Tokenizer:
    """BPE分词器实现"""

    def __init__(self, vocab: Dict[int, bytes], merges: List[Tuple[bytes, bytes]],
                 special_tokens: Optional[List[str]] = None):
        """
        从给定的词汇表、合并列表和特殊令牌构造分词器

        Args:
            vocab: 词汇表，映射token ID到bytes
            merges: BPE合并操作列表
            special_tokens: 特殊令牌列表
        """
        self.vocab = vocab.copy()  # 创建副本以避免修改原始词汇表
        self.merges = merges.copy()

        # 构建反向词汇表（bytes到ID的映射）
        self.vocab_inv = {token: idx for idx, token in self.vocab.items()}

        # 处理特殊令牌
        self.special_tokens = special_tokens or []
        self.special_token_ids = {}

        # 添加特殊令牌到词汇表（如果不存在）
        for token in self.special_tokens:
            token_bytes = token.encode('utf-8')
            if token_bytes not in self.vocab_inv:
                # 分配新的ID
                new_id = max(self.vocab.keys()) + 1
                self.vocab[new_id] = token_bytes
                self.vocab_inv[token_bytes] = new_id
            self.special_token_ids[token] = self.vocab_inv[token_bytes]

        # 构建合并优先级字典
        self.merge_priority = {}
        for i, (a, b) in enumerate(self.merges):
            self.merge_priority[(a, b)] = i

    @classmethod
    def from_files(cls, vocab_filepath: str, merges_filepath: str, special_tokens: Optional[List[str]] = None):
        """
        从文件加载词汇表和合并列表构造分词器

        Args:
            vocab_filepath: 词汇表文件路径
            merges_filepath: 合并列表文件路径
            special_tokens: 特殊令牌列表
        """
        vocab = {}
        merges = []

        # 加载词汇表
        with open(vocab_filepath, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) >= 2:
                    token_id = int(parts[0])
                    # 处理字节表示（可能是十六进制或原始字节）
                    byte_repr = parts[1]
                    if byte_repr.startswith('b"') and byte_repr.endswith('"'):
                        # 处理b"..."格式
                        byte_repr = byte_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    elif byte_repr.startswith("b'") and byte_repr.endswith("'"):
                        # 处理b'...'格式
                        byte_repr = byte_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    else:
                        # 假设是十六进制表示
                        try:
                            byte_repr = bytes.fromhex(byte_repr)
                        except ValueError:
                            # 如果十六进制解析失败，假设是原始文本
                            byte_repr = byte_repr.encode('utf-8')
                    vocab[token_id] = byte_repr

        # 加载合并列表
        with open(merges_filepath, 'r', encoding='utf-8') as f:
            for line in f:
                line = line.strip()
                if not line:
                    continue
                parts = line.split()
                if len(parts) >= 2:
                    # 处理字节表示
                    a_repr = parts[0]
                    b_repr = parts[1]

                    if a_repr.startswith('b"') and a_repr.endswith('"'):
                        a_bytes = a_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    elif a_repr.startswith("b'") and a_repr.endswith("'"):
                        a_bytes = a_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    else:
                        try:
                            a_bytes = bytes.fromhex(a_repr)
                        except ValueError:
                            a_bytes = a_repr.encode('utf-8')

                    if b_repr.startswith('b"') and b_repr.endswith('"'):
                        b_bytes = b_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    elif b_repr.startswith("b'") and b_repr.endswith("'"):
                        b_bytes = b_repr[2:-1].encode('utf-8').decode('unicode_escape').encode('latin1')
                    else:
                        try:
                            b_bytes = bytes.fromhex(b_repr)
                        except ValueError:
                            b_bytes = b_repr.encode('utf-8')

                    merges.append((a_bytes, b_bytes))

        return cls(vocab, merges, special_tokens)

    def encode(self, text: str) -> List[int]:
        """
        将输入文本编码为token ID序列

        Args:
            text: 输入文本

        Returns:
            token ID列表
        """
        # 将文本转换为字节
        byte_sequence = text.encode('utf-8')

        # 初始tokenization：将字节序列拆分为单个字节
        tokens = [bytes([b]) for b in byte_sequence]

        # 应用BPE合并
        tokens = self._apply_merges(tokens)

        # 转换为ID
        token_ids = []
        for token in tokens:
            if token in self.vocab_inv:
                token_ids.append(self.vocab_inv[token])
            else:
                # 处理未知token（使用最低ID的特殊token或字节回退）
                if self.special_tokens and '<unk>' in self.special_token_ids:
                    token_ids.append(self.special_token_ids['<unk>'])
                else:
                    # 回退到单个字节编码
                    for b in token:
                        byte_token = bytes([b])
                        token_ids.append(self.vocab_inv[byte_token])

        return token_ids

    def _apply_merges(self, tokens: List[bytes]) -> List[bytes]:
        """应用BPE合并规则"""
        if not self.merges:
            return tokens

        # 按照合并优先级排序
        sorted_merges = sorted(self.merges, key=lambda x: self.merge_priority.get(x, 0))

        changed = True
        while changed and len(tokens) > 1:
            changed = False
            i = 0
            while i < len(tokens) - 1:
                current_pair = (tokens[i], tokens[i + 1])

                # 检查是否可以合并
                for a, b in sorted_merges:
                    if current_pair == (a, b):
                        # 执行合并
                        merged = a + b
                        tokens[i] = merged
                        tokens.pop(i + 1)
                        changed = True
                        break
                i += 1

        return tokens

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        """
        对字符串可迭代对象进行编码，惰性生成token IDs

        Args:
            iterable: 字符串可迭代对象

        Yields:
            token IDs
        """
        for text in iterable:
            token_ids = self.encode(text)
            for token_id in token_ids:
                yield token_id

    def decode(self, ids: List[int]) -> str:
        """
        将token ID序列解码为文本

        Args:
            ids: token ID列表

        Returns:
            解码后的文本
        """
        byte_sequence = b''
        for token_id in ids:
            if token_id in self.vocab:
                byte_sequence += self.vocab[token_id]
            else:
                # 处理未知ID（使用替换字符）
                replacement_char = b'\xef\xbf\xbd'  # UTF-8替换字符
                byte_sequence += replacement_char

        try:
            return byte_sequence.decode('utf-8')
        except UnicodeDecodeError:
            # 如果UTF-8解码失败，使用错误处理
            return byte_sequence.decode('utf-8', errors='replace')


# 测试函数
def test_tokenizer():
    """测试Tokenizer类"""

    # 创建测试词汇表和合并列表
    user_dir = os.path.expanduser('~')
    file_path = os.path.join(user_dir, 'Downloads', 'tinystories_sample.txt')

    # 检查文件大小
    file_size = os.path.getsize(file_path)
    print(f"测试文件大小: {file_size} 字节")

    # 并行训练
    vocab, merges = parallel_bpe_training(
        input_path=file_path,
        vocab_size=10000,
        special_tokens=['<pad>', '<unk>', '<s>', '</s>'],
        max_merges=None,
        num_processes=2
    )
    special_tokens = ['<pad>', '<unk>', '<s>', '</s>']
    # 创建分词器
    tokenizer = Tokenizer(vocab, merges, special_tokens)

    def calculate_compression_ratio(text, tokenizer):
        # 原始文本大小（字符数）
        original_size = len(text)

        # 编码为tokens
        tokens = tokenizer.encode(text)

        # 编码后的大小（token数）
        encoded_size = len(tokens)

        # 压缩比 = 原始大小 / 编码后大小
        compression_ratio = original_size / encoded_size if encoded_size > 0 else 0

        return original_size, encoded_size, compression_ratio

    def calculate_throughput(text, tokenizer, iterations=100):
        # 编码吞吐量
        start_time = time.time()
        for _ in range(iterations):
            _ = tokenizer.encode(text)
        encode_time = time.time() - start_time

        tokens = tokenizer.encode(text)

        # 解码吞吐量
        start_time = time.time()
        for _ in range(iterations):
            _ = tokenizer.decode(tokens)
        decode_time = time.time() - start_time

        # 计算每秒处理的字符数
        encode_throughput = (len(text) * iterations) / encode_time
        decode_throughput = (len(text) * iterations) / decode_time

        # 计算每秒处理的token数
        encode_token_throughput = (len(tokens) * iterations) / encode_time
        decode_token_throughput = (len(tokens) * iterations) / decode_time

        return (encode_throughput, decode_throughput,
                encode_token_throughput, decode_token_throughput)
    with open(file_path, 'r', encoding='utf-8') as f:
        file_content = f.read()
    #对文本进行编码和解码
    encoded_tokens = tokenizer.encode(file_content)
    decoded_text = tokenizer.decode(encoded_tokens)

    #计算压缩比
    original_size, encoded_size, compression_ratio = calculate_compression_ratio(file_content, tokenizer)

    #计算吞吐量
    encode_throughput, decode_throughput, encode_token_throughput, decode_token_throughput = calculate_throughput(
        file_content[:10000], tokenizer)

    print("\n" + "=" * 50)
    print("Tokenizer 性能测试结果")
    print("=" * 50)

    print(f"原始文本大小: {original_size} 字符")
    print(f"编码后token数量: {encoded_size} tokens")
    print(f"压缩比: {compression_ratio:.2f} (字符/token)")

    print("\n吞吐量测试:")
    print(f"编码吞吐量: {encode_throughput:.2f} 字符/秒")
    print(f"解码吞吐量: {decode_throughput:.2f} 字符/秒")
    print(f"编码吞吐量: {encode_token_throughput:.2f} tokens/秒")
    print(f"解码吞吐量: {decode_token_throughput:.2f} tokens/秒")

    # 测试文件保存和加载（模拟）
    #print("\nTesting file I/O simulation...")

    # 创建另一个分词器测试encode_iterable
    #simple_vocab = {i: bytes([i]) for i in range(256)}
    #simple_vocab[256] = b'test'
    #simple_merges = []
    #simple_tokenizer = Tokenizer(simple_vocab, simple_merges)

    # 测试encode_iterable
    #texts = ["hello", "world", "test"]
    #print("Testing encode_iterable:")
    #for token_id in simple_tokenizer.encode_iterable(texts):
    #    print(f"  Yielded token ID: {token_id}")


if __name__ == "__main__":
    test_tokenizer()

结果：原始文本大小: 3786 字符
编码后token数量: 1881 tokens
压缩比: 2.01 (字符/token)

吞吐量测试:
编码吞吐量: 9363.55 字符/秒
解码吞吐量: 14811296.15 字符/秒
编码吞吐量: 4652.10 tokens/秒
解码吞吐量: 7358702.61 tokens/秒


## Linear

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class Linear(nn.Module):
    def __init__(self, in_features, out_features, device=None, dtype=None):
        """
        参数:
            in_features: int - 输入的最终维度
            out_features: int - 输出的最终维度  
            device: torch.device | None = None - 存储参数的设备
            dtype: torch.dtype | None = None - 参数的数据类型
        """
        super().__init__()
        
        # 调用父类构造函数
        super(Linear, self).__init__()
        
        # 存储维度信息
        self.in_features = in_features
        self.out_features = out_features
        
        # 参数设置
        factory_kwargs = {}
        if device is not None:
            factory_kwargs['device'] = device
        if dtype is not None:
            factory_kwargs['dtype'] = dtype
        
        # 创建权重参数 W (形状: out_features × in_features)
        self.weight = nn.Parameter(torch.empty((out_features, in_features), **factory_kwargs))
        
        # 使用截断正态分布初始化权重
        self.reset_parameters()
    
    def reset_parameters(self):
        """使用截断正态分布初始化权重参数"""
        init.trunc_normal_(self.weight, mean=0.0, std=1.0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        对输入应用线性变换
        
        参数:
            x: torch.Tensor - 输入张量，形状为 (..., in_features)
            
        返回:
            torch.Tensor - 输出张量，形状为 (..., out_features)
        """
        # 执行线性变换: output = x @ W^T
        # 因为我们的 W 形状是 (out_features, in_features)
        # 所以需要转置为 (in_features, out_features) 来进行矩阵乘法
        return x @ self.weight.t()

## Embedding

In [None]:
import torch
import torch.nn as nn
import torch.nn.init as init

class Embedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim = embedding_dim
        
        # 初始化权重矩阵
        self.weight = nn.Parameter(torch.empty((num_embeddings, embedding_dim), 
                                              device=device, dtype=dtype))
        
        # 使用截断正态分布初始化权重
        init.trunc_normal_(self.weight, mean=0.0, std=1.0)
    
    def forward(self, token_ids: torch.Tensor) -> torch.Tensor:
        # 根据 token_ids 查找对应的嵌入向量
        return self.weight[token_ids]

## RMSnorm

In [None]:
import torch
import torch.nn as nn

class RMSNorm(nn.Module):
    def __init__(self, d_model: int, eps: float = 1e-5, device=None, dtype=None):
        super().__init__()
        self.d_model = d_model
        self.eps = eps
        
        # 初始化可学习的缩放参数 gamma
        factory_kwargs = {'device': device, 'dtype': dtype}
        self.gamma = nn.Parameter(torch.ones(d_model, **factory_kwargs))
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # 保存原始数据类型以便之后下转换
        original_dtype = x.dtype
        
        # 上转换为 float32 进行归一化计算
        x_float = x.float()
        
        # 计算均方根 (RMS)
        # 沿最后一个维度 (d_model) 计算，保持维度以便广播
        rms = torch.sqrt(torch.mean(x_float ** 2, dim=-1, keepdim=True) + self.eps)
        
        # 应用 RMS 归一化
        x_normalized = x_float / rms
        
        # 应用缩放参数 gamma
        x_normalized = x_normalized * self.gamma
        
        # 下转换回原始数据类型
        return x_normalized.to(original_dtype)

## Softmax

In [None]:
import torch

def run_softmax(tensor, dim):
    """
    对输入张量沿指定维度应用 softmax 操作。

    参数:
        tensor (torch.Tensor): 输入张量。
        dim (int): 应用 softmax 的维度。

    返回:
        torch.Tensor: 输出张量，形状与输入相同，指定维度为概率分布。
    """
    # 减去指定维度的最大值以提高数值稳定性
    max_vals = torch.max(tensor, dim=dim, keepdim=True).values
    shifted = tensor - max_vals
    exp_tensor = torch.exp(shifted)
     sum_exp= torch.sum(exp_tensor, dim=dim, keepdim=True)
    softmax_tensor = exp_tensor / sum_exp
    return softmax_tensor

减去最大值的原因：防止原始值过大时exp（x）超出浮点数表示范围

保持维度 指的是在张量操作（如求和、求最大值等）后，保持原始张量的维度结构，而不是压缩减少的维度。
假设我们有一个 2×3 的张量：
tensor = torch.tensor([[1, 2, 3],
                       [4, 5, 6]])
情况1：不保持维度 (keepdim=False，默认值)
max_vals = torch.max(tensor, dim=1)  # 沿行维度求最大值
print(max_vals.values)  # 输出: tensor([3, 6])
print(max_vals.values.shape)  # 输出: torch.Size([2])

结果从 2×3 变成了 1维张量 [3, 6]，维度从 2 维变成了 1 维。

情况2：保持维度 (keepdim=True)

max_vals = torch.max(tensor, dim=1, keepdim=True)
print(max_vals.values)  # 输出: tensor([[3], [6]])
print(max_vals.values.shape)  # 输出: torch.Size([2, 1])

结果保持为 2×1 的张量，仍然是 2 维。

## positionwise_feedforward

In [None]:
import torch
import torch.nn as nn

class SwiGLUFeedForward(nn.Module):
    def __init__(self, d_model, dropout=0.1):
        super().__init__()
        
        # 计算中间维度，确保是64的倍数
        d_ff_raw = (8/3) * d_model
        d_ff = int(round(d_ff_raw / 64)) * 64
        
        # 第一个线性层扩展到2倍维度（用于GLU拆分）
        self.linear1 = nn.Linear(d_model, 2 * d_ff)
        
        # 第二个线性层投影回原始维度
        self.linear2 = nn.Linear(d_ff, d_model)
        
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        # 第一个线性变换
        x_gate = self.linear1(x)
        
        # 拆分为两部分用于GLU
        x1, x2 = x_gate.chunk(2, dim=-1)
        
        # 手动实现SiLU: x * sigmoid(x)
        swish = x1 * torch.sigmoid(x1)
        
        # GLU门控部分
        gate = torch.sigmoid(x2)
        
        # GLU操作：逐元素相乘
        glu_output = swish * gate
        
        # 应用dropout
        glu_output = self.dropout(glu_output)
        
        # 第二个线性变换投影回原始维度
        output = self.linear2(glu_output)
        
        return output

# 测试适配器函数
def run_swiglu(input_tensor, d_model, dropout=0.1):
    """
    Args:
        input_tensor: 输入张量
        d_model: 模型维度
        dropout: dropout率
    Returns:
        output_tensor: 输出张量
    """
    swiglu = SwiGLUFeedForward(d_model, dropout)
    return swiglu(input_tensor)

## RoPE

In [None]:
import torch
import torch.nn as nn
import math

class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, theta: float, d_k: int, max_seq_len: int, device=None):
        super().__init__()
        self.theta = theta
        self.d_k = d_k
        self.max_seq_len = max_seq_len
        
        # 预计算正弦和余弦值
        positions = torch.arange(max_seq_len, device=device).float()
        indices = torch.arange(0, d_k, 2, device=device).float()
        
        # 计算角度：theta^(-2i/d_k) * pos
        theta_powered = theta ** (-indices / d_k)
        angles = positions.unsqueeze(1) * theta_powered.unsqueeze(0)  # (max_seq_len, d_k/2)
        
        # 计算正弦和余弦
        cos_vals = torch.cos(angles)  # (max_seq_len, d_k/2)
        sin_vals = torch.sin(angles)  # (max_seq_len, d_k/2)
        
        # 注册为缓冲区，这样它们会随模型一起移动设备且不需要梯度
        self.register_buffer('cos_cached', cos_vals, persistent=False)
        self.register_buffer('sin_cached', sin_vals, persistent=False)

    def forward(self, x: torch.Tensor, token_positions: torch.Tensor) -> torch.Tensor:
        """
        应用旋转位置编码到输入张量
        
        Args:
            x: 形状为 (..., seq_len, d_k) 的输入张量
            token_positions: 形状为 (..., seq_len) 的标记位置张量
            
        Returns:
            应用RoPE后的张量，形状与x相同
        """
        batch_dims = x.shape[:-2]  # 获取批次维度
        seq_len = x.shape[-2]
        
        # 重塑x以分离实部和虚部（相邻维度为一对）
        x_reshaped = x.reshape(*batch_dims, seq_len, self.d_k // 2, 2)
        
        # 根据token_positions获取对应的cos和sin值
        # token_positions形状: (..., seq_len) -> 展平以进行索引
        flat_positions = token_positions.reshape(-1)
        
        # 索引预先计算的cos和sin值
        cos_vals = self.cos_cached[flat_positions]  # (total_tokens, d_k/2)
        sin_vals = self.sin_cached[flat_positions]  # (total_tokens, d_k/2)
        
        # 重塑回原始批次形状加上序列和特征维度
        cos_vals = cos_vals.reshape(*batch_dims, seq_len, self.d_k // 2, 1)
        sin_vals = sin_vals.reshape(*batch_dims, seq_len, self.d_k // 2, 1)
        
        # 应用旋转：对于每对(x_i, x_{i+1})，应用2D旋转
        x_rotated = torch.stack([
            x_reshaped[..., 0] * cos_vals.squeeze(-1) - x_reshaped[..., 1] * sin_vals.squeeze(-1),
            x_reshaped[..., 0] * sin_vals.squeeze(-1) + x_reshaped[..., 1] * cos_vals.squeeze(-1)
        ], dim=-1)
        
        # 重塑回原始形状
        return x_rotated.reshape(x.shape)

张量x的形状：(batch_size, ..., seq_len, d_k)

## scaled_dot_product_attention

In [None]:
import torch
import torch.nn.functional as F

def scaled_dot_product_attention(query, key, value, mask=None):
    """
    参数:
        query: 形状为 (batch_size, ..., seq_len_q, d_k) 的张量
        key: 形状为 (batch_size, ..., seq_len_k, d_k) 的张量
        value: 形状为 (batch_size, ..., seq_len_v, d_v) 的张量 (通常 seq_len_k = seq_len_v)
        mask: 形状为 (seq_len_q, seq_len_k) 的布尔张量 (可选)

    返回:
        输出: 形状为 (batch_size, ..., seq_len_q, d_v) 的张量
    """
    # 获取 key 的最后一个维度 d_k
    d_k = query.size(-1)
    
    # 计算点积并缩放
    scores = torch.matmul(query, key.transpose(-2, -1)) / (d_k ** 0.5)
    
    # 如果提供了掩码，应用掩码
    if mask is not None:
        # 将掩码中 False 的位置对应的分数设置为一个非常大的负值
        scores = scores.masked_fill(mask == False, -1e9)
    
    # 在最后一个维度（key 的序列维度）上应用 softmax 来获取注意力权重
    attention_weights = F.softmax(scores, dim=-1)
    
    # 使用注意力权重对 value 进行加权求和
    output = torch.matmul(attention_weights, value)
    
    return output

## multihead_self_attention