In [1]:
import wave
from time import time
import torch
import os
from tqdm import tqdm
import numpy as np
import pandas as pd
import webrtcvad
import contextlib
import soundfile as sf
import scipy.io.wavfile as wav
import resampy

from other.utils import get_files_by_extension

In [2]:
class WebrtcVadLabelMaker:
    @staticmethod
    def resample_pcm_wav(file_path, target_sr, output_path=None):
        """
        Reads a PCM-formatted WAV file, resamples if the sample rate differs from target_sr,
        and saves the output as a PCM WAV file.

        Parameters:
        - file_path: str, path to the input WAV file.
        - target_sr: int, target sample rate (Hz).
        - output_path: str (optional), path to save the resampled WAV file.
                    If None, the output is saved as "<original_filename>_resampled.wav" in the same directory.
        """
        # Read original wav file
        orig_sr, data = wav.read(file_path)

        # Check if resampling is needed
        if orig_sr == target_sr:
            if output_path is None:
                output_path = f"{os.path.splitext(file_path)[0]}_copy.wav"
            wav.write(output_path, orig_sr, data)
            return

        # Convert data to float32 for processing. PCM int16 data ranges -32768 to 32767.
        data_float = data.astype(np.float32)
        if data.dtype == np.int16:
            data_float /= 32768.0  # Normalize to roughly [-1, 1)

        # Resample the data.
        # Handle mono and multi-channel audio:
        if data_float.ndim == 1:
            data_resampled = resampy.resample(data_float, orig_sr, target_sr)
        else:
            # For multi-channel, resample each channel separately.
            channels = []
            for ch in range(data_float.shape[1]):
                ch_resampled = resampy.resample(data_float[:, ch], orig_sr, target_sr)
                channels.append(ch_resampled)
            # Stack channels back into a 2D array (samples x channels)
            data_resampled = np.stack(channels, axis=-1)

        # Convert resampled data back to int16.
        # Scale the float data back to the int16 range and clip to avoid overflow.
        data_resampled = np.clip(data_resampled * 32768, -32768, 32767).astype(np.int16)

        # Determine the output path if not provided.
        if output_path is None:
            base, ext = os.path.splitext(file_path)
            output_path = f"{base}_resampled.wav"

        # Save the resampled file.
        wav.write(output_path, target_sr, data_resampled)

    @staticmethod
    def read_wave(path, target_sr=None):
        ext = os.path.splitext(path)[1]
        if ext == '.wav':
            with contextlib.closing(wave.open(path, 'rb')) as wf:
                comp_type = wf.getcomptype()
                assert comp_type == 'NONE'
            WebrtcVadLabelMaker.resample_pcm_wav(path, target_sr, path)
            with contextlib.closing(wave.open(path, 'rb')) as wf:
                num_channels = wf.getnchannels()
                assert num_channels == 1
                sample_width = wf.getsampwidth()
                assert sample_width == 2
                sample_rate = wf.getframerate()
                assert sample_rate in (8000, 16000, 32000, 48000)
                pcm_data = wf.readframes(wf.getnframes())
                return pcm_data, sample_rate
        elif ext == '.flac':
            with sf.SoundFile(path, "r") as flac_file:
                pcm_data = flac_file.read(dtype="int16").tobytes()
                sample_rate = flac_file.samplerate
                assert sample_rate in (8000, 16000, 32000, 48000)
                num_channels = flac_file.channels
                assert num_channels == 1
                return pcm_data, sample_rate

    @staticmethod
    def find_ones_regions(arr, threshold=0):
        diff = np.diff(arr)
        starts = np.where(diff == 1)[0] + 1  # +1 because diff shifts left
        ends = np.where(diff == -1)[0]

        # Handle edge cases
        if arr[0] == 1:
            starts = np.insert(starts, 0, 0)
        if arr[-1] == 1:
            ends = np.append(ends, len(arr) - 1)

        # Ensure starts and ends are the same length
        if len(starts) != len(ends):
            raise ValueError("Mismatch between starts and ends")

        # Flatten the starts and ends into a single list
        result = []
        for s, e in zip(starts, ends):
            if e - s < threshold:
                continue
            result.extend([s, e])

        if len(result) > 2:
            result = [result[0], result[-1]]

        return result

    def __init__(self, mode=2, vad_window_ms=30, min_region_ms=30, vad_overlap_ratio=0, target_sample_rate=16000, decider_function=None):
        self.vad_window_ms = vad_window_ms
        self.vad_overlap_ratio = vad_overlap_ratio
        self.vad = webrtcvad.Vad(mode)
        self.target_sample_rate = target_sample_rate
        self.decider_function = decider_function
        self.min_region_ms = min_region_ms

    def __call__(self, file_path):
        wave, rate = WebrtcVadLabelMaker.read_wave(file_path, target_sr=self.target_sample_rate)
        if rate != self.target_sample_rate:
            print(f"{file_path} has a rate of {rate} instead of {self.target_sample_rate}")
            return
        window = int(self.vad_window_ms * rate / 1000)
        step = int((1 - self.vad_overlap_ratio) * window)

        samples_count = len(wave) // 2
        samples_pred_sum = np.zeros(len(wave), dtype=np.float32)
        samples_pred_count = np.zeros(len(wave), dtype=np.float32)

        n_frames = int((samples_count - window) / step)
        for i in range(n_frames):
            s = i * step
            e = s + window
            is_speech = self.vad.is_speech(wave[2 * s:2 * e], rate)
            samples_pred_sum[s:e] += is_speech
            samples_pred_count[s:e] += 1

        samples_pred = (samples_pred_sum / (samples_pred_count + 1e-8)) >= 0.5
        ones_regions = self.find_ones_regions(samples_pred.astype(np.int32), threshold=self.min_region_ms * rate / 1000)

        return ones_regions

In [None]:
class SileroVadLabelMaker:
    def __init__(self, sample_rate=8000):
        self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
        (self.get_speech_timestamps, _, self.read_audio, _, _) = utils
        self.sample_rate = sample_rate
    
    def __call__(self, file_path):
        wav = self.read_audio(file_path, sampling_rate=self.sample_rate)
        speech_timestamps = self.get_speech_timestamps(
        wav,
        self.model,
        return_seconds=False,
        sampling_rate=self.sample_rate,
        # min_silence_duration_ms=50,
        # min_speech_duration_ms=150,
        )
        ones_regions = []
        for stamps in speech_timestamps:
            ones_regions.append(stamps['start'])
            ones_regions.append(stamps['end'])
        # if len(ones_regions) > 0:
        #     ones_regions = [ones_regions[0], ones_regions[1]]
        return ones_regions

In [17]:
target_sample_rate = 8000
vad_window_ms = [10, 20, 30][2]
vad_overlap_ratio = 0.9

In [18]:
openSLR_data_directory, ext = 'datasets\google_commands_v2', 'wav'
# openSLR_data_directory, ext = "../data/MSDWild/raw_wav", 'wav'
where_to_save = 'buffer'

# vad = WebrtcVadLabelMaker(
#     mode=2,
#     vad_window_ms=vad_window_ms,
#     vad_overlap_ratio=vad_overlap_ratio, 
#     target_sample_rate=target_sample_rate,
#     min_region_ms=60)
vad = SileroVadLabelMaker()

audio_files_paths = get_files_by_extension(openSLR_data_directory, ext=ext, rel=True)

# labels_path = f'{vad.target_sample_rate}_{vad.vad_window_ms}_{int(vad.vad_overlap_ratio * 100)}_webrtc_labels.csv'
labels_path = f'{vad.sample_rate}_silerovad_labels.csv'
labels_path = os.path.join(where_to_save, labels_path)
os.makedirs(where_to_save, exist_ok=True)
data_samples = len(audio_files_paths)
print(data_samples, "files like:", np.random.choice(audio_files_paths))
print(labels_path)

Using cache found in C:\Users\narek/.cache\torch\hub\snakers4_silero-vad_master


17665 files like: down\627c0bec_nohash_2.wav
buffer\8000_silerovad_labels.csv


In [19]:
if data_samples > 0:
    with open(labels_path, 'w') as file:
        file.write("filename,labels" + '\n')

        t = tqdm(audio_files_paths, total=data_samples)
        webrtcvad_t, write_t = 0, 0
        ma = 0.8
        for i, audio_path in enumerate(t):
            s_vad = time()
            filepath = os.path.join(openSLR_data_directory, audio_path)
            one_stamps = vad(filepath)
            if one_stamps is None:
                continue
            e_vad = time()
            path_parts = audio_path.split(os.sep)
            filename = path_parts[-1]


            file.write(audio_path + ',' + '-'.join(map(str, one_stamps)) + '\n')
            e_write = time()

            webrtcvad_t = ma * webrtcvad_t + (1 - ma) * (e_vad - s_vad)
            write_t = ma * webrtcvad_t + (1 - ma) * (e_write - e_vad)
            if i % 100 == 0:
                t.set_description_str(f"webrtcvad: {webrtcvad_t * 1000:.1f}ms | write: {write_t * 1000:.1f}ms")

else:
    print(len(audio_files_paths), "audio files not found")

webrtcvad: 34.8ms | write: 27.8ms:  50%|█████     | 8857/17665 [04:15<04:14, 34.65it/s]


KeyboardInterrupt: 

In [8]:
df = pd.read_csv(labels_path)

problematics = df[df.isnull().any(axis=1)].filename.values.tolist()
problematics = set(os.path.basename(problematic).split('_')[0] for problematic in problematics)
len(problematics)

mask = df['filename'].apply(lambda x: any(p in x for p in problematics))

df = df[~mask]
df.to_csv(labels_path.replace('.csv', '_filtered.csv'), index=False)