In [43]:
from collections import defaultdict
import nltk
import pandas as pd
import numpy as np
from tqdm import tqdm # Progress bar

In [104]:
class WordPieceTokenizer:
    def __init__(self, text: str, max_vocabulary_size: int):
        self.text = text
        self.max_vocabulary_size = max_vocabulary_size
        self.prefix = "#"
        self.unknown_token = "[UNK]"
        self.corpus = list() # The text, tokenized using self.vocabulary
        self.vocabulary = defaultdict(int) # The vocabulary of learned tokens and their frequency in self.text
        self.init_corpus_and_vocabulary()

    def init_corpus_and_vocabulary(self):
        """Initalizes corpus and vocabulary based on self.text"""
        for word in self.text.split(): # For each word in the text
            tokenized_word = list()
            for idx, char in enumerate(word.strip()): # Split into default tokens (either x or #x)
                token = char if idx == 0 else self.prefix + char
                tokenized_word.append(token)
                self.vocabulary[token] += 1
            self.corpus.append(tokenized_word) # Then save the tokenized word into corpus

    def train(self, normalize: bool = True):
        for _ in tqdm(range(self.max_vocabulary_size)): # Equivalent to checking vocabulary size, as each cycle adds one vocabulary
            check = self.update(normalize=normalize)
            if not check: break

    def get_token_pairs(self, normalize: bool = True) -> pd.Series:
        scores = defaultdict(int) # The frequency (absolute or relative) of each pair
        for tokenized_word in self.corpus:
            for a, b in nltk.ngrams(tokenized_word, n=2): # Take 2-grams of the tokens in which the word was split
                scores[(a, b)] += 1
        if normalize:
            scores = {
                (a, b): score / (self.vocabulary[a] * self.vocabulary[b])
                for (a, b), score in scores.items()
            }
        return pd.Series(scores).sort_values(ascending=False)

    def update(self, normalize: bool = True) -> bool:
        scores = self.get_token_pairs(normalize=normalize)
        if scores.empty: return False

        (a, b), score = list(scores.items())[0]
        if len(self.vocabulary) >= self.max_vocabulary_size - 1:
            return False

        # Very lazy update (just re-tokenize every word after updating vocabulary)
        self.vocabulary[a + b.replace(self.prefix, "")] += score
        self.corpus = [self.tokenize_text(word) for word in self.text.split()]
        return True

    def tokenize_text(self, text: str) -> list[str]:
        words = text.split()
        return sum([self.tokenize_word(word) for word in words], [])

    def tokenize_word(self, word: str) -> list[str]:
        if word in self.vocabulary:
            return [word]

        tokens = list()
        start = 0
        while start < len(word):
            end = len(word)
            # Find the longest possible token
            new_token = None
            while start < end:
                current_substring = word[start:end] if start == 0 else self.prefix + word[start:end]
                if current_substring in self.vocabulary:
                    new_token = current_substring
                    break
                end -= 1

            if not new_token: # No token applicable (not even single chars)
                return [self.unknown_token]

            tokens.append(new_token)
            start = end # Go to next char and repeat the search

        return tokens


In [107]:
text = """
Years ago, the fearsome Pirate King, Gol D. Roger was executed leaving a huge pile of treasure and the famous "One Piece" behind.
Whoever claims the "One Piece" will be named the new King of the Pirates.
Monkey D. Luffy, a boy who consumed a "Devil Fruit," decides to follow in the footsteps of his idol, the pirate Shanks, and find the One Piece.
It helps, of course, that his body has the properties of rubber and that he's surrounded by a bevy of skilled fighters and thieves to help him along the way.
Luffy will do anything to get the One Piece and become King of the Pirates!
"""

test = "consumed a"

tokenizer = WordPieceTokenizer(text=text, max_vocabulary_size=180)
tokenizer.train(normalize=False)

tokenizer.tokenize_text(test)

 66%|██████▌   | 118/180 [00:00<00:00, 531.85it/s]


['con', '#su', '#med', 'a']