In [13]:
import multiprocessing
import tensorflow as tf
import librosa
import numpy as np
from jiwer import wer

This is the audio file we are going to transcribe, as well as the ground truth transcription

In [92]:
audio_file = 'demo_input/84-121550-0000.flac'
transcript = 'BUT WITH FULL RAVISHMENT THE HOURS OF PRIME SINGING RECEIVED THEY IN THE MIDST OF LEAVES THAT EVER BORE A BURDEN TO THEIR RHYMES'

We first convert the transcript into integers, as well as defining a reverse mapping for decoding the final output.

In [101]:
alphabet = "abcdefghijklmnopqrstuvwxyz' @"
alphabet_dict = {c: ind for (ind, c) in enumerate(alphabet)}
index_dict = {ind: c for (ind, c) in enumerate(alphabet)}
transcript_ints = [alphabet_dict[letter] for letter in transcript.lower()]
print(transcript_ints)

[1, 20, 19, 27, 22, 8, 19, 7, 27, 5, 20, 11, 11, 27, 17, 0, 21, 8, 18, 7, 12, 4, 13, 19, 27, 19, 7, 4, 27, 7, 14, 20, 17, 18, 27, 14, 5, 27, 15, 17, 8, 12, 4, 27, 18, 8, 13, 6, 8, 13, 6, 27, 17, 4, 2, 4, 8, 21, 4, 3, 27, 19, 7, 4, 24, 27, 8, 13, 27, 19, 7, 4, 27, 12, 8, 3, 18, 19, 27, 14, 5, 27, 11, 4, 0, 21, 4, 18, 27, 19, 7, 0, 19, 27, 4, 21, 4, 17, 27, 1, 14, 17, 4, 27, 0, 27, 1, 20, 17, 3, 4, 13, 27, 19, 14, 27, 19, 7, 4, 8, 17, 27, 17, 7, 24, 12, 4, 18]


We then load the audio file and convert it to MFCCs (with an extra batch dimension).

In [70]:
def normalize(values):
    """
    Normalize values to mean 0 and std 1
    """
    return (values - np.mean(values)) / np.std(values)

def transform_audio_to_mfcc(audio_file, transcript, n_mfcc=13, n_fft=512, hop_length=160):
    audio_data, sample_rate = librosa.load(audio_file, sr=16000)

    mfcc = librosa.feature.mfcc(audio_data, sr=sample_rate, n_mfcc=n_mfcc, n_fft=n_fft, hop_length=hop_length)

    # add derivatives and normalize
    mfcc_delta = librosa.feature.delta(mfcc)
    mfcc_delta2 = librosa.feature.delta(mfcc, order=2)
    mfcc = np.concatenate((normalize(mfcc), normalize(mfcc_delta), normalize(mfcc_delta2)), axis=0)

    seq_length = mfcc.shape[1] // 2

    sequences = np.concatenate([[seq_length], transcript]).astype(np.int32)
    sequences = np.expand_dims(sequences, 0)
    mfcc_out = mfcc.T.astype(np.float32)
    mfcc_out = np.expand_dims(mfcc_out, 0)

    return mfcc_out, sequences

In [33]:
def log(std):
    """Log the given string to the standard output."""
    print("******* {}".format(std), flush=True)

We use the ctc decoder to decode the output of the network

In [99]:
def ctc_preparation(tensor, y_predict):
    if len(y_predict.shape) == 4:
        y_predict = tf.squeeze(y_predict, axis=1)
    y_predict = tf.transpose(y_predict, (1, 0, 2))
    sequence_lengths, labels = tensor[:, 0], tensor[:, 1:]
    idx = tf.where(tf.not_equal(labels, 28))
    sparse_labels = tf.SparseTensor(
        idx, tf.gather_nd(labels, idx), tf.shape(labels, out_type=tf.int64)
    )
    return sparse_labels, sequence_lengths, y_predict

def ctc_ler(y_true, y_predict):
    sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)
    decoded, log_probabilities = tf.nn.ctc_greedy_decoder(
        y_predict, tf.cast(logit_length, tf.int32), merge_repeated=True
    )
    return tf.reduce_mean(
        tf.edit_distance(
            tf.cast(decoded[0], tf.int32), tf.cast(sparse_labels, tf.int32)
        ).numpy()
    ), tf.sparse.to_dense(decoded[0]).numpy()

def trans_int_to_string(trans_int):
    #create dictionary int -> string (0 -> a 1 -> b)
    string = ""
    alphabet = "abcdefghijklmnopqrstuvwxyz' @"
    alphabet_dict = {}
    count = 0
    for x in alphabet:
        alphabet_dict[count] = x
        count += 1
    for letter in trans_int:
        letter_np = np.array(letter).item(0)
        if letter_np != 28:
            string += alphabet_dict[letter_np]
    return string

def ctc_wer(y_true, y_predict):
    sparse_labels, logit_length, y_predict = ctc_preparation(y_true, y_predict)
    decoded, log_probabilities = tf.nn.ctc_greedy_decoder(
            y_predict, tf.cast(logit_length, tf.int32), merge_repeated=True
    )
    true_sentence = tf.cast(sparse_labels.values, tf.int32)
    return wer(str(trans_int_to_string(decoded[0].values)), str(trans_int_to_string(true_sentence)))

The TFLite file requires inputs of size 296, so we apply a window to the input

In [94]:
def evaluate_tflite(tflite_path, input_window_length =  296):
    """Evaluates tflite (fp32, int8)."""
    results = []
    data, label = transform_audio_to_mfcc(audio_file, transcript_ints)

    interpreter = tf.lite.Interpreter(model_path=tflite_path, num_threads=multiprocessing.cpu_count())
    interpreter.allocate_tensors()
    input_chunk = interpreter.get_input_details()[0]
    output_details = interpreter.get_output_details()[0]

    input_shape = input_chunk["shape"]
    log("eval_model() - input_shape: {}".format(input_shape))
    input_dtype = input_chunk["dtype"]
    output_dtype = output_details["dtype"]

    # Check if the input/output type is quantized,
    # set scale and zero-point accordingly
    if input_dtype != tf.float32:
        input_scale, input_zero_point = input_chunk["quantization"]
    else:
        input_scale, input_zero_point = 1, 0

    if output_dtype != tf.float32:
        output_scale, output_zero_point = output_details["quantization"]
    else:
        output_scale, output_zero_point = 1, 0


    data = data / input_scale + input_zero_point
    # Round the data up if dtype is int8, uint8 or int16
    if input_dtype is not np.float32:
        data = np.round(data)

    while data.shape[1] < input_window_length:
        data = np.append(data, data[:, -2:-1, :], axis=1)
    # Zero-pad any odd-length inputs
    if data.shape[1] % 2 == 1:
        # log('Input length is odd, zero-padding to even (first layer has stride 2)')
        data = np.concatenate([data, np.zeros((1, 1, data.shape[2]), dtype=input_dtype)], axis=1)

    context = 24 + 2 * (7 * 3 + 16)  # = 98 - theoretical max receptive field on each side
    size = input_chunk['shape'][1]
    inner = size - 2 * context
    data_end = data.shape[1]

    # Initialize variables for the sliding window loop
    data_pos = 0
    outputs = []

    while data_pos < data_end:
        if data_pos == 0:
            # Align inputs from the first window to the start of the data and include the intial context in the output
            start = data_pos
            end = start + size
            y_start = 0
            y_end = y_start + (size - context) // 2
            data_pos = end - context
        elif data_pos + inner + context >= data_end:
            # Shift left to align final window to the end of the data and include the final context in the output
            shift = (data_pos + inner + context) - data_end
            start = data_pos - context - shift
            end = start + size
            assert start >= 0
            y_start = (shift + context) // 2  # Will be even because we assert it above
            y_end = size // 2
            data_pos = data_end
        else:
            # Capture only the inner region from mid-input inferences, excluding output from both context regions
            start = data_pos - context
            end = start + size
            y_start = context // 2
            y_end = y_start + inner // 2
            data_pos = end - context

        interpreter.set_tensor(input_chunk["index"], tf.cast(data[:, start:end, :], input_dtype))
        interpreter.invoke()
        cur_output_data = interpreter.get_tensor(output_details["index"])[:, :, y_start:y_end, :]
        cur_output_data = output_scale * (
                cur_output_data.astype(np.float32) - output_zero_point
        )
        outputs.append(cur_output_data)

    complete = np.concatenate(outputs, axis=2)
    LER, output = ctc_ler(label, complete)
    WER = ctc_wer(label, complete)
    return output, LER , WER


In [107]:
wav2letter_tflite_path = "tflite_int8/tiny_wav2letter_int8.tflite"
output, LER , WER = evaluate_tflite(wav2letter_tflite_path)

decoded_output = [index_dict[value] for value in output[0]]
log(f'Transcribed File: {"".join(decoded_output)}')
log(f'Letter Error Rate is {LER}')
log(f'Word Error Rate is {WER}')

******* eval_model() - input_shape: [  1 296  39]
******* Input length is odd, zero-padding to even (first layer has stride 2)
******* Transcribed File: but with full ravishment the hours of prime singing received they in the midst of leaves that everborea burden to their rimes
******* Letter Error Rate is 0.03125
******* Word Error Rate is 1.05
