In [42]:
import numpy as np
from collections  import defaultdict

In [None]:
class SimpleTokenizer:
    def __init__(self, vocabsize=258) -> None:
        self.vocab={}
        self.merges={}
        self.vocabsize= vocabsize
        self.vocab={i:bytes([i]) for i in range(256)}

    @staticmethod
    def get_stats(tokens):
        counter=defaultdict(int)
        for p1,p2 in zip(tokens[:-1],tokens[1:]):
            counter[(p1,p2)]+=1
        return counter


    def merge(self, tokens, maxpair, newtoken):
        newtokens=[]
        i=0
        while i < len(tokens):
            if tokens[i]==maxpair[0] and i<len(tokens)-1 and tokens[i+1]==maxpair[1]:
                newtokens.append(newtoken)
                i=i+2
            else:
                newtokens.append(tokens[i])
                i=i+1
    
        return newtokens

    def train(self, text):
        encoded = text.encode("UTF-8") 
        tokens = list(encoded)
        print(f"Initial token length: {len(tokens)}")
        iterations = self.vocabsize - 256
        
        for _ in range(iterations):
            stats = self.get_stats(tokens)
            if not stats: 
                print(f"Not enough tokens, exiting early!")
                break
            
            maxpair = max(stats, key=stats.get)
            print(f"Pair to merge: {maxpair}") 
            
            newtoken = 256 + len(self.merges)
            tokens = self.merge(tokens, maxpair, newtoken)
            self.merges[maxpair] = newtoken

        # Build vocab from merges for decode
        for (p1, p2), idx in self.merges.items():
            self.vocab[idx] = self.vocab[p1] + self.vocab[p2]



    def encode(self, text):
        encoded = text.encode("UTF-8")
        tokens = list(encoded)
        while True:
            stats = self.get_stats(tokens)
            if not stats:  # ✓ Check if empty first!
                break
            maxtoken = min(stats, key=lambda x: self.merges.get(x, float("inf")))
            if maxtoken not in self.merges:
                break
            newtoken = self.merges[maxtoken] 
            tokens = self.merge(tokens, maxtoken, newtoken)  
        return tokens
    
    def decode(self,input_ids):
        bytes_seq = b''.join(self.vocab[i] for i in input_ids)
        text= bytes_seq.decode("UTF-8", errors="replace")
        return text

    

In [107]:
# Step 1: Get the sample text from Nathan Reed's blog post
text = """Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄 The very name strikes fear and awe into the hearts of programmers worldwide. We all know we ought to "support Unicode" in our software (whatever that means—like using wchar_t for all the strings, right?). But Unicode can be abstruse, and diving into the thousand-page Unicode Standard plus its dozens of supplementary annexes, reports, and notes can be more than a little intimidating. I don't blame programmers for still finding the whole thing mysterious, even 30 years after Unicode's inception."""
text ="Hello, world! This is a test. The quick brown fox jumps over the lazy dog. Unicode is important too: 😄 🅤🅝🅘🅒🅞🅓🅔"
print(f"Text: {text}")
print(f"Length in characters: {len(text)}")

Text: Hello, world! This is a test. The quick brown fox jumps over the lazy dog. Unicode is important too: 😄 🅤🅝🅘🅒🅞🅓🅔
Length in characters: 110


In [108]:
tokenizer = SimpleTokenizer()
tokenizer.train(text)

Initial token length: 134
Pair to merge: (240, 159)
Pair to merge: (256, 133)


In [109]:
encoded = tokenizer.encode("Hello, world")
decoded = tokenizer.decode(encoded)
print(decoded)  

Hello, world


In [105]:
encoded

[256]