In [64]:
import os
from typing import BinaryIO


def find_chunk_boundaries(
    file: BinaryIO,
    desired_num_chunks: int,
    split_special_token: bytes,
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), "Must represent special token as a bytestring"

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

## Slpit on byte level, may end up with invalid utf-8 characters!

Byte-level BPE (like GPT-2/GPT-4)
In byte-level BPE, you do start with individual UTF-8 bytes as the base vocabulary. This means:

The initial vocabulary consists of all 256 possible byte values (0-255)
Multi-byte UTF-8 characters are split into their constituent bytes initially
For example, "café" becomes bytes like [99, 97, 102, 195, 169] where 195,169 represents "é"

Can this create invalid tokens? Yes, absolutely:

Individual byte tokens can represent incomplete UTF-8 sequences (like having just the first byte of a multi-byte character)
During merging, you can end up with subword tokens that contain invalid UTF-8 byte sequences
The model learns to work with these byte-level representations, even if they don't correspond to valid Unicode characters

Why byte-level BPE became popular
Byte-level BPE is preferred in modern models because:

Universal coverage: Can represent any text in any language/script
Handles unknown characters: No need for special UNK tokens
Consistent vocabulary size: Always 256 base tokens regardless of language

The trade-off is that the model must learn to reconstruct valid UTF-8 from potentially fragmented byte sequences, but modern transformers handle this quite well.RetryClaude can make mistakes. Please double-check responses.

In [65]:
import dask

In [130]:
import collections
import regex as re

def read_chunk(file_path: str, start_offset: int, end_offset: int) -> bytes:
    with open(file_path, "rb") as f:
        f.seek(start_offset)
        return f.read(end_offset - start_offset)

def split_on_special_characaters(data: bytes, pattern: bytes) -> list[bytes]:
    return re.split(pattern, data)

In [208]:
def _merge_tokens(token1: int | tuple[int, ...], token2: int | tuple[int, ...]) -> tuple[int, ...]:
    if isinstance(token1, int):
        token1 = (token1,)
    if isinstance(token2, int):
        token2 = (token2,)
    return (*token1, *token2)


def count_and_merge(pretoken_counts: collections.Counter) -> tuple[tuple[int, ...]]:
    pair_counts = collections.Counter()
    for word, count in pretoken_counts.items():
        for i in range(len(word) - 1):
            pair_counts[(word[i], word[i + 1])] += count

    # Find the most common byte pair
    # TODO: this is resolves ties arbitrarily
    most_common_pair = pair_counts.most_common(1)[0][0]
    pair_counts.pop(most_common_pair)

    # Merge the most common pair in the pret
    new_token = _merge_tokens(most_common_pair[0], most_common_pair[1])
    updates = []
    for word, count in pretoken_counts.items():
        new_word = []
        ix = 0
        updated = False
        while ix < len(word):
            if ix + 1 < len(word) and (word[ix], word[ix + 1]) == most_common_pair:
                new_word.append(new_token)
                ix += 2
                updated = True
            else:
                new_word.append(word[ix])
                ix += 1
        if updated:
            updates.append((word, tuple(new_word), count))

    # Update the pretoken_counts with the new tokens
    for old_word, new_word, count in updates:
        pretoken_counts.pop(old_word)
        pretoken_counts[new_word] = count

    return most_common_pair, new_token

In [212]:
# Patterns for the pretokenization
PRETOKENIZATION_PATTERN = re.compile(r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")

def get_chunk_boundaries(data_path: str, separator: bytes) -> list[int]:
    with open(data_path, "rb") as f:
        num_processes = 64
        boundaries = find_chunk_boundaries(f, num_processes, b"<|endoftext|>")

    return list(zip(boundaries[:-1], boundaries[1:]))


def train_bpe(
    input_path: str,
    vocab_size: int,
    separator: bytes = b"<|endoftext|>",
    special_tokens: list[str] = None
) -> dict[int, tuple[bytes]]:
    if not special_tokens:
        special_tokens = []
    vocabulary = collections.defaultdict(int)
    
    # add the special tokens to the vocabulary
    for ix, token in enumerate(special_tokens):
        vocabulary[token.encode("utf-8")] = ix

    # add initial byte tokens: assume the encoding is utf-8: 256
    for i in range(256):
        vocabulary[bytes([i])] = len(vocabulary)

    # Find chunk boundaries
    start_end_pairs = get_chunk_boundaries(input_path, separator=separator)

    # TODO: parallelize this part
    pretoken_counts = collections.Counter()

    # Read the data in chunks
    for start, end in start_end_pairs:
        data = read_chunk(input_path, start, end)

        # Split each chunk into stories using the EOS token
        stories = split_on_special_characaters(data, re.escape(separator))

        # Pre-tokenize the stories using a regex pattern to get individual words
        pretokenized = re.finditer(PRETOKENIZATION_PATTERN, stories[1].decode("utf-8", errors="ignore"))

        # Get the byte tokens for each word and merge with the global counts
        pretoken_counts += collections.Counter(tuple(word.group(0).encode('utf-8')) for word in pretokenized if word.group(0) is not None)

    # Merge the most common byte pairs until we reach the desired vocabulary size
    merges = []
    while len(vocabulary) < vocab_size:
        most_common_pair, new_token = count_and_merge(pretoken_counts)
        if not most_common_pair:
            break
        merges.append(most_common_pair)
        vocabulary[new_token] = len(vocabulary)

    # eos_pattern = re.escape(b"<|endoftext|>")
    return vocabulary, merges

In [215]:
data_path = 'data/TinyStoriesV2-GPT4-valid.txt'
vocab, merges = train_bpe(data_path, 300, special_tokens=['<|endoftext|>', '<|startoftext|>'])
print(len(vocab))
print(merges)

300
[(104, 101), (32, 116), (32, 97), (32, 115), (32, 119), ((32, 116), (104, 101)), (110, 100), (101, 100), (32, 98), ((32, 116), 111), ((32, 97), (110, 100)), (32, 104), (32, 84), (105, 110), (114, 101), (32, 102), (105, 116), (111, 117), ((32, 119), 97), (32, 108), (97, 121), (32, 99), (32, 100), (32, 112), (101, 114), ((32, 84), (104, 101)), (105, 115), (32, (104, 101)), (32, 109), (105, 109), ((32, 119, 97), 115), (111, 109), (111, 110), (97, 114), (97, 116), (32, 110), (105, 100), ((32, 115), 97), (32, 103), (32, 83), (105, 108), (111, 116)]
