In [27]:
!pip install librosa soundfile webrtcvad pydub



In [28]:
import os
import numpy as np
import librosa
import soundfile as sf
from pydub import AudioSegment
from google.colab import drive
import webrtcvad
import wave
import contextlib
from datetime import datetime
import logging
from typing import Tuple, List, Optional

In [29]:
class AudioProcessingError(Exception):
    """Custom exception for audio processing errors."""
    pass

class AudioProcessor:
    SUPPORTED_FORMATS = {'.wav', '.mp3', '.flac'}

    def __init__(self, target_sr: int = 16000, min_segment_length: int = 20, max_segment_length: int = 30):
        """
        Initialize the audio processor with support for multiple formats.

        Args:
            target_sr: Target sample rate in Hz
            min_segment_length: Minimum segment length in seconds
            max_segment_length: Maximum segment length in seconds
        """
        try:
            if target_sr <= 0:
                raise ValueError("Target sample rate must be positive")
            if min_segment_length <= 0 or max_segment_length <= 0:
                raise ValueError("Segment lengths must be positive")
            if min_segment_length >= max_segment_length:
                raise ValueError("Minimum segment length must be less than maximum segment length")

            self.target_sr = target_sr
            self.min_segment_length = min_segment_length
            self.max_segment_length = max_segment_length
            self.vad = webrtcvad.Vad(2)  # Reduced aggressiveness to 2 for better segment detection

            logging.basicConfig(level=logging.INFO,
                              format='%(asctime)s - %(levelname)s - %(message)s')
            self.logger = logging.getLogger(__name__)

        except Exception as e:
            raise AudioProcessingError(f"Failed to initialize AudioProcessor: {str(e)}")

    def mount_drive(self) -> None:
        """
        Mount Google Drive with error handling.

        Raises:
            AudioProcessingError: If drive mounting fails
        """
        try:
            drive.mount('/content/drive')
            self.logger.info("Google Drive mounted successfully")
        except Exception as e:
            raise AudioProcessingError(f"Failed to mount Google Drive: {str(e)}")

    def get_file_format(self, file_path: str) -> str:
        """
        Get the audio file format from the file path.

        Args:
            file_path: Path to the audio file

        Returns:
            File extension including the dot

        Raises:
            AudioProcessingError: If format is not supported
        """
        file_ext = os.path.splitext(file_path)[1].lower()
        if file_ext not in self.SUPPORTED_FORMATS:
            raise AudioProcessingError(f"Unsupported audio format: {file_ext}")
        return file_ext

    def validate_file_path(self, file_path: str) -> None:
        """
        Validate if file exists and is accessible.

        Args:
            file_path: Path to the file

        Raises:
            AudioProcessingError: If file validation fails
        """
        if not os.path.exists(file_path):
            raise AudioProcessingError(f"File not found: {file_path}")
        if not os.path.isfile(file_path):
            raise AudioProcessingError(f"Not a file: {file_path}")
        if not os.access(file_path, os.R_OK):
            raise AudioProcessingError(f"File not readable: {file_path}")

    def load_audio(self, file_path: str) -> Tuple[np.ndarray, int]:
        """
        Load audio file in any supported format.

        Args:
            file_path: Path to the audio file

        Returns:
            Tuple of audio data and sample rate
        """
        try:
            self.validate_file_path(file_path)
            file_format = self.get_file_format(file_path)

            self.logger.info(f"Loading {file_format} audio file: {file_path}")

            if file_format == '.mp3':
                # Use pydub for MP3 files
                audio_segment = AudioSegment.from_mp3(file_path)
                audio_array = np.array(audio_segment.get_array_of_samples())
                # Convert to float32 and normalize
                audio_array = audio_array.astype(np.float32) / np.iinfo(np.int16).max
                if audio_segment.channels == 2:
                    audio_array = audio_array.reshape((-1, 2))
                    audio_array = np.mean(audio_array, axis=1)
                return audio_array, audio_segment.frame_rate
            else:
                # Use librosa for WAV and FLAC
                audio, sr = librosa.load(file_path, sr=None, mono=True)
                return audio, sr

        except Exception as e:
            raise AudioProcessingError(f"Failed to load audio file: {str(e)}")

    def save_segment(self, audio: np.ndarray, sr: int, filepath: str, original_format: str) -> None:
        """
        Save audio segment in the same format as the input file.

        Args:
            audio: Audio data
            sr: Sample rate
            filepath: Output file path
            original_format: Original audio format
        """
        try:
            if original_format == '.mp3':
                # Convert to int16 and save as MP3
                audio_int16 = (audio * np.iinfo(np.int16).max).astype(np.int16)
                segment = AudioSegment(
                    audio_int16.tobytes(),
                    frame_rate=sr,
                    sample_width=2,
                    channels=1
                )
                segment.export(filepath, format='mp3')
            else:
                # Save as WAV or FLAC using soundfile
                sf.write(filepath, audio, sr, format=original_format[1:])

        except Exception as e:
            raise AudioProcessingError(f"Failed to save segment: {str(e)}")

    def save_segments(self, audio: np.ndarray, sr: int, segments: List[Tuple[float, float]],
                     output_dir: str, filename_prefix: str, original_format: str) -> List[str]:
        """
        Save audio segments in the original format with sequential naming.
        """
        try:
            os.makedirs(output_dir, exist_ok=True)

            if not os.access(output_dir, os.W_OK):
                raise AudioProcessingError(f"Output directory not writable: {output_dir}")

            current_number = self.get_next_file_number(output_dir, filename_prefix)

            saved_files = []
            total_segments = len(segments)

            for idx, (start, end) in enumerate(segments):
                try:
                    start_sample = int(start * sr)
                    end_sample = int(end * sr)
                    segment = audio[start_sample:end_sample]

                    filename = f"{filename_prefix}-{current_number:03d}{original_format}"
                    filepath = os.path.join(output_dir, filename)

                    # Save segment in original format
                    self.save_segment(segment, sr, filepath, original_format)
                    saved_files.append(filename)

                    # Print progress for each saved file
                    print(f"Saved segment {current_number:03d} ({idx + 1}/{total_segments}): {filename}")
                    print(f"  Duration: {(end - start):.2f} seconds")
                    print(f"  Time range: {start:.2f}s - {end:.2f}s")
                    print("-" * 50)

                    current_number += 1

                except Exception as e:
                    self.logger.error(f"Failed to save segment {current_number}: {str(e)}")
                    continue

            if not saved_files:
                raise AudioProcessingError("No segments were successfully saved")

            return saved_files

        except Exception as e:
            raise AudioProcessingError(f"Failed to save segments: {str(e)}")

    def detect_voice_activity(self, audio: np.ndarray, sr: int) -> List[Tuple[float, float]]:
        """
        Detect segments with voice activity.

        Args:
            audio: Audio data
            sr: Sample rate

        Returns:
            List of (start, end) tuples in seconds
        """
        try:
            # Convert to 16-bit PCM
            audio_pcm = (audio * 32768).astype(np.int16)

            # Parameters for VAD
            frame_duration = 20  # Reduced from 30ms to 20ms for finer granularity
            frames_per_window = sr * frame_duration // 1000

            # Split audio into frames
            frames = []
            for i in range(0, len(audio_pcm), frames_per_window):
                frame = audio_pcm[i:i + frames_per_window]
                if len(frame) < frames_per_window:
                    frame = np.pad(frame, (0, frames_per_window - len(frame)))
                frames.append(frame.tobytes())

            if not frames:
                raise AudioProcessingError("No valid frames found in audio")

            # Detect speech in frames with smoothing
            is_speech = []
            window_size = 3  # Smoothing window size
            for i in range(0, len(frames)):
                # Get average of surrounding frames
                window_start = max(0, i - window_size)
                window_end = min(len(frames), i + window_size + 1)
                window_speeches = []
                for frame in frames[window_start:window_end]:
                    try:
                        window_speeches.append(self.vad.is_speech(frame, sr))
                    except Exception as e:
                        self.logger.warning(f"Failed to process frame: {str(e)}")
                        window_speeches.append(False)
                # Use majority vote
                is_speech.append(sum(window_speeches) > len(window_speeches) / 2)

            # Find continuous speech segments with more lenient gap filling
            segments = []
            start = None
            current_duration = 0
            silence_duration = 0
            max_silence_gap = 0.5  # Maximum silence gap to bridge (in seconds)

            for i, speech in enumerate(is_speech):
                frame_time = frame_duration / 1000  # Convert ms to seconds

                if speech:
                    if start is None:
                        start = i
                    current_duration += frame_time
                    silence_duration = 0
                else:
                    if start is not None:
                        silence_duration += frame_time
                        if silence_duration >= max_silence_gap:
                            # End segment if silence is too long
                            if current_duration >= self.min_segment_length:
                                end = i - int(silence_duration * (1000/frame_duration))
                                segments.append((
                                    start * frame_time,
                                    end * frame_time
                                ))
                                self.logger.info(f"Found segment: {start * frame_time:.2f}s - {end * frame_time:.2f}s")
                            start = None
                            current_duration = 0
                            silence_duration = 0

                # Force split if segment is too long
                if current_duration >= self.max_segment_length:
                    end = i
                    segments.append((
                        start * frame_time,
                        end * frame_time
                    ))
                    self.logger.info(f"Split long segment: {start * frame_time:.2f}s - {end * frame_time:.2f}s")
                    start = i
                    current_duration = 0
                    silence_duration = 0

            # Handle the last segment
            if start is not None and current_duration >= self.min_segment_length:
                segments.append((
                    start * frame_time,
                    len(is_speech) * frame_time
                ))
                self.logger.info(f"Final segment: {start * frame_time:.2f}s - {len(is_speech) * frame_time:.2f}s")

            if not segments:
                self.logger.warning("No voice activity segments detected")
            else:
                self.logger.info(f"Total segments detected: {len(segments)}")

            return segments

        except Exception as e:
            raise AudioProcessingError(f"Failed to detect voice activity: {str(e)}")

    def get_next_file_number(self, output_dir: str, filename_prefix: str) -> int:
        """
        Find the next available file number in the sequence.

        Args:
            output_dir: Output directory
            filename_prefix: Prefix for output filenames

        Returns:
            Next available file number

        Raises:
            AudioProcessingError: If directory access fails
        """
        try:
            if not os.path.exists(output_dir):
                return 0

            existing_files = os.listdir(output_dir)
            existing_numbers = []

            # Extract existing numbers from filenames
            for filename in existing_files:
                if filename.startswith(filename_prefix) and filename.endswith('.wav'):
                    try:
                        num_str = filename.replace(filename_prefix + '-', '').replace('.wav', '')
                        num = int(num_str)
                        existing_numbers.append(num)
                    except ValueError:
                        continue

            return max(existing_numbers + [-1]) + 1

        except Exception as e:
            raise AudioProcessingError(f"Failed to get next file number: {str(e)}")

    def process_audio_file(self, input_file: str, output_dir: str, filename_prefix: str) -> Tuple[int, List[str]]:
        """
        Process a single audio file and maintain original format.

        Args:
            input_file: Path to input audio file
            output_dir: Output directory
            filename_prefix: Prefix for output filenames

        Returns:
            Tuple of number of segments created and list of saved filenames
        """
        try:
            # Get original format
            original_format = self.get_file_format(input_file)

            # Load audio
            self.logger.info("Starting audio processing")
            audio, sr = self.load_audio(input_file)

            # Resample if necessary
            if sr != self.target_sr:
                self.logger.info(f"Resampling from {sr}Hz to {self.target_sr}Hz")
                audio = librosa.resample(audio, orig_sr=sr, target_sr=self.target_sr)
                sr = self.target_sr

            # Detect voice activity segments
            self.logger.info("Detecting voice activity")
            segments = self.detect_voice_activity(audio, sr)

            if not segments:
                self.logger.warning("No voice segments detected in the audio file")
                return 0, []

            # Save segments in original format
            self.logger.info(f"Saving segments in {original_format} format")
            saved_files = self.save_segments(audio, sr, segments, output_dir, filename_prefix, original_format)

            return len(segments), saved_files

        except Exception as e:
            raise AudioProcessingError(f"Failed to process audio file: {str(e)}")

def main():
    try:
        # Initialize processor
        processor = AudioProcessor(
            target_sr=16000,
            min_segment_length=20,
            max_segment_length=30
        )

        # Configure paths
        input_file = '/content/drive/Shareddrives/CS307-Thesis/Dataset/test-data/test_raw.mp3'  # Can be .mp3, .wav, or .flac
        output_dir = '/content/drive/Shareddrives/CS307-Thesis/Dataset/test-data/'
        filename_prefix = 'training'

        # Process audio file
        num_segments, saved_files = processor.process_audio_file(input_file, output_dir, filename_prefix)

        print(f"\nProcessing complete!")
        print(f"Created {num_segments} segments.")
        print("\nSaved files:")
        for filename in saved_files:
            print(f"- {filename}")

    except AudioProcessingError as e:
        print(f"\nError during audio processing: {str(e)}")
        logging.error(f"Audio processing failed: {str(e)}")
    except Exception as e:
        print(f"\nUnexpected error: {str(e)}")
        logging.error(f"Unexpected error: {str(e)}", exc_info=True)

if __name__ == "__main__":
    main()

Saved segment 000 (1/198): training-000.mp3
  Duration: 31.10 seconds
  Time range: 0.02s - 31.12s
--------------------------------------------------
Saved segment 001 (2/198): training-001.mp3
  Duration: 30.58 seconds
  Time range: 31.12s - 61.70s
--------------------------------------------------
Saved segment 002 (3/198): training-002.mp3
  Duration: 30.58 seconds
  Time range: 61.70s - 92.28s
--------------------------------------------------
Saved segment 003 (4/198): training-003.mp3
  Duration: 31.06 seconds
  Time range: 92.28s - 123.34s
--------------------------------------------------
Saved segment 004 (5/198): training-004.mp3
  Duration: 30.26 seconds
  Time range: 123.34s - 153.60s
--------------------------------------------------
Saved segment 005 (6/198): training-005.mp3
  Duration: 30.06 seconds
  Time range: 153.60s - 183.66s
--------------------------------------------------
Saved segment 006 (7/198): training-006.mp3
  Duration: 33.70 seconds
  Time range: 191.08