From 4d78066a434c877fe768d431324ce7cc01b8ef79 Mon Sep 17 00:00:00 2001 From: Mohamed Saad Ibn Seddik Date: Fri, 3 Feb 2023 09:48:29 -0500 Subject: [PATCH] FilterbankFeaturesTA to match FilterbankFeatures Signed-off-by: Mohamed Saad Ibn Seddik --- .../asr/parts/preprocessing/features.py | 44 +++++++++---------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 1a33bca05a47..00452978f554 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -529,22 +529,22 @@ def __init__( if window not in self.torch_windows: raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}") - self._win_length = n_window_size - self._hop_length = n_window_stride + self.win_length = n_window_size + self.hop_length = n_window_stride self._sample_rate = sample_rate self._normalize_strategy = normalize self._use_log = log self._preemphasis_value = preemph - self._log_zero_guard_type = log_zero_guard_type - self._log_zero_guard_value: Union[str, float] = log_zero_guard_value - self._dither_value = dither - self._pad_to = pad_to - self._pad_value = pad_value - self._num_fft = n_fft + self.log_zero_guard_type = log_zero_guard_type + self.log_zero_guard_value: Union[str, float] = log_zero_guard_value + self.dither = dither + self.pad_to = pad_to + self.pad_value = pad_value + self.n_fft = n_fft self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram( sample_rate=self._sample_rate, - win_length=self._win_length, - hop_length=self._hop_length, + win_length=self.win_length, + hop_length=self.hop_length, n_mels=nfilt, window_fn=self.torch_windows[window], mel_scale="slaney", @@ -561,13 +561,13 @@ def filter_banks(self): return self._mel_spec_extractor.mel_scale.fb def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float: - if isinstance(self._log_zero_guard_value, float): - return self._log_zero_guard_value - return getattr(torch.finfo(dtype), self._log_zero_guard_value) + if isinstance(self.log_zero_guard_value, float): + return self.log_zero_guard_value + return getattr(torch.finfo(dtype), self.log_zero_guard_value) def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor: - if self.training and self._dither_value > 0.0: - noise = torch.randn_like(signals) * self._dither_value + if self.training and self.dither > 0.0: + noise = torch.randn_like(signals) * self.dither signals = signals + noise return signals @@ -578,25 +578,25 @@ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor: return signals def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor: - out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long() + out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long() return out_lengths def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor: # Only apply during training; else need to capture dynamic shape for exported models - if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0: + if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0: return features - pad_length = self._pad_to - (features.shape[-1] % self._pad_to) - return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value) + pad_length = self.pad_to - (features.shape[-1] % self.pad_to) + return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value) def _apply_log(self, features: torch.Tensor) -> torch.Tensor: if self._use_log: zero_guard = self._resolve_log_zero_guard_value(features.dtype) - if self._log_zero_guard_type == "add": + if self.log_zero_guard_type == "add": features = features + zero_guard - elif self._log_zero_guard_type == "clamp": + elif self.log_zero_guard_type == "clamp": features = features.clamp(min=zero_guard) else: - raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'") + raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'") features = features.log() return features