In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
model_ori = AutoModelForCausalLM.from_pretrained('meta-llama/Llama-2-7b-hf')
model_mod = AutoModelForCausalLM.from_pretrained('indonlp/cendol-llama2-7b-ind-vocab')

tokenizer_ori = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf')
tokenizer_mod = AutoTokenizer.from_pretrained('indonlp/cendol-llama2-7b-ind-vocab')

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

In [3]:
ori_vocab = tokenizer_ori.vocab.items()
ori_vocab_list = sorted(ori_vocab, key=lambda x: x[1])

In [4]:
mod_vocab = tokenizer_mod.vocab.items()
mod_vocab_list = sorted(mod_vocab, key=lambda x: x[1])

torch.Size([4096])

In [29]:
import torch


tensor([[2., 1., 1.],
        [1., 2., 1.],
        [1., 1., 2.]])

In [44]:
c = {True: 0, False: 0}

with torch.inference_mode():
    for i, word in enumerate(mod_vocab_list):
        idx = tokenizer_ori.convert_tokens_to_ids(word[0])
        if i == tokenizer_mod.unk_token_id:
            c[(model_ori.model.embed_tokens.weight.data[tokenizer_ori.unk_token_id,:] == model_mod.model.embed_tokens.weight.data[i,:]).all().item()] += 1
        elif idx != tokenizer_mod.unk_token_id:
            c[(model_ori.model.embed_tokens.weight.data[idx,:] == model_mod.model.embed_tokens.weight.data[i,:]).all().item()] += 1
        else:
            subword_embed = []
            subwords_idx = tokenizer_ori.encode(word[0])[1:]
            for subword_idx in subwords_idx:
                subword_embed.append(model_ori.model.embed_tokens.weight.data[subword_idx,:])
            subword_embed = torch.stack(subword_embed, dim=0).mean(dim=0)
            c[((subword_embed - model_mod.model.embed_tokens.weight.data[i,:]).abs() < 1e-7).all().item()] += 1
c

{True: 32000, False: 0}

In [56]:
word[0]

'▁membalas'

In [57]:
tokenizer_ori.tokenize(word[0])

['▁', '▁memb', 'al', 'as']

In [58]:
[tokenizer_ori.convert_ids_to_tokens(x) for x in tokenizer_ori.encode(word[0])[1:]]

['▁', '▁memb', 'al', 'as']

In [59]:
tokenizer_ori.tokenize('kan')

['▁kan']

In [60]:
[tokenizer_ori.convert_ids_to_tokens(x) for x in tokenizer_ori.encode('kan')[1:]]

['▁kan']

# Calculate Token Efficiency using the new tokenizer

In [26]:
stats = []

In [27]:
tokenizer = AutoTokenizer.from_pretrained('indonlp/cendol-llama2-7b-ind-vocab')
llama_tokenizer = AutoTokenizer.from_pretrained('meta-llama/Llama-2-7b-hf', use_auth_token=True)



In [28]:
data, stat = ind_data, { 'split_count': 0, 'cendol_count': 0, 'llama_count': 0, 'unk_count': 0 }
for sent in (data['dev']['sentence'] + data['devtest']['sentence']):
    split_toks = sent.split()
    cendol_toks = tokenizer.tokenize(sent)
    llama_toks = llama_tokenizer.tokenize(sent)
    stat['split_count'] += len(split_toks)
    stat['cendol_count'] += len(cendol_toks)
    stat['llama_count'] += len(llama_toks)
    stat['unk_count'] += len(list(filter(lambda x: x == tokenizer.unk_token, lang_toks)))
stat['dset'] = 'flores-ind'
stats.append(stat)
stat

{'split_count': 38880,
 'cendol_count': 56496,
 'llama_count': 106644,
 'unk_count': 0,
 'dset': 'flores-ind'}

In [36]:
data, stat = datasets.load_dataset('indonlp/nusatranslation_mt'), { 'split_count': 0, 'cendol_count': 0, 'llama_count': 0, 'unk_count': 0 }
for sent in (data['train']['text'] + data['validation']['text'] + data['test']['text']):
    split_toks = sent.split()
    cendol_toks = tokenizer.tokenize(sent)
    llama_toks = llama_tokenizer.tokenize(sent)
    stat['split_count'] += len(split_toks)
    stat['cendol_count'] += len(cendol_toks)
    stat['llama_count'] += len(llama_toks)
    stat['unk_count'] += len(list(filter(lambda x: x == tokenizer.unk_token, lang_toks)))
stat['dset'] = 'nt-ind'
stats.append(stat)
stat

{'split_count': 211169,
 'cendol_count': 474522,
 'llama_count': 567920,
 'unk_count': 0,
 'dset': 'nt-ind'}

In [33]:
# stats = []
for subset in [
    'nusatranslation_mt_abs_ind_nusantara_t2t', 'nusatranslation_mt_btk_ind_nusantara_t2t', 'nusatranslation_mt_bew_ind_nusantara_t2t',
    'nusatranslation_mt_bhp_ind_nusantara_t2t', 'nusatranslation_mt_jav_ind_nusantara_t2t', 'nusatranslation_mt_mad_ind_nusantara_t2t', 
    'nusatranslation_mt_mak_ind_nusantara_t2t', 'nusatranslation_mt_min_ind_nusantara_t2t', 'nusatranslation_mt_mui_ind_nusantara_t2t', 
    'nusatranslation_mt_rej_ind_nusantara_t2t', 'nusatranslation_mt_sun_ind_nusantara_t2t'
]:
    data, stat = datasets.load_dataset('indonlp/nusatranslation_mt', name=subset), { 'dset': '', 'split_count': 0, 'cendol_count': 0, 'llama_count': 0, 'unk_count': 0 }
    for sent in (data['train']['text_1'] + data['validation']['text_1'] + data['test']['text_1']):
        split_toks = sent.split()
        cendol_toks = tokenizer.tokenize(sent)
        llama_toks = llama_tokenizer.tokenize(sent)
        stat['dset'] = 'nt-' + subset.split('nusatranslation_mt_')[-1][:3]
        stat['split_count'] += len(split_toks)
        stat['cendol_count'] += len(cendol_toks)
        stat['llama_count'] += len(llama_toks)
        stat['unk_count'] += len(list(filter(lambda x: x == tokenizer.unk_token, lang_toks)))
    stats.append(stat)

In [37]:
import pandas as pd
pd.DataFrame(stats).to_csv('cendol_efficiency.csv', index=False)
pd.DataFrame(stats)

Unnamed: 0,split_count,cendol_count,llama_count,unk_count,dset
0,38880,56496,106644,0,flores-ind
1,38336,71046,90340,0,nt-abs
2,217281,493099,553066,0,nt-btk
3,213368,479957,610043,0,nt-bew
4,33597,78234,87223,0,nt-bhp
5,211169,474522,567920,0,nt-jav
6,215155,538931,604589,0,nt-mad
7,194351,543043,593848,0,nt-mak
8,214370,466615,578901,0,nt-min
9,37959,73366,97727,0,nt-mui
