In [None]:
import json
import copy
import os

from transformers import AutoTokenizer, PreTrainedTokenizerFast
from tokenizers import Tokenizer
from tokenizers.models import BPE

In [None]:
original_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")

In [None]:
initial_tokenizer_json = json.loads(original_tokenizer._tokenizer.to_str())

In [None]:
with open('llama_tokenizer.json', 'w') as f:
    json.dump(initial_tokenizer_json, f, indent=2, ensure_ascii=False)

In [None]:
initial_tok_vocab = json.loads(original_tokenizer._tokenizer.to_str())['model']['vocab']
initial_tok_merges = json.loads(original_tokenizer._tokenizer.to_str())['model']['merges']

In [None]:
with open("llama_3_ext.merges") as f:
    new_merges = f.readlines()

In [None]:
merges = []
for m in new_merges:
    if m.endswith("\n"):
        merges.append(m[:-1])
    else:
        merges.append(m)

del new_merges

In [None]:
with open("llama_3_ext.vocab") as f:
    vocab = {k:v for k,v in sorted(json.load(f).items(), key=lambda x: x[1])}

In [None]:
vocab_fixed = copy.deepcopy(vocab)
for k, v in vocab.items():
    if v >= 128000:
        vocab_fixed[k] = v + 256

In [None]:
initial_tokenizer_json['model']['merges'] = merges
initial_tokenizer_json['model']['vocab'] = vocab_fixed

In [None]:
print("After change:", len(initial_tokenizer_json['model']['vocab']))
print("Before change:", len(json.loads(original_tokenizer._tokenizer.to_str())['model']['vocab']))

In [None]:
print("After change:", len(initial_tokenizer_json['model']['merges']))
print("Before change:", len(json.loads(original_tokenizer._tokenizer.to_str())['model']['merges']))

In [None]:
# Load the original tokenizer
original_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3-8B")
tokenizer_json = original_tokenizer.backend_tokenizer.to_str()

tokenizer = Tokenizer.from_str(tokenizer_json)

added_tokens = {token: idx for token, idx in original_tokenizer.get_added_vocab().items()}
combined_vocab = {**added_tokens, **vocab_fixed}

bpe = BPE(combined_vocab, [(merge.split()[0], merge.split()[1]) for merge in merges])

new_tokenizer = Tokenizer(bpe)

if tokenizer.pre_tokenizer:
    new_tokenizer.pre_tokenizer = tokenizer.pre_tokenizer

if tokenizer.normalizer:
    new_tokenizer.normalizer = tokenizer.normalizer

if tokenizer.decoder:
    new_tokenizer.decoder = tokenizer.decoder

if tokenizer.post_processor:
    new_tokenizer.post_processor = tokenizer.post_processor

new_transformer_tokenizer = PreTrainedTokenizerFast(tokenizer_object=new_tokenizer)

new_transformer_tokenizer.save_pretrained("/maybe")

In [None]:
original_tokenizer.save_pretrained("/temp_tok")

In [None]:
with open("/maybe/tokenizer.json") as f:
    tok_json = json.load(f)

tok_json['pre_tokenizer'] = initial_tokenizer_json['pre_tokenizer']
tok_json['normalizer'] = initial_tokenizer_json['normalizer']
tok_json['decoder'] = initial_tokenizer_json['decoder']
tok_json['post_processor'] = initial_tokenizer_json['post_processor']

with open("/maybe/tokenizer.json", 'w') as f:
    json.dump(tok_json, f, indent=2, ensure_ascii=False)

with open("/temp_tok/tokenizer_config.json") as f:
    tok_config_json = json.load(f)
with open("/maybe/tokenizer_config.json") as f:
    json.dump(tok_config_json, f, indent=2, ensure_ascii=False)

with open("/temp_tok/special_tokens_map.json") as f:
    tok_config_json = json.load(f)
with open("/maybe/special_tokens_map.json") as f:
    json.dump(tok_config_json, f, indent=2, ensure_ascii=False)

In [None]:
os.system("rm -rf /temp_tok")

In [None]:
tt = AutoTokenizer.from_pretrained("/maybe")

In [None]:
tt

In [None]:
original_tokenizer

In [None]:
itos = {v:k for k,v in tt.get_vocab().items()}
itos[128000]

In [None]:
itos_original = {v:k for k,v in tt.get_vocab().items()}
itos_original[128000]

In [None]:
tt.encode("<|begin_of_text|>")

In [None]:
original_tokenizer.encode("<|begin_of_text|>")