# Tokenizer 

A lot of oddity of LLMs comes from tokenizers, into which we will deep dive today. 

In little shakespeare nano-gpt, we simply used character level tokenizers. This casues the context length to explode. It is shown that he __complexiety of train time increase quadratically with context length__, hence we are incentivized to have longer token sizes. <br>
Sub-word tokenizers are most commonly used, using algorithms such as __Byte pair encoding__. 

__Tokenization is essentially the process of converting text into sequences of integers and vice versa.__


Experiment with tokenizers here: https://www.google.com/url?q=https%3A%2F%2Ftiktokenizer.vercel.app

- Tokenization can be perplexing; for ex. 
1. Egg  = E + gg
2. I have an egg = I h + ave + an + egg
3. Egg vs egg vs EGG

The tokenization for the same word can be done differently depending on case, context etc. 

- encoding: 'Eg' (char) -> 1223 (token) -> fetch embedding of 1223rd row from emb table 
- decoding: vice versa

- Training data usually has much more english than other. Further, in say Korean or Japanese, the tokens are smaller (at max 2 chars) as opposed to english, with longer subword tokens. <br>
=> More tokens => less context can fit within the block_size

- In gpt-2 tokenizer for instance; ' ' is a token. So indentation in python is multiple tokens and we end up bloating the transformer while conveying little useful information. Hence, gpt-2 performed poorly on python tasks. 

- the same text may have different # of tokens depending on which tokenizer it is. 
    - __Intuition__: gpt 4 has roughly 2 $\times$ the tokens as the gpt 2 tokenizer. More is not always better. Since dim of embedding table $\uparrow$, output porbabilities become more scattered. 


In [1]:
import torch

In [2]:
"안녕하세요 👋 (hello in Korean!)"

'안녕하세요 👋 (hello in Korean!)'

what even is a string in python? => [unicode code points](https://en.wikipedia.org/wiki/Unicode)

can be accessed using the `ord()` function in python. `chr()` is the inverse of that. 

A full breakdown of [unicode](https://www.reedbeta.com/blog/programmers-intro-to-unicode/)

In [6]:
print([ord(char) for char in "안녕하세요 👋"])

[50504, 45397, 54616, 49464, 50836, 32, 128075]


Why not use unicode based tokenization? 
- Vocab would be quite long ~ 150k chars 
- keeps changing

Standardization attempts lead to [UTf-8](http://utf8everywhere.org/), Utf-16, UTf-32 encodings. UTF-8 is the only one which is backword compatible to the ASCII format. 

In [None]:
len("안녕하세요 👋 (hello in Korean!") #25

list(("안녕하세요 👋 (hello in Korean!").encode('utf-8')) # len = 38

list(("안녕하세요 👋 (hello in Korean!").encode('utf-16')) # len = 54

One may remark that UTF-16 is looser and seems a bit 'wasteful' since it expresses the same string using a longer list. 

But _why_ is len of utf-8 encoding more than # of chars? $\implies$ because ASCII chars (a,e,b,',', k,l, % etc)  take 1 byte, korean char takes 3 bytes whereas the emoji take 4 bytes. <br>

__Each byte will essentially be a number between 0-255. So emoji = [23,90,88,157] or "A" = [65] or "e" = [101] and so on.__

Still, even UTF-8 may seem like a waste of space. So we come up with a middle ground: [__Byte pair encoding__](https://en.wikipedia.org/wiki/Byte-pair_encoding)

Now lets implement the BPE on a simple text.

In [3]:
# text = open('text.txt').read() -- doesnt work, since default encoding cp1252 cant read enojis etc; so specify encoding!!

with open('text.txt', encoding='utf-8') as f:
    text = f.read()

list(text[:5].encode('utf-8'))

[34, 239, 188, 181, 239, 189, 142, 239, 189, 137, 239, 189, 131]

In [90]:
tokens = text.encode("utf-8") # raw bytes

print(type(tokens))
print(f'combined bytes in hex: {tokens[:10]} | all bytes in hex together: {tokens[:10].hex()}')

tokens = list(map(int, text.encode('utf-8'))) # some formatting 

print(tokens[:10])

# print(len(text), len(tokens))

<class 'bytes'>
combined bytes in hex: b'"\xef\xbc\xb5\xef\xbd\x8e\xef\xbd\x89' | all bytes in hex together: 22efbcb5efbd8eefbd89
[34, 239, 188, 181, 239, 189, 142, 239, 189, 137]


`b'"\xef\xbc\xb5\xef\xbd\x8e\xef\xbd\x89'`: b indicates byte class, `\x` indicates next 2 chars are hexadecimal (ef, bc, b5, bd, 89)

In [None]:
some = 'some words'
print(list(map(int, some.encode('utf-8')))) # map is not needed hmm..

print(list(some.encode('utf-8')))

[115, 111, 109, 101, 32, 119, 111, 114, 100, 115]
[115, 111, 109, 101, 32, 119, 111, 114, 100, 115]


Step 1 of BPE: Iterate over utf-8 encoded text (`tokens`) and find the pair of bytes which occurs most frequently 

In [5]:
# find the most frequently occuring pair first 
# a very pythonic way to do this: 

def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)

In [None]:
# to sort by v - which is the first variable. 
print(sorted(((v,k) for k,v in stats.items()), reverse=True)) 

# equivalent: max(stats, key=lambda k: stats[k])
# 'ranking by the value to get the maximum key' 
top_pair = max(stats, key = stats.get) 
top_pair

[(20, (101, 32)), (15, (240, 159)), (12, (226, 128)), (12, (105, 110)), (10, (115, 32)), (10, (97, 110)), (10, (32, 97)), (9, (32, 116)), (8, (116, 104)), (7, (159, 135)), (7, (159, 133)), (7, (97, 114)), (6, (239, 189)), (6, (140, 240)), (6, (128, 140)), (6, (116, 32)), (6, (114, 32)), (6, (111, 114)), (6, (110, 103)), (6, (110, 100)), (6, (109, 101)), (6, (104, 101)), (6, (101, 114)), (6, (32, 105)), (5, (117, 115)), (5, (115, 116)), (5, (110, 32)), (5, (100, 101)), (5, (44, 32)), (5, (32, 115)), (4, (116, 105)), (4, (116, 101)), (4, (115, 44)), (4, (114, 105)), (4, (111, 117)), (4, (111, 100)), (4, (110, 116)), (4, (110, 105)), (4, (105, 99)), (4, (104, 97)), (4, (103, 32)), (4, (101, 97)), (4, (100, 32)), (4, (99, 111)), (4, (97, 109)), (4, (85, 110)), (4, (32, 119)), (4, (32, 111)), (4, (32, 102)), (4, (32, 85)), (3, (118, 101)), (3, (116, 115)), (3, (116, 114)), (3, (116, 111)), (3, (114, 116)), (3, (114, 115)), (3, (114, 101)), (3, (111, 102)), (3, (111, 32)), (3, (108, 108)), (

(101, 32)

Note:<br>
Andrej's implementation for `top_pair` just runs way faster.. <br>

```def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(ids)
pair = max(stats, key=stats.get)```

In [53]:
def get_stats2(ids):

    #self-implementation using a counts table
    
    dim = max(ids)
    count = torch.zeros(dim+1,dim+1) # +1 to avoid index overspill - think!

    for k in range(len(ids)-1):
        count[ids[k], ids[k+1]]  += 1
    
    # arg max for 2D
    flat_index = torch.argmax(count) 
    row = flat_index // count.shape[1]
    col = flat_index % count.shape[1]
    return (row.item(), col.item())

top_pair = get_stats2(tokens)
top_pair

(101, 32)

essentially karpathy returns a dictionary in get_stats() and then does some ops to extract the most frequent token pair, I do it natively in a table itself using arg max. 

In [24]:
chr(101), chr(32)

('e', ' ')

$\implies$ 'e ' is the most common token pair, so a lot of words in the text seems to end with 'e'. 

Since we have 0-255 codes already defined, we must define a new code (256) for 'e '. 

The merge logic can be captured in a function: <br>
__P.S.:__ (I wrote one on my own, but it missed edge cases and could induce bugs - Need to genuinely learn these basic manipulations!)

In [39]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and (ids[i],ids[i+1]) == (pair[0], pair[1]):
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

tokens2 = merge(tokens, top_pair, 256)
print('Original length = ', len(tokens), '\nLength after merge = ', len(tokens2))

Original length =  617 
Length after merge =  597


So now we have the tools to idenity the top_pair and merge it. <br>

__So much do we compress? That is a hyperparameter__. More iteration => larger vocab, smaller context size. 

Lets perform this exercise on the complete text of [this blog](https://www.reedbeta.com/blog/programmers-intro-to-unicode/). 

In [43]:
with open('more_text.txt', encoding='utf-8') as f:
    full_text = f.read()

full_tokens = full_text.encode("utf-8") # raw bytes
full_tokens = list(map(int, full_tokens)) # convert to a list of integers in range 0..255 for convenience

len(full_text), len(full_tokens)

(22636, 23884)

Now lets make a coherent cell: 

In [91]:
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(full_tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
length_post_merge = []

for i in range(num_merges):
  pair = get_stats2(ids)
  idx = 256 + i
  ids = merge(ids, pair, idx)
  
  # tracking
  length_post_merge.append(len(ids))
  merges[pair] = idx

  print(f"merging {pair} into a new token {idx} | sequence length = {length_post_merge[i]}")

merging (101, 32) into a new token 256 | sequence length = 23252
merging (105, 110) into a new token 257 | sequence length = 22813
merging (115, 32) into a new token 258 | sequence length = 22402
merging (116, 104) into a new token 259 | sequence length = 22065
merging (101, 114) into a new token 260 | sequence length = 21777
merging (99, 111) into a new token 261 | sequence length = 21490
merging (116, 32) into a new token 262 | sequence length = 21208
merging (226, 128) into a new token 263 | sequence length = 20962
merging (44, 32) into a new token 264 | sequence length = 20720
merging (97, 110) into a new token 265 | sequence length = 20494
merging (111, 114) into a new token 266 | sequence length = 20283
merging (100, 32) into a new token 267 | sequence length = 20074
merging (97, 114) into a new token 268 | sequence length = 19897
merging (101, 110) into a new token 269 | sequence length = 19728
merging (257, 103) into a new token 270 | sequence length = 19565
merging (261, 100) 

In [92]:
print(f"In {num_merges} merges:")
print(f"compression ratio = {len(full_tokens)} -> {length_post_merge[-1]} = {len(full_tokens)/length_post_merge[-1]:.2f}x")

In 20 merges:
compression ratio = 23884 -> 18832 = 1.27x


## Tokenizer is independent of the LM

Note, the Tokenizer is a completely separate, independent module from the LLM. It has its own training dataset of text (which could be different from that of the LLM), on which you train the vocabulary using the Byte Pair Encoding (BPE) algorithm. It then translates back and forth between raw text and sequences of tokens. The LLM later only ever sees the tokens and never directly deals with any text.

<img title="a title" alt="Alt text" src="llm_tokenizer.png" width = 60%>

It facilitates raw text $\rightleftharpoons$ tokens conversion. 

The diversity of your tokenizer training set directly affects the LM performance (french, japanese, code presence and density) because the merges are determined in that way. 

### Next step

Now lets create a decode function which accepts list of integers and return the corresponding python string while incorporating our new characters. 

In [76]:
def decode(ids):
    """Returns a python string corresponding to the input list on integers ids"""
    reconstructed = bytes(ids).decode('utf-8')
    return reconstructed

    

In [None]:
bytes([17]) # \x = in hexadecimal; b = byte class object

b'\x11'