# Character Level Tokenization

Code taken from [Andrej Karpathy](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing).

In [1]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [2]:
from urllib.request import urlopen
from bs4 import BeautifulSoup

url = "https://raw.githubusercontent.com/cltk/hindi_text_ltrc/master/tulasidaas/Raamacharita_maanasa/1/main.txt"
html = urlopen(url).read()
soup = BeautifulSoup(html, features="html.parser")

# kill all script and style elements
for script in soup(["script", "style"]):
    script.extract()    # rip it out

# get text
text = soup.get_text()
ramayana_text = text
print(type(text))

<class 'str'>


In [3]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  198223


In [4]:
# let's look at the first 1000 characters
print(text[:1000])

जो सुमिरत सिधि होइ गन नायक करिबर बदन।
करउ अनुग्रह सोइ बुद्धि रासि सुभ गुन सदन॥1॥
मूक होइ बाचाल पंगु चढइ गिरिबर गहन।
जासु कृपाँ सो दयाल द्रवउ सकल कलि मल दहन॥2॥
नील सरोरुह स्याम तरुन अरुन बारिज नयन।
करउ सो मम उर धाम सदा छीरसागर सयन॥3॥
कुंद इंदु सम देह उमा रमन करुना अयन।
जाहि दीन पर नेह करउ कृपा मर्दन मयन॥4॥
बंदउ गुरु पद कंज कृपा सिंधु नररूप हरि।
महामोह तम पुंज जासु बचन रबि कर निकर॥5॥
बंदउ गुरु पद पदुम परागा। सुरुचि सुबास सरस अनुरागा॥
अमिय मूरिमय चूरन चारू। समन सकल भव रुज परिवारू॥
सुकृति संभु तन बिमल बिभूती। मंजुल मंगल मोद प्रसूती॥
जन मन मंजु मुकुर मल हरनी। किएँ तिलक गुन गन बस करनी॥
श्रीगुर पद नख मनि गन जोती। सुमिरत दिब्य द्रृष्टि हियँ होती॥
दलन मोह तम सो सप्रकासू। बड़े भाग उर आवइ जासू॥
उघरहिं बिमल बिलोचन ही के। मिटहिं दोष दुख भव रजनी के॥
सूझहिं राम चरित मनि मानिक। गुपुत प्रगट जहँ जो जेहि खानिक॥
दो0-जथा सुअंजन अंजि दृग साधक सिद्ध सुजान।
कौतुक देखत सैल बन भूतल भूरि निधान॥1॥

एहि महँ रघुपति नाम उदारा। अति पावन पुरान श्रुति सारा॥
मंगल भवन अमंगल हारी। उमा सहित जेहि जपत पुरारी॥
भनिति बिचित्र स

In [5]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 (),-0123456789ûँंःअआइईउऊएऐओऔकखगघङचछजझटठडढणतथदधनपफबभमयरलवशषसह़ऽािीुूृेैोौ्।॥०१२३४५६७८९
87


In [6]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("एहि महँ रघुपति "))
print(decode(encode("एहि महँ रघुपति")))

[26, 61, 65, 1, 53, 61, 17, 1, 55, 33, 67, 49, 44, 65, 1]
एहि महँ रघुपति


In [7]:
# let's now encode the entire text dataset and store it into a torch.Tensor
import torch # we use PyTorch: https://pytorch.org
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the LLM look like this

torch.Size([198223]) torch.int64
tensor([37, 72,  1, 60, 67, 53, 65, 55, 44,  1, 60, 65, 47, 65,  1, 61, 72, 22,
         1, 32, 48,  1, 48, 64, 54, 30,  1, 30, 55, 65, 51, 55,  1, 51, 46, 48,
        75,  0, 30, 55, 24,  1, 20, 48, 67, 32, 74, 55, 61,  1, 60, 72, 22,  1,
        51, 67, 46, 74, 47, 65,  1, 55, 64, 60, 65,  1, 60, 67, 52,  1, 32, 67,
        48,  1, 60, 46, 48, 76,  7, 76,  0, 53, 68, 30,  1, 61, 72, 22,  1, 51,
        64, 35, 64, 56,  1, 49, 18, 32, 67,  1, 35, 42, 22,  1, 32, 65, 55, 65,
        51, 55,  1, 32, 61, 48, 75,  0, 37, 64, 60, 67,  1, 30, 69, 49, 64, 17,
         1, 60, 72,  1, 46, 54, 64, 56,  1, 46, 74, 55, 57, 24,  1, 60, 30, 56,
         1, 30, 56, 65,  1, 53, 56,  1, 46, 61, 48, 76,  8, 76,  0, 48, 66, 56,
         1, 60, 55, 72, 55, 67, 61,  1, 60, 74, 54, 64, 53,  1, 44, 55, 67, 48,
         1, 20, 55, 67, 48,  1, 51, 64, 55, 65, 37,  1, 48, 54, 48, 75,  0, 30,
        55, 24,  1, 60, 72,  1, 53, 53,  1, 24, 55,  1, 47, 64, 53,  1, 60, 46,
       

In [8]:
data = torch.tensor(encode("कौन है रघुपति"), dtype=torch.long)
print(data.shape, data.dtype)
print(data)

torch.Size([13]) torch.int64
tensor([30, 73, 48,  1, 61, 71,  1, 55, 33, 67, 49, 44, 65])


In [11]:
text = "नाम जीहँ जपि जागहिं जोगी। बिरति बिरंचि प्रपंच बियोगी॥"
tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
print(text)
print("length:", len(text))
print('---')
print(tokens)
print("length:", len(tokens))

---
नाम जीहँ जपि जागहिं जोगी। बिरति बिरंचि प्रपंच बियोगी॥
length: 53
---
[224, 164, 168, 224, 164, 190, 224, 164, 174, 32, 224, 164, 156, 224, 165, 128, 224, 164, 185, 224, 164, 129, 32, 224, 164, 156, 224, 164, 170, 224, 164, 191, 32, 224, 164, 156, 224, 164, 190, 224, 164, 151, 224, 164, 185, 224, 164, 191, 224, 164, 130, 32, 224, 164, 156, 224, 165, 139, 224, 164, 151, 224, 165, 128, 224, 165, 164, 32, 224, 164, 172, 224, 164, 191, 224, 164, 176, 224, 164, 164, 224, 164, 191, 32, 224, 164, 172, 224, 164, 191, 224, 164, 176, 224, 164, 130, 224, 164, 154, 224, 164, 191, 32, 224, 164, 170, 224, 165, 141, 224, 164, 176, 224, 164, 170, 224, 164, 130, 224, 164, 154, 32, 224, 164, 172, 224, 164, 191, 224, 164, 175, 224, 165, 139, 224, 164, 151, 224, 165, 128, 224, 165, 165]
length: 143


Let's find the pair of bytes that occur most commonly and then replace them

In [30]:
print(list(tokens))

[224, 164, 168, 224, 164, 190, 224, 164, 174, 32, 224, 164, 156, 224, 165, 128, 224, 164, 185, 224, 164, 129, 32, 224, 164, 156, 224, 164, 170, 224, 164, 191, 32, 224, 164, 156, 224, 164, 190, 224, 164, 151, 224, 164, 185, 224, 164, 191, 224, 164, 130, 32, 224, 164, 156, 224, 165, 139, 224, 164, 151, 224, 165, 128, 224, 165, 164, 32, 224, 164, 172, 224, 164, 191, 224, 164, 176, 224, 164, 164, 224, 164, 191, 32, 224, 164, 172, 224, 164, 191, 224, 164, 176, 224, 164, 130, 224, 164, 154, 224, 164, 191, 32, 224, 164, 170, 224, 165, 141, 224, 164, 176, 224, 164, 170, 224, 164, 130, 224, 164, 154, 32, 224, 164, 172, 224, 164, 191, 224, 164, 175, 224, 165, 139, 224, 164, 151, 224, 165, 128, 224, 165, 165]


In [12]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

stats = get_stats(tokens)
print(stats)
print(sorted(((v,k) for k,v in stats.items()), reverse=True))

{(224, 164): 37, (164, 168): 1, (168, 224): 1, (164, 190): 2, (190, 224): 2, (164, 174): 1, (174, 32): 1, (32, 224): 8, (164, 156): 4, (156, 224): 4, (224, 165): 8, (165, 128): 3, (128, 224): 3, (164, 185): 2, (185, 224): 2, (164, 129): 1, (129, 32): 1, (164, 170): 3, (170, 224): 3, (164, 191): 7, (191, 32): 3, (164, 151): 3, (151, 224): 3, (191, 224): 4, (164, 130): 3, (130, 32): 1, (165, 139): 2, (139, 224): 2, (165, 164): 1, (164, 32): 1, (164, 172): 3, (172, 224): 3, (164, 176): 3, (176, 224): 3, (164, 164): 1, (164, 224): 1, (130, 224): 2, (164, 154): 2, (154, 224): 1, (165, 141): 1, (141, 224): 1, (154, 32): 1, (164, 175): 1, (175, 224): 1, (165, 165): 1}
[(37, (224, 164)), (8, (224, 165)), (8, (32, 224)), (7, (164, 191)), (4, (191, 224)), (4, (164, 156)), (4, (156, 224)), (3, (191, 32)), (3, (176, 224)), (3, (172, 224)), (3, (170, 224)), (3, (165, 128)), (3, (164, 176)), (3, (164, 172)), (3, (164, 170)), (3, (164, 151)), (3, (164, 130)), (3, (151, 224)), (3, (128, 224)), (2, (19

In [13]:
top_pair = max(stats, key=stats.get)
top_pair

(224, 164)

In [14]:
from IPython.display import HTML, display

def set_css():
  display(HTML('''
  <style>
    pre {
        white-space: pre-wrap;
    }
  </style>
  '''))
get_ipython().events.register('pre_run_cell', set_css)

In [16]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

print(merge([5, 6, 6, 7, 9, 1], (6, 7), 99))

tokens2 = merge(tokens, top_pair, 256)


print(tokens2)
print("length:", len(tokens2), len(tokens))

[5, 6, 99, 9, 1]
[256, 168, 256, 190, 256, 174, 32, 256, 156, 224, 165, 128, 256, 185, 256, 129, 32, 256, 156, 256, 170, 256, 191, 32, 256, 156, 256, 190, 256, 151, 256, 185, 256, 191, 256, 130, 32, 256, 156, 224, 165, 139, 256, 151, 224, 165, 128, 224, 165, 164, 32, 256, 172, 256, 191, 256, 176, 256, 164, 256, 191, 32, 256, 172, 256, 191, 256, 176, 256, 130, 256, 154, 256, 191, 32, 256, 170, 224, 165, 141, 256, 176, 256, 170, 256, 130, 256, 154, 32, 256, 172, 256, 191, 256, 175, 224, 165, 139, 256, 151, 224, 165, 128, 224, 165, 165]
length: 106 143


In [17]:
tokens = ramayana_text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience

In [18]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

merging (224, 164) into a new token 256
merging (32, 256) into a new token 257
merging (224, 165) into a new token 258
merging (256, 190) into a new token 259
merging (256, 191) into a new token 260
merging (256, 176) into a new token 261
merging (258, 129) into a new token 262
merging (259, 256) into a new token 263
merging (256, 168) into a new token 264
merging (256, 185) into a new token 265
merging (260, 257) into a new token 266
merging (260, 256) into a new token 267
merging (164, 257) into a new token 268
merging (258, 165) into a new token 269
merging (258, 128) into a new token 270
merging (258, 135) into a new token 271
merging (10, 256) into a new token 272
merging (257, 184) into a new token 273
merging (258, 141) into a new token 274
merging (258, 139) into a new token 275


In [19]:
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 516373
ids length: 222256
compression ratio: 2.32X


In [20]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
vocab_size = 1000 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  # print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

tokens length: 516373
ids length: 87110
compression ratio: 5.93X


### Decoding



In [32]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([349]))

हु


### Encoding

The other way around: Given a string, what are the tokens?


In [23]:
merges

{(224, 164): 256,
 (32, 256): 257,
 (224, 165): 258,
 (256, 190): 259,
 (256, 191): 260,
 (256, 176): 261,
 (258, 129): 262,
 (259, 256): 263,
 (256, 168): 264,
 (256, 185): 265,
 (260, 257): 266,
 (260, 256): 267,
 (164, 257): 268,
 (258, 165): 269,
 (258, 128): 270,
 (258, 135): 271,
 (10, 256): 272,
 (257, 184): 273,
 (258, 141): 274,
 (258, 139): 275,
 (262, 256): 276,
 (258, 268): 277,
 (269, 272): 278,
 (257, 172): 279,
 (256, 130): 280,
 (257, 149): 281,
 (256, 164): 282,
 (256, 184): 283,
 (256, 178): 284,
 (257, 174): 285,
 (256, 149): 286,
 (257, 170): 287,
 (256, 174): 288,
 (263, 168): 289,
 (257, 168): 290,
 (257, 156): 291,
 (274, 261): 292,
 (259, 261): 293,
 (256, 172): 294,
 (256, 175): 295,
 (256, 170): 296,
 (256, 166): 297,
 (256, 151): 298,
 (256, 181): 299,
 (262, 264): 300,
 (258, 130): 301,
 (265, 267): 302,
 (257, 133): 303,
 (265, 266): 304,
 (259, 277): 305,
 (275, 256): 306,
 (256, 268): 307,
 (302, 130): 308,
 (257, 164): 309,
 (257, 166): 310,
 (257, 173):

In [33]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

print(encode("हु"))

[349]


In [26]:
print(decode(encode("नाम जीहँ जपि जागहिं जोगी।")))

नाम जीहँ जपि जागहिं जोगी।
