In [2]:
# text = open("tests/fixtures/corpus.en").read()
text = open("../data/TinyStoriesV2-GPT4-valid.txt").read()

In [3]:
import regex as re
from collections import Counter

In [4]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [5]:
tuple([bytes([o]) for o in list("test string".encode("utf-8"))])

(b't', b'e', b's', b't', b' ', b's', b't', b'r', b'i', b'n', b'g')

In [6]:
def get_word_count(text, PAT):
    word_count = Counter()
    for word in re.finditer(PAT, text, re.IGNORECASE):
        word = word.captures()[0]
        word_count[
            tuple([bytes([o]) for o in list(word.encode("utf-8"))])
        ] += 1

    return word_count

In [7]:
def get_pair_count(word_count):
    pair_count = Counter()
    for word, freq in word_count.items(): 
        for p in zip(word, word[1:]):
            pair_count[p] += freq

    return pair_count

In [8]:
def merge(word_count, pair_count, top_pair):
    new_word_count = Counter()
    for word, freq in word_count.items(): # word_count -> dict[tuple[bytes], int]
        new_word = ()
        i = 0
        idxs = []

        new_tok: bytes = top_pair[0] + top_pair[1]

        while i+1 < len(word):
            if (word[i], word[i+1]) == top_pair:
                idxs.append(i)
                new_word += (new_tok, )
                pair_count[top_pair] -= freq # decrement top_pair's count
                if i-1 >= 0:
                    pair_count[(word[i-1], word[i])] -= freq # decrement left neighbour
                if i+2 < len(word):
                    pair_count[(word[i+1], word[i+2])] -= freq # decrement right neighbour
                i += 2
            else:
                new_word += (word[i], )
                i += 1
        if i < len(word): new_word += (word[i], )

        if idxs:
            # create and insert the new tuple
            new_word_count[new_word] = freq
            # add new neighbours' counts
            for idx in idxs:
                if idx-1 >= 0:
                    pair_count[(word[idx-1], new_tok)] += freq # left neighbour
                if idx+2 < len(word):
                    pair_count[(new_tok, word[idx+2])] += freq # right neighbour
        else:
            new_word_count[new_word] = freq

    return new_word_count

In [9]:
def merge(word_count, pair_count, top_pair):
    new_word_count = Counter()
    for word, freq in word_count.items(): # word_count -> dict[tuple[bytes], int]
        new_word = ()
        idxs = []
        new_tok: bytes = top_pair[0] + top_pair[1]

        i = 0
        pairs_to_decrement = set()
        while i+1 < len(word):
            if (word[i], word[i+1]) == top_pair:
                idxs.append(i)
                new_word += (new_tok, )
                if i > 0          : pairs_to_decrement.add((word[i-1], word[i]))
                if i+2 < len(word): pairs_to_decrement.add((word[i+1], word[i+2]))
                i += 2
            else:
                new_word += (word[i], )
                i += 1
        if i < len(word): new_word += (word[i], )

        for pair in pairs_to_decrement:
            if pair != top_pair:
                pair_count[pair] -= freq

        # create and insert the new tuple
        new_word_count[new_word] = freq
        if idxs:
            # add new neighbours' counts
            for (p1, p2) in zip(new_word, new_word[1:]):
                if p1 == new_tok or p2 == new_tok:
                    pair_count[(p1, p2)] += freq

    return new_word_count

In [10]:
special_tokens = ["<|endoftext|>"]
escaped_tokens = [re.escape(o) for o in special_tokens]
special_tok_pat = "|".join(escaped_tokens)
special_tok_pat

'<\\|endoftext\\|>'

In [11]:
# text = """low low low low low
# lower lower widest widest widest
# newest newest newest newest newest newest"""

In [12]:
word_count = get_word_count(text, PAT)
pair_count = get_pair_count(word_count)
merges = []
vocab = {idx: bytes([idx]) for idx in range(256)}

vocab_sz = 500
num_merges = vocab_sz - 256
for i in range(num_merges):
    new_byte_idx = 256 + i
    top_pair = max(pair_count, key=pair_count.get)
    vocab[new_byte_idx] = top_pair[0] + top_pair[1]
    del pair_count[top_pair]
    word_count = merge(word_count, pair_count, top_pair)
    merges.append(top_pair)

    mismatch_found = False
    for k_check, v_expected in get_pair_count(word_count).items():
        v_actual = pair_count.get(k_check, 0)
        if v_actual != v_expected:
            print(f"{i} MISMATCH for pair {k_check}: actual {v_actual}, expected {v_expected}")
            mismatch_found = True
    if mismatch_found:
        break
    # if pair_count[top_pair] < 0: print(pair_count[top_pair], top_pair)

0
0
0
0
0
0
5 MISMATCH for pair (b'a', b'n'): actual 98628, expected 98626
5 MISMATCH for pair (b'u', b'n'): actual 41986, expected 41937
5 MISMATCH for pair (b'd', b'e'): actual 52062, expected 51860
5 MISMATCH for pair (b'd', b's'): actual 7263, expected 7261


In [50]:
# sanity check
for k,v in get_pair_count(word_count).items():
    if v != pair_count[k]:
        print(f"k: {k}; actual v: {v}; pair_count v: {pair_count[k]}")

k: (b'ig', b'h'); actual v: 3084; pair_count v: 3085
k: (b'a', b'c'); actual v: 3130; pair_count v: 3182
k: (b't', b'i'); actual v: 1023; pair_count v: 1194
k: (b'n', b'g'); actual v: 1573; pair_count v: 2326
k: (b'd', b's'); actual v: 1139; pair_count v: 1141
k: (b'g', b'h'); actual v: 2348; pair_count v: 2349
k: (b'm', b'a'); actual v: 2604; pair_count v: 2605
k: (b'b', b'r'); actual v: 522; pair_count v: 525
k: (b'd', b'e'); actual v: 795; pair_count v: 997
k: (b'n', b'a'); actual v: 4180; pair_count v: 4557
k: (b'd', b'a'); actual v: 10; pair_count v: 11
k: (b'r', b'd'); actual v: 129; pair_count v: 270
k: (b'e', b'i'); actual v: 2; pair_count v: 3


In [35]:
t = """u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.
<|endoftext|>
Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
Tom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."
Sam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."
They went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one could hear them. They were sad and scared, and they never got out of the pit.
<|endoftext|>


Tom and Lily were playing with their toys in the living room. They liked to build towers and bridges with their blocks and cars. Tom was very proud of his tall tower. He wanted to make it even taller, so he reached for more blocks.
"Tom, can I have some blocks too?" Lily asked. She wanted to make a bridge for her cars.
"No, these are mine. Go find your own," Tom said. He did not want to share with his sister. He pulled the blocks closer to him.
Lily felt sad and angry. She did not think Tom was being nice. She looked at his tower and had an idea. She decided to pull one of the blocks at the bottom of the tower.
Suddenly, the tower fell down with a loud crash. All the blocks and cars scattered on the floor. Tom and Lily were shocked. They felt the floor shake and heard a rumble. It was an earthquake!
"Mommy! Daddy!" they cried. They were scared and ran to their parents, who were in the kitchen.
"Are you okay, kids?" Mommy asked. She hugged them and checked if they were hurt.
"We're okay, Mommy. But our toys are broken," Lily said.
"I'm sorry, Lily. But toys are not important. You are important. We are safe and together. That's what matters," Mommy said.
Tom felt sorry for what he did. He realized he was selfish and mean to his sister. He saw how scared she was during the earthquake. He wanted to make her happy.
"Lily, I'm sorry I did not share with you. You can have all the blocks you want. I love you, sister," Tom said.
Lily smiled and hugged him. She forgave him and thanked him. She loved him too.
They went back to the living room and cleaned up their toys. They decided to build something together. They made a big house with a garden and a fence. They put their cars and dolls inside. They were happy and proud of their work.
Mommy and Daddy came to see their house. They praised them and gave them a treat. It was a lemon cake. It was sour, but they liked it. They learned that sharing is caring, and that family is sweet.
<|endoftext|>"""

In [36]:
for w in re.splititer(special_tok_pat, t):
    print(w)
    print("---")

u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little girl. She was very kind and the mole soon came to trust her. He leaned against her and she kept him safe. The mole had found his best friend.

---

Once upon a time, in a warm and sunny place, there was a big pit. A little boy named Tom liked to play near the pit. One day, Tom lost his red ball. He was very sad.
Tom asked his friend, Sam, to help him search for the ball. They looked high and low, but they could not find the ball. Tom said, "I think my ball fell into the pit."
Sam and Tom went close to the pit. They were scared, but they wanted to find the red ball. They looked into the pit, but it was too dark to see. Tom said, "We must go in and search for my ball."
They went into the pit to search. It was dark and scary. They could not find the ball. They tried to get out, but the pit was too deep. Tom and Sam were stuck in the pit. They called for help, but no one could hear them. The

In [37]:
import os
from typing import BinaryIO

def find_chunk_boundaries(
    file: BinaryIO, 
    desired_num_chunks: int, 
    split_special_token: bytes
) -> list[int]:
    """
    Chunk the file into parts that can be counted independently.
    May return fewer chunks if the boundaries end up overlapping.
    """
    assert isinstance(split_special_token, bytes), (
        "Must represent special token as a bytestring"
    )

    # Get total file size in bytes
    file.seek(0, os.SEEK_END)
    file_size = file.tell()
    file.seek(0)

    chunk_size = file_size // desired_num_chunks

    # Initial guesses for chunk boundary locations, uniformly spaced
    # Chunks start on previous index, don't include last index
    chunk_boundaries = [i * chunk_size for i in range(desired_num_chunks + 1)]
    chunk_boundaries[-1] = file_size

    mini_chunk_size = 4096  # Read ahead by 4k bytes at a time

    for bi in range(1, len(chunk_boundaries) - 1):
        initial_position = chunk_boundaries[bi]
        file.seek(initial_position)  # Start at boundary guess
        while True:
            mini_chunk = file.read(mini_chunk_size)  # Read a mini chunk

            # If EOF, this boundary should be at the end of the file
            if mini_chunk == b"":
                chunk_boundaries[bi] = file_size
                break

            # Find the special token in the mini chunk
            found_at = mini_chunk.find(split_special_token)
            if found_at != -1:
                chunk_boundaries[bi] = initial_position + found_at
                break
            initial_position += mini_chunk_size

    # Make sure all boundaries are unique, but might be fewer than desired_num_chunks
    return sorted(set(chunk_boundaries))

In [38]:
find_chunk_boundaries(open("../data/tiny_test.txt", "rb"), 20, "<|endoftext|>".encode("utf-8"))

[0,
 1066,
 3035,
 4075,
 4742,
 6584,
 7560,
 9498,
 10895,
 11739,
 12410,
 13340,
 14054,
 15759,
 17099,
 17737,
 18519]

In [39]:
with open("../data/tiny_test.txt", "rb") as f:
    boundaries = find_chunk_boundaries(f, 6, "<|endoftext|>".encode("utf-8"))
    f.seek(boundaries[1])
    chunk = f.read(boundaries[2] - boundaries[1]).decode("utf-8", errors="ignore")

In [40]:
print(chunk)

<|endoftext|>
One morning, a cat named Tom woke up. He felt happy because the sun was shining. Tom wanted to start his day, so he did a big stretch. He stretched his legs, his back, and his tail. It felt easy and good.
Tom went outside to play. He saw his friend, a dog named Max. Max was also stretching in the morning sun. They both felt very happy. They decided to play together and have fun all day.
At the end of the day, Tom and Max were tired. They had played all day and had lots of fun. They said goodbye to each other and went to their homes. Before going to sleep, they both did another easy stretch. Tom knew that tomorrow would be another happy morning.
<|endoftext|>


Lily and Tom were twins who liked to decorate things. They had a big box of crayons, stickers, and glitter. One day, they found a shiny copper pot in the kitchen. It was Mom's pot, but she was not home. Lily and Tom wanted to make it more pretty.
They took the pot to their room and put it on the floor. They opened t

In [41]:
import multiprocessing as mp

In [42]:
def foo(start, end, f, q):
    q.put(f"processing {start} - {end} bytes in {f.name}")

In [43]:
procs = []
res_q = mp.Queue()
results = []
with open("../data/tiny_test.txt", "rb") as f:
    boundaries = find_chunk_boundaries(f, 6, b"<|endoftext|>")

    for start, end in zip(boundaries[:-1], boundaries[1:]):
        proc = mp.Process(target=foo, args=(start, end, f, res_q))
        procs.append(proc)
        proc.start()

    for _ in range(6):
        results.append(res_q.get())
    
    for proc in procs:
        proc.join()

In [44]:
results

['processing 0 - 4075 bytes in ../data/tiny_test.txt',
 'processing 4075 - 6584 bytes in ../data/tiny_test.txt',
 'processing 6584 - 9498 bytes in ../data/tiny_test.txt',
 'processing 9498 - 12410 bytes in ../data/tiny_test.txt',
 'processing 12410 - 15759 bytes in ../data/tiny_test.txt',
 'processing 15759 - 18519 bytes in ../data/tiny_test.txt']

In [6]:
from cs336_basics.train_bpe import train_bpe

In [2]:
v, m, w, p = train_bpe("tests/fixtures/corpus.en", 500, ["<|endoftext|>"])

3 MISMATCH for pair (b'n', b'i'): actual 100, expected 112


In [10]:
# sanity check
for k,v in get_pair_count(w).items():
    if v != p[k]:
        print(f"k: {k}; actual v: {v}; pair_count v: {p[k]}")

k: (b's', b'a'); actual v: 3; pair_count v: 2
k: (b'g', b'in'); actual v: 39; pair_count v: 36
k: (b'n', b'i'); actual v: 14; pair_count v: 2
k: (b's', b'i'); actual v: 26; pair_count v: 24
k: (b'd', b'e'); actual v: 17; pair_count v: 16
k: (b'l', b'i'); actual v: 27; pair_count v: 26


In [7]:
v, m, w, p = train_bpe("../data/TinyStoriesV2-GPT4-valid.txt", 500, ["<|endoftext|>"])

7 MISMATCH for pair (b'd', b'e'): actual 42624, expected 43881


In [12]:
# sanity check
for k,v in get_pair_count(w).items():
    if v != p[k]:
        print(f"k: {k}; actual v: {v}; pair_count v: {p[k]}")

k: (b'd', b'e'); actual v: 43881; pair_count v: 42624


In [13]:
v, m, w, p = train_bpe("../data/tiny_test.txt", 500, ["<|endoftext|>"])

7 MISMATCH for pair (b'd', b'e'): actual 35, expected 36
