In [None]:
import tensorflow as tf
import keras

model = keras.models.load_model("qPrunedCNNLSTM.keras")


def preprocess_audio(audio_input, sample_rate=16000, segment_duration=1, frame_duration=0.025,
                     hop_duration=0.010, num_bands=50):
    
    segment_samples = int(segment_duration * sample_rate)
    frame_samples = int(frame_duration * sample_rate)
    hop_samples = int(hop_duration * sample_rate)
    
    audio_data_padded = tf.pad(tf.squeeze(audio_input),
                               [[0, tf.maximum(0, segment_samples - tf.shape(audio_input)[0])]])
    
    audio_data_normalized = tf.math.l2_normalize(audio_data_padded)
    
    stft = tf.signal.stft(audio_data_normalized, fft_length=512, frame_length=frame_samples, frame_step=hop_samples)
    magnitude_spectrogram = tf.abs(stft)
    mel_filterbank = tf.signal.linear_to_mel_weight_matrix(num_bands, tf.shape(magnitude_spectrogram)[-1],
                                                           sample_rate)
    mel_spectrogram = tf.matmul(magnitude_spectrogram, mel_filterbank)
    
    log_mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)
    log_mel_spectrogram = tf.expand_dims(log_mel_spectrogram, axis=-1)
    log_mel_spectrogram = tf.expand_dims(log_mel_spectrogram, axis=0)

    return log_mel_spectrogram


class ExportModel(tf.Module):

    def __init__(self, model):
        super().__init__()
        self.model = model
        
        self.call_string_input = self.__call__.get_concrete_function(
            input=tf.TensorSpec(shape=(), dtype=tf.string))
        self.call_float_input = self.__call__.get_concrete_function(
            input=tf.TensorSpec(shape=[None, 16000], dtype=tf.float32))

    @tf.function
    def __call__(self, input):
        class_labels = ["background", "down", "go", "left", "no", "off", "on", "right", "stop", "up", "yes", "unknown"]
        if input.dtype == tf.string:
            input = tf.io.read_file(input)
            input, _ = tf.audio.decode_wav(input, desired_channels=1, desired_samples=16000, )
            input = tf.squeeze(input, axis=-1)

        input = preprocess_audio(input)

        result = self.model(input)

        class_ids = tf.argmax(result, axis=-1)
        class_names = tf.gather(class_labels, class_ids)

        return {'predictions': result,
                'class_ids': class_ids,
                'class_names': class_names}

export = ExportModel(model)

tf.saved_model.save(export, "SavedModel", signatures={
    'serving_default': export.call_string_input,
    'serving_float': export.call_float_input})

loaded = tf.saved_model.load("SavedModel")

#Batch Waveform:
x = '0d53e045_nohash_1.wav'
x = tf.io.read_file(x)
x, _ = tf.audio.decode_wav(x, desired_channels=1, desired_samples=16000,)
x = tf.squeeze(x, axis=-1)
output = loaded.signatures['serving_float'](x)
print(output)
#String.wav File:
out = loaded("0d53e045_nohash_1.wav")
print(out)

