In [1]:
import re, collections
from IPython.display import display, Markdown, Latex
from statsmodels.formula.api import ordinal_gee

In [2]:
num_merges = 10

dictionary = {'l o w ' : 5,
         'l o w e r ' : 2,
         'n e w e s t ':6,
         'w i d e s t ':3
         }

In [3]:
def get_stats(dictionary):
    pairs = collections.defaultdict(int)
    for word, freq in dictionary.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pairs[symbols[i], symbols[i+1]] += freq
    print('현재 pair들의 빈도수 :', dict(pairs))
    return pairs

def merge_dictionary(pair, v_in):
    v_out = {}
    bigram = re.escape(' '.join(pair))
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)')
    for word in v_in:
        w_out = p.sub(''.join(pair), word)
        v_out[w_out] = v_in[word]
    return v_out

bpe_codes = {}
bpe_codes_reverse = {}

for i in range(num_merges):
    display(Markdown('### Iteration {}'.format(i + 1)))
    pairs = get_stats(dictionary)
    best = max(pairs, key=pairs.get)
    dictionary = merge_dictionary(best, dictionary)

    bpe_codes[best] = i
    bpe_codes_reverse[best[0] + best[1]] = best

    print('new merge: {}'.format(best))
    print('dictionary: {}'.format(dictionary))

### Iteration 1

현재 pair들의 빈도수 : {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 8, ('e', 'r'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('e', 's'): 9, ('s', 't'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'e'): 3}
new merge: ('e', 's')
dictionary: {'l o w ': 5, 'l o w e r ': 2, 'n e w es t ': 6, 'w i d es t ': 3}


### Iteration 2

현재 pair들의 빈도수 : {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 2, ('e', 'r'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'es'): 6, ('es', 't'): 9, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'es'): 3}
new merge: ('es', 't')
dictionary: {'l o w ': 5, 'l o w e r ': 2, 'n e w est ': 6, 'w i d est ': 3}


### Iteration 3

현재 pair들의 빈도수 : {('l', 'o'): 7, ('o', 'w'): 7, ('w', 'e'): 2, ('e', 'r'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('l', 'o')
dictionary: {'lo w ': 5, 'lo w e r ': 2, 'n e w est ': 6, 'w i d est ': 3}


### Iteration 4

현재 pair들의 빈도수 : {('lo', 'w'): 7, ('w', 'e'): 2, ('e', 'r'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('lo', 'w')
dictionary: {'low ': 5, 'low e r ': 2, 'n e w est ': 6, 'w i d est ': 3}


### Iteration 5

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('n', 'e'): 6, ('e', 'w'): 6, ('w', 'est'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('n', 'e')
dictionary: {'low ': 5, 'low e r ': 2, 'ne w est ': 6, 'w i d est ': 3}


### Iteration 6

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('ne', 'w'): 6, ('w', 'est'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('ne', 'w')
dictionary: {'low ': 5, 'low e r ': 2, 'new est ': 6, 'w i d est ': 3}


### Iteration 7

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('new', 'est'): 6, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('new', 'est')
dictionary: {'low ': 5, 'low e r ': 2, 'newest ': 6, 'w i d est ': 3}


### Iteration 8

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('w', 'i'): 3, ('i', 'd'): 3, ('d', 'est'): 3}
new merge: ('w', 'i')
dictionary: {'low ': 5, 'low e r ': 2, 'newest ': 6, 'wi d est ': 3}


### Iteration 9

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('wi', 'd'): 3, ('d', 'est'): 3}
new merge: ('wi', 'd')
dictionary: {'low ': 5, 'low e r ': 2, 'newest ': 6, 'wid est ': 3}


### Iteration 10

현재 pair들의 빈도수 : {('low', 'e'): 2, ('e', 'r'): 2, ('wid', 'est'): 3}
new merge: ('wid', 'est')
dictionary: {'low ': 5, 'low e r ': 2, 'newest ': 6, 'widest ': 3}


In [6]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('',)
    display(Markdown("__word split into characters:__ {}".format(word)))

    pairs = get_pairs(word)

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        display(Markdown("__Iteration {}:__".format(iteration)))

        print("bigrams in the word: {}".format(pairs))
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))
        print("candidate for merging: {}".format(bigram))
        if bigram not in bpe_codes:
            display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
            break
        first, second = bigram
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break

            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # don't print end-of-word symbols
    if word[-1] == '':
        word = word[:-1]
    elif word[-1].endswith(''):
        word = word[:-1] + (word[-1].replace('',''),)

    return word


In [7]:
encode('loki')

__word split into characters:__ ('l', 'o', 'k', 'i', '')

__Iteration 1:__

bigrams in the word: {('k', 'i'), ('o', 'k'), ('l', 'o'), ('i', '')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'k', 'i', '')


__Iteration 2:__

bigrams in the word: {('lo', 'k'), ('i', ''), ('k', 'i')}
candidate for merging: ('lo', 'k')


__Candidate not in BPE merges, algorithm stops.__

('lo', 'k', 'i')

In [8]:
encode('lowest')

__word split into characters:__ ('l', 'o', 'w', 'e', 's', 't', '')

__Iteration 1:__

bigrams in the word: {('l', 'o'), ('s', 't'), ('t', ''), ('o', 'w'), ('e', 's'), ('w', 'e')}
candidate for merging: ('e', 's')
word after merging: ('l', 'o', 'w', 'es', 't', '')


__Iteration 2:__

bigrams in the word: {('l', 'o'), ('w', 'es'), ('o', 'w'), ('t', ''), ('es', 't')}
candidate for merging: ('es', 't')
word after merging: ('l', 'o', 'w', 'est', '')


__Iteration 3:__

bigrams in the word: {('w', 'est'), ('l', 'o'), ('est', ''), ('o', 'w')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'w', 'est', '')


__Iteration 4:__

bigrams in the word: {('w', 'est'), ('lo', 'w'), ('est', '')}
candidate for merging: ('lo', 'w')
word after merging: ('low', 'est', '')


__Iteration 5:__

bigrams in the word: {('low', 'est'), ('est', '')}
candidate for merging: ('low', 'est')


__Candidate not in BPE merges, algorithm stops.__

('low', 'est')