# 15.6 Subword Embedding

In English, words such as "helps", "helped", and "helping" are inflected forms of the same word "help". The relationship between "dog" and "dogs" is the same as that between "cat" and "cats", and the relationship between "boy" and "boyfriend" is the same as that between "girl" and "girlfriend". In other languages such as French and Spanish, many verbs have over 40 inflected forms, while in Finnish, a noun may have up to 15 cases. In linguistics, morphology studies word formation and word relationships. However, the internal structure of words was neither explored in word2vec nor in GloVe.

## 15.6.1 fastText Model

Recall how words are represented in word2vec. In both the skip-gram model and the CBOW, different inflected forms of the word are directly represented by different vectors without shared parameters. To use morphological information, the fastText model proposed a subword embedding approach, where a subword is a character n-gram. Instead of learning word-level vector representations, fastText can be cconsidered as the subword-level skip-gram, where each center word is represented by the sum of its subword vectors.

Let's illustrate how to obtain subwords for each center word in fastText using the word "where". First, add special characters "<" and ">" at the beginning and end of the word to distinguish prefixes and suffixes from other subwords. Then, extract character n-grams from the word. For example, when $n=3$, we obtain all subwords of length 3: "< wh", "whe", "ere", "re >", and the special subword "< where > ".

In fastText, for any word $w$, denote be $\mathcal{G}_w$ the union of all its subwords of length between 3 and 6 and its special subword. The vocab is the union of the subwords of all words. Letting $\mathbf{z}_g$ be the vector of subword $g$ in the dictionary, the vector $\mathbf{v}_w$ for word $w$ as a center word in the skip-gram model is the sum of its subword vectors:


$$

\mathbf{v}_w = \sum_{g \in \mathcal{G}_w} \mathbf{z}_g

$$

The rest of fastText is the same as the skip-gram model. Compared with the skip-gram model, the vocabulary in fastText is larger, resulting in more model parameters. Besides, to calculate the representation of a word, all its subword vectors have to be summed, leading to higher computational complexity.


## 15.6.2 Byte Pair Encoding (BPE)

In fastText, all the extract subwords have to be of the specified lengths, such as 3 to 6, thus the vocab size cannot be predefined. To allow for variable-length subwords in a fixed size vocabulary, we can apply a compression algorithm called byte pair encoding (BPE) to extract subwords.

BPE performs a statistical analysis of the training dataset to discover common symbols within a word, such as consecutive characters of arbitrary length. Starting from symbols of length 1, BPE iteratively merges the most frequent pair of consecutive symobls to produce new longer symbols. Note that for efficiency, pairs crossing word boundaries are not considered. In the end, we can use such symbols as subwords to segement words. Byte pair encoding and its variants has been used for input representations in popular NLP pretrianing models such as GPT-2 and RoBERTa. In the following, we will illustrate how BPE works:

In [1]:
import collections

symbols = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm',
'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z',
'_', '[UNK]']

In [12]:
raw_token_freqs = {'fast_': 4, 'faster_': 3, 'tall_': 5, 'taller_': 4}
token_freqs = {}
for token, freq in raw_token_freqs.items():
    token_freqs[' '.join(list(token))] = raw_token_freqs[token]
token_freqs

{'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}

In [13]:
def get_max_freq_pair(token_freqs):
    pairs = collections.defaultdict(int)
    for token, freq in token_freqs.items():
        symbols = token.split()
        for i in range(len(symbols) - 1):
            # Key of `pairs` is a tuple of two consecutive symbols
            pairs[symbols[i], symbols[i + 1]] += freq
    return max(pairs, key=pairs.get) # Key of `pairs` with the max value

In [14]:
def merge_symbols(max_freq_pair, token_freqs, symbols):
    symbols.append(''.join(max_freq_pair))
    new_token_freqs = dict()
    for token, freq in token_freqs.items():
        new_token = token.replace(' '.join(max_freq_pair),
                ''.join(max_freq_pair))
        new_token_freqs[new_token] = token_freqs[token]
    return new_token_freqs

In [15]:
num_merges = 10
print("token_freqs: ", token_freqs)
for i in range(num_merges):
    max_freq_pair = get_max_freq_pair(token_freqs)
    token_freqs = merge_symbols(max_freq_pair, token_freqs, symbols)
    print(f"token_freqs {i}: ", token_freqs)
    print(f'merge #{i + 1}:', max_freq_pair)

token_freqs:  {'f a s t _': 4, 'f a s t e r _': 3, 't a l l _': 5, 't a l l e r _': 4}
token_freqs 0:  {'f a s t _': 4, 'f a s t e r _': 3, 'ta l l _': 5, 'ta l l e r _': 4}
merge #1: ('t', 'a')
token_freqs 1:  {'f a s t _': 4, 'f a s t e r _': 3, 'tal l _': 5, 'tal l e r _': 4}
merge #2: ('ta', 'l')
token_freqs 2:  {'f a s t _': 4, 'f a s t e r _': 3, 'tall _': 5, 'tall e r _': 4}
merge #3: ('tal', 'l')
token_freqs 3:  {'fa s t _': 4, 'fa s t e r _': 3, 'tall _': 5, 'tall e r _': 4}
merge #4: ('f', 'a')
token_freqs 4:  {'fas t _': 4, 'fas t e r _': 3, 'tall _': 5, 'tall e r _': 4}
merge #5: ('fa', 's')
token_freqs 5:  {'fast _': 4, 'fast e r _': 3, 'tall _': 5, 'tall e r _': 4}
merge #6: ('fas', 't')
token_freqs 6:  {'fast _': 4, 'fast er _': 3, 'tall _': 5, 'tall er _': 4}
merge #7: ('e', 'r')
token_freqs 7:  {'fast _': 4, 'fast er_': 3, 'tall _': 5, 'tall er_': 4}
merge #8: ('er', '_')
token_freqs 8:  {'fast _': 4, 'fast er_': 3, 'tall_': 5, 'tall er_': 4}
merge #9: ('tall', '_')
to

In [16]:
print(symbols)

['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z', '_', '[UNK]', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_', 'taller_', 'faster_', 'ta', 'tal', 'tall', 'fa', 'fas', 'fast', 'er', 'er_', 'tall_', 'fast_']


In [17]:
print(list(token_freqs.keys()))

['fast_', 'fast er_', 'tall_', 'tall er_']


In [18]:
def segment_BPE(tokens, symbols):
    outputs = []
    for token in tokens:
        start, end = 0, len(token)
        cur_output = []
        # Segment token with the longest possible subwords from symbols
        while start < len(token) and start < end:
            if token[start: end] in symbols:
                cur_output.append(token[start: end])
                start = end
                end = len(token)
            else:
                end -= 1
        if start < len(token):
            cur_output.append('[UNK]')
        outputs.append(' '.join(cur_output))
    return outputs

In [19]:
tokens = ['tallest_', 'fatter_']
print(segment_BPE(tokens, symbols))

['tall e s t _', 'fa t t er_']
