In [None]:
import wave
from time import time
import torch

import os
from tqdm import tqdm
import numpy as np
import webrtcvad
import contextlib
import soundfile as sf

from other.utils import get_files_by_extension

In [None]:
target_sample_rate = 16000
# vad_window_ms = [10, 20, 30][2]
# vad_overlap_ratio = 0.9

In [None]:
class WebrtcVadLabelMaker:
    @staticmethod
    def read_wave(path):
        ext = os.path.splitext(path)[1]
        if ext == '.wav':
            with contextlib.closing(wave.open(path, 'rb')) as wf:
                num_channels = wf.getnchannels()
                assert num_channels == 1
                sample_width = wf.getsampwidth()
                assert sample_width == 2
                sample_rate = wf.getframerate()
                assert sample_rate in (8000, 16000, 32000, 48000)
                pcm_data = wf.readframes(wf.getnframes())
                return pcm_data, sample_rate
        elif ext == '.flac':
            with sf.SoundFile(path, "r") as flac_file:
                pcm_data = flac_file.read(dtype="int16").tobytes()
                sample_rate = flac_file.samplerate
                assert sample_rate in (8000, 16000, 32000, 48000)
                num_channels = flac_file.channels
                assert num_channels == 1
                return pcm_data, sample_rate
    
    @staticmethod   
    def find_ones_regions(arr):
        diff = np.diff(arr)
        starts = np.where(diff == 1)[0] + 1  # +1 because diff shifts left
        ends = np.where(diff == -1)[0]
        
        # Handle edge cases
        if arr[0] == 1:
            starts = np.insert(starts, 0, 0)
        if arr[-1] == 1:
            ends = np.append(ends, len(arr) - 1)
        
        # Ensure starts and ends are the same length
        if len(starts) != len(ends):
            raise ValueError("Mismatch between starts and ends")
        
        # Flatten the starts and ends into a single list
        result = []
        for s, e in zip(starts, ends):
            result.extend([s, e])
        
        return result

    def __init__(self, mode=2, vad_window_ms=30, vad_overlap_ratio=0, target_sample_rate=16000, decider_function=None):
        self.vad_window_ms = vad_window_ms
        self.vad_overlap_ratio = vad_overlap_ratio
        self.vad = webrtcvad.Vad(mode)
        self.target_sample_rate = target_sample_rate
        self.decider_function = decider_function

    def __call__(self, file_path):
        wave, rate = WebrtcVadLabelMaker.read_wave(file_path)
        if rate != self.target_sample_rate:
            print(f"{file_path} has a rate of {rate} instead of {self.target_sample_rate}")
            return
        rate = self.target_sample_rate
        window = int(self.vad_window_ms * rate / 1000)
        step = int((1 - self.vad_overlap_ratio) * window)

        samples_count = len(wave) // 2
        samples_pred_sum = np.zeros(len(wave), dtype=np.float32)
        samples_pred_count = np.zeros(len(wave), dtype=np.float32)

        n_frames = int((samples_count - window) / step)
        for i in range(n_frames):
            s = i * step
            e = s + window
            is_speech = self.vad.is_speech(wave[2 * s:2 * (s + window)], rate)
            samples_pred_sum[s:(s + window)] += is_speech
            samples_pred_count[s:(s + window)] += 1

        samples_pred = (samples_pred_sum / samples_pred_count) >= 0.5

        ones_regions = self.find_ones_regions(samples_pred.astype(np.int32))

        return ones_regions

In [None]:
class SileroVadLabelMaker:
    def __init__(self, sample_rate=8000):
        self.model, utils = torch.hub.load(repo_or_dir='snakers4/silero-vad', model='silero_vad')
        (self.get_speech_timestamps, _, self.read_audio, _, _) = utils
        self.sample_rate = sample_rate
    
    def __call__(self, file_path):
        wav = self.read_audio(file_path, sampling_rate=self.sample_rate)
        speech_timestamps = self.get_speech_timestamps(
        wav,
        self.model,
        return_seconds=False,
        sampling_rate=self.sample_rate
        )
        ones_regions = []
        for stamps in speech_timestamps:
            ones_regions.append(stamps['start'])
            ones_regions.append(stamps['end'])
        return ones_regions

In [None]:
openSLR_data_directory, ext = r'accent-dataset\common_accent', 'wav'
where_to_save = 'buffer'

# vad = WebrtcVadLabelMaker(2, vad_window_ms, vad_overlap_ratio, target_sample_rate)
vad = SileroVadLabelMaker(sample_rate=target_sample_rate)

audio_files_paths = get_files_by_extension(openSLR_data_directory, ext=ext, rel=True)
audio_files_paths = [f for f in audio_files_paths if 'clean' in f]
# labels_path = f'{vad.target_sample_rate}_{vad.vad_window_ms}_{int(vad.vad_overlap_ratio * 100)}_webrtc_labels.csv'
labels_path = f'{vad.sample_rate}_silero_labels.csv'
labels_path = os.path.join(where_to_save, labels_path)
data_samples = len(audio_files_paths)
print(data_samples, "files like:", np.random.choice(audio_files_paths))
print(labels_path)

In [None]:
if data_samples > 0:
    with open(labels_path, 'a') as file:
        # file.write("filename,labels" + '\n')

        t = tqdm(audio_files_paths[:], total=data_samples)
        webrtcvad_t, write_t = 0, 0
        ma = 0.8
        for i, audio_path in enumerate(t):
            s_vad = time()
            filepath = os.path.join(openSLR_data_directory, audio_path)
            one_stamps = vad(filepath)
            if one_stamps is None:
                continue
            e_vad = time()
            path_parts = audio_path.split(os.sep)
            filename = path_parts[-1]

            file.write(audio_path + ',' + '-'.join(map(str, one_stamps)) + '\n')
            e_write = time()

            webrtcvad_t = ma * webrtcvad_t + (1 - ma) * (e_vad - s_vad)
            write_t = ma * webrtcvad_t + (1 - ma) * (e_write - e_vad)
            if i % 100 == 0:
                t.set_description_str(f"webrtcvad: {webrtcvad_t * 1000:.1f}ms | write: {write_t * 1000:.1f}ms")

else:
    print(len(audio_files_paths), "audio files not found")