In [18]:
class BPETokenizer():
    def __init__(self):
        self.merge = {}
        self.id_to_char = {}
        self.char_to_id = {}
    def train(self, input_texts, vocab_size):
        '''
        BPE Algorithm Training
        :param input_texts: list of input texts
        :param vocab_size: vocabulary size
        :return:
        '''
        #1.tokenize the input
        unique_chars = list(set(list(input_texts)))
        #2.initialize the base vocabulary
        id_to_char = {idx:char for idx, char in enumerate(unique_chars)}
        char_to_id = {char:idx for idx, char in enumerate(unique_chars)}
        #3.map the input texts to id using vocabulary
        ids = [char_to_id[char] for char in input_texts]

        merge_times = vocab_size-len(unique_chars)
        vocab_size = len(unique_chars)-1
        merge = {}
        #4.iteratively merge tokens until the vocabulary reaches the target size
        for i in range(merge_times):
            if len(ids) == 1:
                break
            # count the frequencies of adjacent subwords
            stats = self.stats(ids)

            #find the most frequent pair of adjacent subwords
            pair = max(stats, key=stats.get)
            #update the vocab(bi-direction mapping)
            vocab_size += 1
            id_to_char[vocab_size] = id_to_char[pair[0]] + id_to_char[pair[1]]
            char_to_id[id_to_char[pair[0]] + id_to_char[pair[1]]] = vocab_size
            #record the merge Rule
            merge[pair] = vocab_size

            #merge ids based on the current merges
            ids = self.merge_ids(ids, pair, vocab_size)
        self.merge = merge
        self.id_to_char = id_to_char
        self.char_to_id = char_to_id

    def stats(self, ids):
        '''
        count the frequencies of adjacent subwords
        :param ids:
        :return:
        '''
        count = {}
        for item in zip(ids[:-1], ids[1:]):
            count[item] = count.get(item, 0) + 1
        return count
    def merge_ids(self, ids, pair, idx):
        '''
        merges adjacent subword pairs in the corpus and updates ids
        :param ids: the list of ids before the update
        :param pair: the specific pair of ids to be merged
        :param idx: the new id assigned to theis merged pair in the vocabulary
        :return: a new list of ids with the pairs replaced
        '''
        new_ids = []
        i = 0

        while i< len(ids):
            if ids[i]==pair[0] and i<len(ids)-1 and ids[i+1]==pair[1]:
                new_ids.append(idx)
                i+=2
            else:
                new_ids.append(ids[i])
                i+=1
        return new_ids

    def encode(self, text):
        '''
        Tokenize a text into a list of ids
        :param text:
        :return:
        '''
        #1 segmentation the text
        ids = [self.char_to_id[c] for c in text]
        print(ids)
        #2 perform multiple merges using the merge dictionary to get the final output
        while len(ids)>=2:
            stats = self.stats(ids)
            pair = min(stats, key=lambda p:self.merge.get(p, float('inf')))
            if pair not in self.merge:
                break
            ids = self.merge_ids(ids, pair, self.merge[pair])
        return ids

    def decode(self, ids):
        '''
        convert the ids back to text
        :param ids:
        :return:
        '''
        return "".join([self.id_to_char[index] for index in ids])


In [19]:
t1 = BPETokenizer()

In [20]:
train_text = """
    hello, this is a training text. The tokenizer will split the text into words and assign an id
    to each word. This is a fantastic world.
    """
t1.train(input_texts=train_text, vocab_size=48)
t1.id_to_char

{0: 'n',
 1: ',',
 2: 'l',
 3: 'z',
 4: 'd',
 5: 'g',
 6: '\n',
 7: 'f',
 8: 'a',
 9: 'w',
 10: 'e',
 11: 'o',
 12: 'r',
 13: 'p',
 14: 'c',
 15: 'x',
 16: 'h',
 17: '.',
 18: 'k',
 19: ' ',
 20: 't',
 21: 'T',
 22: 'i',
 23: 's',
 24: '  ',
 25: ' t',
 26: 's ',
 27: 'is ',
 28: ' w',
 29: '\n  ',
 30: '\n    ',
 31: 'he',
 32: 'in',
 33: ' wo',
 34: ' wor',
 35: 'an',
 36: 'll',
 37: 'his ',
 38: 'his is ',
 39: 'his is a',
 40: ' te',
 41: ' tex',
 42: ' text',
 43: '. ',
 44: '. T',
 45: 'to',
 46: ' word',
 47: 'as'}

In [21]:
t1.encode("hello world")

[16, 10, 2, 2, 11, 19, 9, 11, 12, 2, 4]


[31, 36, 11, 34, 2, 4]

In [23]:
t1.decode([31, 36, 11, 34, 2, 4])

'hello world'

In [24]:
t1.decode([16, 10, 2, 2, 11, 19, 9, 11, 12, 2, 4])

'hello world'