Skip to content

Commit

Permalink
add batch test
Browse files Browse the repository at this point in the history
  • Loading branch information
caillonantoine committed May 11, 2023
1 parent f389f18 commit 2bf45fe
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 12 deletions.
15 changes: 9 additions & 6 deletions extras/generate_test_model.py
@@ -1,7 +1,10 @@
import torch, torch.nn as nn
import nn_tilde
from typing import List, Tuple

import torch
import torch.nn as nn

import nn_tilde


class AudioUtils(nn_tilde.Module):

Expand Down Expand Up @@ -169,7 +172,7 @@ def get_fractal(self) -> Tuple[int, float]:
# return -1 if the attribute was wrong.
@torch.jit.export
def set_gain_factor(self, x: float) -> int:
self.gain_factor= (x,)
self.gain_factor = (x, )
return 0

@torch.jit.export
Expand All @@ -182,14 +185,14 @@ def set_polynomial_factors(self, factor1: float, factor2: float,
@torch.jit.export
def set_saturate_mode(self, x: str):
if (x == 'tanh') or (x == 'clip'):
self.saturate_mode = (x,)
self.saturate_mode = (x, )
return 0
else:
return -1

@torch.jit.export
def set_invert_signal(self, x: bool):
self.invert_signal = (x,)
self.invert_signal = (x, )
return 0

@torch.jit.export
Expand All @@ -198,7 +201,7 @@ def set_fractal(self, factor: int, amount: float):
return -1
elif factor % 2 != 0:
return -1
self.fractal = ( factor,float(amount))
self.fractal = (factor, float(amount))
return 0


Expand Down
13 changes: 9 additions & 4 deletions python_tools/__init__.py
@@ -1,8 +1,8 @@
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import torch
import logging
import inspect
import logging
from typing import Any, Callable, Optional, Sequence, Tuple, Union

import torch

TYPE_HASH = {bool: 0, int: 1, float: 2, str: 3, torch.Tensor: 4}

Expand Down Expand Up @@ -80,6 +80,9 @@ def register_method(
raise ValueError(
("Output tensor must have exactly 3 dimensions, "
f"got {len(y.shape)}"))
if y.shape[0] != 1:
raise ValueError(
f"Expecting single batch output, got {y.shape[0]}")
if y.shape[1] != out_channels:
raise ValueError((
f"Wrong number of output channels for method \"{method_name}\", "
Expand All @@ -89,7 +92,7 @@ def register_method(
(f"Wrong output length for method \"{method_name}\", "
f"expected {test_buffer_size//out_ratio} "
f"got {y.shape[2]}"))

logging.info(f"Testing method {method_name} with mc.nn~ API")
x = torch.zeros(4, in_channels, test_buffer_size // in_ratio)
y = getattr(self, method_name)(x)
Expand All @@ -98,6 +101,8 @@ def register_method(
raise ValueError(
("Output tensor must have exactly 3 dimensions, "
f"got {len(y.shape)}"))
if y.shape[0] != 4:
raise ValueError(f"Expecting 4 batch output, got {y.shape[0]}")
if y.shape[1] != out_channels:
raise ValueError((
f"Wrong number of output channels for method \"{method_name}\", "
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Expand Up @@ -21,12 +21,12 @@
long_description=readme,
long_description_content_type="text/markdown",
packages=['nn_tilde'],
package_dir= {'nn_tilde': 'python_tools'},
package_dir={'nn_tilde': 'python_tools'},
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
install_requires=requirements.split("\n"),
python_requires='>=3.7',
)
)

0 comments on commit 2bf45fe

Please sign in to comment.