In [1]:
from collections import Counter, OrderedDict
import regex as re
import pdb
from tqdm import tqdm

In [2]:
unk_token_id=0
unk_tokens = [112, 120] # excluded from coverage due to low frequency during training. replaced with unk_tokens
vocab = [32, 97, 101, 108, 109, 112, 115, 0] # 116 is not in vocabulary i.e not seen during training

In [3]:
LLAMA_SPECIAL_TOKENS = {
    '<|unk|>': 0,
    '<s>': 1,
    '</s>': 2,
    '<|pad|>': 3,
}
GPT2_SPLIT_PATTERN = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
GPT2_SPLIT_PATTERN_SINGLE = r"""'(?:[sdmt]|ll|ve|re)| ?\p{L}+| ?\p{N}{1}| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"""
LLAMA_SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+""" # gpt4 with digits split

num_special_tokens = len(LLAMA_SPECIAL_TOKENS)
compiled_pattern = re.compile(LLAMA_SPLIT_PATTERN)
character_coverage=0.9995

In [4]:
text = "this is sample text"

In [5]:
codepoints = [ord(x) for x in text] # unicode encoding
# codepoints

In [6]:
codepoints = [x if not x in unk_tokens else unk_token_id for x in codepoints] # replacing unk_tokens with unk_token_id
# codepoints

In [7]:
# [x if x in vocab else chr(x).encode('utf-8') for x in codepoints] # finding utf-8 encoding of codepoints not seen during training

In [8]:
codepoints = [x if x in vocab else int.from_bytes(chr(x).encode('utf-8')) for x in codepoints] # convert utf-8 bytes to ints
# codepoints

In [9]:
def build_vocab(text,character_coverage=0.9995):
    
    # Build rare_tokens from coverage
    vocab_counts = Counter(text)
    freq_sort = sorted(vocab_counts.items(), key=lambda item: (-item[1], item[0])) # sort by freq (item[1]) and break ties alphabetically (item[0]).
    freq_sort = OrderedDict(freq_sort)
        
    total_chars = len(text)
    coverage_count = int((1-character_coverage) * total_chars)
    rare_tokens = [char for char,freq in freq_sort.items() if freq<coverage_count]
    unicode_vocab = [char for char,freq in freq_sort.items() if freq>=coverage_count]

    num_special_tokens = len(LLAMA_SPECIAL_TOKENS)
    vocab_itos = {v:k for k,v in LLAMA_SPECIAL_TOKENS.items()} # add special tokens
    vocab_itos.update({idx+num_special_tokens: bytes([idx]) for idx in range(256)}) # add utf-8 fallback bytes
    vocab_itos.update({idx+num_special_tokens + 256: ch for idx,ch in enumerate(unicode_vocab)}) # add utf-8 fallback bytes
    return vocab_itos, rare_tokens

In [10]:
vocab_itos, rare_tokens = build_vocab(text)
vocab_stoi = {v:k for k,v in vocab_itos.items()}
# vocab_itos

In [11]:
def get_token_ids(text,vocab_itos=vocab_itos,rare_tokens=rare_tokens,verbose=False):
    vocab_stoi = {v:k for k,v in vocab_itos.items()}
    unk_token_id = 0
    token_ids = []
    for x in text:
        if x in rare_tokens:
            token_ids.append(unk_token_id)
        elif x in vocab_stoi:
            token_ids.append(vocab_stoi[x])
        else:
            if verbose: print(f'utf-8 encoded : {x}')
            x_utf8 = x.encode("utf-8")
            x_utf8_token_ids = [i+4 for i in x_utf8] #TODO Fix the hack of adding special tokens manually
            token_ids.extend(x_utf8_token_ids)

    return token_ids

# Code from Karpathy

In [12]:
def get_counts_llama(ids,counts=None):
    """
    Given a list of integers, updates a dictionary of counts of consecutive pairs (called stats), skipping over unk_token_id
    Example: [1, 2, 3, 1, 2] -> {(1, 2): 2, (2, 3): 1, (3, 1): 1}
    Optionally allows to update an existing dictionary of counts
    """
    counts = {} if counts is None else counts
    for pair in zip(ids, ids[1:]):
        if unk_token_id in pair: # ignore pairs with unk_tokens
            continue
        # if type(vocab_itos[pair[0]])!=type(vocab_itos[pair[1]]): # don't merge across utf-8 and unicode encoding
        #     continue
        counts[pair] = counts.get(pair, 0) + 1
    return counts

def merge_llama(ids, pair, idx):
    newids = []
    i = 0
    while i < len(ids):
        
        # if unk_token_id, skip over it
        if ids[i] == unk_token_id:
            newids.append(ids[i])
            i += 1
        # if not at the very last position AND the pair matches, replace it
        elif ids[i] == pair[0] and i < len(ids) - 1 and ids[i+1] == pair[1]:
            newids.append(idx)
            i += 2
        else:
            newids.append(ids[i])
            i += 1
    return newids

In [13]:
text = "\n".join(open('tests/taylorswift.txt').readlines())
len(text)

186548

In [14]:
def train(text, required_vocab_size=350, verbose=False, character_coverage=0.9995):
# pdb.set_trace()
    # required_vocab_size = 350
    vocab_itos, rare_tokens = build_vocab(text,character_coverage=character_coverage)
    vocab_len = len(vocab_itos)
    assert required_vocab_size >= vocab_len, f"{required_vocab_size} < {vocab_len}"
    num_merges = required_vocab_size - vocab_len
    if verbose: print(f"New Merges : {num_merges}")
    
    # split the text up into text chunks
    text_chunks = re.findall(compiled_pattern, text)
    
    # Get Unicode codepoints for each chunk. Replace the characters in unk_tokens with unk_token_id.
    ids = [get_token_ids(chunk,vocab_itos=vocab_itos) for chunk in text_chunks] # unicode encoded
    
    # iteratively merge the most common pairs to create new tokens
    merges = {} # (int, int) -> int
    # vocab = {idx: bytes([idx]) for idx in range(256)} # idx -> bytes
    
    for i in tqdm(range(num_merges),total=num_merges,disable=verbose):
        # count the number of times every consecutive pair appears
        counts = {}
        for chunk_ids in ids:
            # passing in counts will update it in place, adding up counts
            get_counts_llama(chunk_ids,counts)
        # find the pair with the highest count
        # pdb.set_trace()
        pair = max(counts, key=counts.get)
        # if any(x in range(4,256) for x in pair):
        #     pdb.set_trace()
        # mint a new token: assign it the next available id
        idx = vocab_len + i
        # replace all occurrences of pair in ids with idx
        ids = [merge_llama(chunk_ids, pair, idx) for chunk_ids in ids]
        # save the merge
        merges[pair] = idx
        vocab_itos[idx] = vocab_itos[pair[0]] + vocab_itos[pair[1]] # if isinstance(new_merge,str) else new_merge.decode('utf-8')
        # prints
        if verbose:
            print(f"merge {i+1}/{num_merges}: {pair} -> {idx} ({vocab_itos[idx]}) had {counts[pair]} occurrences")

    # vocab.update({idx+num_special_tokens+256+num_merges: char for idx,char in enumerate(unicode_vocab)}) # add unicode characters
    return vocab_itos,merges,rare_tokens

In [15]:
vocab_itos,merges,rare_tokens=train(text,required_vocab_size=400,character_coverage=1.0)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 45/45 [00:02<00:00, 17.52it/s]


In [46]:
def encode(text,verbose=False):
      # given a string, return list of integers (the tokens)
    tokens = get_token_ids(text,vocab_itos=vocab_itos,rare_tokens=rare_tokens,verbose=verbose)
    while True:
        counts = get_counts_llama(tokens)
        # pdb.set_trace()
        pair = min(counts, key=lambda p: merges.get(p, float("inf")))
        if pair not in merges:
            if verbose: print(f"breaking out on pair : {pair}")
            break # nothing else can be merged
        idx = merges[pair]
        tokens = merge_llama(tokens, pair, idx)
        if len(tokens) < 2 or (len(tokens)==2 and unk_token_id in tokens):
            break
    return tokens

In [17]:
test = "My name is Taylor Swift and this is Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄"

In [18]:
enc = encode(test,verbose=False)
# enc

In [50]:
def decode(ids):
    text=""
    utf8_chunk=[]
    uni_chunk=[]
    for id in ids:
        if id==0:
            text += "<|unk|>"
            continue
        if id<260:
            if len(uni_chunk)>0:
                text += "".join(vocab_itos[idx] for idx in uni_chunk)
                uni_chunk=[]
            utf8_chunk.append(id)
        else:
            if len(utf8_chunk)>0:
                tokens = b"".join(vocab_itos[idx] for idx in utf8_chunk)
                text += tokens.decode("utf-8", errors="replace")
                utf8_chunk=[]
            uni_chunk.append(id)
    #flush out the last chunk
    if len(utf8_chunk)>0:
        tokens = b"".join(vocab_itos[idx] for idx in utf8_chunk)
        text += tokens.decode("utf-8", errors="replace")
    else:
        text += "".join(vocab_itos[idx] for idx in uni_chunk)
    return text

In [20]:
# vocab_itos

In [21]:
print(decode(enc))

My name is Taylor Swift and this is Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄


In [22]:
print(decode(encode(test))),print(test)

My name is Taylor Swift and this is Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄
My name is Taylor Swift and this is Ｕｎｉｃｏｄｅ! 🅤🅝🅘🅒🅞🅓🅔‽ 🇺‌🇳‌🇮‌🇨‌🇴‌🇩‌🇪! 😄


(None, None)

In [23]:
decode(encode(test))==test

True

In [25]:
for line in test.split('\n'):
    if decode(encode(line))!=line:
        print(line)

In [26]:
test = "\n".join(open('toy.txt').readlines())
decode(encode(test))==test

True

# Handling UNK tokens

In [28]:
vocab_itos,merges,rare_tokens=train(text,required_vocab_size=400)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 71/71 [00:04<00:00, 14.86it/s]


In [29]:
print(rare_tokens)

[';', 'j', '–', '$', 'q', '&', '?', 'Z', '—', '!', '•', '/', 'Q', 'X', '\t', 'é', '#', '@', '+', '£', '®', 'á', 'í', 'ñ', 'ö', '™']


In [33]:
rare_tokens+='Y'

In [48]:
encode('Xenophobia')

[0, 361, 264, 289, 272, 264, 285, 263, 266]

In [51]:
decode(encode('You'))

'<|unk|>ou'

In [38]:
tokens = get_token_ids('You',vocab_itos=vocab_itos,rare_tokens=rare_tokens,verbose=False)
tokens

[0, 264, 275]

In [39]:
counts = get_counts_llama(tokens)
counts

{(264, 275): 1}

In [40]:
pair = min(counts, key=lambda p: merges.get(p, float("inf")))
pair

(264, 275)

In [41]:
idx = merges[pair]
print(idx)
tokens = merge_llama(tokens, pair, idx)
print(tokens)

380
[0, 380]


In [44]:
%pdb

Automatic pdb calling has been turned ON


In [52]:
ord('⁂')

8258