Skip to content

Commit

Permalink
Bugfix adaptive spec augment time masking (#2398)
Browse files Browse the repository at this point in the history
* bugfix adaptive spec augment

Signed-off-by: smajumdar <titu1994@gmail.com>

* Revert freq mask guard

Signed-off-by: smajumdar <titu1994@gmail.com>

* Revert freq mask guard

Signed-off-by: smajumdar <titu1994@gmail.com>

* Remove static time width clamping

Signed-off-by: smajumdar <titu1994@gmail.com>
  • Loading branch information
titu1994 authored Jun 24, 2021
1 parent 5525f69 commit c01740f
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 30 deletions.
8 changes: 4 additions & 4 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,7 +434,7 @@ def input_types(self):
"""
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType(), optional=True),
"length": NeuralType(tuple('B'), LengthsType()),
}

@property
Expand Down Expand Up @@ -474,7 +474,7 @@ def __init__(
mask_value=mask_value,
)
else:
self.spec_augment = lambda input_spec: input_spec
self.spec_augment = lambda input_spec, length: input_spec

# Check if numba is supported, and use a Numba kernel if it is
if use_numba_spec_augment and numba_utils.numba_cuda_is_supported(__NUMBA_MINIMUM_VERSION__):
Expand All @@ -490,15 +490,15 @@ def __init__(
self.spec_augment_numba = None

@typecheck()
def forward(self, input_spec, length=None):
def forward(self, input_spec, length):
augmented_spec = self.spec_cutout(input_spec=input_spec)

# To run the Numba kernel, correct numba version is required as well as
# tensor must be on GPU and length must be provided
if self.spec_augment_numba is not None and spec_augment_launch_heuristics(augmented_spec, length):
augmented_spec = self.spec_augment_numba(input_spec=augmented_spec, length=length)
else:
augmented_spec = self.spec_augment(input_spec=augmented_spec)
augmented_spec = self.spec_augment(input_spec=augmented_spec, length=length)
return augmented_spec


Expand Down
33 changes: 24 additions & 9 deletions nemo/collections/asr/parts/numba/spec_augment/spec_aug_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,11 +251,6 @@ def forward(self, input_spec, length):
sh = input_spec.shape
bs = sh[0]

if self.adaptive_temporal_width:
time_width = max(1, int(sh[2] * self.time_width))
else:
time_width = self.time_width

# Construct the freq and time masks as well as start positions
if self.freq_masks > 0:
freq_starts = torch.randint(
Expand All @@ -267,10 +262,30 @@ def forward(self, input_spec, length):
freq_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)

if self.time_masks > 0:
time_starts = torch.randint(
0, sh[2] - time_width + 1, size=[bs, self.time_masks], device=input_spec.device
)
time_lengths = torch.randint(0, time_width + 1, size=[bs, self.time_masks], device=input_spec.device)
if self.adaptive_temporal_width:
time_width = (length * self.time_width).int().clamp(min=1)
else:
time_width = (
torch.tensor(self.time_width, dtype=torch.int32, device=input_spec.device)
.unsqueeze(0)
.repeat(sh[0])
)

time_starts = []
time_lengths = []
for idx in range(sh[0]):
time_starts.append(
torch.randint(
0, max(1, length[idx] - time_width[idx]), size=[1, self.time_masks], device=input_spec.device
)
)
time_lengths.append(
torch.randint(0, time_width[idx] + 1, size=[1, self.time_masks], device=input_spec.device)
)

time_starts = torch.cat(time_lengths, 0)
time_lengths = torch.cat(time_lengths, 0)

else:
time_starts = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)
time_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=input_spec.device)
Expand Down
23 changes: 13 additions & 10 deletions nemo/collections/asr/parts/submodules/spectr_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import torch.nn as nn

from nemo.core.classes import Typing, typecheck
from nemo.core.neural_types import NeuralType, SpectrogramType
from nemo.core.neural_types import LengthsType, NeuralType, SpectrogramType


class SpecAugment(nn.Module, Typing):
Expand All @@ -43,7 +43,10 @@ class SpecAugment(nn.Module, Typing):
def input_types(self):
"""Returns definitions of module input types
"""
return {"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType())}
return {
"input_spec": NeuralType(('B', 'D', 'T'), SpectrogramType()),
"length": NeuralType(tuple('B'), LengthsType()),
}

@property
def output_types(self):
Expand All @@ -54,7 +57,7 @@ def output_types(self):
def __init__(
self, freq_masks=0, time_masks=0, freq_width=10, time_width=10, rng=None, mask_value=0.0,
):
super(SpecAugment, self).__init__()
super().__init__()

self._rng = random.Random() if rng is None else rng

Expand All @@ -76,14 +79,9 @@ def __init__(

@typecheck()
@torch.no_grad()
def forward(self, input_spec):
def forward(self, input_spec, length):
sh = input_spec.shape

if self.adaptive_temporal_width:
time_width = max(1, int(sh[2] * self.time_width))
else:
time_width = self.time_width

for idx in range(sh[0]):
for i in range(self.freq_masks):
x_left = self._rng.randint(0, sh[1] - self.freq_width)
Expand All @@ -93,7 +91,12 @@ def forward(self, input_spec):
input_spec[idx, x_left : x_left + w, :] = self.mask_value

for i in range(self.time_masks):
y_left = self._rng.randint(0, sh[2] - time_width)
if self.adaptive_temporal_width:
time_width = max(1, int(length[idx] * self.time_width))
else:
time_width = self.time_width

y_left = self._rng.randint(0, max(1, length[idx] - time_width))

w = self._rng.randint(0, time_width)

Expand Down
30 changes: 24 additions & 6 deletions tests/collections/asr/numba/spec_augment/test_spec_aug_numba.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,7 @@ def prepare_data(b, f, t, device='cuda', freq_masks=0, time_masks=0, freq_width=

adaptive_temporal_width = True

if adaptive_temporal_width:
time_width = max(1, int(sh[2] * time_width))
else:
time_width = time_width
orginal_time_width = time_width

# Construct the freq and time masks as well as start positions
if freq_masks > 0:
Expand All @@ -71,8 +68,29 @@ def prepare_data(b, f, t, device='cuda', freq_masks=0, time_masks=0, freq_width=
freq_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)

if time_masks > 0:
time_starts = torch.randint(0, sh[2] - time_width + 1, size=[bs, time_masks], device=x.device)
time_lengths = torch.randint(0, time_width + 1, size=[bs, time_masks], device=x.device)
if adaptive_temporal_width:
time_width = (x_len * orginal_time_width).int().clamp(min=1)
else:
time_width = (
torch.tensor(orginal_time_width, dtype=torch.int32, device=x.device)
.unsqueeze(0)
.repeat(sh[0])
)

time_starts = []
time_lengths = []
for idx in range(sh[0]):
time_starts.append(
torch.randint(
0, max(1, x_len[idx] - time_width[idx]), size=[1, time_masks], device=x.device
)
)
time_lengths.append(
torch.randint(0, time_width[idx] + 1, size=[1, time_masks], device=x.device)
)

time_starts = torch.cat(time_lengths, 0)
time_lengths = torch.cat(time_lengths, 0)
else:
time_starts = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)
time_lengths = torch.zeros([bs, 1], dtype=torch.int64, device=x.device)
Expand Down
2 changes: 1 addition & 1 deletion tests/collections/asr/test_asr_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def test_SpectrogramAugmentationr(self):
input_signal = torch.randn(size=(4, 512))
length = torch.randint(low=161, high=500, size=[4])
res0 = instance0(input_signal=input_signal, length=length)
res = instance1(input_spec=res0[0])
res = instance1(input_spec=res0[0], length=length)

assert res.shape == res0[0].shape

Expand Down

0 comments on commit c01740f

Please sign in to comment.