In [8]:
from argparse import ArgumentParser
from json import dump
from logging import basicConfig, getLogger
from os import linesep, remove
from os.path import exists
from tempfile import NamedTemporaryFile
from typing import Dict, List, Tuple

from requests import get
from sentencepiece import SentencePieceProcessor
from tqdm import trange, tqdm

In [2]:
basicConfig()

In [14]:
class SentencePieceExtractor:
    """
    Extractor implementation for SentencePiece trained models.
    https://github.com/google/sentencepiece
    """

    def __init__(self, model: str):
        # Get SentencePiece
        self.sp = SentencePieceProcessor()
        self.sp.Load(model)

    def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
        sp = self.sp
        vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}

        # Merges
        merges = []
        for piece_l in tqdm(vocab.keys(), total=sp.GetPieceSize()):
            for piece_r in vocab.keys():
                merge = f"{piece_l}{piece_r}"
                piece_id = vocab.get(merge, None)
                if piece_id:
                    merges += [(piece_l, piece_r, piece_id)]
        merges = sorted(merges, key=lambda val: val[2])
        merges = [(val[0], val[1]) for val in merges]

        return vocab, merges

In [17]:
from itertools import chain
from multiprocessing import Pool, cpu_count

class SentencePieceExtractor:
    """
    Extractor implementation for SentencePiece trained models.
    https://github.com/google/sentencepiece
    """

    def __init__(self, model: str):
        # Get SentencePiece
        self.sp = SentencePieceProcessor()
        self.sp.Load(model)
    
    def extract(self) -> Tuple[Dict[str, int], List[Tuple]]:
        sp = self.sp
        vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())} 

        results = []
        with Pool(cpu_count()) as pool:
            results = pool.starmap(self.extract_merges, [(key, vocab,) for key in vocab.keys()])

        # Flatten and filter empty lists
        merges = list(chain.from_iterable(filter(None, results)))

        merges.sort(key=lambda val: val[2])
        merges = [(val[0], val[1]) for val in merges]
        
        return vocab, merges

    """
    Multiprocessing for merges.
    """
    @staticmethod
    def extract_merges(piece_l, vocab):
        merges = []
        for piece_r in vocab.keys():
            merge = f"{piece_l}{piece_r}"
            piece_id = vocab.get(merge, None)
            if piece_id:
                merges += [(piece_l, piece_r, piece_id)]
        
        return merges

In [9]:
import sys
sys.path.insert(0, "/projects/data/aamod/Libraries/tokenizers/bindings/python/scripts")
from sentencepiece_extractor import SentencePieceExtractor

In [15]:
sp = SentencePieceProcessor()
sp.Load("SP_Tokenizer/BG.model")

True

In [16]:
spe = SentencePieceExtractor("SP_Tokenizer/BG.model")

In [13]:
v, m = spe.extract()

100%|███████████████████████████████████████████████████████████████████████████████████████| 12000/12000 [00:00<00:00, 884765.38it/s]


In [17]:
v1, m1 = spe.extract()

100%|███████████████████████████████████████████████████████████████████████████████████████| 12000/12000 [00:00<00:00, 939373.80it/s]
100%|██████████████████████████████████████████████████████████████████████████████████████████| 12000/12000 [00:21<00:00, 563.81it/s]


In [18]:
v1 == v

True

In [19]:
m1 == m

True

In [23]:
v

{'#-': 0,
 '<PAD>': 1,
 '<unk>': 2,
 '<BOS>': 3,
 '<EOS>': 4,
 '\n': 5,
 '\r': 6,
 '<CLS>': 7,
 '<SEP>': 8,
 '<EOD>': 9,
 '<MASK>': 10,
 '<start_of_turn>': 11,
 '<end_of_turn>': 12,
 '〈|javascript|〉': 13,
 '〈|python|〉': 14,
 '〈|sql|〉': 15,
 '〈|shell|〉': 16,
 '〈|c|〉': 17,
 '〈|cpp|〉': 18,
 '〈|java|〉': 19,
 '〈|go|〉': 20,
 '<|reserved_0|>': 21,
 ' <|reserved_1|>': 22,
 ' <|reserved_2|>': 23,
 ' <|reserved_3|>': 24,
 ' <|reserved_4|>': 25,
 ' <|reserved_5|>': 26,
 ' <|reserved_6|>': 27,
 ' <|reserved_7|>': 28,
 ' <|reserved_8|>': 29,
 ' <|reserved_9|>': 30,
 ' <|reserved_10|>': 31,
 ' <|reserved_11|>': 32,
 ' <|reserved_12|>': 33,
 ' <|reserved_13|>': 34,
 ' <|reserved_14|>': 35,
 ' <|reserved_15|>': 36,
 ' <|reserved_16|>': 37,
 ' <|reserved_17|>': 38,
 ' <|reserved_18|>': 39,
 ' <|reserved_19|>': 40,
 ' <|reserved_20|>': 41,
 ' <|reserved_21|>': 42,
 ' <|reserved_22|>': 43,
 ' <|reserved_23|>': 44,
 ' <|reserved_24|>': 45,
 ' <|reserved_25|>': 46,
 ' <|reserved_26|>': 47,
 ' <|reserved_27

In [6]:
vocab = {sp.id_to_piece(index): index for index in trange(sp.GetPieceSize())}

100%|███████████████████████████████████████████████████████████████████████████████████████| 12000/12000 [00:00<00:00, 887104.50it/s]


In [7]:
from multiprocessing import Pool, cpu_count

In [11]:
def process_merges(piece_l, vocab):    
    merges = []
    for piece_r in vocab.keys():
        merge = f"{piece_l}{piece_r}"
        piece_id = vocab.get(merge, None)
        if piece_id:
            merges += [(piece_l, piece_r, piece_id)]
    
    return merges

In [12]:
## Think about sorting later, aim should be correctness

In [14]:
results = []
with Pool(cpu_count()) as pool:
    results = pool.starmap(process_merges, [(k, vocab,) for k in vocab.keys()])

In [15]:
merges = []
for i in results:
    if i != []:
        merges += i

In [16]:
merges = sorted(merges, key=lambda val: val[2])
merges = [(val[0], val[1]) for val in merges]

In [22]:
import json

In [23]:
with open("vocab_fast.json", "w") as f:
    json.dump(vocab, f)

In [24]:
type(merges)

list

In [25]:
merges[0]

('▁', '▁')

In [26]:
with open("merges_fast.txt", "w") as f:
    for item in merges:
        f.write(f"{item[0]} {item[1]}\n")

In [27]:
from tokenizers import SentencePieceBPETokenizer

In [28]:
tokenizer = SentencePieceBPETokenizer()

In [29]:
test = tokenizer.from_file("vocab_fast.json", "merges_fast.txt")

In [30]:
test

Tokenizer(vocabulary_size=128000, model=SentencePieceBPE, unk_token=<unk>, replacement=▁, add_prefix_space=True, dropout=None)

In [36]:
test.save_model("Final")

['Final/vocab.json', 'Final/merges.txt']

In [37]:
test.save("Final/tokenizer.json")

In [42]:
from transformers import PreTrainedTokenizerFast
tokenizer = PreTrainedTokenizerFast(tokenizer_file="Final/tokenizer.json")

In [44]:
tokenizer.save_pretrained("Paka_Final")

('Paka_Final/tokenizer_config.json',
 'Paka_Final/special_tokens_map.json',
 'Paka_Final/tokenizer.json')

In [46]:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Paka_Final", use_fast=True)

In [47]:
tokenizer.encode("How are tyou sff <BOS>ffds dsa<EOS>1!)")

[13576,
 6047,
 11980,
 5304,
 5315,
 5522,
 124378,
 3,
 5318,
 83741,
 5348,
 11795,
 4,
 124378,
 124400,
 64684]