# Main Idea Behind BPE
## Bits, Bytes

In [1]:
from functools import lru_cache
from collections import deque,Counter
import os
text="This's some text."
byte_ary=bytearray(text,'utf-8')
byte_ary

bytearray(b"This\'s some text.")

In [2]:
ids=list(byte_ary)
print(f'{ids}\nNumber of characters: {len(text)}\nNumber of token IDs: {len(ids)}')

[84, 104, 105, 115, 39, 115, 32, 115, 111, 109, 101, 32, 116, 101, 120, 116, 46]
Number of characters: 17
Number of token IDs: 17


# BPE Implementation Walkthrough
## Training, Encoding, Decoding

In [3]:
class BPETokenizerSimple:
    def __init__(self):
        self.vocab={}
        self.inverse_vocab={}
        self.bpe_merges={}
    def train(self,text,vocab_size,allowed_special={'<|endoftext|>'}):
        processed_text=[]
        for i,char in enumerate(text):
            if char==' ' and i!=0:
                processed_text.append('Ġ')
            if char!=' ':
                processed_text.append(char)
        processed_text=''.join(processed_text)
        unique_chars=[chr(i) for i in range(256)]
        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)
        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')
        self.vocab={i:char for i,char in enumerate(unique_chars)}
        self.inverse_vocab={char:i for i,char in self.vocab.items()}
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id=len(self.vocab)
                    self.vocab[new_id]=token
                    self.inverse_vocab[token]=new_id
        token_ids=[self.inverse_vocab[char] for char in processed_text]
        for new_id in range(len(self.vocab),vocab_size):
            pair_id=self.find_freq_pair(token_ids,mode='most')
            if pair_id is None:
                break
            token_ids=self.replace_pair(token_ids,pair_id,new_id)
            self.bpe_merges[pair_id]=new_id
        for (p0,p1),new_id in self.bpe_merges.items():
            merged_token=self.vocab[p0]+self.vocab[p1]
            self.vocab[new_id]=merged_token
            self.inverse_vocab[merged_token]=new_id
    def encode(self,text):
        tokens=[]
        words=text.replace('\n',' \n ').split()
        for i,word in enumerate(words):
            if i>0 and not word.startswith('\n'):
                tokens.append('Ġ'+word)
            else:
                tokens.append(word)
        token_ids=[]
        for token in tokens:
            if token in self.inverse_vocab:
                token_id=self.inverse_vocab[token]
                token_ids.append(token_id)
            else:
                sub_token_ids=self.tokenize_with_bpe(token)
                token_ids.extend(sub_token_ids)
        return token_ids
    def tokenize_with_bpe(self,token):
        token_ids=[self.inverse_vocab.get(char,None) for char in token]
        if None in token_ids:
            missing_chars=[char for char,tid in zip(token,token_ids) if tid is None]
            raise ValueError(f'Characters not found in vocab: {missing_chars}')
        can_merge=True
        while can_merge and len(token_ids)>1:
            can_merge=False
            new_tokens=[]
            i=0
            while i<len(token_ids)-1:
                pair=(token_ids[i],
                      token_ids[i+1])
                if pair in self.bpe_merges:
                    merged_token_id=self.bpe_merges[pair]
                    new_tokens.append(merged_token_id)
                    i+=2
                    can_merge=True
                else:
                    new_tokens.append(token_ids[i])
                    i+=1
            if i<len(token_ids):
                new_tokens.append(token_ids[i])
            token_ids=new_tokens
        return token_ids
    def decode(self,token_ids):
        decoded_string=''
        for token_id in token_ids:
            if token_id not in self.vocab:
                raise ValueError(f'Token ID {token_id} not found in vocab.')
            token=self.vocab[token_id]
            if token.startswith('Ġ'):
                decoded_string+=' '+token[1:]
            else:
                decoded_string+=token
        return decoded_string
    @lru_cache(maxsize=None)
    def get_special_token_id(self,token):
        return self.inverse_vocab.get(token,None)
    @staticmethod
    def find_freq_pair(token_ids,mode='most'):
        pairs=Counter(zip(token_ids,token_ids[1:]))
        if mode=='most':
            return max(pairs.items(),key=lambda x:x[1])[0]
        elif mode=='least':
            return min(pairs.items(),key=lambda x:x[1])[0]
        else:
            raise ValueError("Invalid mode: Choose 'most', 'least'.")
    @staticmethod
    def replace_pair(token_ids,pair_id,new_id):
        dq=deque(token_ids)
        replaced=[]
        while dq:
            current=dq.popleft()
            if dq and (current,dq[0])==pair_id:
                replaced.append(new_id)
                dq.popleft()
            else:
                replaced.append(current)
        return replaced
with open('./1-adding_bells_whistles_to_training_loop/verdict.txt','r',encoding='utf-8') as f:
    text=f.read()
tokenizer=BPETokenizerSimple()
tokenizer.train(text,vocab_size=1000,allowed_special={'<|endoftext|>'})
len(tokenizer.vocab)

1000

In [4]:
len(tokenizer.bpe_merges)

742

In [5]:
input_text='Jack embraced beauty through art, life.'
token_ids=tokenizer.encode(input_text)
print(f'{token_ids}\nNumber of characters: {len(input_text)}\nNumber of token IDs: {len(token_ids)}\n{tokenizer.decode(token_ids)}')

[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 44, 256, 326, 972, 46]
Number of characters: 39
Number of token IDs: 19
Jack embraced beauty through art, life.


In [6]:
for token_id in token_ids:
    print(f'{token_id} -> {tokenizer.decode([token_id])}')

424 -> Jack
256 ->  
654 -> em
531 -> br
302 -> ac
311 -> ed
256 ->  
296 -> be
97 -> a
465 -> ut
121 -> y
595 ->  through
841 ->  ar
116 -> t
44 -> ,
256 ->  
326 -> li
972 -> fe
46 -> .


In [7]:
tokenizer.decode(tokenizer.encode("This's some text."))

"This's some text."

In [8]:
import re,json
class BPETokenizerSimple:
    def __init__(self):
        self.vocab={}
        self.inverse_vocab={}
        self.bpe_merges={}
        self.bpe_ranks={}
    def train(self,text,vocab_size,allowed_special={'<|endoftext|>'}):
        processed_text=[]
        for i,char in enumerate(text):
            if char==' ' and i!=0:
                processed_text.append('Ġ')
            if char!=' ':
                processed_text.append(char)
        processed_text=''.join(processed_text)
        unique_chars=[chr(i) for i in range(256)]
        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)
        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')
        self.vocab={i:char for i,char in enumerate(unique_chars)}
        self.inverse_vocab={char:i for i,char in self.vocab.items()}
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id=len(self.vocab)
                    self.vocab[new_id]=token
                    self.inverse_vocab[token]=new_id
        token_ids=[self.inverse_vocab[char] for char in processed_text]
        for new_id in range(len(self.vocab),vocab_size):
            pair_id=self.find_freq_pair(token_ids,mode='most')
            if pair_id is None:
                break
            token_ids=self.replace_pair(token_ids,pair_id,new_id)
            self.bpe_merges[pair_id]=new_id
        for (p0,p1),new_id in self.bpe_merges.items():
            merged_token=self.vocab[p0]+self.vocab[p1]
            self.vocab[new_id]=merged_token
            self.inverse_vocab[merged_token]=new_id
    def load_vocab_and_merges_from_openai(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'r',encoding='utf-8') as file:
            loaded_vocab=json.load(file)
            self.vocab={int(v):k for k,v in loaded_vocab.items()}
            self.inverse_vocab={k:int(v) for k,v in loaded_vocab.items()}
        if 'Ċ' not in self.inverse_vocab or self.inverse_vocab['Ċ']!=198:
            raise KeyError("Vocabulary missing GPT-2 newline glyph 'Ċ' at id 198.")
        if '<|endoftext|>' not in self.inverse_vocab or self.inverse_vocab['<|endoftext|>']!=50256:
            raise KeyError('Vocabulary missing <|endoftext|> at id 50256.')
        if '\n' not in self.inverse_vocab:
            self.inverse_vocab['\n']=self.inverse_vocab['Ċ']
        if '\r' not in self.inverse_vocab:
            if 201 in self.vocab:
                self.inverse_vocab['\r']=201
            else:
                raise KeyError('Vocabulary missing carriage return token at id 201.')
        self.bpe_ranks={}
        with open(bpe_merges_path,'r',encoding='utf-8') as file:
            lines=file.readlines()
            if lines and lines[0].startswith('#'):
                lines=lines[1:]
            rank=0
            for line in lines:
                token1,*rest=line.strip().split()
                if len(rest)!=1:
                    continue
                token2=rest[0]
                if token1 in self.inverse_vocab and token2 in self.inverse_vocab:
                    self.bpe_ranks[(token1,token2)]=rank
                    rank+=1
                else:
                    pass
    def encode(self,text,allowed_special=None):
        specials_in_vocab=[tok for tok in self.inverse_vocab if tok.startswith('<|') and tok.endswith('|>')]
        if allowed_special is None:
            disallowed=[tok for tok in specials_in_vocab if tok in text]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        else:
            disallowed=[tok for tok in specials_in_vocab if tok in text and tok not in allowed_special]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        token_ids=[]
        if allowed_special is not None and len(allowed_special)>0:
            special_pattern='('+'|'.join(re.escape(tok) for tok in sorted(allowed_special,key=len,reverse=True))+')'
            last_index=0
            for match in re.finditer(special_pattern,text):
                prefix=text[last_index:match.start()]
                token_ids.extend(self.encode(prefix,allowed_special=None))
                special_token=match.group(0)
                if special_token in self.inverse_vocab:
                    token_ids.append(self.inverse_vocab[special_token])
                else:
                    raise ValueError(f'Special token {special_token} not found in vocabulary.')
                last_index=match.end()
            text=text[last_index:]
            disallowed=[tok for tok in self.inverse_vocab if tok.startswith('<|') and tok.endswith('|>') and tok in text and tok not in allowed_special]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        tokens=[]
        parts=re.split(r'(\r\n|\r|\n)',text)
        for part in parts:
            if part=='':
                continue
            if part=='\r\n':
                tokens.append('\r')
                tokens.append('\n')
                continue
            if part=='\r':
                tokens.append('\r')
                continue
            if part=='\n':
                tokens.append('\n')
                continue
            pending_spaces=0
            for m in re.finditer(r'( +)|(\S+)',part):
                if m.group(1) is not None:
                    pending_spaces+=len(m.group(1))
                else:
                    word=m.group(2)
                    if pending_spaces>0:
                        tokens.append('Ġ'+word)
                        for _ in range(pending_spaces-1):
                            tokens.append('Ġ')
                        pending_spaces=0
                    else:
                        tokens.append(word)
            for _ in range(pending_spaces):
                tokens.append('Ġ')
        for tok in tokens:
            if tok in self.inverse_vocab:
                token_ids.append(self.inverse_vocab[tok])
            else:
                token_ids.extend(self.tokenize_with_bpe(tok))
        return token_ids
    def tokenize_with_bpe(self,token):
        token_ids=[self.inverse_vocab.get(char,None) for char in token]
        if None in token_ids:
            missing_chars=[char for char,tid in zip(token,token_ids) if tid is None]
            raise ValueError(f'Characters not found in vocab: {missing_chars}')
        if not self.bpe_ranks:
            can_merge=True
            while can_merge and len(token_ids)>1:
                can_merge=False
                new_tokens=[]
                i=0
                while i<len(token_ids)-1:
                    pair=(token_ids[i],
                          token_ids[i+1])
                    if pair in self.bpe_merges:
                        merged_token_id=self.bpe_merges[pair]
                        new_tokens.append(merged_token_id)
                        i+=2
                        can_merge=True
                    else:
                        new_tokens.append(token_ids[i])
                        i+=1
                if i<len(token_ids):
                    new_tokens.append(token_ids[i])
                token_ids=new_tokens
            return token_ids
        symbols=[self.vocab[id_num] for id_num in token_ids]
        while True:
            pairs=set(zip(symbols,symbols[1:]))
            if not pairs:
                break
            min_rank=float('inf')
            bigram=None
            for p in pairs:
                r=self.bpe_ranks.get(p,float('inf'))
                if r<min_rank:
                    min_rank=r
                    bigram=p
            if bigram is None or bigram not in self.bpe_ranks:
                break
            first,second=bigram
            new_symbols=[]
            i=0
            while i<len(symbols):
                if i<len(symbols)-1 and symbols[i]==first and symbols[i+1]==second:
                    new_symbols.append(first+second)
                    i+=2
                else:
                    new_symbols.append(symbols[i])
                    i+=1
            symbols=new_symbols
            if len(symbols)==1:
                break
        merged_ids=[self.inverse_vocab[sym] for sym in symbols]
        return merged_ids
    def decode(self,token_ids):
        out=[]
        for tid in token_ids:
            if tid not in self.vocab:
                raise ValueError(f'Token ID {tid} not found in vocab.')
            tok=self.vocab[tid]
            if tid==198 or tok=='\n':
                out.append('\n')
            elif tid==201 or tok=='\r':
                out.append('\r')
            elif tok.startswith('Ġ'):
                out.append(' '+tok[1:])
            else:
                out.append(tok)
        return ''.join(out)
    def save_vocab_and_merges(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'w',encoding='utf-8') as file:
            json.dump(self.vocab,file,ensure_ascii=False,indent=2)
        with open(bpe_merges_path,'w',encoding='utf-8') as file:
            merges_list=[{'pair':list(pair),
                          'new_id':new_id} for pair,new_id in self.bpe_merges.items()]
            json.dump(merges_list,file,ensure_ascii=False,indent=2)
    def load_vocab_and_merges(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'r',encoding='utf-8') as file:
            loaded_vocab=json.load(file)
            self.vocab={int(k):v for k,v in loaded_vocab.items()}
            self.inverse_vocab={v:int(k) for k,v in loaded_vocab.items()}
        with open(bpe_merges_path,'r',encoding='utf-8') as file:
            merges_list=json.load(file)
            for merge in merges_list:
                pair=tuple(merge['pair'])
                new_id=merge['new_id']
                self.bpe_merges[pair]=new_id
    @lru_cache(maxsize=None)
    def get_special_token_id(self,token):
        return self.inverse_vocab.get(token,None)
    @staticmethod
    def find_freq_pair(token_ids,mode='most'):
        pairs=Counter(zip(token_ids,token_ids[1:]))
        if not pairs:
            return None
        if mode=='most':
            return max(pairs.items(),key=lambda x:x[1])[0]
        elif mode=='least':
            return min(pairs.items(),key=lambda x:x[1])[0]
        else:
            raise ValueError("Invalid mode: Choose 'most', 'least'.")
    @staticmethod
    def replace_pair(token_ids,pair_id,new_id):
        dq=deque(token_ids)
        replaced=[]
        while dq:
            current=dq.popleft()
            if dq and (current,dq[0])==pair_id:
                replaced.append(new_id)
                dq.popleft()
            else:
                replaced.append(current)
        return replaced
tokenizer=BPETokenizerSimple()
tokenizer.train(text,vocab_size=1000,allowed_special={'<|endoftext|>'})
input_text='Jack embraced beauty through art, life.<|endoftext|>'
token_ids=tokenizer.encode(input_text,allowed_special={'<|endoftext|>'})
print(f'{token_ids}\nNumber of characters: {len(input_text)}\nNumber of token IDs: {len(token_ids)}\n{tokenizer.decode(token_ids)}')

[424, 256, 654, 531, 302, 311, 256, 296, 97, 465, 121, 595, 841, 116, 44, 256, 326, 972, 46, 257]
Number of characters: 52
Number of token IDs: 20
Jack embraced beauty through art, life.<|endoftext|>


In [9]:
for token_id in token_ids:
    print(f'{token_id} -> {tokenizer.decode([token_id])}')

424 -> Jack
256 ->  
654 -> em
531 -> br
302 -> ac
311 -> ed
256 ->  
296 -> be
97 -> a
465 -> ut
121 -> y
595 ->  through
841 ->  ar
116 -> t
44 -> ,
256 ->  
326 -> li
972 -> fe
46 -> .
257 -> <|endoftext|>


In [10]:
tokenizer.decode(tokenizer.encode("This's some text with \n newline characters."))

"This's some text with \n newline characters."

## Loading Original GPT-2 BPE Tokenizer From OpenAI

In [11]:
tokenizer_gpt2=BPETokenizerSimple()
tokenizer_gpt2.load_vocab_and_merges_from_openai(vocab_path=os.path.join('./5-comparing_various_bpe_implementations/gpt2_model','encoder.json'),
                                                 bpe_merges_path=os.path.join('./5-comparing_various_bpe_implementations/gpt2_model','vocab.bpe'))
len(tokenizer_gpt2.vocab)

50257

In [12]:
input_text="This's some text."
token_ids=tokenizer_gpt2.encode(input_text)
token_ids

[1212, 338, 617, 2420, 13]