In [22]:
from dataclasses import dataclass
from collections import defaultdict
@dataclass(frozen=True)
class BPETokenizerParams:
    """All you need to specify a BPETokenizer."""
    vocab: dict[int, bytes]     # index -> bytes
    merges: dict[tuple[int, int], int]  # index1,index2 -> new_index

`Character-based tokenization`
    
- A Unicode string is a sequence of Unicode characters.
-  Each character can be converted into a code point (integer) via ord.

In [2]:
class Tokenizer():
    """Abstract interface for a tokenizer
    """
    def encode(self,string:str)-> list[int]:
        raise NotImplementedError
    def decode(self,indices:list[int])->str:
        raise NotImplementedError

In [5]:
class CharacterTokenizer(Tokenizer):
    """Tokenizer that splits the input string into individual characters"""
    def encode(self,string:str)-> list[int]:
        return list(map(ord,string))
    def decode(self,indices:list[int])->str:
        return ''.join(chr(i) for i in indices)

In [8]:
string="Hello, 🌍! 你好!" 

In [13]:
tokenizer=CharacterTokenizer()
tokenizer.encode(string)

[72, 101, 108, 108, 111, 44, 32, 127757, 33, 32, 20320, 22909, 33]

In [14]:
class ByteTokenizer(Tokenizer):
    """Represent a string  as a Sequence of bytes."""

    def encode(self,string: str) ->list[int]:
        string_bytes=string.encode("utf-8")
        indices =list(map(int,string_bytes))
        return indices
    def decode (self,indices: list[int]) -> str:
        string_btyes=bytes(indices)
        string=string_btyes.decode("utf-8")
        return string

In [15]:
tokenizer = ByteTokenizer()

In [19]:
def train_bpe(string: str, num_merges: int) -> BPETokenizerParams:
    ## start with list list of btyes of string.
    indices=list(map(int,string.encode("utf-8")))
    mereges: dict[tuple[int,int],int]={}
    vocab: dict[int,bytes]={x: bytes([x]) for x in range(256)}

    for i in range(num_merges):
        ## Count the number of Occurences of each pair of tokens
        counts=defaultdict(int)
        for index1,index2 in zip(indices,indices[1:]):
            counts[(index1,index2)]+=1
        
        ## Find the Most Common Pair
        pair=max(counts,key=counts.get)
        index1,index2=pair

        ## Merge that Pair
        new_index=256+i
        mereges[pair]=new_index
        vocab[new_index]=vocab[index1]+vocab[index2]
        indices=mereges(indices,pair,new_index)
    return BPETokenizerParams(vocab=vocab,merges=mereges)


In [20]:
indices=list(map(int,string.encode("utf-8")))
indices


[72,
 101,
 108,
 108,
 111,
 44,
 32,
 240,
 159,
 140,
 141,
 33,
 32,
 228,
 189,
 160,
 229,
 165,
 189,
 33]

In [21]:
vocab: dict[int,bytes]={x: bytes([x]) for x in range(256)}
vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [43]:
def merge(indices: list[int], pair: tuple[int, int], new_index: int) -> list[int]:  # @inspect indices, @inspect pair, @inspect new_index
    """Return `indices`, but with all instances of `pair` replaced with `new_index`."""
    new_indices = []  # @inspect new_indices
    i = 0  # @inspect i
    while i < len(indices):
        if i + 1 < len(indices) and indices[i] == pair[0] and indices[i + 1] == pair[1]:
            new_indices.append(new_index)
            i += 2
        else:
            new_indices.append(indices[i])
            i += 1
    return new_indices

In [None]:
indices = list(map(int, string.encode("utf-8")))  # @inspect indices
merges: dict[tuple[int, int], int] = {}  # index1, index2 => merged index
vocab: dict[int, bytes] = {x: bytes([x]) for x in range(256)}  # index -> bytes
for i in range(8):
    counts = defaultdict(int)
    for index1,index2 in zip(indices,indices[1:]):
    
        counts[(index1,index2)]+=1
    pair=max(counts,key=counts.get)
    index1,index2=pair
    
    ## Merges that pair 
    new_index=256+i
    merges[pair]=new_index
    vocab[new_index] = vocab[index1] + vocab[index2]
    indices=merge(indices,pair,new_index)

72 101


In [44]:
chr(0)

'\x00'

In [None]:
chr(0)
print(chr(0))

print("this is a test" + chr(0) + "string")

 
this is a test string


In [48]:
"this is a test" + chr(0) + "string"

'this is a test\x00string'

In [49]:
print(chr(0))

 
