Skip to content

Commit

Permalink
v0.4.10: add simple perf tools
Browse files Browse the repository at this point in the history
  • Loading branch information
FindDefinition committed Aug 9, 2023
1 parent b816f6f commit 387bd08
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Changelog
## [0.4.11] - 2023-08-09
### Added
- add simple perf tools

## [0.4.10] - 2023-06-15
### Fixed
- fix a bug in when compile code with arch < sm_75
Expand Down
20 changes: 13 additions & 7 deletions cumm/inliner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from cumm.common import TensorViewKernel
import enum
from pccm.utils import get_qualname_of_type
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, ContextManager, Dict, List, Optional, Set, Tuple, Type, Union
import numpy as np
from cumm import dtypes, tensorview as tv
from cumm.common import TensorViewNVRTC, GemmBasic, TensorViewViewClass
Expand Down Expand Up @@ -116,7 +116,7 @@ class CUDAMode(enum.Enum):
def torch_tensor_to_tv(ten,
dtype: Optional[int] = None,
shape: Optional[List[int]] = None):
assert ten.is_contiguous(), "must be contiguous tensor"
# assert ten.is_contiguous(), "must be contiguous tensor"
ptr = ten.data_ptr()
device = ten.device
if device.type == "cpu":
Expand Down Expand Up @@ -229,14 +229,16 @@ def __init__(self,
verbose_path: str = "",
measure_time: bool = False,
is_cpu: bool = False,
capture_tensor_as_tview: bool = False) -> None:
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None) -> None:
self.mode = mode
self.launch = launch
self.verbose = verbose
self.verbose_path = verbose_path
self.measure_time = measure_time
self.is_cpu = is_cpu
self.capture_tensor_as_tview = capture_tensor_as_tview
self.perf_context = perf_context


_NVRTC_FUNC_NAME = f"{PCCM_INLINE_NAMESPACE}::{PCCM_INLINE_FUNCTION_NAME}"
Expand Down Expand Up @@ -285,7 +287,6 @@ def __init__(
self._remote_addr = remote_addr



def get_nvrtc_module(self, name: str) -> Optional[CummNVRTCModule]:
for k, v in self.modules.items():
if name == k[1]:
Expand Down Expand Up @@ -394,7 +395,7 @@ def run_func(self,
user_args: Optional[_NVRTCInlineParams] = None):
assert user_args is not None
launch = user_args.launch
return func.run_kernel(_NVRTC_FUNC_NAME, launch, *args)
return func.run_kernel(_NVRTC_FUNC_NAME, launch, *args, perf_context=user_args.perf_context)

def kernel_raw(self,
name: str,
Expand All @@ -403,11 +404,13 @@ def kernel_raw(self,
verbose_path: str = "",
disable_cache: bool = False,
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None,
*,
_frame_cnt: int = 2):
verbose = verbose_path != ""
user_arg = _NVRTCInlineParams(CUDAMode.KernelRaw, param, verbose,
verbose_path, capture_tensor_as_tview=capture_tensor_as_tview)
verbose_path, capture_tensor_as_tview=capture_tensor_as_tview,
perf_context=perf_context)
if capture_tensor_as_tview:
if not isinstance(code, pccm.FunctionCode):
code_pccm = pccm.code()
Expand Down Expand Up @@ -446,13 +449,16 @@ def kernel_1d(self,
verbose_path: str = "",
disable_cache: bool = False,
capture_tensor_as_tview: bool = False,
perf_context: Optional[ContextManager] = None,
*,
_frame_cnt: int = 2):
verbose = verbose_path != ""
num = int(num)
user_arg = _NVRTCInlineParams(CUDAMode.Kernel1D,
self.get_1d_param(num, stream=stream),
verbose, verbose_path,
capture_tensor_as_tview=capture_tensor_as_tview)
capture_tensor_as_tview=capture_tensor_as_tview,
perf_context=perf_context)
additional_args = {
_CUMM_KERNEL_1D_SIZE_NAME: num,
}
Expand Down
91 changes: 91 additions & 0 deletions cumm/perftools.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import contextlib
from typing import Dict, List, Optional, Tuple
from cumm import tensorview as tv

import contextvars


class PerfContext:
def __init__(self) -> None:
self._ns_stack: List[str] = []
self._measures: Dict[Tuple[str, ...], List[Tuple[tv.CUDAEvent,
tv.CUDAEvent]]] = {}
self._print_pair: List[Tuple[str, float]] = []
self.perf_result: Dict[Tuple[str, ...], List[float]] = {}
self._enable_controlled_by_root: Optional[bool] = None


PERF_CONTEXT: contextvars.ContextVar[
Optional[PerfContext]] = contextvars.ContextVar("perf_context",
default=None)


@contextlib.contextmanager
def __enter_perf_conetxt(perf_ctx: PerfContext):
token = PERF_CONTEXT.set(perf_ctx)
try:
yield perf_ctx
finally:
PERF_CONTEXT.reset(token)


@contextlib.contextmanager
def perf_context(name: str,
*,
stream: int = 0,
enable: bool = True,
print_result: bool = True,
control_child_enable: bool = False):
ctx = PERF_CONTEXT.get()
enter_null = contextlib.nullcontext()
is_root = False
root_key = None
if ctx is None:
ctx = PerfContext()
if control_child_enable:
ctx._enable_controlled_by_root = enable
is_root = True
root_key = (name, )
enter_null = __enter_perf_conetxt(ctx)
if ctx._enable_controlled_by_root is not None:
enable = ctx._enable_controlled_by_root
if not enable:
yield None
return
ctx._ns_stack.append(name)
root_time = 1
try:
with enter_null:
ev_start = tv.CUDAEvent("")
ev_stop = tv.CUDAEvent("")
ev_start.record(stream)
yield ctx
ev_stop.record(stream)
key = tuple(ctx._ns_stack)
if key not in ctx._measures:
ctx._measures[key] = []
ctx._measures[tuple(ctx._ns_stack)].append((ev_start, ev_stop))

finally:
ctx._ns_stack.pop()
if is_root:
all_times: Dict[Tuple[str, ...], List[float]] = {}
for key, data in ctx._measures.items():
for pair in data:
pair[0].sync()
pair[1].sync()
times = [tv.CUDAEvent.duration(x[0], x[1]) for x in data]
all_times[key] = times
if key == root_key:
root_time = times[0]
ctx.perf_result = all_times
ctx._measures.clear()
if print_result:
for key, data in all_times.items():
time = sum(data, 0)
if len(key) > 1:
print(
f"[{key[-1]}@{len(data)}]({(time / root_time) * 100:.3f}%): {time:.4}"
)
else:
print(f"[{key[-1]}@{len(data)}]: {time:.4}")
10 changes: 7 additions & 3 deletions cumm/tensorview/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from enum import Enum
import enum
from functools import partial
from typing import Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union, ContextManager

import numpy as np
from pccm import Argument
Expand Down Expand Up @@ -247,7 +247,8 @@ def run_kernel_unchecked(self, name: str, launch: LaunchParam, *args: Tuple[Tens

def run_kernel(self, name: str, launch: LaunchParam,
*args: Union[Tensor, int, float, List[int], List[float],
Tuple[float, ...], Tuple[int, ...]]):
Tuple[float, ...], Tuple[int, ...]],
perf_context: Optional[ContextManager] = None):
metas: List[NVRTCArgMeta] = [NVRTCArgMeta(NVRTCArgBaseType.Scalar, False, -1, [])] * len(args)
if self.name_to_meta:
assert name in self.name_to_meta, f"can't find your kernel {name}, available: {self.name_to_meta.keys()}"
Expand Down Expand Up @@ -316,7 +317,10 @@ def run_kernel(self, name: str, launch: LaunchParam,
else:
assert isinstance(arg, Tensor)
kernel_args.append((arg, _NVRTCModule.kTensor))

if perf_context is not None:
with perf_context:
return self._mod.run_kernel(name, launch.blocks, launch.threads,
launch.smem, launch.stream, kernel_args)
return self._mod.run_kernel(name, launch.blocks, launch.threads,
launch.smem, launch.stream, kernel_args)

Expand Down
5 changes: 3 additions & 2 deletions test/cutest_nvrtc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from cumm.inliner import NVRTCInlineBuilder
from cumm.common import TensorView, TensorViewCPU, TensorViewNVRTCHashKernel, TensorViewArrayLinalg, EigenLib
import numpy as np
from cumm.perftools import perf_context
# @lineprof.lineprof_wrapper_cpp
def test_nvrtc():

Expand Down Expand Up @@ -87,7 +88,7 @@ def test_nvrtc2():
table.insert(5, 1);
table2.insert(5, 1);
""")

inliner.kernel_1d("wtf3", 1, 0, f"""
namespace op = tv::arrayops;
tv::array<float, 3> a{{2.010012, 0.530250, 0.630409}};
Expand All @@ -108,7 +109,7 @@ def test_nvrtc2():
tv::array_nd<float, 4, 4> imu2enu{{}};
auto corners2 = corners.op<op::transform_3d>(imu2enu);
""")
""", perf_context=perf_context("wtf3"))

print(a.cpu().numpy())

Expand Down

0 comments on commit 387bd08

Please sign in to comment.