In [2]:
from collections import defaultdict

In [31]:
from collections import defaultdict, Counter

class BytePairTokenizer:
  def __init__(self, target_vocab_size):
    self.target_vocab_size = target_vocab_size
    self.merges = []
  
  def _compute_word_frequencies(self, corpus):
    """
    Compute word frequencies from input corpus.
    """
    return Counter([word for text in corpus for word in text.split()])
  
  def _initialize_tokens(self, word_frequency):
    """
    Initialize tokens as single characters.
    """
    return {word:list(word) for word in word_frequency.keys()}
  
  def _compute_pair_frequencies(self, tokens, word_frequency):
    """
    Compute frequencies of adjacent token pairs.
    """
    pair_freq = defaultdict(int)
    for word, freq in word_frequency.items():
      token = tokens[word]
      for i in range(len(token) - 1):
        pair = (token[i], token[i + 1])
        pair_freq[pair] += freq
    return pair_freq
  
  def _find_best_pair(self, pair_freq):
    """
    Find most frequent pair in our tokenized corpus
    """
    return max(pair_freq, key=pair_freq.get)

  def _merge_tokens(self, tokens, pair):
    """
    Replace occurences of a pair of tokens with the merged form.
    """
    merged_tokens = {}
    merge_str = ''.join(pair)
    for word, token in tokens.items():
      new_token = []
      i = 0
      while i < len(token):
        if i < len(token) - 1 and (token[i], token[i + 1]) == pair:
          new_token.append(merge_str)
          i += 2
        else:
          new_token.append(token[i])
          i += 1
      merged_tokens[word] = new_token
    return merged_tokens

  def fit(self, corpus):
    word_frequency = self._compute_word_frequencies(corpus)
    tokens = self._initialize_tokens(word_frequency)

    while len(self.merges) + len(set("".join(tokens.keys()))) < self.target_vocab_size:
      pair_frequency = self._compute_pair_frequencies(tokens, word_frequency)
      best_pair = self._find_best_pair(pair_frequency)
      if not best_pair:
        break
      self.merges.append(best_pair)
      tokens = self._merge_tokens(tokens, best_pair)

  def get_merge_length(self, merges):
    '''
    Get maximum length of token in a merge pair.
    '''
    merge_lengths = []
    for merge in merges:
      merge_lengths.append(max([len(c) for c in merge]))
    return merge_lengths

  def sort_merges(self, merges):
    '''
    Sorts a list of tuples of merges, to be in the order of the maximum
    length of token they are merging.
    '''
    merge_lengths = bpe.get_merge_length(merges)
    return [merge for length, merge in sorted(zip(merge_lengths, merges))]

  def tokenize(self, corpus):
    '''
    Iteratively merge learnt pairs of tokens in a corpus.
    '''
    # Find max merge depth to iterate to
    max_merge_depth = max(self.get_merge_length(self.merges))
    
    # Sort merges by depth needed to reach them
    sorted_merges = self.sort_merges(self.merges)
    
    # Get merged string for each merge pair
    merge_dict = {m:m[0]+m[1] for m in sorted_merges}

    tokenized_corpus = []
    for document in corpus:
      document = list(document)
      
      # Re-iterate as many times as max merge depth to ensure we merge all pairs
      merge_depth = 0
      while merge_depth < max_merge_depth:
        
        # Iterate across entire document
        i = 0
        while i < len(document) - 1:
          pair = (document[i], document[i+1])
          if pair in merge_dict.keys():
            merge_string = merge_dict.get(pair)
            document[i] = merge_string
            del document[i+1]
          i += 1
        merge_depth += 1
        
      tokenized_corpus.append(document)
    return tokenized_corpus

corpus = [
  "Hello world!",
  "It is a great day to hug puppies.",
  "It would be a shame to not do so.",
  "This text is full of nonsense: I don't care!",
  "I hope I have enough pair variety here to get an interesting result.",
  "This project is going to be a challenge",
  "This is all about tokenization.",
  "I'm trying to make this tokenization algorithm work",
  "I really I hope this works.",
]

bpe = BytePairTokenizer(100)
bpe.fit(corpus)
print(f"Merges: {bpe.merges}")
tokens = bpe.tokenize(corpus)


Merges: [('i', 's'), ('t', 'o'), ('r', 'e'), ('l', 'l'), ('h', 'is'), ('e', 'n'), ('w', 'o'), ('o', 'n'), ('i', 'n'), ('wo', 'r'), ('a', 't'), ('h', 'a'), ('T', 'his'), ('in', 'g'), ('l', 'd'), ('I', 't'), ('u', 'g'), ('i', 'e'), ('s', '.'), ('b', 'e'), ('t', 'e'), ('h', 'o'), ('ho', 'p'), ('hop', 'e'), ('g', 'e'), ('re', 's'), ('g', 'o'), ('a', 'll'), ('to', 'k'), ('tok', 'en'), ('token', 'i'), ('tokeni', 'z'), ('tokeniz', 'at'), ('tokenizat', 'i'), ('tokenizati', 'on'), ('t', 'his'), ('wor', 'k'), ('H', 'e'), ('He', 'll'), ('Hell', 'o'), ('wor', 'ld'), ('world', '!'), ('g', 're'), ('gre', 'at'), ('d', 'a'), ('da', 'y'), ('h', 'ug'), ('p', 'u'), ('pu', 'p'), ('pup', 'p'), ('pupp', 'ie'), ('puppie', 's.'), ('wo', 'u'), ('wou', 'ld'), ('s', 'ha'), ('sha', 'm'), ('sham', 'e'), ('n', 'o'), ('no', 't'), ('d', 'o'), ('s', 'o'), ('so', '.'), ('te', 'x'), ('tex', 't'), ('f', 'u'), ('fu', 'll'), ('o', 'f'), ('n', 'on')]


In [32]:
tokens

[['Hello', ' ', 'world!'],
 ['It',
  ' ',
  'is',
  ' ',
  'a',
  ' ',
  'great',
  ' ',
  'day',
  ' ',
  'to',
  ' ',
  'hug',
  ' ',
  'puppies.'],
 ['It',
  ' ',
  'would',
  ' ',
  'be',
  ' ',
  'a',
  ' ',
  'shame',
  ' ',
  'to',
  ' ',
  'not',
  ' ',
  'do',
  ' ',
  'so.'],
 ['This',
  ' ',
  'text',
  ' ',
  'is',
  ' ',
  'full',
  ' ',
  'of',
  ' ',
  'no',
  'n',
  's',
  'en',
  's',
  'e',
  ':',
  ' ',
  'I',
  ' ',
  'do',
  'n',
  "'",
  't',
  ' ',
  'c',
  'a',
  're',
  '!'],
 ['I',
  ' ',
  'hope',
  ' ',
  'I',
  ' ',
  'ha',
  'v',
  'e',
  ' ',
  'en',
  'o',
  'ug',
  'h',
  ' ',
  'p',
  'a',
  'i',
  'r',
  ' ',
  'v',
  'a',
  'r',
  'ie',
  't',
  'y',
  ' ',
  'h',
  'e',
  're',
  ' ',
  'to',
  ' ',
  'ge',
  't',
  ' ',
  'a',
  'n',
  ' ',
  'in',
  'te',
  'res',
  't',
  'ing',
  ' ',
  'res',
  'u',
  'l',
  't',
  '.'],
 ['This',
  ' ',
  'p',
  'r',
  'o',
  'j',
  'e',
  'c',
  't',
  ' ',
  'is',
  ' ',
  'go',
  'ing',
  ' ',
  'to',
  ' '