<a href="https://colab.research.google.com/github/SatishWG/BytePairEncoding/blob/main/BytePairEncoding.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import ipykernel
import sys
import os
import numpy as np
import pandas as pd
import datasets
import torch
import re
import tiktoken

In [None]:
# Load the Hindi Wikipedia dataset
from datasets import load_dataset
dataset = load_dataset("zicsx/Wikipedia-Hindi")

In [None]:
dataset

In [None]:
# texts = ""
# for example in dataset['train']:
#     texts += example['text'] + "\n" # Add a newline to separate texts

# print(f"Total number of characters in 'texts': {len(texts)}")
# # print(texts[:500]) # Display the first 500 characters to verify

In [None]:
#Read sample data
text0 = dataset['train'][0]['text']
print(text0)

In [None]:
# Extract the first 100 rows of the 'text' column
texts100 = dataset["train"]["text"][:100]

# Optional: check the first few entries
print(texts100[:5])

In [None]:
#Save the dataset File
from datasets import Dataset
# Convert the list of texts to a Dataset
dataset_10k = Dataset.from_dict({"text": texts100})

# Save to disk in Hugging Face format
dataset_10k.save_to_disk("texts100_dataset")

print("‚úÖ Dataset saved to 'texts100_dataset/'")

In [None]:
#Extract only text from the data
text = " ".join(texts100)

# Optional: check a snippet
print(text[:500])

In [None]:
print("length of text in words: ", len(text))

In [None]:
# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)

In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

In [None]:
print(encode(text0))
print(decode(encode(text0)))

In [None]:
# let's now encode the entire text dataset and store it into a torch.Tensor

data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
# print(data[:1000]) # the 1000 characters we looked at earier will to the LLM look like this

In [None]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)
print(data)

Apply Regex specific to Hindi

In [None]:
# Correct regex pattern (string must be in quotes)

devanagari_word = re.compile(
    r'(?:[\u0904-\u0939\u0958-\u0961\u0966-\u096F]'
    r'(?:[\u093C\u094D]?[\u0904-\u0939\u0958-\u0961])*'
    r'[\u0900-\u0903\u093A-\u094F\u0951-\u0957\u0962-\u0963]*'
    r')+'
)

dev_text = devanagari_word.findall(text)

# print("Devanagari words:", dev_text)
print("Number of words:", len(dev_text))

In [None]:
encdev_text = ' '.join(dev_text)

dev_tokens = encdev_text.encode("utf-8") # raw bytes
dev_tokens = list(map(int, dev_tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
# print(text)
print("length of text:", len(encdev_text))
print('---')
# print(tokens)
print("length of tokens:", len(dev_tokens))

In [None]:
# text from https://www.reedbeta.com/blog/programmers-intro-to-unicode/
# text = "ÔºµÔΩéÔΩâÔΩÉÔΩèÔΩÑÔΩÖ! üÖ§üÖùüÖòüÖíüÖûüÖìüÖî‚ÄΩ üá∫‚Äåüá≥‚ÄåüáÆ‚Äåüá®‚Äåüá¥‚Äåüá©‚Äåüá™! üòÑ The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to ‚Äúsupport Unicode‚Äù in our software (whatever that means‚Äîlike using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don‚Äôt blame programmers for still finding the whole thing mysterious, even 30 years after Unicode‚Äôs inception."

tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
print('---')
# print(text)
print("length of text:", len(text))
print('---')
# print(tokens)
print("length of tokens:", len(tokens))

Let's find the pair of bytes that occur most commonly and then replace them

In [None]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]): # Pythonic way to iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

# stats = get_stats(tokens)
stats = get_stats(dev_tokens)
print(stats)
print(sorted(((v,k) for k,v in stats.items()), reverse=True))

In [None]:
top_pair = max(stats, key=stats.get)
top_pair

In [None]:
chr(224), chr(164)

In [None]:
def merge(ids, pair, idx):
  # in the list of ints (ids), replace all consecutive occurences of pair with the new token idx
  newids = []
  i = 0
  while i < len(ids):
    # if we are not at the very last position AND the pair matches, replace it
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# tokens2 = merge(tokens, top_pair, 256)
tokens2 = merge(dev_tokens, top_pair, 256)


# print(tokens2)
print("length:", len(tokens2), len(dev_tokens))

In [None]:
print("length:", len(tokens2), len(tokens))

length: 2689593 3510382

> Add blockquote



In [None]:
# # making the training text longer to have more representative token statistics
# Extract the first 1000 rows of the 'text' column
text = " ".join(dataset["train"]["text"][:1000])
print("length of text before regex",len(text))
text = " ".join(devanagari_word.findall(text))
print("length of text after regex",len(text))
tokens = text.encode("utf-8") # raw bytes
tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience

In [None]:
len(text), len(tokens2), len(tokens)

In [None]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
vocab_size = 276 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

In [None]:
print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

In [None]:
def get_stats(ids):
    counts = {}
    for pair in zip(ids, ids[1:]):
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
  newids = []
  i = 0
  while i < len(ids):
    if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
      newids.append(idx)
      i += 2
    else:
      newids.append(ids[i])
      i += 1
  return newids

# ---
vocab_size = 5000 # the desired final vocabulary size
num_merges = vocab_size - 256
ids = list(tokens) # copy so we don't destroy the original list

merges = {} # (int, int) -> int
for i in range(num_merges):
  stats = get_stats(ids)
  pair = max(stats, key=stats.get)
  idx = 256 + i
  print(f"merging {pair} into a new token {idx}")
  ids = merge(ids, pair, idx)
  merges[pair] = idx

print("tokens length:", len(tokens))
print("ids length:", len(ids))
print(f"compression ratio: {len(tokens) / len(ids):.2f}X")

### Decoding

Given a sequence of integers in the range [0, vocab_size], what is the text?


In [None]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0, p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
  # given ids (list of integers), return Python string
  tokens = b"".join(vocab[idx] for idx in ids)
  text = tokens.decode("utf-8", errors="replace")
  return text

print(decode([128, 255, 233]))

### Encoding

The other way around: Given a string, what are the tokens?


In [None]:
# merges
len(merges)

In [None]:
def encode(text):
  # given a string, return list of integers (the tokens)
  tokens = list(text.encode("utf-8"))
  while len(tokens) >= 2:
    stats = get_stats(tokens)
    pair = min(stats, key=lambda p: merges.get(p, float("inf")))
    if pair not in merges:
      break # nothing else can be merged
    idx = merges[pair]
    tokens = merge(tokens, pair, idx)
  return tokens

print(encode(""))

In [None]:
print(decode(encode(text0)))

In [None]:
text2 = decode(encode(text0))
print(text2 == text0)