In [1]:
from array import array
from collections import Counter

In [2]:
s = "hello world!!!? (안녕하세요!) lol123 😉 fffffffff"

In [3]:
byte_list = list(s.encode("utf-8"))
print(byte_list)

[104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100, 33, 33, 33, 63, 32, 40, 236, 149, 136, 235, 133, 149, 237, 149, 152, 236, 132, 184, 236, 154, 148, 33, 41, 32, 108, 111, 108, 49, 50, 51, 32, 240, 159, 152, 137, 32, 102, 102, 102, 102, 102, 102, 102, 102, 102]


In [4]:
def create_pairs(l):
    # Creating paris of consecutive elements
    pairs = []
    for a, b in zip(l, l[1:]):
        pairs.append((a, b))
    return pairs

assert create_pairs([1, 2, 3, 4]) == [(1, 2), (2, 3), (3, 4)]

In [5]:
byte_pairs = create_pairs(byte_list)
print(byte_pairs)

[(104, 101), (101, 108), (108, 108), (108, 111), (111, 32), (32, 119), (119, 111), (111, 114), (114, 108), (108, 100), (100, 33), (33, 33), (33, 33), (33, 63), (63, 32), (32, 40), (40, 236), (236, 149), (149, 136), (136, 235), (235, 133), (133, 149), (149, 237), (237, 149), (149, 152), (152, 236), (236, 132), (132, 184), (184, 236), (236, 154), (154, 148), (148, 33), (33, 41), (41, 32), (32, 108), (108, 111), (111, 108), (108, 49), (49, 50), (50, 51), (51, 32), (32, 240), (240, 159), (159, 152), (152, 137), (137, 32), (32, 102), (102, 102), (102, 102), (102, 102), (102, 102), (102, 102), (102, 102), (102, 102), (102, 102)]


In [6]:
def replace(byte_list, idx, pair_to_replace):
    # replaces all consecutive pair that are equivalent to pair_to_replace with the value idx
    ptr = 0
    new_tokens = []
    while ptr < len(byte_list):
        current_pair = tuple(byte_list[ptr:(ptr+2)])
        if current_pair == pair_to_replace:
            new_tokens.append(idx)
            ptr += 2
        else:
            new_tokens.append(byte_list[ptr])
            ptr += 1

    return new_tokens


def merge(byte_list, idx, merges):
    # Create one merge, i.e. take the most common pair and replace those paris with idx
    byte_pairs = create_pairs(byte_list)
    most_common_pair, _ = Counter(byte_pairs).most_common(n=1)[0]
    merges[most_common_pair] = idx

    return replace(byte_list, idx, most_common_pair)


assert replace([3, 2, 5, 9, 3, 2], 10, (3, 2)) == [10, 5, 9, 10]


In [7]:
num_merges = 5
idx = 256  # utf-8 has ids until 255 (1 byte), so starting after that
new_tokens = byte_list[:]
merges = {}

for _ in range(num_merges):
    new_tokens = merge(new_tokens, idx, merges)
    idx += 1

In [8]:
print(f"Length before merge: {len(byte_list)}")
print(f"Length after merge: {len(new_tokens)}")

Length before merge: 56
Length after merge: 46


In [9]:
print(merges)

{(102, 102): 256, (256, 256): 257, (108, 111): 258, (33, 33): 259, (104, 101): 260}


In [10]:
def encode(byte_list, merges):
    new_tokens = byte_list[:]

    # This works since the elements was added to merges
    # from most to least common
    # Also, since python 3.7 dicts are ordered
    for m_pair, idx in merges.items():
        new_tokens = replace(new_tokens, idx, m_pair)

    return new_tokens

print(len(encode(byte_list, merges)))

46


In [11]:
def decode(encoded_list, merges):
    old_tokens = encoded_list[:]
    for m_pair, idx in reversed(merges.items()):
        new_tokens = []
        for ot in old_tokens:
            if ot == idx:
                new_tokens.append(m_pair[0])
                new_tokens.append(m_pair[1])
            else:
                new_tokens.append(ot)

        old_tokens = new_tokens[:]
    return old_tokens


# [6, 5, 9] -> [6, 5, 1, 8]; 9 -> (1,8) 
# [6, 5, 9] -> [6, 5, 1, 1, 2]; 8 -> (1, 2)
assert decode([6, 5, 9], {(1, 2): 8, (1, 8): 9}) == [6, 5, 1, 1, 2]

In [12]:
encoded_list = encode(byte_list, merges)
decoded_list = decode(encoded_list, merges)

assert len(decoded_list) == len(byte_list)

In [13]:
# Should get same string as original
arr = array('B', decoded_list)
arr.tobytes().decode("utf-8")

'hello world!!!? (안녕하세요!) lol123 😉 fffffffff'