In [1]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
sys.path.append("..")
from model_utils.model import DeepSpeech2Model
from model_utils.network import deepspeech_LocalDotAtten, deepspeech_orig
from data_utils.dataloader import SpecgramGenerator
from torch.utils.data import DataLoader

import torch
import os

# DNN initialize

In [2]:
vocab_list = ["'", ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
ds2_model = DeepSpeech2Model(
    model=deepspeech_orig,
    vocab_list=vocab_list,
    pretrained_model_path="TBD",
    device="cuda")


ds2_model.init_ext_scorer(1.4, 0.35, "../models/lm/common_crawl_00.prune01111.trie.klm")

[INFO 2020-03-04 09:36:23,207 model.py:362] begin to initialize the external scorer for decoding
begin to initialize the external scorer for decoding
[INFO 2020-03-04 09:36:32,201 model.py:372] language model: is_character_based = 0, max_order = 5, dict_size = 400000
language model: is_character_based = 0, max_order = 5, dict_size = 400000
[INFO 2020-03-04 09:36:32,294 model.py:373] end initializing scorer
end initializing scorer


In [25]:
ds2_model.model.load_state_dict(torch.load("iconect/models/pt_only_all_utt_wgt_transfer/exps/200302-10:10:12_lr2e-05-deepspeech/model_17600.pth"))

<All keys matched successfully>

# Data Loading

In [26]:
test_dataset = SpecgramGenerator(manifest="iconect/models/all_utt_wgt_transfer/data/val_kaldi.csv",
                               vocab_filepath="../models/baidu_en8k/vocab.txt",
                               mean_std_filepath="../models/baidu_en8k/mean_std.npz",
                               max_duration=float('inf'),
                               min_duration=2,
                               segmented=False)

dataloader = DataLoader(test_dataset, batch_size=32,
                        shuffle=False, num_workers=8,
                       collate_fn=SpecgramGenerator.padding_batch)

# Decoding 

In [None]:
outputs = defaultdict(list)
beam_alpha=1

for i_batch, sample_batched in enumerate(dataloader):
    batch_results = ds2_model.infer_batch_probs(infer_data=sample_batched)
    batch_transcripts_beam = ds2_model.decode_batch_beam_search(probs_split=batch_results,
                                                                 beam_alpha=beam_alpha,
                                                                 beam_beta=0.35,
                                                                 beam_size=500,
                                                                 cutoff_prob=1.0,
                                                                 cutoff_top_n=40,
                                                                 num_processes=6)
    
    outputs["uttid"].extend(sample_batched["uttid"])
    outputs["probs"].extend(batch_results)
    outputs["asr"].extend(batch_transcripts_beam)
    outputs["text"].extend(sample_batched["trans"])
df = pd.DataFrame.from_dict(outputs)


In [24]:
saving_path = "test/pt_200302_lr2e-5_model_on_kaldival.csv"
df.to_pickle(saving_path)
test_df = pd.read_pickle(saving_path)
test_df.head()

Unnamed: 0,uttid,probs,asr,text
0,C1083_VC_3_FullCon_Wk01_Day3_080719_Moderator_...,"[[0.0004430709, 0.0004414702, 0.028753586, 0.0...",as for anyway,that's the best part anyway
1,C1082_VC_3_FullCon_Wk01_Day3_081419_Participan...,"[[1.18267645e-08, 5.302059e-07, 8.0755824e-05,...",uh it's been a long time,uh it's been a long time
2,C1033_VC_3_FullCon_Wk01_Day3_032719_Participan...,"[[9.2032207e-07, 4.1930738e-07, 0.00011967613,...",planning,planting
3,C1083_VC_3_FullCon_Wk01_Day3_080719_Participan...,"[[4.688401e-05, 3.5874993e-05, 0.0046603284, 0...",you can't imagine,can't imagine
4,C1082_VC_3_FullCon_Wk01_Day3_081419_Participan...,"[[1.2688511e-07, 3.8794515e-06, 2.6128811e-05,...",it was part of my culture as i got old,it was part of my culture as i got older


In [None]:
from matplotlib import pyplot as plt
import logging
plt.style.use('classic')
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

def plot_output(df, row=30):
    plt.figure(figsize=(30, 5))
    for i in df.iloc[row].probs.sum(1):
        assert np.abs(i-1) < 1e-6, "{}".format(row)
    
    plt.imshow(df.iloc[row].probs.T)
    plt.colorbar()
    
plot_output(df)