In [1]:
from common import load_train_index
from pathlib import Path
from hw_asr.base.base_text_encoder import BaseTextEncoder

VOCAB_SIZE = 500

index_directory = Path('pretrained_model/index/')

tokenizer_directory = Path('pretrained_model/tokenizer')
tokenizer_directory.mkdir(exist_ok=True)
texts_directory = tokenizer_directory / 'texts.txt'
model_directory = tokenizer_directory / f'sentence_piece_vocab_{VOCAB_SIZE}'

datasets = load_train_index(index_directory)

sentences = []
for dataset in datasets:
    for observation in dataset:
        sentences.append(BaseTextEncoder.normalize_text(observation['text']))
with open(texts_directory, 'w') as f:
    print(*sentences, sep='\n', file=f)

In [8]:
import sentencepiece as spm

model_prefix = f'sentence_piece_vocab_{VOCAB_SIZE}'
model_prefix = tokenizer_directory / model_prefix


if not model_prefix.with_suffix('.model').exists():
    spm.SentencePieceTrainer.train(
        input=texts_directory,
        model_prefix=model_prefix,
        vocab_size=VOCAB_SIZE,
        model_type='bpe'
    )
sp_model = spm.SentencePieceProcessor(model_file=str(model_prefix) + '.model')

In [50]:
assert sp_model.unk_id() == 0

In [34]:
sentences[0]

'it had no ornamentation being exceedingly plain in appearance'

In [42]:
encoded = sp_model.Encode('it had no ornamentation being exceedingly plain in appearance')
print('|'.join([sp_model.IdToPiece(c).replace('▁', '') for c in encoded]))

it|had|no|or|n|am|ent|ation|be|ing|ex|ce|ed|ing|ly|pl|ain|in|app|e|ar|ance


In [65]:
tmp = sp_model.Decode([sp_model.unk_id()])

In [33]:
from utils import reload
reload('hw_asr')
from hw_asr.text_encoder.ctc_char_bpe_encoder import CTCCharBpeEncoder  # noqa

VOCAB_SIZE = 500

encoder = CTCCharBpeEncoder(f'pretrained_model/tokenizer/sentence_piece_vocab_{VOCAB_SIZE}')
encoded = encoder.encode('hello world every day')
print(encoded)
encoder.ctc_decode_enhanced(encoded[0].numpy())

tensor([[ 33.,  39., 478., 205.,  57., 361., 308.]])


'hello world every day'

In [19]:
import torch
import pathlib
temp = pathlib.PosixPath
pathlib.PosixPath = pathlib.WindowsPath

checkpoint = torch.load('pretrained_model/model_checkpoint.pth')

In [20]:
from hw_asr.model.deep_speech import DeepSpeech2


model = DeepSpeech2(n_feats=128, n_class=28)
model.load_state_dict(checkpoint['state_dict'])



<All keys matched successfully>

In [25]:
from torch import nn

model.fc = nn.Linear(1024, VOCAB_SIZE, bias=False)

In [37]:
import json
from torch.optim import Adam


with open('hw_asr\configs\deep_speech_2_server_bpe.json') as f:
    config = json.load(f)

torch.save({
    'state_dict': model.state_dict(),
    'monitor_best': 0,
    'config': config
}, 'tmp/bpe_model.pth')