In [17]:
from collections import Counter, defaultdict
import nltk

class BytePairEncoding:
    def __init__(self, num_merges):
        self.num_merges = num_merges
        self.vocab = None
        self.merges = []

    def _get_stats(self, vocab):
        """Count frequency of pairs of symbols in the vocabulary."""
        pairs = defaultdict(int)
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pairs[(symbols[i], symbols[i + 1])] += freq
        return pairs

    def _merge_vocab(self, pair, vocab):
        """Merge the most frequent pair in the vocabulary."""
        new_vocab = {}
        bigram = ' '.join(pair)
        replacement = ''.join(pair)
        for word in vocab:
            # Replace all occurrences of the bigram with its replacement.
            new_word = word.replace(bigram, replacement)
            new_vocab[new_word] = vocab[word]
        return new_vocab

    def fit(self, corpus):
        """Train the Byte Pair Encoding model on the provided corpus."""
        # Initialize the vocabulary
        vocab = Counter()
        for word in corpus:
            # Clean the word by removing special characters and punctuation
            word = ''.join(char.lower() for char in word if char.isalnum())
            if word:  # Only process non-empty words
                tokens = ' '.join(word) + '_'
                vocab[tokens] += 1
        
        self.vocab = vocab

        for _ in range(self.num_merges):
            # Get the most frequent pair of symbols.
            pairs = self._get_stats(vocab)
            if not pairs:
                break
            best_pair = max(pairs, key=pairs.get)

            # Merge the most frequent pair.
            vocab = self._merge_vocab(best_pair, vocab)
            self.merges.append(best_pair)

        self.vocab = vocab

    def encode(self, word):
        """Encode a word using the learned merges."""
        # Clean the word by removing special characters and punctuation
        word = ''.join(char.lower() for char in word if char.isalnum())
        word = ' '.join(word) + '_'
        for pair in self.merges:
            bigram = ' '.join(pair)
            replacement = ''.join(pair)
            word = word.replace(bigram, replacement)
        return word.split()

    def decode(self, encoded_word):
        """Decode an encoded word back to its original form."""
        word = ' '.join(encoded_word)
        for pair in reversed(self.merges):
            replacement = ''.join(pair)
            bigram = ' '.join(pair)
            word = word.replace(replacement, bigram)
        return word.replace('_', '').replace(' ', '')

In [19]:
# Example Usage:
if __name__ == "__main__":
    corpus = ["low", "lowest", "newer", "wider"]
    nltk.download('gutenberg')
    from nltk.corpus import gutenberg
    book1 = gutenberg.raw("austen-emma.txt")
    book2 = gutenberg.raw("blake-poems.txt") 
    book3 = gutenberg.raw("shakespeare-hamlet.txt")
    
    training_text = book1 + " " +  book2 + " " + book3 + " " + "low"
    
    corpus = training_text.split()
    print(corpus[0])
    
    bpe = BytePairEncoding(num_merges=1000)
    bpe.fit(corpus)

    print("Learned merges:", bpe.merges)
    print("Vocabulary after training:", bpe.vocab)

    encoded_word = bpe.encode("lowest pi-chart")
    print("Encoded:", encoded_word)

    decoded_word = bpe.decode(encoded_word)
    print("Decoded:", decoded_word)

[nltk_data] Error loading gutenberg: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed:
[nltk_data]     unable to get local issuer certificate (_ssl.c:1000)>


[Emma
Learned merges: [('t', 'h'), ('a', 'n'), ('i', 'n'), ('e', 'r'), ('o', 'u'), ('h', 'a'), ('th', 'e_'), ('an', 'd_'), ('t', 'o_'), ('e', 'd_'), ('e', 'n'), ('in', 'g_'), ('e', 'a'), ('o', 'f_'), ('o', 'n'), ('h', 'e_'), ('e', 'l'), ('n', 'o'), ('a', 's_'), ('h', 'i'), ('o', 'r'), ('e', 's'), ('t', 'i'), ('a', 'r'), ('i', 't_'), ('a', 'l'), ('w', 'i'), ('l', 'd_'), ('no', 't_'), ('o', 'm'), ('h', 'er_'), ('w', 'as_'), ('y', 'ou_'), ('r', 'i'), ('l', 'y_'), ('s', 'he_'), ('r', 'e'), ('g', 'h'), ('b', 'e'), ('v', 'e_'), ('i', 's'), ('l', 'e_'), ('ou', 'ld_'), ('tha', 't_'), ('s', 'u'), ('s', 't_'), ('b', 'u'), ('c', 'h_'), ('l', 'i'), ('v', 'er'), ('e', 'm'), ('c', 'e_'), ('ver', 'y_'), ('bu', 't_'), ('s', 'a'), ('s', 'e_'), ('ha', 'd_'), ('a', 't'), ('m', 'e_'), ('f', 'or_'), ('o', 'o'), ('wi', 'th_'), ('s', 'i'), ('w', 'h'), ('s', 't'), ('l', 'e'), ('l', 'l_'), ('m', 'u'), ('r', 'o'), ('ha', 've_'), ('m', 'y_'), ('m', 'a'), ('l', 'a'), ('s', 'e'), ('gh', 't_'), ('k', 'e_'), ('d', '