In [1]:
from collections import Counter, defaultdict
import re

def get_vocab(text):
    vocab = Counter(text.split())
    return vocab

def merge_vocab(pair, vocab):
    bigram = ' '.join(pair)
    replacement = ''.join(pair)
    new_vocab = {}
    
    for word in vocab:
        new_word = re.sub(r'\b' + re.escape(bigram) + r'\b', replacement, word)
        new_vocab[new_word] = vocab[word]
    
    return new_vocab

def bpe(text, num_merges):
    text = re.sub(r'(\S)', r'\1 ', text).strip()
    vocab = get_vocab(text)
    
    for _ in range(num_merges):
        pairs = Counter()
        for word, freq in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pair = (symbols[i], symbols[i + 1])
                pairs[pair] += freq
        
        if not pairs:
            break
        
        most_freq_pair = pairs.most_common(1)[0][0]
        vocab = merge_vocab(most_freq_pair, vocab)
    
    return vocab



In [2]:
text = "low low lower lowest"
num_merges = 10
bpe_vocab = bpe(text, num_merges)

print("Vocabulary after BPE:")
for word, freq in bpe_vocab.items():
    print(f"{word}: {freq}")


Vocabulary after BPE:
l: 4
o: 4
w: 4
e: 2
r: 1
s: 1
t: 1
