<a href="https://colab.research.google.com/github/HernanDL/Noise-Cancellation-Using-GenAI/blob/main/Simple_CNN_(Test).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# Step 1: Load Required Libraries
#!pip install torch torchaudio librosa matplotlib numpy

import numpy as np
import matplotlib.pyplot as plt
import librosa
import librosa.display
import tensorflow as tf
from scipy.signal import filtfilt
from google.colab import files
from tensorflow.keras.callbacks import EarlyStopping
from keras import backend as K
import IPython.display as ipd
import torch

# Step 2: Load Input Noise Waveform
uploaded = files.upload()

input_file = next(iter(uploaded))
input_waveform, sr = librosa.load(input_file, sr=None)

print(f"input_waveform len: {len(input_waveform)}")
sampling_rate_khz = sr / 1000  # Convert to kHz
print(f"Sampling Rate: {sampling_rate_khz:.2f} kHz")

duration = librosa.get_duration(y=input_waveform, sr=sr)
print(f"Duration: {duration:.2f} seconds")

Saving Bn-ord-B3-mf-N-N.wav to Bn-ord-B3-mf-N-N (1).wav
input_waveform len: 201450
Sampling Rate: 44.10 kHz
Duration: 4.57 seconds


In [None]:
K.clear_session()
torch.cuda.empty_cache()

# Step 3: Preprocess the Data
def preprocess_waveform(waveform):
    return waveform / np.max(np.abs(waveform))

input_waveform = preprocess_waveform(input_waveform)

# Reshape for CNN input (1 sample, length of waveform, 1 channel)
X_cnn = input_waveform.reshape(1, -1, 1)

# Step 4: Define the CNN Model with Input Layer
cnn_model = tf.keras.Sequential([
    tf.keras.layers.Input(shape=(X_cnn.shape[1], X_cnn.shape[2])),  # Specify input shape here
    tf.keras.layers.Conv1D(32, kernel_size=3, activation='relu'),
    tf.keras.layers.MaxPooling1D(pool_size=2),
    tf.keras.layers.Conv1D(64, kernel_size=3, activation='relu'),
    tf.keras.layers.MaxPooling1D(pool_size=2),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(len(input_waveform), activation='tanh')  # Output shape matches input shape
])

cnn_model.compile(optimizer='adam', loss='mean_squared_error')
cnn_model.summary()

# Step 5: Train the CNN Model
target_waveform = -X_cnn

# Define early stopping
early_stopping = EarlyStopping(monitor='loss', patience=3, verbose=1)

# Train model (using a small number of epochs for demonstration)
history = cnn_model.fit(X_cnn, target_waveform, epochs=10, batch_size=1, callbacks=[early_stopping], verbose=1)

# Visualize Training Loss
plt.plot(history.history['loss'], label='Train Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.title('Training Loss')
plt.show()

# Predict the output waveform using CNN
predicted_waveform_cnn = cnn_model.predict(X_cnn).flatten()

# apply filter to remove residual noise
# Function to apply Gaussian Filter with Linear Phase (Forward-Backward Filtering)
def gaussian_filter_linear_phase(signal, window_size, sigma):
    gauss_kernel = np.exp(-(np.linspace(-2.5, 2.5, window_size))**2 / (2 * sigma**2))
    gauss_kernel /= np.sum(gauss_kernel)
    return filtfilt(gauss_kernel, 1, signal)

# Assuming `predicted_signal` is your noisy signal
predicted_waveform_cnn = gaussian_filter_linear_phase(predicted_waveform_cnn, window_size=15, sigma=1.5)


# Step 6: Visualize Results for CNN
plt.figure(figsize=(15, 8))

plt.subplot(3, 1, 1)
librosa.display.waveshow(input_waveform, sr=sr)
plt.title('Input Waveform')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')

plt.subplot(3, 1, 2)
librosa.display.waveshow(predicted_waveform_cnn, sr=sr)
plt.title('Predicted Waveform (CNN)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')

combined_waveform_cnn = input_waveform + predicted_waveform_cnn

plt.subplot(3, 1, 3)
librosa.display.waveshow(input_waveform, sr=sr, label='Input Waveform', color='b')
librosa.display.waveshow(combined_waveform_cnn, sr=sr, label='Combined Waveform', color='r')
plt.title('Input and Combined Signals (CNN)')
plt.xlabel('Time (s)')
plt.ylabel('Amplitude')

plt.tight_layout()
plt.show()

# Audio Playback for CNN Results
print("Playing Input Waveform:")
ipd.display(ipd.Audio(input_waveform, rate=sr))

print("Playing Predicted Waveform (CNN):")
ipd.display(ipd.Audio(predicted_waveform_cnn, rate=sr))

print("Playing Combined Signal (CNN):")
ipd.display(ipd.Audio(combined_waveform_cnn, rate=sr))


Epoch 1/10
