In [None]:
# test.py
import collections
import multiprocessing
import typing as t

# Assuming these are defined elsewhere or passed in
# For demonstration, let's provide dummy implementations
def get_pair_counts(tokens: t.List[int], p_counts: t.Dict[tuple[int, int], int]):
    """
    Counts pairs of adjacent tokens in a list and updates the given p_counts dictionary.
    This function should be designed to update a shared dictionary safely (though here it returns updates).
    For multiprocessing, it's safer for each process to return its own counts, then combine.
    """
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i+1])
        p_counts[pair] = p_counts.get(pair, 0) + 1
    return p_counts # In multiprocessing, each worker returns its own dict

def merge_pair(tokens: t.List[int], pair_to_merge: tuple[int, int], new_token_id: int) -> t.List[int]:
    """
    Merges occurrences of a specific pair of tokens into a new token ID.
    """
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i + 1 < len(tokens) and (tokens[i], tokens[i+1]) == pair_to_merge:
            merged_tokens.append(new_token_id)
            i += 2 # Skip both tokens that were merged
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens

# --- Your main class/function where the loop resides ---

class BPEProcessor:
    def __init__(self, num_merges: int, num_processes: int = None):
        self._num_merges = num_merges
        # Use CPU count if num_processes is not specified
        self._num_processes = num_processes if num_processes is not None else multiprocessing.cpu_count()

    def perform_merges(self, initial_chunks_tokens: t.List[t.List[int]], initial_vocab: t.Dict[int, bytes]):
        """
        Performs BPE merges with multiprocessing acceleration.
        """
        chunks_tokens = initial_chunks_tokens
        vocab = initial_vocab
        merge_ranks: t.Dict[tuple[int, int], int] = {} # Stores merge rules (pair -> new_token_id)

        print(f"Starting BPE merges with {self._num_processes} processes...")

        # Create a multiprocessing pool
        # It's good practice to create the pool outside the loop if it's long-running
        # and you want to reuse processes.
        # However, for a fixed number of merges, creating/closing inside the loop might be simpler
        # if the pool is not expected to be reused extensively across different high-level tasks.
        # For this specific scenario (looping through _num_merges), let's create it once outside.
        with multiprocessing.Pool(processes=self._num_processes) as pool:
            for i in range(self._num_merges):
                print(f"\n--- Merge Iteration {i+1}/{self._num_merges} ---")

                # --- Accelerate 1: Accumulate Pair Counts ---
                # Each process will call get_pair_counts on a chunk of tokens and return its local counts.
                # pool.map or pool.starmap are good for this.
                # Here, we'll map a helper that wraps get_pair_counts.
                
                # Each process will receive one 'tokens' list from chunks_tokens
                # and return its own dictionary of pair counts for that list.
                all_partial_p_counts = pool.map(
                    lambda tokens_list: get_pair_counts(tokens_list, {}), # Pass an empty dict for each process
                    chunks_tokens
                )
                
                # Manually accumulate all partial p_counts from workers
                p_counts: t.Dict[tuple[int, int], int] = collections.defaultdict(int)
                for partial_counts in all_partial_p_counts:
                    for pair, count in partial_counts.items():
                        p_counts[pair] += count

                if not p_counts:
                    print("No more pairs to merge. Stopping early.")
                    break

                # From p_counts find occur-most pair of tokens (two IDs) as top_pair
                occur_most_pair: tuple[int, int] = max(p_counts, key=p_counts.get)
                new_token: int = i + 256 # Use merge rank as new token ID

                merge_ranks[occur_most_pair] = new_token # Record merge: rank as new token
                vocab[new_token] = vocab[occur_most_pair[0]] + vocab[occur_most_pair[1]] # Record new token corresponding bytes

                print(f"Merging pair {occur_most_pair} (representing '{vocab[occur_most_pair[0]].decode(errors='replace')}' + '{vocab[occur_most_pair[1]].decode(errors='replace')}') into new token ID {new_token}")

                # --- Accelerate 2: Update chunks_tokens ---
                # Each process will call merge_pair on a chunk of tokens and return the updated chunk.
                # Use functools.partial to fix the pair_to_merge and new_token_id arguments.
                
                # Create a partially applied function for merge_pair
                from functools import partial
                merge_func_for_pool = partial(merge_pair,
                                              pair_to_merge=occur_most_pair,
                                              new_token_id=new_token)

                # Map this partial function over all token chunks
                chunks_tokens = pool.map(merge_func_for_pool, chunks_tokens)

        print("\nBPE merges completed.")
        return chunks_tokens, merge_ranks, vocab

# --- Example Usage ---
if __name__ == "__main__":
    # Dummy initial data for demonstration
    # Initial tokens are ASCII values (b'a' -> 97, b'b' -> 98, etc.)
    # Let's say we start with some basic words
    initial_tokens_data = [
        [ord('a'), ord('b'), ord('a'), ord('b'), ord('a')], # ababa
        [ord('a'), ord('b'), ord('c')],                     # abc
        [ord('c'), ord('a'), ord('b')],                     # cab
        [ord('a'), ord('b'), ord('a'), ord('c')]            # abac
    ]

    # Initial vocab for base bytes (0-255 ASCII)
    initial_vocab_data = {i: bytes([i]) for i in range(256)}

    # Create an instance of the processor
    # For testing, you might use 2 processes, for production use multiprocessing.cpu_count()
    processor = BPEProcessor(num_merges=3, num_processes=2) # Perform 3 merges with 2 processes

    final_chunks, final_merge_ranks, final_vocab = processor.perform_merges(
        initial_tokens_data,
        initial_vocab_data
    )

    print("\n--- Final Results ---")
    print("Final Chunks Tokens:", final_chunks)
    print("Final Merge Ranks:", final_merge_ranks)
    print("Final Vocab:")
    for token_id, token_bytes in final_vocab.items():
        try:
            print(f"  {token_id}: '{token_bytes.decode('utf-8')}'")
        except UnicodeDecodeError:
            print(f"  {token_id}: {token_bytes} (undecodable)")


if __name__ == "__main__":
    tokens = []
    print(get_pair_counts(tokens))

    p_counts = {}
    returned = get_pair_counts(tokens, p_counts)

    print(returned)

    


In [None]:
import re
import collections
import typing as t
import os
import multiprocessing # 提前导入用于后续的多进程部分
from functools import partial

# 假设你的 get_pair_counts 和 merge_pair 函数已经定义好
# 它们应该接受单个 tokens 列表作为输入

# ----------------------------------------------------
# 辅助函数 (保持不变或略作调整以适应单次调用)
# ----------------------------------------------------
def get_pair_counts_for_chunk(tokens: t.List[int]) -> t.Dict[tuple[int, int], int]:
    """
    计算单个 token 块的 pair 计数，并返回本地字典。
    """
    p_counts_local: t.Dict[tuple[int, int], int] = collections.defaultdict(int)
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i+1])
        p_counts_local[pair] += 1
    return p_counts_local

def merge_pair_in_chunk(tokens: t.List[int], pair_to_merge: tuple[int, int], new_token_id: int) -> t.List[int]:
    """
    在单个 token 块中执行合并操作。
    """
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i + 1 < len(tokens) and (tokens[i], tokens[i+1]) == pair_to_merge:
            merged_tokens.append(new_token_id)
            i += 2
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens
# ----------------------------------------------------


class BPEMemoryOptimized:
    def __init__(self, pat_str: str, num_merges: int):
        self.pat_str = pat_str
        self._num_merges = num_merges
        self._vocab: t.Dict[int, bytes] = {i: bytes([i]) for i in range(256)}
        self._merge_ranks: t.Dict[tuple[int, int], int] = {} # 存储合并规则 (pair -> new_token_id)
        self._next_token_id = 256 # 从 256 开始分配新的 token ID

    def _tokenize_and_apply_merges(self, text_chunk: str) -> t.List[int]:
        """
        将文本块转换为 token ID 列表，并应用所有当前已存在的合并规则。
        这个函数在每次从磁盘读取文本时被调用。
        """
        # 1. 初始按 pat_str 分割并编码
        str_chunks = re.findall(self.pat_str, text_chunk)
        tokens_list = []
        for s_chunk in str_chunks:
            tokens_list.extend(list(s_chunk.encode('utf-8'))) # 转换为 int 列表

        # 2. 应用所有已学到的合并规则
        # 这里需要一个高效的方式来循环应用所有 merge_ranks
        # 最简单但不一定最快的方式是不断迭代直到没有变化
        # 更优的方法是构建一个 BPE decoder/encoder 结构，一次性编码
        # 为了演示，我们使用一个简单的迭代合并过程
        current_tokens = list(tokens_list) # 创建副本
        for pair_to_merge, new_token_id in self._merge_ranks.items():
            current_tokens = merge_pair_in_chunk(current_tokens, pair_to_merge, new_token_id)
        return current_tokens

    def train(self, corpus_filepath: str, batch_size_lines: int = 1000):
        """
        进行 BPE 训练，优化了内存占用，通过分批处理语料库。
        `corpus_filepath`: 大型文本文件的路径。
        `batch_size_lines`: 每次从文件中读取并处理的行数。
        """
        print(f"开始 BPE 训练，共 {self._num_merges} 轮合并，批处理大小为 {batch_size_lines} 行。")

        # 使用 multiprocessing.Pool 来加速计算
        with multiprocessing.Pool(processes=multiprocessing.cpu_count()) as pool: # 使用所有可用核心
            for i in range(self._num_merges):
                p_counts: t.Dict[tuple[int, int], int] = collections.defaultdict(int)
                print(f"\n--- 合并迭代 {i+1}/{self._num_merges} ---")

                # --- 内存优化和并行计算 p_counts ---
                # 不再将整个 chunks_tokens 加载到内存。
                # 而是逐批读取文件，对每个批次进行 tokenization 和 pair 计数。
                # 并且在每次迭代中，都重新从原始语料库读取，并应用“目前为止”所有已学到的合并规则。
                # 这样，`get_pair_counts_for_chunk` 就能在当前正确的 tokens 表示上工作。

                batch_texts = []
                current_batch_tokens_lists = [] # 存储当前批次的 tokens 列表

                with open(corpus_filepath, 'r', encoding='utf-8') as f:
                    for line_idx, line in enumerate(f):
                        batch_texts.append(line)
                        if (line_idx + 1) % batch_size_lines == 0:
                            # 处理一个批次
                            batch_corpus_text = "".join(batch_texts)

                            # 将当前批次文本转换成 token ID 列表，并应用所有已学到的合并规则
                            # 这一步是计算密集型，可以并行
                            processed_tokens_list_for_batch = self._tokenize_and_apply_merges(batch_corpus_text)
                            current_batch_tokens_lists.append(processed_tokens_list_for_batch)

                            batch_texts = [] # 重置批次

                    # 处理文件中剩余的行
                    if batch_texts:
                        batch_corpus_text = "".join(batch_texts)
                        processed_tokens_list_for_batch = self._tokenize_and_apply_merges(batch_corpus_text)
                        current_batch_tokens_lists.append(processed_tokens_list_for_batch)

                # 将 current_batch_tokens_lists 分发给多个进程，计算 pair counts
                # pool.map 会返回一个列表，其中包含每个进程计算的局部 p_counts
                all_partial_p_counts = pool.map(get_pair_counts_for_chunk, current_batch_tokens_lists)

                # 聚合所有局部 pair counts
                for partial_counts in all_partial_p_counts:
                    for pair, count in partial_counts.items():
                        p_counts[pair] += count

                if not p_counts:
                    print("没有更多可合并的 pair，提前停止。")
                    break

                # 找出出现次数最多的 pair
                occur_most_pair: tuple[int, int] = max(p_counts, key=p_counts.get)
                new_token_id: int = self._next_token_id
                self._next_token_id += 1

                # 记录新的合并规则和词汇
                self._merge_ranks[occur_most_pair] = new_token_id
                self._vocab[new_token_id] = self._vocab[occur_most_pair[0]] + self._vocab[occur_most_pair[1]]

                print(f"合并 pair {occur_most_pair} (新 token ID: {new_token_id})")

                # 注意：这里不再需要显式地更新一个巨大的 `chunks_tokens` 列表。
                # 下一轮迭代时，`_tokenize_and_apply_merges` 会从头读取文件并应用最新的 `_merge_ranks`。

        print("\nBPE 训练完成。")
        return self._vocab, self._merge_ranks

# --- 使用示例 (需要一个较大的虚拟语料库文件) ---
if __name__ == "__main__":
    # 创建一个虚拟的大语料库文件 (例如 100MB)
    dummy_corpus_path = "large_dummy_corpus.txt"
    if not os.path.exists(dummy_corpus_path):
        print(f"创建虚拟语料库文件: {dummy_corpus_path}")
        with open(dummy_corpus_path, 'w', encoding='utf-8') as f:
            for _ in range(50000): # 写入多行
                f.write("The quick brown fox jumps over the lazy dog. " * 20 + " hello world " * 10 + "\n")
        print("虚拟语料库已创建。")

    # BPE 分词器实例
    # pat_str 定义了如何初步分割文本（例如，按单词、空格、标点）
    bpe_trainer = BPEMemoryOptimized(pat_str=r"\w+|\s+|[^\w\s]", num_merges=20)
    final_vocab, final_merges = bpe_trainer.train(dummy_corpus_path, batch_size_lines=1000)

    print("\n--- 训练结果 ---")
    print(f"最终词汇表大小: {len(final_vocab)}")
    print(f"合并规则数量: {len(final_merges)}")
    # 可以进一步打印部分词汇和规则进行检查
    # print(dict(list(final_vocab.items())[256:266]))
    # print(dict(list(final_merges.items())[:10]))

    # 清理虚拟文件 (可选)
    # os.remove(dummy_corpus_path)