## GPT Tokenizer - Coding along

In this notebook, the first steps towards a tokenizer were done. We learnt to encode strings to bytes (using utf-8 standard) and vice-versa with `.encode()` / `.decode().` \
Utf-8 uses multi-byte encodings for some code points but the standard ASCII set of characters (256 characters) is a single-byte encoding (backward compatible). \
We end up with a list of bytes that could already be called tokens. Additionally, we perform a compression operation called _byte pair encoding_. This merges the most occuring pairs of bytes and substitutes it with a new token. Therefore, the total number of bytes decreases. As a consequence, we add a number of functions that map from the new bytes to their substitutes (for decoding). 

In [11]:
import torch

In [12]:
sent = 'hello world'
token = sent.encode("utf-8")
inttuples = list(token)
inttuples

[104, 101, 108, 108, 111, 32, 119, 111, 114, 108, 100]

In [13]:
print(str(len(sent)) + ': ' + sent)
print(str(len(inttuples)) + ': ' + ' '.join(list(map(str,inttuples))))

11: hello world
11: 104 101 108 108 111 32 119 111 114 108 100


In [14]:
def bpe_algorithm_old(intlist):
    '''deprecated'''
    newtoken = -1
    ref = {}
    while True:
        inttuples = list(zip(intlist, intlist[1:]))
        counts = {i:0 for i in set(inttuples)}
        for i in inttuples:
            counts[i] += 1

        maxtuple = max(counts, key=lambda x: counts[x])
        if max(counts.values()) == 1:
            return intlist, ref
        ref.update({newtoken: maxtuple})
        ii = 0
        newlist = []
        while ii <= len(intlist)-2:
            if (intlist[ii], intlist[ii+1]) == maxtuple:
                newlist.append(newtoken)
                ii +=1
            else: 
                newlist.append(intlist[ii])
            ii += 1
            if ii == len(intlist)-1:
                    newlist.append(intlist[-1])
        intlist = newlist
        newtoken -= 1

bpe_algorithm_old([1,1,1,2,4,1,1,1,2,1,3])

([-3, 4, -3, 1, 3], {-1: (1, 1), -2: (1, 2), -3: (-1, -2)})

In [15]:
text = """Byte pair encoding  Article Talk Read Edit View history  Tools From Wikipedia, the free encyclopedia Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial "tokens"). Then, successively the most frequent pair of adjacent characters is merged into a new, 2-characters long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from final vocabulary tokens and initial-set characters.[8]  All the unique tokens found in a corpus are listed in a token vocabulary, the size of which, in the case of GPT-3.5 and GPT-4, is 100256.  The difference between the modified and the original algorithm is that the original algorithm does not merge the most frequent pair of bytes of data, but replaces them by a new byte that was not contained in the initial dataset. A lookup table of the replacements is required to rebuild the initial dataset. The algorithm is effective for the tokenization because it does not require large computational overheads and remains consistent and reliable.  Original algorithm The original algorithm operates by iteratively replacing the most common contiguous sequences of characters in a target piece of text with unused \'placeholder\' bytes. The iteration ends when no sequences can be found, leaving the target text effectively compressed. Decompression can be performed by reversing this process, querying known placeholder terms against their corresponding denoted sequence, per a lookup table. In the original paper, this lookup table is encoded and stored alongside the compressed text.  Example Suppose the data to be encoded is  aaabdaaabac The byte pair "aa" occurs most often, so it will be replaced by a byte that is not used in the data, such as "Z". Now there is the following data and replacement table:  ZabdZabac Z=aa Then the process is repeated with byte pair "ab", replacing it with "Y":  ZYdZYac Y=ab Z=aa The only literal byte pair left occurs only once, and the encoding might stop here. Alternatively, the process could continue with recursive byte pair encoding, replacing "ZY" with "X":  XdXac X=ZY Y=ab Z=aa This data cannot be compressed further by byte pair encoding because there are no pairs of bytes that occur more than once.  To decompress the data, simply perform the replacements in the reverse order."""
tokens = list(text.encode('utf-8'))

print(text, len(text), sep='\n')
print(tokens, len(tokens), sep='\n')

Byte pair encoding  Article Talk Read Edit View history  Tools From Wikipedia, the free encyclopedia Byte pair encoding[1][2] (also known as digram coding)[3] is an algorithm, first described in 1994 by Philip Gage for encoding strings of text into tabular form for use in downstream modeling.[4] Its modification is notable as the large language model tokenizer with an ability to combine both tokens that encode single characters (including single digits or single punctuation marks) and those that encode whole words (even the longest compound words).[5][6][7] This modification, in the first step, assumes all unique characters to be an initial set of 1-character long n-grams (i.e. initial "tokens"). Then, successively the most frequent pair of adjacent characters is merged into a new, 2-characters long n-gram and all instances of the pair are replaced by this new token. This is repeated until a vocabulary of prescribed size is obtained. Note that new words can always be constructed from f

In [17]:
def get_stats(ids):
    stats = {i:0 for i in set(zip(ids, ids[1:]))}
    for i in zip(ids, ids[1:]):
        stats[i] += 1
    sorted_stats = dict(sorted(stats.items(), key=lambda x: x[1], reverse=True))
    return sorted_stats

def merge(ids, pair, idx):
    merged_ids = []
    i = 0
    while i < len(ids)-1:
        if (ids[i], ids[i+1]) == pair: 
            merged_ids += [idx]
            i += 2
        else:
            merged_ids += [ids[i]]
            i += 1
    if i == len(ids) - 1:
        merged_ids += [ids[i]]
    return merged_ids

#---
vocab_size = 276
num_merges = vocab_size - 256
ids = list(tokens)

merges = {}
for i in range(num_merges):
    stats = get_stats(ids)
    pair = max(stats, key=stats.get)
    idx = 256 + i
    print(f"merging {pair} into a new token {idx}")
    ids = merge(ids, pair, idx)
    merges[pair] = idx

merging (101, 32) into a new token 256
merging (115, 32) into a new token 257
merging (116, 104) into a new token 258
merging (105, 110) into a new token 259
merging (116, 32) into a new token 260
merging (101, 110) into a new token 261
merging (100, 32) into a new token 262
merging (114, 101) into a new token 263
merging (101, 114) into a new token 264
merging (258, 256) into a new token 265
merging (99, 111) into a new token 266
merging (32, 97) into a new token 267
merging (97, 99) into a new token 268
merging (111, 114) into a new token 269
merging (259, 103) into a new token 270
merging (105, 257) into a new token 271
merging (32, 265) into a new token 272
merging (116, 97) into a new token 273
merging (97, 108) into a new token 274
merging (101, 262) into a new token 275


In [18]:
print('tokens lengths:', len(tokens))
print('ids length:', len(ids))
print(f'compression ratio: {len(tokens) / len(ids):.2f}X')

tokens lengths: 2925
ids length: 2227
compression ratio: 1.31X


In [19]:
def encode(text):
    tokens = list(text.encode('utf-8'))
    while len(tokens) >= 2:
        stats = get_stats(tokens)
        pair = min(stats, key=lambda p: merges.get(p, float('inf')))
        if pair not in merges:
            break # nothing else can be merged
        idx = merges[pair]
        tokens = merge(tokens, pair, idx)
    return tokens

encode('hello world')

[104, 101, 108, 108, 111, 32, 119, 269, 108, 100]

In [20]:
vocab = {idx: bytes([idx]) for idx in range(256)}
for (p0,p1), idx in merges.items():
    vocab[idx] = vocab[p0] + vocab[p1]

def decode(ids):
    tokens = b''.join(vocab[idx] for idx in ids)
    text = tokens.decode('utf-8', errors='replace')
    return text


decode(encode("""hello world!"""))


'hello world!'