Skip to content

Increasing rss in pipeline execution #6153

@bveldhoen

Description

@bveldhoen

Version

1.52.0, 1.53.0

Describe the bug.

Thanks in advance for taking a look at the description below. I may be doing something incorrectly. If so, please let me know!

Environment:

In custom docker container based on nvcr.io/nvidia/cuda:13.1.0-cudnn-devel-ubuntu24.04
$ nvcc --version
...
Build cuda_13.0.r13.0/compiler.36424714_0
>>> nvidia.dali.__version__
'1.52.0'

(Also reproduced with version 1.53.0)

We're experiencing OOM killer issues, which kills the process after a few days. We're seeing increasing rss memory when executing a pipeline in a loop.

The external_source_pipeline is meant to be a control test. The difference between external_source_pipeline and resize_pipeline is a fn.resize.

Is there anything that can be changed to this code using resize_pipeline, so that the rss doesn't increase?

Minimum reproducible example

test_dali_pipeline.py:

import datetime
import gc
import os
import psutil
import sys

import numpy as np
import cupy as cp

from nvidia.dali import pipeline_def
import nvidia.dali.fn as fn
import nvidia.dali.types as types


@pipeline_def(num_threads=1, batch_size=1, prefetch_queue_depth=1)
def external_source_pipeline():
    input_batch = fn.external_source(
        batch=True,
        device="gpu",
        name="input_batch",
    )
    return input_batch # / 128  # adding this also causes rss to increase


@pipeline_def(num_threads=1, batch_size=1, prefetch_queue_depth=1)
def reshape_pipeline():
    input_batch = fn.external_source(
        batch=True,
        device="gpu",
        name="input_batch",
    )

    output_batch = fn.reshape(
        input_batch,
        shape=[1, 3, 1440, 2560],
    )

    return output_batch


@pipeline_def(num_threads=1, batch_size=1, prefetch_queue_depth=1)
def resize_pipeline():
    input_batch = fn.external_source(
        batch=True,
        device="gpu",
        name="input_batch",
    )

    output_batch = fn.resize(
        input_batch,
        antialias=False,
        interp_type=types.DALIInterpType.INTERP_LINEAR,
        mag_filter=types.DALIInterpType.INTERP_LINEAR,
        min_filter=types.DALIInterpType.INTERP_LINEAR,
        mode="stretch",
        resize_x=224,
        resize_y=224,
        subpixel_scale=False,
        device="gpu",
    )

    return output_batch


@pipeline_def(num_threads=1, batch_size=1, prefetch_queue_depth=1)
def rotate_pipeline():
    input_batch = fn.external_source(
        batch=True,
        device="gpu",
        name="input_batch",
    )

    output_batch = fn.rotate(
        input_batch,
        angle=90,
        device="gpu",
        preserve=True,
    )

    return output_batch


def get_rss_mb() -> float:
    process = psutil.Process(os.getpid())
    mem_bytes = process.memory_info().rss
    return mem_bytes / (1024 * 1024)


def report_rss(tag: str) -> float:
    gc.collect()
    rss_mb = get_rss_mb()
    print(f"{datetime.datetime.now()} [{tag}] rss_mb: {rss_mb}")
    return rss_mb


def create_input_batch():
    return [
        np.random.randint(
            0, high=256, size=(1, 1440, 2560, 3), dtype=np.uint8
        )
    ]


def create_input_batch_gpu():
    return [
        cp.random.randint(
            0, high=256, size=(1, 1440, 2560, 3), dtype=np.uint8
        )
    ]


# def exec_pipe(pipe, input_batch):
#     # input_batch = [cp.asarray(input) for input in input_batch]
#     pipe.feed_input("input_batch", input_batch)
#     (output_batch,) = pipe.run()
#     return output_batch


# This methods gives the same results
def exec_pipe(pipe, input_batch):
    (output_batch,) = pipe.run(input_batch=input_batch)
    return output_batch


def test_pipeline_rss(
    pipe,
    input_batch,
    report_mem_usage_interval: int = 2000,
    test_count: int | None = None
):
    if test_count is None:
        test_count = report_mem_usage_interval * 10

    rss_mb_begin = report_rss("begin test_pipeline")
    for i in range(test_count):
        output_batch = exec_pipe(pipe, input_batch)
        # print(f"output_batch: {output_batch}")
        output_batch = None

        if i % report_mem_usage_interval == 0:
            report_rss(i)

    rss_mb_end = report_rss("end test_pipeline")

    rss_mb_diff = rss_mb_end - rss_mb_begin
    rss_bytes_diff_per_iteration = (rss_mb_diff * 1024 * 1024) / test_count
    print(f"{rss_mb_diff} MB / {test_count} = {rss_bytes_diff_per_iteration} byte/execution")


def test_pipeline_memray(
    filename: str,
    pipe,
    input_batch,
):
    if os.path.exists(filename):
        os.remove(filename)

    import memray
    with memray.Tracker(
        filename,
        native_traces=True,
        trace_python_allocators=True,
        track_object_lifetimes=False,
        follow_fork=False,
        memory_interval_ms=10,
        file_format=memray.FileFormat.ALL_ALLOCATIONS,
    ):
        output_batch = exec_pipe(pipe, input_batch)
        output_batch = None
        gc.collect()

def main(
    pipeline_name: str,
    test_type: str,
    report_mem_usage_interval: int = 2000,
):
    print(f"Executing test_type {test_type} for pipeline: {pipeline_name}")
    # report_rss("begin main")

    input_batch = create_input_batch()
    # input_batch = create_input_batch_gpu()

    if pipeline_name == "external_source_pipeline":
        pipe = external_source_pipeline()  # rss stays constant
    if pipeline_name == "reshape_pipeline":
        pipe = reshape_pipeline()  # rss increases
    if pipeline_name == "resize_pipeline":
        pipe = resize_pipeline()  # rss increases
    if pipeline_name == "rotate_pipeline":
        pipe = rotate_pipeline()  # rss increases

    pipe.build()

    # Warm-up
    for _ in range(2 * report_mem_usage_interval):
        output_batch = exec_pipe(pipe, input_batch)

    output_batch = None
    gc.collect()

    if test_type == "rss":
        test_pipeline_rss(
            pipe,
            input_batch,
            report_mem_usage_interval,
        )
    elif test_type == "memray":
        memray_filename = f"{pipeline_name}.bin"
        test_pipeline_memray(
            memray_filename,
            pipe,
            input_batch,
        )

    input_batch = None
    pipe = None

    # report_rss("end main")


if __name__ == "__main__":
    pipeline_name = sys.argv[1]
    test_type = sys.argv[2]
    report_mem_usage_interval = int(sys.argv[3])

    main(pipeline_name, test_type)

    # To run the pipeline using memray:
    # PYTHONMALLOC=malloc python3 test_dali_pipeline.py

    # Create memray flamegraph overview:
    # python3 -m memray flamegraph --leaks --force resize_pipeline.bin --output resize_pipeline-flamegraph.html

    # Create memray table overview:
    # python3 -m memray table --leaks --force resize_pipeline.bin --output resize_pipeline-table.html

test_dali_pipeline.sh:

#!/bin/bash

PIPELINE_NAME="${1:-resize_pipeline}"
TEST_TYPE="${2:-memray}"
REPORT_MEM_USAGE_INTERVAL="${3:-2000}"

# To run the pipeline using memray:
python3 test_dali_pipeline.py "${PIPELINE_NAME}" "${TEST_TYPE}" "${REPORT_MEM_USAGE_INTERVAL}"

if [ "${TEST_TYPE}" == "memray" ]; then
    # Create memray flamegraph overview:
    python3 -m memray flamegraph --leaks --force "${PIPELINE_NAME}.bin" --output "${PIPELINE_NAME}-flamegraph.html"

    # Create memray table overview:
    python3 -m memray table --leaks --force "${PIPELINE_NAME}.bin" --output "${PIPELINE_NAME}-table.html"
fi

test_dali_pipeline_all.sh:

#!/bin/bash

./test_dali_pipeline.sh "external_source_pipeline" "rss"
./test_dali_pipeline.sh "reshape_pipeline" "rss"
./test_dali_pipeline.sh "resize_pipeline" "rss"
./test_dali_pipeline.sh "rotate_pipeline" "rss"

export PYTHONMALLOC=malloc; ./test_dali_pipeline.sh "external_source_pipeline" "memray"
export PYTHONMALLOC=malloc; ./test_dali_pipeline.sh "reshape_pipeline" "memray"
export PYTHONMALLOC=malloc; ./test_dali_pipeline.sh "resize_pipeline" "memray"
export PYTHONMALLOC=malloc; ./test_dali_pipeline.sh "rotate_pipeline" "memray"

Relevant log output

$ ./test_dali_pipeline_all.sh
Executing test_type rss for pipeline: external_source_pipeline
2026-01-12 17:11:05.469598 [begin test_pipeline] rss_mb: 423.26953125
2026-01-12 17:11:05.479715 [0] rss_mb: 423.26953125
2026-01-12 17:11:07.767429 [2000] rss_mb: 423.26953125
2026-01-12 17:11:10.532311 [4000] rss_mb: 423.26953125
2026-01-12 17:11:13.201345 [6000] rss_mb: 423.26953125
2026-01-12 17:11:15.699299 [8000] rss_mb: 423.26953125
2026-01-12 17:11:18.194227 [10000] rss_mb: 423.26953125
2026-01-12 17:11:20.692626 [12000] rss_mb: 423.26953125
2026-01-12 17:11:23.627351 [14000] rss_mb: 423.26953125
2026-01-12 17:11:26.039452 [16000] rss_mb: 423.26953125
2026-01-12 17:11:28.363611 [18000] rss_mb: 423.26953125
2026-01-12 17:11:30.944950 [end test_pipeline] rss_mb: 423.26953125
0.0 MB / 20000 = 0.0 byte/execution
Executing test_type rss for pipeline: reshape_pipeline
2026-01-12 17:11:37.407176 [begin test_pipeline] rss_mb: 419.546875
2026-01-12 17:11:37.417318 [0] rss_mb: 419.546875
2026-01-12 17:11:39.770390 [2000] rss_mb: 419.546875
2026-01-12 17:11:42.521751 [4000] rss_mb: 419.546875
2026-01-12 17:11:45.226038 [6000] rss_mb: 419.546875
2026-01-12 17:11:47.737420 [8000] rss_mb: 419.546875
2026-01-12 17:11:50.202665 [10000] rss_mb: 419.546875
2026-01-12 17:11:52.686424 [12000] rss_mb: 419.546875
2026-01-12 17:11:55.683424 [14000] rss_mb: 419.546875
2026-01-12 17:11:57.997830 [16000] rss_mb: 419.546875
2026-01-12 17:12:00.348124 [18000] rss_mb: 419.546875
2026-01-12 17:12:02.705068 [end test_pipeline] rss_mb: 419.546875
0.0 MB / 20000 = 0.0 byte/execution
Executing test_type rss for pipeline: resize_pipeline
2026-01-12 17:12:09.127668 [begin test_pipeline] rss_mb: 552.77734375
2026-01-12 17:12:09.139451 [0] rss_mb: 552.77734375
2026-01-12 17:12:11.609857 [2000] rss_mb: 553.02734375
2026-01-12 17:12:14.347069 [4000] rss_mb: 553.27734375
2026-01-12 17:12:17.614641 [6000] rss_mb: 553.52734375
2026-01-12 17:12:20.282539 [8000] rss_mb: 553.77734375
2026-01-12 17:12:22.963455 [10000] rss_mb: 553.77734375
2026-01-12 17:12:25.694125 [12000] rss_mb: 554.02734375
2026-01-12 17:12:28.615628 [14000] rss_mb: 554.02734375
2026-01-12 17:12:31.198280 [16000] rss_mb: 554.52734375
2026-01-12 17:12:33.903700 [18000] rss_mb: 554.52734375
2026-01-12 17:12:36.778276 [end test_pipeline] rss_mb: 554.77734375
2.0 MB / 20000 = 104.8576 byte/execution
Executing test_type rss for pipeline: rotate_pipeline
2026-01-12 17:12:44.330738 [begin test_pipeline] rss_mb: 436.046875
2026-01-12 17:12:44.343241 [0] rss_mb: 436.046875
2026-01-12 17:12:47.565645 [2000] rss_mb: 438.796875
2026-01-12 17:12:51.010564 [4000] rss_mb: 440.796875
2026-01-12 17:12:54.315621 [6000] rss_mb: 440.796875
2026-01-12 17:12:57.579050 [8000] rss_mb: 440.796875
2026-01-12 17:13:01.271073 [10000] rss_mb: 441.046875
2026-01-12 17:13:04.523953 [12000] rss_mb: 441.046875
2026-01-12 17:13:07.813349 [14000] rss_mb: 441.046875
2026-01-12 17:13:11.662529 [16000] rss_mb: 441.046875
2026-01-12 17:13:15.182125 [18000] rss_mb: 441.546875
2026-01-12 17:13:18.441670 [end test_pipeline] rss_mb: 441.796875
5.75 MB / 20000 = 301.4656 byte/execution
Executing test_type memray for pipeline: external_source_pipeline
Memray WARNING: Correcting symbol for malloc from 0x420620 to 0x70e61c244650
Memray WARNING: Correcting symbol for free from 0x420ab0 to 0x70e61c244d30
Wrote external_source_pipeline-flamegraph.html
Wrote external_source_pipeline-table.html
Executing test_type memray for pipeline: reshape_pipeline
Memray WARNING: Correcting symbol for malloc from 0x420620 to 0x76c1b0fb1650
Memray WARNING: Correcting symbol for free from 0x420ab0 to 0x76c1b0fb1d30
Wrote reshape_pipeline-flamegraph.html
Wrote reshape_pipeline-table.html
Executing test_type memray for pipeline: resize_pipeline
Memray WARNING: Correcting symbol for malloc from 0x420620 to 0x7929ab2b7650
Memray WARNING: Correcting symbol for free from 0x420ab0 to 0x7929ab2b7d30
Wrote resize_pipeline-flamegraph.html
Wrote resize_pipeline-table.html
Executing test_type memray for pipeline: rotate_pipeline
Memray WARNING: Correcting symbol for malloc from 0x420620 to 0x7fd2c5cef650
Memray WARNING: Correcting symbol for free from 0x420ab0 to 0x7fd2c5cefd30
Wrote rotate_pipeline-flamegraph.html
Wrote rotate_pipeline-table.html

Other/Misc.

I've also included memory tracking using memray. These results (memray table --leaks) show 136 entries for external_source_pipeline, vs. 333 entries for resize_pipeline.

The stack trace reported by memray flamegraph --leaks for resize_pipeline:

 (output_batch,) = pipe.run(input_batch=input_batch)
self.feed_input(inp_name, inp_value)
self._feed_input(name, data, layout, cuda_stream, use_copy_kernel)
self._pipe.SetExternalTensorInput(name, data, cuda_stream_ptr, use_copy_kernel)
cfunction_call at ../Objects/methodobject.c:537
pybind11::cpp_function::dispatcher(_object*, _object*, _object*) at <unknown>:0
pybind11::cpp_function::initialize<dali::python::ExposePipeline(pybind11::module_&)::{lambda(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool)#1}, void, dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool, pybind11::name, pybind11::is_method, pybind11::sibling, pybind11::arg, pybind11::arg, pybind11::arg_v, pybind11::arg_v>(dali::python::ExposePipeline(pybind11::module_&)::{lambda(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool)#1}&&, void (*)(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool), pybind11::name const&, pybind11::is_method const&, pybind11::sibling const&, pybind11::arg const&, pybind11::arg const&, pybind11::arg_v const&, pybind11::arg_v const&)::{lambda(pybind11::detail::function_call&)#1}::_FUN(pybind11::detail::function_call&) at <unknown>:0
dali::python::ExposePipeline(pybind11::module_&)::{lambda(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool)#1}::operator()(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, pybind11::object, bool) const [clone .isra.0] at <unknown>:0
void dali::python::FeedPipeline<dali::CPUBackend>(dali::Pipeline*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, pybind11::list, dali::AccessOrder, bool, bool) at <unknown>:0
void dali::Pipeline::SetExternalInputHelper<dali::CPUBackend>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, dali::TensorList<dali::CPUBackend> const&, std::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, dali::AccessOrder, dali::InputOperatorSettingMode, bool) at <unknown>:0
void dali::Pipeline::SetDataSourceHelper<dali::CPUBackend, dali::GPUBackend>(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, dali::TensorList<dali::CPUBackend> const&, std::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, dali::OperatorBase*, dali::AccessOrder, dali::InputOperatorSettingMode, bool) at <unknown>:0
std::enable_if<std::is_same<dali::GPUBackend, dali::GPUBackend>::value, void>::type dali::InputOperator<dali::GPUBackend>::CopyUserData<dali::CPUBackend, dali::GPUBackend>(dali::TensorList<dali::CPUBackend> const&, std::optional<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, dali::AccessOrder, bool, bool) at <unknown>:0
void dali::TensorList<dali::GPUBackend>::Copy<dali::CPUBackend>(dali::TensorList<dali::CPUBackend> const&, dali::AccessOrder, bool) at <unknown>:0
dali::TensorList<dali::GPUBackend>::Resize(dali::TensorListShape<-1> const&, _DALIDataType, dali::BatchContiguity) at <unknown>:0
dali::Buffer<dali::GPUBackend>::reserve(unsigned long, dali::AccessOrder) at <unknown>:0
dali::AllocBuffer(unsigned long, bool, int, dali::AccessOrder, dali::GPUBackend*) at <unknown>:0
dali::mm::async_pool_resource<dali::cuda_for_dali::memory_kind::device, dali::mm::cuda_vm_resource, std::mutex, void>::do_allocate_async(unsigned long, unsigned long, dali::cuda_for_dali::stream_view) at <unknown>:0
dali::mm::cuda_vm_resource::do_allocate(unsigned long, unsigned long) at <unknown>:0
dali::mm::cuda_vm_resource::va_allocate(unsigned long) at <unknown>:0
cuMemAddressReserve at <unknown>:0
<unknown> at <unknown>:0
<unknown> at <unknown>:0
<unknown> at <unknown>:0
<unknown> at <unknown>:0
<unknown> at <unknown>:0
<unknown> at <unknown>:0

However, this same stacktrace is reported for external_source_pipeline, so I don't know how accurate these reports are.

No response

Check for duplicates

  • I have searched the open bugs/issues and have found no duplicates for this bug report

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions