<a href="https://colab.research.google.com/github/QasimWani/simple-transformer/blob/main/transformers/tokenization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# Deepdive into the mystical world of tokenization. This is fundamental to making LLMs work

# Topics:
# 1. Character tokenization
# 2. Byte tokenization
# 3. Bypte-pair encoding - BPE
# 4. WordPiece tokenization

In [None]:
import torch
import numpy as np
from abc import ABC, abstractmethod
from collections import defaultdict, Counter

In [None]:
data = """
“If you’re going to try, go all the way.

Otherwise, don’t even start.

This could mean losing girlfriends, wives, relatives and maybe even your mind.

It could mean not eating for three or four days.

It could mean freezing on a park bench.

It could mean jail.

It could mean derision.

It could mean mockery — isolation.

Isolation is a gift.

All the others are a test of your endurance, of how much you really want to do it.

And, you’ll do it, despite rejection and the worst odds.

And it will be better than anything else you can imagine.

If you’re going to try, go all the way.

There is no other feeling like that.

You will be alone with the gods, and the nights will flame with fire.

You will ride life straight to perfect laughter.

It’s the only good fight there is.”

- Charles Bukowski (1920 – 1994) 🪦
"""

In [None]:
class BaseTokenizer(ABC):

  @abstractmethod
  def encode(self, text: str):
    pass

  @abstractmethod
  def decode(self, tokens: list):
    pass

  def is_lossless(self, text):
    tokens = self.encode(text)
    decoded_text = self.decode(tokens)
    assert decoded_text == text, f"{text} != {decoded_text}"

  def compression_factor(self, text):
    self.is_lossless(text)
    num_tokens = len(self.encode(text))
    return len(text) / num_tokens # ratio of uncompressed to compressed representation


In [None]:
class CharTokenizer(BaseTokenizer):

  def encode(self, text):
      '''
      This is a very simple tokenizer that makes use of Unicode code points.
      There are two issues with using this as our default tokenizer.
      1) Vocab size too large - up to a 1M depending on what type of encoding you choose.
      2) Expanding vocab size. The unicode encoding schema keeps adding new character
      so lack of consistency. You want a frozen tokenizer.
      '''
      return [ord(char) for char in text]


  def decode(self, tokens: list):
    return ''.join([chr(token) for token in tokens])


In [None]:
class ByteTokenizer(BaseTokenizer):

  def encode(self, text: str):
    '''
    The advantage of using a byte tokenizer is that the vocab_size
    is bounded to 256. Each character can be represented as a list
    of bytes, where each byte is in range(0, 256).
    The advantage of this is we have a tiny vocab_size, but the
    major disadvantage is that our sequence length will be very
    long. As you can tell, the upper bound of `ByteTokenizer.compression_factor`
    is 1.0 because a single character can be composed of multiple
    tokens.

    A better alternative to this is using BPE.
    '''
    return list(text.encode('utf-8'))

  def decode(self, tokens):
    return bytes(tokens).decode('utf-8')


In [None]:
class BPE(BaseTokenizer):

  def __init__(self, vocab_size):
    self.merge_mapping = {}
    self.vocab_size = vocab_size
    assert vocab_size >= 256, f"BPE requires the vocab size to be larger than 256"

  def most_frequent_pair(self, tokens):
    # list of tokens.
    # Output: the specific token pair that occurs the most amount of times
    freq = {}
    for pair in zip(tokens, tokens[1:]): # consecutive indices search
      freq[pair] = freq.get(pair, 0) + 1

    # Get the most common
    max_pair = max(freq, key=freq.get) # gets the highest value
    max_value = freq[max_pair]
    if max_value > 1:
      return max_pair
    else:
      return None

  def merge(self, tokens, pair, token_id):
    # Everytime you see the desired pair in tokens, replace it with new_token
    new_tokens = []
    i = 0
    num_tokens = len(tokens)
    while i < num_tokens:
      if i < num_tokens - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
        new_tokens.append(token_id)
        i += 2
      else:
        new_tokens.append(tokens[i])
        i += 1
    return new_tokens

  def join(self, tokens, pair, token_id):
    # Perform the opposite of the merge function
    num_tokens = len(tokens)
    new_tokens = []

    for i in range(num_tokens):
      if tokens[i] == token_id:
        new_tokens.extend(pair)
      else:
        new_tokens.append(tokens[i])
    return new_tokens


  def encode_once(self, text: str):
    '''
    Proposed by Philip Gage in 1994 for data compression.
    The idea is as follows (https://en.wikipedia.org/wiki/Byte-pair_encoding):
    Suppose you have the string: aaabdaaabac

    We first identify the most frequent byte-pair and replace it with a
    byte code not represented in the string.

    ZabdZabac ; Z = aa

    Repeat the process with the most frequent byte-pair: ab

    ZYdZYac ; Y = ab, Z = aa

    Because the remainder of the characters only repeat once,
    we can either end here OR we can go further to optimize it
    even more:

    XdXac ; X = ZY, Y = ab, Z = aa

    Decoding is simple, just move in reverse order.

    1. The disadvantages of BPE is that there's no sense of true word/byte tokenization.
    Because most frequent parings are tokenized, we could have weird tokens like x representing facti (assuming the word is faction).
    Additionally there's no notion of capturing semantics between tokens. 'dog' and 'dog.' may have very different tokens.
    # To solve for this, GPT-2 implemented a regex expression: https://github.com/openai/gpt-2/blob/master/src/encoder.py#L53

    This processes the string by splitting the punctuation into a separate item in the list.
    So, if you have an input string: "I love dogs. They're 2 nice" --> [I, love, dogs,., They, 're, 2, nice] // This solves for the first issue I brought up

    2. Data distribution MATTERS a lot. Frequent words have very high compression ratios because they occur more frequently.
    But languages that don't occur frequently in the train set have a much lower compression ratio and are usually quite fragmented.
    '''

    # we have a base vocab size of 256. that just comes from the fact that a
    # every character is represented by a series of bytes and there are 256
    # possible values. i.e. [0, 255]
    # So the very first thing we need to do is calculate the number of merges
    # we need to make.

    num_merges = self.vocab_size - 256

    tokens = list(text.encode('utf-8'))

    merge_mapping = {} # pair as key, new_token_id as value

    i = 0
    for i in range(num_merges):
      # Step 1. get the most frequent byte pair
      pair = self.most_frequent_pair(tokens)
      if pair is None: # if the most frequent pair is 1, we can just stop here.
        print(f'Merging terminated early at step: {i}')
        break
      # Step 2. Merge
      tokens = self.merge(tokens, pair, i + 256)
      merge_mapping[pair] = i + 256

    self.merge_mapping = merge_mapping
    print(f'Final vocab size: {256 + i}')
    return tokens

  def encode(self, text):
    if len(self.merge_mapping) == 0:
      return self.encode_once(text)

    tokens = list(text.encode('utf-8'))
    for pair, token_id in self.merge_mapping.items():
      tokens = self.merge(tokens, pair, token_id)
    return tokens


  def decode(self, tokens):
    if self.merge_mapping is None:
      # No merge dictionary, this will be equivalent to Byte decoding
      return bytes(tokens).decode('utf-8')

    new_tokens = list(tokens) # create a copy
    for pair, token_id in reversed(self.merge_mapping.items()): # Python 3.7+ preverses insertion order, reverse makes it sorted
      # At each step, join the highest value merge
      new_tokens = self.join(new_tokens, pair, token_id)

    return bytes(new_tokens).decode('utf-8')

In [None]:
from collections import Counter

class WordPiece(BaseTokenizer):

  def __init__(self, vocab_size: int):
    self.merge_mapping = {}
    self.vocab_size = vocab_size
    assert vocab_size >= 256, f"WordPiece requires the vocab size to be larger than 256"

  def compute_maximum_likelihood(self, tokens):
    if len(tokens) <= 1:
      return None

    unigram = Counter(tokens)
    bigram = Counter(zip(tokens, tokens[1:]))

    # Max. likelihood can be deceptive in that rare tokens will have a likelihood score of 1 since freq(pair) = 1 and freq(token1) and freq(token2) = 1
    # To circumvent this, we just impose a simple filter where we only take bigrams that have a frequency of at least 2
    candidates = {pair: score for pair, score in bigram.items() if score > 1}
    if not candidates: # terminate early
      return None

    scores = {}
    eps = 1 # Smoothing factor to prevent rare tokens from having too high of PMI (pointwise mutual information)
    for pair, count in candidates.items():
      scores[pair] = count / ((eps + unigram[pair[0]]) * (eps + unigram[pair[1]]))

    max_score = max(scores, key=lambda p: (scores[p], candidates[p])) # two-tier tiebreaker. if two pairs have the same likelihood, pick the one that has higher joint frequency
    return max_score


  def merge(self, tokens, pair, new_token_id):
    new_tokens = []
    num_tokens = len(tokens)

    i = 0
    while i < num_tokens:
      if i < num_tokens - 1 and tokens[i] == pair[0] and tokens[i + 1] == pair[1]:
        new_tokens.append(new_token_id)
        i += 2
      else:
        new_tokens.append(tokens[i])
        i += 1
    return new_tokens

  def encode_once(self, text):
    '''
    The algorithm is exactly identical to BPE, except instead of computing the pair with highest frequency,
    we compute the pair with highest likelihood, i.e. score(x, y) = frequency(x, y) / frequency(x) * frequency(y)
    We assume Markov property, i.e. each token is independent of every other token.
    '''

    tokens = list(text.encode('utf-8'))
    num_tokens = len(tokens)

    num_merges = self.vocab_size - 256

    merge_mapping = {} # pair, token_id
    i = 0
    for i in range(num_merges):
      # Step 1. find pair with highest likelihood
      pair = self.compute_maximum_likelihood(tokens)
      if pair is None:
        print(f"Terminating early with {i} merges")
        break
      # Step 2. merge
      tokens = self.merge(tokens, pair, 256 + i)
      merge_mapping[pair] = 256 + i

    print(f"Final vocab size: {256 + i}")
    self.merge_mapping = merge_mapping
    return tokens

  def encode(self, text):
    if len(self.merge_mapping) == 0:
      return self.encode_once(text)

    tokens = list(text.encode('utf-8'))
    for pair, token_id in self.merge_mapping.items():
      tokens = self.merge(tokens, pair, token_id)
    return tokens


  def join(self, tokens, pair, token_id):
    num_tokens = len(tokens)
    new_tokens = []

    for i in range(num_tokens):
      if tokens[i] == token_id:
        new_tokens.extend(pair)
      else:
        new_tokens.append(tokens[i])
    return new_tokens

  def decode(self, tokens):
    new_tokens = list(tokens)
    for pair, token_id in reversed(self.merge_mapping.items()):
      new_tokens = self.join(new_tokens, pair, token_id)
    return bytes(new_tokens).decode('utf-8')


In [None]:
char_tokenizer = CharTokenizer()
byte_tokenizer = ByteTokenizer()
bpe_tokenizer = BPE(vocab_size=400)
wordpiece_tokenizer = WordPiece(vocab_size=400)

char_compression_factor = char_tokenizer.compression_factor(data)
byte_compression_factor = byte_tokenizer.compression_factor(data)
bpe_compression_factor = bpe_tokenizer.compression_factor(data)
wordpiece_compression_factor = wordpiece_tokenizer.compression_factor(data)

print("\nCompression factors - higher is better")
print(f"Char compression factor: {round(char_compression_factor, 3)}")
print(f"Byte compression factor: {round(byte_compression_factor, 3)}")
print(f"BPE compression factor: {round(bpe_compression_factor, 3)}")
print(f"WordPiece compression factor: {round(wordpiece_compression_factor, 3)}")

Merging terminated early at step: 111
Final vocab size: 367
Final vocab size: 399

Compression factors - higher is better
Char compression factor: 1.0
Byte compression factor: 0.975
BPE compression factor: 2.465
WordPiece compression factor: 2.495
