In [1]:
import sys
import numpy as np
import pandas as pd
from collections import defaultdict
sys.path.append("..")
from model_utils.model import DeepSpeech2Model
from model_utils.network import deepspeech_LocalDotAtten, deepspeech_orig
from data_utils.dataloader import SpecgramGenerator
from torch.utils.data import DataLoader

import torch
import os

# DNN initialize

In [2]:
vocab_list = ["'", ' ', 'a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']
ds2_model = DeepSpeech2Model(
    model=deepspeech_LocalDotAtten,
    vocab_list=vocab_list,
    device="cuda")


ds2_model.init_ext_scorer(1.4, 0.35, "../models/lm/common_crawl_00.prune01111.trie.klm")

[INFO 2020-03-05 16:33:39,017 model.py:361] begin to initialize the external scorer for decoding
begin to initialize the external scorer for decoding
[INFO 2020-03-05 16:33:47,976 model.py:371] language model: is_character_based = 0, max_order = 5, dict_size = 400000
language model: is_character_based = 0, max_order = 5, dict_size = 400000
[INFO 2020-03-05 16:33:48,069 model.py:372] end initializing scorer
end initializing scorer


In [3]:
ds2_model.load_weights("iconect/models/pt_only_all_utt_wgt_transfer_mean/exps/200304-12:06:03_lr7e-06-deepspeech/model_final.pth")

[INFO 2020-03-05 16:33:48,177 model.py:511] load weights from: iconect/models/pt_only_all_utt_wgt_transfer_mean/exps/200304-12:06:03_lr7e-06-deepspeech/model_final.pth
load weights from: iconect/models/pt_only_all_utt_wgt_transfer_mean/exps/200304-12:06:03_lr7e-06-deepspeech/model_final.pth
[INFO 2020-03-05 16:33:48,179 model.py:512] excluded weights: {'deepspeech_bottleneck.bottleneck.bias', 'deepspeech_bottleneck.bottleneck.weight'}
excluded weights: {'deepspeech_bottleneck.bottleneck.bias', 'deepspeech_bottleneck.bottleneck.weight'}


# Data Loading

In [8]:
test_dataset = SpecgramGenerator(manifest="iconect/models/pt_only_all_utt_wgt_transfer_mean/data/test_short.csv",
                               vocab_filepath="../models/baidu_en8k/vocab.txt",
                               mean_std_filepath="../models/baidu_en8k/mean_std.npz",
                               max_duration=float('inf'),
                               min_duration=3,
                               segmented=False)

dataloader = DataLoader(test_dataset, batch_size=32,
                        shuffle=False, num_workers=8,
                       collate_fn=SpecgramGenerator.padding_batch)

# Decoding 

In [9]:
outputs = defaultdict(list)
beam_alpha=1

for i_batch, sample_batched in enumerate(dataloader):
    batch_results = ds2_model.infer_batch_probs(infer_data=sample_batched)
    batch_transcripts_beam = ds2_model.decode_batch_beam_search(probs_split=batch_results,
                                                                 beam_alpha=beam_alpha,
                                                                 beam_beta=0.35,
                                                                 beam_size=500,
                                                                 cutoff_prob=1.0,
                                                                 cutoff_top_n=40,
                                                                 num_processes=6)
    
    outputs["uttid"].extend(sample_batched["uttid"])
    outputs["probs"].extend(batch_results)
    outputs["asr"].extend(batch_transcripts_beam)
    outputs["text"].extend(sample_batched["trans"])
df = pd.DataFrame.from_dict(outputs)


In [10]:
saving_path = "test/pt_200304_lr7e-6_model_on_val.csv"
df.to_pickle(saving_path)
test_df = pd.read_pickle(saving_path)
test_df.head()

Unnamed: 0,uttid,probs,asr,text
0,C1007_VC_3_FullCon_Wk01_Day3_100318_Participan...,"[[3.8415106e-05, 7.407037e-05, 0.037902262, 0....",that seer called the crawl or the free to,that's that's either called the crawl or the f...
1,C1028_VC_3_FullCon_Wk01_Day3_040319_Participan...,"[[0.00011470138, 0.00021678227, 0.006443831, 0...",have people that's why there's so many people ...,uh people that's why there's so many people ou...
2,C1028_VC_3_FullCon_Wk01_Day3_040319_Participan...,"[[8.468798e-05, 0.00047555013, 0.008537458, 0....",were down at the beach the hand of bay part,were down at the beach nehalem bay park
3,C2016_VC_3_FullCon_Wk01_Day4_050219_Participan...,"[[0.00010233711, 2.3265064e-05, 0.036521416, 0...",new ones that it's like uh the virginian,but the new ones is like um the virginian
4,C1007_VC_3_FullCon_Wk01_Day3_100318_Participan...,"[[5.9622544e-06, 1.21478015e-05, 0.0060807224,...",there and swim somewhere between a half a mile...,there and swim somewhere between a half a mile...


In [None]:
from matplotlib import pyplot as plt
import logging
plt.style.use('classic')
mpl_logger = logging.getLogger('matplotlib')
mpl_logger.setLevel(logging.WARNING)

def plot_output(df, row=30):
    plt.figure(figsize=(30, 5))
    for i in df.iloc[row].probs.sum(1):
        assert np.abs(i-1) < 1e-6, "{}".format(row)
    
    plt.imshow(df.iloc[row].probs.T)
    plt.colorbar()
    
plot_output(df)