### A simple corpus
[From stanford NLP tutorial](https://youtu.be/tOMjTCO0htA?si=AyaI1zthwS1dFQgS)



In [1]:
corpus = "low low low low low lowest lowest newer newer newer newer newer newer wider wider wider new new"

### Learner module

In [2]:
special_token = "#"
# _ can also be used instead of #


# get space seaprated tokens
# all words in text were whitespace separated
tokens = corpus.split(" ")

In [3]:
from collections import Counter

# count token types
token_counts = Counter(tokens)
print(token_counts)

Counter({'newer': 6, 'low': 5, 'wider': 3, 'lowest': 2, 'new': 2})


In [4]:
# the buffer serves as an intermediate variable
# where we can do the searching and merging
# ultimately we can about the vocab and not the buffer, so ...........
def create_buffer(token_counts):
    buffer = {}

    for token_type, freq in token_counts.items():
        # add a space between all the chars in the token
        # also add the special token in the end
        buffer_entry = " ".join(list(token_type)) + \
            f" {special_token}"  # mind the space!
        buffer[buffer_entry] = freq

    return buffer

buffer = create_buffer(token_counts)
print(buffer)    

{'l o w #': 5, 'l o w e s t #': 2, 'n e w e r #': 6, 'w i d e r #': 3, 'n e w #': 2}


In [5]:
# 3 steps
# get pair frequencies
# find the best adjacent token pair, or symbol pair (by literature)
# merge them

def get_pair_frequencies(buffer):
    pair_counts = dict()
    
    for entry, frequency in buffer.items():
        chars = entry.split()
        for idx, c in enumerate(chars):
            if c == special_token or idx + 1 == len(chars):
                continue
            
            pair = (c, chars[idx + 1])
            
            # pairs are parts of tokens
            # if a token appears x times, so does the pair in it
            if pair in pair_counts.keys():
                pair_counts[pair] += frequency
            else:
                pair_counts[pair] = frequency 
    
    return pair_counts
    

get_pair_frequencies(buffer)

{('l', 'o'): 7,
 ('o', 'w'): 7,
 ('w', '#'): 7,
 ('w', 'e'): 8,
 ('e', 's'): 2,
 ('s', 't'): 2,
 ('t', '#'): 2,
 ('n', 'e'): 8,
 ('e', 'w'): 8,
 ('e', 'r'): 9,
 ('r', '#'): 9,
 ('w', 'i'): 3,
 ('i', 'd'): 3,
 ('d', 'e'): 3}

In [6]:
# new symbol is the pair with the best frequency
# (merged as a string, with a space) e.g. "l o" from ("l", "o")
def merge(new_symbol, buffer):
    merged_buffer = dict()
    for token_type in buffer:
        merged_token = token_type.replace(new_symbol, new_symbol.replace(" ", ""))
        merged_buffer[merged_token] = buffer[token_type]
        
    return merged_buffer

merge("e r", buffer)
    

{'l o w #': 5,
 'l o w e s t #': 2,
 'n e w er #': 6,
 'w i d er #': 3,
 'n e w #': 2}

In [7]:
from tqdm.auto import trange

N_MERGES = 10

def fit(corpus, buffer, N_MERGES=N_MERGES):
    vocabulary = set(list(corpus.replace(" ", special_token)))
    
    for _ in trange(N_MERGES):
        # find the best adjacent pair
        pair_counts = get_pair_frequencies(buffer)
        best_pair = max(pair_counts, key=pair_counts.get) #type: ignore
        
        # update vocabulary
        new_vocab_entry = "".join(best_pair)
        vocabulary.add(new_vocab_entry)
        
        # merge
        new_symbol = " ".join(best_pair)
        buffer = merge(new_symbol, buffer)    
        
        
    return vocabulary



vocabulary = fit(corpus, buffer)

  0%|          | 0/10 [00:00<?, ?it/s]

In [8]:
sorted(list(vocabulary))

['#',
 'd',
 'e',
 'er',
 'er#',
 'i',
 'l',
 'lo',
 'low',
 'low#',
 'n',
 'ne',
 'new',
 'newer#',
 'o',
 'r',
 's',
 't',
 'w',
 'wi',
 'wid']

In [9]:
# let's put all that into a class

from typing import Dict, Set

class Learner(object):
    def __init__(self, corpus: str, special_token="#") -> None:
        self.corpus = corpus
        self.special_token = special_token
        
        self.buffer = dict()
        self.__create_buffer()
        
        self.vocabulary = set()
        self.__init_vocab()
        
        # order is important later on for using the tokenizer
        # to encode new text
        self.merge_order = dict()
        self.merge_order_tracker = 0
        
    def __create_buffer(self) -> None:
        tokens = self.corpus.split(" ")
        token_counts = Counter(tokens)
        
        for token_type, freq in token_counts.items():
            # add a space between all the chars in the token
            # also add the special token in the end
            buffer_entry = " ".join(list(token_type)) + \
                f" {special_token}"  # mind the space!
            self.buffer[buffer_entry] = freq
    
    def __init_vocab(self) -> None:
        self.vocabulary = set(self.corpus.replace(" ", "") + self.special_token)
        
    def __get_pair_frequencies(self) -> Dict:
        pair_counts = dict()
        
        for entry, frequency in self.buffer.items():
            chars = entry.split()
            for idx, c in enumerate(chars):
                if c == self.special_token or idx + 1 == len(chars):
                    continue
                
                pair = (c, chars[idx + 1])
                
                # pairs are parts of tokens
                # if a token appears x times, so does the pair in it
                if pair in pair_counts.keys():
                    pair_counts[pair] += frequency
                else:
                    pair_counts[pair] = frequency 
        
        return pair_counts
    
    # new symbol is the pair with the best frequency
    # (merged as a string, with a space) e.g. "l o" from ("l", "o")

    def __merge(self, new_symbol) -> None:
        merged_buffer = dict()
        for token_type in buffer:
            merged_token = token_type.replace(
                new_symbol, new_symbol.replace(" ", ""))
            merged_buffer[merged_token] = self.buffer[token_type]
                        

        # update buffer
        self.buffer = merged_buffer
        
        
    # k number of merges
    def fit(self, k: int) -> None:
        
        for _ in trange(k):
            # find the best adjacent pair
            pair_counts = self.__get_pair_frequencies()
            best_pair = max(pair_counts, key=pair_counts.get) #type: ignore
            
            # update vocabulary
            new_vocab_entry = "".join(best_pair)
            self.vocabulary.add(new_vocab_entry)
            
            # merge
            new_symbol = " ".join(best_pair)
            self.buffer = merge(new_symbol, self.buffer)
            
            # update merge order
            assert best_pair not in self.merge_order.keys()
            self.merge_order[best_pair] = self.merge_order_tracker
            self.merge_order_tracker += 1
            
    
learner = Learner(corpus)
learner.fit(10)
    

  0%|          | 0/10 [00:00<?, ?it/s]

In [10]:
learner.vocabulary

{'#',
 'd',
 'e',
 'er',
 'er#',
 'i',
 'l',
 'lo',
 'low',
 'low#',
 'n',
 'ne',
 'new',
 'newer#',
 'o',
 'r',
 's',
 't',
 'w',
 'wi',
 'wid'}

In [11]:
learner.merge_order

{('e', 'r'): 0,
 ('er', '#'): 1,
 ('n', 'e'): 2,
 ('ne', 'w'): 3,
 ('l', 'o'): 4,
 ('lo', 'w'): 5,
 ('new', 'er#'): 6,
 ('low', '#'): 7,
 ('w', 'i'): 8,
 ('wi', 'd'): 9}

### Segmenter

In [12]:
from typing import Dict, List

class Segmenter(object):
    def __init__(self, special_token="#") -> None:
        self.special_token = special_token
    
    
    def __get_adjacent_pairs(self, tokens: List[str]) -> Set:            
        pairs = set()
        
        for idx, c in enumerate(tokens):
            if idx + 1 == len(tokens):
                continue
            
            pairs.add((c, tokens[idx + 1]))
            
        
        return pairs
    
    def __merge(self, tokens: List[str], pair: tuple) -> List[str]:  
        new_symbol = " ".join(pair)
        space_joined_tokens = " ".join(tokens)
        
        encoded = space_joined_tokens.replace(new_symbol, "".join(pair)).split()
        
        return encoded
    
    
    def encode(self, word: str, merge_order: Dict, vocabulary: Set) -> List[str]:
        if self.special_token not in word:
            word = word + self.special_token
            
        tokens = list(word)
        
            
        if len(word) == 1 and word in vocabulary:
            return list(word)

        while True:
            pairs = self.__get_adjacent_pairs(tokens)    
            merges_for_pair = [
                (p, merge_order[p]) for p in pairs if p in merge_order 
            ]
            if not merges_for_pair:
                break
            
            # greedily pick the merge with the lowest serial / order
            earliest_merge = min(merges_for_pair, key=lambda k: k[1])
            merge_pair = earliest_merge[0]
            
            # encode
            tokens = self.__merge(tokens, merge_pair)
            
        return tokens

segmenter = Segmenter()
segmenter.encode("newer", learner.merge_order, learner.vocabulary)

['newer#']

### Full tokeniser

In [13]:
from tqdm.auto import tqdm

class BytePairTokeniser(object):
    def __init__(self, corpus: str, special_token="#") -> None:
        self.corpus = corpus
        self.special_token = special_token
        
        self.learner = Learner(self.corpus, self.special_token)
        self.segmenter = Segmenter(self.special_token)
        
    def get_vocabulary(self) -> Set:
        return self.learner.vocabulary
    
    def fit(self, k: int) -> None:
        self.learner.fit(k)
        
    def encode_word(self, word: str) -> List[str]:
        tokens = self.segmenter.encode(word, self.learner.merge_order, self.learner.vocabulary)
        return tokens
    
    def encode_sentence(self, sentence: str) -> List[str]:
        words = sentence.split()
        tokens = list()
        
        for _, w in tqdm(enumerate(words)):        
            tokens.append(self.segmenter.encode(
                w, self.learner.merge_order, self.learner.vocabulary))
        
        return tokens
    

In [14]:
tokeniser = BytePairTokeniser(corpus)

In [15]:
tokeniser.fit(10)

  0%|          | 0/10 [00:00<?, ?it/s]

In [16]:
tokeniser.encode_sentence(corpus)

0it [00:00, ?it/s]

[['low#'],
 ['low#'],
 ['low#'],
 ['low#'],
 ['low#'],
 ['low', 'e', 's', 't', '#'],
 ['low', 'e', 's', 't', '#'],
 ['newer#'],
 ['newer#'],
 ['newer#'],
 ['newer#'],
 ['newer#'],
 ['newer#'],
 ['wid', 'er#'],
 ['wid', 'er#'],
 ['wid', 'er#'],
 ['new', '#'],
 ['new', '#']]