In [2]:
from collections import defaultdict
from typing import List

In [3]:
from collections import defaultdict
from typing import List, Tuple

class BPETokenizer:
    """
    A Byte Pair Encoding (BPE) Tokenizer for training on a corpus and tokenizing text.

    Attributes:
        corpus (List[str]): The input text corpus for training the tokenizer.
        vocab_size (int): The desired size of the vocabulary.
        merge_limit (int): Limit on the number of merges performed during training.
        word_freq (defaultdict): Frequency of words in the corpus.
        splits (dict): Mapping of words to their split representation (character-level).
        vocab (List[str]): The vocabulary of the tokenizer.
        merge_rules (List[Tuple[str, str]]): List of merge rules learned during training.
    """

    def __init__(self, corpus: List[str], vocab_size: int, merge_limit: int):
        """
        Initializes the BPE Tokenizer with the given corpus, vocabulary size, and merge limit.

        Args:
            corpus (List[str]): List of text strings for training.
            vocab_size (int): Desired vocabulary size.
            merge_limit (int): Maximum number of merges to perform.
        """
        self._corpus = corpus
        self._vocab_size = vocab_size
        self._merge_limit = merge_limit
        self._word_freq = defaultdict(int)
        self._splits = dict()
        self._vocab = []
        self._merge_rules = []

    def train(self):
        """
        Trains the BPE tokenizer by learning merge rules from the corpus.
        """
        self._prepare_corpus()
        while len(self._vocab) < self._vocab_size:
            pairs_freq = self._get_pair_freq()
            most_freq_pair, freq = self._get_most_freq_pair(pairs_freq)
            self._merge(*most_freq_pair)
            self._update_vocab(most_freq_pair)
            self._update_merge_rules(most_freq_pair)

    def tokenize(self, text: str) -> List[str]:
        """
        Tokenizes the input text using the learned merge rules.

        Args:
            text (str): The input text to tokenize.

        Returns:
            List[str]: The tokenized representation of the input text.
        """
        tokenized_text = []
        words = text.split(" ")
        splits = [[c for c in word] for word in words]

        for rule in self._merge_rules:
            for idx, split in enumerate(splits):
                i = 0
                while i < len(split) - 1:
                    a, b = split[i], split[i + 1]
                    if a + b == "".join(rule):
                        split = split[:i] + [a + b] + split[i + 2:]
                    else:
                        i += 1
                splits[idx] = split

        return sum(splits, [])

    def decode(text: str):
        pass

    def _prepare_corpus(self):
        """
        Prepares the corpus by calculating word frequencies, creating splits, and building the initial vocabulary.
        """
        self._get_word_freq()
        self._create_splits()
        self._build_initial_vocab()

    def _get_word_freq(self):
        """
        Calculates the frequency of each word in the corpus.
        """
        for text in self._corpus:
            words = text.split(" ")
            for word in words:
                self._word_freq[word] += 1

    def _create_splits(self):
        """
        Splits each word in the corpus into characters.
        """
        self._splits = {word: [c for c in word] for word in self._word_freq}

    def _build_initial_vocab(self):
        """
        Builds the initial vocabulary consisting of unique characters in the corpus.
        """
        for word, splits in self._splits.items():
            for c in splits:
                if c not in self._vocab:
                    self._vocab.append(c)

    def _get_pair_freq(self) -> defaultdict:
        """
        Calculates the frequency of adjacent character pairs in the corpus.

        Returns:
            defaultdict: Frequency of character pairs.
        """
        pairs_freq = defaultdict(int)
        for word, freq in self._word_freq.items():
            splits = self._splits[word]
            for i in range(len(splits) - 1):
                pair = (splits[i], splits[i + 1])
                pairs_freq[pair] += freq
        return pairs_freq

    def _get_most_freq_pair(self, pairs_freq: defaultdict) -> Tuple[Tuple[str, str], int]:
        """
        Finds the most frequent character pair in the corpus.

        Args:
            pairs_freq (defaultdict): Frequency of character pairs.

        Returns:
            Tuple[Tuple[str, str], int]: The most frequent pair and its frequency.
        """
        most_freq = None
        max_freq = None
        for pair, freq in pairs_freq.items():
            if most_freq is None or freq > max_freq:
                most_freq = pair
                max_freq = freq
        return most_freq, max_freq

    def _merge(self, a: str, b: str):
        """
        Merges a pair of characters in all words in the corpus.

        Args:
            a (str): The first character of the pair.
            b (str): The second character of the pair.
        """
        for word in self._word_freq.keys():
            split = self._splits[word]

            if len(split) == 1:
                continue

            i = 0
            while i < len(split) - 1:
                first_element, second_element = split[i], split[i + 1]
                if first_element == a and second_element == b:
                    split = split[:i] + [a + b] + split[i + 2:]
                else:
                    i += 1
            self._splits[word] = split

    def _update_vocab(self, pair: Tuple[str, str]):
        """
        Adds a merged pair to the vocabulary.

        Args:
            pair (Tuple[str, str]): The merged character pair.
        """
        self._vocab.append("".join(pair))

    def _update_merge_rules(self, merge_rule: Tuple[str, str]):
        """
        Updates the list of merge rules with the new rule.

        Args:
            merge_rule (Tuple[str, str]): The new merge rule.
        """
        self._merge_rules.append(merge_rule)


In [4]:
corpus = ["low lower lowest",
"new newer newest",
"widest wide",
"brightest bright",

]

vocab_size = 20
merge_limit = 10

In [5]:
tokenizer = BPETokenizer(corpus=corpus, vocab_size=vocab_size, merge_limit=merge_limit)

In [6]:
tokenizer.train()

In [7]:
tokenizer.tokenize("lowing low bri")

['lo', 'wi', 'n', 'g', 'lo', 'w', 'b', 'r', 'i']