In [None]:

#FILE ORGANIZATION
# dereverberation_dataset/
# ‚îú‚îÄ‚îÄ train/
# ‚îÇ   ‚îú‚îÄ‚îÄ clean/          (folder)
# ‚îÇ   ‚îî‚îÄ‚îÄ reverberant/    (folder)
# ‚îú‚îÄ‚îÄ test/
# ‚îÇ   ‚îú‚îÄ‚îÄ output/         (folder)
# ‚îÇ   ‚îî‚îÄ‚îÄ reverberant/    (folder - EMPTY)
# ‚îî‚îÄ‚îÄ dereverberation_checkpoints/
#     ‚îú‚îÄ‚îÄ best_model.pt
#     ‚îî‚îÄ‚îÄ checkpoint_epoch_X.pt

In [23]:
import numpy as np
import librosa
import soundfile as sf
from scipy import signal
from scipy.fft import rfft, irfft
import torch
import torch.nn as nn
import os

class AudioPreprocessor:
    """
    Preprocessing pipeline for LSTM-based dereverberation.
    Implements Step 1 from the paper.
    """

    def __init__(self,
                 sample_rate=16000,
                 frame_length_ms=32,
                 frame_shift_ms=8,
                 n_fft=512,
                 normalize=True):
        """
        Initialize the audio preprocessor.

        Args:
            sample_rate: Audio sampling rate (Hz)
            frame_length_ms: Frame length in milliseconds (32ms as per paper)
            frame_shift_ms: Frame shift/hop in milliseconds (8ms as per paper)
            n_fft: FFT size (512 points as per paper)
            normalize: Whether to normalize features
        """
        self.sample_rate = sample_rate
        self.frame_length_ms = frame_length_ms
        self.frame_shift_ms = frame_shift_ms
        self.n_fft = n_fft
        self.normalize = normalize

        # Convert ms to samples
        self.frame_length = int(frame_length_ms * sample_rate / 1000)  # 512 samples at 16kHz
        self.hop_length = int(frame_shift_ms * sample_rate / 1000)      # 128 samples at 16kHz

        # Number of frequency bins (257 for 512-point FFT)
        self.n_freq_bins = n_fft // 2 + 1

        # Statistics for normalization (to be computed from training data)
        self.feature_mean = None
        self.feature_std = None

        print(f"Initialized AudioPreprocessor:")
        print(f"  Sample rate: {sample_rate} Hz")
        print(f"  Frame length: {self.frame_length} samples ({frame_length_ms} ms)")
        print(f"  Hop length: {self.hop_length} samples ({frame_shift_ms} ms)")
        print(f"  FFT size: {n_fft}")
        print(f"  Frequency bins: {self.n_freq_bins}")

    def load_audio(self, audio_path):
        """
        Load audio file and resample to target sample rate.

        Args:
            audio_path: Path to audio file

        Returns:
            audio: Audio time series
        """
        audio, sr = librosa.load(audio_path, sr=self.sample_rate, mono=True)
        return audio

    def extract_magnitude_spectrum(self, audio):
        """
        Extract magnitude spectrum using STFT with Hamming window.

        Args:
            audio: Input audio signal (1D numpy array)

        Returns:
            magnitude: Magnitude spectrum (n_frames, n_freq_bins)
            phase: Phase spectrum (n_frames, n_freq_bins)
        """
        # Apply STFT with Hamming window
        stft_matrix = librosa.stft(
            audio,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.frame_length,
            window='hamming',
            center=True,
            pad_mode='reflect'
        )

        # Extract magnitude and phase
        magnitude = np.abs(stft_matrix).T  # Shape: (n_frames, n_freq_bins)
        phase = np.angle(stft_matrix).T     # Shape: (n_frames, n_freq_bins)

        return magnitude, phase

    def apply_cubic_root_compression(self, magnitude):
        """
        Apply cubic root compression to magnitude spectrum.

        Args:
            magnitude: Magnitude spectrum (n_frames, n_freq_bins)

        Returns:
            compressed: Cubic root compressed magnitude
        """
        # Cubic root compression: Y_compressed = Y^(1/3)
        compressed = np.power(magnitude, 1.0/3.0)
        return compressed

    def normalize_features(self, features, compute_stats=False):
        """
        Normalize features to zero mean and unit variance.

        Args:
            features: Input features (n_frames, n_freq_bins)
            compute_stats: If True, compute and store mean/std from this data

        Returns:
            normalized: Normalized features
        """
        if compute_stats:
            # Compute statistics across all frames and frequency bins
            self.feature_mean = np.mean(features, axis=0, keepdims=True)
            self.feature_std = np.std(features, axis=0, keepdims=True)
            # Avoid division by zero
            self.feature_std = np.maximum(self.feature_std, 1e-8)
            print(f"Computed normalization statistics:")
            print(f"  Mean shape: {self.feature_mean.shape}")
            print(f"  Std shape: {self.feature_std.shape}")

        if self.feature_mean is None or self.feature_std is None:
            raise ValueError("Normalization statistics not computed. Set compute_stats=True first.")

        # Normalize
        normalized = (features - self.feature_mean) / self.feature_std
        return normalized

    def process_audio(self, audio, compute_stats=False, return_phase=True):
        """
        Complete preprocessing pipeline for a single audio signal.

        Args:
            audio: Input audio signal (1D numpy array) or path to audio file
            compute_stats: If True, compute normalization statistics
            return_phase: If True, return phase information

        Returns:
            features: Preprocessed features (n_frames, n_freq_bins)
            phase: Phase spectrum (if return_phase=True)
        """
        # Load audio if path is provided
        if isinstance(audio, str):
            audio = self.load_audio(audio)

        # Step 1: Extract magnitude spectrum
        magnitude, phase = self.extract_magnitude_spectrum(audio)

        # Step 2: Apply cubic root compression
        compressed = self.apply_cubic_root_compression(magnitude)

        # Step 3: Normalize (if enabled)
        if self.normalize:
            features = self.normalize_features(compressed, compute_stats=compute_stats)
        else:
            features = compressed

        if return_phase:
            return features, phase
        else:
            return features

    def compute_normalization_stats_from_dataset(self, audio_list):
        """
        Compute normalization statistics from a list of audio files/arrays.
        This should be called on ALL REVERBERANT training audio files.

        Args:
            audio_list: List of audio file paths or numpy arrays (reverberant only)
        """
        all_features = []

        print(f"Computing normalization statistics from {len(audio_list)} reverberant audio files...")
        for i, audio in enumerate(audio_list):
            if isinstance(audio, str):
                audio = self.load_audio(audio)

            magnitude, _ = self.extract_magnitude_spectrum(audio)
            compressed = self.apply_cubic_root_compression(magnitude)
            all_features.append(compressed)

            if (i + 1) % 100 == 0:
                print(f"  Processed {i + 1}/{len(audio_list)} files")

        # Concatenate all features and compute global statistics
        all_features = np.concatenate(all_features, axis=0)
        self.feature_mean = np.mean(all_features, axis=0, keepdims=True)
        self.feature_std = np.std(all_features, axis=0, keepdims=True)
        self.feature_std = np.maximum(self.feature_std, 1e-8)

        print(f"Normalization statistics computed successfully!")
        print(f"  Mean shape: {self.feature_mean.shape}")
        print(f"  Mean range: [{self.feature_mean.min():.4f}, {self.feature_mean.max():.4f}]")
        print(f"  Std shape: {self.feature_std.shape}")
        print(f"  Std range: [{self.feature_std.min():.4f}, {self.feature_std.max():.4f}]")


class TrainingTargetGenerator:
    """
    UNIFIED preprocessing for training pairs.
    Handles normalization correctly: normalize reverb (input), keep clean unnormalized (target).
    """
    def __init__(self, preprocessor):
        """
        Args:
            preprocessor: AudioPreprocessor instance with computed normalization stats
        """
        self.preprocessor = preprocessor
        self.sample_rate = preprocessor.sample_rate
        self.n_fft = preprocessor.n_fft
        self.frame_length = preprocessor.frame_length
        self.hop_length = preprocessor.hop_length
        self.n_freq_bins = preprocessor.n_freq_bins

    def generate_training_pair_from_real_data(self, clean_audio_path, reverb_audio_path):
        """
        Process clean + reverberant pair for training.
        IMPORTANT: Returns NORMALIZED reverb features and UNNORMALIZED clean features.

        Args:
            clean_audio_path: Path to clean/dry audio file
            reverb_audio_path: Path to reverberant audio file

        Returns:
            reverb_features: NORMALIZED cubic-root compressed reverberant features (input)
            target_features: UNNORMALIZED cubic-root compressed clean features (target)
            reverb_phase: Phase from reverberant audio (for reconstruction)
        """
        # Load both audios
        clean_audio = self.preprocessor.load_audio(clean_audio_path)
        reverb_audio = self.preprocessor.load_audio(reverb_audio_path)

        # Ensure same length (truncate to shorter one)
        min_len = min(len(clean_audio), len(reverb_audio))
        clean_audio = clean_audio[:min_len]
        reverb_audio = reverb_audio[:min_len]

        # Extract magnitude spectra
        clean_mag, _ = self.preprocessor.extract_magnitude_spectrum(clean_audio)
        reverb_mag, reverb_phase = self.preprocessor.extract_magnitude_spectrum(reverb_audio)

        # Apply cubic root compression to both
        clean_compressed = self.preprocessor.apply_cubic_root_compression(clean_mag)
        reverb_compressed = self.preprocessor.apply_cubic_root_compression(reverb_mag)

        # CRITICAL: Normalize ONLY reverb features (input to LSTM)
        # Clean features stay unnormalized (target for LSTM)
        if self.preprocessor.feature_mean is None:
            raise ValueError("Preprocessor must have normalization stats computed before processing pairs!")

        reverb_features = self.preprocessor.normalize_features(reverb_compressed, compute_stats=False)
        target_features = clean_compressed  # Keep clean unnormalized

        return reverb_features, target_features, reverb_phase

    def save_audio_examples(self, clean_path, reverb_path, output_prefix='step2_example'):
        """Save the original audio files for comparison"""
        clean_audio = self.preprocessor.load_audio(clean_path)
        reverb_audio = self.preprocessor.load_audio(reverb_path)

        sf.write(f'{output_prefix}_clean.wav', clean_audio, self.sample_rate)
        sf.write(f'{output_prefix}_reverberant.wav', reverb_audio, self.sample_rate)


class RealDataPreparer:
    """
    Prepares real-world clean + reverberant pairs for training.
    Handles normalization statistics computation and pair processing.
    """
    def __init__(self, sample_rate=16000):
        self.sample_rate = sample_rate
        self.preprocessor = AudioPreprocessor(sample_rate=sample_rate, normalize=True)

    def prepare_dataset_from_folder(self, data_folder):
        """
        Expects folder structure:
        data_folder/
            clean/
                file1.wav, file2.wav, ...
            reverberant/
                file1.wav, file2.wav, ...

        Returns:
            reverb_features_list: List of normalized reverb features (inputs)
            target_features_list: List of unnormalized clean features (targets)
        """
        clean_dir = os.path.join(data_folder, 'clean')
        reverb_dir = os.path.join(data_folder, 'reverberant')

        clean_files = sorted([f for f in os.listdir(clean_dir) if f.endswith('.wav')])
        reverb_files = sorted([f for f in os.listdir(reverb_dir) if f.endswith('.wav')])

        print(f"Found {len(clean_files)} clean files and {len(reverb_files)} reverberant files")

        # STEP 1: Compute normalization stats from ALL reverberant files
        reverb_paths = [os.path.join(reverb_dir, f) for f in reverb_files]
        self.preprocessor.compute_normalization_stats_from_dataset(reverb_paths)

        # STEP 2: Process all pairs with computed normalization stats
        reverb_features_list = []
        target_features_list = []

        target_gen = TrainingTargetGenerator(preprocessor=self.preprocessor)

        print(f"\nProcessing {len(clean_files)} training pairs...")
        for i, (clean_file, reverb_file) in enumerate(zip(clean_files, reverb_files)):
            clean_path = os.path.join(clean_dir, clean_file)
            reverb_path = os.path.join(reverb_dir, reverb_file)

            # Generate training pair (normalized reverb + unnormalized clean)
            rev_feat, target_feat, _ = target_gen.generate_training_pair_from_real_data(
                clean_path, reverb_path
            )

            reverb_features_list.append(rev_feat)
            target_features_list.append(target_feat)

            if (i + 1) % 10 == 0:
                print(f"  Processed {i + 1}/{len(clean_files)} pairs")

        print(f"‚úì Dataset preparation complete!")
        print(f"  Total pairs: {len(reverb_features_list)}")
        print(f"  Input (reverb): normalized, cubic-root compressed")
        print(f"  Target (clean): unnormalized, cubic-root compressed")

        return reverb_features_list, target_features_list


# === STEP 1 & 2 EXECUTION - UNIFIED AND CORRECTED ===
if __name__ == "__main__":
    print("\n" + "="*60)
    print("UNIFIED Steps 1 & 2: Audio Preprocessing for Training")
    print("="*60)

    # Check if running in Google Colab
    try:
        import google.colab
        IN_COLAB = True
        print("Running in Google Colab")
    except:
        IN_COLAB = False
        print("Running locally")

    # Mount Google Drive if in Colab
    if IN_COLAB:
        from google.colab import drive
        drive.mount('/content/drive')

        # MODIFY THIS PATH to match your Google Drive structure
        drive_base = '/content/drive/MyDrive/dereverberation_dataset'
        train_folder = os.path.join(drive_base, 'train')
        test_folder = os.path.join(drive_base, 'test')

        print(f"\nüìÅ Using Google Drive path: {drive_base}")
        print(f"   Dataset structure:")
        print(f"   {drive_base}/")
        print(f"     ‚îú‚îÄ‚îÄ train/")
        print(f"     ‚îÇ   ‚îú‚îÄ‚îÄ clean/")
        print(f"     ‚îÇ   ‚îî‚îÄ‚îÄ reverberant/")
        print(f"     ‚îî‚îÄ‚îÄ test/")
        print(f"         ‚îî‚îÄ‚îÄ reverberant/")
    else:
        # Local paths
        train_folder = './train'
        test_folder = './test'
        print(f"\nüìÅ Using local paths: ./train and ./test")

    train_clean_dir = os.path.join(train_folder, 'clean')
    train_reverb_dir = os.path.join(train_folder, 'reverberant')
    test_reverb_dir = os.path.join(test_folder, 'reverberant')

    # Check if train directories exist
    if not os.path.exists(train_clean_dir) or not os.path.exists(train_reverb_dir):
        print(f"\n‚ùå ERROR: Required training directory structure not found!")
        print(f"\nExpected structure:")
        if IN_COLAB:
            print(f"  /content/drive/MyDrive/dereverberation_dataset/")
            print(f"    ‚îú‚îÄ‚îÄ train/")
            print(f"    ‚îÇ   ‚îú‚îÄ‚îÄ clean/      (your clean/dry audio files)")
            print(f"    ‚îÇ   ‚îî‚îÄ‚îÄ reverberant/ (your reverberant audio files)")
            print(f"    ‚îî‚îÄ‚îÄ test/           (optional)")
            print(f"        ‚îî‚îÄ‚îÄ reverberant/ (reverberant test files - outputs will be generated)")
            print(f"\nüí° TIP: Upload your dataset to Google Drive first!")
            print(f"   Then modify 'drive_base' variable in the code to match your path.")
        else:
            print(f"  ./train/")
            print(f"    ‚îú‚îÄ‚îÄ clean/      (your clean/dry audio files)")
            print(f"    ‚îî‚îÄ‚îÄ reverberant/ (your reverberant audio files)")
            print(f"  ./test/           (optional)")
            print(f"    ‚îî‚îÄ‚îÄ reverberant/ (reverberant test files - outputs will be generated)")
        print(f"\nPlease create this structure and place your audio files accordingly.")
        exit(1)

    # Check if test directory exists (only reverberant folder needed)
    has_test_data = os.path.exists(test_reverb_dir)

    # Get file lists
    clean_files = sorted([f for f in os.listdir(train_clean_dir) if f.lower().endswith(('.wav', '.flac', '.mp3', '.m4a', '.ogg'))])
    reverb_files = sorted([f for f in os.listdir(train_reverb_dir) if f.lower().endswith(('.wav', '.flac', '.mp3', '.m4a', '.ogg'))])

    if not clean_files or not reverb_files:
        print(f"‚ùå ERROR: No audio files found in train/clean/ or train/reverberant/")
        exit(1)

    print(f"\n‚úÖ Found {len(clean_files)} clean training files")
    print(f"‚úÖ Found {len(reverb_files)} reverberant training files")

    if has_test_data:
        test_reverb_files = sorted([f for f in os.listdir(test_reverb_dir) if f.lower().endswith(('.wav', '.flac', '.mp3', '.m4a', '.ogg'))])

        if test_reverb_files:
            print(f"‚úÖ Found {len(test_reverb_files)} reverberant test files (for inference)")
        else:
            print(f"‚ö†Ô∏è  Test folder exists but no audio files found")
            has_test_data = False
    else:
        print(f"‚ÑπÔ∏è  No test dataset found (optional)")
        test_reverb_files = []

    # Initialize preprocessor
    preprocessor = AudioPreprocessor(
        sample_rate=16000,
        frame_length_ms=32,
        frame_shift_ms=8,
        n_fft=512,
        normalize=True
    )

    # STEP 1: Compute normalization statistics from ALL reverberant audio
    print("\n" + "="*60)
    print("STEP 1: Computing normalization statistics")
    print("="*60)
    reverb_paths = [os.path.join(train_reverb_dir, f) for f in reverb_files]
    preprocessor.compute_normalization_stats_from_dataset(reverb_paths)

    # STEP 2: Generate training pairs
    print("\n" + "="*60)
    print("STEP 2: Generating training pairs")
    print("="*60)

    target_gen = TrainingTargetGenerator(preprocessor=preprocessor)
    reverb_features_list = []
    target_features_list = []
    phase_list = []

    try:
        for i, (clean_file, reverb_file) in enumerate(zip(clean_files, reverb_files)):
            clean_path = os.path.join(train_clean_dir, clean_file)
            reverb_path = os.path.join(train_reverb_dir, reverb_file)

            reverb_features, target_features, reverb_phase = target_gen.generate_training_pair_from_real_data(
                clean_path, reverb_path
            )

            reverb_features_list.append(reverb_features)
            target_features_list.append(target_features)
            phase_list.append(reverb_phase)

            if (i + 1) % 10 == 0 or (i + 1) == len(clean_files):
                print(f"  Processed {i + 1}/{len(clean_files)} training pairs")

        print(f"\n‚úÖ Training dataset generated successfully!")
        print(f"  Total training pairs: {len(reverb_features_list)}")
        print(f"  Input shape (each): {reverb_features_list[0].shape} (normalized)")
        print(f"  Target shape (each): {target_features_list[0].shape} (unnormalized)")

        # Process test dataset if available (inference only - no targets)
        if has_test_data:
            print("\n" + "="*60)
            print("STEP 3: Processing test data (inference)")
            print("="*60)

            test_reverb_features_list = []
            test_phase_list = []
            test_filenames = []

            for i, reverb_file in enumerate(test_reverb_files):
                reverb_path = os.path.join(test_reverb_dir, reverb_file)

                # Load and process test audio
                reverb_audio = preprocessor.load_audio(reverb_path)
                reverb_mag, reverb_phase = preprocessor.extract_magnitude_spectrum(reverb_audio)
                reverb_compressed = preprocessor.apply_cubic_root_compression(reverb_mag)
                reverb_features = preprocessor.normalize_features(reverb_compressed, compute_stats=False)

                test_reverb_features_list.append(reverb_features)
                test_phase_list.append(reverb_phase)
                test_filenames.append(reverb_file)

                if (i + 1) % 10 == 0 or (i + 1) == len(test_reverb_files):
                    print(f"  Processed {i + 1}/{len(test_reverb_files)} test files")

            print(f"\n‚úÖ Test dataset processed successfully!")
            print(f"  Total test files: {len(test_reverb_features_list)}")
            print(f"  Input shape (each): {test_reverb_features_list[0].shape} (normalized)")
            print(f"  Note: Outputs will be generated after model training")

        print(f"\n" + "="*60)
        print("PROCESSING COMPLETE")
        print("="*60)
        print("‚úì Normalization stats computed from reverberant training audio")
        print("‚úì Training pairs created:")
        print("  - Input: Normalized reverberant features")
        print("  - Target: Unnormalized clean features")
        if has_test_data:
            print("‚úì Test data processed (ready for inference after training)")
            print(f"  - Test outputs will be saved to: {test_folder}/output/")
        print("‚úì Ready for Step 3 (Model Training)")
        print("="*60)

    except Exception as e:
        print(f"‚ùå Error in processing: {e}")
        import traceback
        traceback.print_exc()


UNIFIED Steps 1 & 2: Audio Preprocessing for Training
Running in Google Colab
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

üìÅ Using Google Drive path: /content/drive/MyDrive/dereverberation_dataset
   Dataset structure:
   /content/drive/MyDrive/dereverberation_dataset/
     ‚îú‚îÄ‚îÄ train/
     ‚îÇ   ‚îú‚îÄ‚îÄ clean/
     ‚îÇ   ‚îî‚îÄ‚îÄ reverberant/
     ‚îî‚îÄ‚îÄ test/
         ‚îî‚îÄ‚îÄ reverberant/

‚úÖ Found 2 clean training files
‚úÖ Found 2 reverberant training files
‚ö†Ô∏è  Test folder exists but no audio files found
Initialized AudioPreprocessor:
  Sample rate: 16000 Hz
  Frame length: 512 samples (32 ms)
  Hop length: 128 samples (8 ms)
  FFT size: 512
  Frequency bins: 257

STEP 1: Computing normalization statistics
Computing normalization statistics from 2 reverberant audio files...
Normalization statistics computed successfully!
  Mean shape: (1, 257)
  Mean range: [0.2559, 1.4560]
  

In [13]:
import torch
import torch.nn as nn
import numpy as np

class LSTMDereverberation(nn.Module):
    """
    LSTM-based Speech Dereverberation Model.
    Implements Step 3 from the paper.
    """

    def __init__(self,
                 input_size=257,
                 hidden_size=512,
                 num_layers=2,
                 dropout=0.3,
                 weight_dropout=0.5):
        super(LSTMDereverberation, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.dropout = dropout
        self.weight_dropout = weight_dropout

        # Build LSTM layers with weight dropout
        self.lstm_layers = nn.ModuleList()

        for i in range(num_layers):
            layer_input_size = input_size if i == 0 else hidden_size

            # Create LSTM layer
            lstm = nn.LSTM(
                input_size=layer_input_size,
                hidden_size=hidden_size,
                num_layers=1,
                batch_first=True,
                dropout=0
            )

            # Apply weight dropout to recurrent connections
            lstm = WeightDropLSTM(lstm, dropout=weight_dropout)

            self.lstm_layers.append(lstm)

        # Dropout between LSTM layers
        self.dropout_layer = nn.Dropout(dropout)

        # Linear projection layer to map hidden state to magnitude spectrum
        self.linear = nn.Linear(hidden_size, input_size)

        # ReLU activation to ensure positive output
        self.relu = nn.ReLU()

        # Initialize weights with orthogonal initialization
        self._init_weights()

        print(f"Initialized LSTMDereverberation model:")
        print(f"  Input size: {input_size}")
        print(f"  Hidden size: {hidden_size}")
        print(f"  Num LSTM layers: {num_layers}")
        print(f"  Dropout (between layers): {dropout}")
        print(f"  Weight dropout (recurrent): {weight_dropout}")
        print(f"  Total parameters: {self.count_parameters():,}")

    def _init_weights(self):
        """Orthogonal initialization as per paper"""
        for name, param in self.named_parameters():
            if 'weight_hh_raw' in name:  # WeightDropLSTM recurrent weights
                nn.init.orthogonal_(param)
            elif 'weight_ih' in name:  # LSTM input weights
                nn.init.orthogonal_(param)
            elif 'weight' in name and param.dim() >= 2:  # Linear layer weights
                nn.init.orthogonal_(param)
            elif 'bias' in name:
                nn.init.zeros_(param)

    def forward(self, x, hidden_states=None):
        """
        Forward pass through the LSTM dereverberation model.

        Args:
            x: Input features (batch_size, seq_len, input_size)
            hidden_states: Optional list of (h, c) tuples for each LSTM layer

        Returns:
            output: Enhanced features (batch_size, seq_len, input_size)
            new_hidden_states: Updated hidden states for each layer
        """
        batch_size, seq_len, _ = x.shape

        # Initialize hidden states if not provided
        if hidden_states is None:
            hidden_states = self._init_hidden(batch_size, x.device)

        # Pass through LSTM layers
        lstm_out = x
        new_hidden_states = []

        for i, lstm in enumerate(self.lstm_layers):
            # LSTM forward pass
            lstm_out, (h, c) = lstm(lstm_out, hidden_states[i])
            new_hidden_states.append((h, c))

            # Apply dropout between layers (not after last layer)
            if i < self.num_layers - 1:
                lstm_out = self.dropout_layer(lstm_out)

        # Linear projection to magnitude spectrum
        projected = self.linear(lstm_out)

        # Apply ReLU to ensure positive magnitude estimates
        output = self.relu(projected)

        return output, new_hidden_states

    def _init_hidden(self, batch_size, device):
        """Initialize hidden states for all LSTM layers"""
        hidden_states = []
        for _ in range(self.num_layers):
            h_0 = torch.zeros(1, batch_size, self.hidden_size, device=device)
            c_0 = torch.zeros(1, batch_size, self.hidden_size, device=device)
            hidden_states.append((h_0, c_0))
        return hidden_states

    def count_parameters(self):
        """Count total trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def predict(self, reverb_features, return_hidden=False):
        """
        Inference mode prediction.

        Args:
            reverb_features: Reverberant features (seq_len, input_size) or (batch_size, seq_len, input_size)
            return_hidden: Whether to return hidden states

        Returns:
            enhanced_features: Enhanced/denoised features
            hidden_states: (optional) Hidden states if return_hidden=True
        """
        # Handle single sequence input
        if reverb_features.dim() == 2:
            reverb_features = reverb_features.unsqueeze(0)

        # Forward pass
        self.eval()
        with torch.no_grad():
            enhanced_features, hidden_states = self.forward(reverb_features)

        if return_hidden:
            return enhanced_features, hidden_states
        else:
            return enhanced_features


class WeightDropLSTM(nn.Module):
    """
    LSTM with weight dropout applied to recurrent connections.
    Implements variational dropout on the hidden-to-hidden weights.
    """
    def __init__(self, lstm, dropout=0.5):
        super(WeightDropLSTM, self).__init__()
        self.lstm = lstm
        self.dropout = dropout

        # Store the original recurrent weight
        # We need to save it and remove it from the LSTM's parameters
        w_hh = self.lstm.weight_hh_l0.data.clone()
        self.weight_hh_raw = nn.Parameter(w_hh)

        # Remove the original weight from LSTM
        # This prevents it from being optimized separately
        del self.lstm._parameters['weight_hh_l0']

    def forward(self, x, hidden=None):
        """
        Forward pass with weight dropout applied.

        During training: Apply dropout mask to recurrent weights
        During evaluation: Use full weights without dropout
        """
        # Apply dropout to recurrent weights during training
        if self.training and self.dropout > 0:
            # Create dropout mask (same mask for all timesteps - variational dropout)
            mask = self.weight_hh_raw.new_ones(self.weight_hh_raw.size()).bernoulli_(1 - self.dropout)
            # Scale by (1 - dropout) to maintain expected value
            w_hh = self.weight_hh_raw * mask / (1 - self.dropout)
        else:
            w_hh = self.weight_hh_raw

        # Temporarily assign the weight to LSTM
        self.lstm.weight_hh_l0 = w_hh

        # Run LSTM forward pass
        output, hidden = self.lstm(x, hidden)

        return output, hidden


class DereverberationLoss(nn.Module):
    """
    Loss function for dereverberation training.
    Uses MSE loss in the cubic-root compressed magnitude domain.
    """
    def __init__(self, loss_type='mse'):
        super(DereverberationLoss, self).__init__()
        self.loss_type = loss_type
        if loss_type == 'mse':
            self.criterion = nn.MSELoss()
        else:
            raise ValueError(f"Unknown loss type: {loss_type}")

    def forward(self, predicted, target):
        """
        Compute loss between predicted and target features.

        Args:
            predicted: Model output (batch_size, seq_len, input_size)
            target: Target clean features (batch_size, seq_len, input_size)

        Returns:
            loss: Scalar loss value
        """
        return self.criterion(predicted, target)


# === STEP 3 EXECUTION - USING DATA FROM STEP 2 ===
if __name__ == "__main__":
    print("\n" + "="*60)
    print("LSTM Dereverberation Model - Step 3")
    print("="*60)

    # Initialize model
    model = LSTMDereverberation(
        input_size=257,
        hidden_size=512,
        num_layers=2,
        dropout=0.3,
        weight_dropout=0.5
    )

    print("\nModel Architecture:")
    print(model)

    # Test the model with the data from Step 2
    print("\n" + "="*60)
    print("Testing with Step 2 data...")
    print("="*60)

    # Convert the reverberant features from Step 2 to PyTorch tensor
    if 'reverb_features' in locals():
        # Use the actual data from Step 2
        input_features = torch.from_numpy(reverb_features).float().unsqueeze(0)
        target_features_tensor = torch.from_numpy(target_features).float().unsqueeze(0)

        print(f"Input shape from Step 2: {input_features.shape}")
        print(f"Target shape from Step 2: {target_features_tensor.shape}")

        # Forward pass with real data
        model.eval()
        with torch.no_grad():
            output, hidden_states = model(input_features)

        print(f"Model output shape: {output.shape}")

        # Test loss with real data
        criterion = DereverberationLoss(loss_type='mse')
        loss = criterion(output, target_features_tensor)
        print(f"MSE Loss with real data: {loss.item():.4f}")

        print("\n" + "="*60)
        print("STEP 3 SUMMARY:")
        print("="*60)
        print("‚úì LSTM model initialized with 2 layers, 512 units each")
        print("‚úì Weight dropout (0.5) applied to recurrent connections")
        print("‚úì Dropout (0.3) between LSTM layers")
        print("‚úì Linear projection + ReLU for positive output")
        print("‚úì Orthogonal weight initialization applied")
        print("‚úì MSE loss function defined for training")
        print("‚úì Successfully tested with Step 2 data")
        print(f"‚úì Input: normalized reverb features {input_features.shape}")
        print(f"‚úì Target: unnormalized clean features {target_features_tensor.shape}")
        print("‚úì Model ready for training in Step 4!")
        print("="*60)

    else:
        # Fallback: test with dummy data (if Step 2 variables aren't available)
        print("‚ö†Ô∏è  Step 2 variables not found. Using dummy data for testing...")
        batch_size = 1
        seq_len = 100
        input_size = 257

        dummy_input = torch.randn(batch_size, seq_len, input_size)
        print(f"Dummy input shape: {dummy_input.shape}")

        model.eval()
        with torch.no_grad():
            output, hidden_states = model(dummy_input)

        print(f"Model output shape: {output.shape}")

        print("\n" + "="*60)
        print("STEP 3 SUMMARY:")
        print("="*60)
        print("‚úì LSTM model initialized with 2 layers, 512 units each")
        print("‚úì Weight dropout (0.5) applied to recurrent connections")
        print("‚úì Dropout (0.3) between LSTM layers")
        print("‚úì Linear projection + ReLU for positive output")
        print("‚úì Orthogonal weight initialization applied")
        print("‚úì MSE loss function defined for training")
        print("‚úì Model tested with dummy data")
        print("‚ö†Ô∏è  Run Step 2 first to test with real data")
        print("‚úì Model ready for training in Step 4!")
        print("="*60)


LSTM Dereverberation Model - Step 3
Initialized LSTMDereverberation model:
  Input size: 257
  Hidden size: 512
  Num LSTM layers: 2
  Dropout (between layers): 0.3
  Weight dropout (recurrent): 0.5
  Total parameters: 3,812,097

Model Architecture:
LSTMDereverberation(
  (lstm_layers): ModuleList(
    (0): WeightDropLSTM(
      (lstm): LSTM(257, 512, batch_first=True)
    )
    (1): WeightDropLSTM(
      (lstm): LSTM(512, 512, batch_first=True)
    )
  )
  (dropout_layer): Dropout(p=0.3, inplace=False)
  (linear): Linear(in_features=512, out_features=257, bias=True)
  (relu): ReLU()
)

Testing with Step 2 data...
Input shape from Step 2: torch.Size([1, 751, 257])
Target shape from Step 2: torch.Size([1, 751, 257])
Model output shape: torch.Size([1, 751, 257])
MSE Loss with real data: 0.1943

STEP 3 SUMMARY:
‚úì LSTM model initialized with 2 layers, 512 units each
‚úì Weight dropout (0.5) applied to recurrent connections
‚úì Dropout (0.3) between LSTM layers
‚úì Linear projection + Re

In [18]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from tqdm import tqdm
import os

# === STEP 4 CLASS DEFINITIONS ===

class DereverberationDataset(Dataset):
    """
    Dataset for speech dereverberation training with clean+reverb pairs.

    Each sample contains:
    - reverb_features: Normalized, cubic-root compressed magnitude spectrum (reverberant)
    - target_features: Cubic-root compressed magnitude spectrum (clean, NOT normalized)
    """

    def __init__(self, reverb_features_list, target_features_list):
        """
        Args:
            reverb_features_list: List of numpy arrays, each shape (seq_len, 257)
            target_features_list: List of numpy arrays, each shape (seq_len, 257)
        """
        assert len(reverb_features_list) == len(target_features_list), \
            f"Mismatch: {len(reverb_features_list)} reverb vs {len(target_features_list)} target files"

        # Convert to torch tensors if needed
        self.reverb_features = []
        self.target_features = []

        for rev_feat, tgt_feat in zip(reverb_features_list, target_features_list):
            if isinstance(rev_feat, np.ndarray):
                rev_feat = torch.from_numpy(rev_feat).float()
            if isinstance(tgt_feat, np.ndarray):
                tgt_feat = torch.from_numpy(tgt_feat).float()

            self.reverb_features.append(rev_feat)
            self.target_features.append(tgt_feat)

        self.num_samples = len(self.reverb_features)

        print(f"Dataset initialized with {self.num_samples} samples")

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        """
        Returns:
            reverb: (seq_len, 257) - normalized, cubic-root compressed (reverberant)
            target: (seq_len, 257) - cubic-root compressed (clean, NOT normalized)
            seq_len: scalar - length of sequence
        """
        reverb = self.reverb_features[idx]
        target = self.target_features[idx]
        seq_len = reverb.shape[0]

        return reverb, target, seq_len


def collate_fn_variable_length(batch):
    """
    Collate function for variable-length sequences.

    Args:
        batch: List of tuples (reverb, target, seq_len)

    Returns:
        reverb_padded: (batch_size, max_seq_len, 257)
        target_padded: (batch_size, max_seq_len, 257)
        lengths: (batch_size,) - actual lengths before padding
    """
    reverb_list = [item[0] for item in batch]
    target_list = [item[1] for item in batch]
    lengths = torch.tensor([item[2] for item in batch])

    # Pad sequences to max length in batch
    reverb_padded = pad_sequence(reverb_list, batch_first=True, padding_value=0.0)
    target_padded = pad_sequence(target_list, batch_first=True, padding_value=0.0)

    return reverb_padded, target_padded, lengths


class DereverberationTrainer:
    """
    Trainer for LSTM-based dereverberation model.
    Implements Step 4 from the paper.
    """

    def __init__(self,
                 model,
                 train_dataset,
                 val_dataset=None,
                 batch_size=8,
                 learning_rate=0.001,
                 device='cuda' if torch.cuda.is_available() else 'cpu',
                 checkpoint_dir=None):  # Make checkpoint_dir optional
        """
        Args:
            model: LSTMDereverberation model (from Step 3)
            train_dataset: Training dataset (DereverberationDataset)
            val_dataset: Validation dataset (DereverberationDataset) - optional
            batch_size: Batch size (8 as per paper)
            learning_rate: Learning rate for Adam optimizer
            device: Device to train on
            checkpoint_dir: Directory to save checkpoints (Google Drive path)
        """
        self.model = model.to(device)
        self.device = device
        self.batch_size = batch_size

        # Set checkpoint directory to your existing dataset folder
        if checkpoint_dir is None:
            # Save to your existing Serveroperation_dataset folder
            try:
                from google.colab import drive
                drive.mount('/content/drive')
                self.checkpoint_dir = '/content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints'
            except:
                self.checkpoint_dir = './dereverberation_checkpoints'
        else:
            self.checkpoint_dir = checkpoint_dir

        print("Model weights already initialized with orthogonal initialization (from Step 3)")

        # Create training data loader
        self.train_loader = DataLoader(
            train_dataset,
            batch_size=batch_size,
            shuffle=True,
            collate_fn=collate_fn_variable_length,
            num_workers=0,
            pin_memory=True if device == 'cuda' else False
        )

        # Create validation data loader if provided
        self.has_val = val_dataset is not None and len(val_dataset) > 0
        if self.has_val:
            self.val_loader = DataLoader(
                val_dataset,
                batch_size=batch_size,
                shuffle=False,
                collate_fn=collate_fn_variable_length,
                num_workers=0,
                pin_memory=True if device == 'cuda' else False
            )
        else:
            self.val_loader = None
            print("‚ö†Ô∏è  No validation dataset provided - will only track training loss")

        # Loss function: MSE (as per paper)
        self.criterion = nn.MSELoss(reduction='mean')

        # Optimizer: Adam (as per paper)
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=learning_rate
        )

        # Create checkpoint directory
        os.makedirs(self.checkpoint_dir, exist_ok=True)

        # Training history
        self.train_losses = []
        self.val_losses = []
        self.best_val_loss = float('inf')

        print(f"\nTrainer initialized:")
        print(f"  Device: {device}")
        print(f"  Batch size: {batch_size}")
        print(f"  Learning rate: {learning_rate}")
        print(f"  Training samples: {len(train_dataset)}")
        if self.has_val:
            print(f"  Validation samples: {len(val_dataset)}")
        print(f"  Training batches per epoch: {len(self.train_loader)}")
        if self.has_val:
            print(f"  Validation batches per epoch: {len(self.val_loader)}")
        print(f"  Checkpoint directory: {self.checkpoint_dir}")

    def train_epoch(self, epoch):
        """
        Train for one epoch.

        Args:
            epoch: Current epoch number

        Returns:
            avg_loss: Average training loss for the epoch
        """
        self.model.train()
        total_loss = 0.0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {epoch} [Train]")

        for batch_idx, (reverb, target, lengths) in enumerate(pbar):
            # Move to device
            reverb = reverb.to(self.device)
            target = target.to(self.device)
            lengths = lengths.to(self.device)

            # Zero gradients
            self.optimizer.zero_grad()

            # Forward pass
            predicted, _ = self.model(reverb)

            # Compute loss (MSE in cubic root compressed space)
            loss = self.compute_loss_with_masking(predicted, target, lengths)

            # Backward pass
            loss.backward()

            # Gradient clipping (recommended for RNNs)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=5.0)

            # Update weights
            self.optimizer.step()

            # Accumulate loss
            total_loss += loss.item()
            num_batches += 1

            # Update progress bar
            pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / num_batches
        return avg_loss

    def validate(self, epoch):
        """
        Validate the model.

        Args:
            epoch: Current epoch number

        Returns:
            avg_loss: Average validation loss (or None if no validation set)
        """
        if not self.has_val:
            return None

        self.model.eval()
        total_loss = 0.0
        num_batches = 0

        pbar = tqdm(self.val_loader, desc=f"Epoch {epoch} [Val]")

        with torch.no_grad():
            for reverb, target, lengths in pbar:
                # Move to device
                reverb = reverb.to(self.device)
                target = target.to(self.device)
                lengths = lengths.to(self.device)

                # Forward pass
                predicted, _ = self.model(reverb)

                # Compute loss
                loss = self.compute_loss_with_masking(predicted, target, lengths)

                # Accumulate loss
                total_loss += loss.item()
                num_batches += 1

                # Update progress bar
                pbar.set_postfix({'loss': f'{loss.item():.4f}'})

        avg_loss = total_loss / num_batches
        return avg_loss

    def compute_loss_with_masking(self, predicted, target, lengths):
        """
        Compute MSE loss only on non-padded frames.

        Args:
            predicted: (batch_size, max_seq_len, 257)
            target: (batch_size, max_seq_len, 257)
            lengths: (batch_size,) - actual sequence lengths

        Returns:
            loss: Scalar loss value
        """
        batch_size, max_seq_len, feat_dim = predicted.shape

        # Create mask for non-padded frames
        mask = torch.arange(max_seq_len, device=self.device).unsqueeze(0) < lengths.unsqueeze(1)
        mask = mask.unsqueeze(-1).expand(-1, -1, feat_dim)  # (batch_size, max_seq_len, 257)

        # Apply mask
        predicted_masked = predicted * mask
        target_masked = target * mask

        # Compute MSE loss
        squared_diff = (predicted_masked - target_masked) ** 2

        # Average over non-padded elements only
        total_elements = mask.sum()
        loss = squared_diff.sum() / total_elements

        return loss

    def train(self, num_epochs, save_every=5):
        """
        Train the model for multiple epochs.

        Args:
            num_epochs: Number of epochs to train
            save_every: Save checkpoint every N epochs
        """
        print(f"\n{'='*60}")
        print(f"Starting training for {num_epochs} epochs")
        print(f"Checkpoints will be saved to: {self.checkpoint_dir}")
        print(f"{'='*60}\n")

        for epoch in range(1, num_epochs + 1):
            # Train
            train_loss = self.train_epoch(epoch)
            self.train_losses.append(train_loss)

            # Validate (if validation set exists)
            val_loss = self.validate(epoch) if self.has_val else None
            if val_loss is not None:
                self.val_losses.append(val_loss)

            # Print epoch summary
            print(f"\nEpoch {epoch}/{num_epochs}")
            print(f"  Train Loss: {train_loss:.4f}")
            if val_loss is not None:
                print(f"  Val Loss:   {val_loss:.4f}")

            # Save best model based on validation loss (or training loss if no validation)
            loss_to_compare = val_loss if val_loss is not None else train_loss
            if loss_to_compare < self.best_val_loss:
                self.best_val_loss = loss_to_compare
                self.save_checkpoint(epoch, is_best=True)
                metric_name = "Val" if val_loss is not None else "Train"
                print(f"  ‚úì New best model saved! ({metric_name} Loss: {loss_to_compare:.4f})")

            # Save periodic checkpoint
            if epoch % save_every == 0:
                self.save_checkpoint(epoch, is_best=False)
                print(f"  ‚úì Checkpoint saved at epoch {epoch}")

            print()

        print(f"{'='*60}")
        print(f"Training completed!")
        metric_name = "validation" if self.has_val else "training"
        print(f"Best {metric_name} loss: {self.best_val_loss:.4f}")
        print(f"All models saved to: {self.checkpoint_dir}")
        print(f"{'='*60}\n")

    def save_checkpoint(self, epoch, is_best=False):
        """
        Save model checkpoint to Google Drive.

        Args:
            epoch: Current epoch
            is_best: Whether this is the best model so far
        """
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'best_val_loss': self.best_val_loss
        }

        if is_best:
            path = os.path.join(self.checkpoint_dir, 'best_model.pt')
        else:
            path = os.path.join(self.checkpoint_dir, f'checkpoint_epoch_{epoch}.pt')

        torch.save(checkpoint, path)
        print(f"  Checkpoint saved: {path}")

    def load_checkpoint(self, checkpoint_path=None):
        """
        Load model checkpoint from Google Drive.

        Args:
            checkpoint_path: Path to checkpoint file (if None, loads best_model.pt)
        """
        if checkpoint_path is None:
            checkpoint_path = os.path.join(self.checkpoint_dir, 'best_model.pt')

        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']
        self.best_val_loss = checkpoint['best_val_loss']

        print(f"Checkpoint loaded from {checkpoint_path}")
        print(f"Epoch: {checkpoint['epoch']}")
        print(f"Best val loss: {self.best_val_loss:.4f}")


# === STEP 4 EXECUTION - ADAPTIVE FOR ANY NUMBER OF PAIRS ===
if __name__ == "__main__":
    print("\n" + "="*60)
    print("LSTM Dereverberation Training - Step 4")
    print("="*60)

    # Check if we have training data from Step 2
    # Step 2 should provide: reverb_features_list and target_features_list
    if 'reverb_features_list' in locals() and 'target_features_list' in locals():
        num_samples = len(reverb_features_list)
        print(f"‚úì Found {num_samples} training pairs from Step 2")

        # Adaptive splitting strategy based on number of samples
        if num_samples == 1:
            # Single pair - no validation
            print("‚ö†Ô∏è  Single audio pair - using for training only (no validation)")
            train_dataset = DereverberationDataset(reverb_features_list, target_features_list)
            val_dataset = None
            batch_size = 1
            num_epochs = 10
            print("   Model will memorize this single example")

        elif num_samples < 10:
            # Few pairs (2-9) - use all for training
            print(f"‚ö†Ô∏è  Only {num_samples} pairs - using all for training (no validation)")
            train_dataset = DereverberationDataset(reverb_features_list, target_features_list)
            val_dataset = None
            batch_size = min(8, num_samples)
            num_epochs = 10
            print(f"   Batch size: {batch_size}")

        else:
            # Many pairs (10+) - split train/val 80/20
            split_idx = int(0.8 * num_samples)
            print(f"‚úì Splitting {num_samples} pairs: {split_idx} train, {num_samples - split_idx} validation")

            train_dataset = DereverberationDataset(
                reverb_features_list[:split_idx],
                target_features_list[:split_idx]
            )
            val_dataset = DereverberationDataset(
                reverb_features_list[split_idx:],
                target_features_list[split_idx:]
            )
            batch_size = 8
            num_epochs = 10

        # Initialize model from Step 3
        if 'LSTMDereverberation' not in dir():
            print("‚ùå LSTMDereverberation class not found!")
            print("   Please run Step 3 first or ensure it's in the same script")
            raise ImportError("LSTMDereverberation class must be defined before Step 4")

        model = LSTMDereverberation(
            input_size=257,
            hidden_size=512,
            num_layers=2,
            dropout=0.3,
            weight_dropout=0.5
        )

        # Set checkpoint path to your existing dataset folder
        try:
            from google.colab import drive
            drive.mount('/content/drive')
            checkpoint_dir = '/content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints'
            print(f"‚úì Google Drive mounted - checkpoints will save to: {checkpoint_dir}")
        except:
            checkpoint_dir = './dereverberation_checkpoints'
            print(f"‚ö†Ô∏è  Running locally - checkpoints will save to: {checkpoint_dir}")

        # Initialize trainer
        trainer = DereverberationTrainer(
            model=model,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            batch_size=batch_size,
            learning_rate=0.001,
            checkpoint_dir=checkpoint_dir
        )

        print(f"\nTraining configuration:")
        print(f"  Total samples: {num_samples}")
        print(f"  Training samples: {len(train_dataset)}")
        if val_dataset:
            print(f"  Validation samples: {len(val_dataset)}")
        print(f"  Batch size: {batch_size}")
        print(f"  Epochs: {num_epochs}")
        print(f"  Learning rate: 0.001")
        print(f"  Checkpoint location: {checkpoint_dir}")
        print(f"  Training objective: normalized reverb ‚Üí unnormalized clean")

        # Start training automatically
        print("\nStarting training...")
        trainer.train(num_epochs=num_epochs, save_every=10)

        # ============================================================
        # POST-TRAINING EVALUATION
        # ============================================================
        print("\n" + "="*60)
        print("POST-TRAINING EVALUATION")
        print("="*60)

        # Load the best model for evaluation
        best_model_path = os.path.join(checkpoint_dir, 'best_model.pt')
        trainer.load_checkpoint(best_model_path)

        # Evaluate on training set (or validation if available)
        print("\nEvaluating model performance...")
        trainer.model.eval()

        eval_dataset = val_dataset if val_dataset else train_dataset
        eval_loader = DataLoader(
            eval_dataset,
            batch_size=1,
            shuffle=False,
            collate_fn=collate_fn_variable_length
        )

        total_mse = 0.0
        total_mae = 0.0
        num_samples = 0

        with torch.no_grad():
            for reverb, target, lengths in tqdm(eval_loader, desc="Evaluating"):
                reverb = reverb.to(trainer.device)
                target = target.to(trainer.device)
                lengths = lengths.to(trainer.device)

                # Forward pass
                predicted, _ = trainer.model(reverb)

                # Compute metrics (only on non-padded regions)
                batch_size, max_seq_len, feat_dim = predicted.shape
                mask = torch.arange(max_seq_len, device=trainer.device).unsqueeze(0) < lengths.unsqueeze(1)
                mask = mask.unsqueeze(-1).expand(-1, -1, feat_dim)

                # MSE
                mse = ((predicted - target) ** 2 * mask).sum() / mask.sum()
                total_mse += mse.item()

                # MAE (Mean Absolute Error)
                mae = (torch.abs(predicted - target) * mask).sum() / mask.sum()
                total_mae += mae.item()

                num_samples += 1

        avg_mse = total_mse / num_samples
        avg_mae = total_mae / num_samples
        rmse = np.sqrt(avg_mse)

        # Calculate approximate SNR improvement (rough estimate)
        snr_improvement_estimate = -10 * np.log10(avg_mse + 1e-10)

        print("\n" + "="*60)
        print("MODEL PERFORMANCE METRICS")
        print("="*60)
        print(f"üìä Evaluation Dataset: {'Validation' if val_dataset else 'Training'} set")
        print(f"üìà Number of samples evaluated: {num_samples}")
        print(f"\nüéØ Reconstruction Metrics (in cubic-root compressed space):")
        print(f"   ‚Ä¢ Mean Squared Error (MSE):  {avg_mse:.6f}")
        print(f"   ‚Ä¢ Root Mean Squared Error:   {rmse:.6f}")
        print(f"   ‚Ä¢ Mean Absolute Error (MAE): {avg_mae:.6f}")
        print(f"\nüîä Estimated Quality Improvement:")
        print(f"   ‚Ä¢ SNR Improvement (approx):  {snr_improvement_estimate:.2f} dB")

        # Interpret the results
        print(f"\nüí° Interpretation:")
        if avg_mse < 0.01:
            print("   ‚úÖ EXCELLENT: Very low error - model learned the mapping well!")
        elif avg_mse < 0.05:
            print("   ‚úÖ GOOD: Reasonable error - model shows learning progress")
        elif avg_mse < 0.1:
            print("   ‚ö†Ô∏è  FAIR: Moderate error - may need more training or data")
        else:
            print("   ‚ùå POOR: High error - needs more epochs, data, or architecture tuning")

        if num_samples == 1:
            print("   ‚ö†Ô∏è  NOTE: Evaluated on single sample - model likely memorized it")
            print("      For real assessment, test on unseen audio files!")
        elif not val_dataset:
            print("   ‚ö†Ô∏è  NOTE: Evaluated on training data - may be overfitting")
            print("      For real assessment, use separate validation/test set!")

        # Training history visualization
        print(f"\nüìâ Training History:")
        print(f"   ‚Ä¢ Initial training loss: {trainer.train_losses[0]:.6f}")
        print(f"   ‚Ä¢ Final training loss:   {trainer.train_losses[-1]:.6f}")
        print(f"   ‚Ä¢ Loss reduction:        {((trainer.train_losses[0] - trainer.train_losses[-1]) / trainer.train_losses[0] * 100):.1f}%")

        if trainer.val_losses:
            print(f"   ‚Ä¢ Best validation loss:  {trainer.best_val_loss:.6f}")

        print("\n" + "="*60)
        print("STEP 4 COMPLETED:")
        print("="*60)
        print(f"‚úì Model trained on {num_samples} audio pairs")
        print(f"‚úì Checkpoints saved to: {checkpoint_dir}")
        print("‚úì Best model saved as 'best_model.pt'")
        print("‚úì Ready for Step 5 (Inference)")
        print("="*60)

    elif 'reverb_features' in locals() and 'target_features' in locals():
        # Fallback: single pair from old Step 2 format
        print("‚úì Found single training pair from Step 2 (old format)")
        print("‚ö†Ô∏è  Converting to list format...")

        reverb_features_list = [reverb_features]
        target_features_list = [target_features]

        train_dataset = DereverberationDataset(reverb_features_list, target_features_list)
        val_dataset = None
        batch_size = 1
        num_epochs = 10

        if 'LSTMDereverberation' not in dir():
            raise ImportError("LSTMDereverberation class must be defined before Step 4")

        # Set checkpoint path to your existing dataset folder
        try:
            from google.colab import drive
            drive.mount('/content/drive')
            checkpoint_dir = '/content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints'
            print(f"‚úì Google Drive mounted - checkpoints will save to: {checkpoint_dir}")
        except:
            checkpoint_dir = './dereverberation_checkpoints'
            print(f"‚ö†Ô∏è  Running locally - checkpoints will save to: {checkpoint_dir}")

        model = LSTMDereverberation(
            input_size=257,
            hidden_size=512,
            num_layers=2,
            dropout=0.3,
            weight_dropout=0.5
        )

        trainer = DereverberationTrainer(
            model=model,
            train_dataset=train_dataset,
            val_dataset=val_dataset,
            batch_size=batch_size,
            learning_rate=0.001,
            checkpoint_dir=checkpoint_dir
        )

        print("\nStarting training...")
        trainer.train(num_epochs=num_epochs, save_every=10)

        print("\n" + "="*60)
        print("STEP 4 COMPLETED:")
        print("="*60)
        print("‚úì Model trained on single audio pair")
        print(f"‚úì Checkpoints saved to: {checkpoint_dir}")
        print("‚úì Best model saved as 'best_model.pt'")
        print("‚úì Ready for Step 5 (Inference)")
        print("="*60)

    else:
        print("‚ùå No training data found from Step 2")
        print("Please run Step 2 first to generate training data")
        print("\nExpected variables:")
        print("  - reverb_features_list: List of reverberant features")
        print("  - target_features_list: List of clean target features")

        print("\n" + "="*60)
        print("STEP 4 SUMMARY:")
        print("="*60)
        print("‚úì Dataset class for variable-length sequences")
        print("‚úì Collate function for padding")
        print("‚úì Trainer with MSE loss in cubic root space")
        print("‚úì Adam optimizer with gradient clipping")
        print("‚úì Checkpoint saving to your dataset folder")
        print("‚úì Adaptive splitting: 1 pair ‚Üí no val, 10+ pairs ‚Üí 80/20 split")
        print("‚ö†Ô∏è  Waiting for training data from Step 2")
        print("="*60)


LSTM Dereverberation Training - Step 4
‚úì Found 2 training pairs from Step 2
‚ö†Ô∏è  Only 2 pairs - using all for training (no validation)
Dataset initialized with 2 samples
   Batch size: 2
Initialized LSTMDereverberation model:
  Input size: 257
  Hidden size: 512
  Num LSTM layers: 2
  Dropout (between layers): 0.3
  Weight dropout (recurrent): 0.5
  Total parameters: 3,812,097
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚úì Google Drive mounted - checkpoints will save to: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints
Model weights already initialized with orthogonal initialization (from Step 3)
‚ö†Ô∏è  No validation dataset provided - will only track training loss

Trainer initialized:
  Device: cpu
  Batch size: 2
  Learning rate: 0.001
  Training samples: 2
  Training batches per epoch: 1
  Checkpoint directory: /content/drive/MyDrive/Serveroperation_dataset/derever

Epoch 1 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.21s/it, loss=0.2341]



Epoch 1/10
  Train Loss: 0.2341
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.2341)



Epoch 2 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.34s/it, loss=0.2168]



Epoch 2/10
  Train Loss: 0.2168
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.2168)



Epoch 3 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.24s/it, loss=0.1904]



Epoch 3/10
  Train Loss: 0.1904
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1904)



Epoch 4 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.21s/it, loss=0.1708]



Epoch 4/10
  Train Loss: 0.1708
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1708)



Epoch 5 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.29s/it, loss=0.1562]



Epoch 5/10
  Train Loss: 0.1562
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1562)



Epoch 6 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.21s/it, loss=0.1392]



Epoch 6/10
  Train Loss: 0.1392
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1392)



Epoch 7 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.70s/it, loss=0.1312]



Epoch 7/10
  Train Loss: 0.1312
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1312)



Epoch 8 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.71s/it, loss=0.1252]



Epoch 8/10
  Train Loss: 0.1252
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1252)



Epoch 9 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.38s/it, loss=0.1184]



Epoch 9/10
  Train Loss: 0.1184
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1184)



Epoch 10 [Train]: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1/1 [00:01<00:00,  1.22s/it, loss=0.1097]



Epoch 10/10
  Train Loss: 0.1097
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
  ‚úì New best model saved! (Train Loss: 0.1097)
  Checkpoint saved: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/checkpoint_epoch_10.pt
  ‚úì Checkpoint saved at epoch 10

Training completed!
Best training loss: 0.1097
All models saved to: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints


POST-TRAINING EVALUATION
Checkpoint loaded from /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints/best_model.pt
Epoch: 10
Best val loss: 0.1097

Evaluating model performance...


Evaluating: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2/2 [00:00<00:00,  2.66it/s]


MODEL PERFORMANCE METRICS
üìä Evaluation Dataset: Training set
üìà Number of samples evaluated: 2

üéØ Reconstruction Metrics (in cubic-root compressed space):
   ‚Ä¢ Mean Squared Error (MSE):  0.101044
   ‚Ä¢ Root Mean Squared Error:   0.317874
   ‚Ä¢ Mean Absolute Error (MAE): 0.214388

üîä Estimated Quality Improvement:
   ‚Ä¢ SNR Improvement (approx):  9.95 dB

üí° Interpretation:
   ‚ùå POOR: High error - needs more epochs, data, or architecture tuning
   ‚ö†Ô∏è  NOTE: Evaluated on training data - may be overfitting
      For real assessment, use separate validation/test set!

üìâ Training History:
   ‚Ä¢ Initial training loss: 0.234087
   ‚Ä¢ Final training loss:   0.109708
   ‚Ä¢ Loss reduction:        53.1%

STEP 4 COMPLETED:
‚úì Model trained on 2 audio pairs
‚úì Checkpoints saved to: /content/drive/MyDrive/Serveroperation_dataset/dereverberation_checkpoints
‚úì Best model saved as 'best_model.pt'
‚úì Ready for Step 5 (Inference)





In [22]:
import torch
import numpy as np
import librosa
import soundfile as sf
import os

class DereverberationInference:
    """
    Inference pipeline for LSTM-based speech dereverberation.
    Implements Steps 5 & 6 from the paper.
    """

    def __init__(self,
                 model_checkpoint_path,
                 preprocessor,
                 device='cuda' if torch.cuda.is_available() else 'cpu'):
        """
        Initialize inference pipeline.
        """
        self.device = device
        self.preprocessor = preprocessor

        # Verify preprocessor has normalization statistics
        if self.preprocessor.feature_mean is None or self.preprocessor.feature_std is None:
            raise ValueError(
                "Preprocessor must have normalization statistics computed."
            )

        # Load trained model
        print(f"Loading model from: {model_checkpoint_path}")
        self.model = self._load_model(model_checkpoint_path)
        self.model.eval()

        print(f"Inference pipeline initialized on device: {device}")

    def _load_model(self, checkpoint_path):
        """
        Load trained model from checkpoint.
        """
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        # Reconstruct model architecture
        model = LSTMDereverberation(
            input_size=257,
            hidden_size=512,
            num_layers=2,
            dropout=0.3,
            weight_dropout=0.5
        )

        # Load trained weights
        model.load_state_dict(checkpoint['model_state_dict'])
        model.to(self.device)

        print(f"Model loaded successfully!")
        print(f"  Trained for {checkpoint['epoch']} epochs")
        print(f"  Best validation loss: {checkpoint['best_val_loss']:.4f}")

        return model

    def enhance_audio(self, audio_input):
        """
        Complete enhancement pipeline: Steps 5 & 6.
        """
        print("\n" + "="*60)
        print("STEP 5: Enhancement Process")
        print("="*60)

        # Load audio if file path provided
        if isinstance(audio_input, str):
            print(f"Loading audio from: {audio_input}")
            audio, sr = librosa.load(audio_input, sr=self.preprocessor.sample_rate, mono=True)
        else:
            audio = audio_input

        print(f"Audio length: {len(audio)} samples ({len(audio)/self.preprocessor.sample_rate:.2f} seconds)")

        # Step 5.1: Extract magnitude spectrum and phase from reverberant audio
        print("Extracting magnitude spectrum and phase...")
        magnitude_reverb, phase_reverb = self.preprocessor.extract_magnitude_spectrum(audio)

        # Step 5.2: Apply cubic root compression
        print("Applying cubic root compression...")
        compressed_reverb = self.preprocessor.apply_cubic_root_compression(magnitude_reverb)

        # Step 5.3: Normalize using training statistics
        print("Normalizing features using training statistics...")
        normalized_reverb = self.preprocessor.normalize_features(
            compressed_reverb,
            compute_stats=False
        )

        # Step 5.4: Prepare input for LSTM
        print("Preparing input for LSTM...")
        input_tensor = torch.from_numpy(normalized_reverb).float()
        input_tensor = input_tensor.unsqueeze(0)  # (1, seq_len, 257)
        input_tensor = input_tensor.to(self.device)

        # Step 5.5: Forward pass through LSTM
        print("Running LSTM inference...")
        with torch.no_grad():
            enhanced_normalized, _ = self.model(input_tensor)

        # Move back to CPU and remove batch dimension
        enhanced_normalized = enhanced_normalized.squeeze(0).cpu().numpy()

        print("\n" + "="*60)
        print("STEP 6: Signal Reconstruction")
        print("="*60)

        # Step 6.1: Denormalize the LSTM output
        print("Denormalizing LSTM output...")
        enhanced_compressed = self._denormalize(enhanced_normalized)

        # Step 6.2: Reverse cubic root compression (cube it)
        print("Reversing cubic root compression (cubing)...")
        enhanced_magnitude = np.power(enhanced_compressed, 3.0)

        # Step 6.3: Combine enhanced magnitude with original reverberant phase
        print("Combining with original phase...")
        enhanced_complex = enhanced_magnitude * np.exp(1j * phase_reverb)

        # Step 6.4: Reconstruct time-domain signal via inverse STFT
        print("Reconstructing time-domain signal (inverse STFT)...")
        enhanced_audio = self._inverse_stft(enhanced_complex.T)  # Transpose for librosa

        # Normalize to prevent clipping
        max_val = np.max(np.abs(enhanced_audio))
        if max_val > 0:
            enhanced_audio = enhanced_audio / max_val * 0.95

        print("\n" + "="*60)
        print("Enhancement complete!")
        print("="*60)

        return enhanced_audio

    def _denormalize(self, normalized_features):
        """
        Denormalize features using training statistics.
        """
        mean = self.preprocessor.feature_mean
        std = self.preprocessor.feature_std

        denormalized = normalized_features * std + mean
        return denormalized

    def _inverse_stft(self, complex_spectrum):
        """
        Reconstruct time-domain signal from complex spectrum.
        """
        audio = librosa.istft(
            complex_spectrum,
            hop_length=self.preprocessor.hop_length,
            win_length=self.preprocessor.frame_length,
            window='hamming',
            center=True
        )
        return audio

    def save_audio(self, audio, output_path, sample_rate=None):
        """
        Save audio to WAV file.
        """
        if sample_rate is None:
            sample_rate = self.preprocessor.sample_rate

        sf.write(output_path, audio, sample_rate)
        print(f"Audio saved to: {output_path}")


# === STEP 5 & 6 EXECUTION ===
print("\n" + "="*60)
print("LSTM Dereverberation - Steps 5 & 6: Inference & Reconstruction")
print("="*60)

try:
    from google.colab import drive
    drive.mount('/content/drive')

    # DIRECT PATHS - NO EXPLORATION
    test_reverb_path = '/content/drive/MyDrive/dereverberation_dataset/test/reverberant'
    output_folder = '/content/drive/MyDrive/dereverberation_dataset/test/output'
    checkpoint_dir = '/content/drive/MyDrive/dereverberation_dataset/dereverberation_checkpoints'

    # Create output folder
    os.makedirs(output_folder, exist_ok=True)

    # Get audio files directly
    test_audio_files = []
    for f in os.listdir(test_reverb_path):
        if f.lower().endswith(('.wav', '.flac', '.mp3', '.m4a', '.ogg')):
            test_audio_files.append(f)

    if not test_audio_files:
        print("‚ùå No audio files found in test/reverberant folder")
        print("Please add audio files to proceed")
    else:
        print(f"‚úÖ Found {len(test_audio_files)} audio files")

        # Check for trained models
        checkpoint_files = []
        for f in os.listdir(checkpoint_dir):
            if f.endswith('.pt'):
                checkpoint_files.append(os.path.join(checkpoint_dir, f))

        if not checkpoint_files:
            print("‚ùå No trained model found")
        else:
            model_checkpoint = [f for f in checkpoint_files if 'best_model' in f][0]
            print(f"‚úÖ Using best model: {model_checkpoint}")

            # Initialize preprocessor
            preprocessor = AudioPreprocessor(
                sample_rate=16000,
                frame_length_ms=32,
                frame_shift_ms=8,
                n_fft=512,
                normalize=True
            )

            # Initialize inference pipeline
            inference = DereverberationInference(
                model_checkpoint_path=model_checkpoint,
                preprocessor=preprocessor
            )

            # Process all test files
            for audio_file in test_audio_files:
                input_audio_path = os.path.join(test_reverb_path, audio_file)
                output_filename = f"enhanced_{os.path.splitext(audio_file)[0]}.wav"
                output_path = os.path.join(output_folder, output_filename)

                print(f"\nüéØ Enhancing: {audio_file}")
                enhanced_audio = inference.enhance_audio(input_audio_path)
                inference.save_audio(enhanced_audio, output_path)

            print(f"\nüéâ ALL TEST FILES PROCESSED!")
            print(f"Output saved to: {output_folder}")

except Exception as e:
    print(f"‚ùå Error: {e}")
    import traceback
    traceback.print_exc()

print("\n" + "="*60)
print("STEPS 5 & 6 COMPLETE")
print("="*60)


LSTM Dereverberation - Steps 5 & 6: Inference & Reconstruction
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
‚ùå No audio files found in test/reverberant folder
Please add audio files to proceed

STEPS 5 & 6 COMPLETE
