In [1]:
import os
import nemo
import copy
import tqdm
import timeit
import shutil
import pathlib
import attrdict
import numpy as np

from ruamel import yaml
from Bio import pairwise2
from nemo.collections import asr as nemo_asr
from nemo.collections.asr.helpers import post_process_predictions, post_process_transcripts, word_error_rate

In [2]:
NGC_MAP = '~/Downloads/librivox-train-all.json'
LOCAL_MAP = [
    '/home/stanislavv/data/nemo-librispeech/train_clean_100.json',
    '/home/stanislavv/data/nemo-librispeech/train_clean_360.json',
    '/home/stanislavv/data/nemo-librispeech/train_other_500.json',
]

In [3]:
def get_ngc_id(audio_file):
    return os.path.basename(audio_file)[:-4]


order = [
    get_ngc_id(example['audio_file']) 
    for example in nemo.collections.asr.parts.manifest.item_iter(NGC_MAP)
]
order[:5]

['792-127528-0000',
 '792-127528-0001',
 '792-127528-0002',
 '792-127528-0003',
 '792-127528-0004']

In [4]:
local_id_index = [
    (get_ngc_id(example['audio_file']), i)
    for i, example in enumerate(nemo.collections.asr.parts.manifest.item_iter(LOCAL_MAP))
]
local_id_index[:5]

[('1594-135914-0000', 0),
 ('1594-135914-0001', 1),
 ('1594-135914-0002', 2),
 ('1594-135914-0003', 3),
 ('1594-135914-0004', 4)]

In [5]:
local_id_index_dict = dict(local_id_index)
new_order = [local_id_index_dict[id_] for id_ in order]
new_order[:5]

[133099, 133100, 133101, 133102, 133103]

In [6]:
AD_HOC_MAP = '/home/stanislavv/data/nemo-librispeech/train-all.json'

lines = []
for mf in LOCAL_MAP:
    with open(mf, 'r') as f:
        lines.extend(list(f))
lines = [lines[id_] for id_ in new_order]

with open(AD_HOC_MAP, 'w') as f:
    for line in lines:
        f.write(line)

In [7]:
runner = nemo.core.NeuralModuleFactory(
    placement=nemo.core.DeviceType.GPU
)

[NeMo W 2020-03-13 14:51:14 deprecated:68] Function ``_get_trainer`` is deprecated. It is going to be removed in the future version.


In [8]:
MODEL_CONFIG = 'examples/asr/configs/quartznet15x5.yaml'
yaml_loader = yaml.YAML(typ="safe")
with open(MODEL_CONFIG) as f:
    config = attrdict.AttrDict(yaml_loader.load(f))
labels = list(config.labels)
str(labels)

'[\' \', \'a\', \'b\', \'c\', \'d\', \'e\', \'f\', \'g\', \'h\', \'i\', \'j\', \'k\', \'l\', \'m\', \'n\', \'o\', \'p\', \'q\', \'r\', \'s\', \'t\', \'u\', \'v\', \'w\', \'x\', \'y\', \'z\', "\'"]'

In [9]:
eval_dl_params = copy.deepcopy(config.AudioToTextDataLayer)
eval_dl_params.update(config.AudioToTextDataLayer["train"])
del eval_dl_params["train"]
del eval_dl_params["eval"]
eval_dl_params['shuffle'] = False
data_layer = nemo_asr.AudioToTextDataLayer(
    manifest_filepath=AD_HOC_MAP,
    sample_rate=config.sample_rate,
    labels=config.labels,
    batch_size=32,
    **eval_dl_params,
)

[NeMo I 2020-03-13 14:52:31 collections:138] Dataset loaded with 278627 files totalling 948.67 hours
[NeMo I 2020-03-13 14:52:31 collections:139] 2614 files were filtered totalling 12.38 hours


In [10]:
data_preprocessor = nemo_asr.AudioToMelSpectrogramPreprocessor(
    sample_rate=config.sample_rate, **config.AudioToMelSpectrogramPreprocessor
)
jasper_encoder = nemo_asr.JasperEncoder(
    feat_in=config.AudioToMelSpectrogramPreprocessor["features"], **config.JasperEncoder
)
jasper_decoder = nemo_asr.JasperDecoderForCTC(
    feat_in=config.JasperEncoder["jasper"][-1]["filters"], num_classes=len(config.labels)
)
greedy_decoder = nemo_asr.GreedyCTCDecoder()

[NeMo I 2020-03-13 14:52:31 features:144] PADDING: 16
[NeMo I 2020-03-13 14:52:31 features:152] STFT using conv


In [11]:
audio_signal_e1, a_sig_length_e1, transcript_e1, transcript_len_e1 = data_layer()
processed_signal_e1, p_length_e1 = data_preprocessor(input_signal=audio_signal_e1, length=a_sig_length_e1)
encoded_e1, encoded_len_e1 = jasper_encoder(audio_signal=processed_signal_e1, length=p_length_e1)
log_probs_e1 = jasper_decoder(encoder_output=encoded_e1)
predictions_e1 = greedy_decoder(log_probs=log_probs_e1)

In [12]:
eval_tensors = [log_probs_e1, predictions_e1, transcript_e1, transcript_len_e1, encoded_len_e1, p_length_e1]
load_dir = '/home/stanislavv/data/nemo-quartznet-15x5'
evaluated_tensors = runner.infer(tensors=eval_tensors, checkpoint_dir=load_dir)

[NeMo I 2020-03-13 14:52:35 actions:1459] Restoring JasperEncoder from /home/stanislavv/data/nemo-quartznet-15x5/JasperEncoder-STEP-406556.pt
[NeMo I 2020-03-13 14:52:36 actions:1459] Restoring JasperDecoderForCTC from /home/stanislavv/data/nemo-quartznet-15x5/JasperDecoderForCTC-STEP-406556.pt
[NeMo I 2020-03-13 14:52:36 actions:729] Evaluating batch 0 out of 8708
[NeMo I 2020-03-13 14:57:31 actions:729] Evaluating batch 870 out of 8708
[NeMo I 2020-03-13 15:02:28 actions:729] Evaluating batch 1740 out of 8708
[NeMo I 2020-03-13 15:07:25 actions:729] Evaluating batch 2610 out of 8708
[NeMo I 2020-03-13 15:12:23 actions:729] Evaluating batch 3480 out of 8708
[NeMo I 2020-03-13 15:17:20 actions:729] Evaluating batch 4350 out of 8708
[NeMo I 2020-03-13 15:22:20 actions:729] Evaluating batch 5220 out of 8708
[NeMo I 2020-03-13 15:27:19 actions:729] Evaluating batch 6090 out of 8708
[NeMo I 2020-03-13 15:32:18 actions:729] Evaluating batch 6960 out of 8708
[NeMo I 2020-03-13 15:37:17 actio

In [13]:
references = post_process_transcripts(evaluated_tensors[2], evaluated_tensors[3], config.labels)
greedy_hypotheses = post_process_predictions(evaluated_tensors[1], config.labels)
word_error_rate(greedy_hypotheses, references)

0.0049496181304478715

In [14]:
text = evaluated_tensors[2][0]
text_len = evaluated_tensors[3][0]
ctc_tokens = evaluated_tensors[1][0]
ctc_logprobs = evaluated_tensors[0][0]
ctc_len = evaluated_tensors[4][0]
mel_len = evaluated_tensors[-1][0]

In [15]:
text1 = text[0].numpy()
text_len1 = text_len[0].numpy().item()
ctc_tokens1 = ctc_tokens[0].numpy()
ctc_logprobs1 = ctc_logprobs[0].numpy()
ctc_len1 = ctc_len[0].numpy().item()
mel_len1 = mel_len[0].numpy().item()

In [16]:
text1 = list(text1[:text_len1])
ctc_tokens1 = list(ctc_tokens1[:ctc_len1])
ctc_logprobs1 = ctc_logprobs1[:ctc_len1]
len(text1), len(ctc_tokens1), ctc_logprobs1.shape

(174, 810, (810, 29))

In [17]:
class Processer:
    def __init__(self, labels):
        labels = labels + ['<BLANK>']
        self.space_id = labels.index(' ')
        self.blank_id = len(labels) - 1
        self.labels_map = dict([(i, labels[i]) for i in range(len(labels))])
    
    def bound_text(self, tokens):
        return [self.space_id] + tokens + [self.space_id]
    
    def bound_ctc(self, tokens, logprobs):
        tokens = [self.space_id, self.blank_id] + tokens + [self.blank_id, self.space_id]
        
        logprobs = np.lib.pad(logprobs, ((2, 2), (0, 0)), 'edge')

        def swap(col, a, b):
            logprobs[col][a], logprobs[col][b] = logprobs[col][b], logprobs[col][a]
        
        first_token, last_token = tokens[2], tokens[-3]
        swap(0, first_token, self.space_id)
        swap(1, first_token, self.blank_id)
        swap(-1, last_token, self.space_id)
        swap(-2, last_token, self.blank_id)

        return tokens, logprobs
    
    def merge(self, tokens):
        output_tokens = []
        output_cnts = []
        cnt = 0
        for i in range(len(tokens)):
            if i != 0 and (tokens[i - 1] != tokens[i]):
                output_tokens.append(tokens[i - 1])
                output_cnts.append(cnt)

                cnt = 0

            cnt += 1

        output_tokens.append(tokens[-1])
        output_cnts.append(cnt)
        
        assert sum(output_cnts) == len(tokens), f'SUM_CHECK {sum(output_cnts)} vs {len(tokens)}'

        return output_tokens, output_cnts
    
    def merge_with_blanks(self, tokens, cnts, logprobs=None):
        def choose_sep(l, r, a, b):
            # `tokens[l] == a and tokens[r] == b`.
            sum_a, sum_b = logprobs[l, a], logprobs[l + 1:r + 1, b].sum()
            best_sum, best_sep = sum_a + sum_b, 0
            for sep in range(1, r - l):
                sum_a += logprobs[l + sep, a]
                sum_b -= logprobs[l + sep, b]
                if sum_a + sum_b > best_sum:
                    best_sum, best_sep = sum_a + sum_b, sep

            return best_sep
        
        output_tokens = []
        output_durs = []
        blank_cnt = 0
        total_cnt = 0
        for token, cnt in zip(tokens, cnts):
            total_cnt += cnt
            if token == self.blank_id:
                blank_cnt += cnt
                continue
            
            output_tokens.append(token)
            
            if logprobs is None:
                # Half half.
                left_cnt = blank_cnt // 2
            else:
                # Clever sep choice based on sum of log probs.
                left_cnt = choose_sep(
                    l=total_cnt - cnt - blank_cnt - 1,
                    r=total_cnt - cnt,
                    a=output_tokens[-1],
                    b=token,
                )
            right_cnt = blank_cnt - left_cnt
            blank_cnt = 0
            
            if left_cnt:
                output_durs[-1] += left_cnt
            output_durs.append(cnt + right_cnt)
        
        output_durs[-1] += blank_cnt

        assert sum(output_durs) == sum(cnts), f'SUM_CHECK {sum(output_durs)} vs {sum(cnts)}'

        return output_tokens, output_durs
    
    def align(self, output_tokens, gt_text):
        def make_str(tokens):
            return ''.join(self.labels_map[c] for c in tokens)
        
        s = make_str(output_tokens)
        t = make_str(gt_text)
        alignmet = pairwise2.align.globalxx(s, t)[0]
        sa, ta, *_ = alignmet
        return sa, ta
    
    def generate(self, gt_text, alignment, durs):
        output_tokens = []
        output_cnts = []
        si, ti = 0, 0    
        for sc, tc in zip(*alignment):
            if sc == '-':
                output_tokens.append(self.blank_id)
                output_cnts.append(2 * durs[si])
                si += 1
            elif tc == '-':
                output_tokens.append(gt_text[ti])
                output_cnts.append(0)
                ti += 1
            else:
                output_tokens.append(gt_text[ti])
                output_cnts.append(2 * durs[si])
                si += 1
                ti += 1

        assert sum(output_cnts) == 2 * sum(durs)
        
        return output_tokens, output_cnts

    def __call__(self, text, ctc_tokens, ctc_logprobs, mel_len):
        # This adds +2 tokens.
        text = self.bound_text(text)
        # This add +4 tokens, 2 of them are blank.
        ctc_tokens, ctc_logprobs = self.bound_ctc(ctc_tokens, ctc_logprobs)

        ctc_tokens, ctc_cnts = self.merge(ctc_tokens)
        ctc_tokens, ctc_durs = self.merge_with_blanks(ctc_tokens, ctc_cnts, ctc_logprobs)
        
        alignment = self.align(text, ctc_tokens)
        tokens, cnts = self.generate(text, alignment, ctc_durs)
        tokens, durs = self.merge_with_blanks(tokens, cnts)
        assert tokens == text, 'EXACT_TOKENS_MATCH_CHECK'

        def adjust(start, direction, value):
            i = start
            while value != 0:
                dur = durs[i]
                
                if value < 0:
                    durs[i] = dur - value
                else:
                    durs[i] = max(dur - value, 0)
                
                value -= dur - durs[i]
                i += direction

        adjust(0, 1, 4)
        adjust(-1, -1, sum(durs) - mel_len)  # Including 4 suffix bound tokens.
        assert durs[0] >= 0, f'{durs[0]}'
        assert durs[-1] >= 0, f'{durs[-1]}'
        
        durs = np.array(durs, dtype=np.long)
        assert durs.shape[0] == len(text), f'LEN_CHECK {durs.shape[0]} vs {len(text)}'
        assert np.sum(durs) == mel_len, f'SUM_CHECK {np.sum(durs)} vs {mel_len}'

        return durs

processer = Processer(labels)
durs1 = processer(text1, ctc_tokens1, ctc_logprobs1, mel_len1)
durs1

array([  0,  38,   8,   2,  10,   6,   4,   4,   4,  14,  16,   2,   8,
         6,   2,  12,   4, 294,   8,   8,   4,   4,   8,   6,   6,   6,
         2,   4,   4,  58,   4,   2,   2,   2,   4,   4,   2,   2,   2,
         4,   4,   4,   2,   4,  12,   4,   8,   4,  10,   4,   6,   4,
         8,   8,   4,   4,   4,  16,   4,   4,   4,   8,   2,   2,   2,
         8,   6,   2,   8,   2,  10,  12,   6,   6,  72,   4,   2,   6,
         6,   2,   2,   6,   4,   2,   2,   2,   8,   4,   8,   8,  10,
         2,   2,   6,   2,   6,   6,   2,   4,   8,   4,   8,  10,   2,
         6,   8,  40,  28,   4,   4,   2,   8,   6,   6,   4,   2,   2,
        52,   8,   6,  12,   6,   4,   2,   6,   6,   4,   6,  10,   6,
         2,   8,   6,   2,  10,   2,   2,   2,   4,  12,   6,   6,   2,
         8,   6,   4,  10,   4,   2,   2,  16,   6,   2,  46,  18,   8,
         4,   6,   4,   2,   8,   2,   4,   8,   4,   4,   4,   8,   8,
        10,   4,   6,  10,   6,  84,   3])

In [18]:
f = lambda: processer(text1, ctc_tokens1, ctc_logprobs1, mel_len1)
n, m = 10000, 280000
(timeit.timeit(f, number=n) / n) * m

944.1908576488495

In [19]:
# NOTE: This is going pretty fast.
durs_dir = pathlib.Path('/home/stanislavv/data/nemo-librispeech/durs')
durs_dir.mkdir(exist_ok=True)
k = -1

for batch in tqdm.tqdm(zip(*evaluated_tensors), total=8789):
    text = batch[2].numpy()
    text_len = batch[3].numpy()
    ctc_tokens = batch[1].numpy()
    ctc_logprobs = batch[0].numpy()
    ctc_len = batch[4].numpy()
    mel_len = batch[-1].numpy()

    for text1, text_len1, ctc_tokens1, ctc_logprobs1, ctc_len1, mel_len1 in zip(
        text, text_len, ctc_tokens, ctc_logprobs, ctc_len, mel_len
    ):
        text1 = list(text1[:text_len1])
        ctc_tokens1 = list(ctc_tokens1[:ctc_len1])
        ctc_logprobs1 = ctc_logprobs1[:ctc_len1]
        mel_len1 = mel_len1.item()

        durs1 = processer(text1, ctc_tokens1, ctc_logprobs1, mel_len1)

        k += 1
        np.save(durs_dir / f'{k}.npy', durs1, allow_pickle=False)

 99%|█████████▉| 8708/8789 [18:02<00:10,  8.04it/s]


In [12]:
durs_dir = pathlib.Path('/home/stanislavv/data/nemo-librispeech/durs')
arrays = []
for index in tqdm.trange(278627):
    arrays.append(np.load(durs_dir / f'{index}.npy'))

100%|██████████| 278627/278627 [01:09<00:00, 3999.21it/s]


In [4]:
len(arrays)

278627

In [13]:
np.save('/home/stanislavv/data/nemo-librispeech/all-durs.npy', arrays)

In [13]:
arrays = np.load('/home/stanislavv/data/nemo-librispeech/bigdurs.npy', allow_pickle=True)

In [14]:
arrays[0]

array([  0,  38,   8,   2,  10,   6,   4,   4,   4,  14,  16,   2,   8,
         6,   2,  12,   4, 294,   8,   8,   4,   4,   8,   6,   6,   6,
         2,   4,   4,  58,   4,   2,   2,   2,   4,   4,   2,   2,   2,
         4,   4,   4,   2,   4,  12,   4,   8,   4,  10,   4,   6,   4,
         8,   8,   4,   4,   4,  16,   4,   4,   4,   8,   2,   2,   2,
         8,   6,   2,   8,   2,  10,  12,   6,   6,  72,   4,   2,   6,
         6,   2,   2,   6,   4,   2,   2,   2,   8,   4,   8,   8,  10,
         2,   2,   6,   2,   6,   6,   2,   4,   8,   4,   8,  10,   2,
         6,   8,  40,  28,   4,   4,   2,   8,   6,   6,   4,   2,   2,
        52,   8,   6,  12,   6,   4,   2,   6,   6,   4,   6,  10,   6,
         2,   8,   6,   2,  10,   2,   2,   2,   4,  12,   6,   6,   2,
         8,   6,   4,  10,   4,   2,   2,  16,   6,   2,  46,  18,   8,
         4,   6,   4,   2,   8,   2,   4,   8,   4,   4,   4,   8,   8,
        10,   4,   6,  10,   6,  84,   3])

In [9]:
arrays[0].dtype

dtype('int64')

In [5]:
len(list(durs_dir.iterdir()))

278627