In [7]:
from cs336_basics.pretokenization_example import find_chunk_boundaries
import regex as re

sample_tiny_path = "/Users/prateekmahadevappahavanur/Documents/GitHub/test_task/assignment1-basics/tests/fixtures/tinystories_sample.txt"


def bpe_train(file_path,vocab_size,special_tokens):
    
    PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

    special_token_pattern = "|".join(re.escape(token) for token in special_tokens)
    word_freqs = {}

    with open(file_path, "rb") as f:
        boundaries = find_chunk_boundaries(
            f, 2, "<|endoftext|>".encode("utf-8"))
            
        # The following is a serial implementation, but you can parallelize this 
        # by sending each start/end pair to a set of processes.
        word_freqs = {}
        for start, end in zip(boundaries[:-1], boundaries[1:]):
            f.seek(start)
            chunk = f.read(end - start).decode("utf-8", errors="ignore")
            text = re.split(special_token_pattern, chunk)
            for segment in text:
                if not segment:
                    continue
                for match in re.finditer(PAT, segment):
                    token = match.group()
                    token_bytes = tuple(token.encode("utf-8"))
                    # print(token,"---",token_bytes)
                    word_freqs[token_bytes] = word_freqs.get(token_bytes, 0) + 1

    def get_count(mapped_list):
        count = {}
        for test_list,value in mapped_list.items():
            test_list = list(test_list)
            for a1,a2 in zip(test_list,test_list[1:]):
                count[(a1,a2)] = count.get((a1,a2),0) +  value
        return count


    def merge_pair(test_list,pair,replacement):
        i = 0
        new_list = []
        while i < len(test_list):
            if i < len(test_list) - 1 and test_list[i] == pair[0] and test_list[i+1] == pair[1]:
                new_list.append(replacement)
                i = i+2
            else:
                new_list.append(test_list[i])
                i += 1
        return new_list

    def get_token_bytes(token_id, vocab_list):
        """Get the bytes for a token ID"""
        if token_id < 256:
            return bytes([token_id])  # Single byte
        else:
            return vocab_list[token_id]  # Already merged token
        
    def make_comparable_pair(pair, vocab_list):
        """Convert pair elements to comparable format"""
        a, b = pair
        # Convert to bytes for consistent comparison
        if isinstance(a, int):
            a_bytes = bytes([a]) if a < 256 else f"merged_{a}".encode()
        else:
            a_bytes = a
        if isinstance(b, int):
            b_bytes = bytes([b]) if b < 256 else f"merged_{b}".encode()
        else:
            b_bytes = b
        return (a_bytes, b_bytes)

    # Initialize vocabulary
    vocab_list = []
    # Special tokens first
    for special_token in special_tokens:
        vocab_list.append(special_token.encode("utf-8"))
    # All 256 bytes
    for i in range(256):
        vocab_list.append(bytes([i]))

    num_merges = vocab_size - len(special_tokens) - 256

    merges = []
    for i in range(num_merges):
        # print(f"----iteration {i}-----")
        pair_counts = get_count(word_freqs)
        # Find best pair with lexicographic tie-breaking
        best_pair = max(
                pair_counts.items(),
                key=lambda x: (x[1], make_comparable_pair(x[0], vocab_list)),
            )[0]
        best_pair

        new_token_id = len(vocab_list)

        first_bytes = get_token_bytes(best_pair[0], vocab_list)
        second_bytes = get_token_bytes(best_pair[1], vocab_list)
        merged_bytes = first_bytes + second_bytes

        vocab_list.append(merged_bytes)

        # print(merged_bytes)

        new_word_freqs = {}
        for words,freq in word_freqs.items():
            merged_list = merge_pair(words, best_pair, new_token_id)
            new_word_freqs[tuple(merged_list)] = freq

        word_freqs = new_word_freqs
        merges.append(best_pair)


    # Build final outputs
    vocab_dict = {i: token for i, token in enumerate(vocab_list)}

    # Build merges_bytes with proper lookup
    merges_bytes = []
    for first, second in merges:
        first_bytes = get_token_bytes(first, vocab_list)
        second_bytes = get_token_bytes(second, vocab_list)
        merges_bytes.append((first_bytes, second_bytes))

    return vocab_dict,merges_bytes

vocab_dict,merges_bytes = bpe_train(sample_tiny_path,300,["<|endoftext|>"])





{0: b'<|endoftext|>', 1: b'\x00', 2: b'\x01', 3: b'\x02', 4: b'\x03', 5: b'\x04', 6: b'\x05', 7: b'\x06', 8: b'\x07', 9: b'\x08', 10: b'\t', 11: b'\n', 12: b'\x0b', 13: b'\x0c', 14: b'\r', 15: b'\x0e', 16: b'\x0f', 17: b'\x10', 18: b'\x11', 19: b'\x12', 20: b'\x13', 21: b'\x14', 22: b'\x15', 23: b'\x16', 24: b'\x17', 25: b'\x18', 26: b'\x19', 27: b'\x1a', 28: b'\x1b', 29: b'\x1c', 30: b'\x1d', 31: b'\x1e', 32: b'\x1f', 33: b' ', 34: b'!', 35: b'"', 36: b'#', 37: b'$', 38: b'%', 39: b'&', 40: b"'", 41: b'(', 42: b')', 43: b'*', 44: b'+', 45: b',', 46: b'-', 47: b'.', 48: b'/', 49: b'0', 50: b'1', 51: b'2', 52: b'3', 53: b'4', 54: b'5', 55: b'6', 56: b'7', 57: b'8', 58: b'9', 59: b':', 60: b';', 61: b'<', 62: b'=', 63: b'>', 64: b'?', 65: b'@', 66: b'A', 67: b'B', 68: b'C', 69: b'D', 70: b'E', 71: b'F', 72: b'G', 73: b'H', 74: b'I', 75: b'J', 76: b'K', 77: b'L', 78: b'M', 79: b'N', 80: b'O', 81: b'P', 82: b'Q', 83: b'R', 84: b'S', 85: b'T', 86: b'U', 87: b'V', 88: b'W', 89: b'X', 90: b'Y

In [None]:
from multiprocessing import Pool

def square(x):
    return 