In [1]:
import re
import collections

In [2]:
def format_word(text, space_token="_"):
    return " ".join(list(text)) + " " + space_token

In [3]:
def initialize_vocab(corpus):
    text = re.sub("\s+", " ", corpus.lower())
    all_words = re.findall(r"[\w']+|[.,!?;]", text)
    vocab = {}
    for word in all_words:
        word = format_word(word)
        vocab[word] = vocab.get(word, 0) + 1
    tokens = collections.Counter(text)
    vocab = dict(sorted(vocab.items(), key=lambda x: x[1], reverse=True))
    
    return vocab, tokens

In [4]:
corpus = """
Hi, my name is John and I love dog. His name is also John which is short for Johnny and he loves not only dogs but all kinds of animals as well.
Her name is Jane and it is funny how our names all start with the letter J and she loves cats. She is afraid of dogs.
"""

text = re.sub("\s+", " ", corpus.lower())
all_words = re.findall(r"[\w']+|[.,!?;]", text)
print("Text split:", all_words)

vocab = {}
for word in all_words:
    print("Before format:", word)
    word = format_word(word)
    print("After format:", word)
    vocab[word] = vocab.get(word, 0) + 1

vocab = dict(sorted(vocab.items(), key=lambda x: x[1], reverse=True))
tokens = collections.Counter(text)

Text split: ['hi', ',', 'my', 'name', 'is', 'john', 'and', 'i', 'love', 'dog', '.', 'his', 'name', 'is', 'also', 'john', 'which', 'is', 'short', 'for', 'johnny', 'and', 'he', 'loves', 'not', 'only', 'dogs', 'but', 'all', 'kinds', 'of', 'animals', 'as', 'well', '.', 'her', 'name', 'is', 'jane', 'and', 'it', 'is', 'funny', 'how', 'our', 'names', 'all', 'start', 'with', 'the', 'letter', 'j', 'and', 'she', 'loves', 'cats', '.', 'she', 'is', 'afraid', 'of', 'dogs', '.']
Before format: hi
After format: h i _
Before format: ,
After format: , _
Before format: my
After format: m y _
Before format: name
After format: n a m e _
Before format: is
After format: i s _
Before format: john
After format: j o h n _
Before format: and
After format: a n d _
Before format: i
After format: i _
Before format: love
After format: l o v e _
Before format: dog
After format: d o g _
Before format: .
After format: . _
Before format: his
After format: h i s _
Before format: name
After format: n a m e _
Before forma

In [5]:
vocab

{'i s _': 6,
 'a n d _': 4,
 '. _': 4,
 'n a m e _': 3,
 'j o h n _': 2,
 'l o v e s _': 2,
 'd o g s _': 2,
 'a l l _': 2,
 'o f _': 2,
 's h e _': 2,
 'h i _': 1,
 ', _': 1,
 'm y _': 1,
 'i _': 1,
 'l o v e _': 1,
 'd o g _': 1,
 'h i s _': 1,
 'a l s o _': 1,
 'w h i c h _': 1,
 's h o r t _': 1,
 'f o r _': 1,
 'j o h n n y _': 1,
 'h e _': 1,
 'n o t _': 1,
 'o n l y _': 1,
 'b u t _': 1,
 'k i n d s _': 1,
 'a n i m a l s _': 1,
 'a s _': 1,
 'w e l l _': 1,
 'h e r _': 1,
 'j a n e _': 1,
 'i t _': 1,
 'f u n n y _': 1,
 'h o w _': 1,
 'o u r _': 1,
 'n a m e s _': 1,
 's t a r t _': 1,
 'w i t h _': 1,
 't h e _': 1,
 'l e t t e r _': 1,
 'j _': 1,
 'c a t s _': 1,
 'a f r a i d _': 1}

In [6]:
tokens

Counter({' ': 59,
         's': 21,
         'n': 19,
         'a': 19,
         'o': 18,
         'e': 16,
         'h': 15,
         'i': 15,
         'l': 13,
         't': 11,
         'd': 9,
         'r': 7,
         'm': 6,
         'j': 5,
         'f': 5,
         'y': 4,
         '.': 4,
         'w': 4,
         'v': 3,
         'g': 3,
         'u': 3,
         'c': 2,
         ',': 1,
         'b': 1,
         'k': 1})

In [7]:
def get_bigram_counts(vocab):
    pairs = {}
    for word, count in vocab.items():
        symbols = word.split()
        for i in range(len(symbols) - 1):
            pair = (symbols[i], symbols[i + 1])
            pairs[pair] = pairs.get(pair, 0) + count
    pairs = dict(sorted(pairs.items(), key=lambda x: x[1], reverse=True))
    return pairs

In [8]:
pairs = {}
for word, count in vocab.items():
    symbols = word.split()
    # Iterate up to the second last word because we are
    # sliding across the text with a window size of 2
    for i in range(len(symbols) - 1):
        pair = (symbols[i], symbols[i + 1])  # get bigram
        pairs[pair] = pairs.get(pair, 0) + count

# Sort based on the counts
pairs = dict(sorted(pairs.items(), key=lambda x: x[1], reverse=True))
pairs

{('s', '_'): 16,
 ('e', '_'): 9,
 ('i', 's'): 7,
 ('a', 'n'): 6,
 ('n', 'd'): 5,
 ('d', '_'): 5,
 ('h', 'e'): 5,
 ('t', '_'): 5,
 ('.', '_'): 4,
 ('n', 'a'): 4,
 ('a', 'm'): 4,
 ('m', 'e'): 4,
 ('a', 'l'): 4,
 ('y', '_'): 4,
 ('r', '_'): 4,
 ('j', 'o'): 3,
 ('o', 'h'): 3,
 ('h', 'n'): 3,
 ('l', 'o'): 3,
 ('o', 'v'): 3,
 ('v', 'e'): 3,
 ('e', 's'): 3,
 ('d', 'o'): 3,
 ('o', 'g'): 3,
 ('l', 'l'): 3,
 ('l', '_'): 3,
 ('s', 'h'): 3,
 ('h', 'i'): 3,
 ('n', '_'): 2,
 ('g', 's'): 2,
 ('o', 'f'): 2,
 ('f', '_'): 2,
 ('i', '_'): 2,
 ('l', 's'): 2,
 ('h', '_'): 2,
 ('h', 'o'): 2,
 ('o', 'r'): 2,
 ('r', 't'): 2,
 ('n', 'n'): 2,
 ('n', 'y'): 2,
 ('e', 'r'): 2,
 ('i', 't'): 2,
 ('t', 'h'): 2,
 (',', '_'): 1,
 ('m', 'y'): 1,
 ('g', '_'): 1,
 ('s', 'o'): 1,
 ('o', '_'): 1,
 ('w', 'h'): 1,
 ('i', 'c'): 1,
 ('c', 'h'): 1,
 ('f', 'o'): 1,
 ('n', 'o'): 1,
 ('o', 't'): 1,
 ('o', 'n'): 1,
 ('n', 'l'): 1,
 ('l', 'y'): 1,
 ('b', 'u'): 1,
 ('u', 't'): 1,
 ('k', 'i'): 1,
 ('i', 'n'): 1,
 ('d', 's'): 1,
 ('n', 

In [9]:
def merge_vocab(pair, vocab_in):
    vocab_out = {}
    bigram = re.escape(" ".join(pair))
    p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
    bytepair = "".join(pair)
    for word in vocab_in:
        w_out = p.sub(bytepair, word)
        vocab_out[w_out] = vocab_in[word]
    return vocab_out, (bigram, bytepair)

In [10]:
def find_merges(vocab, tokens, num_merges):
    merges = []
    for _ in range(num_merges):
        pairs = get_bigram_counts(vocab)
        best_pair = max(pairs, key=pairs.get)
        best_count = pairs[best_pair]
        vocab, (bigram, bytepair) = merge_vocab(best_pair, vocab)
        merges.append((r"(?<!\S)" + bigram + r"(?!\S)", bytepair))
        tokens[bytepair] = best_count
    return vocab, tokens, merges

In [11]:
best_pair = max(pairs, key=pairs.get)
print("Best pair:", best_pair)
best_count = pairs[best_pair]
print("Best count:", best_count)

Best pair: ('s', '_')
Best count: 16


In [12]:
bigram = re.escape(" ".join(best_pair))
print("Bigram:", bigram)
# Negative lookbehind -> r"(?<!\S)" -> What's behind that is NOT a character (whitespaces are okay)?
# Negative lookahead ->  r"(?!\S)" -> What's ahead that is NOT a character (whitespaces are okay)?
# Basically, the negative lookbehind and lookahead regexes are looking for bigrams isolated by spaces before and after
p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
print("p:", p)
bytepair = "".join(best_pair)  # merge rule
print("bytepair:", bytepair)

Bigram: s\ _
p: re.compile('(?<!\\S)s\\ _(?!\\S)')
bytepair: s_


In [13]:
vocab_out = {}
for word in vocab:
    # Subsitute the part in the word with the bytepair
    # if there is a part in the word that matches the regex criteria `p`
    w_out = p.sub(bytepair, word)
    print("")
    print("word:", word)
    print("w_out:", w_out)
    print("---")
    vocab_out[w_out] = vocab[word]


word: i s _
w_out: i s_
---

word: a n d _
w_out: a n d _
---

word: . _
w_out: . _
---

word: n a m e _
w_out: n a m e _
---

word: j o h n _
w_out: j o h n _
---

word: l o v e s _
w_out: l o v e s_
---

word: d o g s _
w_out: d o g s_
---

word: a l l _
w_out: a l l _
---

word: o f _
w_out: o f _
---

word: s h e _
w_out: s h e _
---

word: h i _
w_out: h i _
---

word: , _
w_out: , _
---

word: m y _
w_out: m y _
---

word: i _
w_out: i _
---

word: l o v e _
w_out: l o v e _
---

word: d o g _
w_out: d o g _
---

word: h i s _
w_out: h i s_
---

word: a l s o _
w_out: a l s o _
---

word: w h i c h _
w_out: w h i c h _
---

word: s h o r t _
w_out: s h o r t _
---

word: f o r _
w_out: f o r _
---

word: j o h n n y _
w_out: j o h n n y _
---

word: h e _
w_out: h e _
---

word: n o t _
w_out: n o t _
---

word: o n l y _
w_out: o n l y _
---

word: b u t _
w_out: b u t _
---

word: k i n d s _
w_out: k i n d s_
---

word: a n i m a l s _
w_out: a n i m a l s_
---

word: a s _
w

In [14]:
vocab_out

{'i s_': 6,
 'a n d _': 4,
 '. _': 4,
 'n a m e _': 3,
 'j o h n _': 2,
 'l o v e s_': 2,
 'd o g s_': 2,
 'a l l _': 2,
 'o f _': 2,
 's h e _': 2,
 'h i _': 1,
 ', _': 1,
 'm y _': 1,
 'i _': 1,
 'l o v e _': 1,
 'd o g _': 1,
 'h i s_': 1,
 'a l s o _': 1,
 'w h i c h _': 1,
 's h o r t _': 1,
 'f o r _': 1,
 'j o h n n y _': 1,
 'h e _': 1,
 'n o t _': 1,
 'o n l y _': 1,
 'b u t _': 1,
 'k i n d s_': 1,
 'a n i m a l s_': 1,
 'a s_': 1,
 'w e l l _': 1,
 'h e r _': 1,
 'j a n e _': 1,
 'i t _': 1,
 'f u n n y _': 1,
 'h o w _': 1,
 'o u r _': 1,
 'n a m e s_': 1,
 's t a r t _': 1,
 'w i t h _': 1,
 't h e _': 1,
 'l e t t e r _': 1,
 'j _': 1,
 'c a t s_': 1,
 'a f r a i d _': 1}

In [15]:
merges = []
merges.append((r"(?<!\S)" + bigram + r"(?!\S)", bytepair))  # store the merge rule
tokens[bytepair] = best_count

In [16]:
merges

[('(?<!\\S)s\\ _(?!\\S)', 's_')]

In [17]:
tokens

Counter({' ': 59,
         's': 21,
         'n': 19,
         'a': 19,
         'o': 18,
         'e': 16,
         's_': 16,
         'h': 15,
         'i': 15,
         'l': 13,
         't': 11,
         'd': 9,
         'r': 7,
         'm': 6,
         'j': 5,
         'f': 5,
         'y': 4,
         '.': 4,
         'w': 4,
         'v': 3,
         'g': 3,
         'u': 3,
         'c': 2,
         ',': 1,
         'b': 1,
         'k': 1})

In [18]:
def fit(text, num_merges):
    vocab, tokens = initialize_vocab(text)
    characters = set(tokens.keys())
    vocab, tokens, merges = find_merges(vocab, tokens, num_merges)

    return characters, vocab, tokens, merges

In [19]:
characters, vocab, tokens, merges = fit(text, num_merges=10)

In [20]:
characters

{' ',
 ',',
 '.',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'y'}

In [21]:
vocab

{'is_': 6,
 'and_': 4,
 '._': 4,
 'nam e_': 3,
 'j o h n _': 2,
 'l o v e s_': 2,
 'd o g s_': 2,
 'a l l _': 2,
 'o f _': 2,
 's h e_': 2,
 'h i _': 1,
 ', _': 1,
 'm y _': 1,
 'i _': 1,
 'l o v e_': 1,
 'd o g _': 1,
 'h is_': 1,
 'a l s o _': 1,
 'w h i c h _': 1,
 's h o r t_': 1,
 'f o r _': 1,
 'j o h n n y _': 1,
 'h e_': 1,
 'n o t_': 1,
 'o n l y _': 1,
 'b u t_': 1,
 'k i n d s_': 1,
 'an i m a l s_': 1,
 'a s_': 1,
 'w e l l _': 1,
 'h e r _': 1,
 'j an e_': 1,
 'i t_': 1,
 'f u n n y _': 1,
 'h o w _': 1,
 'o u r _': 1,
 'nam e s_': 1,
 's t a r t_': 1,
 'w i t h _': 1,
 't h e_': 1,
 'l e t t e r _': 1,
 'j _': 1,
 'c a t s_': 1,
 'a f r a i d_': 1}

In [22]:
tokens

Counter({' ': 59,
         's': 21,
         'n': 19,
         'a': 19,
         'o': 18,
         'e': 16,
         's_': 16,
         'h': 15,
         'i': 15,
         'l': 13,
         't': 11,
         'd': 9,
         'e_': 9,
         'r': 7,
         'is_': 7,
         'm': 6,
         'an': 6,
         'j': 5,
         'f': 5,
         'd_': 5,
         't_': 5,
         'y': 4,
         '.': 4,
         'w': 4,
         'and_': 4,
         '._': 4,
         'na': 4,
         'nam': 4,
         'v': 3,
         'g': 3,
         'u': 3,
         'c': 2,
         ',': 1,
         'b': 1,
         'k': 1})

In [23]:
merges

[('(?<!\\S)s\\ _(?!\\S)', 's_'),
 ('(?<!\\S)e\\ _(?!\\S)', 'e_'),
 ('(?<!\\S)i\\ s_(?!\\S)', 'is_'),
 ('(?<!\\S)a\\ n(?!\\S)', 'an'),
 ('(?<!\\S)d\\ _(?!\\S)', 'd_'),
 ('(?<!\\S)t\\ _(?!\\S)', 't_'),
 ('(?<!\\S)an\\ d_(?!\\S)', 'and_'),
 ('(?<!\\S)\\.\\ _(?!\\S)', '._'),
 ('(?<!\\S)n\\ a(?!\\S)', 'na'),
 ('(?<!\\S)na\\ m(?!\\S)', 'nam')]

In [24]:
class BytePairEncoding:
    def __init__(self):
        self.characters = None
        self.vocab = None
        self.tokens = None
        self.merges = None
        
    def format_word(self, text, space_token="_"):
        return " ".join(list(text)) + " " + space_token
    
    def initialize_vocab(self, corpus):
        text = re.sub("\s+", " ", corpus.lower())
        all_words = re.findall(r"[\w']+|[.,!?;]", text)
        vocab = {}
        for word in all_words:
            word = self.format_word(word)
            vocab[word] = vocab.get(word, 0) + 1
        tokens = collections.Counter(text)
        vocab = dict(sorted(vocab.items(), key=lambda x: x[1], reverse=True))
        
        return vocab, tokens
    
    def get_bigram_counts(self, vocab):
        pairs = {}
        for word, count in vocab.items():
            symbols = word.split()
            for i in range(len(symbols) - 1):
                pair = (symbols[i], symbols[i + 1])
                pairs[pair] = pairs.get(pair, 0) + count
        pairs = dict(sorted(pairs.items(), key=lambda x: x[1], reverse=True))
        
        return pairs
    
    def merge_vocab(self, pair, vocab_in):
        vocab_out = {}
        bigram = re.escape(" ".join(pair))
        p = re.compile(r"(?<!\S)" + bigram + r"(?!\S)")
        bytepair = "".join(pair)
        for word in vocab_in:
            w_out = p.sub(bytepair, word)
            vocab_out[w_out] = vocab_in[word]
            
        return vocab_out, (bigram, bytepair)
    
    def find_merges(self, vocab, tokens, num_merges):
        merges = []
        for _ in range(num_merges):
            pairs = self.get_bigram_counts(vocab)
            best_pair = max(pairs, key=pairs.get)
            best_count = pairs[best_pair]
            vocab, (bigram, bytepair) = self.merge_vocab(best_pair, vocab)
            merges.append((r"(?<!\S)" + bigram + r"(?!\S)", bytepair))
            tokens[bytepair] = best_count
            
        return vocab, tokens, merges
    
    def fit(self, text, num_merges):
        vocab, tokens = self.initialize_vocab(text)
        self.characters = set(tokens.keys())
        self.vocab, self.tokens, self.merges = self.find_merges(vocab, tokens, num_merges)    