In [3]:
from collections import Counter

class BytePairTokenizer:
  def __init__(self, target_vocab_size):
    self.target_vocab_size = target_vocab_size
    self.merges = []
  
  def _compute_word_frequencies(self, corpus):
    """Compute word frequencies from input corpus."""
    return Counter(" ".join(corpus).split())
  
  def _initialize_tokens(self, word_frequency):
    """Initialize tokens as single characters."""
    return {word:list(word) for word in word_frequency}
  
  def _compute_pair_frequencies(self, tokens, word_frequency):
    """Compute frequencies of adjacent token pairs."""
    pair_freq = Counter()
    for word, freq in word_frequency.items():
      token = tokens[word]
      for i in range(len(token) - 1):
        pair_freq[tuple(token[i:i+2])] += freq
    return pair_freq
  
  def _find_best_pair(self, pair_freq):
    """Find most frequent pair in our tokenized corpus"""
    if pair_freq:
      return pair_freq.most_common(1)[0][0]
    return None

  def _merge_tokens(self, tokens, pair):
    """Replace occurences of a pair of tokens with the merged form."""
    merged_tokens = {}
    merge_str = ''.join(pair)
    for word, token in tokens.items():
      new_token = []
      i = 0
      while i < len(token):
        if i < len(token) - 1 and tuple(token[i:i+2]) == pair:
          new_token.append(merge_str)
          i += 2
        else:
          new_token.append(token[i])
          i += 1
      merged_tokens[word] = new_token
    return merged_tokens

  def fit(self, corpus):
    word_frequency = self._compute_word_frequencies(corpus)
    tokens = self._initialize_tokens(word_frequency)

    while len(self.merges) + len(set("".join(tokens.keys()))) < self.target_vocab_size:
      pair_frequency = self._compute_pair_frequencies(tokens, word_frequency)
      best_pair = self._find_best_pair(pair_frequency)
      if not best_pair:
        break
      self.merges.append(best_pair)
      tokens = self._merge_tokens(tokens, best_pair)

  def get_merge_length(self, merges):
    '''
    Get maximum length of token in a merge pair.
    '''
    merge_lengths = []
    for merge in merges:
      merge_lengths.append(max([len(c) for c in merge]))
    return merge_lengths

  def sort_merges(self, merges):
    '''
    Sorts a list of tuples of merges, to be in the order of the maximum
    length of token they are merging.
    '''
    merge_lengths = self.get_merge_length(merges)
    return [merge for _, merge in sorted(zip(merge_lengths, merges))]
  
  def tokenize(self, corpus):
      """Iteratively merge learnt pairs of tokens in a corpus."""

      merge_dict = {m:m[0]+m[1] for m in self.merges}

      tokenized_corpus = []
      for document in corpus:

          document = list(document)

          for merge_pair, merged_token in merge_dict.items():
            i = 0
            while i < len(document) - 1:
              if tuple(document[i:i + len(merge_pair)]) == merge_pair:
                document = document[:i] + [merged_token] + document[i + len(merge_pair):]
              i+=1
          
          tokenized_corpus.append(document)
      return tokenized_corpus

corpus = [
  "Hello world!",
  "It is a great day to hug puppies.",
  "It would be a shame to not do so.",
  "This text is full of nonsense: I don't care!",
  "I hope I have enough pair variety here to get an interesting result.",
  "This project is going to be a challenge",
  "This is all about tokenization.",
  "I'm trying to make this tokenization algorithm work",
  "I really I hope this works.",
]

In [4]:
bpe = BytePairTokenizer(100)
bpe.fit(corpus)
tokens = bpe.tokenize(corpus)
print(tokens)

[['Hello', ' ', 'world!'], ['It', ' ', 'is', ' ', 'a', ' ', 'great', ' ', 'day', ' ', 'to', ' ', 'hug', ' ', 'puppies.'], ['It', ' ', 'would', ' ', 'be', ' ', 'a', ' ', 'shame', ' ', 'to', ' ', 'not', ' ', 'do', ' ', 'so.'], ['This', ' ', 'text', ' ', 'is', ' ', 'full', ' ', 'of', ' ', 'non', 's', 'en', 's', 'e', ':', ' ', 'I', ' ', 'd', 'on', "'", 't', ' ', 'c', 'a', 're', '!'], ['I', ' ', 'hope', ' ', 'I', ' ', 'ha', 'v', 'e', ' ', 'en', 'o', 'ug', 'h', ' ', 'p', 'a', 'i', 'r', ' ', 'v', 'a', 'r', 'ie', 't', 'y', ' ', 'h', 'e', 're', ' ', 'to', ' ', 'ge', 't', ' ', 'a', 'n', ' ', 'in', 'te', 'res', 't', 'ing', ' ', 'res', 'u', 'l', 't', '.'], ['This', ' ', 'p', 'r', 'o', 'j', 'e', 'c', 't', ' ', 'is', ' ', 'go', 'ing', ' ', 'to', ' ', 'be', ' ', 'a', ' ', 'c', 'ha', 'll', 'en', 'ge'], ['This', ' ', 'is', ' ', 'all', ' ', 'a', 'b', 'o', 'u', 't', ' ', 'tokenization', '.'], ['I', "'", 'm', ' ', 't', 'r', 'y', 'ing', ' ', 'to', ' ', 'm', 'a', 'k', 'e', ' ', 'this', ' ', 'tokenization', 