diff --git a/README.md b/README.md index f0790f1..2bf49ab 100644 --- a/README.md +++ b/README.md @@ -34,9 +34,9 @@ Currently supported and tested models from [onnx_zoo](https://github.com/onnx/mo ## Limitations Known current version limitations are: -- `batch_size > 1` could deliver unexpected results due to ambiguity of onnx's BatchNorm layer. -That is why in this case for now we raise an assertion error. -Set `experimental=True` in `ConvertModel` to be able to use `batch_size > 1`. +- `batch_size > 1` is now supported by default. +BatchNorm layers use inference mode (running statistics), which is correct for ONNX models +exported for inference. - Fine tuning and training of converted models was not tested yet, only inference. ## Development diff --git a/onnx2pytorch/convert/model.py b/onnx2pytorch/convert/model.py index ea16ff1..827960c 100644 --- a/onnx2pytorch/convert/model.py +++ b/onnx2pytorch/convert/model.py @@ -96,8 +96,7 @@ def __init__( batch_dim: int Dimension of the batch. experimental: bool - Experimental implementation allows batch_size > 1. However, - batchnorm layers could potentially produce false outputs. + At the moment it does not do anything anymore. Default: False enable_pruning: bool Track kept/pruned indices between different calls to forward pass. @@ -143,12 +142,6 @@ def __init__( self.onnx_model.graph, self, self.mapping ) - if experimental: - warnings.warn( - "Using experimental implementation that allows 'batch_size > 1'." - "Batchnorm layers could potentially produce false outputs." - ) - def forward(self, *input_list, **input_dict): if len(input_list) > 0 and len(input_dict) > 0: raise ValueError( @@ -160,10 +153,6 @@ def forward(self, *input_list, **input_dict): if len(input_dict) > 0: inputs = [input_dict[key] for key in self.input_names] - if not self.experimental and inputs[0].shape[self.batch_dim] > 1: - raise NotImplementedError( - "Input with larger batch size than 1 not supported yet." - ) activations = dict(zip(self.input_names, inputs)) still_needed_by = deepcopy(self.needed_by) diff --git a/onnx2pytorch/convert/operations.py b/onnx2pytorch/convert/operations.py index 66d6869..dbd0577 100644 --- a/onnx2pytorch/convert/operations.py +++ b/onnx2pytorch/convert/operations.py @@ -232,6 +232,10 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr op = partial(torch.prod, **kwargs) elif node.op_type == "ReduceSum": op = ReduceSum(opset_version=opset_version, **extract_attributes(node)) + elif node.op_type == "ReduceSumSquare": + op = ReduceSumSquare( + opset_version=opset_version, **extract_attributes(node) + ) elif node.op_type == "ReduceL2": op = ReduceL2(opset_version=opset_version, **extract_attributes(node)) elif node.op_type == "Relu": @@ -266,9 +270,16 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr kwargs = dict(dim=-1) kwargs.update(extract_attributes(node)) op = nn.Softmax(**kwargs) + elif node.op_type == "LogSoftmax": + kwargs = dict(dim=-1) + kwargs.update(extract_attributes(node)) + op = nn.LogSoftmax(**kwargs) elif node.op_type == "Softplus": + # ONNX Softplus has no attributes: y = ln(exp(x) + 1) + # PyTorch Softplus with beta=1 matches ONNX spec op = nn.Softplus(beta=1) elif node.op_type == "Softsign": + # ONNX Softsign has no attributes: y = x / (1 + |x|) op = nn.Softsign() elif node.op_type == "Split": kwargs = extract_attributes(node) diff --git a/onnx2pytorch/operations/__init__.py b/onnx2pytorch/operations/__init__.py index 6c25656..9459e01 100644 --- a/onnx2pytorch/operations/__init__.py +++ b/onnx2pytorch/operations/__init__.py @@ -30,6 +30,7 @@ from .randomuniformlike import RandomUniformLike from .reducemax import ReduceMax from .reducesum import ReduceSum +from .reducesumsquare import ReduceSumSquare from .reducel2 import ReduceL2 from .reshape import Reshape from .resize import Resize, Upsample @@ -80,6 +81,7 @@ "RandomUniformLike", "ReduceMax", "ReduceSum", + "ReduceSumSquare", "ReduceL2", "Reshape", "Resize", diff --git a/onnx2pytorch/operations/batchnorm.py b/onnx2pytorch/operations/batchnorm.py index e4ba779..087f195 100644 --- a/onnx2pytorch/operations/batchnorm.py +++ b/onnx2pytorch/operations/batchnorm.py @@ -10,7 +10,6 @@ class _LazyBatchNorm(_LazyNormBase, _BatchNorm): cls_to_become = _BatchNorm - except ImportError: # for torch < 1.10.0 from torch.nn.modules.batchnorm import _LazyBatchNorm @@ -49,6 +48,9 @@ def __init__(self, torch_params, *args, **kwargs): for key, value in zip(keys, torch_params): getattr(self.bnu, key).data = value + # Set to eval mode to use running statistics (ONNX inference behavior) + self.bnu.eval() + def forward(self, X, scale=None, B=None, input_mean=None, input_var=None): if self.has_lazy: self.bnu.initialize_parameters(X) @@ -56,7 +58,7 @@ def forward(self, X, scale=None, B=None, input_mean=None, input_var=None): if scale is not None: getattr(self.bnu, "weight").data = scale if B is not None: - getattr(self.bnu, "bias").data = scale + getattr(self.bnu, "bias").data = B if input_mean is not None: getattr(self.bnu, "running_mean").data = input_mean if input_var is not None: diff --git a/onnx2pytorch/operations/loop.py b/onnx2pytorch/operations/loop.py index 80ffe0f..22083fd 100644 --- a/onnx2pytorch/operations/loop.py +++ b/onnx2pytorch/operations/loop.py @@ -1,10 +1,7 @@ from collections import defaultdict -from copy import deepcopy from functools import partial from importlib import import_module -import warnings -import numpy as np import onnx import torch from onnx import numpy_helper @@ -71,7 +68,6 @@ def forward(self, enclosing_modules, enclosing_activations, *inputs): """ N = len(self.input_names) - 2 - K = len(self.output_names) - (1 + N) M = inputs[0] cond = inputs[1] diff --git a/onnx2pytorch/operations/reducesumsquare.py b/onnx2pytorch/operations/reducesumsquare.py new file mode 100644 index 0000000..29b2b8b --- /dev/null +++ b/onnx2pytorch/operations/reducesumsquare.py @@ -0,0 +1,43 @@ +import torch +from torch import nn + + +class ReduceSumSquare(nn.Module): + """ + Computes the sum of the squared elements of the input tensor's elements along the provided axes. + + Equivalent to ReduceSum(Square(data), axes, keepdim). + """ + + def __init__( + self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False + ): + self.opset_version = opset_version + self.dim = dim + self.keepdim = bool(keepdim) + self.noop_with_empty_axes = noop_with_empty_axes + super().__init__() + + def forward(self, data: torch.Tensor, axes: torch.Tensor = None): + # In opset < 13, axes is an attribute (self.dim) + # In opset >= 13, axes is an optional input + if self.opset_version < 13: + dims = self.dim + else: + dims = axes + + if dims is None: + if self.noop_with_empty_axes: + return data + else: + # Reduce over all dimensions + dims = tuple(range(data.ndim)) + + if isinstance(dims, int): + dim = dims + else: + dim = tuple(list(dims)) + + # Compute sum of squares: sum(x^2) + ret = torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim) + return ret diff --git a/tests/onnx2pytorch/operations/test_batchnorm.py b/tests/onnx2pytorch/operations/test_batchnorm.py new file mode 100644 index 0000000..e1985bb --- /dev/null +++ b/tests/onnx2pytorch/operations/test_batchnorm.py @@ -0,0 +1,307 @@ +import numpy as np +import onnxruntime as ort +import pytest +import torch +from onnx import helper, TensorProto + +from onnx2pytorch.convert import ConvertModel +from onnx2pytorch.operations.batchnorm import BatchNormWrapper + + +@pytest.mark.parametrize( + "batch_size,channels,height,width,epsilon,momentum", + [ + # Test with batch_size=1 + (1, 3, 5, 5, 1e-5, 0.9), + # Test with batch_size>1 (the critical case) + (2, 3, 5, 5, 1e-5, 0.9), + (4, 3, 5, 5, 1e-5, 0.9), + (8, 16, 7, 7, 1e-5, 0.9), + # Test with different epsilons + (2, 3, 5, 5, 1e-3, 0.9), + (2, 3, 5, 5, 1e-7, 0.9), + # Test with different momentums + (2, 3, 5, 5, 1e-5, 0.1), + (2, 3, 5, 5, 1e-5, 0.99), + # Test with different spatial dimensions + (2, 8, 10, 10, 1e-5, 0.9), + (2, 16, 3, 3, 1e-5, 0.9), + ], +) +def test_batchnorm_onnxruntime(batch_size, channels, height, width, epsilon, momentum): + """Test BatchNorm against onnxruntime with various batch sizes.""" + np.random.seed(42) + torch.manual_seed(42) + + # Create input + X = np.random.randn(batch_size, channels, height, width).astype(np.float32) + + # Create BatchNorm parameters + scale = np.random.randn(channels).astype(np.float32) + bias = np.random.randn(channels).astype(np.float32) + mean = np.random.randn(channels).astype(np.float32) + var = np.abs(np.random.randn(channels).astype(np.float32)) + 0.1 # Ensure positive + + # Create ONNX graph with BatchNormalization node + input_tensor = helper.make_tensor_value_info( + "X", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + output_tensor = helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + + scale_init = helper.make_tensor( + "scale", TensorProto.FLOAT, [channels], scale.tolist() + ) + bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist()) + mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist()) + var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist()) + + bn_node = helper.make_node( + "BatchNormalization", + inputs=["X", "scale", "B", "mean", "var"], + outputs=["Y"], + epsilon=epsilon, + momentum=momentum, + ) + + graph = helper.make_graph( + [bn_node], + "batchnorm_test", + [input_tensor], + [output_tensor], + [scale_init, bias_init, mean_init, var_init], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + msg=f"BatchNorm mismatch for batch_size={batch_size}, channels={channels}", + ) + + +def test_batchnorm_bias_fix(): + """Test that the bias parameter is correctly applied (not overwritten by scale).""" + np.random.seed(42) + + batch_size = 2 + channels = 4 + height, width = 5, 5 + + X = np.random.randn(batch_size, channels, height, width).astype(np.float32) + + # Create BatchNorm parameters with distinct scale and bias + scale = np.ones(channels, dtype=np.float32) * 2.0 # Scale = 2 + bias = np.ones(channels, dtype=np.float32) * 5.0 # Bias = 5 (should NOT be 2!) + mean = np.zeros(channels, dtype=np.float32) + var = np.ones(channels, dtype=np.float32) + + # Create ONNX model + input_tensor = helper.make_tensor_value_info( + "X", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + output_tensor = helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + + scale_init = helper.make_tensor( + "scale", TensorProto.FLOAT, [channels], scale.tolist() + ) + bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist()) + mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist()) + var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist()) + + bn_node = helper.make_node( + "BatchNormalization", + inputs=["X", "scale", "B", "mean", "var"], + outputs=["Y"], + epsilon=1e-5, + ) + + graph = helper.make_graph( + [bn_node], + "batchnorm_bias_test", + [input_tensor], + [output_tensor], + [scale_init, bias_init, mean_init, var_init], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime (ground truth) + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # If bias was incorrectly set to scale (the bug), outputs would differ + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + msg="Bias parameter was not correctly applied", + ) + + # Verify that the output includes the bias (should be around 5, not 2) + # After normalization: (X - 0) / sqrt(1 + eps) * 2 + 5 ≈ X * 2 + 5 + # The mean should be around 5 (from bias), not 2 (from scale) + output_mean_per_channel = o2p_output.mean(dim=(0, 2, 3)) + # The mean should be close to bias (5), not scale (2) + # Note: This is approximate since X is random + assert torch.allclose( + output_mean_per_channel, torch.tensor([5.0] * channels), rtol=1, atol=1 + ) + + +def test_batchnorm_eval_mode(): + """Test that BatchNorm uses eval mode (running statistics).""" + + channels = 4 + scale = torch.ones(channels) + bias = torch.zeros(channels) + running_mean = torch.randn(channels) + running_var = torch.abs(torch.randn(channels)) + 0.1 + + # Create BatchNormWrapper + bn_wrapper = BatchNormWrapper([scale, bias, running_mean, running_var]) + + # Verify it's in eval mode + assert not bn_wrapper.bnu.training, "BatchNorm should be in eval mode" + + # Test with batch_size > 1 + X = torch.randn(4, channels, 5, 5) + + output = bn_wrapper(X) + + # In eval mode, it should use running_mean and running_var, + # not compute statistics from the current batch + # Verify output shape + assert output.shape == X.shape + + +def test_batchnorm_formula(): + """Test that BatchNorm implements the correct formula.""" + batch_size = 2 + channels = 3 + height, width = 4, 4 + + X = torch.randn(batch_size, channels, height, width) + + scale = torch.ones(channels) * 2.0 + bias = torch.ones(channels) * 3.0 + mean = torch.zeros(channels) + var = torch.ones(channels) + epsilon = 1e-5 + + # Manual computation: Y = scale * (X - mean) / sqrt(var + epsilon) + bias + expected = scale.view(1, -1, 1, 1) * (X - mean.view(1, -1, 1, 1)) / torch.sqrt( + var.view(1, -1, 1, 1) + epsilon + ) + bias.view(1, -1, 1, 1) + + # Using BatchNormWrapper + + bn_wrapper = BatchNormWrapper([scale, bias, mean, var], eps=epsilon) + output = bn_wrapper(X) + + torch.testing.assert_close(output, expected, rtol=1e-5, atol=1e-5) + + +@pytest.mark.parametrize("batch_size", [1, 2, 4, 8]) +def test_batchnorm_consistency_across_batch_sizes(batch_size): + """Test that BatchNorm produces consistent results across different batch sizes.""" + np.random.seed(42) + torch.manual_seed(42) + + channels = 8 + height, width = 6, 6 + + # Create a deterministic input pattern + X = np.random.randn(batch_size, channels, height, width).astype(np.float32) + + scale = np.random.randn(channels).astype(np.float32) + bias = np.random.randn(channels).astype(np.float32) + mean = np.random.randn(channels).astype(np.float32) + var = np.abs(np.random.randn(channels).astype(np.float32)) + 0.1 + + # Create ONNX model + input_tensor = helper.make_tensor_value_info( + "X", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + output_tensor = helper.make_tensor_value_info( + "Y", TensorProto.FLOAT, [batch_size, channels, height, width] + ) + + scale_init = helper.make_tensor( + "scale", TensorProto.FLOAT, [channels], scale.tolist() + ) + bias_init = helper.make_tensor("B", TensorProto.FLOAT, [channels], bias.tolist()) + mean_init = helper.make_tensor("mean", TensorProto.FLOAT, [channels], mean.tolist()) + var_init = helper.make_tensor("var", TensorProto.FLOAT, [channels], var.tolist()) + + bn_node = helper.make_node( + "BatchNormalization", + inputs=["X", "scale", "B", "mean", "var"], + outputs=["Y"], + epsilon=1e-5, + ) + + graph = helper.make_graph( + [bn_node], + "batchnorm_consistency_test", + [input_tensor], + [output_tensor], + [scale_init, bias_init, mean_init, var_init], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Should match onnxruntime regardless of batch size + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + msg=f"BatchNorm failed for batch_size={batch_size}", + ) diff --git a/tests/onnx2pytorch/operations/test_logsoftmax.py b/tests/onnx2pytorch/operations/test_logsoftmax.py new file mode 100644 index 0000000..b36c6cf --- /dev/null +++ b/tests/onnx2pytorch/operations/test_logsoftmax.py @@ -0,0 +1,135 @@ +import numpy as np +import onnxruntime as ort +import pytest +import torch +from onnx import helper, TensorProto + +from onnx2pytorch.convert import ConvertModel + + +@pytest.mark.parametrize( + "axis,input_shape", + [ + (-1, [2, 3, 4]), # Default axis=-1 + (0, [2, 3, 4]), + (1, [2, 3, 4]), + (2, [2, 3, 4]), + (-2, [2, 3, 4]), + (1, [5, 10]), # 2D input + (-1, [8]), # 1D input + ], +) +def test_logsoftmax_onnxruntime(axis, input_shape): + """Test LogSoftmax against onnxruntime.""" + np.random.seed(42) + + # Create input + X = np.random.randn(*input_shape).astype(np.float32) + + # Create ONNX graph with LogSoftmax node + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape) + + logsoftmax_node = helper.make_node( + "LogSoftmax", + inputs=["X"], + outputs=["Y"], + axis=axis, + ) + + graph = helper.make_graph( + [logsoftmax_node], + "logsoftmax_test", + [input_tensor], + [output_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 13)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + ) + + +def test_logsoftmax_default_axis(): + """Test LogSoftmax with default axis=-1.""" + np.random.seed(42) + + input_shape = [2, 3, 4] + X = np.random.randn(*input_shape).astype(np.float32) + + # Create ONNX graph WITHOUT specifying axis (should default to -1) + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape) + + logsoftmax_node = helper.make_node( + "LogSoftmax", + inputs=["X"], + outputs=["Y"], + # No axis specified - should default to -1 + ) + + graph = helper.make_graph( + [logsoftmax_node], + "logsoftmax_test", + [input_tensor], + [output_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 13)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + ) + + +def test_logsoftmax_properties(): + """Test mathematical properties of LogSoftmax.""" + # LogSoftmax(x) = log(Softmax(x)) + X = torch.randn(2, 5) + + logsoftmax_output = torch.nn.functional.log_softmax(X, dim=-1) + softmax_output = torch.nn.functional.softmax(X, dim=-1) + log_of_softmax = torch.log(softmax_output) + + torch.testing.assert_close(logsoftmax_output, log_of_softmax, rtol=1e-5, atol=1e-5) + + # Sum of exp(log_softmax) should be 1 + sum_exp = torch.exp(logsoftmax_output).sum(dim=-1) + torch.testing.assert_close(sum_exp, torch.ones_like(sum_exp), rtol=1e-5, atol=1e-5) diff --git a/tests/onnx2pytorch/operations/test_reducesumsquare.py b/tests/onnx2pytorch/operations/test_reducesumsquare.py new file mode 100644 index 0000000..2daa821 --- /dev/null +++ b/tests/onnx2pytorch/operations/test_reducesumsquare.py @@ -0,0 +1,224 @@ +import numpy as np +import onnxruntime as ort +import pytest +import torch +from onnx import helper, TensorProto + +from onnx2pytorch.convert import ConvertModel +from onnx2pytorch.operations.reducesumsquare import ReduceSumSquare + + +@pytest.mark.parametrize( + "input_shape,axes,keepdims", + [ + # Test with different axes + ([3, 4, 5], [0], 1), + ([3, 4, 5], [1], 1), + ([3, 4, 5], [2], 1), + ([3, 4, 5], [-1], 1), + # Test with multiple axes + ([3, 4, 5], [0, 1], 1), + ([3, 4, 5], [1, 2], 1), + ([3, 4, 5], [0, 2], 1), + # Test with keepdims=0 + ([3, 4, 5], [1], 0), + ([3, 4, 5], [0, 2], 0), + # Test with all axes (None means reduce all) + ([3, 4, 5], None, 1), + ([3, 4, 5], None, 0), + # Test 2D inputs + ([5, 10], [0], 1), + ([5, 10], [1], 1), + ([5, 10], None, 1), + # Test 1D inputs + ([10], [0], 1), + ([10], None, 1), + ], +) +def test_reducesumsquare_onnxruntime(input_shape, axes, keepdims): + """Test ReduceSumSquare against onnxruntime.""" + np.random.seed(42) + + # Create input + X = np.random.randn(*input_shape).astype(np.float32) + + # Create ONNX graph with ReduceSumSquare node + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, None) + + # Use axes as attribute (supported in all opset versions) + node_attrs = {"keepdims": keepdims} + if axes is not None: + node_attrs["axes"] = axes + + reducesumsquare_node = helper.make_node( + "ReduceSumSquare", + inputs=["X"], + outputs=["Y"], + **node_attrs, + ) + + graph = helper.make_graph( + [reducesumsquare_node], + "reducesumsquare_test", + [input_tensor], + [output_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + ) + + +def test_reducesumsquare_formula(): + """Test that ReduceSumSquare implements sum(x^2).""" + X = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + + # Manual computation + expected_all = torch.sum(X**2) + expected_axis0 = torch.sum(X**2, dim=0, keepdim=True) + expected_axis1 = torch.sum(X**2, dim=1, keepdim=True) + + # Test reduce all + + op_all = ReduceSumSquare(opset_version=13, dim=None, keepdim=True) + result_all = op_all(X) + torch.testing.assert_close( + result_all, expected_all.view(1, 1), rtol=1e-6, atol=1e-6 + ) + + # Test reduce axis 0 + op_axis0 = ReduceSumSquare(opset_version=11, dim=0, keepdim=True) + result_axis0 = op_axis0(X) + torch.testing.assert_close(result_axis0, expected_axis0, rtol=1e-6, atol=1e-6) + + # Test reduce axis 1 + op_axis1 = ReduceSumSquare(opset_version=11, dim=1, keepdim=True) + result_axis1 = op_axis1(X) + torch.testing.assert_close(result_axis1, expected_axis1, rtol=1e-6, atol=1e-6) + + +def test_reducesumsquare_keepdims(): + """Test keepdims parameter.""" + X = torch.randn(2, 3, 4) + + # With keepdims=True + op_keep = ReduceSumSquare(opset_version=11, dim=1, keepdim=True) + result_keep = op_keep(X) + assert result_keep.shape == (2, 1, 4) + + # With keepdims=False + op_no_keep = ReduceSumSquare(opset_version=11, dim=1, keepdim=False) + result_no_keep = op_no_keep(X) + assert result_no_keep.shape == (2, 4) + + # Values should be the same (just different shapes) + torch.testing.assert_close( + result_keep.squeeze(1), result_no_keep, rtol=1e-6, atol=1e-6 + ) + + +def test_reducesumsquare_noop_with_empty_axes(): + """Test noop_with_empty_axes parameter.""" + X = torch.randn(2, 3, 4) + + # With noop_with_empty_axes=True and no axes, should return input unchanged + op_noop = ReduceSumSquare( + opset_version=13, dim=None, keepdim=True, noop_with_empty_axes=True + ) + result_noop = op_noop(X) + torch.testing.assert_close(result_noop, X, rtol=1e-6, atol=1e-6) + + # With noop_with_empty_axes=False and no axes, should reduce all + op_reduce = ReduceSumSquare( + opset_version=13, dim=None, keepdim=True, noop_with_empty_axes=False + ) + result_reduce = op_reduce(X) + expected = torch.sum(X**2).view(1, 1, 1) + torch.testing.assert_close(result_reduce, expected, rtol=1e-6, atol=1e-6) + + +def test_reducesumsquare_with_axes_input(): + """Test with axes as an input tensor (for frameworks that support it).""" + X = torch.randn(2, 3, 4) + + # Opset 13+ supports axes as input + op = ReduceSumSquare(opset_version=13, dim=None, keepdim=True) + + # Provide axes as a tensor + axes = torch.tensor([0, 2], dtype=torch.int64) + result = op(X, axes) + + # Expected: reduce along axes 0 and 2 + expected = torch.sum(X**2, dim=(0, 2), keepdim=True) + torch.testing.assert_close(result, expected, rtol=1e-6, atol=1e-6) + assert result.shape == (1, 3, 1) + + +def test_reducesumsquare_vs_reducesum_square(): + """Test that ReduceSumSquare(x) == ReduceSum(Square(x)).""" + X = torch.randn(3, 4, 5) + + # ReduceSumSquare + op_sumsquare = ReduceSumSquare(opset_version=11, dim=1, keepdim=True) + result_sumsquare = op_sumsquare(X) + + # ReduceSum(Square(x)) + result_square_sum = torch.sum(X**2, dim=1, keepdim=True) + + torch.testing.assert_close( + result_sumsquare, result_square_sum, rtol=1e-6, atol=1e-6 + ) + + +def test_reducesumsquare_negative_axis(): + """Test with negative axis values.""" + X = torch.randn(2, 3, 4) + + # axis=-1 should be equivalent to axis=2 + op_neg = ReduceSumSquare(opset_version=11, dim=-1, keepdim=True) + result_neg = op_neg(X) + + op_pos = ReduceSumSquare(opset_version=11, dim=2, keepdim=True) + result_pos = op_pos(X) + + torch.testing.assert_close(result_neg, result_pos, rtol=1e-6, atol=1e-6) + + +def test_reducesumsquare_gradient(): + """Test that gradients flow correctly through ReduceSumSquare.""" + X = torch.randn(2, 3, 4, requires_grad=True) + + op = ReduceSumSquare(opset_version=11, dim=1, keepdim=True) + result = op(X) + + # Compute gradient + loss = result.sum() + loss.backward() + + # Gradient of sum(x^2) with respect to x is 2x + # After summing along dim=1, gradient should be 2x broadcast along dim=1 + expected_grad = 2 * X + + assert X.grad is not None + torch.testing.assert_close(X.grad, expected_grad, rtol=1e-5, atol=1e-5) diff --git a/tests/onnx2pytorch/operations/test_softplus.py b/tests/onnx2pytorch/operations/test_softplus.py new file mode 100644 index 0000000..bbba7ae --- /dev/null +++ b/tests/onnx2pytorch/operations/test_softplus.py @@ -0,0 +1,127 @@ +import numpy as np +import onnxruntime as ort +import pytest +import torch +from onnx import helper, TensorProto + +from onnx2pytorch.convert import ConvertModel + + +@pytest.mark.parametrize( + "input_shape", + [ + [2, 3, 4], + [5, 10], + [8], + [1, 1, 5, 5], + ], +) +def test_softplus_default_onnxruntime(input_shape): + """Test Softplus with default parameters against onnxruntime.""" + np.random.seed(42) + + # Create input + X = np.random.randn(*input_shape).astype(np.float32) + + # Create ONNX graph with Softplus node (default parameters) + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape) + + softplus_node = helper.make_node( + "Softplus", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [softplus_node], + "softplus_test", + [input_tensor], + [output_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + ) + + +def test_softplus_properties(): + """Test mathematical properties of Softplus.""" + # Softplus(x) = log(1 + exp(x)) + X = torch.randn(10, 20) + + softplus_output = torch.nn.functional.softplus(X) + manual_output = torch.log(1 + torch.exp(X)) + + # Note: For very large X, exp(X) overflows, so softplus uses approximation + # Compare only for reasonable values + mask = X < 10 + torch.testing.assert_close( + softplus_output[mask], manual_output[mask], rtol=1e-5, atol=1e-5 + ) + + # Softplus should always be positive + assert (softplus_output > 0).all() + + # For large positive x, softplus(x) ≈ x + large_x = torch.tensor([10.0, 20.0, 50.0]) + softplus_large = torch.nn.functional.softplus(large_x) + torch.testing.assert_close(softplus_large, large_x, rtol=1e-2, atol=1e-2) + + # For large negative x, softplus(x) ≈ 0 + small_x = torch.tensor([-10.0, -20.0, -50.0]) + softplus_small = torch.nn.functional.softplus(small_x) + assert (softplus_small < 0.01).all() + + +def test_softplus_vs_relu(): + """Test that Softplus is a smooth approximation of ReLU.""" + X = torch.linspace(-5, 5, 100) + + softplus_output = torch.nn.functional.softplus(X) + relu_output = torch.nn.functional.relu(X) + + # Softplus should be close to ReLU for large positive values + mask = X > 3 + torch.testing.assert_close( + softplus_output[mask], relu_output[mask], rtol=0.1, atol=0.1 + ) + + # Softplus should be smooth (no sharp corner at 0 like ReLU) + # At x=0: softplus(0) = log(2) ≈ 0.693, relu(0) = 0 + softplus_at_zero = torch.nn.functional.softplus(torch.tensor([0.0])) + assert abs(softplus_at_zero.item() - 0.693) < 0.01 + + +def test_softplus_gradient(): + """Test that Softplus gradient is sigmoid.""" + # d/dx softplus(x) = sigmoid(x) = 1/(1 + exp(-x)) + X = torch.randn(5, 5, requires_grad=True) + + output = torch.nn.functional.softplus(X) + output.sum().backward() + + # Gradient should be sigmoid(X) + expected_grad = torch.sigmoid(X) + + torch.testing.assert_close(X.grad, expected_grad, rtol=1e-5, atol=1e-5) diff --git a/tests/onnx2pytorch/operations/test_softsign.py b/tests/onnx2pytorch/operations/test_softsign.py new file mode 100644 index 0000000..4188925 --- /dev/null +++ b/tests/onnx2pytorch/operations/test_softsign.py @@ -0,0 +1,166 @@ +import numpy as np +import onnxruntime as ort +import pytest +import torch +from onnx import helper, TensorProto + +from onnx2pytorch.convert import ConvertModel + + +@pytest.mark.parametrize( + "input_shape", + [ + [2, 3, 4], + [5, 10], + [8], + [1, 1, 5, 5], + [3, 3, 3, 3], + ], +) +def test_softsign_onnxruntime(input_shape): + """Test Softsign against onnxruntime.""" + np.random.seed(42) + + # Create input with varied values (positive, negative, zero) + X = ( + np.random.randn(*input_shape).astype(np.float32) * 5 + ) # Scale to get larger values + + # Create ONNX graph with Softsign node + input_tensor = helper.make_tensor_value_info("X", TensorProto.FLOAT, input_shape) + output_tensor = helper.make_tensor_value_info("Y", TensorProto.FLOAT, input_shape) + + softsign_node = helper.make_node( + "Softsign", + inputs=["X"], + outputs=["Y"], + ) + + graph = helper.make_graph( + [softsign_node], + "softsign_test", + [input_tensor], + [output_tensor], + ) + + model = helper.make_model( + graph, opset_imports=[helper.make_opsetid("", 11)], ir_version=8 + ) + + # Run with onnxruntime + ort_session = ort.InferenceSession(model.SerializeToString()) + ort_outputs = ort_session.run(None, {"X": X}) + expected_Y = ort_outputs[0] + + # Convert to PyTorch and run + o2p_model = ConvertModel(model, experimental=True) + X_torch = torch.from_numpy(X) + + with torch.no_grad(): + o2p_output = o2p_model(X_torch) + + # Compare outputs + torch.testing.assert_close( + o2p_output, + torch.from_numpy(expected_Y), + rtol=1e-5, + atol=1e-5, + ) + + +def test_softsign_formula(): + """Test that Softsign implements x / (1 + |x|).""" + X = torch.tensor([-5.0, -2.0, -1.0, 0.0, 1.0, 2.0, 5.0]) + + softsign_output = torch.nn.functional.softsign(X) + manual_output = X / (1 + torch.abs(X)) + + torch.testing.assert_close(softsign_output, manual_output, rtol=1e-6, atol=1e-6) + + +def test_softsign_properties(): + """Test mathematical properties of Softsign.""" + X = torch.linspace(-10, 10, 100) + + softsign_output = torch.nn.functional.softsign(X) + + # Softsign output should be in range (-1, 1) + assert (softsign_output > -1).all() + assert (softsign_output < 1).all() + + # Softsign(0) = 0 + zero_output = torch.nn.functional.softsign(torch.tensor([0.0])) + assert abs(zero_output.item()) < 1e-6 + + # Softsign is odd function: softsign(-x) = -softsign(x) + X_test = torch.tensor([1.0, 2.0, 3.0, 5.0]) + softsign_pos = torch.nn.functional.softsign(X_test) + softsign_neg = torch.nn.functional.softsign(-X_test) + torch.testing.assert_close(softsign_neg, -softsign_pos, rtol=1e-6, atol=1e-6) + + # For large |x|, softsign(x) approaches ±1 + large_pos = torch.nn.functional.softsign(torch.tensor([100.0])) + large_neg = torch.nn.functional.softsign(torch.tensor([-100.0])) + assert abs(large_pos.item() - 1.0) < 0.01 + assert abs(large_neg.item() + 1.0) < 0.01 + + +def test_softsign_vs_tanh(): + """Test that Softsign is similar to tanh but with different shape.""" + X = torch.linspace(-5, 5, 100) + + softsign_output = torch.nn.functional.softsign(X) + tanh_output = torch.tanh(X) + + # Both should be in (-1, 1) + assert (softsign_output > -1).all() and (softsign_output < 1).all() + assert (tanh_output > -1).all() and (tanh_output < 1).all() + + # Both should be odd functions passing through origin + # Find the index closest to x=0 + zero_idx = X.abs().argmin() + # At x≈0, both should be close to 0 + assert abs(softsign_output[zero_idx].item()) < 0.1 + assert abs(tanh_output[zero_idx].item()) < 0.1 + + # Softsign approaches asymptotes more slowly than tanh + # At x=5: tanh(5) ≈ 0.9999, softsign(5) = 5/6 ≈ 0.833 + assert abs(tanh_output[-1].item() - 1.0) < abs(softsign_output[-1].item() - 1.0) + + +def test_softsign_gradient(): + """Test Softsign gradient: d/dx softsign(x) = 1 / (1 + |x|)^2.""" + X = torch.tensor([-2.0, -1.0, 0.0, 1.0, 2.0], requires_grad=True) + + output = torch.nn.functional.softsign(X) + output.sum().backward() + + # Expected gradient + expected_grad = 1.0 / (1.0 + torch.abs(X)) ** 2 + + torch.testing.assert_close(X.grad, expected_grad, rtol=1e-5, atol=1e-5) + + +def test_softsign_extreme_values(): + """Test Softsign with extreme input values.""" + # Very large positive (but not so large that floating point precision makes it exactly 1.0) + large_pos = torch.tensor([100.0, 1000.0, 10000.0]) + output_pos = torch.nn.functional.softsign(large_pos) + assert ( + output_pos <= 1.0 + ).all() # Use <= since at extreme values it can be exactly 1.0 + assert (output_pos > 0.99).all() # Should be very close to 1 + + # Very large negative + large_neg = torch.tensor([-100.0, -1000.0, -10000.0]) + output_neg = torch.nn.functional.softsign(large_neg) + assert ( + output_neg >= -1.0 + ).all() # Use >= since at extreme values it can be exactly -1.0 + assert (output_neg < -0.99).all() # Should be very close to -1 + + # Very small (near zero) + small = torch.tensor([1e-6, -1e-6, 1e-10]) + output_small = torch.nn.functional.softsign(small) + # For small x, softsign(x) ≈ x + torch.testing.assert_close(output_small, small, rtol=1e-3, atol=1e-3)