## Unicode and Unicode encodings 

**Note**: This notebook is heavlity inspired by Andrej Karpathy's video [Let's build the GPT tokenizer](https://www.youtube.com/watch?v=zduSFxRajkE&t=1715s) 

In Python, strings are encoded with unicode, meaning each character is a unicode code point. Code points are identified by number, customarily written in hexadecimal with the prefix “U+”,
which are their index in the codespace. To retrieve the unicode code point of a character in Python, use the `ord` function with the character as its parameter. 

Unicode defines 3 types of encodings: UTF-8, UTF-16 and UTF-32. These encodings allow the standard's abstracted codes for characters to be processed and stored as binary data (refer to [A Programmer's Introduction to Unicode](https://www.reedbeta.com/blog/programmers-intro-to-unicode/)). 

Using the UTF-8 encoding to represent each characters as 1 - 4 bytes, which we can straight up visualize as integers. We can then utilize the byte pair encoding algorithm to shorten the representations of token, allowing our language model to see more tokens in its attention. 

## Byte pair encoding

In Karpathy's video, he implemented the byte pair encodign algorithm by looping through all the possible bytes combination, and update their frequencies. Then, he uses the max function to get the byte pair with the most frequency. After that, he creates the new byte sequence based on the new representation of the byte pair with the highest frequency. The `max()` function has a time complexity of **O(n)**, which results in his algorithm having the time complexity of **O(3n)**. 

Another way of implementing this procedure without having to rely on the max function is using a variable to always maintain the byte pair with the highest frequency. This will remove the need for a `max()` function, reducing the time complexity of this function to **O(2n)**. 

In [30]:
s = "日本からこんにちは (Hello from Japan!)"

In [31]:
unicode_s = [ord(x) for x in s] 

# bytes representation converted to integers representation 
utf_8_s = list(s.encode("utf-8"))

print(f"{unicode_s} | length: {len(unicode_s)}")
print(f"{utf_8_s} | length: {len(utf_8_s)}") # longer length meaning some characters need more than 1 byte to be represented

[26085, 26412, 12363, 12425, 12371, 12435, 12395, 12385, 12399, 32, 40, 72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 74, 97, 112, 97, 110, 33, 41] | length: 29
[230, 151, 165, 230, 156, 172, 227, 129, 139, 227, 130, 137, 227, 129, 147, 227, 130, 147, 227, 129, 171, 227, 129, 161, 227, 129, 175, 32, 40, 72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 74, 97, 112, 97, 110, 33, 41] | length: 47


In [32]:
# encode a long paragraph
text = " UTF-8, each code point is stored using 1 to 4 bytes, based on its index value.  UTF-8 uses a system of binary prefixes, in which the high bits of each byte mark whether it’s a single byte, the beginning of a multi-byte sequence, or a continuation byte; the remaining bits, concatenated, give the code point index."

tokens = text.encode("utf-8")
tokens = list(tokens) 

print(text) 
print(f"length: {len(text)}")
print("----")
print(tokens) 
print(f"length: {len(tokens)}")

 UTF-8, each code point is stored using 1 to 4 bytes, based on its index value.  UTF-8 uses a system of binary prefixes, in which the high bits of each byte mark whether it’s a single byte, the beginning of a multi-byte sequence, or a continuation byte; the remaining bits, concatenated, give the code point index.
length: 314
----
[32, 85, 84, 70, 45, 56, 44, 32, 101, 97, 99, 104, 32, 99, 111, 100, 101, 32, 112, 111, 105, 110, 116, 32, 105, 115, 32, 115, 116, 111, 114, 101, 100, 32, 117, 115, 105, 110, 103, 32, 49, 32, 116, 111, 32, 52, 32, 98, 121, 116, 101, 115, 44, 32, 98, 97, 115, 101, 100, 32, 111, 110, 32, 105, 116, 115, 32, 105, 110, 100, 101, 120, 32, 118, 97, 108, 117, 101, 46, 32, 32, 85, 84, 70, 45, 56, 32, 117, 115, 101, 115, 32, 97, 32, 115, 121, 115, 116, 101, 109, 32, 111, 102, 32, 98, 105, 110, 97, 114, 121, 32, 112, 114, 101, 102, 105, 120, 101, 115, 44, 32, 105, 110, 32, 119, 104, 105, 99, 104, 32, 116, 104, 101, 32, 104, 105, 103, 104, 32, 98, 105, 116, 115, 32, 111, 

In [33]:
# ----- byte pair encoding -----

def merge(tokens, byte_pair_max, byte_pair_representation): 
    new_token = [] 

    i = 0 
    while i < len(tokens): 
        if i  <= len(tokens) - 2: 
            a = tokens[i] 
            b = tokens[i + 1]
            byte_pair = str(a) + str(b) 

            if byte_pair == byte_pair_max: 
                new_token.append(byte_pair_representation)
                i += 2 # skip the next byte because they are combined 
                continue 
            else: 
                new_token.append(a)
        else: 
            new_token.append(tokens[i]) 
        i += 1

    return new_token
    
def byte_pair_encoding(tokens, merges: dict, counter=256): 
    # loop through all bytes 
    pair_lookup = {} 
    byte_pair_max = ""
    byte_pair_max_frequency = 0
    byte_pair_set = ()

    for i in range(len(tokens) - 1): 
        a = tokens[i] 
        b = tokens[i + 1]
        byte_pair = str(a) + str(b) 

        if pair_lookup.get(byte_pair): 
            pair_lookup[byte_pair] += 1
        else: 
            pair_lookup[byte_pair] = 1 
        
        if pair_lookup[byte_pair] > byte_pair_max_frequency: 
            byte_pair_max_frequency = pair_lookup[byte_pair]
            byte_pair_max = byte_pair
            byte_pair_set = (a, b) 

    # update byte pair to new representation 
    byte_pair_representation = counter
    pair_lookup[byte_pair_representation] = byte_pair_max_frequency

    # update current token to token replaced with new byte pair
    tokens = merge(tokens=tokens, byte_pair_max=byte_pair_max, byte_pair_representation=byte_pair_representation)
    merges[byte_pair_set] = byte_pair_representation

    print(f"byte pair with most frequency: {byte_pair_max} frequency: {byte_pair_max_frequency} | new representation: {counter} | new token length: {len(tokens)}")
    print("--")
    print(tokens) 
    print("---------------------")

    return tokens

def byte_pair_encoding_full_procedure(tokens, depth=0): 
    merges = {}
    for i in range(depth): 
        tokens = byte_pair_encoding(tokens, merges=merges, counter=256+i)

    return (tokens, merges)

In [34]:
new_tokens, merges = byte_pair_encoding_full_procedure(tokens, depth=20)

byte pair with most frequency: 105110 frequency: 13 | new representation: 256 | new token length: 303
--
[32, 85, 84, 70, 45, 56, 44, 32, 101, 97, 99, 104, 32, 99, 111, 100, 101, 32, 112, 111, 256, 116, 32, 105, 115, 32, 115, 116, 111, 114, 101, 100, 32, 117, 115, 256, 103, 32, 49, 32, 116, 111, 32, 52, 32, 98, 121, 116, 101, 115, 44, 32, 98, 97, 115, 101, 100, 32, 111, 110, 32, 105, 116, 115, 32, 256, 100, 101, 120, 32, 118, 97, 108, 117, 101, 46, 32, 32, 85, 84, 70, 45, 56, 32, 117, 115, 101, 115, 32, 97, 32, 115, 121, 115, 116, 101, 109, 32, 111, 102, 32, 98, 256, 97, 114, 121, 32, 112, 114, 101, 102, 105, 120, 101, 115, 44, 32, 256, 32, 119, 104, 105, 99, 104, 32, 116, 104, 101, 32, 104, 105, 103, 104, 32, 98, 105, 116, 115, 32, 111, 102, 32, 101, 97, 99, 104, 32, 98, 121, 116, 101, 32, 109, 97, 114, 107, 32, 119, 104, 101, 116, 104, 101, 114, 32, 105, 116, 226, 128, 153, 115, 32, 97, 32, 115, 256, 103, 108, 101, 32, 98, 121, 116, 101, 44, 32, 116, 104, 101, 32, 98, 101, 103, 256, 

In [39]:
for key, value in merges.items(): 
    print(f"Merge token {key[0]}, {key[1]} -> {value}")

Merge token 105, 110 -> 256
Merge token 101, 32 -> 257
Merge token 32, 98 -> 258
Merge token 116, 101 -> 259
Merge token 44, 32 -> 260
Merge token 115, 32 -> 261
Merge token 116, 104 -> 262
Merge token 97, 32 -> 263
Merge token 256, 103 -> 264
Merge token 105, 116 -> 265
Merge token 111, 110 -> 266
Merge token 262, 257 -> 267
Merge token 99, 104 -> 268
Merge token 111, 102 -> 269
Merge token 115, 101 -> 270
Merge token 258, 121 -> 271
Merge token 114, 101 -> 272
Merge token 116, 111 -> 273
Merge token 100, 32 -> 274
Merge token 32, 85 -> 275


In [43]:
print(f"compression ratio: {len(tokens) / len(new_tokens):.2f}x")

compression ratio: 1.42x


## Decode and Encode 

After having the final vocabulary of the tokenizer, we can continue to implement how the tokenizer can transform text to token if it is given text, and how it can transform token to text if it is given token. 

### Decode 

A way to do this is to create the base vocabulary of 256 bytes ranging from 0 $\to$ 255. Then, in order of the merges, we add the new merge representation to our vocabulary. The new representation is the combination of 2 old byte representations. 

In [56]:
a = bytes(1) 
b = bytes(255) 

print(a) 
print(b) 

test_string = b"".join([a, b]) 

print(test_string.decode("utf-8"))

b'\x00'
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x

In [49]:
# ----- decode -----
vocab = {i: bytes(i) for i in range(256)}

vocab

for key, value in merges.items(): 
    vocab[value] = vocab[key[0]] + vocab[key[1]]

In [50]:
vocab

{0: b'',
 1: b'\x00',
 2: b'\x00\x00',
 3: b'\x00\x00\x00',
 4: b'\x00\x00\x00\x00',
 5: b'\x00\x00\x00\x00\x00',
 6: b'\x00\x00\x00\x00\x00\x00',
 7: b'\x00\x00\x00\x00\x00\x00\x00',
 8: b'\x00\x00\x00\x00\x00\x00\x00\x00',
 9: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 10: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 11: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 12: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 13: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 14: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 15: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 16: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 17: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 18: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 19: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00',
 20: b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0