In [None]:
import regex as re
from cs336_basics.pretokenization_example import find_chunk_boundaries


input_path:str = "../data/TinyStoriesV2-GPT4-valid.txt"
vocab_size:int = 1024
special_tokens:list[str] = ["<|endoftext|>", "<|startoftext|>", ]

# 1.初始化词表
vocab:dict[int,bytes] = {i:bytes([i]) for i in range(256)}
merges:list[tuple[bytes, bytes]] = []
next_token_id = 256  
existing_bytes_token:set[bytes] = set(vocab.values())# 集合记录词表中已合并出的字节token和特殊token，用于高效索引
# 添加special tokens到词表
for sp_token in special_tokens:
    if len(vocab) >= vocab_size:
        break
    sp_token_bytes = sp_token.encode('utf-8')
    if sp_token_bytes not in existing_bytes_token:
        vocab[next_token_id] = sp_token_bytes
        existing_bytes_token.add(sp_token_bytes)
        next_token_id += 1

print(f"vocab: {vocab}")
print(f"existing_bytes_token: {existing_bytes_token}")
print(f"vocab_size: {len(vocab)}")

vocab: {0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\x16', 23: b'\x17', 24: b'\x18', 25: b'\x19', 26: b'\x1a', 27: b'\x1b', 28: b'\x1c', 29: b'\x1d', 30: b'\x1e', 31: b'\x1f', 32: b' ', 33: b'!', 34: b'"', 35: b'#', 36: b'$', 37: b'%', 38: b'&', 39: b"'", 40: b'(', 41: b')', 42: b'*', 43: b'+', 44: b',', 45: b'-', 46: b'.', 47: b'/', 48: b'0', 49: b'1', 50: b'2', 51: b'3', 52: b'4', 53: b'5', 54: b'6', 55: b'7', 56: b'8', 57: b'9', 58: b':', 59: b';', 60: b'<', 61: b'=', 62: b'>', 63: b'?', 64: b'@', 65: b'A', 66: b'B', 67: b'C', 68: b'D', 69: b'E', 70: b'F', 71: b'G', 72: b'H', 73: b'I', 74: b'J', 75: b'K', 76: b'L', 77: b'M', 78: b'N', 79: b'O', 80: b'P', 81: b'Q', 82: b'R', 83: b'S', 84: b'T', 85: b'U', 86: b'V', 87: b'W', 88: b'X', 89: b'Y', 90: b'Z', 91

In [40]:
from collections import Counter
# 2.pre-tokenization
num_processes = 8  # Number of processes to use for parallelization
with open(input_path, 'rb') as f:
    boundaries:list[int] = find_chunk_boundaries(
        f, num_processes, "<|endoftext|>".encode("utf-8"))
    # 切分文件为多个部分
    print(f"Chunk boundaries: {boundaries}")

    for start, end in zip(boundaries[:-1], boundaries[1:]):
        f.seek(start)
        # 以字节形式读取，解码为字符串
        chunk:str = f.read(end - start).decode("utf-8", errors="ignore")
        print(f"Processing chunk from {start} to {end}, length: {len(chunk)}")
        print(f"Chunk content: {chunk[:100]}...")  
        # 删除特殊token
        split_pattern:str = "|".join([re.escape(special_token) for special_token in set(special_tokens)])
        # 将chunk按特殊token进行分割
        if not split_pattern:
            chunk_no_special_token:list[str] = [chunk]
        else:
            chunk_no_special_token:list[str] = re.split(split_pattern, chunk)
        # 对每个chunk进行预分词
        PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
        pre_token_counter:Counter[tuple[bytes,...]] = Counter() # 预分词频率表
        for sub_chunk in chunk_no_special_token:
            if not sub_chunk or sub_chunk.isspace():
                continue
# 预分词后应得到频率表dict[tuple[bytes],int]
            for match in re.finditer(PAT, sub_chunk):
                pre_token:bytes = match.group(0).encode('utf-8')
                pre_token_bytes = tuple(bytes([b]) for b in pre_token)  
                # print(pre_token_bytes)
                pre_token_counter[pre_token_bytes] += 1
           
    print(pre_token_counter)

Chunk boundaries: [0, 2813541, 5625758, 8438541, 11252559, 14064356, 16877372, 19690315, 22502601]
Processing chunk from 0 to 2813541, length: 2812473
Chunk content: u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little ...
Processing chunk from 2813541 to 5625758, length: 2810964
Chunk content: <|endoftext|>

Tommy was excited that it was time to play with his new toy car. He ran down the stai...
Processing chunk from 5625758 to 8438541, length: 2811507
Chunk content: <|endoftext|>
Once upon a time, in a green forest, there lived a big bear and a little bunny. They w...
Processing chunk from 8438541 to 11252559, length: 2812828
Chunk content: <|endoftext|>

Once there was a young girl called Lucy. She was only three years old and loved playi...
Processing chunk from 11252559 to 14064356, length: 2810919
Chunk content: <|endoftext|>
Once upon a time, there was a little boy named Tim. Tim loved to print pictures of ani...
Processing chunk fr

In [None]:
from cs336_basics.train_bpe import merge,get_stat,train_bpe,pre_tokenization
pre_token_counter = pre_tokenization(input_path, special_tokens)


6765
Pre-token: b'\n', Count: 19105
Pre-token: b'One', Count: 2009
Pre-token: b' day', Count: 5358
Pre-token: b',', Count: 29297
Pre-token: b' a', Count: 18887
Pre-token: b' bald', Count: 13
Pre-token: b' man', Count: 711
Pre-token: b' named', Count: 2528
Pre-token: b' Tom', Count: 2134


In [None]:

vocab,merges = train_bpe(input_path, vocab_size=vocab_size, special_tokens=special_tokens)
print(f"vocab: {vocab}")
print(f"merges: {merges}")

PatternError: bad escape \p at position 23

In [None]:
# 3.迭代合并字节对
while len(vocab) < 300:
    # 找到频率最高的字节对
    most_common_pair = get_stat(pre_token_counter,existing_bytes_token)
    if not most_common_pair:
        break
    b0, b1 = most_common_pair
    merge_token = b0 + b1
    # 如果合并后的token已经存在于词表中，则只需清除pre-token中的单个token跳过
    if merge_token in existing_bytes_token:
        pre_token_counter.pop(merge_token, None)
        continue
    # 将合并后的token添加到词表
    vocab[next_token_id] = merge_token
    existing_bytes_token.add(merge_token)
    merges.append((b0, b1))
    next_token_id += 1
    # 更新pre-token计数器
    print(len(vocab))
    print(merge_token)
    pre_token_counter:Counter[bytes] = merge(pre_token_counter, b0, b1, merge_token)
print(f"vocab {vocab}")

KeyboardInterrupt: 

PatternError: bad escape \p at position 23