<a href="https://colab.research.google.com/github/Terrykamau/PROJECTS/blob/main/noise_reduction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import librosa
import soundfile as sf
from scipy import signal
from typing import Tuple, Optional
import logging

logger = logging.getLogger(__name__)


class SpectralSubtractor:
    """
    Spectral subtraction noise reduction using various methods.

    Implements basic spectral subtraction, power spectral subtraction,
    and multi-band spectral subtraction with over-subtraction factor
    and spectral floor to prevent musical noise.
    """

    def __init__(
        self,
        frame_length: int = 2048,
        hop_length: int = 512,
        alpha: float = 2.0,
        beta: float = 0.01,
        method: str = "power"
    ):
        """
        Initialize spectral subtractor.

        Args:
            frame_length: STFT frame length
            hop_length: STFT hop length
            alpha: Over-subtraction factor (1.0-4.0)
            beta: Spectral floor factor (0.001-0.1)
            method: 'basic', 'power', or 'multiband'
        """
        self.frame_length = frame_length
        self.hop_length = hop_length
        self.alpha = alpha
        self.beta = beta
        self.method = method

        # Validate parameters
        if alpha < 1.0 or alpha > 10.0:
            logger.warning(f"Alpha={alpha} outside typical range [1.0, 10.0]")
        if beta < 0.001 or beta > 0.5:
            logger.warning(f"Beta={beta} outside typical range [0.001, 0.5]")

    def estimate_noise_spectrum(
        self,
        noisy_stft: np.ndarray,
        noise_frames: int = 10
    ) -> np.ndarray:
        """
        Estimate noise spectrum from initial frames of audio.

        Args:
            noisy_stft: STFT of noisy signal, shape (n_freq, n_frames)
            noise_frames: Number of initial frames to use for noise estimation

        Returns:
            Estimated noise power spectrum, shape (n_freq,)
        """
        if noise_frames > noisy_stft.shape[1]:
            noise_frames = noisy_stft.shape[1] // 2
            logger.warning(f"Reduced noise_frames to {noise_frames}")

        # Use initial frames for noise estimation
        noise_stft = noisy_stft[:, :noise_frames]
        noise_power = np.mean(np.abs(noise_stft) ** 2, axis=1)

        return noise_power

    def basic_spectral_subtraction(
        self,
        noisy_stft: np.ndarray,
        noise_power: np.ndarray
    ) -> np.ndarray:
        """
        Basic spectral subtraction: |Y|² - α|N|² with spectral floor.

        Args:
            noisy_stft: STFT of noisy signal
            noise_power: Estimated noise power spectrum

        Returns:
            Enhanced STFT
        """
        noisy_magnitude = np.abs(noisy_stft)
        noisy_phase = np.angle(noisy_stft)
        noisy_power = noisy_magnitude ** 2

        # Spectral subtraction
        enhanced_power = noisy_power - self.alpha * noise_power[:, np.newaxis]

        # Apply spectral floor
        spectral_floor = self.beta * noisy_power
        enhanced_power = np.maximum(enhanced_power, spectral_floor)

        # Reconstruct magnitude and combine with original phase
        enhanced_magnitude = np.sqrt(enhanced_power)
        enhanced_stft = enhanced_magnitude * np.exp(1j * noisy_phase)

        return enhanced_stft

    def power_spectral_subtraction(
        self,
        noisy_stft: np.ndarray,
        noise_power: np.ndarray
    ) -> np.ndarray:
        """
        Power spectral subtraction with Wiener-like gain function.

        Args:
            noisy_stft: STFT of noisy signal
            noise_power: Estimated noise power spectrum

        Returns:
            Enhanced STFT
        """
        noisy_magnitude = np.abs(noisy_stft)
        noisy_phase = np.angle(noisy_stft)
        noisy_power = noisy_magnitude ** 2

        # Compute gain function
        snr_prior = noisy_power / (noise_power[:, np.newaxis] + 1e-10)
        gain = 1 - self.alpha / snr_prior

        # Apply spectral floor
        gain = np.maximum(gain, self.beta)
        gain = np.minimum(gain, 1.0)  # Ensure gain <= 1

        # Apply gain to magnitude
        enhanced_magnitude = gain * noisy_magnitude
        enhanced_stft = enhanced_magnitude * np.exp(1j * noisy_phase)

        return enhanced_stft

    def multiband_spectral_subtraction(
        self,
        noisy_stft: np.ndarray,
        noise_power: np.ndarray,
        n_bands: int = 6
    ) -> np.ndarray:
        """
        Multi-band spectral subtraction with frequency-dependent parameters.

        Args:
            noisy_stft: STFT of noisy signal
            noise_power: Estimated noise power spectrum
            n_bands: Number of frequency bands

        Returns:
            Enhanced STFT
        """
        n_freq, n_frames = noisy_stft.shape
        enhanced_stft = np.zeros_like(noisy_stft)

        # Define frequency bands
        band_edges = np.linspace(0, n_freq, n_bands + 1, dtype=int)

        # Different alpha values for different bands
        alpha_bands = np.linspace(self.alpha * 0.5, self.alpha * 2.0, n_bands)
        beta_bands = np.linspace(self.beta * 2.0, self.beta * 0.5, n_bands)

        for i in range(n_bands):
            start_idx = band_edges[i]
            end_idx = band_edges[i + 1]

            # Extract band
            band_stft = noisy_stft[start_idx:end_idx, :]
            band_noise = noise_power[start_idx:end_idx]

            # Apply spectral subtraction with band-specific parameters
            temp_subtractor = SpectralSubtractor(
                frame_length=self.frame_length,
                hop_length=self.hop_length,
                alpha=alpha_bands[i],
                beta=beta_bands[i],
                method="power"
            )

            enhanced_band = temp_subtractor.power_spectral_subtraction(
                band_stft, band_noise
            )

            enhanced_stft[start_idx:end_idx, :] = enhanced_band

        return enhanced_stft

    def smooth_gain_function(
        self,
        gain: np.ndarray,
        smoothing_factor: float = 0.98
    ) -> np.ndarray:
        """
        Apply temporal smoothing to gain function to reduce musical noise.

        Args:
            gain: Gain function, shape (n_freq, n_frames)
            smoothing_factor: Smoothing factor (0.9-0.99)

        Returns:
            Smoothed gain function
        """
        smoothed_gain = np.zeros_like(gain)
        smoothed_gain[:, 0] = gain[:, 0]

        for t in range(1, gain.shape[1]):
            smoothed_gain[:, t] = (
                smoothing_factor * smoothed_gain[:, t-1] +
                (1 - smoothing_factor) * gain[:, t]
            )

        return smoothed_gain

    def denoise(
        self,
        audio_data: np.ndarray,
        sample_rate: int,
        noise_frames: int = 10,
        apply_smoothing: bool = True
    ) -> np.ndarray:
        """
        Apply spectral subtraction to remove noise from audio.

        Args:
            audio_data: Input audio signal
            sample_rate: Sample rate in Hz
            noise_frames: Number of frames for noise estimation
            apply_smoothing: Whether to apply gain smoothing

        Returns:
            Denoised audio signal
        """
        # Validate input
        if len(audio_data) == 0:
            raise ValueError("Audio data cannot be empty")
        if sample_rate <= 0:
            raise ValueError("Sample rate must be positive")

        logger.info(f"Applying {self.method} spectral subtraction")
        logger.debug(f"Audio length: {len(audio_data)/sample_rate:.2f}s")

        # Compute STFT
        noisy_stft = librosa.stft(
            audio_data,
            n_fft=self.frame_length,
            hop_length=self.hop_length,
            window='hann'
        )

        # Estimate noise spectrum
        noise_power = self.estimate_noise_spectrum(noisy_stft, noise_frames)

        # Apply spectral subtraction based on method
        if self.method == "basic":
            enhanced_stft = self.basic_spectral_subtraction(noisy_stft, noise_power)
        elif self.method == "power":
            enhanced_stft = self.power_spectral_subtraction(noisy_stft, noise_power)
        elif self.method == "multiband":
            enhanced_stft = self.multiband_spectral_subtraction(noisy_stft, noise_power)
        else:
            raise ValueError(f"Unknown method: {self.method}")

        # Apply gain smoothing if requested
        if apply_smoothing and self.method in ["power", "multiband"]:
            gain = np.abs(enhanced_stft) / (np.abs(noisy_stft) + 1e-10)
            smoothed_gain = self.smooth_gain_function(gain)
            enhanced_stft = smoothed_gain * noisy_stft

        # Reconstruct audio
        enhanced_audio = librosa.istft(
            enhanced_stft,
            hop_length=self.hop_length,
            window='hann',
            length=len(audio_data)
        )

        logger.info("Spectral subtraction completed")
        return enhanced_audio


def load_audio(filepath: str, target_sr: int = 16000) -> Tuple[np.ndarray, int]:
    """
    Load audio file with resampling.

    Args:
        filepath: Path to audio file
        target_sr: Target sample rate

    Returns:
        Tuple of (audio_data, sample_rate)
    """
    try:
        audio_data, sample_rate = librosa.load(filepath, sr=target_sr)
        logger.info(f"Loaded {filepath}: {len(audio_data)} samples at {sample_rate} Hz")
        return audio_data, sample_rate
    except Exception as e:
        logger.error(f"Failed to load audio {filepath}: {e}")
        raise


def save_audio(
    audio_data: np.ndarray,
    filepath: str,
    sample_rate: int = 16000
) -> None:
    """
    Save audio to file.

    Args:
        audio_data: Audio signal to save
        filepath: Output file path
        sample_rate: Sample rate in Hz
    """
    try:
        # Normalize to prevent clipping
        if np.max(np.abs(audio_data)) > 1.0:
            audio_data = audio_data / np.max(np.abs(audio_data))
            logger.warning("Audio normalized to prevent clipping")

        sf.write(filepath, audio_data, sample_rate)
        logger.info(f"Saved audio to {filepath}")
    except Exception as e:
        logger.error(f"Failed to save audio {filepath}: {e}")
        raise


def calculate_snr(clean_signal: np.ndarray, noisy_signal: np.ndarray) -> float:
    """
    Calculate Signal-to-Noise Ratio in dB.

    Args:
        clean_signal: Clean reference signal
        noisy_signal: Noisy signal

    Returns:
        SNR in dB
    """
    noise = noisy_signal - clean_signal
    signal_power = np.mean(clean_signal ** 2)
    noise_power = np.mean(noise ** 2)

    if noise_power == 0:
        return float('inf')

    snr_db = 10 * np.log10(signal_power / noise_power)
    return snr_db


# Example usage and testing
if __name__ == "__main__":
    # Configure logging
    logging.basicConfig(level=logging.INFO)

    # Example: Process audio file
    try:
        # Load noisy audio
        input_file = "/content/audio1.wav"  # Replace with your file
        audio_data, sample_rate = load_audio(input_file)

        # Create spectral subtractor
        subtractor = SpectralSubtractor(
            frame_length=2048,
            hop_length=512,
            alpha=2.0,
            beta=0.01,
            method="power"  # or "basic", "multiband"
        )

        # Apply noise reduction
        enhanced_audio = subtractor.denoise(
            audio_data,
            sample_rate,
            noise_frames=20,
            apply_smoothing=True
        )

        # Save result
        output_file = "enhanced_audio.wav"
        save_audio(enhanced_audio, output_file, sample_rate)

        print(f"Noise reduction completed!")
        print(f"Input: {input_file}")
        print(f"Output: {output_file}")

    except FileNotFoundError:
        print("Please provide a valid audio file path")

        # Generate synthetic example instead
        print("Generating synthetic example...")

        # Create test signal: sine wave + white noise
        duration = 3.0  # seconds
        sample_rate = 16000
        t = np.linspace(0, duration, int(sample_rate * duration))

        # Clean signal: combination of sine waves
        clean_signal = (
            0.5 * np.sin(2 * np.pi * 440 * t) +  # A4
            0.3 * np.sin(2 * np.pi * 880 * t)    # A5
        )

        # Add white noise
        noise = 0.2 * np.random.randn(len(clean_signal))
        noisy_signal = clean_signal + noise

        # Calculate original SNR
        original_snr = calculate_snr(clean_signal, noisy_signal)
        print(f"Original SNR: {original_snr:.2f} dB")

        # Apply spectral subtraction
        subtractor = SpectralSubtractor(method="power", alpha=2.0, beta=0.01)
        enhanced_signal = subtractor.denoise(noisy_signal, sample_rate)

        # Calculate improved SNR
        improved_snr = calculate_snr(clean_signal, enhanced_signal)
        print(f"Enhanced SNR: {improved_snr:.2f} dB")
        print(f"SNR improvement: {improved_snr - original_snr:.2f} dB")

        # Save synthetic example
        save_audio(noisy_signal, "synthetic_noisy.wav", sample_rate)
        save_audio(enhanced_signal, "synthetic_enhanced.wav", sample_rate)
        print("Synthetic example saved!")

Noise reduction completed!
Input: /content/audio1.wav
Output: enhanced_audio.wav
