In [82]:
import os
import regex as re
from typing import BinaryIO
from typing import Iterable, Iterator
from collections import defaultdict
from multiprocessing import Process, Queue
import time
from tqdm import tqdm

In [None]:
def find_chunk_boundaries(
    file: BinaryIO, 
    desired_num_chunks: int, 
    split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), (
        "Must represent special token as a bytestring"
    )

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

def split_by_special_tokens(text: str, special_tokens: list[str]) -> list[str]:
    """
    Split on the special tokens
    example: 
        text = "Hello world! <|endoftext|> Great!" 
        special_tokens = "<|endoftext|>"
        result = ['Hello world! ', '<|endoftext|>', ' Great!']
    """
    special_tokens_sorted = sorted(special_tokens, key=lambda x: -len(x))
    if not special_tokens_sorted:
        parts = [text]
    else:
        pattern = "|".join(re.escape(tok) for tok in special_tokens_sorted)
        parts = re.split('(' + pattern + ')', text)

    return parts

def pretokenize(text: str, special_tokens: list[str], drop_special_token: bool = True) -> list[bytes]:
    """
    Seperating text into pretokens
    Special tokens are independent pretokens
    """
    parts = split_by_special_tokens(text, special_tokens)
    #parts内容: ['Hello, how ', '<|endoftext|><|endoftext|>', ' are you?', '<|endoftext|>', '']
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    tokens_list = []
    for part in parts:
        if part in special_tokens:
            if not drop_special_token:  # Keep special tokens, otherwise ignore
                spec_tok_bytes = part.encode('utf-8')
                tokens_list.append([spec_tok_bytes])
        else:
            #str_tokens = re.findall(PAT, part)  #re.finditer(PAT, part)更好
            #part_tokens = [s.encode('utf-8') for s in str_tokens]
            # 更好的方式
            str_tokens = re.finditer(PAT, part)  # 返回Match对象迭代器
            part_tokens = [match.group().encode('utf-8') for match in str_tokens]

            tokens_list.append(part_tokens)
    #tokens_list内容: [[b'Hello', b',', b' how', b' '], [b'<|endoftext|><|endoftext|>'], [b' are', b' you', b'?'], [b'<|endoftext|>'], []]
    tokens = [token for part_tokens in tokens_list for token in part_tokens]    #flatten token_list
    return tokens

def worker(text: str, special_tokens: list[str], q: Queue):
    try:
        print("[worker] start")
        pretokens = pretokenize(text, special_tokens)
        print("[worker] done pretokenize, len:", len(pretokens))
        q.put(pretokens)
    except Exception as e:
        print("[worker] error:", e)
        q.put([])

def train_bpe(
    input_path: str | os.PathLike,
    vocab_size: int,
    special_tokens: list[str],
    **kwargs,
) -> tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
    """Given the path to an input corpus, run train a BPE tokenizer and
    output its vocabulary and merges.

    Args:
        input_path (str | os.PathLike): Path to BPE tokenizer training data.
        vocab_size (int): Total number of items in the tokenizer's vocabulary (including special tokens).
        special_tokens (list[str]): A list of string special tokens to be added to the tokenizer vocabulary.
            These strings will never be split into multiple tokens, and will always be
            kept as a single token. If these special tokens occur in the `input_path`,
            they are treated as any other string.

    Returns:
        tuple[dict[int, bytes], list[tuple[bytes, bytes]]]:
            vocab:
                The trained tokenizer vocabulary, a mapping from int (token ID in the vocabulary)
                to bytes (token bytes)
            merges:
                BPE merges. Each list item is a tuple of bytes (<token1>, <token2>),
                representing that <token1> was merged with <token2>.
                Merges are ordered by order of creation.
    """
    print("train_bpe: start")
    special_tokens = special_tokens or []
    num_merges = max(vocab_size - len(special_tokens) - 256, 0)

    # Initialize vocab
    vocab = {}
    vocab = {x:bytes([x]) for x in range(0,256)}
    for i, token in enumerate(special_tokens):
        vocab[256+i] = token.encode("utf-8")
    merges = []

    # Chunk the text file
    num_processes = 4
    chunk_list = []
    with open(input_path, "rb") as f:
        boundaries = find_chunk_boundaries(f, num_processes, "<|endoftext|>".encode("utf-8"))

        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            chunk_list.append(chunk)

    print("train_bpe: after chunking")
    print("chunk_list len:", len(chunk_list))
    for i, chunk in enumerate(chunk_list):
        print(f"chunk {i} len:", len(chunk))
    # # 只用前1个chunk
    #chunk_list = chunk_list[:1]
    # Parallelizing pretokenization
    pretokens_list = []
    processes = []
    q = Queue()
    for chunk in chunk_list:
        print("[main] starting worker")
        p = Process(target=worker, args=(chunk, special_tokens, q))
        p.start()
        processes.append(p)

    for i, p in enumerate(processes):
        print(f"[main] waiting for q.get() from worker {i}")
        pretokens = q.get()
        print(f"[main] got result from worker {i}, len={len(pretokens)}")
        pretokens_list.append(pretokens)

    for i, p in enumerate(processes):
        print(f"[main] joining worker {i}")
        p.join()
        print(f"[main] worker {i} joined")

    pretokens = [token for tokens in pretokens_list for token in tokens]
    print("train_bpe: after pretokenization")
    print("total pretokens:", len(pretokens))
    print("first pretoken len:", len(pretokens[0]))
    # Merging
    counts = defaultdict(int)   #统计相邻对的频率
    index_dict = defaultdict(set)  # Store pretoken location for each pair记录块索引

    for j, pretoken in tqdm(enumerate(pretokens), total=len(pretokens), desc="Counting pairs"):
        for index1, index2 in zip(pretoken, pretoken[1:]): #对pretoken遍历相邻对
            counts[index1, index2] += 1
            index_dict[index1, index2].add(j)

    for i in tqdm(range(num_merges), desc="Merging BPE pairs"):
        # Prefer lexicographically greater pair 频率相同时 字典序更大的优先 代表更丰富的语义信息
        # Example: max([("A", "B"), ("A", "C"), ("B", "ZZ"), ("BA", "A")]) = ('BA', 'A')
        max_pair = max(
            counts.items(),
            key=lambda x: (
                x[1],  
                vocab[x[0][0]].decode("utf-8", errors="ignore"),
                vocab[x[0][1]].decode("utf-8", errors="ignore")
            )
        )[0]

        index1, index2 = max_pair

        new_index = 256 + len(special_tokens) + i

        vocab[new_index] = vocab[index1] + vocab[index2]
        merges.append((vocab[index1], vocab[index2]))

        merge(counts, index_dict, pretokens, max_pair, new_index)
        #进行左右侧计数更，指标更新，pretokens更新
    return (vocab, merges)

def merge(counts: dict[tuple[int, int], int], index_dict: dict[tuple[int, int],set[int]], pretokens: list[list[int]], max_pair: tuple[int, int], new_index: int):
    """Merge the pairs with highest frequency and update counts, index_dict"""
    index_set = index_dict[max_pair]    #获取需要处理的分块​

    for i in index_set:
        pretoken = pretokens[i]
        new_pretoken = []

        pos_list = []   # Store positions of max_pair for each new pretoken after merge
        pos = 0
        j = 0

        # Replace max_pair with new_index in each pretoken
        while j < len(pretoken):
            if (j < len(pretoken)-1) and ((pretoken[j], pretoken[j+1]) == max_pair):
                new_pretoken.append(new_index)
                pos_list.append(pos)
                j += 2
            else:
                new_pretoken.append(pretoken[j])
                j += 1
            pos += 1    #新的pretoken里面合并的位置的索引

        # Update counts and index_dict
        for pos in pos_list:
            counts[max_pair] -= 1

            if pos > 0: #合并位置不是开头第一位，更新左侧相邻对
                if new_pretoken[pos-1] == new_index: #若左侧也是合并对，即连续合并情况
                    counts[(max_pair[1], max_pair[0])] -= 1 #BA减少一次
                else:   #左侧是普通符号
                    counts[(new_pretoken[pos-1], max_pair[0])] -= 1 #左侧与合并第一位减少一次

                counts[(new_pretoken[pos-1], new_pretoken[pos])] += 1   #左侧与合并体增加一次
                index_dict[(new_pretoken[pos-1], new_pretoken[pos])].add(i) #记录新对的位置，也就是上面一行的对

            if pos < len(new_pretoken)-1:   #更新右侧相邻对
                if new_pretoken[pos+1] == new_index:
                    counts[(max_pair[1], max_pair[0])] -= 1     
                else:
                    counts[(max_pair[1], new_pretoken[pos+1])] -= 1

                counts[(new_pretoken[pos], new_pretoken[pos+1])] += 1
                index_dict[(new_pretoken[pos], new_pretoken[pos+1])].add(i)

        pretokens[i] = new_pretoken #每个块里面更新pretoken



In [None]:
class BPETokenizer:
    def __init__(self, vocab: dict[int, bytes], merges: list[tuple[bytes, bytes]], special_tokens: list[str]| None = None):
        self.vocab = vocab
        self.merges = merges
        self.special_tokens = special_tokens or []

    @classmethod
    def from_files(cls, vocab_filepath: str, merges_filepath: str, special_tokens: list[str] | None = None):
        """Class method that constructs and return a Tokenizer from a serialized vocabulary and list of merges"""
        raise NotImplementedError

    def encode(self, text:str) -> list[int]:
        """Encode an input text into a sequence of token IDs."""

        vocab_reversed = {v: k for k, v in self.vocab.items()}  # int: bytes-->bytes: int
        byte_pretokens = pretokenize(text, self.special_tokens, drop_special_token=False)   # list[bytes]
        byte_special_tokens = [token.encode('utf-8') for token in self.special_tokens]
        pretokens = []  # list[list[int]]

        # Convert pretokens from bytes to list[int] by vocab
        for i, pretoken in enumerate(byte_pretokens):

            new_pretoken = []

            if pretoken in byte_special_tokens:
                index = vocab_reversed[pretoken]
                new_pretoken.append(index)
            else:
                for b in pretoken:  #普通token按照字节处理
                    index = vocab_reversed[bytes([b])]
                    new_pretoken.append(index)

            pretokens.append(new_pretoken)

        # Merge  三重循环
        #首先对pretokens中的每个pretoken进行merge
        #其次对合并规则表进行逐项匹配
        #对于每一个merge，遍历pretoken若其中有相邻的merge则合并，这样避免了跨单词合并
        for i, pretoken in enumerate(pretokens):
            for merge in self.merges:   #merges: list[tuple[bytes, bytes]],
                new_pretoken = []
                new_index = vocab_reversed[merge[0] + merge[1]]
                j = 0
                while j < len(pretoken):
                    if (j < len(pretoken)-1) and ((self.vocab[pretoken[j]], self.vocab[pretoken[j+1]]) == merge):
                        new_pretoken.append(new_index)
                        j += 2
                    else:
                        new_pretoken.append(pretoken[j])
                        j += 1
                        #new_pretoken保存的是合并一个merge后每个pretoken的id串
                pretoken = new_pretoken

            pretokens[i] = pretoken
            #pretokens[i]保存的是合并所有merge后每个pretoken的id串，然后对i循环所有的pretokens
        tokens = [token for pretoken in pretokens for token in pretoken] 
        return tokens

    def encode_iterable(self, iterable: Iterable[str]) -> Iterator[int]:
        """Given an iterable of strings (e.g., a Python file handle), 
        return a generator that lazily yields token IDs. 
        This is required for memory-eﬀicient tokenization of large files 
        that we cannot directly load into memory.
        """
        for line in iterable:
            for idx in self.encode(line):
                yield idx


    def decode(self, ids: list[int]) -> str:
        """Decode a sequence of token IDs into text."""
        tokens = bytes()
        vocab_size = len(self.vocab)
        replacement_char = "\uFFFD"

        for token_id in ids:
            if token_id < vocab_size:
                token = self.vocab[token_id]    # vocab: int: bytes
            else:
                token = bytes(replacement_char, encoding='utf-8')   # Replace tokens with Unicode replacement characters if index out of bounds

            tokens += token
        decoded = tokens.decode(encoding='utf-8', errors='replace')

        return decoded 

In [85]:
def test():
    import tiktoken
    tokenizer = tiktoken.get_encoding('gpt2')
    test_string = "Hello, how <|endoftext|><|endoftext|> are you?<|endoftext|>"
    ids = tokenizer.encode(test_string, allowed_special={"<|endoftext|><|endoftext|>", "<|endoftext|>"})
    decoded = [tokenizer.decode([x]) for x in ids]
    print(decoded)

In [86]:
file_path = "../data/TinyStoriesV2-GPT4-valid.txt"
vocab_size = 330  # 临时减少
special_tokens = ["<|endoftext|>", "<|endoftext|><|endoftext|>"]

vocab, merges = train_bpe(file_path, vocab_size, special_tokens)

tokenizer = BPETokenizer(vocab, merges, special_tokens)
test_string = "Hello, how <|endoftext|><|endoftext|> are you?<|endoftext|>"
encoded = tokenizer.encode(test_string)
print("encoded:",encoded)
decoded = [tokenizer.decode([x]) for x in encoded]
print("decoded:", decoded)

print(test_string == ''.join(decoded))

    # print(vocab)

train_bpe: start
train_bpe: after chunking
chunk_list len: 4
chunk 0 len: 5623437
chunk 1 len: 5624335
chunk 2 len: 5622697
chunk 3 len: 5622918
[main] starting worker
[main] starting worker
[worker] start
[main] starting worker
[worker] start
[main] starting worker
[worker] start
[worker] start
[main] waiting for q.get() from worker 0
[worker] done pretokenize, len: 1355040
[worker] done pretokenize, len: 1354266
[worker] done pretokenize, len: [worker] done pretokenize, len:1355347 
1354348
[main] got result from worker 0, len=1355040
[main] waiting for q.get() from worker 1
[main] got result from worker 1, len=1354266
[main] waiting for q.get() from worker 2
[main] got result from worker 2, len=1354348
[main] waiting for q.get() from worker 3
[main] got result from worker 3, len=1355347
[main] joining worker 0
[main] worker 0 joined
[main] joining worker 1
[main] worker 1 joined
[main] joining worker 2
[main] worker 2 joined
[main] joining worker 3
[main] worker 3 joined
train_bpe: 

Counting pairs: 100%|██████████| 5419001/5419001 [00:05<00:00, 910438.82it/s]
Merging BPE pairs: 100%|██████████| 72/72 [00:36<00:00,  1.98it/s]


pretokens内容 (ID转换后，合并前):
pretokens数量: 9
pretoken[0]: [72, 101, 108, 108, 111]
  对应字节串: [b'H', b'e', b'l', b'l', b'o']
pretoken[1]: [44]
  对应字节串: [b',']
pretoken[2]: [32, 104, 111, 119]
  对应字节串: [b' ', b'h', b'o', b'w']
pretoken[3]: [32]
  对应字节串: [b' ']
pretoken[4]: [257]
  对应字节串: [b'<|endoftext|><|endoftext|>']
pretoken[5]: [32, 97, 114, 101]
  对应字节串: [b' ', b'a', b'r', b'e']
pretoken[6]: [32, 121, 111, 117]
  对应字节串: [b' ', b'y', b'o', b'u']
pretoken[7]: [63]
  对应字节串: [b'?']
pretoken[8]: [256]
  对应字节串: [b'<|endoftext|>']
encoded: [72, 101, 294, 111, 44, 269, 319, 32, 257, 260, 274, 32, 121, 276, 63, 256]
decoded: ['H', 'e', 'll', 'o', ',', ' h', 'ow', ' ', '<|endoftext|><|endoftext|>', ' a', 're', ' ', 'y', 'ou', '?', '<|endoftext|>']
True


In [87]:
print(list(vocab.items())[:50])
print(list(merges)[-50:])

[(0, b'\x00'), (1, b'\x01'), (2, b'\x02'), (3, b'\x03'), (4, b'\x04'), (5, b'\x05'), (6, b'\x06'), (7, b'\x07'), (8, b'\x08'), (9, b'\t'), (10, b'\n'), (11, b'\x0b'), (12, b'\x0c'), (13, b'\r'), (14, b'\x0e'), (15, b'\x0f'), (16, b'\x10'), (17, b'\x11'), (18, b'\x12'), (19, b'\x13'), (20, b'\x14'), (21, b'\x15'), (22, b'\x16'), (23, b'\x17'), (24, b'\x18'), (25, b'\x19'), (26, b'\x1a'), (27, b'\x1b'), (28, b'\x1c'), (29, b'\x1d'), (30, b'\x1e'), (31, b'\x1f'), (32, b' '), (33, b'!'), (34, b'"'), (35, b'#'), (36, b'$'), (37, b'%'), (38, b'&'), (39, b"'"), (40, b'('), (41, b')'), (42, b'*'), (43, b'+'), (44, b','), (45, b'-'), (46, b'.'), (47, b'/'), (48, b'0'), (49, b'1')]
[(b' ', b'p'), (b'a', b'y'), (b' ', b'm'), (b'e', b'r'), (b' wa', b's'), (b' T', b'he'), (b'o', b'm'), (b' ', b'he'), (b'i', b's'), (b' ', b'n'), (b'i', b'm'), (b'a', b'r'), (b'o', b'n'), (b' s', b'a'), (b'l', b'l'), (b'i', b'd'), (b' h', b'a'), (b' ', b'g'), (b'a', b't'), (b' ', b'S'), (b'in', b'g'), (b'o', b't'), (b

In [88]:
vocab_reversed = {v: k for k, v in vocab.items()}  # bytes: int
print(vocab_reversed)

{b'\x00': 0, b'\x01': 1, b'\x02': 2, b'\x03': 3, b'\x04': 4, b'\x05': 5, b'\x06': 6, b'\x07': 7, b'\x08': 8, b'\t': 9, b'\n': 10, b'\x0b': 11, b'\x0c': 12, b'\r': 13, b'\x0e': 14, b'\x0f': 15, b'\x10': 16, b'\x11': 17, b'\x12': 18, b'\x13': 19, b'\x14': 20, b'\x15': 21, b'\x16': 22, b'\x17': 23, b'\x18': 24, b'\x19': 25, b'\x1a': 26, b'\x1b': 27, b'\x1c': 28, b'\x1d': 29, b'\x1e': 30, b'\x1f': 31, b' ': 32, b'!': 33, b'"': 34, b'#': 35, b'$': 36, b'%': 37, b'&': 38, b"'": 39, b'(': 40, b')': 41, b'*': 42, b'+': 43, b',': 44, b'-': 45, b'.': 46, b'/': 47, b'0': 48, b'1': 49, b'2': 50, b'3': 51, b'4': 52, b'5': 53, b'6': 54, b'7': 55, b'8': 56, b'9': 57, b':': 58, b';': 59, b'<': 60, b'=': 61, b'>': 62, b'?': 63, b'@': 64, b'A': 65, b'B': 66, b'C': 67, b'D': 68, b'E': 69, b'F': 70, b'G': 71, b'H': 72, b'I': 73, b'J': 74, b'K': 75, b'L': 76, b'M': 77, b'N': 78, b'O': 79, b'P': 80, b'Q': 81, b'R': 82, b'S': 83, b'T': 84, b'U': 85, b'V': 86, b'W': 87, b'X': 88, b'Y': 89, b'Z': 90, b'[': 91,