Wrap the [sentencepiece](https://github.com/google/sentencepiece) Byte-Pair-Encoder with a nice API.

In [1]:
N_TRAIN_SENTENCES = 10_000_000
VOCAB_SIZE = 10_000
MODEL_NAME = 'summarizer'
TRAINING_FILE = f'summary_bpe_train_{VOCAB_SIZE}.txt'
DATA_DIRECTORY = '../data/preprocessed_stories'

# make training data

In [2]:
import glob
import os
import random
import tqdm
from tqdm import tqdm_notebook as tqdm
FILES = glob.glob(os.path.join(DATA_DIRECTORY, '*'))

In [3]:
len(FILES)

311971

In [4]:
random.shuffle(FILES)

In [5]:
class Quit(Exception): pass
try:
    with open(TRAINING_FILE, 'w') as f_train:
        n_sentences = 0
        for file in tqdm(FILES):
            with open(file) as f:
                for line in f:
                    n_sentences += 1
                    if n_sentences == N_TRAIN_SENTENCES:
                        raise Quit(
                            'hacky solution to break from two for loops in '
                            'notebook without defining a bunch of funtions and '
                            'overengineering this whole thing.'
                            'But then again 6 indents later...')
                    f_train.write(line.replace('\t', ' '))
except Quit:
    pass

HBox(children=(IntProgress(value=0, max=311971), HTML(value='')))

# Fit Encoder

In [15]:
import numpy as np
import sentencepiece as spm

class BytePairEncoder:
    def __init__(self, vocab_size, model_name, *, model_file=None, vocab_file=None,
                 training_file=None, processor=None, **kwargs):
        self.vocab_size = vocab_size
        self.model_name = f'{model_name}_{vocab_size}'
        self.training_file = training_file
        self.model_file = f'{self.model_name}.model' if model_file is None else model_file
        self.vocab_file = f'{self.model_name}.vocab' if vocab_file is None else vocab_file
        self.model_type = 'bpe'
        if processor is None:
            if training_file is None:
                raise ValueError('training_file cannot be None when processor is also None.')
            processor = self._fit(input=training_file, vocab_size=self.vocab_size,
                                  model_prefix=self.model_name, model_type=self.model_type,
                                  **kwargs)
        self.processor = processor
        
    def encode(self, text):
        return np.array(self.processor.EncodeAsIds(text))
    
    def encode_as_pieces(self, text):
        return self.processor.EncodeAsPieces(text)
    
    def decode(self, ids):
        return self.processor.DecodeIds(ids.tolist())
    
    def decode_pieces(self, pieces):
        return self.processor.DecodePieces(pieces)

    @classmethod
    def from_files(cls, model_file, vocab_file):
        model_name = model_file.partition('.')[0]
        processor = cls._load_model(model_file)
        for vocab_size, _ in enumerate(open(vocab_file), start=1): pass
        return cls(vocab_size=vocab_size, model_name=model_name, processor=processor,
                   model_file=model_file, vocab_file=vocab_file)
        
    @staticmethod
    def _load_model(filename):
        processor = spm.SentencePieceProcessor()
        processor.Load(filename)
        return processor
        
    def _fit(self, **kwargs):
        params = ' '.join([f'--{k}={v}' for k, v in kwargs.items()])
        print(params)
        spm.SentencePieceTrainer.Train(params)
        processor = self._load_model(self.model_file)
        return processor

In [16]:
%%time
bpe = BytePairEncoder(
    vocab_size=VOCAB_SIZE,
    model_name=MODEL_NAME,
    training_file=TRAINING_FILE, 
)

--input=summary_bpe_train_10000.txt --vocab_size=10000 --model_prefix=summarizer_10000 --model_type=bpe
CPU times: user 2min 40s, sys: 40.2 s, total: 3min 20s
Wall time: 3min 45s


In [17]:
bpe.model_name

'summarizer_10000'

In [18]:
sample_text = ' '.join('''
While the firelight's aglow
strange shadows in the flames will grow
till things we've never seen
will seem familiar
'''.strip().split('\n'))
sample_text

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [19]:
ids = bpe.encode(sample_text)
ids

array([2232,    8, 1247, 3551, 9948, 9930,  262, 7425, 7730,  151,   62,
       2173,   30,    8, 8435,  239, 1697,    3,   96, 1519,  166, 9948,
         54,  925,  941,  239, 1473, 7337])

In [20]:
bpe.decode(ids)

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [21]:
pieces = bpe.encode_as_pieces(sample_text)
pieces

[b'\xe2\x96\x81While',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81fire',
 b'light',
 b"'",
 b's',
 b'\xe2\x96\x81ag',
 b'low',
 b'\xe2\x96\x81strange',
 b'\xe2\x96\x81sh',
 b'ad',
 b'ows',
 b'\xe2\x96\x81in',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81flames',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81grow',
 b'\xe2\x96\x81t',
 b'ill',
 b'\xe2\x96\x81things',
 b'\xe2\x96\x81we',
 b"'",
 b've',
 b'\xe2\x96\x81never',
 b'\xe2\x96\x81seen',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81seem',
 b'\xe2\x96\x81familiar']

In [22]:
bpe.decode_pieces(pieces)

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [23]:
''.join(p.decode('utf-8').replace('▁', ' ') for p in pieces)

" While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [24]:
!head -20 cnn.vocab

head: cannot open 'cnn.vocab' for reading: No such file or directory


In [25]:
!tail -20 cnn.vocab

tail: cannot open 'cnn.vocab' for reading: No such file or directory


# Test loading

In [26]:
bpe = BytePairEncoder.from_files('summarizer.model', 'summarizer.vocab')

In [27]:
pieces = bpe.encode_as_pieces(sample_text)
pieces

[b'\xe2\x96\x81While',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81fire',
 b'light',
 b"'",
 b's',
 b'\xe2\x96\x81ag',
 b'low',
 b'\xe2\x96\x81strange',
 b'\xe2\x96\x81sh',
 b'ad',
 b'ows',
 b'\xe2\x96\x81in',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81flames',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81grow',
 b'\xe2\x96\x81t',
 b'ill',
 b'\xe2\x96\x81things',
 b'\xe2\x96\x81we',
 b"'",
 b've',
 b'\xe2\x96\x81never',
 b'\xe2\x96\x81seen',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81seem',
 b'\xe2\x96\x81familiar']

In [28]:
bpe.vocab_size, bpe.model_name

(10000, 'summarizer_10000')