Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FilterbankFeaturesTA to match FilterbankFeatures #5913

Merged
merged 2 commits into from
Feb 3, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 22 additions & 22 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,22 +529,22 @@ def __init__(
if window not in self.torch_windows:
raise ValueError(f"Got window value '{window}' but expected a member of {self.torch_windows.keys()}")

self._win_length = n_window_size
self._hop_length = n_window_stride
self.win_length = n_window_size
self.hop_length = n_window_stride
self._sample_rate = sample_rate
self._normalize_strategy = normalize
self._use_log = log
self._preemphasis_value = preemph
self._log_zero_guard_type = log_zero_guard_type
self._log_zero_guard_value: Union[str, float] = log_zero_guard_value
self._dither_value = dither
self._pad_to = pad_to
self._pad_value = pad_value
self._num_fft = n_fft
self.log_zero_guard_type = log_zero_guard_type
self.log_zero_guard_value: Union[str, float] = log_zero_guard_value
self.dither = dither
self.pad_to = pad_to
self.pad_value = pad_value
self.n_fft = n_fft
self._mel_spec_extractor: torchaudio.transforms.MelSpectrogram = torchaudio.transforms.MelSpectrogram(
sample_rate=self._sample_rate,
win_length=self._win_length,
hop_length=self._hop_length,
win_length=self.win_length,
hop_length=self.hop_length,
n_mels=nfilt,
window_fn=self.torch_windows[window],
mel_scale="slaney",
Expand All @@ -561,13 +561,13 @@ def filter_banks(self):
return self._mel_spec_extractor.mel_scale.fb

def _resolve_log_zero_guard_value(self, dtype: torch.dtype) -> float:
if isinstance(self._log_zero_guard_value, float):
return self._log_zero_guard_value
return getattr(torch.finfo(dtype), self._log_zero_guard_value)
if isinstance(self.log_zero_guard_value, float):
return self.log_zero_guard_value
return getattr(torch.finfo(dtype), self.log_zero_guard_value)

def _apply_dithering(self, signals: torch.Tensor) -> torch.Tensor:
if self.training and self._dither_value > 0.0:
noise = torch.randn_like(signals) * self._dither_value
if self.training and self.dither > 0.0:
noise = torch.randn_like(signals) * self.dither
signals = signals + noise
return signals

Expand All @@ -578,25 +578,25 @@ def _apply_preemphasis(self, signals: torch.Tensor) -> torch.Tensor:
return signals

def _compute_output_lengths(self, input_lengths: torch.Tensor) -> torch.Tensor:
out_lengths = input_lengths.div(self._hop_length, rounding_mode="floor").add(1).long()
out_lengths = input_lengths.div(self.hop_length, rounding_mode="floor").add(1).long()
return out_lengths

def _apply_pad_to(self, features: torch.Tensor) -> torch.Tensor:
# Only apply during training; else need to capture dynamic shape for exported models
if not self.training or self._pad_to == 0 or features.shape[-1] % self._pad_to == 0:
if not self.training or self.pad_to == 0 or features.shape[-1] % self.pad_to == 0:
return features
pad_length = self._pad_to - (features.shape[-1] % self._pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self._pad_value)
pad_length = self.pad_to - (features.shape[-1] % self.pad_to)
return torch.nn.functional.pad(features, pad=(0, pad_length), value=self.pad_value)

def _apply_log(self, features: torch.Tensor) -> torch.Tensor:
if self._use_log:
zero_guard = self._resolve_log_zero_guard_value(features.dtype)
if self._log_zero_guard_type == "add":
if self.log_zero_guard_type == "add":
features = features + zero_guard
elif self._log_zero_guard_type == "clamp":
elif self.log_zero_guard_type == "clamp":
features = features.clamp(min=zero_guard)
else:
raise ValueError(f"Unsupported log zero guard type: '{self._log_zero_guard_type}'")
raise ValueError(f"Unsupported log zero guard type: '{self.log_zero_guard_type}'")
features = features.log()
return features

Expand Down