<a href="https://colab.research.google.com/github/ak2742/mlplay/blob/PyTorch-Models/tokenizer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#Code to mount Google Drive at Colab Notebook instance
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title load data

text = "नमस्ते 💖 World!!"
file_path = "/content/drive/MyDrive/Colab Notebooks/shakespeare.txt"
with open(file_path, "r", encoding="utf-8") as f:
    text += f.read()
chars = sorted(set(text))
vocab_size = len(chars)

# unicodes = [ord(x) for x in text]
# list(txt.encode("utf-8")) # utf-8 utf-16 utf-32

In [None]:
#@title get utf-8 codes for text

tokens = text.encode("utf-8")
tokens = list(map(int,tokens))
print(len(text))
print(len(tokens))

In [None]:
#@title fn to get the n of occurences of a pair in tokens

def get_stats(ids):   #{pair -> n_occur}
    pairs = {}
    for i in range(len(ids) - 1):
        pair = tuple(ids[i:i+2])
        pairs[pair] = pairs.get(pair, 0) + 1
    return pairs

# stats = get_stats(tokens)
# print(sorted(((v, k) for k, v in stats.items()), reverse=True)[0])

# top_pair = max(stats, key=stats.get)
# print(top_pair)

In [None]:
#@title fn to merge a pair into a new token

def merge_pairs(ids, pair, idx):
  # in a list of ints (ids), replace all consecutive occurences of pair with the new token idx
  new_ids = []
  i = 0
  while i < len(ids):
    if i < len(ids)-1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      new_ids.append(idx)
      i += 2
    else:
      new_ids.append(ids[i])
      i += 1
  return new_ids

In [None]:
#@title merge tokens upto vocab size

vocab_size = 280  # max unique tokens after merge
num_merges = vocab_size-256
ids = list(tokens)

merges = {} # {pair -> new_token}

for i in range(num_merges):
  stats = get_stats(ids)
  top_pair = max(stats, key=stats.get)
  idx = 256 + i
  ids = merge_pairs(ids, top_pair, idx)
  merges[top_pair] = idx

merges

In [None]:
print(len(tokens)) # tokens before merges

print(len(ids)) # tokens after merges

print(f"compression ratio {len(tokens)/len(ids):.2f}x")


In [None]:
#@title encoder-decoder

vocab = {idx: bytes([idx]) for idx in range(256)} # {tokens -> bytes}
for (p0, p1), idx in merges.items():     # for merges
  vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  tokens = b"".join(vocab[i] for i in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

def encode(text):
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda k: merges.get(k, float("inf")))
    if pair not in merges:
      break
    idx = merges[pair]
    tokens = merge_pairs(tokens, pair, idx)
  return tokens


In [None]:
#@title gpt regex to split text

import regex as re
reg = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")

print(reg.findall(text[:100]))

In [None]:
#@title test

txt = text
encoded_txt = encode(txt)
# print(encoded_txt)
decoded_txt = decode(encoded_txt)
print(decoded_txt[:100])
print(decoded_txt == txt)