Skip to content

Commit

Permalink
Fix return type inconsistencies in composition classes
Browse files Browse the repository at this point in the history
  • Loading branch information
iver56 committed Jun 29, 2022
1 parent 0a7e1eb commit 0963d45
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 6 deletions.
27 changes: 26 additions & 1 deletion tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,32 @@


class TestCompose(unittest.TestCase):
def test_compose(self):
def test_compose_without_specifying_output_type(self):
samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32)
sample_rate = 16000

augment = Compose(
[
Gain(min_gain_in_db=-6.000001, max_gain_in_db=-6, p=1.0),
PolarityInversion(p=1.0),
]
)
processed_samples = augment(
samples=torch.from_numpy(samples), sample_rate=sample_rate
)
# This dtype should be torch.Tensor until we switch to ObjectDict as default
assert type(processed_samples) == torch.Tensor
processed_samples = processed_samples.numpy()
expected_factor = -convert_decibels_to_amplitude_ratio(-6)
assert_almost_equal(
processed_samples,
expected_factor
* np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32),
decimal=6,
)
self.assertEqual(processed_samples.dtype, np.float32)

def test_compose_dict(self):
samples = np.array([[[1.0, 0.5, -0.25, -0.125, 0.0]]], dtype=np.float32)
sample_rate = 16000

Expand Down
37 changes: 37 additions & 0 deletions tests/test_one_of.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import unittest

import torch

from torch_audiomentations import PolarityInversion, PeakNormalization, Gain, OneOf
from torch_audiomentations.utils.object_dict import ObjectDict


class TestOneOf(unittest.TestCase):
def setUp(self):
self.sample_rate = 16000
self.audio = torch.randn(1, 1, 16000)

self.transforms = [
Gain(min_gain_in_db=-6.000001, max_gain_in_db=-2, p=1.0),
PolarityInversion(p=1.0),
PeakNormalization(p=1.0),
]

def test_one_of_without_specifying_output_type(self):
augment = OneOf(self.transforms)

self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet
output = augment(
samples=self.audio, sample_rate=self.sample_rate
)
# This dtype should be torch.Tensor until we switch to ObjectDict by default
assert type(output) == torch.Tensor

def test_one_of_dict(self):
augment = OneOf(self.transforms, output_type="dict")

self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet
output = augment(
samples=self.audio, sample_rate=self.sample_rate
)
assert type(output) == ObjectDict
19 changes: 16 additions & 3 deletions tests/test_someof.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from torch_audiomentations import PolarityInversion, PeakNormalization, Gain
from torch_audiomentations import SomeOf
from torch_audiomentations.utils.object_dict import ObjectDict


class TestSomeOf(unittest.TestCase):
Expand All @@ -19,13 +20,25 @@ def setUp(self):
PeakNormalization(p=1.0),
]

def test_someof(self):
def test_someof_without_specifying_output_type(self):
augment = SomeOf(2, self.transforms)

self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet
output = augment(
samples=self.audio, sample_rate=self.sample_rate
)
# This dtype should be torch.Tensor until we switch to ObjectDict by default
assert type(output) == torch.Tensor
self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied

def test_someof_dict(self):
augment = SomeOf(2, self.transforms, output_type="dict")

self.assertEqual(len(augment.transform_indexes), 0) # no transforms applied yet
processed_samples = augment(
output = augment(
samples=self.audio, sample_rate=self.sample_rate
).samples
)
assert type(output) == ObjectDict
self.assertEqual(len(augment.transform_indexes), 2) # 2 transforms applied

def test_someof_with_p_zero(self):
Expand Down
4 changes: 2 additions & 2 deletions torch_audiomentations/core/composition.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(
# FIXME: do we really want to support regular nn.Module?
inputs.samples = self.transforms[i](inputs.samples)

return inputs
return inputs.samples if self.output_type == "tensor" else inputs


class SomeOf(BaseCompose):
Expand Down Expand Up @@ -214,7 +214,7 @@ def forward(
# FIXME: do we really want to support regular nn.Module?
inputs.samples = self.transforms[i](inputs.samples)

return inputs
return inputs.samples if self.output_type == "tensor" else inputs


class OneOf(SomeOf):
Expand Down

0 comments on commit 0963d45

Please sign in to comment.