Wrap the [sentencepiece](https://github.com/google/sentencepiece) Byte-Pair-Encoder with a nice API.

In [1]:
N_TRAIN_SENTENCES = 10_000_000
TRAINING_FILE = 'cnn/summary_bpe_train.txt'
VOCAB_SIZE = 30000
MODEL_NAME = 'summarizer'
DATA_DIRECTORY = '../data/preprocessed_stories'

# make training data

In [2]:
import glob
import os
import random
import tqdm
FILES = glob.glob(os.path.join(DATA_DIRECTORY, '*'))

In [3]:
len(FILES)

312085

In [4]:
random.shuffle(FILES)

In [5]:
class Quit(Exception): pass
try:
    with open(TRAINING_FILE, 'w') as f_train:
        n_sentences = 0
        for file in tqdm.tqdm(FILES):
            with open(file) as f:
                for line in f:
                    n_sentences += 1
                    if n_sentences == N_TRAIN_SENTENCES:
                        raise Quit(
                            'hacky solution to break from two for loops in '
                            'notebook without defining a bunch of funtions and '
                            'overengineering this whole thing.')
                    f_train.write(line.replace('\t', ' '))
except Quit:
    pass

 51%|█████     | 158099/312085 [01:24<01:22, 1873.67it/s]

# Fit Encoder

In [6]:
import numpy as np
import sentencepiece as spm

class BytePairEncoder:
    def __init__(self, vocab_size, model_name, *, model_file=None, vocab_file=None,
                 training_file=None, processor=None, **kwargs):
        self.vocab_size = vocab_size
        self.model_name = model_name
        self.training_file = training_file
        self.model_file = f'{self.model_name}.model' if model_file is None else model_file
        self.vocab_file = f'{self.model_name}.vocab' if vocab_file is None else vocab_file
        if processor is None:
            if training_file is None:
                raise ValueError('training_file cannot be None when processor is also None.')
            processor = self._fit(input=training_file, vocab_size=vocab_size,
                                  model_prefix=model_name, model_type='bpe',
                                  **kwargs)
        self.processor = processor
        
    def encode(self, text):
        return np.array(self.processor.EncodeAsIds(text))
    
    def encode_as_pieces(self, text):
        return self.processor.EncodeAsPieces(text)
    
    def decode(self, ids):
        return self.processor.DecodeIds(ids.tolist())
    
    def decode_pieces(self, pieces):
        return self.processor.DecodePieces(pieces)

    @classmethod
    def from_files(cls, model_file, vocab_file):
        model_name = model_file.partition('.')[0]
        processor = cls._load_model(model_file)
        for vocab_size, _ in enumerate(open(vocab_file), start=1): pass
        return cls(vocab_size=vocab_size, model_name=model_name, processor=processor,
                   model_file=model_file, vocab_file=vocab_file)
        
    @staticmethod
    def _load_model(filename):
        processor = spm.SentencePieceProcessor()
        processor.Load(filename)
        return processor
        
    def _fit(self, **kwargs):
        params = ' '.join([f'--{k}={v}' for k, v in kwargs.items()])
        spm.SentencePieceTrainer.Train(params)
        processor = self._load_model(self.model_file)
        return processor

In [7]:
%%time
bpe = BytePairEncoder(vocab_size=VOCAB_SIZE, model_name=MODEL_NAME, training_file=TRAINING_FILE)

CPU times: user 4min 56s, sys: 45.6 s, total: 5min 42s
Wall time: 5min 55s


 51%|█████     | 158099/312085 [07:20<07:08, 358.95it/s] 

In [8]:
sample_text = ' '.join('''
While the firelight's aglow
strange shadows in the flames will grow
till things we've never seen
will seem familiar
'''.strip().split('\n'))
sample_text

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [9]:
ids = bpe.encode(sample_text)
ids

array([ 2201,     8,  1241,  3542, 29948, 29930,   262,  7487,  7718,
       22535,    30,     8,  8433,   239,  1706, 17841,  1508,   166,
       29948,    54,   926,   941,   239,  1487,  7467])

In [10]:
bpe.decode(ids)

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [11]:
pieces = bpe.encode_as_pieces(sample_text)
pieces

[b'\xe2\x96\x81While',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81fire',
 b'light',
 b"'",
 b's',
 b'\xe2\x96\x81ag',
 b'low',
 b'\xe2\x96\x81strange',
 b'\xe2\x96\x81shadows',
 b'\xe2\x96\x81in',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81flames',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81grow',
 b'\xe2\x96\x81till',
 b'\xe2\x96\x81things',
 b'\xe2\x96\x81we',
 b"'",
 b've',
 b'\xe2\x96\x81never',
 b'\xe2\x96\x81seen',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81seem',
 b'\xe2\x96\x81familiar']

In [12]:
bpe.decode_pieces(pieces)

"While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [13]:
''.join(p.decode('utf-8').replace('▁', ' ') for p in pieces)

" While the firelight's aglow strange shadows in the flames will grow till things we've never seen will seem familiar"

In [14]:
!more cnn.vocab | head -20

<unk>	0
<s>	0
</s>	0
▁	-2.70172
entity	-2.75739
▁@	-2.75798
▁the	-3.20654
s	-3.45313
,	-3.48918
▁.	-3.62246
▁a	-4.06935
▁to	-4.14815
▁in	-4.20395
ed	-4.20493
▁"	-4.28682
▁of	-4.33887
and	-4.47272
ing	-4.55818
▁'	-4.65301
d	-4.91634


In [15]:
!more cnn.vocab | tail -20

office	-15.7772
amer	-15.7801
▁bei	-15.7805
gnan	-15.794
uns	-15.7945
▁jum	-15.7958
▁instructi	-15.7968
met	-15.7983
monstr	-15.7984
abet	-15.8024
ection	-15.8069
ivo	-15.8081
ighth	-15.8117
.0	-15.818
▁catche	-15.8277
record	-15.8278
▁fed	-15.839
?	-15.8402
q	-16.0888
j	-16.0889


# Test loading

In [16]:
bpe = BytePairEncoder.from_files('summarizer.model', 'summarizer.vocab')

In [17]:
pieces = bpe.encode_as_pieces(sample_text)
pieces

[b'\xe2\x96\x81While',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81fire',
 b'light',
 b"'",
 b's',
 b'\xe2\x96\x81ag',
 b'low',
 b'\xe2\x96\x81strange',
 b'\xe2\x96\x81shadows',
 b'\xe2\x96\x81in',
 b'\xe2\x96\x81the',
 b'\xe2\x96\x81flames',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81grow',
 b'\xe2\x96\x81till',
 b'\xe2\x96\x81things',
 b'\xe2\x96\x81we',
 b"'",
 b've',
 b'\xe2\x96\x81never',
 b'\xe2\x96\x81seen',
 b'\xe2\x96\x81will',
 b'\xe2\x96\x81seem',
 b'\xe2\x96\x81familiar']

In [18]:
bpe.vocab_size, bpe.model_name

(30000, 'summarizer')