Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update perturb.py #5231

Merged
merged 9 commits into from
Oct 26, 2022
42 changes: 38 additions & 4 deletions nemo/collections/asr/parts/preprocessing/perturb.py
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,11 @@ def get_one_noise_sample(self, target_sr):
)

def perturb(self, data, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
ref_mic (int): reference mic index for scaling multi-channel audios
"""
noise = read_one_audiosegment(
self._manifest,
data.sample_rate,
Expand All @@ -442,6 +447,18 @@ def perturb(self, data, ref_mic=0):
self.perturb_with_input_noise(data, noise, ref_mic=ref_mic)

def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if data_rms is None:
data_rms = data.rms_db
Expand All @@ -467,14 +484,31 @@ def perturb_with_input_noise(self, data, noise, data_rms=None, ref_mic=0):
else:
data._samples += noise._samples

def perturb_with_foreground_noise(
self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1,
):
def perturb_with_foreground_noise(self, data, noise, data_rms=None, max_noise_dur=2, max_additions=1, ref_mic=0):
"""
Args:
data (AudioSegment): audio data
noise (AudioSegment): noise data
data_rms (Union[float, List[float]): rms_db for data input
max_noise_dur: (float): max noise duration
max_additions (int): number of times for adding noise
ref_mic (int): reference mic index for scaling multi-channel audios
"""
if data.num_channels != noise.num_channels:
raise ValueError(
f"Found mismatched channels for data ({data.num_channels}) and noise ({noise.num_channels})."
)

snr_db = self._rng.uniform(self._min_snr_db, self._max_snr_db)
if not data_rms:
data_rms = data.rms_db

noise_gain_db = min(data_rms - noise.rms_db - snr_db, self._max_gain_db)
if data.num_channels > 1:
noise_gain_db = data_rms[ref_mic] - noise.rms_db[ref_mic] - snr_db
else:
noise_gain_db = data_rms - noise.rms_db - snr_db
noise_gain_db = min(noise_gain_db, self._max_gain_db)

n_additions = self._rng.randint(1, max_additions)

for i in range(n_additions):
Expand Down