In [3]:
from collections import defaultdict

In [33]:
from collections import defaultdict, Counter

class BytePairTokenizer:
    def __init__(self, target_vocab_size):
        self.target_vocab_size = target_vocab_size
        self.merges = []
    
    def _compute_word_frequencies(self, corpus):
        """
        Compute word frequencies from input corpus.
        """
        return Counter([word for text in corpus for word in text.split()])
    
    def _initialize_tokens(self, word_frequency):
        """
        Initialize tokens as single characters.
        """
        return {word:list(word) for word in word_frequency.keys()}
    
    def _compute_pair_frequencies(self, tokens, word_frequency):
        """
        Compute frequencies of adjacent token pairs.
        """
        pair_freq = defaultdict(int)
        for word, freq in word_frequency.items():
            token = tokens[word]
            for i in range(len(token) - 1):
                pair = (token[i], token[i + 1])
                pair_freq[pair] += freq
        return pair_freq
    
    def _find_best_pair(self, pair_freq):
        """
        Find most frequent pair in our tokenized corpus
        """
        return max(pair_freq, key=pair_freq.get)

    def _merge_tokens(self, tokens, pair):
        """
        Replace occurences of a pair of tokens with the merged form.
        """
        merged_tokens = {}
        merge_str = ''.join(pair)
        for word, token in tokens.items():
            new_token = []
            i = 0
            while i < len(token):
                if i < len(token) - 1 and (token[i], token[i + 1]) == pair:
                    new_token.append(merge_str)
                    i += 2
                else:
                    new_token.append(token[i])
                    i += 1
            merged_tokens[word] = new_token
        return merged_tokens


    def fit(self, corpus):
       word_frequency = self._compute_word_frequencies(corpus)
       tokens = self._initialize_tokens(word_frequency)

       while len(self.merges) + len(set("".join(tokens.keys()))) < self.target_vocab_size:
           pair_frequency = self._compute_pair_frequencies(tokens, word_frequency)
           best_pair = self._find_best_pair(pair_frequency)
           if not best_pair:
               break
           self.merges.append(best_pair)
           tokens = self._merge_tokens(tokens, best_pair)

    def tokenize(self, corpus):
      tokenized_corpus = []
      for string in corpus:
        
        tokenized_string = []
        i=0
        while i < len(string) - 1:
          if (i < len(string) - 1) & ((string[i], string[i + 1]) in self.merges):
            tokenized_string.append(string[i] + string[i+1])
            # i += 2
          else:
            tokenized_string.append(string[i])
            i += 1
        tokenized_corpus.append(tokenized_string)
      print(tokenized_corpus)
            
corpus = [
    # "Hello world!",
    # "It is a great day to hug puppies.",
    # "It would be a shame to not do so.",
    # "This text is full of nonsense: I don't care!",
    # "I hope I have enough pair variety here to get an interesting result.",
    # "This project is going to be a challenge",
    "This is all about tokenization.",
    "I'm trying to make this tokenization algorithm work",
    # "I really I hope this works.",
        ]

bpe = BytePairTokenizer(30)
bpe.fit(corpus)
print(bpe.merges)
bpe.tokenize(corpus)

[('i', 's'), ('t', 'o'), ('k', 'e'), ('h', 'is'), ('a', 'l'), ('to', 'ke'), ('toke', 'n'), ('token', 'i')]


KeyboardInterrupt: 

In [None]:
# Maybe I can split my text into pairs, check if pairs in merge dict, replace pair with that value until I can't find anymore
string_pairs = [(string[i], string[i+1]) for i in range(len(string) - 1)]
string_from_pairs = [pair[0] for pair in string_pairs]