# Denoise Audio

In [None]:
from audioautoencoder.denoising import *

In [None]:
# --------------- In Loop Parameters --------------
model_name = 'UNetConv10_mask'
SNRdB_load = [-10, 10]
load_file = 'Autoencodermodel_earlystopping.pth'
load_path = f'/content/drive/MyDrive/Projects/ML_Projects/De-noising-autoencoder/Models_Denoising/Checkpoints_{model_name}_{SNRdB_load[0]}-{SNRdB_load[1]}/{load_file}'

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_one = UNetConv10().to(device)
denoiser = DenoisingLoader(model_one, load_path)
model_one = denoiser.model
print('Loaded Model')

In [None]:
import os

whole_files = '/content/drive/MyDrive/Datasets/Music/MUSDB18/test/'
song_files = []

# Walk through the directory tree
for root, dirs, files in os.walk(whole_files):
    # Filter files with '.wav' extension and 'mixture' in their name
    for f in files:
        if f.endswith('.wav') and 'mixture' in f:
            full_path = os.path.join(root, f)
            song_files.append(full_path)

print(f"\nTotal matching files: {len(song_files)}")

In [None]:
#noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/val/crowd noise (4)_xDoJJ9.wav'
#noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/val/plane-noise-passengers-sound_s8OrJQ.mp3'
#noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/val/Crowd - Mall ambience_UdLE4r.wav'
#noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/test/Crowd - Cheering - Strong cheering and soft rhythmic cheering_Pxj5eZ.wav'
#noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/test/Crowd - Street parade with music_FgF6cW.wav'
noise_file = '/content/drive/MyDrive/Datasets/Noise/All_Noise/splits_v2/test/Robocup 2019 4.1_vTYYoj.mp3'

In [None]:
def generate_audio_with_noise(audio_file, noise_file, start_time=10, duration=10,
                             signal_level=1, noise_level=0.1, sr=44100, plot=False):
    """
    Loads an audio file and a noise file, trims them, normalizes, and adds Gaussian noise.

    Parameters:
        audio_file (str): Path to the main audio file.
        noise_file (str): Path to the noise file.
        start_time (int): Start time (in seconds) for trimming.
        duration (int): Total duration (in seconds).
        signal_level (float): Scaling factor for the audio signal.
        noise_level (float): Scaling factor for the noise.
        sr (int): Expected sample rate (default: 44100 Hz).

    Returns:
        noisy_audio (np.array): Processed noisy audio.
        snr (float): Signal-to-noise ratio in dB.
    """
    # Load audio and noise
    audio, audio_sr = load_audio_file(audio_file)
    noise_waveform, noise_sr = load_audio_file(noise_file)

    if len(audio) == 2:
      audio = audio[0]

    if len(noise_waveform) == 2:
      noise_waveform = noise_waveform[0]

    # Trim audio and noise to the specified start time and duration

    audio = audio.cpu().numpy() if isinstance(audio, torch.Tensor) else audio
    noise_waveform = noise_waveform.cpu().numpy() if isinstance(noise_waveform, torch.Tensor) else noise_waveform

    print('Noise Sample Rate:', noise_sr)

    assert audio_sr == sr, f"Expected sample rate {sr}, but got {audio_sr}"

    # Trim audio and noise to the specified start time and duration
    audio = audio[start_time * sr : (start_time + duration) * sr]
    noise_waveform = noise_waveform[start_time * noise_sr : (start_time + duration) * noise_sr]

    # Normalize audio to [-1, 1]
    audio = np.clip((audio / np.max(np.abs(audio))) * signal_level, -1, 1)
    noise_waveform = np.clip((noise_waveform / np.max(np.abs(noise_waveform))) * noise_level, -1, 1)

    # Add noise to the signal
    noisy_audio = np.clip(audio + noise_waveform, -1, 1)

    # Compute SNR
    signal_power = np.mean(audio**2)
    noise_power = np.mean(noise_waveform**2)
    snr = 10 * np.log10(signal_power / noise_power)

    print(f"SNR: {snr:.2f} dB")

    # Plot results
    if plot:
        plt.figure(figsize=(10, 4))
        plt.plot(noise_waveform, label="Noise")
        plt.legend()
        plt.show()

        plt.figure(figsize=(10, 4))
        plt.plot(audio, label="Clean Audio")
        plt.legend()
        plt.show()

        plt.figure(figsize=(10, 4))
        plt.plot(noisy_audio, label="Noisy Audio")
        plt.legend()
        plt.show()

    return noisy_audio, sr

In [None]:
import os
import torch
import numpy as np
import random
import matplotlib.pyplot as plt
import soundfile as sf

def denoise_audio_chunk(chunk, sr, model, scalers, chunk_samples=2*44100, device=None):
    device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    features = extract_features(chunk, sr, audio_length=chunk_samples)
    transformed_features, metadata = transform_features(features, scalers)
    
    # AI denoising
    input_tensor = torch.tensor(np.array([transformed_features]), dtype=torch.float32).to(device)
    denoised = model(input_tensor)
    
    #input_array = input_tensor.detach().cpu().numpy()[0]
    denoised_array = denoised.detach().cpu().numpy()[0]
    
    denoised_spectrogram = reconstruct_spectrogram(denoised_array, metadata)
    #input_spectrogram = reconstruct_spectrogram(input_array, metadata)
    
    # Remove lower-than-average values from the spectrogram
    denoised_spectrogram = threshold_spectrogram(denoised_spectrogram, np.mean(denoised_spectrogram), percentage=0.5)
    
    # De-normalize
    denoised_spectrogram = inverse_scale(denoised_spectrogram, scalers)
    denoised_spectrogram = librosa.db_to_amplitude(denoised_spectrogram)
    
    output_chunk = magphase_to_waveform(denoised_spectrogram, features['phase'], chunk_samples)
    
    return output_chunk

class AudioDenoiser:
    def __init__(self, model_one, scalers, output_path, sample_rate=44100, chunk_duration=2, step_size=0.5, device=None):
        """
        Audio Denoising Pipeline using AI model.

        Parameters:
            model (torch.nn.Module): AI model for denoising.
            output_path (str): Directory to save output files.
            sample_rate (int): Sample rate (default 44100 Hz).
            chunk_duration (int): Duration of each chunk in seconds.
            step_size (float): Step size for overlap-add in seconds.
            device (str, optional): Device for PyTorch computation ("cuda" or "cpu").
        """
        self.model_one = model_one
        #self.model_two = model
        self.output_path = output_path
        self.sample_rate = sample_rate
        self.chunk_samples = sample_rate * chunk_duration
        self.scalers = scalers
        self.step_samples = int(self.chunk_samples * step_size)
        self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
        self.model_one.to(self.device)
        self.model_one.eval()
        #self.model_two.to(self.device)
        #self.model_two.eval()

    def process_audio(self, waveform, sr):
        """
        Processes audio by adding noise, chunking, denoising, and reconstructing.

        Parameters:
            input_path (str): Path to the input song file.
            noise_path (str): Path to the noise file.

        Returns:
            Tuple of (reconstructed_audio, reconstructed_audio_input).
        """
        # Load audio
        self.waveform = waveform
        assert sr == self.sample_rate, f"Sample rate mismatch: expected {self.sample_rate}, got {sr}"

        # Process in chunks
        processed_audio, processed_input = [], []
        for start in range(0, len(waveform) - self.chunk_samples + 1, self.step_samples):
            input_chunk = waveform[start:start + self.chunk_samples]

            # --- start denoising
            output_chunk = denoise_audio_chunk(input_chunk, sr, self.model_one, self.scalers, self.chunk_samples, self.device)
            #output_chunk = denoise_audio_chunk(output_chunk, sr, self.model_two, self.scalers, self.chunk_samples, self.device)

            processed_input.append(input_chunk)
            processed_audio.append(output_chunk)

        # Reconstruct waveform with overlap-add
        reconstructed_audio = self._overlap_add(processed_audio)
        reconstructed_audio_input = waveform #self._overlap_add(processed_input)

        # Save output
        self._save_audio(reconstructed_audio, "output_audio_song.wav")
        self._save_audio(reconstructed_audio_input, "input_audio_song.wav")

        # Plot spectrograms
        self._plot_spectrograms(reconstructed_audio, reconstructed_audio_input)

        return reconstructed_audio, reconstructed_audio_input

    def _overlap_add(self, chunks):
        """Reconstructs the waveform using overlap-add method."""
        reconstructed = np.zeros(len(self.waveform))
        weight = np.zeros(len(self.waveform))

        for i, start in enumerate(range(0, len(self.waveform) - self.chunk_samples + 1, self.step_samples)):
            reconstructed[start:start + self.chunk_samples] += chunks[i]
            weight[start:start + self.chunk_samples] += np.hanning(self.chunk_samples)

        reconstructed /= np.maximum(weight, 1e-6)
        reconstructed = np.clip(reconstructed, -1, 1)

        fade_in = int(self.sample_rate / 2)
        reconstructed[:fade_in] *= np.hanning(self.sample_rate)[:fade_in]
        reconstructed[-fade_in:] *= np.hanning(self.sample_rate)[-fade_in:]

        return reconstructed

    def _save_audio(self, audio, filename):
        """Saves the audio file."""
        output_filename = os.path.join(self.output_path, add_datetime_to_filename(filename))
        sf.write(output_filename, audio / np.max(audio), self.sample_rate)
        print(f"Saved: {output_filename}")

    def _plot_spectrograms(self, reconstructed_audio, reconstructed_audio_input):
        """Plots spectrograms of processed and input audio with consistent color scale."""

        import librosa
        import librosa.display
        import numpy as np
        import matplotlib.pyplot as plt

        # Compute Mel spectrograms
        Sxx1 = librosa.feature.melspectrogram(y=reconstructed_audio, sr=self.sample_rate, n_fft=2048, hop_length=1024)
        Sxx2 = librosa.feature.melspectrogram(y=reconstructed_audio_input, sr=self.sample_rate, n_fft=2048, hop_length=1024)

        # Convert to log scale (dB)
        Sxx1_db = librosa.amplitude_to_db(Sxx1, ref=np.max)
        Sxx2_db = librosa.amplitude_to_db(Sxx2, ref=np.max)

        # Compute shared color limits
        vmin, vmax = min(Sxx1_db.min(), Sxx2_db.min()), max(Sxx1_db.max(), Sxx2_db.max())

        # Create figure
        fig, axes = plt.subplots(1, 2, figsize=(14, 6), sharey=True)

        # Plot processed audio spectrogram
        img1 = librosa.display.specshow(Sxx1_db, sr=self.sample_rate, hop_length=1024, cmap="viridis", ax=axes[0], vmin=vmin, vmax=vmax)
        axes[0].set_title("Spectrogram of Processed Audio")
        axes[0].set_xlabel("Time (s)")
        axes[0].set_ylabel("Frequency (Hz)")

        # Plot input spectrogram
        img2 = librosa.display.specshow(Sxx2_db, sr=self.sample_rate, hop_length=1024, cmap="viridis", ax=axes[1], vmin=vmin, vmax=vmax)
        axes[1].set_title("Spectrogram of Input Audio")
        axes[1].set_xlabel("Time (s)")

        # Add shared colorbar
        fig.colorbar(img1, ax=axes, orientation="vertical", fraction=0.02, pad=0.02, label="Amplitude (dB)")

        plt.show()


def resample_feature(feature, target_shape):
    """Resamples a 2D numpy feature array to match target shape using torch.nn.functional.interpolate."""
    feature_tensor = torch.tensor(feature, dtype=torch.float32).unsqueeze(0).unsqueeze(0)  # Shape: (1, 1, H, W)
    target_size = (target_shape[0], target_shape[1])  # (new_H, new_W)

    resized_feature = F.interpolate(feature_tensor, size=target_size, mode="bilinear", align_corners=False)
    return resized_feature.squeeze(0).squeeze(0).numpy()  # Remove batch/channel dim and return as numpy

def transform_features(features, scalers):
    input_spectrogram = features['spectrogram']
    input_edges = features['edges']
    input_cepstrum = features['cepstrum']

    # function to transform the extracted features to an input to the
    target_shape = input_spectrogram.shape
    # Apply scalers
        #input_phase = self.scalers["input_features_phase"].transform(input_phase.reshape(1, -1)).reshape(input_phase.shape)
    input_spectrogram = scalers["input_features_spectrogram"].transform(input_spectrogram.reshape(1, -1)).reshape(input_spectrogram.shape)
    input_edges = scalers["input_features_edges"].transform(input_edges.reshape(1, -1)).reshape(input_edges.shape)
    input_cepstrum = scalers["input_features_cepstrum"].transform(input_cepstrum.reshape(1, -1)).reshape(input_cepstrum.shape)
    #input_cepstrum_edges = self.scalers["input_features_cepstrum_edges"].transform(input_cepstrum_edges.reshape(1, -1)).reshape(input_cepstrum_edges.shape)

    # resample mfcc featues so theyre the same shape as the spectrogram and phase features
    # Define frequency bins
    sampling_rate = 44100  # 44.1 kHz audio
    n_fft = 2048  # Adjust this for better resolution
    freqs = np.linspace(0, sampling_rate / 2, n_fft // 2 + 1)  # STFT frequency bins

    # Find indices corresponding to 0–4000 Hz
    min_freq, hf, mf, lf = 0, 4000, 1000, 200
    freq_indices_hf = np.where((freqs >= min_freq) & (freqs <= hf))[0]
    freq_indices_mf = np.where((freqs >= min_freq) & (freqs <= mf))[0]
    freq_indices_lf = np.where((freqs >= min_freq) & (freqs <= lf))[0]
    # input spectrogram
    input_spectrogram_hf = resample_feature(input_spectrogram[freq_indices_hf, :], target_shape)
    input_spectrogram_mf = resample_feature(input_spectrogram[freq_indices_mf, :], target_shape)
    input_spectrogram_lf = resample_feature(input_spectrogram[freq_indices_lf, :], target_shape)
    # edges
    input_edges_hf = resample_feature(input_edges[freq_indices_hf, :], target_shape)
    input_edges_mf = resample_feature(input_edges[freq_indices_mf, :], target_shape)
    input_edges_lf = resample_feature(input_edges[freq_indices_lf, :], target_shape)

    # now input indices for 0-1000 and 0-200 to add as channels and as freq_indicies for reconstruction

    # Resample MFCC features
    input_cepstrum = resample_feature(input_cepstrum, target_shape)

    # Convert to tensors - input_phase, is missing,..... it's too confusing
    inputs = torch.tensor(np.stack([
        input_spectrogram, input_spectrogram_hf, input_spectrogram_mf, input_spectrogram_lf,
        input_edges, input_edges_hf, input_edges_mf, input_edges_lf,
        input_cepstrum
    ], axis=0), dtype=torch.float32)  # Shape: (6, H, W)

    a = 3
    inputs = (inputs/a) + 0.5

    # metadata
    # Extract metadata
    metadata = {
        "hf_shape": input_spectrogram[freq_indices_hf, :].shape,
        "mf_shape": input_spectrogram[freq_indices_mf, :].shape,
        "lf_shape": input_spectrogram[freq_indices_lf, :].shape,
        "freq_indices_hf": freq_indices_hf,
        "freq_indices_mf": freq_indices_mf,
        "freq_indices_lf": freq_indices_lf
    }

    return inputs, metadata

def reconstruct_spectrogram(outputs, metadata):
    # lets evaluate this from a l1 loss perspective
    # reconstruct spectrogram
    out_spectrogram = np.array(outputs[0])
    out_spectrogram[metadata["freq_indices_hf"], :] = resample_feature(outputs[1], metadata["hf_shape"])
    out_spectrogram[metadata["freq_indices_mf"], :] = resample_feature(outputs[2], metadata["mf_shape"])
    out_spectrogram[metadata["freq_indices_lf"], :] = resample_feature(outputs[3], metadata["lf_shape"])
    return out_spectrogram

def inverse_scale(out_spectrogram, scalers):
    # inverse scale the
    # transform back to 0 centred and
    out_spectrogram = (out_spectrogram - 0.5) * 3
    out_spec_shape = out_spectrogram.shape

    # undo scaler
    out_spectrogram = scalers["input_features_spectrogram"].inverse_transform(np.array([out_spectrogram]).reshape(1, -1)).reshape(out_spec_shape)
    return out_spectrogram

def threshold_spectrogram(spectrogram, threshold, percentage=0.8):
    """
    Zeroes out all values in the spectrogram that are below the given threshold.

    Args:
        spectrogram (np.ndarray): Input 2D array.
        threshold (float): The threshold value.

    Returns:
        np.ndarray: The processed spectrogram with values below threshold set to zero.
    """
    spectrogram = np.where(spectrogram >= threshold * percentage, spectrogram, 0)
    return spectrogram

def magphase_to_waveform(magnitude, phase, audio_length=44100):
    """
    Converts a spectrogram image back into an audio waveform.

    Parameters:
        image (np.array): Spectrogram image (3 channels).
        sr (int): Sampling rate.

    Returns:
        np.array: Reconstructed audio waveform.
    """
    stft = magnitude * np.exp(1j * phase)
    return librosa.istft(stft, length=audio_length)

In [None]:
file_number = 25
noisy_audio, sr = generate_audio_with_noise(song_files[file_number], noise_file, start_time=20, duration=10, noise_level=0.7)
denoiser = AudioDenoiser(model_one, scalers, output_path=output_path, chunk_duration=2, step_size=0.5)
reconstructed_audio, reconstructed_input = denoiser.process_audio(noisy_audio, sr)

In [None]:
import soundfile as sf
from google.colab import files

def save_wav_sf(file_path, audio_array, sample_rate):
    """Saves a NumPy array as a WAV file using soundfile."""
    sf.write(file_path, audio_array, sample_rate, subtype="PCM_16")  # Can be PCM_24, PCM_32, FLOAT
    files.download(file_path)  # Trigger download in Colab

save_wav_sf(f"output_{file_number}.wav", reconstructed_audio, sr)
save_wav_sf(f"input_{file_number}.wav", reconstructed_input, sr)

---