In [1]:
from pathlib import Path

with open("data/train.csv") as f:
    train = f.read()
with open("data/test.csv") as f:
    test = f.read()
with open("data/validation.csv") as f:
    val = f.read()


In [2]:
#ord fucntion used to get the unicode code point of a character
ord('😂')

128514

In [3]:
#encoding the character to its utf-8 byte representation
'😂'.encode('utf-8')

b'\xf0\x9f\x98\x82'

In [4]:
"hello my name is my name is my name is Slim Shady".encode("utf-8")

b'hello my name is my name is my name is Slim Shady'

In [5]:
#to see as numeric reps use list
list('😂'.encode('utf-8'))

[240, 159, 152, 130]

In [6]:
list("hello my name is my name is my name is Slim Shady".encode("utf-8"))

[104,
 101,
 108,
 108,
 111,
 32,
 109,
 121,
 32,
 110,
 97,
 109,
 101,
 32,
 105,
 115,
 32,
 109,
 121,
 32,
 110,
 97,
 109,
 101,
 32,
 105,
 115,
 32,
 109,
 121,
 32,
 110,
 97,
 109,
 101,
 32,
 105,
 115,
 32,
 83,
 108,
 105,
 109,
 32,
 83,
 104,
 97,
 100,
 121]

In [7]:
sample = "The “I’m Feeling Lucky” button, a feature of Google’s search engine, is akin to a digital shortcut that whisks users away to the top search result page without displaying the full list of search results. This button, positioned next to the standard “Google Search” button, offers a quick one-click path to what Google’s algorithm determines as the most relevant page for a given search query. The feature is designed for those who trust Google’s ranking system implicitly, allowing them to bypass the search results page entirely."
sample_bytes = sample.encode("utf-8")

print(sample)
print('len of sample:', len(sample))

tokens = list(map(int, sample_bytes))
print(tokens)
print('len of tokens:', len(tokens)) # slightly longer due to utf-8 encoding

The “I’m Feeling Lucky” button, a feature of Google’s search engine, is akin to a digital shortcut that whisks users away to the top search result page without displaying the full list of search results. This button, positioned next to the standard “Google Search” button, offers a quick one-click path to what Google’s algorithm determines as the most relevant page for a given search query. The feature is designed for those who trust Google’s ranking system implicitly, allowing them to bypass the search results page entirely.
len of sample: 530
[84, 104, 101, 32, 226, 128, 156, 73, 226, 128, 153, 109, 32, 70, 101, 101, 108, 105, 110, 103, 32, 76, 117, 99, 107, 121, 226, 128, 157, 32, 98, 117, 116, 116, 111, 110, 44, 32, 97, 32, 102, 101, 97, 116, 117, 114, 101, 32, 111, 102, 32, 71, 111, 111, 103, 108, 101, 226, 128, 153, 115, 32, 115, 101, 97, 114, 99, 104, 32, 101, 110, 103, 105, 110, 101, 44, 32, 105, 115, 32, 97, 107, 105, 110, 32, 116, 111, 32, 97, 32, 100, 105, 103, 105, 116, 97, 

first step will be to find the pairs off bytes which repeat

In [8]:
def get_pair_counts(tokens: list[int]) -> dict[tuple[int, int], int]:
    pair_counts = {}
    for i in range(len(tokens) - 1):
        pair = (tokens[i], tokens[i + 1])
        if pair not in pair_counts:
            pair_counts[pair] = 0
        pair_counts[pair] += 1
    return pair_counts

In [9]:
token_pairs = get_pair_counts(tokens)

In [10]:
print(sorted(((v,k) for k,v in token_pairs.items()), reverse=True)[:10])

[(15, (32, 116)), (14, (101, 32)), (13, (115, 32)), (11, (116, 104)), (10, (116, 32)), (9, (116, 111)), (9, (32, 97)), (8, (226, 128)), (8, (104, 101)), (8, (101, 97))]


In [11]:
#prepare to mint new token by discovering what the max token id is
max(tokens)
# here is 226 but we will reserve 256 ids for single byte tokens (0-255) so new tokens will start from 256


226

In [12]:
# most common pair
top_pair = max(token_pairs.keys(), key=lambda x: token_pairs[x])
top_pair


(32, 116)

In [13]:
def merge_pair(tokens: list[int], pair: tuple[int, int], new_token_id: int) -> list[int]:
    merged_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
            merged_tokens.append(new_token_id)
            i += 2  # Skip the next token as it's part of the merged pair
        else:
            merged_tokens.append(tokens[i])
            i += 1
    return merged_tokens

In [14]:
new_tokens = merge_pair(tokens, top_pair, 256)

In [15]:
sum(1 for t in new_tokens if t == 256)

15

In [16]:
print(new_tokens)
print(f'len of new tokens: {len(new_tokens)}')

[84, 104, 101, 32, 226, 128, 156, 73, 226, 128, 153, 109, 32, 70, 101, 101, 108, 105, 110, 103, 32, 76, 117, 99, 107, 121, 226, 128, 157, 32, 98, 117, 116, 116, 111, 110, 44, 32, 97, 32, 102, 101, 97, 116, 117, 114, 101, 32, 111, 102, 32, 71, 111, 111, 103, 108, 101, 226, 128, 153, 115, 32, 115, 101, 97, 114, 99, 104, 32, 101, 110, 103, 105, 110, 101, 44, 32, 105, 115, 32, 97, 107, 105, 110, 256, 111, 32, 97, 32, 100, 105, 103, 105, 116, 97, 108, 32, 115, 104, 111, 114, 116, 99, 117, 116, 256, 104, 97, 116, 32, 119, 104, 105, 115, 107, 115, 32, 117, 115, 101, 114, 115, 32, 97, 119, 97, 121, 256, 111, 256, 104, 101, 256, 111, 112, 32, 115, 101, 97, 114, 99, 104, 32, 114, 101, 115, 117, 108, 116, 32, 112, 97, 103, 101, 32, 119, 105, 116, 104, 111, 117, 116, 32, 100, 105, 115, 112, 108, 97, 121, 105, 110, 103, 256, 104, 101, 32, 102, 117, 108, 108, 32, 108, 105, 115, 116, 32, 111, 102, 32, 115, 101, 97, 114, 99, 104, 32, 114, 101, 115, 117, 108, 116, 115, 46, 32, 84, 104, 105, 115, 32, 98

In [17]:
vocab_size = 280 # this is the vocab size we want after all merges are done, yes i know our sample size is too small to reach this but just for demo
num_merges = vocab_size - 256 # 256 is the starting point for new tokens
mapping = {}
new_tokens = tokens.copy()

for merge_id in range(max(256,max(new_tokens)), max(256,max(new_tokens)) + num_merges):
    token_pairs = get_pair_counts(new_tokens)
    if not token_pairs:
        break  # No more pairs to merge
    top_pair = max(token_pairs.keys(), key=lambda x: token_pairs[x])
    new_tokens = merge_pair(new_tokens, top_pair, merge_id)
    mapping[top_pair] = merge_id
    print(f'Merge {merge_id - 255}: Merged pair {top_pair} into new token ID {merge_id}')
    print(f'Len of tokens after merge: {len(new_tokens)}')

Merge 1: Merged pair (32, 116) into new token ID 256
Len of tokens after merge: 531
Merge 2: Merged pair (101, 32) into new token ID 257
Len of tokens after merge: 518
Merge 3: Merged pair (115, 32) into new token ID 258
Len of tokens after merge: 507
Merge 4: Merged pair (226, 128) into new token ID 259
Len of tokens after merge: 499
Merge 5: Merged pair (101, 97) into new token ID 260
Len of tokens after merge: 491
Merge 6: Merged pair (256, 104) into new token ID 261
Len of tokens after merge: 483
Merge 7: Merged pair (116, 32) into new token ID 262
Len of tokens after merge: 475
Merge 8: Merged pair (105, 110) into new token ID 263
Len of tokens after merge: 468
Merge 9: Merged pair (260, 114) into new token ID 264
Len of tokens after merge: 462
Merge 10: Merged pair (264, 99) into new token ID 265
Len of tokens after merge: 456
Merge 11: Merged pair (265, 104) into new token ID 266
Len of tokens after merge: 450
Merge 12: Merged pair (256, 111) into new token ID 267
Len of tokens 

#successfully reduced sequence  size from 531 to 388

In [18]:
mapping

{(32, 116): 256,
 (101, 32): 257,
 (115, 32): 258,
 (226, 128): 259,
 (101, 97): 260,
 (256, 104): 261,
 (116, 32): 262,
 (105, 110): 263,
 (260, 114): 264,
 (264, 99): 265,
 (265, 104): 266,
 (256, 111): 267,
 (111, 110): 268,
 (44, 32): 269,
 (115, 266): 270,
 (270, 32): 271,
 (258, 97): 272,
 (105, 116): 273,
 (114, 101): 274,
 (112, 97): 275,
 (259, 153): 276,
 (263, 103): 277,
 (117, 116): 278,
 (71, 111): 279}

great now lets do it on train data

In [19]:
encoded = train.encode('utf-8')
train_tokens = list(map(int, encoded))

In [20]:
vocab_size = 286
num_merges = vocab_size - 256 # 256 is the starting point for new tokens
mapping = {}
new_tokens = train_tokens.copy()

for merge_id in range(max(256,max(new_tokens)), max(256,max(new_tokens)) + num_merges):
    token_pairs = get_pair_counts(new_tokens)
    if not token_pairs:
        break  # No more pairs to merge
    top_pair = max(token_pairs.keys(), key=lambda x: token_pairs[x])
    new_tokens = merge_pair(new_tokens, top_pair, merge_id)
    mapping[top_pair] = merge_id
    print(f'Merge {merge_id - 255}: Merged pair {top_pair} into new token ID {merge_id}')
    print(f'Len of tokens after merge: {len(new_tokens)}')

Merge 1: Merged pair (101, 32) into new token ID 256
Len of tokens after merge: 978852
Merge 2: Merged pair (116, 104) into new token ID 257
Len of tokens after merge: 958260
Merge 3: Merged pair (116, 32) into new token ID 258
Len of tokens after merge: 943381
Merge 4: Merged pair (115, 32) into new token ID 259
Len of tokens after merge: 929395
Merge 5: Merged pair (100, 32) into new token ID 260
Len of tokens after merge: 916600
Merge 6: Merged pair (44, 32) into new token ID 261
Len of tokens after merge: 904115
Merge 7: Merged pair (111, 117) into new token ID 262
Len of tokens after merge: 892609
Merge 8: Merged pair (101, 114) into new token ID 263
Len of tokens after merge: 882050
Merge 9: Merged pair (105, 110) into new token ID 264
Len of tokens after merge: 872519
Merge 10: Merged pair (121, 32) into new token ID 265
Len of tokens after merge: 863202
Merge 11: Merged pair (97, 110) into new token ID 266
Len of tokens after merge: 854060
Merge 12: Merged pair (111, 114) into 

In [21]:
compression_ratio = len(train_tokens) / len(new_tokens)
print(f'Compression ratio on train data: {compression_ratio:.2f} after bpe on train data for num merges: {num_merges}')

Compression ratio on train data: 1.34 after bpe on train data for num merges: 30


In [22]:
mapping

{(101, 32): 256,
 (116, 104): 257,
 (116, 32): 258,
 (115, 32): 259,
 (100, 32): 260,
 (44, 32): 261,
 (111, 117): 262,
 (101, 114): 263,
 (105, 110): 264,
 (121, 32): 265,
 (97, 110): 266,
 (111, 114): 267,
 (58, 10): 268,
 (111, 32): 269,
 (101, 110): 270,
 (97, 114): 271,
 (10, 10): 272,
 (32, 257): 273,
 (111, 110): 274,
 (108, 108): 275,
 (104, 97): 276,
 (44, 10): 277,
 (105, 259): 278,
 (101, 115): 279,
 (46, 272): 280,
 (121, 262): 281,
 (32, 115): 282,
 (116, 269): 283,
 (101, 97): 284,
 (266, 260): 285}

# encoding and decoding

In [23]:
#decoding
import warnings

vocab = {i: bytes([i]).decode('utf-8', errors='ignore') for i in range(256)}
for k, v in mapping.items():
    vocab[v] = vocab[k[0]] + vocab[k[1]]

print(vocab)

def decode(tokens:list[int]|int) -> str: # slightly inefficient, actually makes moresense to keep the vocab dict key as byte with the value as ints then convert
    #the inputted tokens to bytes then decode the bytes to string at the end with errors set to replace
    decoded_tokens = ''
    if isinstance(tokens, int):
        tokens = [tokens]

    for token in tokens:
        if token in vocab:
            decoded_tokens += vocab[token]
        else:
            decoded_tokens += ''  # or handle unknown token as needed
            warnings.warn(f'Token {token} not in vocabulary during decoding.')
    return decoded_tokens


{0: '\x00', 1: '\x01', 2: '\x02', 3: '\x03', 4: '\x04', 5: '\x05', 6: '\x06', 7: '\x07', 8: '\x08', 9: '\t', 10: '\n', 11: '\x0b', 12: '\x0c', 13: '\r', 14: '\x0e', 15: '\x0f', 16: '\x10', 17: '\x11', 18: '\x12', 19: '\x13', 20: '\x14', 21: '\x15', 22: '\x16', 23: '\x17', 24: '\x18', 25: '\x19', 26: '\x1a', 27: '\x1b', 28: '\x1c', 29: '\x1d', 30: '\x1e', 31: '\x1f', 32: ' ', 33: '!', 34: '"', 35: '#', 36: '$', 37: '%', 38: '&', 39: "'", 40: '(', 41: ')', 42: '*', 43: '+', 44: ',', 45: '-', 46: '.', 47: '/', 48: '0', 49: '1', 50: '2', 51: '3', 52: '4', 53: '5', 54: '6', 55: '7', 56: '8', 57: '9', 58: ':', 59: ';', 60: '<', 61: '=', 62: '>', 63: '?', 64: '@', 65: 'A', 66: 'B', 67: 'C', 68: 'D', 69: 'E', 70: 'F', 71: 'G', 72: 'H', 73: 'I', 74: 'J', 75: 'K', 76: 'L', 77: 'M', 78: 'N', 79: 'O', 80: 'P', 81: 'Q', 82: 'R', 83: 'S', 84: 'T', 85: 'U', 86: 'V', 87: 'W', 88: 'X', 89: 'Y', 90: 'Z', 91: '[', 92: '\\', 93: ']', 94: '^', 95: '_', 96: '`', 97: 'a', 98: 'b', 99: 'c', 100: 'd', 101: 'e'

In [24]:
decode([281,284, 285, 257, 259])

'youeaand ths '

In [25]:
print(mapping
      )

{(101, 32): 256, (116, 104): 257, (116, 32): 258, (115, 32): 259, (100, 32): 260, (44, 32): 261, (111, 117): 262, (101, 114): 263, (105, 110): 264, (121, 32): 265, (97, 110): 266, (111, 114): 267, (58, 10): 268, (111, 32): 269, (101, 110): 270, (97, 114): 271, (10, 10): 272, (32, 257): 273, (111, 110): 274, (108, 108): 275, (104, 97): 276, (44, 10): 277, (105, 259): 278, (101, 115): 279, (46, 272): 280, (121, 262): 281, (32, 115): 282, (116, 269): 283, (101, 97): 284, (266, 260): 285}


In [32]:
def encode(text:str) -> int:
    encoded_bytes = text.encode('utf-8')
    int_reps = list(map(int, encoded_bytes))
    # print(int_reps)

    #apply vocabulary merges
    while True:
        pairs = get_pair_counts(int_reps)
        # print(pairs)
        if len(pairs) > 0:
            selected_pair = min(pairs, key=lambda x: mapping.get(x, float('inf'))) # gets the pair with the lowest merge id
        else:
            break
        if selected_pair not in mapping:
            break  # no more merges can be applied
        new_token = mapping[selected_pair]
        int_reps = merge_pair(int_reps, selected_pair, new_token)
    return int_reps

In [35]:
encode('h')

[104]

In [34]:
encode('youeaand ths ') == [281,284, 285, 257, 259]

True

In [36]:
validation = decode(encode(val))
print(validation == val)

True


In [None]:
import warnings
import torch

class NaiveBPETokenizer:
    def __init__(self):
        self.mapping = {}
        self.vocab  = {}

    @staticmethod
    def _get_pair_counts(tokens: list[int]) -> dict[tuple[int, int], int]:
        pair_counts = {}
        for i in range(len(tokens) - 1):
            pair = (tokens[i], tokens[i + 1])
            if pair not in pair_counts:
                pair_counts[pair] = 0
            pair_counts[pair] += 1
        return pair_counts

    @staticmethod
    def _merge_pair(tokens: list[int], pair: tuple[int, int], new_token_id: int) -> list[int]:
        merged_tokens = []
        i = 0
        while i < len(tokens):
            if i < len(tokens) - 1 and (tokens[i], tokens[i + 1]) == pair:
                merged_tokens.append(new_token_id)
                i += 2  # Skip the next token as it's part of the merged pair
            else:
                merged_tokens.append(tokens[i])
                i += 1
        return merged_tokens

    def _build_vocab(self):
        """Build vocabulary from merges for efficient decoding."""
        # Initialize with base bytes
        self.vocab = {i: bytes([i]) for i in range(256)}

        # Apply merges in order to build vocabulary
        for (token1, token2), new_token in self.mapping.items():
            if token1 in self.vocab and token2 in self.vocab:
                self.vocab[new_token] = self.vocab[token1] + self.vocab[token2]

    def train(self, data:list[str] | str, vocab_size:int=1000):
        if isinstance(data, list):
            data = ''.join(data)

        tokens = list(data.encode('utf-8'))

        num_merges = vocab_size - 256 # 256 is the starting point for new tokens
        new_tokens = tokens.copy()
        for merge_id in range(max(256,max(new_tokens)), max(256,max(new_tokens)) + num_merges):
            token_pairs = self._get_pair_counts(new_tokens)
            if not token_pairs:
                break  # No more pairs to merge
            top_pair = max(token_pairs.keys(), key=lambda x: token_pairs[x])
            new_tokens = self._merge_pair(new_tokens, top_pair, merge_id)
            self.mapping[top_pair] = merge_id
            print(f'Merge {merge_id - 255}: Merged pair {top_pair} into new token ID {merge_id}')
            print(f'Len of tokens after merge: {len(new_tokens)}')
        self._build_vocab()
        print('Training complete.')
        print(f'Compression ratio: {len(tokens) / len(new_tokens):.2f}')

    def encode(self, text:str) -> torch.Tensor[int]:
        encoded_bytes = text.encode('utf-8')
        int_reps = list(map(int, encoded_bytes))
        # print(int_reps)

        #apply vocabulary merges
        while True:
            pairs = self._get_pair_counts(int_reps)
            if len(pairs) > 0:
                selected_pair = min(pairs, key=lambda x: self.mapping.get(x, float('inf'))) # gets the pair with the lowest merge id
            else:
                break
            if selected_pair not in self.mapping:
                break  # no more merges can be applied
            new_token = self.mapping[selected_pair]
            int_reps = self._merge_pair(int_reps, selected_pair, new_token)
        return torch.tensor(int_reps)

    def decode(self, tokens:list[int]|int | torch.Tensor) -> str:
        if isinstance(tokens, int):
            tokens = [tokens]

        if not self.vocab:
            self._build_vocab()

        decoded_bytes = b''
        for token in tokens:
            if token in self.vocab:
                decoded_bytes += self.vocab[token]
            else:
                warnings.warn(f'Token {token} not in vocabulary')

        return decoded_bytes.decode('utf-8', errors='replace')


In [42]:
bpe = NaiveBPETokenizer()

In [43]:
bpe.train(train, vocab_size=1000)

Merge 1: Merged pair (101, 32) into new token ID 256
Len of tokens after merge: 978852
Merge 2: Merged pair (116, 104) into new token ID 257
Len of tokens after merge: 958260
Merge 3: Merged pair (116, 32) into new token ID 258
Len of tokens after merge: 943381
Merge 4: Merged pair (115, 32) into new token ID 259
Len of tokens after merge: 929395
Merge 5: Merged pair (100, 32) into new token ID 260
Len of tokens after merge: 916600
Merge 6: Merged pair (44, 32) into new token ID 261
Len of tokens after merge: 904115
Merge 7: Merged pair (111, 117) into new token ID 262
Len of tokens after merge: 892609
Merge 8: Merged pair (101, 114) into new token ID 263
Len of tokens after merge: 882050
Merge 9: Merged pair (105, 110) into new token ID 264
Len of tokens after merge: 872519
Merge 10: Merged pair (121, 32) into new token ID 265
Len of tokens after merge: 863202
Merge 11: Merged pair (97, 110) into new token ID 266
Len of tokens after merge: 854060
Merge 12: Merged pair (111, 114) into 

In [46]:
test_enc = bpe.encode(test)

In [47]:
test_dec = bpe.decode(test_enc)

In [48]:
test_dec == test

True

In [53]:
with open(Path("data/test_encoded.txt"), "w") as f:
    f.write(" ".join(map(str, test_enc)))

In [55]:
enc = bpe.encode("Hello world")
bpe.decode(enc)

'Hello world'