Skip to content

Commit

Permalink
Ensure exported models output float tensors
Browse files Browse the repository at this point in the history
  • Loading branch information
hugofloresgarcia committed Jun 19, 2023
1 parent e144b5d commit 8c6fa75
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions python_tools/__init__.py
Expand Up @@ -93,6 +93,8 @@ def register_method(
(f"Wrong output length for method \"{method_name}\", "
f"expected {test_buffer_size//out_ratio} "
f"got {y.shape[2]}"))
if y.dtype != torch.float:
raise ValueError(f"Output tensor must be of type float, got {y.dtype}")

if cc.MAX_BATCH_SIZE > 1:
logging.info(f"Testing method {method_name} with mc.nn~ API")
Expand Down

0 comments on commit 8c6fa75

Please sign in to comment.