In [12]:
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""


def get_stats(ids):
    stats = {}
    for pair in zip(ids, ids[1:]):
        stats[pair] = stats.get(pair, 0) + 1
    return stats

def merge(ids, pair, idx):
    i = 0
    newids = []
    while i < len(ids):
        if i < len(ids) - 1 and (ids[i], ids[i+1]) == pair:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

import regex as re

class RegexTokenizer:
    def __init__(self) -> None:
        self.pattern = GPT4_SPLIT_PATTERN
        self.compiled_pattern = re.compile(self.pattern)
        self.merges = {}
        self.vocab = {}
    def train(self, text, vocab_size, verbose=False):
        assert vocab_size > 256
        
        num_merges = vocab_size - 256

        text_bytes = text.encode('utf-8')
        ids = list(text_bytes)
        
        for i in range(num_merges):
            stats = get_stats(ids)
            pair = max(stats, key=stats.get)
            idx = 256 + i
            self.merges[pair] = idx
            ids = merge(ids, pair, idx)
        self.vocab = {i: bytes([i]) for i in range(256)}
        for (p0, p1), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p0] + self.vocab[p1]

    def _encode(self, text):
        raw = text.encode('utf-8')
        ids = list(raw)
        while True:
            stats = get_stats(ids)
            if not stats:
                break
            pair = min(stats, key=lambda p: self.merges.get(p, float('inf')))
            if pair not in self.merges:
                break
            ids = merge(ids, pair, self.merges[pair])
        return ids
    def encode(self, text):
        
        splits = re.findall(self.compiled_pattern, text)
        output = []
        for s in splits:
            output.extend(self._encode(s))
        return output
    
    def decode(self, ids):
        raw = b''.join([self.vocab[idx] for idx in ids])
        text = raw.decode('utf-8', errors='replace')
        return text


In [13]:
with open('taylorswift.txt') as f:
    text = f.read()
tokenizer = RegexTokenizer()
tokenizer.train(text, 300)

In [14]:
print(len(text))

185561


In [15]:
print(tokenizer.merges)
print(tokenizer.vocab)


{(101, 32): 256, (44, 32): 257, (100, 32): 258, (46, 32): 259, (114, 32): 260, (50, 48): 261, (115, 32): 262, (105, 110): 263, (111, 110): 264, (114, 105): 265, (116, 32): 266, (116, 104): 267, (101, 258): 268, (257, 261): 269, (97, 110): 270, (97, 114): 271, (101, 260): 272, (121, 32): 273, (97, 108): 274, (267, 256): 275, (118, 268): 276, (119, 105): 277, (101, 114): 278, (264, 32): 279, (277, 102): 280, (82, 101): 281, (83, 280): 282, (111, 260): 283, (99, 104): 284, (269, 49): 285, (111, 109): 286, (98, 272): 287, (32, 275): 288, (97, 121): 289, (101, 110): 290, (111, 114): 291, (274, 32): 292, (101, 109): 293, (46, 10): 294, (265, 101): 295, (263, 103): 296, (269, 50): 297, (116, 105): 298, (289, 108): 299}
{0: b'\x00', 1: b'\x01', 2: b'\x02', 3: b'\x03', 4: b'\x04', 5: b'\x05', 6: b'\x06', 7: b'\x07', 8: b'\x08', 9: b'\t', 10: b'\n', 11: b'\x0b', 12: b'\x0c', 13: b'\r', 14: b'\x0e', 15: b'\x0f', 16: b'\x10', 17: b'\x11', 18: b'\x12', 19: b'\x13', 20: b'\x14', 21: b'\x15', 22: b'\

In [16]:
input = text
ids = tokenizer.encode(input)

out = tokenizer.decode(ids)
print(input == out)

True


In [17]:
import regex as re
GPT4_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
gpt4pat = re.compile(GPT4_SPLIT_PATTERN)
input = "Hello've world123 how's are you!!!?"
chunks = re.findall(gpt4pat, input)