Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Performance] Quadratic behaviour in list operations with SequenceInsert in onnx #20696

Open
floriangriese opened this issue May 16, 2024 · 0 comments
Labels
performance issues related to performance regressions

Comments

@floriangriese
Copy link

Describe the issue

When doing list operations e.g. splitting and concatenating a list, the runtime performance of the onnx exported model increases quadratically with the number of elements.
The expected behaviour is linear dependency on the number of elements as in the torch and torch script implementation.
The problem seems to be in the SequenceInsert onnx operator.

result_list_operations_python_torch_script_onnx_minimal_example

To reproduce

import io
import platform
import numpy as np
import onnxruntime
import pytest
import torch
from pytest_benchmark.fixture import BenchmarkFixture
from torch import Tensor

pytestmark = pytest.mark.benchmark


def list_operations(inputs: Tensor) -> Tensor:
    inputs_split = torch.split(inputs, 10)
    return torch.cat([inputs_split[k] for k in range(len(inputs_split)) ],0)

#: Compiled version of the above function to TorchScript
list_operations_torch_script = torch.jit.script(
    list_operations,
    example_inputs=[
        (
            torch.rand(10),
        ),
        (
            torch.rand(20),
        ),
        (
            torch.rand(100),
        ),
    ],
)

@pytest.mark.parametrize("n", [100_000, 200_000, 300_000, 400_000, 500_000])
def test_list_operations(benchmark: BenchmarkFixture, n: int) -> None:
    rng = np.random.default_rng(5337)
    inputs = torch.tensor(rng.gamma(10, scale=0.5, size=(n,)), dtype=torch.float32)

    benchmark(list_operations, inputs,)

@pytest.mark.parametrize("n", [100_000, 200_000, 300_000, 400_000, 500_000 ])
def test_list_operations_torch_script(benchmark: BenchmarkFixture, n: int) -> None:
    rng = np.random.default_rng(5337)
    inputs = torch.tensor(rng.gamma(10, scale=0.5, size=(n,)), dtype=torch.float32)

    benchmark(list_operations, inputs)


@pytest.mark.parametrize("n", [100_000, 200_000, 300_000, 400_000, 500_000])
def test_list_operations_onnx(benchmark: BenchmarkFixture, n: int) -> None:
    inputs = torch.rand(15)
    traced_list_operation_torch_script = torch.jit.trace(
        list_operations_torch_script,
        example_inputs=(inputs),
    )
    torch.onnx.export(
        traced_list_operation_torch_script,
        (inputs),
        buf := io.BytesIO(),
        input_names=["input"],
        output_names=["output"],
        dynamic_axes={
            "input": {0: "observations"},
        },
    )
    session = onnxruntime.InferenceSession(
        buf.getvalue(),
        sess_options=onnxruntime.SessionOptions(),
        providers=["CPUExecutionProvider"],
    )

    def list_operation_onnx(
        input: torch.Tensor,
    ) -> Tensor:
        (output,) = session.run(
            None,
            {
                "input": input.numpy(),
            },
        )
        return torch.as_tensor(output)

    rng = np.random.default_rng(5337)
    inputs = torch.tensor(rng.gamma(10, scale=0.5, size=(n,)), dtype=torch.float32)
    benchmark(list_operation_onnx, inputs)

Urgency

This is currently slowing down our inference pipeline and we are trying to avoid all list operations.

Platform

Mac

OS Version

macOS-14.4.1-x86_64-i386-64bit

ONNX Runtime Installation

Released Package

ONNX Runtime Version or Commit ID

1.17.3

ONNX Runtime API

Python

Architecture

X64

Execution Provider

Default CPU

Execution Provider Library Version

CPUExecutionProvider

Model File

list_operation.onnx.zip

Is this a quantized model?

No

@sophies927 sophies927 added the performance issues related to performance regressions label May 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance issues related to performance regressions
Projects
None yet
Development

No branches or pull requests

2 participants