diff --git a/extras/generate_test_model.py b/extras/generate_test_model.py index 7c81b00..39578c7 100644 --- a/extras/generate_test_model.py +++ b/extras/generate_test_model.py @@ -1,7 +1,10 @@ -import torch, torch.nn as nn -import nn_tilde from typing import List, Tuple +import torch +import torch.nn as nn + +import nn_tilde + class AudioUtils(nn_tilde.Module): @@ -169,7 +172,7 @@ def get_fractal(self) -> Tuple[int, float]: # return -1 if the attribute was wrong. @torch.jit.export def set_gain_factor(self, x: float) -> int: - self.gain_factor= (x,) + self.gain_factor = (x, ) return 0 @torch.jit.export @@ -182,14 +185,14 @@ def set_polynomial_factors(self, factor1: float, factor2: float, @torch.jit.export def set_saturate_mode(self, x: str): if (x == 'tanh') or (x == 'clip'): - self.saturate_mode = (x,) + self.saturate_mode = (x, ) return 0 else: return -1 @torch.jit.export def set_invert_signal(self, x: bool): - self.invert_signal = (x,) + self.invert_signal = (x, ) return 0 @torch.jit.export @@ -198,7 +201,7 @@ def set_fractal(self, factor: int, amount: float): return -1 elif factor % 2 != 0: return -1 - self.fractal = ( factor,float(amount)) + self.fractal = (factor, float(amount)) return 0 diff --git a/python_tools/__init__.py b/python_tools/__init__.py index 22c48f9..b79fe8d 100644 --- a/python_tools/__init__.py +++ b/python_tools/__init__.py @@ -1,8 +1,8 @@ -from typing import Any, Callable, Optional, Sequence, Tuple, Union -import torch -import logging import inspect +import logging +from typing import Any, Callable, Optional, Sequence, Tuple, Union +import torch TYPE_HASH = {bool: 0, int: 1, float: 2, str: 3, torch.Tensor: 4} @@ -80,6 +80,9 @@ def register_method( raise ValueError( ("Output tensor must have exactly 3 dimensions, " f"got {len(y.shape)}")) + if y.shape[0] != 1: + raise ValueError( + f"Expecting single batch output, got {y.shape[0]}") if y.shape[1] != out_channels: raise ValueError(( f"Wrong number of output channels for method \"{method_name}\", " @@ -89,7 +92,7 @@ def register_method( (f"Wrong output length for method \"{method_name}\", " f"expected {test_buffer_size//out_ratio} " f"got {y.shape[2]}")) - + logging.info(f"Testing method {method_name} with mc.nn~ API") x = torch.zeros(4, in_channels, test_buffer_size // in_ratio) y = getattr(self, method_name)(x) @@ -98,6 +101,8 @@ def register_method( raise ValueError( ("Output tensor must have exactly 3 dimensions, " f"got {len(y.shape)}")) + if y.shape[0] != 4: + raise ValueError(f"Expecting 4 batch output, got {y.shape[0]}") if y.shape[1] != out_channels: raise ValueError(( f"Wrong number of output channels for method \"{method_name}\", " diff --git a/setup.py b/setup.py index 09bddf7..55ffd32 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,7 @@ long_description=readme, long_description_content_type="text/markdown", packages=['nn_tilde'], - package_dir= {'nn_tilde': 'python_tools'}, + package_dir={'nn_tilde': 'python_tools'}, classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -29,4 +29,4 @@ ], install_requires=requirements.split("\n"), python_requires='>=3.7', -) \ No newline at end of file +)