Skip to content

Commit

Permalink
small changes for nvrtc
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Jan 29, 2024
1 parent 20ac6f9 commit 9e11bf9
Show file tree
Hide file tree
Showing 12 changed files with 212 additions and 18 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
# Changelog
## [0.5.2] - 2024-01-02
### Added
- add `run_in_process` support for inliner to debug some unrecoverable cuda errors such as invalid memory access (700) without restart whole process. this option will copy all tensor data to cpu, copy them to child process (spawn mode), run in child process, and copy back to cpu and main process. this will slow down the performance, but it's very useful for debugging.
- add macro `TV_ASSERT_WITH_PRINT` to perform print in assert.
- change inliner function name with user-provided name for debug.

## [0.5.1] - 2023-12-26
### Fixed
- fix a small bug in `mp_helper.h`
Expand Down
8 changes: 8 additions & 0 deletions cumm/core_cc/tensorview_bind.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@ class Tensor:
def copy_(self, other: "Tensor", ctx: Context) -> None:
...

@overload
def copy_storage_(self, other: "Tensor") -> None:
...

@overload
def copy_storage_(self, other: "Tensor", ctx: Context) -> None:
...

@overload
def zero_(self) -> "Tensor":
...
Expand Down
32 changes: 24 additions & 8 deletions cumm/inliner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,9 @@
"""

import contextlib
import re
import pccm
from pccm.builder.inliner import InlineBuilder, InlineBuilderPlugin, PCCM_INLINE_NAMESPACE, PCCM_INLINE_FUNCTION_NAME
from pccm.builder.inliner import InlineBuilder, InlineBuilderPlugin, PCCM_INLINE_NAMESPACE, PCCM_INLINE_FUNCTION_NAME, PCCM_INLINE_FUNCTION_NAME_FORMAT
from cumm.nvrtc import CummLLVMModule, CummNVRTCModule, CummNVRTCModuleBase, create_nvrtc_code
from pathlib import Path
from cumm.common import TensorViewKernel
Expand Down Expand Up @@ -104,6 +105,7 @@ def _cached_get_torch_dtype_to_tv():
torch.int8: tv.int8,
torch.int16: tv.int16,
torch.uint8: tv.uint8,
torch.bfloat16: tv.bfloat16,
})
return _TORCH_DTYPE_TO_TV

Expand Down Expand Up @@ -230,7 +232,8 @@ def __init__(self,
measure_time: bool = False,
is_cpu: bool = False,
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None) -> None:
perf_context: Optional[ContextManager] = None,
run_in_process: bool = False) -> None:
self.mode = mode
self.launch = launch
self.verbose = verbose
Expand All @@ -239,9 +242,11 @@ def __init__(self,
self.is_cpu = is_cpu
self.capture_tensor_as_tview = capture_tensor_as_tview
self.perf_context = perf_context
self.run_in_process = run_in_process


_NVRTC_FUNC_NAME = f"{PCCM_INLINE_NAMESPACE}::{PCCM_INLINE_FUNCTION_NAME}"
_NVRTC_FUNC_NAME_FORMAT = f"{PCCM_INLINE_NAMESPACE}::{PCCM_INLINE_FUNCTION_NAME_FORMAT}"

_DEFAULT_KERNEL_PLUGINS: Dict[str, InlineBuilderPlugin] = {
"numpy.ndarray": NumpyPlugin(),
Expand Down Expand Up @@ -298,7 +303,7 @@ def get_nvrtc_module(self, name: str) -> Optional[CummNVRTCModule]:
def get_nvrtc_kernel_attrs(self, name: str) -> Dict[str, int]:
nvrtc_mod = self.get_nvrtc_module(name)
assert nvrtc_mod is not None
return nvrtc_mod.get_kernel_attrs(nvrtc_mod.get_lowered_name(_NVRTC_FUNC_NAME))
return nvrtc_mod.get_kernel_attrs(nvrtc_mod.get_lowered_name(_NVRTC_FUNC_NAME_FORMAT.format(name)))

def get_save_root(self,
path: Path,
Expand Down Expand Up @@ -398,13 +403,21 @@ def build(self,

return mod

def _get_nvrtc_inline_func_name_for_debug(self, name: str):
return _NVRTC_FUNC_NAME_FORMAT.format(re.sub('[^0-9a-zA-Z]', '_', name))

def run_func(self,
name: str,
func: CummNVRTCModuleBase,
*args,
user_args: Optional[_NVRTCInlineParams] = None):
assert user_args is not None
launch = user_args.launch
return func.run_kernel(_NVRTC_FUNC_NAME, launch, *args, perf_context=user_args.perf_context)
if user_args.run_in_process:
return func.run_kernel_in_spawn_process(self._get_nvrtc_inline_func_name_for_debug(name), launch, *args)
else:
return func.run_kernel(self._get_nvrtc_inline_func_name_for_debug(name), launch, *args, perf_context=user_args.perf_context)


def kernel_raw(self,
name: str,
Expand All @@ -414,12 +427,13 @@ def kernel_raw(self,
disable_cache: bool = False,
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None,
run_in_process: bool = False,
*,
_frame_cnt: int = 2):
verbose = verbose_path != ""
user_arg = _NVRTCInlineParams(CUDAMode.KernelRaw, param, verbose,
verbose_path, capture_tensor_as_tview=capture_tensor_as_tview,
perf_context=perf_context)
perf_context=perf_context, run_in_process=run_in_process)
if capture_tensor_as_tview:
if not isinstance(code, pccm.FunctionCode):
code_pccm = pccm.code()
Expand Down Expand Up @@ -459,15 +473,17 @@ def kernel_1d(self,
disable_cache: bool = False,
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None,
run_in_process: bool = False,
*,
_frame_cnt: int = 2):
_frame_cnt: int = 2,
maximum_1d_threads: Optional[int] = None):
verbose = verbose_path != ""
num = int(num)
user_arg = _NVRTCInlineParams(CUDAMode.Kernel1D,
self.get_1d_param(num, stream=stream),
self.get_1d_param(num, stream=stream, maximum_1d_threads=maximum_1d_threads),
verbose, verbose_path,
capture_tensor_as_tview=capture_tensor_as_tview,
perf_context=perf_context)
perf_context=perf_context, run_in_process=run_in_process)
additional_args = {
_CUMM_KERNEL_1D_SIZE_NAME: num,
}
Expand Down
56 changes: 55 additions & 1 deletion cumm/tensorview/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@
from enum import Enum
import enum
from functools import partial
import io
import traceback
from typing import Callable, Dict, List, Optional, Tuple, Union, ContextManager

import numpy as np
from pccm import Argument
from pccm.middlewares.pybind import (TemplateTypeStmt,
_simple_template_type_parser)

import multiprocessing
from cumm.core_cc import tensorview_bind
from cumm.core_cc.tensorview_bind import CUDAKernelTimer, CUDAEvent, Context
from cumm.core_cc.tensorview_bind import NVRTCModule as _NVRTCModule
Expand Down Expand Up @@ -245,6 +247,58 @@ def run_kernel_unchecked(self, name: str, launch: LaunchParam, *args: Tuple[Tens
return self._mod.run_kernel(name, launch.blocks, launch.threads,
launch.smem, launch.stream, list(args))

def run_kernel_in_spawn_process(self, name: str, launch: LaunchParam,
*args: Union[Tensor, int, float, List[int], List[float],
Tuple[float, ...], Tuple[int, ...]], timeout: Optional[float] = None):
ctx = multiprocessing.get_context("spawn")
arg_kernel_meta = []
for arg in args:
if isinstance(arg, Tensor):
arg_kernel_meta.append((arg.cpu().numpy(), True))
else:
arg_kernel_meta.append((arg, False))
ret_queue = ctx.Queue()
proc = ctx.Process(target=self._run_kernel_in_spawn_process_func, args=(ret_queue, name, launch, arg_kernel_meta))
proc.daemon = True
proc.start()
# must get queue before proc join to avoid dead lock
res = ret_queue.get()
# TODO how to deal with timeout process (kernel deadlock)?
proc.join(timeout)
if isinstance(res, str):
raise ValueError(f"Error, traceback: \n{res}")
else:
for arg, (arg_np, is_tensor) in zip(args, res):
if isinstance(arg, Tensor):
# use copy_storage_ to avoid strided tensor copy which
# isn't supported
arg.copy_storage_(from_numpy(arg_np))


def _run_kernel_in_spawn_process_func(self, q: multiprocessing.Queue, name: str, launch: LaunchParam,
arg_proc_metas: List[Tuple[Union[np.ndarray, int, float, List[int], List[float],
Tuple[float, ...], Tuple[int, ...]], bool]]):
launch.stream = 0
args = []
arg_pairs = []
for arg, is_tensor in arg_proc_metas:
if is_tensor:
assert isinstance(arg, np.ndarray)
pair = (arg, from_numpy(arg).cuda())
arg_pairs.append(pair)
args.append(pair[-1])
else:
args.append(arg)
try:
self.run_kernel(name, launch, *args)
for arg, ten in arg_pairs:
arg[:] = ten.cpu().numpy()
q.put(arg_proc_metas)
except:
ss = io.StringIO()
traceback.print_exc(file=ss)
q.put(ss.getvalue())

def run_kernel(self, name: str, launch: LaunchParam,
*args: Union[Tensor, int, float, List[int], List[float],
Tuple[float, ...], Tuple[int, ...]],
Expand Down
3 changes: 3 additions & 0 deletions cumm/tensorview_bind.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,6 +612,9 @@ def bind_tensorview(self):
.def("copy_", [](tv::Tensor& t, const tv::Tensor& other, tv::Context ctx) -> void{
t.copy_(other, ctx);
}, py::arg("other"), py::arg("ctx") = tv::Context())
.def("copy_storage_", [](tv::Tensor& t, const tv::Tensor& other, tv::Context ctx) -> void{
t.copy_storage_(other, ctx);
}, py::arg("other"), py::arg("ctx") = tv::Context())
.def("copy_2d_pitched_", [](tv::Tensor& t, const tv::Tensor& other, tv::Context ctx) -> void{
t.copy_2d_pitched_(other, ctx);
}, py::arg("other"), py::arg("ctx") = tv::Context())
Expand Down
8 changes: 8 additions & 0 deletions cumm/tensorview_bind_anno.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ class Tensor:
def copy_(self, other: "Tensor", ctx: Context) -> None:
...

@overload
def copy_storage_(self, other: "Tensor") -> None:
...

@overload
def copy_storage_(self, other: "Tensor", ctx: Context) -> None:
...

@overload
def zero_(self) -> "Tensor":
...
Expand Down
3 changes: 1 addition & 2 deletions include/tensorview/core/printf2.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ template <class T, size_t N> struct type_to_format<T[N]> {
};

template <size_t N> struct type_to_format<const char[N]> {
static constexpr auto value = type_to_format<const char*>::value;
static constexpr auto value = type_to_format<const char *>::value;
};

template <char Sep, class... Ts> struct types_to_format;
Expand Down Expand Up @@ -217,7 +217,6 @@ TV_HOST_DEVICE_INLINE void printf2_array(T const (&arg)[N], Ts &&...args) {
std::forward<Ts>(args)...);
}


template <char Sep = ' ', unsigned Tx = 0, class... Ts>
TV_HOST_DEVICE_INLINE void printf2_array_once(Ts &&...args) {
#if defined(__CUDA_ARCH__)
Expand Down
8 changes: 4 additions & 4 deletions include/tensorview/cuda/launch.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,27 +30,27 @@ constexpr int CUDA_MAX_GRID = 65535;

// CUDA: number of blocks for threads.

inline int getNumThreads(const int N) {
inline int64_t getNumThreads(const int64_t N) {
if (N > CUDA_NUM_THREADS) {
return CUDA_NUM_THREADS;
}
return DivUp(N, 32) * 32;
}

template <size_t MaxNumThreads> inline int getNumThreadsEx(const int N) {
template <size_t MaxNumThreads> inline int64_t getNumThreadsEx(const int64_t N) {
if (N > MaxNumThreads) {
return MaxNumThreads;
}
return DivUp(N, 32) * 32;
}

inline int getBlocks(const int N) {
inline int64_t getBlocks(const int64_t N) {
TV_ASSERT_RT_ERR(N > 0,
"CUDA kernel launch blocks must be positive, but got N=", N);
return DivUp(N, getNumThreads(N));
}

template <size_t MaxNumThreads> inline int getBlocksEx(const int N) {
template <size_t MaxNumThreads> inline int64_t getBlocksEx(const int64_t N) {
TV_ASSERT_RT_ERR(N > 0,
"CUDA kernel launch blocks must be positive, but got N=", N);
return DivUp(N, getNumThreadsEx<MaxNumThreads>(N));
Expand Down
68 changes: 68 additions & 0 deletions include/tensorview/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,8 @@ enum DeviceType {
kDeviceCUDA = 0,
};

class Tensor;

namespace detail {

using dtype_collection_t =
Expand Down Expand Up @@ -152,7 +154,9 @@ using all_int_tensor_types_t =
std::tuple<int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t,
uint64_t>;


template <typename T> class TensorStorage {
friend Tensor;
public:
TensorStorage(size_t size, int device = -1, bool managed = false,
bool pinned = false)
Expand Down Expand Up @@ -1340,6 +1344,70 @@ struct Tensor {
}
}

void copy_storage_(const Tensor &tensor, Context ctx = Context()) {
writable_check();
TV_ASSERT_RT_ERR(!this->empty() && !tensor.empty(), "must not empty");
TV_ASSERT_RT_ERR(this->storage_->size() == tensor.storage_->size(), "storage must have same size", this->shape(), tensor.shape(), this->storage_->size(), tensor.storage_->size());
if (this->device() == -1 && tensor.device() == -1) {
#ifdef TV_CUDA
// use memcpy instead to avoid cuda context init
if (this->pinned()) {
if (ctx.has_cuda_stream()) {
host2host(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size(),
ctx.cuda_stream());

} else {
host2host(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size());
}
} else {
std::copy(tensor.storage_->ptr_,
tensor.storage_->ptr_ + this->storage_->size(),
this->storage_->ptr_);
}
#else
std::copy(tensor.storage_->ptr_,
tensor.storage_->ptr_ + this->storage_->size(),
this->storage_->ptr_);
#endif
}
#ifdef TV_CUDA
else if (device() >= 0 && tensor.device() == -1) {
if (ctx.has_cuda_stream()) {
host2dev(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size(), ctx.cuda_stream());

} else {
host2dev(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size());
}

} else if (device() == -1 && tensor.device() >= 0) {
if (ctx.has_cuda_stream()) {
dev2host(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size(), ctx.cuda_stream());

} else {
dev2host(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size());
}
} else if (device() >= 0 && tensor.device() >= 0) {
if (ctx.has_cuda_stream()) {
dev2dev(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size(), ctx.cuda_stream());
} else {
dev2dev(this->storage_->ptr_, tensor.storage_->ptr_,
this->storage_->size());
}
}
#endif
else {
TV_THROW_RT_ERR("only support cpu tensor");
}
}


void copy_2d_pitched_(const Tensor &tensor, Context ctx = Context()) {
writable_check();
TV_ASSERT_RT_ERR(!this->empty() && !tensor.empty(), "must not empty");
Expand Down
5 changes: 3 additions & 2 deletions test/cutest_nvrtc_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ def test_nvrtc_std():

inliner.kernel_1d("nvrtc_std", 1, 0,
f"""
assert(int(i) < 0);
float x = 0.356632;
float y = 0.346854;
float z = 0.998650;
Expand All @@ -24,9 +25,9 @@ def test_nvrtc_std():
""")
print(inliner.get_nvrtc_kernel_attrs("nvrtc_std"))
# print(inliner.get_nvrtc_kernel_attrs("nvrtc_std"))

mod = CummNVRTCModule([TensorViewNVRTCDev()], verbose=True)
# mod = CummNVRTCModule([TensorViewNVRTCDev()], verbose=True)


if __name__ == "__main__":
Expand Down
Loading

0 comments on commit 9e11bf9

Please sign in to comment.