In [1]:
from collections import Counter, defaultdict
import regex as re

In [2]:
PAT = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""

In [3]:
text = open("../data/tiny_test.txt").read()
print(text[:100])

u don't have to be scared of the loud dog, I'll protect you". The mole felt so safe with the little 


In [4]:
vocab = Counter()
for word in re.finditer(PAT, text):
    word = word.captures()[0]
    byte_chars = tuple([bytes([o]) for o in list(word.encode("utf-8"))])
    vocab[byte_chars] += 1

In [5]:
def get_vocab(
    corpus: str
) -> Counter:
    vocab = Counter()
    for word in re.finditer(PAT, corpus):
        word = word.captures()[0]
        byte_chars = tuple([bytes([o]) for o in list(word.encode("utf-8"))])
        vocab[byte_chars] += 1
    return vocab

In [6]:
vocab = get_vocab(text)
sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1], reverse=True)

In [7]:
pair_stats, pair_indices = Counter(), defaultdict(Counter)
for i, (word, freq) in enumerate(vocab.items()):
    for p in zip(word, word[1:]):
        pair_stats[p] += freq
        pair_indices[p][i] += 1

In [8]:
def get_pair_stats(
    vocab: Counter
) -> tuple[Counter, defaultdict[tuple, Counter]]:
    pair_stats, pair_indices = Counter(), defaultdict(Counter)
    for i, (word, freq) in enumerate(vocab.items()):
        for p in zip(word, word[1:]):
            pair_stats[p] += freq
            pair_indices[p][i] += 1
        
    return pair_stats, pair_indices

In [9]:
pair_stats, pair_indices = get_pair_stats(vocab)

In [10]:
top_pair = max(pair_stats, key=pair_stats.get)
top_pair

(b' ', b't')

In [53]:
def replace_pair(
    pair: tuple[bytes, bytes],
    vocab: list[tuple[tuple, int]],
    indices: defaultdict[tuple, Counter]
) -> list[tuple[int, tuple, tuple, int]]:

    changes = []
    for j, freq in indices[pair].items():
        if freq < 1:
            continue
        word, freq = vocab[j]
        first, second = pair
        new_tok = pair[0] + pair[1]

        i = 0
        new_word = ()
        is_changed = False
        while i+1 < len(word):
            if word[i] == first and word[i+1] == second:
                new_word += (new_tok, )
                i += 2
                is_changed = True
            else:
                new_word += (word[i], )
                i += 1
        if i < len(word): new_word += (word[i], )

        # if is_changed:
        #     sorted_vocab[j] = (new_word, freq)
        #     changes.append((j, new_word, word, freq))
        vocab[j] = (new_word, freq)
        changes.append((j, new_word, word, freq))

    
    return changes

In [12]:
changes = replace_pair(top_pair, sorted_vocab, pair_indices)
changes

[(4, (b' t', b'o'), (b' ', b't', b'o'), 131),
 (324,
  (b' t', b'r', b'e', b'a', b't'),
  (b' ', b't', b'r', b'e', b'a', b't'),
  2),
 (358, (b' t', b'u', b'b'), (b' ', b't', b'u', b'b'), 2),
 (455,
  (b' t', b'h', b'e', b's', b'e'),
  (b' ', b't', b'h', b'e', b's', b'e'),
  1),
 (556,
  (b' t', b'o', b'u', b'c', b'h'),
  (b' ', b't', b'o', b'u', b'c', b'h'),
  1),
 (586,
  (b' t', b'o', b'w', b'e', b'l'),
  (b' ', b't', b'o', b'w', b'e', b'l'),
  1)]

In [13]:
len(changes)

6

In [54]:
def update_pair_stats(
    pair: tuple[bytes, bytes],
    changes: list[tuple[int, tuple, tuple, int]],
    stats: Counter,
    indices: defaultdict[tuple, Counter]
):
    del stats[pair]
    del indices[pair]

    first, second = pair
    new_tok = first + second

    for j, word, old_word, freq in changes:

        i = 0
        while True:
            try:
                i = old_word.index(first, i)
            except ValueError:
                break

            if i+1 < len(old_word) and old_word[i+1] == second:
                # decrement left neighbour if present
                if i > 0:
                    prev = old_word[i-1:i+1]
                    stats[prev] -= freq
                    indices[prev][j] -= 1

                # decrement right neighbour
                if i+2 < len(old_word):
                    # walking on some thin indexing ice right here
                    # skip if the sequence is A B C B C and top_pair is B C, because the frequency of "C B" will be reduced by the previous code block
                    if i < len(old_word)-3:
                        if old_word[i+2] != first or old_word[i+3] != second:
                            next_ = old_word[i+1:i+3]
                            stats[next_] -= freq
                            indices[next_][j] -= 1

                i += 2
            else:
                i += 1
        
        i = 0
        while True:
            try:
                i = word.index(new_tok, i)
            except ValueError:
                break
                
            if i > 0:
                prev = word[i-1:i+1]
                stats[prev] += freq
                indices[prev][j] += 1
            
            if i+1 < len(word) and word[i+1] != new_tok:
                next_ = word[i:i+2]
                stats[next_] += freq
                indices[next_][j] += 1
            i += 1

In [15]:
update_pair_stats(top_pair, changes, pair_stats, pair_indices)

In [16]:

# # sanity check
# for k,v in get_pair_count(word_count).items():
#     if v != pair_count[k]:
#         print(f"k: {k}; actual v: {v}; pair_count v: {pair_count[k]}")

In [55]:
vocab_sz = 1000
num_merges = vocab_sz - 256

text = open("../data/TinyStoriesV2-GPT4-valid.txt").read()
# text = open("../data/tiny_test.txt").read()
vocab = get_vocab(text)
sorted_vocab = sorted(list(vocab.items()), key=lambda x: x[1], reverse=True)
pair_stats, pair_indices = get_pair_stats(vocab)
merges = []

for _ in range(num_merges):
    top_pair = max(pair_stats, key=pair_stats.get)
    merges.append(top_pair)
    changes = replace_pair(top_pair, sorted_vocab, pair_indices)
    update_pair_stats(top_pair, changes, pair_stats, pair_indices)

In [58]:
tv = Counter()
for b,c in sorted_vocab:
    tv[b] += c

In [59]:
# sanity check
for k,v in get_pair_stats(tv)[0].items():
    if v != pair_stats[k]:
        print(f"k: {k}; actual v: {v}; pair_count v: {pair_stats[k]}")

k: (b' ', b't'); actual v: 612144; pair_count v: -25109
k: (b't', b'h'); actual v: 430879; pair_count v: -7158
k: (b'h', b'e'); actual v: 563857; pair_count v: -11055
k: (b' ', b'a'); actual v: 475712; pair_count v: -1550
k: (b'a', b'n'); actual v: 317667; pair_count v: -2433
k: (b'n', b'd'); actual v: 308163; pair_count v: -411
k: (b't', b'o'); actual v: 223512; pair_count v: -6604
k: (b' ', b'w'); actual v: 306737; pair_count v: -11291
k: (b'w', b'a'); actual v: 176392; pair_count v: -1671
k: (b'a', b's'); actual v: 145017; pair_count v: -515
k: (b' ', b'i'); actual v: 130641; pair_count v: -1505
k: (b'i', b't'); actual v: 145957; pair_count v: -1614
k: (b' ', b'H'); actual v: 56047; pair_count v: 0
k: (b'H', b'e'); actual v: 56277; pair_count v: 0
k: (b' ', b'"'); actual v: 47787; pair_count v: 0
k: (b' ', b'T'); actual v: 114428; pair_count v: -52425
k: (b'T', b'h'); actual v: 78523; pair_count v: 0
k: (b' ', b's'); actual v: 285262; pair_count v: -19050
k: (b's', b'a'); actual v: 

In [1]:
from tests.common import gpt2_bytes_to_unicode

In [6]:
dict(sorted(list(gpt2_bytes_to_unicode().items())))

{0: 'Ā',
 1: 'ā',
 2: 'Ă',
 3: 'ă',
 4: 'Ą',
 5: 'ą',
 6: 'Ć',
 7: 'ć',
 8: 'Ĉ',
 9: 'ĉ',
 10: 'Ċ',
 11: 'ċ',
 12: 'Č',
 13: 'č',
 14: 'Ď',
 15: 'ď',
 16: 'Đ',
 17: 'đ',
 18: 'Ē',
 19: 'ē',
 20: 'Ĕ',
 21: 'ĕ',
 22: 'Ė',
 23: 'ė',
 24: 'Ę',
 25: 'ę',
 26: 'Ě',
 27: 'ě',
 28: 'Ĝ',
 29: 'ĝ',
 30: 'Ğ',
 31: 'ğ',
 32: 'Ġ',
 33: '!',
 34: '"',
 35: '#',
 36: '$',
 37: '%',
 38: '&',
 39: "'",
 40: '(',
 41: ')',
 42: '*',
 43: '+',
 44: ',',
 45: '-',
 46: '.',
 47: '/',
 48: '0',
 49: '1',
 50: '2',
 51: '3',
 52: '4',
 53: '5',
 54: '6',
 55: '7',
 56: '8',
 57: '9',
 58: ':',
 59: ';',
 60: '<',
 61: '=',
 62: '>',
 63: '?',
 64: '@',
 65: 'A',
 66: 'B',
 67: 'C',
 68: 'D',
 69: 'E',
 70: 'F',
 71: 'G',
 72: 'H',
 73: 'I',
 74: 'J',
 75: 'K',
 76: 'L',
 77: 'M',
 78: 'N',
 79: 'O',
 80: 'P',
 81: 'Q',
 82: 'R',
 83: 'S',
 84: 'T',
 85: 'U',
 86: 'V',
 87: 'W',
 88: 'X',
 89: 'Y',
 90: 'Z',
 91: '[',
 92: '\\',
 93: ']',
 94: '^',
 95: '_',
 96: '`',
 97: 'a',
 98: 'b',
 99: 'c',
 100: 'd'

In [12]:
[(o[0], bytes([t for t in o[1]])) for o in list(gpt2_bytes_to_unicode().items())]

TypeError: 'str' object cannot be interpreted as an integer

In [1]:
from tests.test_train_bpe import test_train_bpe

In [2]:
v, r = test_train_bpe()

In [3]:
v

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [4]:
r

{0: b'<|endoftext|>',
 1: b'!',
 2: b'"',
 3: b'#',
 4: b'$',
 5: b'%',
 6: b'&',
 7: b"'",
 8: b'(',
 9: b')',
 10: b'*',
 11: b'+',
 12: b',',
 13: b'-',
 14: b'.',
 15: b'/',
 16: b'0',
 17: b'1',
 18: b'2',
 19: b'3',
 20: b'4',
 21: b'5',
 22: b'6',
 23: b'7',
 24: b'8',
 25: b'9',
 26: b':',
 27: b';',
 28: b'<',
 29: b'=',
 30: b'>',
 31: b'?',
 32: b'@',
 33: b'A',
 34: b'B',
 35: b'C',
 36: b'D',
 37: b'E',
 38: b'F',
 39: b'G',
 40: b'H',
 41: b'I',
 42: b'J',
 43: b'K',
 44: b'L',
 45: b'M',
 46: b'N',
 47: b'O',
 48: b'P',
 49: b'Q',
 50: b'R',
 51: b'S',
 52: b'T',
 53: b'U',
 54: b'V',
 55: b'W',
 56: b'X',
 57: b'Y',
 58: b'Z',
 59: b'[',
 60: b'\\',
 61: b']',
 62: b'^',
 63: b'_',
 64: b'`',
 65: b'a',
 66: b'b',
 67: b'c',
 68: b'd',
 69: b'e',
 70: b'f',
 71: b'g',
 72: b'h',
 73: b'i',
 74: b'j',
 75: b'k',
 76: b'l',
 77: b'm',
 78: b'n',
 79: b'o',
 80: b'p',
 81: b'q',
 82: b'r',
 83: b's',
 84: b't',
 85: b'u',
 86: b'v',
 87: b'w',
 88: b'x',
 89: b'y',
 90: b'

In [1]:
from cs336_basics.train_bpe import train_bpe

In [3]:
train_bpe("../data/test.txt", 300, ["<|endoftext|>"])

({0: b'\x00',
  1: b'\x01',
  2: b'\x02',
  3: b'\x03',
  4: b'\x04',
  5: b'\x05',
  6: b'\x06',
  7: b'\x07',
  8: b'\x08',
  9: b'\t',
  10: b'\n',
  11: b'\x0b',
  12: b'\x0c',
  13: b'\r',
  14: b'\x0e',
  15: b'\x0f',
  16: b'\x10',
  17: b'\x11',
  18: b'\x12',
  19: b'\x13',
  20: b'\x14',
  21: b'\x15',
  22: b'\x16',
  23: b'\x17',
  24: b'\x18',
  25: b'\x19',
  26: b'\x1a',
  27: b'\x1b',
  28: b'\x1c',
  29: b'\x1d',
  30: b'\x1e',
  31: b'\x1f',
  32: b' ',
  33: b'!',
  34: b'"',
  35: b'#',
  36: b'$',
  37: b'%',
  38: b'&',
  39: b"'",
  40: b'(',
  41: b')',
  42: b'*',
  43: b'+',
  44: b',',
  45: b'-',
  46: b'.',
  47: b'/',
  48: b'0',
  49: b'1',
  50: b'2',
  51: b'3',
  52: b'4',
  53: b'5',
  54: b'6',
  55: b'7',
  56: b'8',
  57: b'9',
  58: b':',
  59: b';',
  60: b'<',
  61: b'=',
  62: b'>',
  63: b'?',
  64: b'@',
  65: b'A',
  66: b'B',
  67: b'C',
  68: b'D',
  69: b'E',
  70: b'F',
  71: b'G',
  72: b'H',
  73: b'I',
  74: b'J',
  75: b'K',
  76: b'