Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add type annotations for numba and pytorch plugins #5129

Merged
merged 10 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion dali/python/nvidia/dali/plugin/numba/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2021-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -13,3 +13,4 @@
# limitations under the License.

from . import experimental # noqa F401
from . import fn # noqa F401
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ def __init__(self, run_fn,
NumbaFunction._check_minimal_numba_version()
NumbaFunction._check_cuda_compatibility()

# TODO(klecki): Normalize the types into lists first, than apply the checks
assert len(in_types) == len(ins_ndim), ("Number of input types "
"and input dimensions should match.")
assert len(out_types) == len(outs_ndim), ("Number of output types "
Expand Down
67 changes: 67 additions & 0 deletions dali/python/nvidia/dali/plugin/numba/experimental/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union, List, Sequence, Callable

from nvidia.dali.data_node import DataNode

from nvidia.dali.types import DALIDataType

class NumbaFunction:
"""Invokes a njit compiled Numba function.
The run function should be a Python function that can be compiled in Numba ``nopython`` mode."""

def __init__(
self,
run_fn: Optional[Callable[..., None]] = None,
out_types: Optional[List[DALIDataType]] = None,
in_types: Optional[List[DALIDataType]] = None,
outs_ndim: Optional[List[int]] = None,
ins_ndim: Optional[List[int]] = None,
setup_fn: Optional[
Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]]
] = None,
device: str = "cpu",
batch_processing: bool = False,
blocks: Optional[Sequence[int]] = None,
threads_per_block: Optional[Sequence[int]] = None,
bytes_per_sample_hint: Union[Sequence[int], int, None] = [0],
seed: Optional[int] = -1,
) -> None: ...
def __call__(
self,
__input_0: DataNode,
__input_1: Optional[DataNode] = None,
__input_2: Optional[DataNode] = None,
__input_3: Optional[DataNode] = None,
__input_4: Optional[DataNode] = None,
__input_5: Optional[DataNode] = None,
/,
*,
run_fn: Optional[Callable[..., None]] = None,
out_types: Optional[List[DALIDataType]] = None,
in_types: Optional[List[DALIDataType]] = None,
outs_ndim: Optional[List[int]] = None,
ins_ndim: Optional[List[int]] = None,
setup_fn: Optional[
Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]]
] = None,
device: str = "cpu",
batch_processing: bool = False,
blocks: Optional[Sequence[int]] = None,
threads_per_block: Optional[Sequence[int]] = None,
bytes_per_sample_hint: Union[Sequence[int], int, None] = [0],
seed: Optional[int] = -1,
) -> Union[DataNode, Sequence[DataNode]]: ...
15 changes: 15 additions & 0 deletions dali/python/nvidia/dali/plugin/numba/fn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from . import experimental # noqa F401
13 changes: 13 additions & 0 deletions dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
48 changes: 48 additions & 0 deletions dali/python/nvidia/dali/plugin/numba/fn/experimental/__init__.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Any, Optional, Union, List, Sequence, Callable

from nvidia.dali.data_node import DataNode

from nvidia.dali.types import DALIDataType

def numba_function(
__input_0: DataNode,
__input_1: Optional[DataNode] = None,
__input_2: Optional[DataNode] = None,
__input_3: Optional[DataNode] = None,
__input_4: Optional[DataNode] = None,
__input_5: Optional[DataNode] = None,
/,
*,
run_fn: Callable[..., None],
out_types: List[DALIDataType],
in_types: List[DALIDataType],
outs_ndim: List[int],
ins_ndim: List[int],
setup_fn: Optional[Callable[[Sequence[Sequence[Any]], Sequence[Sequence[Any]], None]]] = None,
batch_processing: bool = False,
blocks: Optional[Sequence[int]] = None,
threads_per_block: Optional[Sequence[int]] = None,
bytes_per_sample_hint: Union[Sequence[int], int, None] = [0],
preserve: Optional[bool] = False,
seed: Optional[int] = -1,
device: Optional[str] = None,
name: Optional[str] = None,
) -> Union[DataNode, Sequence[DataNode]]:
"""Invokes a njit compiled Numba function.
The run function should be a Python function that can be compiled in Numba ``nopython`` mode."""
...
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,34 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali.backend import TensorGPU, TensorListGPU
from nvidia.dali.pipeline import Pipeline
import nvidia.dali.ops as ops
import sys

from typing import Union, Optional
from typing import Any, Dict, List

from nvidia.dali import internal as _internal
from nvidia.dali import ops
from nvidia.dali import types
from nvidia.dali.backend import TensorCPU, TensorGPU, TensorListCPU, TensorListGPU
from nvidia.dali.pipeline import Pipeline

from nvidia.dali.plugin.base_iterator import _DaliBaseIterator
from nvidia.dali.plugin.base_iterator import LastBatchPolicy

import torch
import torch.utils.dlpack as torch_dlpack
import torch.utils.dlpack as torch_dlpack # noqa: F401
import ctypes
import numpy as np

from . import fn # noqa: F401

from nvidia.dali.plugin.pytorch._torch_function import TorchPythonFunction as TorchPythonFunction

_internal._adjust_operator_module(TorchPythonFunction, sys.modules[__name__], [])

ops._wrap_op(TorchPythonFunction, "fn", __name__)


to_torch_type = {
types.DALIDataType.FLOAT: torch.float32,
types.DALIDataType.FLOAT64: torch.float64,
Expand All @@ -36,7 +53,11 @@
}


def feed_ndarray(dali_tensor, arr, cuda_stream=None):
def feed_ndarray(
dali_tensor: Union[TensorCPU, TensorGPU, TensorListCPU, TensorListGPU],
arr: torch.Tensor,
cuda_stream: Union[torch.cuda.Stream, Any, None] = None,
) -> torch.Tensor:
"""
Copy contents of DALI tensor to PyTorch's Tensor.
Expand Down Expand Up @@ -161,17 +182,19 @@ class DALIGenericIterator(_DaliBaseIterator):
next iteration will return ``[2, 3]``
"""

def __init__(self,
pipelines,
output_map,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=None,
dynamic_shape=False,
last_batch_padded=False,
last_batch_policy=LastBatchPolicy.FILL,
prepare_first_batch=True):
def __init__(
self,
pipelines: Union[List[Pipeline], Pipeline],
output_map: List[str],
size: int = -1,
reader_name: Optional[str] = None,
auto_reset: Union[str, bool, None] = False,
fill_last_batch: Optional[bool] = None,
dynamic_shape: Optional[bool] = False,
last_batch_padded: bool = False,
last_batch_policy: LastBatchPolicy = LastBatchPolicy.FILL,
prepare_first_batch: bool = True,
) -> None:

# check the assert first as _DaliBaseIterator would run the prefetch
assert len(set(output_map)) == len(output_map), "output_map names should be distinct"
Expand Down Expand Up @@ -200,7 +223,7 @@ def __init__(self,
"if `last_batch_policy` is set to PARTIAL and the requested batch size is " \
"greater than the shard size."

def __next__(self):
def __next__(self) -> List[Dict[str, torch.Tensor]]:
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
Expand Down Expand Up @@ -393,16 +416,18 @@ class DALIClassificationIterator(DALIGenericIterator):
next iteration will return ``[2, 3]``
"""

def __init__(self,
pipelines,
size=-1,
reader_name=None,
auto_reset=False,
fill_last_batch=None,
dynamic_shape=False,
last_batch_padded=False,
last_batch_policy=LastBatchPolicy.FILL,
prepare_first_batch=True):
def __init__(
self,
pipelines: Union[List[Pipeline], Pipeline],
size: int = -1,
reader_name: Optional[str] = None,
auto_reset: Union[str, bool, None] = False,
fill_last_batch: Optional[bool] = None,
dynamic_shape: Optional[bool] = False,
last_batch_padded: bool = False,
last_batch_policy: LastBatchPolicy = LastBatchPolicy.FILL,
prepare_first_batch: bool = True,
) -> None:
super(DALIClassificationIterator, self).__init__(pipelines, ["data", "label"],
size,
reader_name=reader_name,
Expand Down Expand Up @@ -516,18 +541,20 @@ class DALIRaggedIterator(_DaliBaseIterator):
last batch = ``[5, 6]``, next iteration will return ``[2, 3]``
"""

def __init__(self,
pipelines,
output_map,
size=-1,
reader_name=None,
output_types=None,
auto_reset=False,
fill_last_batch=None,
dynamic_shape=False,
last_batch_padded=False,
last_batch_policy=LastBatchPolicy.FILL,
prepare_first_batch=True):
def __init__(
self,
pipelines: Union[List[Pipeline], Pipeline],
output_map: List[str],
size: int = -1,
reader_name: Optional[str] = None,
output_types: Optional[List[str]] = None,
auto_reset: Union[str, bool, None] = False,
fill_last_batch: Optional[bool] = None,
dynamic_shape: Optional[bool] = False,
last_batch_padded: bool = False,
last_batch_policy: LastBatchPolicy = LastBatchPolicy.FILL,
prepare_first_batch: bool = True,
) -> None:

# check the assert first as _DaliBaseIterator would run the prefetch
self._output_tags = {
Expand Down Expand Up @@ -566,7 +593,7 @@ def __init__(self,
"if `last_batch_policy` is set to PARTIAL and the requested batch size is " \
"greater than the shard size."

def __next__(self):
def __next__(self) -> List[Dict[str, torch.Tensor]]:
self._ever_consumed = True
if self._first_batch is not None:
batch = self._first_batch
Expand Down Expand Up @@ -705,56 +732,6 @@ def __next__(self):

return data_batches

DENSE_TAG = "dense"
SPARSE_LIST_TAG = "sparse_list"
SPARSE_COO_TAG = "sparse_coo"


class TorchPythonFunction(ops.PythonFunctionBase):
schema_name = "TorchPythonFunction"
_impl_module = "nvidia.dali.plugin.pytorch"
ops.register_cpu_op('TorchPythonFunction')
ops.register_gpu_op('TorchPythonFunction')

def _torch_stream_wrapper(self, function, *ins):
with torch.cuda.stream(self.stream):
out = function(*ins)
self.stream.synchronize()
return out

def torch_wrapper(self, batch_processing, function, device, *args):
func = function if device == 'cpu' else \
lambda *ins: self._torch_stream_wrapper(function, *ins)
if batch_processing:
return ops.PythonFunction.function_wrapper_batch(func,
self.num_outputs,
torch.utils.dlpack.from_dlpack,
torch.utils.dlpack.to_dlpack,
*args)
else:
return ops.PythonFunction.function_wrapper_per_sample(func,
self.num_outputs,
torch_dlpack.from_dlpack,
torch_dlpack.to_dlpack,
*args)

def __call__(self, *inputs, **kwargs):
pipeline = Pipeline.current()
if pipeline is None:
Pipeline._raise_no_current_pipeline("TorchPythonFunction")
if self.stream is None:
self.stream = torch.cuda.Stream(device=pipeline.device_id)
return super(TorchPythonFunction, self).__call__(*inputs, **kwargs)

def __init__(self, function, num_outputs=1, device='cpu', batch_processing=False, **kwargs):
self.stream = None
super(TorchPythonFunction, self).__init__(impl_name="DLTensorPythonFunctionImpl",
function=lambda *ins:
self.torch_wrapper(batch_processing,
function, device,
*ins),
num_outputs=num_outputs, device=device,
batch_processing=batch_processing, **kwargs)


ops._wrap_op(TorchPythonFunction, "fn", __name__)
DENSE_TAG: str = "dense"
SPARSE_LIST_TAG: str = "sparse_list"
SPARSE_COO_TAG: str = "sparse_coo"
Loading