Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ Currently supported and tested models from [onnx_zoo](https://github.com/onnx/mo

## Limitations
Known current version limitations are:
- `batch_size > 1` could deliver unexpected results due to ambiguity of onnx's BatchNorm layer.
That is why in this case for now we raise an assertion error.
Set `experimental=True` in `ConvertModel` to be able to use `batch_size > 1`.
- `batch_size > 1` is now supported by default.
BatchNorm layers use inference mode (running statistics), which is correct for ONNX models
exported for inference.
- Fine tuning and training of converted models was not tested yet, only inference.

## Development
Expand Down
13 changes: 1 addition & 12 deletions onnx2pytorch/convert/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,7 @@ def __init__(
batch_dim: int
Dimension of the batch.
experimental: bool
Experimental implementation allows batch_size > 1. However,
batchnorm layers could potentially produce false outputs.
At the moment it does not do anything anymore. Default: False
enable_pruning: bool
Track kept/pruned indices between different calls to forward pass.

Expand Down Expand Up @@ -143,12 +142,6 @@ def __init__(
self.onnx_model.graph, self, self.mapping
)

if experimental:
warnings.warn(
"Using experimental implementation that allows 'batch_size > 1'."
"Batchnorm layers could potentially produce false outputs."
)

def forward(self, *input_list, **input_dict):
if len(input_list) > 0 and len(input_dict) > 0:
raise ValueError(
Expand All @@ -160,10 +153,6 @@ def forward(self, *input_list, **input_dict):
if len(input_dict) > 0:
inputs = [input_dict[key] for key in self.input_names]

if not self.experimental and inputs[0].shape[self.batch_dim] > 1:
raise NotImplementedError(
"Input with larger batch size than 1 not supported yet."
)
activations = dict(zip(self.input_names, inputs))
still_needed_by = deepcopy(self.needed_by)

Expand Down
11 changes: 11 additions & 0 deletions onnx2pytorch/convert/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
op = partial(torch.prod, **kwargs)
elif node.op_type == "ReduceSum":
op = ReduceSum(opset_version=opset_version, **extract_attributes(node))
elif node.op_type == "ReduceSumSquare":
op = ReduceSumSquare(
opset_version=opset_version, **extract_attributes(node)
)
elif node.op_type == "ReduceL2":
op = ReduceL2(opset_version=opset_version, **extract_attributes(node))
elif node.op_type == "Relu":
Expand Down Expand Up @@ -266,9 +270,16 @@ def convert_operations(onnx_graph, opset_version, batch_dim=0, enable_pruning=Tr
kwargs = dict(dim=-1)
kwargs.update(extract_attributes(node))
op = nn.Softmax(**kwargs)
elif node.op_type == "LogSoftmax":
kwargs = dict(dim=-1)
kwargs.update(extract_attributes(node))
op = nn.LogSoftmax(**kwargs)
elif node.op_type == "Softplus":
# ONNX Softplus has no attributes: y = ln(exp(x) + 1)
# PyTorch Softplus with beta=1 matches ONNX spec
op = nn.Softplus(beta=1)
elif node.op_type == "Softsign":
# ONNX Softsign has no attributes: y = x / (1 + |x|)
op = nn.Softsign()
elif node.op_type == "Split":
kwargs = extract_attributes(node)
Expand Down
2 changes: 2 additions & 0 deletions onnx2pytorch/operations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from .randomuniformlike import RandomUniformLike
from .reducemax import ReduceMax
from .reducesum import ReduceSum
from .reducesumsquare import ReduceSumSquare
from .reducel2 import ReduceL2
from .reshape import Reshape
from .resize import Resize, Upsample
Expand Down Expand Up @@ -80,6 +81,7 @@
"RandomUniformLike",
"ReduceMax",
"ReduceSum",
"ReduceSumSquare",
"ReduceL2",
"Reshape",
"Resize",
Expand Down
6 changes: 4 additions & 2 deletions onnx2pytorch/operations/batchnorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ class _LazyBatchNorm(_LazyNormBase, _BatchNorm):

cls_to_become = _BatchNorm


except ImportError:
# for torch < 1.10.0
from torch.nn.modules.batchnorm import _LazyBatchNorm
Expand Down Expand Up @@ -49,14 +48,17 @@ def __init__(self, torch_params, *args, **kwargs):
for key, value in zip(keys, torch_params):
getattr(self.bnu, key).data = value

# Set to eval mode to use running statistics (ONNX inference behavior)
self.bnu.eval()

def forward(self, X, scale=None, B=None, input_mean=None, input_var=None):
if self.has_lazy:
self.bnu.initialize_parameters(X)

if scale is not None:
getattr(self.bnu, "weight").data = scale
if B is not None:
getattr(self.bnu, "bias").data = scale
getattr(self.bnu, "bias").data = B
if input_mean is not None:
getattr(self.bnu, "running_mean").data = input_mean
if input_var is not None:
Expand Down
4 changes: 0 additions & 4 deletions onnx2pytorch/operations/loop.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from collections import defaultdict
from copy import deepcopy
from functools import partial
from importlib import import_module
import warnings

import numpy as np
import onnx
import torch
from onnx import numpy_helper
Expand Down Expand Up @@ -71,7 +68,6 @@ def forward(self, enclosing_modules, enclosing_activations, *inputs):
"""

N = len(self.input_names) - 2
K = len(self.output_names) - (1 + N)

M = inputs[0]
cond = inputs[1]
Expand Down
43 changes: 43 additions & 0 deletions onnx2pytorch/operations/reducesumsquare.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import torch
from torch import nn


class ReduceSumSquare(nn.Module):
"""
Computes the sum of the squared elements of the input tensor's elements along the provided axes.

Equivalent to ReduceSum(Square(data), axes, keepdim).
"""

def __init__(
self, opset_version, dim=None, keepdim=True, noop_with_empty_axes=False
):
self.opset_version = opset_version
self.dim = dim
self.keepdim = bool(keepdim)
self.noop_with_empty_axes = noop_with_empty_axes
super().__init__()

def forward(self, data: torch.Tensor, axes: torch.Tensor = None):
# In opset < 13, axes is an attribute (self.dim)
# In opset >= 13, axes is an optional input
if self.opset_version < 13:
dims = self.dim
else:
dims = axes

if dims is None:
if self.noop_with_empty_axes:
return data
else:
# Reduce over all dimensions
dims = tuple(range(data.ndim))

if isinstance(dims, int):
dim = dims
else:
dim = tuple(list(dims))

# Compute sum of squares: sum(x^2)
ret = torch.sum(torch.square(data), dim=dim, keepdim=self.keepdim)
return ret
Loading