In [1]:
import json
with open("input.txt", 'r', encoding='utf-8') as f:
    text = f.read()

In [2]:
corpus = list(set(text.split()))
vocab = set()
pair_scores = {}

for word in corpus:
    split = []
    for idx, letter in enumerate(word):
        split.append(letter if idx == 0 else "##" + letter)
    vocab.update(split)

In [3]:
from tqdm import tqdm

def get_word_from_token(token):
    if token[0] == "#":
        return token[2:]
    return token

def combine_pair(pair):
    ''' ex. 'h' and '##e' become 'he' '''
    return pair[0] + get_word_from_token(pair[1])

def get_frequency_from_token(token):
    word = get_word_from_token(token)
    frequency = 0

    for item in corpus:
        for i in range(len(item) - len(token) + 1):
            if word == item[i : i + len(token)]:
                frequency += 1
    return frequency

def get_score_from_pair(pair):
    ''' The score of a pair is freq(pair) / (freq(pair[0]) * freq(pair[1]))'''
    word1 = get_word_from_token(pair[0])
    word2 = get_word_from_token(pair[1])
    combined_word = get_word_from_token(combine_pair(pair))

    combined_freq = get_frequency_from_token(combined_word)
    # freq1 = get_frequency_from_token(word1)
    # freq2 = get_frequency_from_token(word2)

    return combined_freq

def find_best_pair(pairs):
    global pair_scores
    best_pair, best_score = 0, 0

    for pair in pairs:
        if pair not in pair_scores:
            pair_scores[pair] = get_score_from_pair(pair)
        
        score = pair_scores[pair]
              
        if score >= best_score:
            best_score = score
            best_pair = pair
    
    return best_pair, best_score

In [4]:
def token_equality(token, index, subword):
    if not get_word_from_token(token) == subword:
        return False
    if index == 0 and token[0] == "#":
        return False
    if index != 0 and token[0] != "#":
        return False
    return True

def tokenize(word):
    tokenized = []
    ptr = 0

    while ptr < len(word):
        remaining_word = word[ptr:]

        for i in range(len(remaining_word), 0, -1):
            subword = remaining_word[:i]
            found = False
            for token in vocab:
                if token_equality(token, ptr, subword):
                    tokenized.append(token)
                    ptr += len(subword)
                    found = True
                    break
            if found:
                break
    
    return tokenized

def get_pairs_from_vocab(vocab):
    pairs = set()
    for word in corpus:
        tokenized_word = tokenize(word)
        pairs.update([(tokenized_word[i], tokenized_word[i+1]) for i in range(len(tokenized_word) - 1)])
    return pairs

In [7]:
for _ in tqdm(range(500)):
    pairs = get_pairs_from_vocab(vocab)
    best_pair, best_score = find_best_pair(pairs)
    best_pair = combine_pair(best_pair)
    vocab.add(best_pair)

with open("tokenizer_tokens", 'w') as f:
    json.dump(list(vocab), f)

100%|██████████| 500/500 [2:46:49<00:00, 20.02s/it]  
