# Tokenizer 

A lot of oddity of LLMs comes from tokenizers, into which we will deep dive today. 

In little shakespeare nano-gpt, we simply used character level tokenizers. This casues the context length to explode. It is shown that he __complexiety of train time increase quadratically with context length__, hence we are incentivized to have longer token sizes. <br>
Sub-word tokenizers are most commonly used, using algorithms such as __Byte pair encoding__. 

__Tokenization is essentially the process of converting text into sequences of integers and vice versa.__


Experiment with tokenizers here: https://www.google.com/url?q=https%3A%2F%2Ftiktokenizer.vercel.app

- Tokenization can be perplexing; for ex. 
1. Egg  = E + gg
2. I have an egg = I h + ave + an + egg
3. Egg vs egg vs EGG

The tokenization for the same word can be done differently depending on case, context etc. 

- encoding: 'Eg' (char) -> 1223 (token) -> fetch embedding of 1223rd row from emb table 
- decoding: vice versa

- Training data usually has much more english than other. Further, in say Korean or Japanese, the tokens are smaller (at max 2 chars) as opposed to english, with longer subword tokens. <br>
=> More tokens => less context can fit within the block_size

- In gpt-2 tokenizer for instance; ' ' is a token. So indentation in python is multiple tokens and we end up bloating the transformer while conveying little useful information. Hence, gpt-2 performed poorly on python tasks. 

- the same text may have different # of tokens depending on which tokenizer it is. 
    - __Intuition__: gpt 4 has roughly 2 $\times$ the tokens as the gpt 2 tokenizer. More is not always better. Since dim of embedding table $\uparrow$, output porbabilities become more scattered. 


In [108]:
import torch

In [1]:
"안녕하세요 👋 (hello in Korean!)"

'안녕하세요 👋 (hello in Korean!)'

what even is a string in python? => [unicode code points](https://en.wikipedia.org/wiki/Unicode)

can be accessed using the `ord()` function in python

A full breakdown of [unicode](https://www.reedbeta.com/blog/programmers-intro-to-unicode/)

In [6]:
print([ord(char) for char in "안녕하세요 👋"])

[50504, 45397, 54616, 49464, 50836, 32, 128075]


Why not use unicode based tokenization? 
- Vocab would be quite long ~ 150k chars 
- keeps changing

Standardization attempts lead to [UTf-8](http://utf8everywhere.org/), Utf-16, UTf-32 encodings. UTF-8 is the only one which is backword compatible to the ASCII format. 

In [None]:
len("안녕하세요 👋 (hello in Korean!") #25

list(("안녕하세요 👋 (hello in Korean!").encode('utf-8')) # len = 38

list(("안녕하세요 👋 (hello in Korean!").encode('utf-16')) # len = 54

One may remark that UTF-16 is looser and seems a bit 'wasteful' since it expresses the same string using a longer list. 

But _why_ is len of utf-8 encoding more than # of chars? $\implies$ because ASCII chars (a,e,b,',', k,l, % etc)  take 1 byte, korean char takes 3 bytes whereas the emoji take 4 bytes. <br>
Still, even UTF-8 may seem like a waste of space. So we come up with a middle ground: [__Byte pair encoding__](https://en.wikipedia.org/wiki/Byte-pair_encoding)

Now lets implement the BPE on a simple text.

In [33]:
# text = open('text.txt').read() -- doesnt work, since default encoding cp1252 cant read enojis etc; so specify encoding!!

with open('text.txt', encoding='utf-8') as f:
    text = f.read()

list(text[:5].encode('utf-8'))

[34, 239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131]

In [44]:
tokens = list(map(int, text.encode('utf-8'))) # some formatting 

# print(len(text), len(tokens))

Step 1 of BPE: Iterate over utf-8 encoded text (`tokens`) and find the pair of bytes which occurs most frequently 

In [46]:
# find the most frequently occuring pair first 
# a very pythonic way to do this: 

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)

In [58]:
# (i,j) for i,j in stats.items()
print(sorted(((v,k) for k,v in stats.items()), reverse=True))

[(20, (101, 32)), (15, (240, 159)), (12, (226, 128)), (12, (105, 110)), (10, (115, 32)), (10, (97, 110)), (10, (32, 97)), (9, (32, 116)), (8, (116, 104)), (7, (159, 135)), (7, (159, 133)), (7, (97, 114)), (6, (239, 189)), (6, (140, 240)), (6, (128, 140)), (6, (116, 32)), (6, (114, 32)), (6, (111, 114)), (6, (110, 103)), (6, (110, 100)), (6, (109, 101)), (6, (104, 101)), (6, (101, 114)), (6, (32, 105)), (5, (117, 115)), (5, (115, 116)), (5, (110, 32)), (5, (100, 101)), (5, (44, 32)), (5, (32, 115)), (4, (116, 105)), (4, (116, 101)), (4, (115, 44)), (4, (114, 105)), (4, (111, 117)), (4, (111, 100)), (4, (110, 116)), (4, (110, 105)), (4, (105, 99)), (4, (104, 97)), (4, (103, 32)), (4, (101, 97)), (4, (100, 32)), (4, (99, 111)), (4, (97, 109)), (4, (85, 110)), (4, (32, 119)), (4, (32, 111)), (4, (32, 102)), (4, (32, 85)), (3, (118, 101)), (3, (116, 115)), (3, (116, 114)), (3, (116, 111)), (3, (114, 116)), (3, (114, 115)), (3, (114, 101)), (3, (111, 102)), (3, (111, 32)), (3, (108, 108)), (

What does `max(stats, key=stats.get)` do?

By default, `max(dictionary)` would return the largest key (based on Python’s default ordering, not useful here).

But when you give it `key=stats.get`, you’re telling Python: <br>
👉 “Instead of comparing the keys directly, compare them based on their values in the dictionary.”

### Internally:

Evaluate step by step

For `('a','b')`: `stats.get(('a','b'))` = 3

For `('b','c')`: `stats.get(('b','c'))` = 5

For `('c','d')`: `stats.get(('c','d'))` = 2

__Now it just picks the key with the largest value: ('b','c') because 5 is the maximum.__

In [66]:
# equivalent: max(stats, key=lambda k: stats[k])

max(stats, key = stats.get) 

(101, 32)

In [111]:
def get_stats2(ids):

    #self-implementation using a counts table
    
    dim = max(ids)
    count = torch.zeros(dim+1,dim+1) # +1 to avoid index overspill - think!

    for k in range(len(tokens)-1):
        count[tokens[k], tokens[k+1]]  += 1
    
    # arg max for 2D
    flat_index = torch.argmax(count) 
    row = flat_index // count.shape[1]
    col = flat_index % count.shape[1]
    return (row.item(), col.item())

get_stats2(tokens)

(101, 32)

essentially karpathy returns a dictionary in get_stats() and then does some ops to extract the most frequent token pair, I do it natively in a table itself using arg max. 