In [2]:
!uv version

cs336-basics [36m1.0.6[39m


In [110]:
from collections import defaultdict
import regex as re
from pprint import pprint

### Unicode

In [10]:
chr(97)

'a'

In [None]:
chr(0) #this gives you __repr__()

'\x00'

In [None]:
print(chr(0)) # this gives you __str__()

 


In [35]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [32]:
print("this is a test" + chr(0) + "string")

this is a test string


### Unicode Encoding

In [54]:
test_string = "hello! こんにちは!"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)
print(list(utf8_encoded))
print(type(utf8_encoded))

print(len(test_string))
print(len(utf8_encoded))
print(utf8_encoded.decode("utf-8"))

b'hello! \xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf!'
[104, 101, 108, 108, 111, 33, 32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 33]
<class 'bytes'>
13
23
hello! こんにちは!


In [55]:
test_string = "hello"
utf8_encoded = test_string.encode("utf-8")
print(utf8_encoded)
print(list(utf8_encoded))
print(type(utf8_encoded))

print(len(test_string))
print(len(utf8_encoded))
print(utf8_encoded.decode("utf-8"))

b'hello'
[104, 101, 108, 108, 111]
<class 'bytes'>
5
5
hello


In [None]:
test_string = "hello! こんにちは!"
utf32_encoded = test_string.encode("utf-32")
print(utf32_encoded)
print(type(utf32_encoded))

print(len(test_string))
print(len(utf32_encoded))
print(utf8_encoded.decode("utf-32"))

b'\xff\xfe\x00\x00h\x00\x00\x00e\x00\x00\x00l\x00\x00\x00l\x00\x00\x00o\x00\x00\x00'
<class 'bytes'>
5
24
hello! こんにちは!


In [50]:
def decode_utf8_bytes_to_str_wrong(bytestring: bytes):
    return "".join([bytes([b]).decode("utf-8") for b in bytestring])

decode_utf8_bytes_to_str_wrong("hello! こんにちは!".encode("utf-8"))

UnicodeDecodeError: 'utf-8' codec can't decode byte 0xe3 in position 0: unexpected end of data

### Subword Tokenization

In [1]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [2]:
import regex as re
iterator = re.finditer(PAT,  "some text that i'll pre-tokenize")

### Compute BPE Merges (Thought Exp)

1. parallelization for pretokenization: chunk the text corpus by special tokens and process in parallel
2. remove special tokens
3. pre-tokenization: subwords (via regex) freq dict
4. find the byte pair freq (within subwords): (A,B) -> freq
5. identify: the pair (A,B) with highest frequency
6. (incremental) update
7. back to step 4

In [204]:
vocab = set([i for i in range(0,256)])
pretokenizers = defaultdict(int)
example_text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest こんにちは こんにちは"
match_iter =  re.finditer(PAT, example_text)

for match in match_iter:
    subword = match.group()
    utf8_encoded = subword.encode("utf-8")
    pretokenizers[tuple(utf8_encoded)] += 1

In [205]:
pprint(pretokenizers)

defaultdict(<class 'int'>,
            {(32, 108, 111, 119): 4,
             (32, 108, 111, 119, 101, 114): 2,
             (32, 110, 101, 119, 101, 115, 116): 6,
             (32, 119, 105, 100, 101, 115, 116): 3,
             (32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175): 2,
             (108, 111, 119): 1})


In [206]:
freq_dict = defaultdict(int)
for pretoken, cnt in pretokenizers.items():
    for idx in range(len(pretoken)-1):
        freq_dict[(pretoken[idx], pretoken[idx+1])]+=cnt

In [207]:
freq_dict

defaultdict(int,
            {(108, 111): 7,
             (111, 119): 7,
             (32, 108): 6,
             (119, 101): 8,
             (101, 114): 2,
             (32, 119): 3,
             (119, 105): 3,
             (105, 100): 3,
             (100, 101): 3,
             (101, 115): 9,
             (115, 116): 9,
             (32, 110): 6,
             (110, 101): 6,
             (101, 119): 6,
             (32, 227): 2,
             (227, 129): 8,
             (129, 147): 2,
             (147, 227): 4,
             (227, 130): 2,
             (130, 147): 2,
             (129, 171): 2,
             (171, 227): 2,
             (129, 161): 2,
             (161, 227): 2,
             (129, 175): 2})

In [208]:
def update_pretoken(t: tuple, pattern: tuple) -> tuple:
    result = []
    i = 0
    n = len(pattern)
    while i < len(t):
        if t[i:i+n] == pattern:
            result.append(pattern)
            i += n
        else:
            result.append(t[i])
            i += 1

    return tuple(result)

In [209]:
new_token = max(freq_dict.items(), key=lambda x: (x[1], x[0]))[0]
print(new_token)
vocab.add(new_token)
del freq_dict[new_token]


(115, 116)


In [210]:
# update

pretoken_changes = defaultdict(int)
for pretoken, cnt in pretokenizers.items():
    
    changes = defaultdict(int)
    reform_pretoken = False
    for idx in range(len(pretoken)-1):
        if (pretoken[idx], pretoken[idx+1]) == new_token:
            reform_pretoken = True
            if idx >= 1:
                changes[(pretoken[idx-1], pretoken[idx])] -= cnt
                changes[(pretoken[idx-1], new_token)] += cnt
            if idx < len(pretoken)-2:
                changes[(pretoken[idx+1], pretoken[idx+2])] -= cnt
                changes[(new_token, pretoken[idx+2])] += cnt
    if reform_pretoken:
        for k,v in changes.items():
            freq_dict[k] += v
        new_pretoken = update_pretoken(pretoken, new_token)
        pretoken_changes[(pretoken, new_pretoken)] = cnt
        
for change, cnt in pretoken_changes.items():
    pt = change[0]
    npt = change[1]
    del pretokenizers[pt]
    pretokenizers[npt] = cnt



In [211]:
pprint(pretokenizers)

defaultdict(<class 'int'>,
            {(32, 108, 111, 119): 4,
             (32, 108, 111, 119, 101, 114): 2,
             (32, 110, 101, 119, 101, (115, 116)): 6,
             (32, 119, 105, 100, 101, (115, 116)): 3,
             (32, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175): 2,
             (108, 111, 119): 1})


In [212]:
pprint(freq_dict)

defaultdict(<class 'int'>,
            {(32, 108): 6,
             (32, 110): 6,
             (32, 119): 3,
             (32, 227): 2,
             (100, 101): 3,
             (101, 114): 2,
             (101, (115, 116)): 9,
             (101, 115): 0,
             (101, 119): 6,
             (105, 100): 3,
             (108, 111): 7,
             (110, 101): 6,
             (111, 119): 7,
             (119, 101): 8,
             (119, 105): 3,
             (129, 147): 2,
             (129, 161): 2,
             (129, 171): 2,
             (129, 175): 2,
             (130, 147): 2,
             (147, 227): 4,
             (161, 227): 2,
             (171, 227): 2,
             (227, 129): 8,
             (227, 130): 2})


In [213]:
pprint(vocab)

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63,
 64,
 65,
 66,
 67,
 68,
 69,
 70,
 71,
 72,
 73,
 74,
 75,
 76,
 77,
 78,
 79,
 80,
 81,
 82,
 83,
 84,
 85,
 86,
 87,
 88,
 89,
 90,
 91,
 92,
 93,
 94,
 95,
 96,
 97,
 98,
 99,
 100,
 101,
 102,
 103,
 104,
 105,
 106,
 107,
 108,
 109,
 110,
 111,
 112,
 113,
 114,
 115,
 116,
 117,
 118,
 119,
 120,
 121,
 122,
 123,
 124,
 125,
 126,
 127,
 128,
 129,
 130,
 131,
 132,
 133,
 134,
 135,
 136,
 137,
 138,
 139,
 140,
 141,
 142,
 143,
 144,
 145,
 146,
 147,
 148,
 149,
 150,
 151,
 152,
 153,
 154,
 155,
 156,
 157,
 158,
 159,
 160,
 161,
 162,
 163,
 164,
 165,
 166,
 167,
 168,
 169,
 170,
 171,
 172,
 173,
 174,
 175,
 176,
 177,
 178,
 179,
 180,
 181,
 182,
 183,
 184,


### Complete BPE Implementation

In [328]:
from collections import defaultdict
import regex as re
from pprint import pprint

def train_bpe(text: str, target_vocab_size: int, debug = False):
    """
    Trains a byte-pair encoding (BPE) tokenizer on `text` until `target_vocab_size` is reached.

    Returns:
        vocab (dict[int, bytes]): Final vocabulary mapping token IDs to byte tokens.
        merges (list[tuple[bytes, bytes]]): List of BPE merges in creation order.
    """
    pattern = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
    vocab = [bytes([i]) for i in range(256)]

    def merge_pretoken(tokens: tuple, pair: tuple) -> tuple:
        merged = b''.join(pair)
        result, i, n = [], 0, len(pair)
        while i < len(tokens):
            if tokens[i:i+n] == pair:
                result.append(merged)
                i += n
            else:
                result.append(tokens[i])
                i += 1
        return tuple(result)

    # initial pretokenization
    token_counts = defaultdict(int)
    for m in re.finditer(pattern, text):
        b = m.group().encode("utf-8")
        token_counts[tuple(bytes([x]) for x in b)] += 1

    if debug:
        print("Initial pretokens:")
        pprint(token_counts)

    # initial pair frequencies
    pair_freqs = defaultdict(int)
    for token, count in token_counts.items():
        for i in range(len(token) - 1):
            pair_freqs[(token[i], token[i + 1])] += count
    if debug:
        print("\nInitial pair frequencies:")
        pprint(pair_freqs)

    merges = []
    while len(vocab) < target_vocab_size:
        if not pair_freqs:
            print("No more pairs to merge")
            break

        new_pair = max(pair_freqs.items(), key=lambda x: (x[1], x[0]))[0]
        left, right = new_pair
        merged_token = left + right
        if debug:
            print(f"\nMerging: {new_pair} -> {merged_token}")
        merges.append(new_pair)
        vocab.append(merged_token)
        del pair_freqs[new_pair]

        pretoken_updates = {}
        for token, count in token_counts.items():
            local_changes = defaultdict(int)
            has_merge = False

            for i in range(len(token) - 1):
                if (token[i], token[i + 1]) == new_pair:
                    has_merge = True
                    if i >= 1:
                        local_changes[(token[i - 1], token[i])] -= count
                        local_changes[(token[i - 1], merged_token)] += count
                    if i < len(token) - 2:
                        local_changes[(token[i + 1], token[i + 2])] -= count
                        local_changes[(merged_token, token[i + 2])] += count

            if has_merge:
                for k, v in local_changes.items():
                    pair_freqs[k] += v
                    if pair_freqs[k] <= 0:
                        del pair_freqs[k]
                new_token = merge_pretoken(token, (left, right))
                pretoken_updates[token] = (new_token, count)

        for old, (new, c) in pretoken_updates.items():
            del token_counts[old]
            token_counts[new] = c
        if debug:
            print(token_counts)

    vocab = {i: token for i, token in enumerate(vocab)}
    return vocab, merges


In [329]:
example_text = "low low low low low lower lower widest widest widest newest newest newest newest newest newest"
vocab, merges = train_bpe(example_text, 256+6)

In [327]:
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [330]:
merges

[(b's', b't'),
 (b'e', b'st'),
 (b'o', b'w'),
 (b'l', b'ow'),
 (b'w', b'est'),
 (b'n', b'e')]