# Using BPE From Tiktoken

In [1]:
from functools import lru_cache
from collections import deque,Counter
from transformers import GPT2Tokenizer,GPT2TokenizerFast
from importlib.metadata import version
import os,json,tiktoken,transformers,regex as re
print(f"Tiktoken version: {version('tiktoken')}")

Tiktoken version: 0.12.0


In [2]:
tik_tokenizer=tiktoken.get_encoding('gpt2')
text='''Hello, world.
Is this a test?'''
integers=tik_tokenizer.encode(text,allowed_special={'<|endoftext|>'})
integers

[15496, 11, 995, 13, 198, 3792, 428, 257, 1332, 30]

In [3]:
strings=tik_tokenizer.decode(integers)
strings

'Hello, world.\nIs this a test?'

In [4]:
tik_tokenizer.n_vocab

50257

# Using BPE Via HuggingFace Transformers

In [5]:
transformers.__version__

'5.0.0rc1'

# Quick Performance Benchmark
## Original OpenAI GPT-2 Tokenizer

In [6]:
@lru_cache()
def bytes_to_unicode():
    bs=list(range(ord('!'),ord('~')+1))+list(range(ord('¡'),ord('¬')+1))+list(range(ord('®'),ord('ÿ')+1))
    cs=bs[:]
    n=0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8+n)
            n+=1
    cs=[chr(n) for n in cs]
    return dict(zip(bs,cs))
def get_pairs(word):
    pairs=set()
    prev_char=word[0]
    for char in word[1:]:
        pairs.add((prev_char,char))
        prev_char=char
    return pairs
class Encoder:
    def __init__(self,encoder,bpe_merges,errors='replace'):
        self.encoder=encoder
        self.decoder={v:k for k,v in self.encoder.items()}
        self.errors=errors
        self.byte_encoder=bytes_to_unicode()
        self.byte_decoder={v:k for k,v in self.byte_encoder.items()}
        self.bpe_ranks=dict(zip(bpe_merges,range(len(bpe_merges))))
        self.cache={}
        self.pat=re.compile(r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""")
    def bpe(self,token):
        if token in self.cache:
            return self.cache[token]
        word=tuple(token)
        pairs=get_pairs(word)
        if not pairs:
            return token
        while True:
            bigram=min(pairs,key=lambda pair:self.bpe_ranks.get(pair,float('inf')))
            if bigram not in self.bpe_ranks:
                break
            first,second=bigram
            new_word=[]
            i=0
            while i<len(word):
                try:
                    j=word.index(first,i)
                    new_word.extend(word[i:j])
                    i=j
                except ValueError:
                    new_word.extend(word[i:])
                    break
                if word[i]==first and i<len(word)-1 and word[i+1]==second:
                    new_word.append(first+second)
                    i+=2
                else:
                    new_word.append(word[i])
                    i+=1
            new_word=tuple(new_word)
            word=new_word
            if len(word)==1:
                break
            else:
                pairs=get_pairs(word)
        word=' '.join(word)
        self.cache[token]=word
        return word
    def encode(self,text):
        bpe_tokens=[]
        for token in re.findall(self.pat,text):
            token=''.join(self.byte_encoder[b] for b in token.encode('utf-8'))
            bpe_tokens.extend(self.encoder[bpe_token] for bpe_token in self.bpe(token).split(' '))
        return bpe_tokens
    def decode(self,tokens):
        text=''.join([self.decoder[token] for token in tokens])
        text=bytearray([self.byte_decoder[c] for c in text]).decode('utf-8',errors=self.errors)
        return text
def get_encoder(model_name,models_dir):
    with open(os.path.join(models_dir,model_name,'encoder.json'),'r') as f:
        encoder=json.load(f)
    with open(os.path.join(models_dir,model_name,'vocab.bpe'),'r',encoding='utf-8') as f:
        bpe_data=f.read()
    bpe_merges=[tuple(merge_str.split()) for merge_str in bpe_data.split('\n')[1:-1]]
    return Encoder(encoder=encoder,
                   bpe_merges=bpe_merges)
orig_tokenizer=get_encoder(model_name='gpt2_model',
                           models_dir='.')
with open('../1-adding_bells_whistles_to_training_loop/verdict.txt','r',encoding='utf-8') as f:
    raw_text = f.read()
%timeit orig_tokenizer.encode(raw_text)

6.93 ms ± 242 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Tiktoken OpenAI GPT-2 Tokenizer

In [7]:
%timeit tik_tokenizer.encode(raw_text)

1.43 ms ± 5.94 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


# HuggingFace OpenAI GPT-2 Tokenizer

In [8]:
hf_tokenizer=GPT2Tokenizer.from_pretrained('gpt2')
%timeit hf_tokenizer(raw_text)['input_ids']

Token indices sequence length is longer than the specified maximum sequence length for this model (5145 > 1024). Running this sequence through the model will result in indexing errors


5.91 ms ± 207 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [9]:
%timeit hf_tokenizer(raw_text,max_length=5145,truncation=True)['input_ids']

5.53 ms ± 42 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [10]:
hf_tokenizer_fast=GPT2TokenizerFast.from_pretrained('gpt2')
%timeit hf_tokenizer_fast(raw_text)['input_ids']

Token indices sequence length is longer than the specified maximum sequence length for this model (5145 > 1024). Running this sequence through the model will result in indexing errors


5.52 ms ± 27.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [11]:
%timeit hf_tokenizer_fast(raw_text,max_length=5145,truncation=True)['input_ids']

5.52 ms ± 31.5 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


# GPT-2 Tokenizer

In [12]:
class BPETokenizerSimple:
    def __init__(self):
        self.vocab={}
        self.inverse_vocab={}
        self.bpe_merges={}
        self.bpe_ranks={}
    def train(self,text,vocab_size,allowed_special={'<|endoftext|>'}):
        processed_text=[]
        for i,char in enumerate(text):
            if char==' ' and i!=0:
                processed_text.append('Ġ')
            if char!=' ':
                processed_text.append(char)
        processed_text=''.join(processed_text)
        unique_chars=[chr(i) for i in range(256)]
        unique_chars.extend(char for char in sorted(set(processed_text)) if char not in unique_chars)
        if 'Ġ' not in unique_chars:
            unique_chars.append('Ġ')
        self.vocab={i:char for i,char in enumerate(unique_chars)}
        self.inverse_vocab={char:i for i,char in self.vocab.items()}
        if allowed_special:
            for token in allowed_special:
                if token not in self.inverse_vocab:
                    new_id=len(self.vocab)
                    self.vocab[new_id]=token
                    self.inverse_vocab[token]=new_id
        token_ids=[self.inverse_vocab[char] for char in processed_text]
        for new_id in range(len(self.vocab),vocab_size):
            pair_id=self.find_freq_pair(token_ids,mode='most')
            if pair_id is None:
                break
            token_ids=self.replace_pair(token_ids,pair_id,new_id)
            self.bpe_merges[pair_id]=new_id
        for (p0,p1),new_id in self.bpe_merges.items():
            merged_token=self.vocab[p0]+self.vocab[p1]
            self.vocab[new_id]=merged_token
            self.inverse_vocab[merged_token]=new_id
    def load_vocab_and_merges_from_openai(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'r',encoding='utf-8') as file:
            loaded_vocab=json.load(file)
            self.vocab={int(v):k for k,v in loaded_vocab.items()}
            self.inverse_vocab={k:int(v) for k,v in loaded_vocab.items()}
        if 'Ċ' not in self.inverse_vocab or self.inverse_vocab['Ċ']!=198:
            raise KeyError("Vocabulary missing GPT-2 newline glyph 'Ċ' at id 198.")
        if '<|endoftext|>' not in self.inverse_vocab or self.inverse_vocab['<|endoftext|>']!=50256:
            raise KeyError('Vocabulary missing <|endoftext|> at id 50256.')
        if '\n' not in self.inverse_vocab:
            self.inverse_vocab['\n']=self.inverse_vocab['Ċ']
        if '\r' not in self.inverse_vocab:
            if 201 in self.vocab:
                self.inverse_vocab['\r']=201
            else:
                raise KeyError('Vocabulary missing carriage return token at id 201.')
        self.bpe_ranks={}
        with open(bpe_merges_path,'r',encoding='utf-8') as file:
            lines=file.readlines()
            if lines and lines[0].startswith('#'):
                lines=lines[1:]
            rank=0
            for line in lines:
                token1,*rest=line.strip().split()
                if len(rest)!=1:
                    continue
                token2=rest[0]
                if token1 in self.inverse_vocab and token2 in self.inverse_vocab:
                    self.bpe_ranks[(token1,token2)]=rank
                    rank+=1
                else:
                    pass
    def encode(self,text,allowed_special=None):
        specials_in_vocab=[tok for tok in self.inverse_vocab if tok.startswith('<|') and tok.endswith('|>')]
        if allowed_special is None:
            disallowed=[tok for tok in specials_in_vocab if tok in text]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        else:
            disallowed=[tok for tok in specials_in_vocab if tok in text and tok not in allowed_special]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        token_ids=[]
        if allowed_special is not None and len(allowed_special)>0:
            special_pattern='('+'|'.join(re.escape(tok) for tok in sorted(allowed_special,key=len,reverse=True))+')'
            last_index=0
            for match in re.finditer(special_pattern,text):
                prefix=text[last_index:match.start()]
                token_ids.extend(self.encode(prefix,allowed_special=None))
                special_token=match.group(0)
                if special_token in self.inverse_vocab:
                    token_ids.append(self.inverse_vocab[special_token])
                else:
                    raise ValueError(f'Special token {special_token} not found in vocabulary.')
                last_index=match.end()
            text=text[last_index:]
            disallowed=[tok for tok in self.inverse_vocab if tok.startswith('<|') and tok.endswith('|>') and tok in text and tok not in allowed_special]
            if disallowed:
                raise ValueError(f'Disallowed special tokens encountered in text: {disallowed}')
        tokens=[]
        parts=re.split(r'(\r\n|\r|\n)',text)
        for part in parts:
            if part=='':
                continue
            if part=='\r\n':
                tokens.append('\r')
                tokens.append('\n')
                continue
            if part=='\r':
                tokens.append('\r')
                continue
            if part=='\n':
                tokens.append('\n')
                continue
            pending_spaces=0
            for m in re.finditer(r'( +)|(\S+)',part):
                if m.group(1) is not None:
                    pending_spaces+=len(m.group(1))
                else:
                    word=m.group(2)
                    if pending_spaces>0:
                        tokens.append('Ġ'+word)
                        for _ in range(pending_spaces-1):
                            tokens.append('Ġ')
                        pending_spaces=0
                    else:
                        tokens.append(word)
            for _ in range(pending_spaces):
                tokens.append('Ġ')
        for tok in tokens:
            if tok in self.inverse_vocab:
                token_ids.append(self.inverse_vocab[tok])
            else:
                token_ids.extend(self.tokenize_with_bpe(tok))
        return token_ids
    def tokenize_with_bpe(self,token):
        token_ids=[self.inverse_vocab.get(char,None) for char in token]
        if None in token_ids:
            missing_chars=[char for char,tid in zip(token,token_ids) if tid is None]
            raise ValueError(f'Characters not found in vocab: {missing_chars}')
        if not self.bpe_ranks:
            can_merge=True
            while can_merge and len(token_ids)>1:
                can_merge=False
                new_tokens=[]
                i=0
                while i<len(token_ids)-1:
                    pair=(token_ids[i],
                          token_ids[i+1])
                    if pair in self.bpe_merges:
                        merged_token_id=self.bpe_merges[pair]
                        new_tokens.append(merged_token_id)
                        i+=2
                        can_merge=True
                    else:
                        new_tokens.append(token_ids[i])
                        i+=1
                if i<len(token_ids):
                    new_tokens.append(token_ids[i])
                token_ids=new_tokens
            return token_ids
        symbols=[self.vocab[id_num] for id_num in token_ids]
        while True:
            pairs=set(zip(symbols,symbols[1:]))
            if not pairs:
                break
            min_rank=float('inf')
            bigram=None
            for p in pairs:
                r=self.bpe_ranks.get(p,float('inf'))
                if r<min_rank:
                    min_rank=r
                    bigram=p
            if bigram is None or bigram not in self.bpe_ranks:
                break
            first,second=bigram
            new_symbols=[]
            i=0
            while i<len(symbols):
                if i<len(symbols)-1 and symbols[i]==first and symbols[i+1]==second:
                    new_symbols.append(first+second)
                    i+=2
                else:
                    new_symbols.append(symbols[i])
                    i+=1
            symbols=new_symbols
            if len(symbols)==1:
                break
        merged_ids=[self.inverse_vocab[sym] for sym in symbols]
        return merged_ids
    def decode(self,token_ids):
        out=[]
        for tid in token_ids:
            if tid not in self.vocab:
                raise ValueError(f'Token ID {tid} not found in vocab.')
            tok=self.vocab[tid]
            if tid==198 or tok=='\n':
                out.append('\n')
            elif tid==201 or tok=='\r':
                out.append('\r')
            elif tok.startswith('Ġ'):
                out.append(' '+tok[1:])
            else:
                out.append(tok)
        return ''.join(out)
    def save_vocab_and_merges(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'w',encoding='utf-8') as file:
            json.dump(self.vocab,file,ensure_ascii=False,indent=2)
        with open(bpe_merges_path,'w',encoding='utf-8') as file:
            merges_list=[{'pair':list(pair),
                          'new_id':new_id} for pair,new_id in self.bpe_merges.items()]
            json.dump(merges_list,file,ensure_ascii=False,indent=2)
    def load_vocab_and_merges(self,vocab_path,bpe_merges_path):
        with open(vocab_path,'r',encoding='utf-8') as file:
            loaded_vocab=json.load(file)
            self.vocab={int(k):v for k,v in loaded_vocab.items()}
            self.inverse_vocab={v:int(k) for k,v in loaded_vocab.items()}
        with open(bpe_merges_path,'r',encoding='utf-8') as file:
            merges_list=json.load(file)
            for merge in merges_list:
                pair=tuple(merge['pair'])
                new_id=merge['new_id']
                self.bpe_merges[pair]=new_id
    @lru_cache(maxsize=None)
    def get_special_token_id(self,token):
        return self.inverse_vocab.get(token,None)
    @staticmethod
    def find_freq_pair(token_ids,mode='most'):
        pairs=Counter(zip(token_ids,token_ids[1:]))
        if not pairs:
            return None
        if mode=='most':
            return max(pairs.items(),key=lambda x:x[1])[0]
        elif mode=='least':
            return min(pairs.items(),key=lambda x:x[1])[0]
        else:
            raise ValueError("Invalid mode: Choose 'most', 'least'.")
    @staticmethod
    def replace_pair(token_ids,pair_id,new_id):
        dq=deque(token_ids)
        replaced=[]
        while dq:
            current=dq.popleft()
            if dq and (current,dq[0])==pair_id:
                replaced.append(new_id)
                dq.popleft()
            else:
                replaced.append(current)
        return replaced
tokenizer_gpt2=BPETokenizerSimple()
tokenizer_gpt2.load_vocab_and_merges_from_openai(vocab_path=os.path.join('gpt2_model','encoder.json'),
                                                 bpe_merges_path=os.path.join('gpt2_model','vocab.bpe'))
%timeit tokenizer_gpt2.encode(raw_text)

20.3 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
