In [1]:
import regex as re

In [2]:
class RegexTokenizer:
    def __init__(self):
        self.merge={}
        self.split_pattern=r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
        self.compiled_pattern=re.compile(self.split_pattern)
    def getstat(self,tokens,count):
        for p0,p1 in zip(tokens[:],tokens[1:]):
            count[(p0,p1)]=count.get((p0,p1),0)+1
        
    def mergepair(self,ids,pair,idx):
        new_tokens=[]
        i=0
        while i<len(ids):
            if i<len(ids)-1 and ids[i]==pair[0] and ids[i+1]==pair[1]:
                new_tokens.append(idx)
                i+=2
            else:
                new_tokens.append(ids[i])
                i+=1
        return new_tokens
        
    def train(self,text,vocab_size,verbose=False):
        if vocab_size>=256:
            word_chunks=re.findall(self.compiled_pattern,text)
            tokens=[list(ch.encode('utf-8')) for ch in word_chunks]
            for i in range(vocab_size-256):
                count={}
                for j in tokens:
                    self.getstat(j,count)
                if len(count):
                    maxpair=max(count,key=count.get)
                    tokens=[self.mergepair(t,maxpair,256+i) for t in tokens]
                    self.merge[maxpair]=256+i
                    if verbose:
                        print(f"Pair {maxpair} merged into {256+i}")
            vocab={k:bytes([k]) for k in range(256)}
            for k,v in self.merge.items():
                vocab[v]=vocab[k[0]]+vocab[k[1]]
            self.vocab=vocab
        else:
            print("The vocab size should be greater than or equal to 256")

    def encode(self,text):
        encoded_tokens=[]
        word_chunks=re.findall(self.compiled_pattern,text)
        for i in word_chunks:
            tokens=list(i.encode('utf-8'))
            count={}
            while(len(tokens)>=2): 
                self.getstat(i,count)
                min_index_pair=min(count,key=lambda x:self.merge.get(x,float('inf')))
                if min_index_pair not in self.merge:
                    break
                tokens=self.mergepair(tokens,min_index_pair,self.merge[min_index_pair])
            encoded_tokens.append(tokens)
        return encoded_tokens

    def decode(self,ids):
        text=b"".join(self.vocab[j] for x in ids for j in x)
        text=text.decode('utf-8',errors='replace')
        return text

    def show_merge_values(self):
        for k,v in self.merge.items():
            print(f'{k}:{self.vocab[v].decode("utf-8",errors="replace")}')

In [3]:
with open('taylorswift_wikipedia.txt','r') as f:
    text=f.read()

In [4]:
tokenizer=RegexTokenizer()
tokenizer.train(text,vocab_size=300,verbose=True)

Pair (101, 114) merged into 256
Pair (50, 48) merged into 257
Pair (111, 114) merged into 258
Pair (105, 110) merged into 259
Pair (101, 100) merged into 260
Pair (32, 116) merged into 261
Pair (111, 110) merged into 262
Pair (104, 101) merged into 263
Pair (32, 83) merged into 264
Pair (97, 114) merged into 265
Pair (97, 110) merged into 266
Pair (32, 65) merged into 267
Pair (261, 263) merged into 268
Pair (97, 108) merged into 269
Pair (114, 105) merged into 270
Pair (118, 260) merged into 271
Pair (115, 116) merged into 272
Pair (119, 105) merged into 273
Pair (32, 82) merged into 274
Pair (257, 49) merged into 275
Pair (32, 102) merged into 276
Pair (257, 50) merged into 277
Pair (32, 84) merged into 278
Pair (102, 116) merged into 279
Pair (97, 121) merged into 280
Pair (32, 34) merged into 281
Pair (273, 279) merged into 282
Pair (101, 116) merged into 283
Pair (264, 282) merged into 284
Pair (99, 104) merged into 285
Pair (98, 256) merged into 286
Pair (97, 116) merged into 287

In [5]:
encoded_values=tokenizer.encode("Hello how are you")
print(encoded_values)

[[72, 101, 108, 108, 111], [32, 104, 111, 119], [32, 97, 114, 101], [32, 121, 111, 117]]


In [6]:
tokenizer.decode(encoded_values)

'Hello how are you'

In [7]:
text1="hello world!!!? (안녕하세요!) lol123 😉"
text2=tokenizer.decode(tokenizer.encode(text1))
print(text1==text2)

True


In [8]:
tokenizer.show_merge_values()

(101, 114):er
(50, 48):20
(111, 114):or
(105, 110):in
(101, 100):ed
(32, 116): t
(111, 110):on
(104, 101):he
(32, 83): S
(97, 114):ar
(97, 110):an
(32, 65): A
(261, 263): the
(97, 108):al
(114, 105):ri
(118, 260):ved
(115, 116):st
(119, 105):wi
(32, 82): R
(257, 49):201
(32, 102): f
(257, 50):202
(32, 84): T
(102, 116):ft
(97, 121):ay
(32, 34): "
(273, 279):wift
(101, 116):et
(264, 282): Swift
(99, 104):ch
(98, 256):ber
(97, 116):at
(111, 109):om
(101, 115):es
(101, 110):en
(101, 109):em
(34, 46):".
(32, 40): (
(46, 10):.

(259, 103):ing
(108, 258):lor
(32, 77): M
(105, 103):ig
(32, 262): on
