In [34]:
import re
from tkinter import filedialog 
from collections import Counter, defaultdict

In [35]:
f_path = filedialog.askopenfile(title="Select text file.").name
f = open(f_path)

In [36]:
full_text = f.read(5000)

In [37]:
def normalize(text):
    text = text.lower()

    text = text.replace('[inst]', ' <inst> ').replace('[/inst]', ' </inst> ')
    text = re.sub('[^a-z0-9<->/]+', ' ', text)

    text = re.sub('/s+', ' ', text)

    return text.strip()

In [38]:
texts = full_text.split()

words = []
for word in texts:
    word = normalize(word)
    words.extend(word.split())


word_freq = Counter(words)
# For debuging
# print(word_freq)

In [39]:
bpe_vocab = {}
for word, freq in word_freq.items():

    if word.startswith('<') and word.endswith('>'):
        bpe_vocab[(word,)] = freq

    else:
        chars = list(word)
        chars.append('</w>')
        
        
        bpe_vocab[tuple(chars)] = freq
        # bpe_vocab[freq] = tuple(chars)
# For debuging
# print(bpe_vocab.hea)

In [40]:
def get_stats(vocab):
    pairs = defaultdict(int)

    # for freq, word in vocab.items():
    for word, freq in vocab.items():
        symbols = word
        for i in range(len(symbols)-1):
            # ****
            pairs[symbols[i], symbols[i+1]] += freq
    return pairs

In [41]:
def merge_vocab(pair, vocab):
    new_vocab = {}
    bigram = ''.join(pair)

    for word, freq in vocab.items():
        w = list(word)

        i = 0
        new_word = []

        while i < len(w):
            if i < len(w)-1 and (w[i], w[i+1]) == pair:
                new_word.append(bigram)
                i+=2
            else:
                new_word.append(w[i])
                i+=1

        new_vocab[tuple(new_word)] = freq

    return new_vocab

In [12]:
num_merges = 100000
merge_rules = []


for i in range(num_merges):
    pairs = get_stats(bpe_vocab)
    if not pairs:
        break
    best = max(pairs, key=pairs.get)
    bpe_vocab = merge_vocab(best, bpe_vocab)
    merge_rules.append(best)

    if i%200 == 0 or i==num_merges:
        # print(len(list(new_bpe_vocab)))
        print(f'Merging... {i}/{num_merges}')

# For debugging
# print(merge_rules)

Merging... 0/100000
Merging... 200/100000
Merging... 400/100000
Merging... 600/100000
Merging... 800/100000


In [13]:
def encode_word(word, merge_rules):

    if word.startswith('<') and word.endswith('>'):
        return [word]
    
    word = normalize(word)

    symbols = list(word)
    symbols.append('</w>')

    i = 0

    for pair in merge_rules:
        i = 0
        while i < len(symbols)-1:
            if (symbols[i], symbols[i+1]) == pair:
                symbols[i:i+2] = [''.join(pair)]
            else:
                i+=1
    return symbols

# See an example
print(encode_word('hello there i am mohammad',merge_rules))

['hell', 'o', ' ', 'ther', 'e', ' ', 'i', ' ', 'a', 'm', ' ', 'mo', 'ha', 'm', 'ma', 'd</w>']


In [32]:
def buid_vocab(merge_rules):
    pairs = []
    for pair in merge_rules:
        pairs.append(''.join(pair))

    token_to_id = {}
    for id, tok in enumerate(pairs):
        token_to_id[tok] = id

    id_to_token = {}
    for id, tok in enumerate(pairs):
        token_to_id[id] = tok

    return token_to_id, id_to_token

tok2id, id2tok = buid_vocab(merge_rules)

In [None]:
# Encode whole of the text:
encoded_text = []
for word in normalize(full_text):
    encodeds = encode_word(word, merge_rules)
    for encoded in encodeds:
        encoded_text.append(encoded)

# For debugging
# print(encoded_text)