In [2]:
def pretokenize(text):
    """
    Pretokenize the input text by splitting it into words and punctuation.
    """
    import re
    # Define a regex pattern to match words and punctuation
    pattern = r"\w+|[^\w\s]"
    # Find all matches in the input text
    tokens = re.findall(pattern, text)
    return tokens

In [3]:
test_text = "Hello world! This is a test... Isn't it?"
pre_tokens = pretokenize(test_text)
print(pre_tokens)

['Hello', 'world', '!', 'This', 'is', 'a', 'test', '.', '.', '.', 'Isn', "'", 't', 'it', '?']


In [4]:
# Sample Corpus 
corpus = [
  "This is the first sentence.",
  "This document is the second document.",
  "And this is the third one.",
  "Is this the first document?",
]

all_pre_tokens = []
for sentence in corpus:
  all_pre_tokens.extend(pretokenize(sentence)) 
print("Pre-tokens:", all_pre_tokens)

initial_vocab = set()
for token in all_pre_tokens:
  initial_vocab.update(list(token))

print("Initial Character Vocabulary:", sorted(list(initial_vocab)))


Pre-tokens: ['This', 'is', 'the', 'first', 'sentence', '.', 'This', 'document', 'is', 'the', 'second', 'document', '.', 'And', 'this', 'is', 'the', 'third', 'one', '.', 'Is', 'this', 'the', 'first', 'document', '?']
Initial Character Vocabulary: ['.', '?', 'A', 'I', 'T', 'c', 'd', 'e', 'f', 'h', 'i', 'm', 'n', 'o', 'r', 's', 't', 'u']


In [6]:
from collections import Counter
word_freqs = Counter(all_pre_tokens)
splits = {word: list(word) + ['</w>'] if word.isalnum() else list(word) for word in word_freqs.keys()}
print("Word Frequencies:", word_freqs)
print("Initial Splits:", splits)


Word Frequencies: Counter({'the': 4, 'is': 3, '.': 3, 'document': 3, 'This': 2, 'first': 2, 'this': 2, 'sentence': 1, 'second': 1, 'And': 1, 'third': 1, 'one': 1, 'Is': 1, '?': 1})
Initial Splits: {'This': ['T', 'h', 'i', 's', '</w>'], 'is': ['i', 's', '</w>'], 'the': ['t', 'h', 'e', '</w>'], 'first': ['f', 'i', 'r', 's', 't', '</w>'], 'sentence': ['s', 'e', 'n', 't', 'e', 'n', 'c', 'e', '</w>'], '.': ['.'], 'document': ['d', 'o', 'c', 'u', 'm', 'e', 'n', 't', '</w>'], 'second': ['s', 'e', 'c', 'o', 'n', 'd', '</w>'], 'And': ['A', 'n', 'd', '</w>'], 'this': ['t', 'h', 'i', 's', '</w>'], 'third': ['t', 'h', 'i', 'r', 'd', '</w>'], 'one': ['o', 'n', 'e', '</w>'], 'Is': ['I', 's', '</w>'], '?': ['?']}


In [7]:
def get_pair_stats(splits, word_freqs):
    """Counts occurrences of adjacent symbol pairs."""
    stats = Counter()
    for word, freq in word_freqs.items():
        symbols = splits[word]
        if len(symbols) < 2:
            continue
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i+1])
            stats[pair] += freq
    return stats

initial_pair_stats = get_pair_stats(splits, word_freqs)
print("\nInitial Pair Stats (using </w>):", initial_pair_stats)

if initial_pair_stats: 
    most_frequent_pair = initial_pair_stats.most_common(1)[0] 
    print("Most Frequent Pair:", most_frequent_pair)
else:
    print("No pairs found.")



Initial Pair Stats (using </w>): Counter({('s', '</w>'): 8, ('i', 's'): 7, ('t', 'h'): 7, ('e', '</w>'): 6, ('h', 'i'): 5, ('t', '</w>'): 5, ('e', 'n'): 5, ('h', 'e'): 4, ('n', 't'): 4, ('i', 'r'): 3, ('d', 'o'): 3, ('o', 'c'): 3, ('c', 'u'): 3, ('u', 'm'): 3, ('m', 'e'): 3, ('d', '</w>'): 3, ('T', 'h'): 2, ('f', 'i'): 2, ('r', 's'): 2, ('s', 't'): 2, ('s', 'e'): 2, ('o', 'n'): 2, ('n', 'd'): 2, ('t', 'e'): 1, ('n', 'c'): 1, ('c', 'e'): 1, ('e', 'c'): 1, ('c', 'o'): 1, ('A', 'n'): 1, ('r', 'd'): 1, ('n', 'e'): 1, ('I', 's'): 1})
Most Frequent Pair: (('s', '</w>'), 8)


In [11]:
num_merges = 100
merges = {}
current_splits = splits.copy()
vocab = initial_vocab.copy()
for i in range(num_merges):
    pair_stats = get_pair_stats(current_splits, word_freqs)
    if not pair_stats:
        break
    most_frequent_pair = max(pair_stats, key=pair_stats.get)
    freq = pair_stats[most_frequent_pair]
    if freq < 2:
        break
    new_symbol = ''.join(most_frequent_pair)
    merges[most_frequent_pair] = new_symbol
    vocab.add(new_symbol)
    new_splits = {}
    for word, symbols in current_splits.items():
        new_symbols = []
        i = 0
        while i < len(symbols):
            if i < len(symbols) - 1 and (symbols[i], symbols[i+1]) == most_frequent_pair:
                new_symbols.append(new_symbol)
                i += 2
            else:
                new_symbols.append(symbols[i])
                i += 1
        new_splits[word] = new_symbols
    current_splits = new_splits
    



In [16]:
print("Final Vocabulary after Merges:", sorted(list(vocab),key=lambda x: -len(x)))

Final Vocabulary after Merges: ['document</w>', 'first</w>', 'This</w>', 'this</w>', 'the</w>', 'documen', 'is</w>', 's</w>', 'e</w>', 'docum', 't</w>', 'd</w>', 'firs', 'docu', 'doc', 'fir', 'do', 'en', 'th', 'on', 'Th', 'ir', 's', 't', 'o', 'r', 'e', 'm', 'f', 'u', 'c', 'd', 'i', 'T', 'A', 'I', '?', 'h', 'n', '.']


In [30]:
def tokenize(text, pretokenizer, merges):
    pretokens = pretokenizer(text)
    result = []
    symbols = []
    for token in pretokens:
        if token.isalnum():
            symbols.append((list(token) + ['</w>']))
        else:
            symbols.append(list(token))
    merged_in_pass = True
    while merged_in_pass:
        merged_in_pass = False
        for i in range(len(symbols)):
            for j in range(len(symbols[i])-1):
                pair = (symbols[i][j], symbols[i][j + 1])
                if pair in merges:
                    merged_in_pass = True
                    new_symbol = merges[pair]
                    symbols[i][j:j + 2] = [new_symbol]
                    break
    result.extend(symbols)
    return result


    
        


In [33]:
new_text = "This is a first test document."
bpe_tokens = tokenize(new_text, pretokenize, merges)

print(f"Original Text: '{new_text}'")
print(f"BPE Tokens: {bpe_tokens}")

another_text = "Lowest common documents."
bpe_tokens_unknown = tokenize(another_text, pretokenize, merges)
print(f"\nOriginal Text: '{another_text}'")
print(f"BPE Tokens: {bpe_tokens_unknown}")

Original Text: 'This is a first test document.'
BPE Tokens: [['This</w>'], ['is</w>'], ['a', '</w>'], ['first</w>'], ['t', 'e', 's', 't</w>'], ['document</w>'], ['.']]

Original Text: 'Lowest common documents.'
BPE Tokens: [['L', 'o', 'w', 'e', 's', 't</w>'], ['c', 'o', 'm', 'm', 'on', '</w>'], ['documen', 't', 's</w>'], ['.']]
