# COLX 585 Trends in Computational Linguistic
##  Lab tutorial 2: BPE and BERT


# Byte Pair Encoding

Neural machine translation (NMT) models typically operate with a fixed vocabulary, but translation is an **open-vocabulary problem.** (i.e., we can observe words during testing that are not present in the training vocabulary). **Byte pair encoding (BPE)** enables NMT model translation on open-vocabulary by encoding rare and unknown words as sequences of subword units. This is based on an intuition that various word classes are translatable via smaller units than words. 

More information in the paper:
Sennrich, Rico, Barry Haddow, and Alexandra Birch. "Neural Machine Translation of Rare Words with Subword Units." Proceedings of the 54th Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers). Vol. 1. 2016. http://www.aclweb.org/anthology/P16-1162

BPE creates a list of merges that are used for splitting out-of-vocabulary words.

## Training algorithm

* Compute frequencies of all words in the training corpus
* Start with vocabulary that consists of **singleton symbols** (character) from training corpus
* To get vocabulary of $n$ merges, iterate $n$ times:
    1. Get the most frequent pair of symbols in the training data
    2. Add the pair into list of merges
    3. Add the merged symbol into vocabulary

![](./bpe_al.png)
Picture Courtesy: https://www.aclweb.org/anthology/P16-1162.pdf

We count the frequency of each word shown in the corpus. 

For each word, we append a special stop token </w\> at the end of the word.

We then split the word into characters. Initially, the tokens of word are all of its characters plus the additional </w\> token. For example, the tokens for word “low” are [“l”, “o”, “w”, </w\>] in order. 

So after counting all the words in the dataset, we will get a vocabulary for the tokenized word with its corresponding counts, such as

In [1]:
{'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

{'l o w </w>': 5,
 'l o w e r </w>': 2,
 'n e w e s t </w>': 6,
 'w i d e s t </w>': 3}

In [2]:
import re, collections

def get_stats(train_data):
    """Compute frequencies of adjacent pairs of symbols.
    input: train_data,  e.g., {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}
    output: pairs:  e.g., {('l', 'o'): 7, ('o', 'w'): 7, ('w', '</w>'): 5 ...}
    """
    pairs = collections.defaultdict(int) # initialize a counter 
    for word, freq in train_data.items():  # get the 
        symbols = word.split()            # split word by whitespace
        for i in range(len(symbols)-1):
            pairs[symbols[i],symbols[i+1]] += freq    # calculate frequencies of adjacent pairs
    return pairs            # return counter

def merge_vocab(best_pair, train_data_in):
    """
    merge the most freqent pair and update training dataset 
    input: 
    best_pair: e.g., ('e', 's') the most frequent pair
    train_data_in:  e.g., {'l o w </w>': 5,..., 'w i d e s t </w>': 3} 
    output: 
    train_data_out: {'l o w </w>': 5, ..., 'n e w es t </w>': 6, 'w i d es t </w>': 3} merged dataset
    """ 
    train_data_out = {}    # create a new container to hold merged dataset
    
    bigram = ' '.join(best_pair)   # conbine the most frequent pair,  e.g., 'e s'
    bigram = re.escape(bigram)     # add backslash before special characters (here is whitespace), e.g., 'e\ s'
    p = re.compile(r'(?<!\S)' + bigram + r'(?!\S)') # regular expression to search the target bigram (the most freqent pair) 
    for word in train_data_in:
        w_out = p.sub(''.join(best_pair), word)   # ''.join(best_pair), e.g., 'es', word e.g., 'n e w e s t' ---> 'n e w es t' 
        # find the string that match the regular expression and merge the target pair in train_data, 
        # e.g., the best pair is (e, s), hence we merge 'es' in 'n e w e s t </w>', the new word is 'n e w es t </w>'
        train_data_out[w_out] = train_data_in[word]  # add new data point to container
    return train_data_out


In each iteration, we count the frequency of each consecutive byte pair, find out the most frequent one, and merge the two byte pair tokens to one token.


In [3]:
from IPython.display import display, Markdown, Latex

train_data = {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w e s t </w>': 6, 'w i d e s t </w>': 3}

bpe_codes = {}
bpe_codes_reverse = {}

num_merges = 10

for i in range(num_merges):
    display(Markdown("### Iteration {}".format(i + 1)))
    pairs = get_stats(train_data)      # use defined function
    best_pair = max(pairs, key=pairs.get)   # get the most frequent pair
    train_data = merge_vocab(best_pair, train_data)   # use defined function
    
    bpe_codes[best_pair] = i    # save merging history
    bpe_codes_reverse[best_pair[0] + best_pair[1]] = best_pair  # reversing dictionary
    
    print("new merge: {}".format(best_pair))
    print("train data: {}".format(train_data))

### Iteration 1

new merge: ('e', 's')
train data: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}


### Iteration 2

new merge: ('es', 't')
train data: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}


### Iteration 3

new merge: ('est', '</w>')
train data: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 4

new merge: ('l', 'o')
train data: {'lo w </w>': 5, 'lo w e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 5

new merge: ('lo', 'w')
train data: {'low </w>': 5, 'low e r </w>': 2, 'n e w est</w>': 6, 'w i d est</w>': 3}


### Iteration 6

new merge: ('n', 'e')
train data: {'low </w>': 5, 'low e r </w>': 2, 'ne w est</w>': 6, 'w i d est</w>': 3}


### Iteration 7

new merge: ('ne', 'w')
train data: {'low </w>': 5, 'low e r </w>': 2, 'new est</w>': 6, 'w i d est</w>': 3}


### Iteration 8

new merge: ('new', 'est</w>')
train data: {'low </w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3}


### Iteration 9

new merge: ('low', '</w>')
train data: {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'w i d est</w>': 3}


### Iteration 10

new merge: ('w', 'i')
train data: {'low</w>': 5, 'low e r </w>': 2, 'newest</w>': 6, 'wi d est</w>': 3}


As the result shows, in the first iteration of merge, because byte pair “e” and “s” occurred 6 + 3 = 9 times which is the most frequent. We merge these to a new token “es”.

```
##Iteration 1
new merge: ('e', 's')
train data: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w es t </w>': 6, 'w i d es t </w>': 3}

```

In the second iteration of merge, token “es” and “t” occurred 6 + 3 = 9 times, which is the most frequent. We merge these to a new token “est”.

```
##Iteration 2
new merge: ('es', 't')
train data: {'l o w </w>': 5, 'l o w e r </w>': 2, 'n e w est </w>': 6, 'w i d est </w>': 3}
```

In the third iteration of merge, token 'est' and "</w\>" pair is the most frequent.

Stop token "</w\>" is important. Without "</w\>", if there is a token "st", this token could be in the word "star", or the word "newest", however, the meaning of them are quite different. With "</w\>", if there is a token "st</w\>", the model immediately know that it is the token for the word "newest</w\>" but not "star</w\>".

We could count the iteration or the maximum number of tokens to control the number of tokens we want to have.


Finally, we get a byte-pair encoding vocabulary `bpe_codes`.

In [5]:
print(bpe_codes)

{('e', 's'): 0, ('es', 't'): 1, ('est', '</w>'): 2, ('l', 'o'): 3, ('lo', 'w'): 4, ('n', 'e'): 5, ('ne', 'w'): 6, ('new', 'est</w>'): 7, ('low', '</w>'): 8, ('w', 'i'): 9}


Then, we will use this byte-pair encoding vocabulary to tokenize given word. 

## Apply BPE

This uses a greedy longest-match-first algorithm to perform tokenization using the given vocabulary.

In [7]:
def get_pairs(word):
    """Return set of symbol pairs in a word.
    Word is represented as a tuple of symbols (symbols being variable-length strings).
    """
    pairs = set()
    prev_char = word[0]
    for char in word[1:]:
        pairs.add((prev_char, char))
        prev_char = char
    return pairs


def encode(orig):
    """Encode word based on list of BPE merge operations, which are applied consecutively"""

    word = tuple(orig) + ('</w>',)   # tokenize to characters and add <\w> token
    display(Markdown("__word split into characters:__ <tt>{}</tt>".format(word)))

    pairs = get_pairs(word)    # use defined function to get symbol pairs, e.g., ('w', 'e'), ('o', 'w')...

    if not pairs:
        return orig

    iteration = 0
    while True:
        iteration += 1
        display(Markdown("__Iteration {}:__".format(iteration)))
        
        print("bigrams in the word: {}".format(pairs))
        # find the candidate pair for merging
        bigram = min(pairs, key = lambda pair: bpe_codes.get(pair, float('inf')))  
        print("candidate for merging: {}".format(bigram))
        #This uses a greedy longest-match-first algorithm. 
        if bigram not in bpe_codes:
            display(Markdown("__Candidate not in BPE merges, algorithm stops.__"))
            break
        first, second = bigram
    # merge the candidate pair and update the word tuple
        new_word = []
        i = 0
        while i < len(word):
            try:
                j = word.index(first, i)
                new_word.extend(word[i:j])
                i = j
            except:
                new_word.extend(word[i:])
                break
            # if match the candidate pair, merge them
            if word[i] == first and i < len(word)-1 and word[i+1] == second:
                new_word.append(first+second)
                i += 2
            else:
                new_word.append(word[i])
                i += 1
        new_word = tuple(new_word)
        word = new_word
        print("word after merging: {}".format(word))
        if len(word) == 1:
            break
        else:
            pairs = get_pairs(word)

    # don't print end-of-word symbols
    if word[-1] == '</w>':
        word = word[:-1]
    elif word[-1].endswith('</w>'):
        word = word[:-1] + (word[-1].replace('</w>',''),)
   
    return word

The word **lowest** was not in the training data. Both **low** and ending **est** are the learned merges, so the word splits as we would expect.

In [8]:
encode("lowest")

__word split into characters:__ <tt>('l', 'o', 'w', 'e', 's', 't', '</w>')</tt>

__Iteration 1:__

bigrams in the word: {('l', 'o'), ('o', 'w'), ('t', '</w>'), ('w', 'e'), ('e', 's'), ('s', 't')}
candidate for merging: ('e', 's')
word after merging: ('l', 'o', 'w', 'es', 't', '</w>')


__Iteration 2:__

bigrams in the word: {('l', 'o'), ('o', 'w'), ('t', '</w>'), ('w', 'es'), ('es', 't')}
candidate for merging: ('es', 't')
word after merging: ('l', 'o', 'w', 'est', '</w>')


__Iteration 3:__

bigrams in the word: {('w', 'est'), ('l', 'o'), ('o', 'w'), ('est', '</w>')}
candidate for merging: ('est', '</w>')
word after merging: ('l', 'o', 'w', 'est</w>')


__Iteration 4:__

bigrams in the word: {('l', 'o'), ('o', 'w'), ('w', 'est</w>')}
candidate for merging: ('l', 'o')
word after merging: ('lo', 'w', 'est</w>')


__Iteration 5:__

bigrams in the word: {('lo', 'w'), ('w', 'est</w>')}
candidate for merging: ('lo', 'w')
word after merging: ('low', 'est</w>')


__Iteration 6:__

bigrams in the word: {('low', 'est</w>')}
candidate for merging: ('low', 'est</w>')


__Candidate not in BPE merges, algorithm stops.__

('low', 'est')

# Sentencepiece python module
We also introduce a useful python package in this section. 

Ref: https://github.com/google/sentencepiece

## install

In [9]:
!pip3 install sentencepiece

[33mYou are using pip version 19.0.3, however version 20.0.2 is available.
You should consider upgrading via the 'pip install --upgrade pip' command.[0m


`Sentencepiece` supports BPE (byte-pair-encoding) for subword segmentation with --model_type=bpe flag. 

`Sentencepiece` will the generated bpe vocabulary automatically. 

In [10]:
import sentencepiece as spm

In [11]:
### generate bpe vocabulary 
'''
vocab_size= n define the number of vocabulary.
'''
spm.SentencePieceTrainer.train('--input=bpe_text.txt --model_prefix=m_bpe --vocab_size=200 --model_type=bpe')
## generate two file: m_bpe.model, m_bpe.vocab

True

You can check the bpe vocabulary, `m_bpe.vocab` file. 
```
<unk>	0
<s>	0
</s>	0
▁t	-0
he	-1
▁a	-2
in	-3
▁s	-4
▁w	-5
▁the	-6
▁o	-7
re	-8
▁b	-9
▁m	-10
ou	-11
ed	-12
▁I	-13
.
.
.
```

### Apply BPE

In [12]:
# check model 
sp_bpe = spm.SentencePieceProcessor()
sp_bpe.load('m_bpe.model')

print('*** BPE ***')
print(sp_bpe.encode_as_pieces('this is a test hello world'))
# __ represent the start of a word

*** BPE ***
['▁this', '▁is', '▁a', '▁t', 'es', 't', '▁he', 'll', 'o', '▁w', 'or', 'ld']


After tokenization, you can find that there are whole **words** and **subwords**.

In this vocabulary, subwords occuring **at the front of a word** are preceded by **‘__’** to denote this case. Namaly, `Sentencepiece` uses the `__` to clarify the boundary of word instead of `</w>` in previous example. 

In the vocabulary of BERT, Devlin et al. use `##` to denote the boundary of word.

### Reference Reading:
* [Neural Machine Translation of Rare Words with Subword Units](https://www.aclweb.org/anthology/P16-1162.pdf)
* [Subword Regularization: Improving Neural Network Translation Models with Multiple Subword Candidates](https://arxiv.org/pdf/1804.10959.pdf)
* [SentencePiece: A simple and language independent subword tokenizer and detokenizer for Neural Text Processing](https://arxiv.org/pdf/1808.06226.pdf)
* [A New Algorithm for Data Compression](https://www.derczynski.com/papers/archive/BPE_Gage.pdf)
* https://leimao.github.io/blog/Byte-Pair-Encoding/
* http://ufal.mff.cuni.cz/~helcl/courses/npfl116/ipython/byte_pair_encoding.html