# Raw byte tokenizer for machine code representation
Based on the byte pair encoding principle

Training process:
1. Add all bytes to the vocabulary (00 - FF)
2. From the training sequence, find the most frequent pair of tokens
3. Merge the pair of bytes into a new token, with a unique id
4. Save the merge rule
5. Repeat from step 2 until the vocabulary size is reached

Encoding process:
1. Apply merge rules to the input sequence until no more merge rules can be applied
2. Convert the resulting sequence of tokens to a sequence of ids

Decoding process:
1. Convert the sequence of ids to a sequence of tokens. That's it.

In [1]:
import random
import lorem
from collections import defaultdict
from collections import Counter

In [2]:
random.seed(0)

In [3]:
def autolog(message):
    "Automatically log the current function details."
    import inspect, logging
    # Get the previous frame in the stack, otherwise it would
    # be this function!!!
    func = inspect.currentframe().f_back.f_code
    # Dump the message + the name of this function to the log.
    message = str(message)
    print("[%s] %s" % (
        func.co_name,
        message
    ))

In [4]:
class ByteTokenizer:
    def __init__(self, vocabulary_size):
        self.vocabulary_size = vocabulary_size
        self.vocabulary = None
        self.token2id = None
        self.id2token = None
        self.merge_rules = None

    @staticmethod
    def _find_most_frequent_pair(sequence):
        pair_count = Counter(zip(sequence, sequence[1:]))
        most_frequent = pair_count.most_common(1)[0][0]
        p1 = most_frequent[0]
        p2 = most_frequent[1]
        newtoken = p1 + p2
        return (p1, p2, newtoken)


    @staticmethod
    def _apply_merge_rule(sequence, merge_rule):
        new_sequence = []
        p1, p2, newtoken = merge_rule
        couldApply = False
        len_sequence = len(sequence)

        i = 0
        while i < len_sequence:
            current = sequence[i]
            if current == p1 and i < len_sequence - 1 and sequence[i + 1] == p2:
                new_sequence.append(newtoken)
                i += 2
                couldApply = True
            else:
                new_sequence.append(current)
                i += 1
        return new_sequence, couldApply

    def train(self, sequence):
        sequence = [bytes([x]) for x in sequence]
        
        self.vocabulary = [bytes([x]) for x in range(256)]
        self.merge_rules = []

        i = 0
        while len(self.vocabulary) < self.vocabulary_size:
            i += 1
            if i % 100 == 0:
                autolog(f"{i} / {self.vocabulary_size}")
            merge_rule = self._find_most_frequent_pair(sequence)
            self.vocabulary.append(merge_rule[2])
            self.merge_rules.append(merge_rule)
            sequence, couldApply = self._apply_merge_rule(sequence, merge_rule)
            if len(sequence) == 1:
                # we have tokenized the whole sequence or we cannot apply any more merge rules
                break
        self.token2id = {token: i for i, token in enumerate(self.vocabulary)}
        self.id2token = {i: token for i, token in enumerate(self.vocabulary)}
        return len(self.vocabulary)

    def tokenize(self, sequence):
        sequence = [bytes([x]) for x in sequence]
        pair_existence = {merge_rule: True for merge_rule in self.merge_rules}

        #apply all merge rules until we cannot apply any more
        while True:
            for merge_rule in self.merge_rules:
                if pair_existence[merge_rule]:
                    sequence, couldApply = self._apply_merge_rule(sequence, merge_rule)
                    pair_existence[merge_rule] = couldApply
            if all([not pair_existence[merge_rule] for merge_rule in self.merge_rules]):
                break

        # we now have a sequence of tokens that should match the vocabulary
        # return the token ids
        return [self.token2id[token] for token in sequence]

    def detokenize(self, sequence):
        return [self.id2token[i] for i in sequence]    

In [5]:
training_text = " ".join([lorem.text() for _ in range(200)])
training_text = training_text.encode("ascii")
len(training_text)

287960

In [12]:
# profiling
import cProfile
import pstats

tokenizer = ByteTokenizer(1500)
# profile the training
cProfile.run("tokenizer.train(training_text)", "train_stats")
p = pstats.Stats("train_stats")
p.sort_stats("cumulative")
p.print_stats()

Sun Dec  3 00:03:26 2023    train_stats

         108022687 function calls in 67.909 seconds

   Ordered by: cumulative time

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
        1    0.000    0.000   67.909   67.909 {built-in method builtins.exec}
        1    0.000    0.000   67.909   67.909 <string>:1(<module>)
        1    0.778    0.778   67.909   67.909 /tmp/ipykernel_3899/1356102597.py:56(train)
     1244   39.842    0.032   51.295    0.041 /tmp/ipykernel_3899/1356102597.py:34(_apply_merge_rule)
     1244    0.152    0.000   15.745    0.013 /tmp/ipykernel_3899/1356102597.py:9(_find_most_frequent_pair)
     1244    0.016    0.000   14.256    0.011 /usr/lib/python3.11/collections/__init__.py:587(__init__)
     1244    0.009    0.000   14.239    0.011 /usr/lib/python3.11/collections/__init__.py:660(update)
     1244   14.219    0.011   14.219    0.011 {built-in method _collections._count_elements}
 54425757    6.315    0.000    6.315    0.000 {built-in me

<pstats.Stats at 0x7f7d3859c450>

In [7]:
tokenizer = ByteTokenizer(1500)
tokenizer.train(training_text)

1500

In [8]:
tokenizer.vocabulary

[b'\x00',
 b'\x01',
 b'\x02',
 b'\x03',
 b'\x04',
 b'\x05',
 b'\x06',
 b'\x07',
 b'\x08',
 b'\t',
 b'\n',
 b'\x0b',
 b'\x0c',
 b'\r',
 b'\x0e',
 b'\x0f',
 b'\x10',
 b'\x11',
 b'\x12',
 b'\x13',
 b'\x14',
 b'\x15',
 b'\x16',
 b'\x17',
 b'\x18',
 b'\x19',
 b'\x1a',
 b'\x1b',
 b'\x1c',
 b'\x1d',
 b'\x1e',
 b'\x1f',
 b' ',
 b'!',
 b'"',
 b'#',
 b'$',
 b'%',
 b'&',
 b"'",
 b'(',
 b')',
 b'*',
 b'+',
 b',',
 b'-',
 b'.',
 b'/',
 b'0',
 b'1',
 b'2',
 b'3',
 b'4',
 b'5',
 b'6',
 b'7',
 b'8',
 b'9',
 b':',
 b';',
 b'<',
 b'=',
 b'>',
 b'?',
 b'@',
 b'A',
 b'B',
 b'C',
 b'D',
 b'E',
 b'F',
 b'G',
 b'H',
 b'I',
 b'J',
 b'K',
 b'L',
 b'M',
 b'N',
 b'O',
 b'P',
 b'Q',
 b'R',
 b'S',
 b'T',
 b'U',
 b'V',
 b'W',
 b'X',
 b'Y',
 b'Z',
 b'[',
 b'\\',
 b']',
 b'^',
 b'_',
 b'`',
 b'a',
 b'b',
 b'c',
 b'd',
 b'e',
 b'f',
 b'g',
 b'h',
 b'i',
 b'j',
 b'k',
 b'l',
 b'm',
 b'n',
 b'o',
 b'p',
 b'q',
 b'r',
 b's',
 b't',
 b'u',
 b'v',
 b'w',
 b'x',
 b'y',
 b'z',
 b'{',
 b'|',
 b'}',
 b'~',
 b'\x7f',
 b'\x80',


In [9]:
test_sentence = lorem.sentence().encode("ascii")
test_sentence

b'Voluptatem non non labore dolorem sed.'

In [10]:
tokenizer.tokenize(test_sentence)

[434, 657, 991, 397, 46]

In [11]:
tokenizer.detokenize(tokenizer.tokenize(test_sentence))

[b'Voluptatem ', b'non non ', b'labore dolorem ', b'sed', b'.']