### Step 1

Write the `BasicTokenizer` class, with the following three core functions:

- `def train(self, text, vocab_size, verbose=False)`
- `def encode(self, text)`
- `def decode(self, ids)`

Train your tokenizer on whatever text you like and visualize the merged tokens. Do they look reasonable? One default test you may wish to use is the text file `tests/taylorswift.txt`.


In [1]:
def get_stats(ids, counts=None):
    """
    Given a list of integers, return a dictionary of counts of consecutive pairs
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]): # iterate consecutive elements
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        if i < len(ids) - 1 and ids[i] == pair[0] and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids


class Tokenizer():
    def __init__(self):
        # default: vocab size of 256 (all bytes), no merges, no patterns
        self.merges = {} # (int, int) -> int
        self.pattern = "" # str
        self.special_tokens = {} # str -> int, e.g. {'<|endoftext|>': 100257}
        self.vocab = self._build_vocab() # int -> bytes
        
    def _build_vocab(self):
        self.vocab = {idx: bytes([idx]) for idx in range(256)}
        
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]
        for special, idx in self.special_tokens.items():
            self.vocab[idx] = special.encode("utf-8") 
        


class BasicTokenizer(Tokenizer):
    
    def __init__(self):
        super().__init__()
        
    def _tokenize(self, text):
        tokens = text.encode("utf-8") # raw bytes
        tokens = list(map(int, tokens)) # convert to a list of integers in range 0..255 for convenience
        return tokens
        
    def train(self, text, vocab_size, verbose=False):
        self.vocab_size=vocab_size
        num_merges = self.vocab_size - 256
        
        tokens=self._tokenize(text)
        ids = list(tokens)
        self.merges = {} # (int, int) -> int
        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = 256 + i
            print(f"merging {pair} into a new token {idx}")
            ids = merge(ids, pair, idx)
            self.merges[pair] = idx
    
    def encode(self, text):
         # given a string, return list of integers (the tokens)
        tokens = list(text.encode("utf-8"))
        while len(tokens) >= 2:
            stats = get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = merge(tokens, pair, idx)
        return tokens
    
    def decode(self, ids):
        # given ids (list of integers), return Python string
        tokens = b"".join(self.vocab[idx] for idx in ids)
        text = tokens.decode("utf-8", errors="replace")
        return text

#### data

In [2]:
from pathlib import Path

text=Path('tests/taylorswift.txt').read_text()

In [3]:
tokenizer=BasicTokenizer()
tokenizer.train(text, 276)

merging (101, 32) into a new token 256
merging (44, 32) into a new token 257
merging (100, 32) into a new token 258
merging (46, 32) into a new token 259
merging (114, 32) into a new token 260
merging (50, 48) into a new token 261
merging (115, 32) into a new token 262
merging (105, 110) into a new token 263
merging (111, 110) into a new token 264
merging (114, 105) into a new token 265
merging (116, 32) into a new token 266
merging (116, 104) into a new token 267
merging (101, 258) into a new token 268
merging (257, 261) into a new token 269
merging (97, 110) into a new token 270
merging (97, 114) into a new token 271
merging (101, 260) into a new token 272
merging (121, 32) into a new token 273
merging (97, 108) into a new token 274
merging (267, 256) into a new token 275


In [4]:
!pytest tests/test_tokenizer.py::test_encode_decode_identity

platform win32 -- Python 3.9.12, pytest-7.1.2, pluggy-1.0.0
rootdir: C:\Users\RistoHinno\python\minbpe
plugins: anyio-3.5.0, Faker-23.2.1, typeguard-3.0.2
collected 12 items

tests\test_tokenizer.py [32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m.[0m[32m                                     [100%][0m



## compare with correct solution

In [5]:
import minbpe as minbpe

In [6]:
tokenizer_minbpe=minbpe.BasicTokenizer()
tokenizer_minbpe.train(text, 276)

In [7]:
tokenizer_minbpe.merges

{(101, 32): 256,
 (44, 32): 257,
 (100, 32): 258,
 (46, 32): 259,
 (114, 32): 260,
 (50, 48): 261,
 (115, 32): 262,
 (105, 110): 263,
 (111, 110): 264,
 (114, 105): 265,
 (116, 32): 266,
 (116, 104): 267,
 (101, 258): 268,
 (257, 261): 269,
 (97, 110): 270,
 (97, 114): 271,
 (101, 260): 272,
 (121, 32): 273,
 (97, 108): 274,
 (267, 256): 275}

In [8]:
tokenizer.merges

{(101, 32): 256,
 (44, 32): 257,
 (100, 32): 258,
 (46, 32): 259,
 (114, 32): 260,
 (50, 48): 261,
 (115, 32): 262,
 (105, 110): 263,
 (111, 110): 264,
 (114, 105): 265,
 (116, 32): 266,
 (116, 104): 267,
 (101, 258): 268,
 (257, 261): 269,
 (97, 110): 270,
 (97, 114): 271,
 (101, 260): 272,
 (121, 32): 273,
 (97, 108): 274,
 (267, 256): 275}

### Step 2

Convert you `BasicTokenizer` into a `RegexTokenizer`, which takes a regex pattern and splits the text exactly as GPT-4 would. Process the parts separately as before, then concatenate the results. Retrain your tokenizer and compare the results before and after. You should see that you will now have no tokens that go across categories (numbers, letters, punctuation, more than one whitespace). Use the GPT-4 pattern:

```
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
```


In [9]:
import regex as re

class RegexTokenizer(BasicTokenizer):
    
    def __init__(self, regex_pattern=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""):
        super().__init__()
        self.regex_pattern=regex_pattern
        
    def _tokenize(self, text):
        GPT4_SPLIT_PATTERN = re.compile(self.regex_pattern)
        tokens=re.findall(GPT4_SPLIT_PATTERN, text)
        tokens=[list(tok.encode("utf-8")) for tok in tokens]
        return tokens
    
    
    def train(self, text, vocab_size, verbose=False):
        self.vocab_size=vocab_size
        num_merges = self.vocab_size - 256
        
        ids=self._tokenize(text)
        self.merges = {} # (int, int) -> int
        self.vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
        for i in range(num_merges):
            stats = {}
            for chunk_ids in ids:
#                 print(stats)
                stats = get_stats(chunk_ids, stats)
            # find the pair with the highest count
            pair = max(stats, key=stats.get)
            # mint a new token: assign it the next available id
            idx = 256 + i
            # replace all occurrences of pair in ids with idx
            ids = [merge(chunk_ids, pair, idx) for chunk_ids in ids]
            # save the merge
            self.merges[pair] = idx
            self.vocab[idx] = self.vocab[pair[0]] + self.vocab[pair[1]]


In [10]:
GPT4_SPLIT_PATTERN = re.compile(r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""")
tokens=re.findall(GPT4_SPLIT_PATTERN, text)
tokens=[list(tok.encode("utf-8")) for tok in tokens]

In [11]:
tokens[:5]

[[67, 111, 112, 121],
 [32, 112, 97, 115, 116, 101],
 [32, 111, 102],
 [32, 116, 104, 101],
 [32, 87, 105, 107, 105, 112, 101, 100, 105, 97]]

In [12]:
reg_tokenizer_minbpe=minbpe.RegexTokenizer()
reg_tokenizer_minbpe.train(text, 276)

In [13]:
reg_tokenizer=RegexTokenizer()
reg_tokenizer.train(text, 276)

In [14]:
reg_tokenizer.merges

{(101, 114): 256,
 (50, 48): 257,
 (111, 114): 258,
 (105, 110): 259,
 (101, 100): 260,
 (32, 116): 261,
 (111, 110): 262,
 (104, 101): 263,
 (32, 83): 264,
 (97, 114): 265,
 (97, 110): 266,
 (32, 65): 267,
 (261, 263): 268,
 (97, 108): 269,
 (114, 105): 270,
 (118, 260): 271,
 (115, 116): 272,
 (119, 105): 273,
 (32, 82): 274,
 (257, 49): 275}

In [15]:
reg_tokenizer_minbpe.merges

{(101, 114): 256,
 (50, 48): 257,
 (111, 114): 258,
 (105, 110): 259,
 (101, 100): 260,
 (32, 116): 261,
 (111, 110): 262,
 (104, 101): 263,
 (32, 83): 264,
 (97, 114): 265,
 (97, 110): 266,
 (32, 65): 267,
 (261, 263): 268,
 (97, 108): 269,
 (114, 105): 270,
 (118, 260): 271,
 (115, 116): 272,
 (119, 105): 273,
 (32, 82): 274,
 (257, 49): 275}

In [16]:
reg_tokenizer_minbpe.vocab

{0: b'\x00',
 1: b'\x01',
 2: b'\x02',
 3: b'\x03',
 4: b'\x04',
 5: b'\x05',
 6: b'\x06',
 7: b'\x07',
 8: b'\x08',
 9: b'\t',
 10: b'\n',
 11: b'\x0b',
 12: b'\x0c',
 13: b'\r',
 14: b'\x0e',
 15: b'\x0f',
 16: b'\x10',
 17: b'\x11',
 18: b'\x12',
 19: b'\x13',
 20: b'\x14',
 21: b'\x15',
 22: b'\x16',
 23: b'\x17',
 24: b'\x18',
 25: b'\x19',
 26: b'\x1a',
 27: b'\x1b',
 28: b'\x1c',
 29: b'\x1d',
 30: b'\x1e',
 31: b'\x1f',
 32: b' ',
 33: b'!',
 34: b'"',
 35: b'#',
 36: b'$',
 37: b'%',
 38: b'&',
 39: b"'",
 40: b'(',
 41: b')',
 42: b'*',
 43: b'+',
 44: b',',
 45: b'-',
 46: b'.',
 47: b'/',
 48: b'0',
 49: b'1',
 50: b'2',
 51: b'3',
 52: b'4',
 53: b'5',
 54: b'6',
 55: b'7',
 56: b'8',
 57: b'9',
 58: b':',
 59: b';',
 60: b'<',
 61: b'=',
 62: b'>',
 63: b'?',
 64: b'@',
 65: b'A',
 66: b'B',
 67: b'C',
 68: b'D',
 69: b'E',
 70: b'F',
 71: b'G',
 72: b'H',
 73: b'I',
 74: b'J',
 75: b'K',
 76: b'L',
 77: b'M',
 78: b'N',
 79: b'O',
 80: b'P',
 81: b'Q',
 82: b'R',
 83: b'

In [17]:
reg_tokenizer.vocab[77].decode('utf-8')

'M'

In [18]:
reg_tokenizer_minbpe.encode('this is an text')

[116, 104, 105, 115, 32, 105, 115, 32, 266, 261, 101, 120, 116]

In [19]:
reg_tokenizer.encode('this is an text')

[116, 104, 105, 115, 32, 105, 115, 32, 266, 261, 101, 120, 116]

### Step 3

You're now ready to load the merges from the GPT-4 tokenizer and show that your tokenizer produces the identical results for both `encode` and `decode`, matching [tiktoken](https://github.com/openai/tiktoken).

```
# match this
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
```

Unfortunately, you will run into two issues:

1. It is not trivial to recover the raw merges from the GPT-4 tokenizer. You can easily recover what we call `vocab` here, and what they call and store under `enc._mergeable_ranks`. Feel free to copy paste the `recover_merges` function in `minbpe/gpt4.py`, which takes these ranks and returns the raw merges. If you wish to know how this function works, read [this](https://github.com/openai/tiktoken/issues/60) and [this](https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306). Basically, under some conditions it is enough to only store the parent nodes (and their rank) and get rid of the precise details of which children merged up to any parent.
2. Second, the GPT-4 tokenizer for some reason permutes its raw bytes. It stores this permutation in the first 256 elements of the mergeable ranks, so you can recover this byte shuffle relatively simply as `byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}`. In both your encode and decode, you'll have to shuffle bytes around accordingly. If you're stuck, reference the minbpe/gpt4.py` file for hints.


In [20]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("hello world!!!? (안녕하세요!) lol123 😉")
text = enc.decode(ids) # get the same text back
text

'hello world!!!? (안녕하세요!) lol123 😉'

In [21]:
ids

[15339,
 1917,
 12340,
 30,
 320,
 31495,
 230,
 75265,
 243,
 92245,
 16715,
 28509,
 4513,
 57037]

In [22]:
# enc._mergeable_ranks

In [23]:
import pdb

def bpe(mergeable_ranks, token, max_rank):
    # helper function used in get_gpt4_merges() to reconstruct the merge forest
    #
    parts = [bytes([b]) for b in token]
#     pdb.set_trace()
    while True:
        min_idx = None
        min_rank = None
        for i, pair in enumerate(zip(parts[:-1], parts[1:])):
            rank = mergeable_ranks.get(pair[0] + pair[1])
            if rank is not None and (min_rank is None or rank < min_rank):
                min_idx = i
                min_rank = rank
        if min_rank is None or (max_rank is not None and min_rank >= max_rank):
            break
        assert min_idx is not None
        parts = parts[:min_idx] + [parts[min_idx] + parts[min_idx + 1]] + parts[min_idx + 2:]
    return parts


def recover_merges(mergeable_ranks):
    # the `merges` are already the byte sequences in their merged state.
    # so we have to recover the original pairings. We can do this by doing
    # a small BPE training run on all the tokens, in their order.
    # also see https://github.com/openai/tiktoken/issues/60
    # also see https://github.com/karpathy/minbpe/issues/11#issuecomment-1950805306
    merges = {}
    for token, rank in mergeable_ranks.items():
        if len(token) == 1:
            continue # skip raw bytes
        pair = tuple(bpe(mergeable_ranks, token, max_rank=rank))
        assert len(pair) == 2
        # recover the integer ranks of the pair
        ix0 = mergeable_ranks[pair[0]]
        ix1 = mergeable_ranks[pair[1]]
        merges[(ix0, ix1)] = rank

    return merges

merges=recover_merges(enc._mergeable_ranks)
merges

{(220, 220): 256,
 (256, 256): 257,
 (72, 77): 258,
 (220, 83): 259,
 (257, 257): 260,
 (68, 81): 261,
 (256, 220): 262,
 (78, 77): 263,
 (220, 64): 264,
 (81, 68): 265,
 (64, 83): 266,
 (82, 83): 267,
 (68, 77): 268,
 (78, 81): 269,
 (259, 71): 270,
 (198, 198): 271,
 (220, 66): 272,
 (75, 68): 273,
 (220, 82): 274,
 (72, 83): 275,
 (64, 77): 276,
 (64, 81): 277,
 (64, 75): 278,
 (270, 68): 279,
 (26, 198): 280,
 (220, 79): 281,
 (220, 69): 282,
 (78, 84): 283,
 (220, 28): 284,
 (72, 82): 285,
 (257, 262): 286,
 (258, 70): 287,
 (68, 82): 288,
 (220, 86): 289,
 (72, 263): 290,
 (68, 67): 291,
 (72, 66): 292,
 (220, 65): 293,
 (220, 67): 294,
 (68, 83): 295,
 (220, 76): 296,
 (220, 78): 297,
 (197, 197): 298,
 (81, 78): 299,
 (64, 82): 300,
 (68, 75): 301,
 (66, 83): 302,
 (77, 67): 303,
 (220, 258): 304,
 (220, 71): 305,
 (268, 83): 306,
 (72, 67): 307,
 (220, 77): 308,
 (64, 76): 309,
 (260, 262): 310,
 (259, 78): 311,
 (220, 265): 312,
 (12, 12): 313,
 (220, 90): 314,
 (297, 69): 31

In [24]:
token=b' when'
parts = [bytes([b]) for b in token]
parts

[b' ', b'w', b'h', b'e', b'n']

In [25]:
for i , tok in enumerate(zip(parts[:-1], parts[1:])):
    print(i)
    print(tok)

0
(b' ', b'w')
1
(b'w', b'h')
2
(b'h', b'e')
3
(b'e', b'n')


In [26]:
byte_shuffle = {i: enc._mergeable_ranks[bytes([i])] for i in range(256)}

In [27]:
# byte_shuffle

In [28]:
bytes([60])

b'<'

In [29]:
class GPT4Tokenizer(RegexTokenizer):
    
    def __init__(self, regex_pattern=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""):
        super().__init__()
        self.regex_pattern=regex_pattern
        
        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        self.merges = recover_merges(mergeable_ranks)
         # reconstruct the vocab from the merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        self.vocab = vocab
        
        # now here is another tricky part.
        # for some reason, the tokens corresponding to individual bytes
        # are permuted in a different order. This is completely non-sensical
        # and probably historical, but therefore we have to deal with it here.
        self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
        self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
                
        
    def encode(self, text):
         # given a string, return list of integers (the tokens)
        tokens = list(text.encode("utf-8"))
        tokens = bytes(self.byte_shuffle[b] for b in tokens)
        while len(tokens) >= 2:
            stats = get_stats(tokens)
            pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
            if pair not in self.merges:
                break # nothing else can be merged
            idx = self.merges[pair]
            tokens = merge(tokens, pair, idx)
        return tokens
        
    def decode(self, ids):
        # given ids (list of integers), return Python string
        tokens = b"".join(self.vocab[idx] for idx in ids)
        tokens = bytes(self.inverse_byte_shuffle[b] for b in tokens)
        text = tokens.decode("utf-8", errors="replace")
        return text
    

In [30]:
gpt4=GPT4Tokenizer()

In [31]:
gpt4_minbpe=minbpe.GPT4Tokenizer()

In [32]:
text="hello world!!!? (안녕하세요!) lol123 😉"
ids = enc.encode(text)
text = enc.decode(ids) # get the same text back

In [33]:
ids

[15339,
 1917,
 12340,
 30,
 320,
 31495,
 230,
 75265,
 243,
 92245,
 16715,
 28509,
 4513,
 57037]

In [34]:
gpt4.encode(text)

[15339,
 1917,
 12340,
 30,
 320,
 31495,
 230,
 75265,
 243,
 92245,
 16715,
 28509,
 4513,
 57037]

In [35]:
gpt4_minbpe.encode(text)

[15339,
 1917,
 12340,
 30,
 320,
 31495,
 230,
 75265,
 243,
 92245,
 16715,
 28509,
 4513,
 57037]

In [36]:
gpt4.decode(gpt4.encode(text))

'hello world!!!? (안녕하세요!) lol123 😉'

In [37]:
gpt4_minbpe.decode(gpt4_minbpe.encode(text))

'hello world!!!? (안녕하세요!) lol123 😉'

### Step 4

(Optional, irritating, not obviously useful) Add the ability to handle special tokens. You'll then be able to match the output of tiktoken even when special tokens are present, e.g.:

```
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
```

Without `allowed_special` tiktoken will error.

In [38]:
import tiktoken
enc = tiktoken.get_encoding("cl100k_base") # this is the GPT-4 tokenizer
ids = enc.encode("<|endoftext|>hello world", allowed_special="all")
ids

[100257, 15339, 1917]

In [39]:
len(enc._mergeable_ranks)

100256

In [40]:
class GPT4Tokenizer(RegexTokenizer):
    
    def __init__(self, 
                 regex_pattern=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""",
                special_tokens=None):
        super().__init__()
        self.regex_pattern=re.compile(regex_pattern)
        
        enc = tiktoken.get_encoding("cl100k_base")
        mergeable_ranks = enc._mergeable_ranks
        self.merges = recover_merges(mergeable_ranks)
         # reconstruct the vocab from the merges
        vocab = {idx: bytes([idx]) for idx in range(256)}
        for (p0, p1), idx in self.merges.items():
            vocab[idx] = vocab[p0] + vocab[p1]
        self.vocab = vocab
        
        # now here is another tricky part.
        # for some reason, the tokens corresponding to individual bytes
        # are permuted in a different order. This is completely non-sensical
        # and probably historical, but therefore we have to deal with it here.
        self.byte_shuffle = {i: mergeable_ranks[bytes([i])] for i in range(256)}
        self.inverse_byte_shuffle = {v: k for k, v in self.byte_shuffle.items()}
        if special_tokens is not None:
            self.special_tokens=special_tokens
            self.inverese_special_tokens={v:k for k, v in self.special_tokens.items()}
        else:
            self.special_tokens=None
            self.inverse_special_tokens=None
                
    def encode_normal(self, text):        
        chunk_tokens=re.findall(self.regex_pattern, text)
        chunk_tokens=[list(tok.encode("utf-8")) for tok in chunk_tokens]
        tokens_merged=[]
        for tokens in chunk_tokens:
            tokens = bytes(self.byte_shuffle[b] for b in tokens)
            while len(tokens) >= 2:
                stats = get_stats(tokens)
#                 print(stats)
                #get the most minimal pair that could be merged first
                pair = min(stats, key=lambda p: self.merges.get(p, float("inf")))
                print(pair)
                if pair not in self.merges:
                    break # nothing else can be merged
                idx = self.merges[pair]
                tokens = merge(tokens, pair, idx)
            tokens_merged.extend(tokens)
        return tokens_merged
        
    def encode(self, text, allowed_special=None):
        if allowed_special=='all':
            special_pattern = "(" + "|".join(re.escape(k) for k in self.special_tokens) + ")"
            special_chunks = re.split(special_pattern, text)
            ids=[]
            for special_chunk in special_chunks:
                if special_chunk in self.special_tokens:
                    ids.append(self.special_tokens[special_chunk])
                else:
                    ids.extend(self.encode_normal(special_chunk))
                    
        return ids
        
    def decode(self, ids):
        # given ids (list of integers), return Python string
        tokens = b"".join(self.vocab[idx] for idx in ids)
        tokens = bytes(self.inverse_byte_shuffle[b] for b in tokens)
        text = tokens.decode("utf-8", errors="replace")
        return text
    

In [41]:
GPT4_SPECIAL_TOKENS = {
    '<|endoftext|>': 100257,
    '<|fim_prefix|>': 100258,
    '<|fim_middle|>': 100259,
    '<|fim_suffix|>': 100260,
    '<|endofprompt|>': 100276
}

In [42]:
gpt4_tokenizer=GPT4Tokenizer(special_tokens=GPT4_SPECIAL_TOKENS)
gpt4_tokenizer.encode("<|endoftext|>hello world", allowed_special='all')

(68, 75)
(75, 78)
(301, 385)
(71, 4896)
(78, 81)
(220, 86)
(75, 67)
(269, 509)
(289, 1410)


[100257, 15339, 1917]

In [43]:
gpt4_tokenizer.merges[(68, 75)]

301

In [44]:
gpt4_tokenizer.merges[(75, 78)]

385