Author : Clement Ong 

In [None]:
import nltk
nltk.download('stopwords')
nltk.download('punkt')

# basic pre-processing
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.stem import PorterStemmer

import re
from collections import Counter, defaultdict
from operator import itemgetter
import numpy as np  


In [None]:
result_words = []
result_vocab = []
result_merge = []
toy_vocab = {'newest':6, 'low':5, 'widest':3, 'lower':2, 'longer':1}
toy_result_vocab = ['longer_', 'newest_', 'widest_', 'ewest_', 'idest_', 'lower_', 'dest_', 'est_', 'ger_', 'low_', 'er_', 'lon', 'low', 'es', 'ew', 'lo', 'r_', 't_', 'w_',
                    'd', 'e', 'g', 'i', 'l', 'n', 'o', 's', 'w']
toy_result_merge = ['e s', 'es t_', 'l o', 'e w', 'ew est_', 'n ewest_', 'lo w_', 'd est_', 'e r_', 'i dest_', 'w idest_',
                    'lo w', 'low er_', 'g er_', 'lo n', 'lon ger_']

stop_words = set(stopwords.words('english'))
stemmer = PorterStemmer()

def preprocess(review):
    review = [stemmer.stem(w.lower()) for w in word_tokenize(re.sub('[^a-zA-Z]+', ' ', review.replace("<br />", ""))) if not w in stop_words]
    return review

def sort_bpe_vocab(vocab):
    sorted_vocab = sorted(list(vocab))
    sorted_vocab = sorted(list(sorted_vocab), key=len, reverse=True)
    return sorted_vocab


In [None]:
def build_vocab(vocab_dict):
    """ Separate each char in word by space and add mark end of token, returning the unique list token"""
    tokens = [" ".join(word) + "_" for word in vocab_dict.keys()] 
    v_values = list(vocab_dict.values())
    vocab_dict = dict(zip(tokens, v_values))

    last_char_list = [w[-2:] for w in tokens]
    char_list = [ i for w in tokens for i in w[:-2] if i != ' ']
    unique_tokens = set(char_list + last_char_list)
    
    return vocab_dict, list(unique_tokens)


def get_stats(vocab_dict):
    """Retrieve pair frequency"""
    pairs = defaultdict(int)
    for word, frequency in vocab_dict.items():
        symbols = word.split(' ')
        # Counting up occurrences of pairs
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i + 1]] += frequency
    return pairs

def find_adjacent_char_merge(adj_vocab_dict):
      """Given {('n', 'e'): 2, ('e', 'w'): 2,('w', 'e'): 1,} find alphabet to be sorted"""
      max_count = np.max(np.array(list(adj_vocab_dict.values())))
      max_index = np.where(np.array(list(adj_vocab_dict.values())) == max_count)[0]
      if len(max_index) > 1:
            keys_dict = list(itemgetter(*max_index)(list(adj_vocab_dict.keys())))
            keys_dict = sort_bpe_vocab(keys_dict)
            return keys_dict[0]
      else:
            return list(adj_vocab_dict.keys())[max_index[0]]

def merge_vocab(pair, v_in):
    """Merging Frequent Pairs"""
    vocab_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        # replace freq vocab with pair
        w_out = p.sub(''.join(pair), word)
        vocab_out[w_out] = v_in[word]
    return vocab_out

def learn_bpe(vocab, num_vocab):
    result_vocab = []
    result_merge = []
    vc_dict, result_vocab = build_vocab(vocab)
    while len(result_vocab) < num_vocab:
        pair_frequency = get_stats(vc_dict)
        if not pair_frequency:
            break
        pair = find_adjacent_char_merge(pair_frequency)
        vc_dict = merge_vocab(pair, vc_dict)
        result_vocab.append("".join(pair))
        result_merge.append(" ".join(pair))
    return sort_bpe_vocab(result_vocab), result_merge

result_vocab, result_merge = learn_bpe(vocab=toy_vocab, num_vocab=5000)
len(result_vocab)

In [None]:
# bpe tokenizer based on learned vocab
def bpe_tokenize(sentence, bpe_merges, bpe_vocab):
    words = [stemmer.stem(w.lower()) for w in word_tokenize(re.sub('[^a-zA-Z]+', ' ', sentence.replace("<br />", ""))) if not w in stop_words]
    result_tokens = []
    for w in words: 
        w = w.replace('', ' ')[1:-1]
        w += "_"
        for pair in bpe_merges:
            replacement = pair.replace(" ", "")
            pattern = r'(?:(?<=\s)|(?<=^))'+ re.escape(pair)
            w = re.sub(pattern, replacement, w)
        sym_vocab = []
        vocab_found = False
        for sym in w.split():
            if (sym not in bpe_vocab):
                sym_vocab.append("UNK")
            else:
                sym_vocab.append(sym)
                vocab_found = True
        if (not vocab_found):
            result_tokens.append("UNK")
        else : 
            result_tokens.extend(sym_vocab)
    return result_tokens

print(bpe_tokenize('fewest gewest makes me new clement', result_merge, result_vocab))