In [535]:
import re
from tkinter import filedialog 
from collections import Counter, defaultdict

In [470]:
f_path = filedialog.askopenfile(title="Select text file.").name
f = open(f_path)

In [536]:
full_text = f.read(500000)

In [537]:
def normalize(text):
    text = text.lower()

    text = text.replace('[inst]', ' <inst> ').replace('[/inst]', ' </inst> ')
    text = re.sub('[^a-z0-9<->/]+', ' ', text)

    text = re.sub('/s+', ' ', text)

    return text.strip()

In [538]:
texts = full_text.split()

words = []
for word in texts:
    word = normalize(word)
    words.extend(word.split())


word_freq = Counter(words)
# For debuging
# print(word_freq)

In [539]:
bpe_vocab = {}
for word, freq in word_freq.items():

    if word.startswith('<') and word.endswith('>'):
        bpe_vocab[(word,)] = freq

    else:
        chars = list(word)
        chars.append('</w>')
        
        
        bpe_vocab[tuple(chars)] = freq
    # bpe_vocab[freq] = tuple(chars)
# For debuging
print(bpe_vocab)

{('e', 'n', 'i', 'n', 'g', '</w>'): 1, ('i', 'f', '</w>'): 199, ('y', 'o', 'u', '</w>'): 1157, ('w', 'a', 'n', 't', '</w>'): 138, ('t', 'o', '</w>'): 2433, ('s', 'h', 'a', 'r', 'e', '</w>'): 114, ('<inst>',): 2934, ('i', '</w>'): 2543, ('t', 'h', 'i', 'n', 'k', '</w>'): 544, ('i', 't', '</w>'): 2335, ('i', 's', '</w>'): 1594, ('f', 'o', 'o', 'l', 'i', 's', 'h', '</w>'): 11, ('f', 'e', 'e', 'l', '</w>'): 1405, ('t', 'h', 'i', 's', '</w>'): 363, ('w', 'a', 'y', '</w>'): 248, ('w', 'h', 'e', 'n', '</w>'): 773, ('o', 't', 'h', 'e', 'r', 's', '</w>'): 139, ('h', 'a', 'v', 'e', '</w>'): 315, ('b', 'i', 'g', 'g', 'e', 'r', '</w>'): 10, ('p', 'r', 'o', 'b', 'l', 'e', 'm', 's', '</w>'): 42, ('</inst>',): 2943, ('c', 'a', 'n', '</w>'): 1765, ('b', 'e', '</w>'): 567, ('e', 'a', 's', 'y', '</w>'): 11, ('c', 'o', 'm', 'p', 'a', 'r', 'e', '</w>'): 2, ('o', 'u', 'r', '</w>'): 849, ('f', 'e', 'e', 'l', 'i', 'n', 'g', 's', '</w>'): 330, ('b', 'u', 't', '</w>'): 422, ('e', 'v', 'e', 'r', 'y', '</w>'): 6

In [540]:
def get_stats(vocab):
    pairs = defaultdict(int)

    # for freq, word in vocab.items():
    for word, freq in vocab.items():
        symbols = word
        for i in range(len(symbols)-1):
            # ****
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs

In [541]:
def merge_vocab(pair, vocab):
    new_vocab = {}
    bigram = ''.join(pair)

    for word, freq in vocab.items():
        w = list(word)

        i = 0
        new_word = []

        while i < len(w):
            if i < len(w)-1 and (w[i], w[i+1]) == pair:
                new_word.append(bigram)
                i+=2
            else:
                new_word.append(w[i])
                i+=1

        new_vocab[tuple(new_word)] = freq

    return new_vocab

In [None]:
num_merges = 1000
merge_rules = []


for i in range(num_merges):
    pairs = get_stats(bpe_vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    bpe_vocab = merge_vocab(best, bpe_vocab)
    merge_rules.append(best)

    if i%100 == 0 or i==num_merges:
        print(f'Merging... {i}/{num_merges}')

# For debugging
# print(merge_rules)

Merging... 0/1000
Merging... 100/1000
Merging... 200/1000
Merging... 300/1000
Merging... 400/1000
Merging... 500/1000
Merging... 600/1000
Merging... 700/1000
Merging... 800/1000
Merging... 900/1000


In [None]:
def encode_word(word, merge_rules):

    if word.startswith('<') and word.endswith('>'):
        return [word]
    
    symbols = list(word)
    symbols.append('</w>')

    i = 0

    done = False

    for pair in merge_rules:
        i = 0
        while i < len(symbols)-1:
            if (symbols[i], symbols[i+1]) == pair:
                symbols[i:i+2] = [''.join(pair)]
            else:
                i+=1
    return symbols

# See an example
print(encode_word('keeping',merge_rules))

['keeping</w>']


In [None]:
# Encode whole of the text:
encoded_text = []
for word in normalize(full_text):
    encodeds = encode_word(word, merge_rules)
    for encoded in encodeds:
        encoded_text.append(encoded)

# For debugging
# print(encoded_text)