In [None]:
import string
import jiwer
from mfcc_model import brnn_ctc_mfcc
import numpy as np
import tensorflow as tf
from tensorflow import keras
from keras.models import Model
import tensorflow.keras.backend as K
from mfcc_generator import DataGenerator, pad_wav, pad_text, get_wavs_path, get_text_path, dictionary

In [None]:
tf.test.gpu_device_name()
from tensorflow.python.client import device_lib
import os
#print(device_lib.list_local_devices())
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
config = tf.compat.v1.ConfigProto() 
config.gpu_options.allow_growth=True

In [None]:
# get test
def get_test_data(test_path):
    wavs_path_val = get_wavs_path(test_path)
    txts_path_val = get_text_path(wavs_path_val)

    audios_val, input_length_val = pad_wav(wavs_path_val, nfilt)
    texts_val, label_length_val = pad_text(txts_path_val)

    return audios_val, input_length_val, texts_val, label_length_val

In [None]:
# model constructor and predict

def predict(weight_save_path, audios_val, input_length_val, nfeature, nclass, lr_rate, momentum):
    # model constructor
    test_model = brnn_ctc_mfcc(nfeature, nclass, lr_rate, momentum, False, False)
    test_model.model.load_weights(weight_save_path, by_name = True)
    
    #test_model.model.summary()
    
    # predict
    y_pred = test_model.model.predict(audios_val)
    
    # ctc decode
    input_length_val = [length[0] for length in input_length_val]
    y_pred = K.get_value(K.ctc_decode(y_pred, input_length=input_length_val,
                         greedy=True)[0][0])
    return y_pred

In [None]:
# convert back to character
def decode(nums):
    _,dic = dictionary()
    result = ""
    for i, klass in enumerate(nums):
        result += dic[klass]
    return result


# calculate word error rate of model prediction
def evaluate(y_pred, texts_val):
    WER, count = 0, 0
    for pred, true in zip(y_pred, texts_val):
        index = np.where(pred == -1)[0][0]
        pred = np.array(pred[:index])
        pred = decode(pred)

        index = np.where(true == -1)[0]
        if len(index) != 0:
            index = index[0]
            true = np.array(true[:index])
        true = decode(true)
        WER += jiwer.wer(true, pred)
        count += 1
    print(WER, count, WER/count)

In [None]:
# variable define

test_path = "/TIMIT/data/TEST"

weight_save_path = "mfcc_Adam_weights.h5"

nfeature = 120
nclass = 27
lr_rate = 10**(-4)
momentum = 0.9

In [None]:
def __main__():
    audios_val, input_length_val, texts_val, label_length_val = get_test_data(test_path)
    y_pred = predict(weight_save_path, audios_val, input_length_val, nfeature, nclass, lr_rate, momentum)
    evaluate(y_pred, texts_val)
    
__main__()