# CNN-bi-LSTM for Speech Recognition

This notebook implements a CNN-bi-LSTM neural network for speech recognition using the LJ-Speech dataset. The model converts speech to text using a combination of convolutional neural networks (CNN) and bidirectional long short-term memory (LSTM) networks.


In [None]:
# Import necessary libraries
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import wave
from IPython.display import Audio
from sklearn.model_selection import train_test_split


## 1. Data Loading and Exploration

First, let's load the metadata and explore the dataset.


In [None]:
# Path to the dataset
data_path = "data/LJSpeech-1.1/"
wavs_path = os.path.join(data_path, "wavs/")
metadata_path = os.path.join(data_path, "metadata.csv")

# Load metadata
metadata_df = pd.read_csv(metadata_path, sep="|", header=None, quoting=3)
metadata_df.columns = ["file_id", "transcription", "normalized_transcription"]
metadata_df["wav_file"] = metadata_df["file_id"].apply(lambda x: os.path.join(wavs_path, f"{x}.wav"))

# Display the first few rows
metadata_df.head()


In [None]:
# Check the number of samples
print(f"Total number of samples: {len(metadata_df)}")

# Check if all wav files exist
missing_files = [f for f in metadata_df["wav_file"] if not os.path.exists(f)]
print(f"Number of missing wav files: {len(missing_files)}")

# Get some statistics about the transcriptions
metadata_df["transcription_length"] = metadata_df["transcription"].apply(len)
print(f"Average transcription length: {metadata_df['transcription_length'].mean():.2f} characters")
print(f"Min transcription length: {metadata_df['transcription_length'].min()} characters")
print(f"Max transcription length: {metadata_df['transcription_length'].max()} characters")


## 2. Audio Preprocessing

Now, let's implement functions to load and preprocess the audio files. We'll convert the audio to spectrograms, which will be the input to our neural network.


In [None]:
# Function to read a wav file and return the audio data
def read_wav_file(wav_file):
    with wave.open(wav_file, "rb") as wav:
        # Get basic information about the wav file
        n_channels = wav.getnchannels()
        sample_width = wav.getsampwidth()
        frame_rate = wav.getframerate()
        n_frames = wav.getnframes()

        # Read the audio data
        audio_data = wav.readframes(n_frames)

        # Convert to numpy array
        audio_data = np.frombuffer(audio_data, dtype=np.int16)

        # Normalize to [-1, 1]
        audio_data = audio_data.astype(np.float32) / 32768.0

    return audio_data, frame_rate


In [None]:
# Function to compute spectrogram from audio data
def compute_spectrogram(audio_data, frame_rate, frame_length=256, frame_step=160, fft_length=384):
    # Compute the Short-Time Fourier Transform (STFT)
    stft = tf.signal.stft(
        audio_data,
        frame_length=frame_length,
        frame_step=frame_step,
        fft_length=fft_length,
        window_fn=tf.signal.hann_window
    )

    # Compute the magnitude spectrogram
    spectrogram = tf.abs(stft)

    # Apply the mel filterbank
    num_mel_bins = 80
    lower_edge_hertz = 80.0
    upper_edge_hertz = 7600.0

    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
        num_mel_bins, spectrogram.shape[-1], frame_rate, lower_edge_hertz, upper_edge_hertz
    )

    mel_spectrogram = tf.tensordot(spectrogram, linear_to_mel_weight_matrix, 1)

    # Convert to log scale (dB)
    log_mel_spectrogram = tf.math.log(mel_spectrogram + 1e-6)

    return log_mel_spectrogram


In [None]:
# Let's visualize a spectrogram for one audio file
sample_wav_file = metadata_df["wav_file"].iloc[0]
sample_audio_data, sample_frame_rate = read_wav_file(sample_wav_file)
sample_spectrogram = compute_spectrogram(sample_audio_data, sample_frame_rate)

plt.figure(figsize=(10, 4))
plt.imshow(tf.transpose(sample_spectrogram), aspect="auto", origin="lower")
plt.colorbar(format="%+2.0f dB")
plt.title("Log Mel Spectrogram")
plt.xlabel("Time")
plt.ylabel("Mel Frequency")
plt.tight_layout()
plt.show()

# Play the audio
Audio(sample_audio_data, rate=sample_frame_rate)


## 3. Text Preprocessing

Next, let's preprocess the text data. We'll create a character-level tokenizer to convert text to sequences of integers.


In [None]:
# Create a vocabulary of characters
characters = [x for x in "abcdefghijklmnopqrstuvwxyz'?! "]
char_to_num = tf.keras.layers.StringLookup(vocabulary=characters, oov_token="")
num_to_char = tf.keras.layers.StringLookup(
    vocabulary=char_to_num.get_vocabulary(), oov_token="", invert=True
)

print(f"The vocabulary is: {char_to_num.get_vocabulary()} (size: {len(char_to_num.get_vocabulary())})")


In [None]:
# Function to preprocess text
def preprocess_text(text):
    # Convert to lowercase
    text = text.lower()

    # Remove special characters except those in our vocabulary
    text = "".join([c for c in text if c in characters])

    # Convert to sequence of integers
    text_encoded = char_to_num(tf.strings.unicode_split(text, input_encoding="UTF-8"))

    return text_encoded


In [None]:
# Test the text preprocessing
sample_text = metadata_df["transcription"].iloc[0]
print(f"Original text: {sample_text}")
processed_text = preprocess_text(sample_text)
print(f"Processed text: {processed_text}")
decoded_text = tf.strings.reduce_join(num_to_char(processed_text)).numpy().decode("utf-8")
print(f"Decoded text: {decoded_text}")


## 4. Data Generator

Now, let's create a data generator to load and preprocess the data in batches.


In [None]:
# Split the data into training and validation sets
train_df, val_df = train_test_split(metadata_df, test_size=0.2, random_state=42)
print(f"Training set size: {len(train_df)}")
print(f"Validation set size: {len(val_df)}")


In [None]:
def calc_output_length(input_length: int):  # length calculate helper
    # Example for 2 conv layers with stride 2 each (adjust based on your architecture)
    length = input_length
    length = length // 2  # after first conv/pool
    length = length // 2  # after second conv/pool
    return length


# Create a data generator
class AudioDataGenerator(tf.keras.utils.Sequence):
    def __init__(self, dataframe, batch_size=32, shuffle=True, **kwargs):
        super().__init__(**kwargs)
        self.dataframe = dataframe
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.indices = np.arange(len(dataframe))
        if self.shuffle:
            np.random.shuffle(self.indices)

    def __len__(self):
        return int(np.ceil(len(self.dataframe) / self.batch_size))

    def __getitem__(self, idx):
        batch_indices = self.indices[idx * self.batch_size:(idx + 1) * self.batch_size]
        batch_df = self.dataframe.iloc[batch_indices]

        batch_spectrograms = []
        batch_labels = []

        for _, row in batch_df.iterrows():
            # Load and preprocess audio
            audio_data, frame_rate = read_wav_file(row["wav_file"])
            spectrogram = compute_spectrogram(audio_data, frame_rate)

            # Preprocess text
            text = row["transcription"]
            text_encoded = preprocess_text(text)
            if spectrogram.shape[0] < 2 * len(text_encoded) - 1:
                continue

            # Skip empty inputs or labels
            if spectrogram.shape[0] == 0:
                print(f"Skipping sample with zero-length spectrogram: {row['wav_file']}")
                continue
            if len(text_encoded) == 0:
                print(f"Skipping sample with empty label: {row['wav_file']} transcription: {text}")
                continue

            batch_spectrograms.append(spectrogram)
            batch_labels.append(text_encoded)

        if len(batch_spectrograms) == 0:
            raise ValueError("No valid samples in batch")

        # Pad spectrograms
        max_spectrogram_length = max([s.shape[0] for s in batch_spectrograms])
        padded_spectrograms = []
        for spectrogram in batch_spectrograms:
            padded_spectrogram = tf.pad(
                spectrogram,
                [[0, max_spectrogram_length - spectrogram.shape[0]], [0, 0]],
                "CONSTANT"
            )
            padded_spectrograms.append(padded_spectrogram)

        batch_spectrograms = tf.stack(padded_spectrograms)

        # Pad labels
        batch_labels = tf.keras.preprocessing.sequence.pad_sequences(
            batch_labels, padding="post"
        )

        input_lengths = tf.expand_dims(
            tf.ones(batch_spectrograms.shape[0], dtype=tf.int32) * calc_output_length(max_spectrogram_length),
            axis=-1
        )

        label_lengths = tf.expand_dims(
            tf.reduce_sum(tf.cast(batch_labels != 0, tf.int32), axis=1),
            axis=-1
        )

        inputs = {
            "input": batch_spectrograms,
            "input_length": input_lengths,
            "label_length": label_lengths,
            "label": batch_labels
        }
        outputs = tf.zeros((batch_spectrograms.shape[0], 1), dtype=tf.float32)

        return inputs, outputs

    def on_epoch_end(self):
        if self.shuffle:
            np.random.shuffle(self.indices)

In [None]:
# Create data generators
batch_size = 32
train_generator = AudioDataGenerator(train_df, batch_size=batch_size)
val_generator = AudioDataGenerator(val_df, batch_size=batch_size, shuffle=False)

# Test the data generator
inputs, outputs = train_generator[0]

## 5. CNN-bi-LSTM Model

Now, let's implement the CNN-bi-LSTM model for speech recognition.


In [None]:
# Define the CTC loss function
def ctc_loss(y_true, y_pred):
    batch_len = tf.shape(y_true)[0]
    input_length = tf.shape(y_pred)[1]
    label_length = tf.shape(y_true)[1]

    input_length = input_length * tf.ones(shape=(batch_len, 1), dtype=tf.int32)
    label_length = label_length * tf.ones(shape=(batch_len, 1), dtype=tf.int32)

    loss = tf.keras.backend.ctc_batch_cost(y_true, y_pred, input_length, label_length)
    return tf.reduce_mean(loss)

In [None]:
# Define a custom layer to get tensor shape
class ShapeLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(ShapeLayer, self).__init__(**kwargs)

    def call(self, inputs):
        # Return the shape as a concrete tensor, not a symbolic one
        return tf.shape(inputs)


# Define a custom layer for dynamic reshaping
class DynamicReshapeLayer(layers.Layer):
    def __init__(self, **kwargs):
        super(DynamicReshapeLayer, self).__init__(**kwargs)

    def call(self, inputs):
        # Get the shape information from the inputs
        batch_size = tf.shape(inputs)[0]
        # For a 4D tensor from Conv2D, the shape is [batch, time, height, width, channels]
        # or for some cases it might be [batch, time, height, width]
        shape = tf.shape(inputs)

        # Handle both 4D and 5D tensors
        if len(inputs.shape) == 4:  # [batch, time, height, width]
            time_steps = shape[1]
            height = shape[2]
            width = shape[3]
            channels = 1
        else:  # [batch, time, height, width, channels]
            time_steps = shape[1]
            height = shape[2]
            width = shape[3]
            channels = shape[4]

        # Reshape to flatten height, width, and channels while preserving time dimension
        return tf.reshape(inputs, [batch_size, time_steps, height * width * channels])


# Define the CNN-bi-LSTM model
def build_model(input_dim, output_dim):
    # Input layer
    input_spectrogram = layers.Input(shape=(None, input_dim), name="input")
    input_length = layers.Input(shape=(1,), dtype=tf.int32, name="input_length")
    label = layers.Input(shape=(None,), dtype=tf.int32, name="label")
    label_length = layers.Input(shape=(1,), dtype=tf.int32, name="label_length")

    # Expand dimensions for CNN
    x = layers.Reshape((-1, input_dim, 1))(input_spectrogram)

    # CNN layers
    x = layers.Conv2D(32, (3, 3), activation="relu", padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(64, (3, 3), activation="relu", padding="same")(x)
    x = layers.BatchNormalization()(x)
    x = layers.MaxPooling2D((2, 2))(x)

    x = layers.Conv2D(128, (3, 3), activation="relu", padding="same")(x)
    x = layers.BatchNormalization()(x)

    # Reshape for RNN using the custom DynamicReshapeLayer
    dynamic_reshape_layer = DynamicReshapeLayer()
    x = dynamic_reshape_layer(x)

    # Bidirectional LSTM layers
    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
    x = layers.Dropout(0.25)(x)

    x = layers.Bidirectional(layers.LSTM(128, return_sequences=True))(x)
    x = layers.Dropout(0.25)(x)

    # Output layer
    x = layers.Dense(output_dim + 1, activation="softmax")(x)

    # Define the loss function
    def ctc_loss_function(args):
        y_pred, labels, input_length, label_length = args
        return tf.keras.backend.ctc_batch_cost(labels, y_pred, input_length, label_length)

    # Add the CTC loss
    ctc_output = layers.Lambda(ctc_loss_function, output_shape=(1,), name="ctc")(
        [x, label, input_length, label_length]
    )
    # Define the training model
    training_model = tf.keras.Model(
        inputs=[input_spectrogram, input_length, label, label_length],
        outputs=ctc_output
    )

    # Compile the model
    training_model.compile(optimizer=tf.keras.optimizers.Adam(), loss=lambda y_true, y_pred: y_pred)

    # Define the prediction model
    prediction_model = tf.keras.Model(inputs=input_spectrogram, outputs=x)

    return training_model, prediction_model


In [None]:
# Build the model
input_dim = 80  # Number of mel frequency bins
output_dim = len(char_to_num.get_vocabulary())  # Number of characters in the vocabulary

training_model, prediction_model = build_model(input_dim, output_dim)

# Print the model summary
training_model.summary()


## 6. Model Training

Now, let's train the model.


In [None]:
# Define callbacks
checkpoint_path = "models/cnn_bilstm_speech_recognition.weights.h5"
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
    checkpoint_path,
    monitor="val_loss",
    save_best_only=True,
    save_weights_only=True,
    verbose=1
)

early_stopping_callback = tf.keras.callbacks.EarlyStopping(
    monitor="val_loss",
    patience=10,
    restore_best_weights=True,
    verbose=1
)

reduce_lr_callback = tf.keras.callbacks.ReduceLROnPlateau(
    monitor="val_loss",
    factor=0.2,
    patience=5,
    min_lr=1e-6,
    verbose=1
)


In [None]:
# Train the model
# epochs = 50
epochs = 1
# try:
#     training_model.load_weights(checkpoint_path)
# except FileNotFoundError:
#     pass

history = training_model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=epochs,
    callbacks=[checkpoint_callback, early_stopping_callback, reduce_lr_callback],
)


In [None]:
# Plot the training history
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.plot(history.history["loss"])
plt.plot(history.history["val_loss"])
plt.title("Model Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend(["Train", "Validation"])

plt.subplot(1, 2, 2)
plt.plot(history.history["lr"])
plt.title("Learning Rate")
plt.xlabel("Epoch")
plt.ylabel("Learning Rate")
plt.tight_layout()
plt.show()


## 7. Model Evaluation and Inference

Now, let's evaluate the model and use it for inference.


In [None]:
# Load the best model weights
prediction_model.load_weights(checkpoint_path)


In [None]:
# Function to decode the predictions
def decode_predictions(pred):
    # Use greedy decoding (best path)
    pred = tf.argmax(pred, axis=-1)
    pred = tf.cast(pred, dtype=tf.int32)

    # Convert to characters
    pred = num_to_char(pred)

    # Join characters to form words
    pred = tf.strings.reduce_join(pred, axis=-1)

    return pred.numpy()


In [None]:
# Function to perform CTC beam search decoding
def decode_batch_predictions(pred):
    # Use CTC beam search decoding
    input_len = np.ones(pred.shape[0]) * pred.shape[1]
    results = tf.keras.backend.ctc_decode(pred, input_length=input_len, greedy=False, beam_width=10)[0][0]

    # Convert to characters
    output_text = []
    for result in results:
        result = tf.strings.reduce_join(num_to_char(result)).numpy().decode("utf-8")
        output_text.append(result)

    return output_text


In [None]:
# Test the model on a few samples
test_samples = val_df.sample(5)

for _, row in test_samples.iterrows():
    # Load and preprocess audio
    audio_data, frame_rate = read_wav_file(row["wav_file"])
    spectrogram = compute_spectrogram(audio_data, frame_rate)

    # Expand dimensions for batch
    spectrogram = tf.expand_dims(spectrogram, axis=0)

    # Make prediction
    prediction = prediction_model.predict(spectrogram)

    # Decode prediction
    decoded_prediction = decode_batch_predictions(prediction)[0]

    # Print results
    print(f"Original text: {row['transcription']}")
    print(f"Predicted text: {decoded_prediction}")
    print("-" * 80)

    # Play the audio
    display(Audio(audio_data, rate=frame_rate))
