In [6]:
text = """low low low low low
lower lower widest widest widest <|endoftext|>
newest newest newest newest newest newest
"""

medium_text = """
He said, “Wow, that is a really amazing vase! Can I buy it?”
The shopkeeper smiled and said, “Of course you can. You can take it home and show all your friends how amazing it is!”
So Ben took the vase home and he was so proud of it! He called his friends over and showed them the amazing vase. All his friends thought the vase was beautiful and couldn't believe how lucky Ben was.
And that's how Ben found an amazing vase in the store!
<|endoftext|>
Once upon a time, there was a reliable otter named Ollie. He lived in a river with his family. They all loved to play and swim together.
One day, Ollie's mom said, "Ollie, hurry and get some fish for dinner!" Ollie swam fast to catch fish. He saw his friend, the duck. "Hi, Ollie!" said the duck. "Hi, duck!" said Ollie. "I need to hurry and catch fish for my family."
While Ollie was catching fish, he found a big shiny stone. He thought, "This is not a fish, but it is so pretty!" Ollie took the shiny stone home to show his family. They all looked at the shiny stone and smiled. The shiny stone made everyone happy, and they forgot about the fish for dinner.
<|endoftext|>
One day, a little boy named Tim went to the park. He saw a big tiger. The tiger was not mean, but very easy to play with. Tim and the tiger played all day. They had lots of fun.
Then, something unexpected happened. The tiger started to shake. Tim was scared. He did not know what was going on. But then, the tiger turned into a nice dog. Tim was very surprised.
Tim and the dog played together now. They were very happy. The dog was easy to play with too. At the end of the day, Tim went home with his new friend.
<|endoftext|>
"""

In [7]:
import regex as re
from typing import List
import regex as re


def split_by_special_tokens(text: str, special_tokens: list[str]) -> List[str]:
    """
    Split text by special tokens, preserving the tokens themselves.
    """
    special_tokens_sorted = sorted(special_tokens, key=len, reverse=True)

    if not special_tokens_sorted:
        parts = [text]
    else:
        pattern = "|".join(re.escape(tok) for tok in special_tokens_sorted)
        print(f"Using pattern: {pattern}")
        parts = re.split("(" + pattern + ")", text)

    return parts

In [8]:
split_by_special_tokens(text, ["<|endoftext|>"])

Using pattern: <\|endoftext\|>


['low low low low low\nlower lower widest widest widest ',
 '<|endoftext|>',
 '\nnewest newest newest newest newest newest\n']

In [8]:
from collections import Counter
from typing import Tuple

# 1. Count the number of words in the text
word_counts = Counter(text.split())


def to_tuple(word: str) -> Tuple[bytes, ...]:
    return tuple(bytes([b]) for b in word.encode("utf-8"))


def flatten_bytes(tup: Tuple[bytes, ...]) -> str:
    return "".join(b.decode("utf-8") for b in tup)


def bpe_merge(corpus: dict[Tuple[bytes, ...], int], num_merges: int):
    merges = []
    vocab = {}

    # 1. 初始化 vocab：每个单字节 bytes → ID (0–255)
    idx = 0
    for word in corpus:
        for token in word:
            if token not in vocab.values():
                vocab[idx] = token
                idx += 1
    next_token_id = max(vocab.keys(), default=-1) + 1

    for step in range(num_merges):
        pair_freq = Counter()

        # Count adjacent pairs
        for word, freq in corpus.items():
            for i in range(len(word) - 1):
                pair = (word[i], word[i + 1])
                pair_freq[pair] += freq

        if not pair_freq:
            break

        # Most frequent pair (tie-break lexicographically)
        most_common = pair_freq.most_common()
        max_freq = most_common[0][1]
        top_pairs = [p for p, c in most_common if c == max_freq]
        pair_to_merge = max(top_pairs)

        merges.append(pair_to_merge)

        # Add merged token to vocab
        merged_token = pair_to_merge[0] + pair_to_merge[1]
        vocab[next_token_id] = merged_token
        next_token_id += 1

        # Merge all occurrences
        new_corpus = {}
        for word, freq in corpus.items():
            new_word = []
            i = 0
            while i < len(word):
                if i < len(word) - 1 and (word[i], word[i + 1]) == pair_to_merge:
                    new_word.append(merged_token)
                    i += 2
                else:
                    new_word.append(word[i])
                    i += 1
            new_corpus[tuple(new_word)] = freq

        corpus = new_corpus

        print(f"Step {step + 1}: merge '{flatten_bytes(pair_to_merge)}'")
        print("Corpus:")
        for w, c in corpus.items():
            print(f"  {flatten_bytes(w)}: {c}")
        print()

    return corpus, merges, vocab


corpus = {to_tuple(word): count for word, count in word_counts.items()}

In [11]:
final_vocab

{0: b'l',
 1: b'o',
 2: b'w',
 3: b'e',
 4: b'r',
 5: b'i',
 6: b'd',
 7: b's',
 8: b't',
 9: b'n',
 10: b'st',
 11: b'est',
 12: b'ow',
 13: b'low',
 14: b'west',
 15: b'ne'}

In [9]:
final_corpus, merge_history, final_vocab = bpe_merge(corpus, num_merges=6)

Step 1: merge 'st'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6

Step 2: merge 'est'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6

Step 3: merge 'ow'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6

Step 4: merge 'low'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6

Step 5: merge 'west'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6

Step 6: merge 'ne'
Corpus:
  low: 5
  lower: 2
  widest: 3
  newest: 6



In [10]:
print("Final vocab:")
for k, v in final_vocab.items():
    print(f"{k}: {v} → '{v.decode('utf-8', errors='replace')}'")

Final vocab:
0: b'l' → 'l'
1: b'o' → 'o'
2: b'w' → 'w'
3: b'e' → 'e'
4: b'r' → 'r'
5: b'i' → 'i'
6: b'd' → 'd'
7: b's' → 's'
8: b't' → 't'
9: b'n' → 'n'
10: b'st' → 'st'
11: b'est' → 'est'
12: b'ow' → 'ow'
13: b'low' → 'low'
14: b'west' → 'west'
15: b'ne' → 'ne'
