Wrapping the Mel-spectrogram extraction into the model

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [2]:
import tensorflow as tf
import numpy as np
from tensorflow.keras.models import load_model
from tensorflow.keras import layers

In [5]:
# Constants
SAMPLE_RATE = 16000
DURATION = 2
N_MELS = 40
FRAME_LENGTH = 512
FRAME_STEP = 256
FFT_LENGTH = 512
EXPECTED_TIME_STEPS = 63 # The number of time steps expected by the CNN model

# Load your trained CNN model that takes (40, 63, 1)
cnn_model = load_model("/content/drive/MyDrive/cnn_audio_classifier_approach (1).h5")

# ----------------------------
# Wrapper Model: Raw audio -> Mel Spectrogram -> CNN
# ----------------------------

# Custom Keras Layer for STFT
class STFTLayer(layers.Layer):
    def __init__(self, frame_length, frame_step, fft_length, **kwargs):
        super(STFTLayer, self).__init__(**kwargs)
        self.frame_length = frame_length
        self.frame_step = frame_step
        self.fft_length = fft_length

    def call(self, inputs):
        stft = tf.signal.stft(
            inputs,
            frame_length=self.frame_length,
            frame_step=self.frame_step,
            fft_length=self.fft_length
        )
        return tf.abs(stft)

# Custom Keras Layer for Mel Spectrogram
class MelSpectrogramLayer(layers.Layer):
    def __init__(self, num_mel_bins, sample_rate, lower_edge_hertz, upper_edge_hertz, **kwargs):
        super(MelSpectrogramLayer, self).__init__(**kwargs)
        self.num_mel_bins = num_mel_bins
        self.sample_rate = sample_rate
        self.lower_edge_hertz = lower_edge_hertz
        self.upper_edge_hertz = upper_edge_hertz

    def build(self, input_shape):
        num_spectrogram_bins = input_shape[-1]
        self.linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
            num_mel_bins=self.num_mel_bins,
            num_spectrogram_bins=num_spectrogram_bins,
            sample_rate=self.sample_rate,
            lower_edge_hertz=self.lower_edge_hertz,
            upper_edge_hertz=self.upper_edge_hertz
        )
        super(MelSpectrogramLayer, self).build(input_shape)

    def call(self, inputs):
        mel_spectrogram = tf.tensordot(inputs, self.linear_to_mel_weight_matrix, 1)
        mel_spectrogram.set_shape(inputs.shape[:-1].concatenate(self.linear_to_mel_weight_matrix.shape[-1:]))
        return mel_spectrogram

# Custom Keras Layer for Log Scale
class LogScaleLayer(layers.Layer):
    def __init__(self, epsilon=1e-6, **kwargs):
        super(LogScaleLayer, self).__init__(**kwargs)
        self.epsilon = epsilon

    def call(self, inputs):
        return tf.math.log(inputs + self.epsilon)

# Custom Keras Layer for Expanding Dimensions
class ExpandDimsLayer(layers.Layer):
    def __init__(self, axis, **kwargs):
        super(ExpandDimsLayer, self).__init__(**kwargs)
        self.axis = axis

    def call(self, inputs):
        return tf.expand_dims(inputs, axis=self.axis)

# Custom Keras Layer for Cropping Time Steps
class CroppingTimeLayer(layers.Layer):
    def __init__(self, target_time_steps, **kwargs):
        super(CroppingTimeLayer, self).__init__(**kwargs)
        self.target_time_steps = target_time_steps

    def call(self, inputs):
        current_time_steps = tf.shape(inputs)[1]

        def crop():
            return inputs[:, :self.target_time_steps, :, :]

        def no_change():
            return inputs

        return tf.cond(tf.greater(current_time_steps, self.target_time_steps), crop, no_change)

# Custom Keras Layer for Transposing Dimensions
class TransposeLayer(layers.Layer):
    def __init__(self, perm, **kwargs):
        super(TransposeLayer, self).__init__(**kwargs)
        self.perm = perm

    def call(self, inputs):
        return tf.transpose(inputs, perm=self.perm)


# Step 1: Input layer for raw audio
input_audio = tf.keras.Input(shape=(SAMPLE_RATE * DURATION,), name="raw_audio")

# Step 2: Compute STFT using custom layer
spectrogram = STFTLayer(
    frame_length=FRAME_LENGTH,
    frame_step=FRAME_STEP,
    fft_length=FFT_LENGTH
)(input_audio)

# Step 3: Apply Mel Filterbank using custom layer
mel_spectrogram = MelSpectrogramLayer(
    num_mel_bins=N_MELS,
    sample_rate=SAMPLE_RATE,
    lower_edge_hertz=80.0,
    upper_edge_hertz=7600.0
)(spectrogram)

# Step 4: Convert to log scale (dB) using custom layer
log_mel_spectrogram = LogScaleLayer()(mel_spectrogram)

# Step 5: Add channel dimension using custom layer
log_mel_spectrogram = ExpandDimsLayer(axis=-1)(log_mel_spectrogram)

# Step 6: Crop time steps using custom layer
cropped_mel_spectrogram = CroppingTimeLayer(target_time_steps=EXPECTED_TIME_STEPS)(log_mel_spectrogram)

# Step 7: Transpose dimensions to match CNN input (Batch, N_MELS, TimeSteps, Channels)
transposed_mel_spectrogram = TransposeLayer(perm=[0, 2, 1, 3])(cropped_mel_spectrogram)

# Step 8: Pass through CNN model
output = cnn_model(transposed_mel_spectrogram)

# Final model
full_model = tf.keras.Model(inputs=input_audio, outputs=output, name="RawAudioToPrediction")

# Save full model
full_model.save("cnn_with_preprocessing_2.h5")



Converting the above model to .tflite

In [6]:
# Convert to TFLite
converter = tf.lite.TFLiteConverter.from_keras_model(full_model)
tflite_model = converter.convert()

# Save TFLite model
with open("cnn_with_preprocessing_2.tflite", "wb") as f:
    f.write(tflite_model)

print("✅ Exported cnn_with_preprocessing_2.tflite")


Saved artifact at '/tmp/tmpa_vd5laq'. The following endpoints are available:

* Endpoint 'serve'
  args_0 (POSITIONAL_ONLY): TensorSpec(shape=(None, 32000), dtype=tf.float32, name='raw_audio')
Output Type:
  TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)
Captures:
  140252764705744: TensorSpec(shape=(257, 40), dtype=tf.float32, name=None)
  140252783985488: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783986832: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783988944: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783990288: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783990096: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783987984: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783991824: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783993552: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783992016: TensorSpec(shape=(), dtype=tf.resource, name=None)
  140252783994128: 