Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vocal tract length perturbation #139

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions tests/test_vtlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import numpy as np
import torch

from torch_audiomentations import VocalTractLengthPerturbation


class TestVTLP:
def test_vtlp(self):
samples = torch.randn(3, 1, 32000)
sample_rate = 16000

augment = VocalTractLengthPerturbation(p=1.0, output_type="dict")
processed_samples = augment(
samples=samples, sample_rate=sample_rate
).samples.numpy()

assert processed_samples.shape == samples.shape
assert processed_samples.dtype == np.float32
1 change: 1 addition & 0 deletions torch_audiomentations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from .augmentations.shift import Shift
from .augmentations.shuffle_channels import ShuffleChannels
from .augmentations.time_inversion import TimeInversion
from .augmentations.vtlp import VocalTractLengthPerturbation
from .core.composition import Compose, SomeOf, OneOf
from .utils.config import from_dict, from_yaml
from .utils.convolution import convolve
Expand Down
158 changes: 158 additions & 0 deletions torch_audiomentations/augmentations/vtlp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
import torch
import torchaudio.transforms as T
from torch import Tensor
from typing import Optional

from ..core.transforms_interface import BaseWaveformTransform
from ..utils.object_dict import ObjectDict


class VocalTractLengthPerturbation(BaseWaveformTransform):
"""
Apply Vocal Tract Length Perturbation as defined in
http://www.cs.toronto.edu/~hinton/absps/perturb.pdf
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if you could explain a bit more what this transform actually does, in an "explain it like I'm five" fashion, so that the average developer (including me) can understand

"""
supported_modes = {"per_example"}

supports_multichannel = False
requires_sample_rate = False

supports_target = True
requires_target = False

def __init__(
self,
min_warp_factor: float = 0.9,
max_warp_factor: float = 1.1,
n_fft: int = 1024,
hop_length: int = 256,
mode: str = "per_example",
p: float = 0.5,
p_mode: str = None,
target_rate: int = None,
output_type: Optional[str] = None,
):
"""
:param min_warp_factor: The minimum warp factor to use.
:param max_warp_factor: The maximum warp factor to use.
:param n_fft: The number of FFT bins to use for stft.
:param hop_length: The hop length to use for stft.
:param mode: ``per_example``, ``per_channel``, or ``per_batch``. Default ``per_example``.
:param p:
:param p_mode:
:param sample_rate:
:param target_rate:
:param output_type:
"""
super().__init__(
mode=mode,
p=p,
p_mode=p_mode,
target_rate=target_rate,
output_type=output_type,
)
if min_warp_factor >= max_warp_factor:
raise ValueError("max_warp_factor must be > min_warp_factor")

self.min_warp_factor = min_warp_factor
self.max_warp_factor = max_warp_factor
self.n_fft = n_fft
self.hop_length = hop_length

@classmethod
def get_scale_factors(
cls,
n_freqs: int,
warp_factor: Tensor,
sample_rate: int,
fhi: int = 4800
) -> Tensor:

factors = []
freqs = torch.linspace(0, 1, n_freqs)

f_boundary = fhi * min(warp_factor, 1) / warp_factor
nyquist = sample_rate / 2
scale = fhi * min(warp_factor, 1)

for f in freqs:
f *= sample_rate
if f <= f_boundary:
factors.append(f * warp_factor)
else:
warp_freq = nyquist - (nyquist - scale) / (nyquist - scale / warp_factor) * (nyquist - f)
factors.append(warp_freq)

factors = torch.FloatTensor(factors)
factors *= (n_freqs - 1) / torch.max(factors) # normalize

return factors

def randomize_parameters(
self,
samples: Tensor = None,
sample_rate: Optional[int] = None,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None,
):
dist = torch.distributions.Uniform(
low=torch.tensor(self.min_warp_factor, dtype=torch.float32, device=samples.device),
high=torch.tensor(self.max_warp_factor, dtype=torch.float32, device=samples.device),
validate_args=True,
)
self.transform_parameters['warp_factor'] = dist.sample()

def apply_transform(
self,
samples: Tensor,
sample_rate: int,
targets: Optional[Tensor] = None,
target_rate: Optional[int] = None
) -> ObjectDict:

batch_size, num_channels, num_samples = samples.shape
assert num_channels == 1, "Only single channel audio is supported"

n_to_pad = self.hop_length - (num_samples % self.hop_length) # enforce integer hoplengths for the FFT
padded_samples = torch.nn.functional.pad(samples, (0, n_to_pad), 'constant', 0.)

original_spect = torch.stft(
padded_samples.squeeze(1),
n_fft=self.n_fft,
hop_length=self.hop_length,
return_complex=True,
)
n_freqs = original_spect.size(-2)
transformed_spect = torch.zeros_like(original_spect)

warp_factors = self.get_scale_factors(
n_freqs,
self.transform_parameters['warp_factor'],
sample_rate,
).to(samples.device)

# apply warp factor to spectrogram
for i in range(n_freqs):
if i == 0 or i + 1 >= n_freqs:
transformed_spect[:, i, :] = original_spect[:, i, :]
else:
warp_up = warp_factors[i] - torch.floor(warp_factors[i])
warp_down = 1. - warp_up
pos = int(torch.floor(warp_factors[i]))

transformed_spect[:, pos, :] += warp_down * original_spect[:, i, :]
transformed_spect[:, pos + 1, :] += warp_up * original_spect[:, i, :]

transformed_samples = torch.istft(
transformed_spect,
n_fft=self.n_fft,
hop_length=self.hop_length,
)[:, :-n_to_pad]
transformed_samples = transformed_samples.unsqueeze(1)

return ObjectDict(
samples=transformed_samples,
sample_rate=self.sample_rate,
targets=targets,
target_rate=target_rate,
)