New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[RFC] Supporting Eager Mode via torch.compile #115545
Comments
We already have precedence for something like that in sparse if I'm not mistaken |
@malfet , do you mean the sparse operations have enabled such a feature? I will double-check it. |
Currently, I can run some elementwise cases by providing a dedicated registration API like |
…t torch.compile-based aten implemantion" This PR is a follow-up of RFC #115545. In this PR, we are trying to provide a registration mode to implement a single aten operation on the top of `torch.compile` and then register to aten. By now, the Python-based aten kernel implementation assumes the hermetic Python object. For `torch.compile`-based aten kernel implementation, the assumption will be broken. Because > While HermeticPyObject was enabled, we attempted to create a tensor subclass with __torch_dispatch__. This violates the invariant that operations in HermeticPyObject have equivalent C++ implementations. [ghstack-poisoned]
…based aten implemantion" This PR is a follow-up of RFC #115545. In this PR, we are trying to provide a registration mode to implement a single aten operation on the top of `torch.compile` and then register to aten. By now, the Python-based aten kernel implementation assumes the hermetic Python object. For `torch.compile`-based aten kernel implementation, the assumption will be broken. Because > While HermeticPyObject was enabled, we attempted to create a tensor subclass with __torch_dispatch__. This violates the invariant that operations in HermeticPyObject have equivalent C++ implementations. [ghstack-poisoned]
cc @jbschlosser on hermetic. I'll pipe up here later |
It seems like there are two things going on here:
The main difference between (1) and (2) is if the final function is an aten operator (2) or something else (1). It sounds like this issue is pursuing (2), but I want to challenge that for a bit. Why can't we just do (1)? Concretely, the proposal would be to have something like: @faster_compile(fullgraph=True, backend="inductor")
def my_activation(x):
return torch.where(x > 0, x ** 2, -0.5 * x)
# accessible via torch
torch.my_activation(x) where |
I see (1) is a superset of (2) where aten op is a special case of an arbitrary op of (1)? So I guess the answer is yes? In general, I think (1) is a reasonable use case where one wants to implement a custom op efficiently by simply constructing a sequence of aten ops in a device-agnostic manner. As an example, we are proposing such a light-weight optimization to bitsandbytes for some quant and dequant ops: TimDettmers/bitsandbytes#894. This would relieve us from manually implementing the native kernels. But the problems mentioned in the RFC still hold, in particular, the python and compilation overhead.
Why do you think option 2 would increase the number of aten operators? We do not intend to increase the number but just wanted to support existing aten operators. |
In terms of (1), the custom function can be accessible via torch, but I'm kind of concerned it might not be able to support all cases. Suppose a code block is as follows. import torch
a = torch.randn(10, device="xpu")
b = torch.randn(10, device="xpu")
// (EIKAN: I suppose the "add" here is not routed to the torch.)
c = a.add(b) May I know the idea to support "add" w/o any code change? Or something like this? Does it mean users need to change the script? import torch
a = torch.randn(10, device="xpu")
b = torch.randn(10, device="xpu")
@faster_compile(fullgraph=True, backend="inductor")
def my_add(x, y):
return torch.add(x, y)
# accessible via torch
c = torch.my_add(x, y) |
Okay, I think I see what's going on. Is this RFC for using torch.compile to implement new operators in PyTorch, or is it to help alternative backends (like XPU) gain coverage for existing PyTorch operations? Based on the replies, it sounds like torch.compile supports XPU, but we do not support XPU in eager-mode, so you want to somehow use torch.compile to add eager-mode support for XPU. Is my understanding correct? If this is the case, why are we registering the implementation to DispatchKey::CPU in the example? I would expect that we only register the torch.compile'd function to a DispatchKey for XPU (if that exists) |
Yes
Yes
I am sorry to confuse you by taking the CPU dispatch key as an example, as XPU is still WIP and cannot verify the idea. I will update the example code to avoid confusion. |
@zou3519 , thank you for pointing out the part that led to ambiguity in the idea. I have updated the example accordingly. |
This was some related discussion on #116368 |
I think #116996 will resolve the immediate tactical problem regarding hermetic mode However, I still have some questions about the proposed implementation strategy here. In particular, how much overhead per eager mode call is acceptable overhead? A big reason why most of PyTorch moved into C++ was to reduce eager mode overhead. With the prototype flow here, we have to dispatch back into Python for each op, and that's going to be expensive. Do you have a target latency here? |
I collected the overhead info (CPU is Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz; PyTorch commit: e8a9d08)
we plan to add a cache mechanism to mitigate the overhead. |
Did you actually measure this? Because the overhead is definitely not negligible lol |
I used the Python profiler to collect the performance data. The subsequent execution w/o re-compilation was 1ms for a simple case. I used the # Invoke torch.compile
x - y
sortby = SortKey.CUMULATIVE
with cProfile.Profile() as pr:
# Invoke torch.compile
x - y
s = io.StringIO()
ps = pstats.Stats(pr, stream=s).sort_stats(sortby)
ps.print_stats()
print(s.getvalue()) I can improve the profile accuracy to microseconds and then take another look. Meanwhile, I will profile some real workloads nut not just a single Aten operation. |
By the way, there should be no big difference compared to the C++ ATen operation run if we add a cache mechanism. Cache hit: Python(user script) -> C++ kernel(The kernel encapsulates a python function w/ torch.compile) -> C++(AOTI) |
I don't expect pstats to have the resolution you need. Just run the thing in a loop N iterations and measure the overall time. And compare it with CPU tensor |
Sorry for the late reply. I used the high resolution to measure the overhead on A10. And yes, the overhead is NOT negligible. @ezyang, thanks for correcting me. We are narrowing down the detailed overhead to get deep insights and will update here. |
A flow that would get good perf is if you register C++ functions to dispatcher, which directly call into CPP wrapper generated kernels. However, there is a missing piece which is guard dispatch. @anijain2305 C++ guards could be part of the solution here, but there's still some more glue that would need to be implemented since they aren't intended to be used this way. |
…t torch.compile-based aten implemantion" This PR is a follow-up of RFC #115545. In this PR, we are trying to provide a registration mode to implement a single aten operation on the top of `torch.compile` and then register to aten. By now, the Python-based aten kernel implementation assumes the hermetic Python object. For `torch.compile`-based aten kernel implementation, the assumption will be broken. Because > While HermeticPyObject was enabled, we attempted to create a tensor subclass with __torch_dispatch__. This violates the invariant that operations in HermeticPyObject have equivalent C++ implementations. [ghstack-poisoned]
…based aten implemantion" This PR is a follow-up of RFC #115545. In this PR, we are trying to provide a registration mode to implement a single aten operation on the top of `torch.compile` and then register to aten. By now, the Python-based aten kernel implementation assumes the hermetic Python object. For `torch.compile`-based aten kernel implementation, the assumption will be broken. Because > While HermeticPyObject was enabled, we attempted to create a tensor subclass with __torch_dispatch__. This violates the invariant that operations in HermeticPyObject have equivalent C++ implementations. [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 Pull Request resolved: #121387 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" }, { "is_symbolic": "false", "device_type": "cuda", "dtype": "torch.float32", "sizes": "[1024, 1024]", "strides": "[1024, 1]" } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC pytorch#115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - pytorch#116368 Pull Request resolved: pytorch#121387 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
This PR is a follow-up of RFC pytorch#115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - pytorch#116368 Pull Request resolved: pytorch#121387 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
This PR is a follow-up of RFC pytorch#115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - pytorch#116368 Pull Request resolved: pytorch#121387 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] }, { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] }, { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
…compile-for-eager" This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] }, { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we are trying to enable a cache mechanism to accelerate **eager-through-torch.compile**. When **eager-through-torch.compile** is enabled, we will store a persistent config to cache the kernel information for the aten operation. The persistent config consists of two parts - meta_info and kernel_path. - meta_info: The input tensors' shape, stride, device type, data type, and symbolic flag. - kernel_path: The path of the kernel produced by Inductor. When an aten operation is registered, the `kernel_holder` will load the persistent config and parse it to build the cache map; the meta_info is key, and the kernel library is the value. Currently, this PR only supports static shape to guard the kernel. Take a `mul` as an example. ```python class MulKernel: def __init__(self) -> None: pass def __call__(self, *args: Any, **kwargs: Any) -> Any: with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.Python, False): opt_fn = torch.compile(torch.ops.aten.mul, dynamic=False, options={ "aot_inductor.eager_mode": True, "aot_inductor.eager_op_name": "mul_Tensor" } ) return opt_fn(*args, **kwargs) torch_compile_op_lib_impl = torch.library.Library("aten", "IMPL") _, overload_names = torch._C._jit_get_operation("aten::mul") schema = torch._C._get_schema("aten::mul", overload_name) reg_name = schema.name if schema.overload_name: reg_name = f"{reg_name}.{schema.overload_name}" torch_compile_op_lib_impl.impl( reg_name, MulKernel(), "CUDA", compile_mode=True) a = torch.randn(1024, 1024, device=device) b = torch.randn(1024, 1024, device=device) warm_up_iter = 1000 iter = 10000 fn = torch.mul # Warm up for _ in range(warm_up_iter): fn(a, b) # Collect performance beg = time.time() for _ in range(iter): fn(a, b) end = time.time() print(f"E2E run: {end - beg}") ``` It will produce the config as follows. ```json [ { "meta_info": [ { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] }, { "is_symbolic": false, "device_type": "cuda", "dtype": "torch.float32", "sizes": [1024, 1024], "strides": [1024, 1] } ], "kernel_path": "/tmp/torchinductor_eikan/e4/ce4jw46i5l2e7v3tvr2pyglpjmahnp7x3hxaqotrvxwoeh5t6qzc.so" } ] ``` Performance-wise, we collected mul.Tensor through torch.compile w/ 10000 runs(e2e). The data is as follows. And we will collect data when we support dynamic shape. - Eager: ~266.11ms - W/O Cache: ~3455.54ms - W/ Cache and Cache Miss: ~3555.3ms - W/ Cache and Cache Hit: ~267.12ms Hardware: - CPU: Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz - GPU: CUDA A10 Software: - PyTorch Version: 39df084 - GPU Driver Version: 525.147.05 - CUDA Version: 12.0 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 cc voznesenskym penguinwu jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng wenzhe-nrv jiayisunx peterbell10 ipiszy yf225 chenyang78 kadeng muchulee8 aakhundov ColinPeppler amjames desertfire chauhang [ghstack-poisoned]
This PR is a follow-up of RFC #115545. In this PR, we intend to provide a registration API dedicated to eager-through-torch.compile. The major workflow of this API will be as follows. - Load cache - Check cache according to the input tensors - Cache Hit: Run the cached kernel directly - Cache Miss: Run the AOTI to produce kernel and run the produced kernel. If AOTI fails to produce the kernel, invoke the python fallback function. Currently, this PR always fallback to python kernel now and cache mechanism will be implemented in another PR - #116368 Differential Revision: [D57164385](https://our.internmc.facebook.com/intern/diff/D57164385) Pull Request resolved: #121387 Approved by: https://github.com/desertfire, https://github.com/jansel, https://github.com/zou3519, https://github.com/jgong5
🚀 The feature, motivation and pitch
Motivation
By now, PyTorch has defined > 2000 operations. Meanwhile, PyTorch users start from eager mode. If a new backend intends to support PyTorch eager mode, it means that the backend, like Intel GPU, has to implement all these operations. Otherwise, the users may encounter
unimplemented
errors if they run PyTorch on the new backend. This scenario presents two substantial challenges - engineering effort and maintenance effort.Given these challenges, we propose an alternate technical pathway for eager mode support through torch.compile, offering two main advantages over the traditional implementation approach:
torch.compile
or wheretorch.compile
explicitly falls back to ATen.torch.compile
allows generating performant backend code without the need for backend-specific implementation of each ATen operation.Approach
We propose compiling single aten operations using
torch.compile
and registering them in PyTorch on the Python side. This approach simplifies backend implementation and maintenance. Here's an illustrative example:Detail Design
The
torch.compile
invocation needs to be wrapped as a general Python function and then registered to the torch dispatcher. The mechanism ensures robust functionality.Meanwhile, we need to accelerate the performance by mitigating Python and
torch.compile
overhead. We propose a cache mechanism for this purpose.So, the detailed design focuses on the registration and cache.
Registration
To trigger
torch.compile
to produce a C++/Triton kernel, we always need to register a Python kernel for each ATen operation, just like the above example code.But we do not need to always invoke the Python kernel if its corresponding
torch.compile
kernel has been produced.In addition, the context switch between Python and C++ introduces additional overhead like Python GIL. And the performance of a Python implementation is worse than its equivalent C++ implementation in general.
Therefore, we always prefer to avoid running code in the Python world.
To achieve this goal, we wrap the Python kernel as a C++ function/class just like what torch has done for the other Python kernels(
PythonKernelHolder
)torch.Library.Library.impl
to register a Python kernel, the Python kernel will be wrapped byPythonKernelHolder
and then cast thePythonKernelHolder
as aBoxedFunctor
.Definitely, we can reuse the mechanism and customize it a little bit to accelerate the operation by introducing a cache mechanism. The following section will elaborate on the detailed design of the cache mechanism.
Suppose a class is named
AOTICompatiblePythonKernelHolder
for this purpose, its major features are as follows.torch.compile
-ed kernel is availabletorch.compile
to produce kerneltorch.compile
and cache itIts pseudo-code could be as follows.
In addition, all the kernels produced by
torch.compile
for eager should be wrapped by C++(CppWrapper
) and loaded byAOTIModelContainerRunner
to mitigate the Python overhead.Cache
Due to the overhead of
torch.compile
is non-negligible (CPU is Intel(R) Xeon(R) Platinum 8260 CPU @ 2.40GHz; PyTorch commit: e8a9d08),torch.compile
. This overhead is a one-time occurrence, triggered only during the first compilation.torch.compile
setup, the overhead for executing a specific Aten operation for the first time is approximately 115 milliseconds.Therefore, we introduce a cache mechanism to mitigate the overhead.
By the way, the cache should be a persistent cache. It can avoid producing kernel code multiple times for a particular aten operation if the input parameters have the same meta information when compiling the operation.
Suppose a
Mul
operation, the torch.compiled-based kernel w/ cache could be as follows from the dispatch perspective.Cache Key
The parameters of an aten operation will be packed as
torch.jit.stack
. Therefore, we can unpack the parameters one by one and extract information fromtorch.jit.IValue
to constitute the cache key for each operator.In addition, each
torch.compile
-ed kernel is dedicated to a particular aten operation. It means that the kernel knows the exact semantics of its implementation. And its schema is informative. Based on this information, the cache key could be constituted by the following factors.Based on these factors, there are two options to establish the cache key:
Option 1 is the common practice, but it may introduce additional overhead as the hash algorithm and hash key lookup might be time-consuming.
In terms of Option 2, it will introduce additional complexity regarding design and implementation compared to Option 1. However, the overhead should be better than option 1 as it just needs to compare all data fields of cache entry.
We will evaluate the overhead and then determine which option is the best one.
In terms of complex data structures like Tensor list and integer list, we will support them gradually.
Cache Load
Regarding the cache loading, there also are two different options here.
torch.compile
-based aten operation being invokedDue to the persistent cache loading, it may introduce additional overhead and take a longer time; therefore, option 1 may slightly impact the user experience as the torch loading has already taken a longer time to finish its initialization. But the side effect is the loading may be useless as
torch.compile
-based aten operations may always be not invoked.Compared to option 1, option 2 is a trade-off solution; it ensures the cache is always useful for the current process. However, it may impact the performance at runtime to initialize the cache during a model/python script running.
We prefer option 1 from the performance perspective.
Cache Lookup
The cache lookup mechanism depends on the cache key design. And the implementation should be straightforward regardless of which one we take.
And there are two scenarios we need to handle – cache hit and cache miss.
torch::jit::stack
to the input parameter ofAOTIModelContainerRunner
and then return the result just like the current Aten C++ implementation.torch.compile
to produce the kernel wrapped byCppWrapper
AOTIModelContainerRunner
From the C++ and Python perspective:
torch.compile
and add an entry to the cache) -> C++ (torch.compile generated operator)Cache Store
The cache mechanism generates a unique key for each kernel produced by
torch.compile
. Regardless of the cache key being a hash key or a data structure, it will be serialized to the disk to accelerate the next bootup, just like Inductor has done for Triton kernel tuning.Beyond that, we need to highlight how to organize the kernels produced by
torch.compile
.We will create a dedicated directory for each ATen operation. The name combines the qualified name and the overload name. So, the directory could be something like
{qualified_name}_{overload_name}
. Takeaten.add
as an example; the directory name could beaten_add_int
. The motivation is that we do not need to add the operation name to the cache key and then avoid string comparison.Currently, the default Inductor kernel cache is placed at
/tmp/torchinductor_{account_name}
. It will be swept out for each boot. To avoid this penalty, we’d prefer to store the cache in the non-temp folder.Summary
This document delves into the implementation of PyTorch Eager mode support using
torch.compile
. Currently, PyTorch has defined over 2000 operations, and for a new backend to support PyTorch Eager mode, it must implement these operations, or users might encounter unimplemented errors. To address this challenge, we propose the method of compiling single ATen operations usingtorch.compile
and registering them. This approach allows for dynamic compilation and optimization of operations, offering more efficient support for different hardware backends.The document details the registration, cache mechanism, cache key design, and steps that new backend maintainers need to take.
Overall, this proposal aims to simplify the maintenance of PyTorch backends while enhancing efficiency and user experience. Through this approach, it becomes easier to introduce new hardware backends to PyTorch.
Current Status and Challenges
We are working on the exploration and have enabled the above example by providing another alternative registration API for POC. We will support more operations and both inference and training to check if there are more feature gaps.
Besides the feature implementation, there are some challenges.
Regarding these challenges, we may address them by improving the persistent disk cache of
torch.compile
and dynamic support.Alternatives
No response
Additional context
Currently, registering a Python kernel for a particular ATen operation will trigger hermetic Python object assumption when ATen operation dispatching.
pytorch/torch/csrc/utils/python_dispatch.cpp
Line 174 in b88be16
However, the assumption will be broken if the Python kernel invokes
torch.compile
to implement its logic. Take the above code snippet as an example,torch.compile
needs to buildFakeTensor
while theFakeTensor
contains__torch_dispatch__
. It means thatcheck_has_torch_dispatch
will returnTrue
. But the logic requires it to beFalse
to ensure "that operations in HermeticPyObject have equivalent C++ implementations."pytorch/torch/csrc/autograd/python_variable.cpp
Lines 220 to 229 in b88be16
Therefore, it cannot pass the check -
pytorch/torch/csrc/autograd/python_variable.cpp
Lines 1962 to 1970 in b88be16
To address the issue, we can provide a dedicated registration API to indicate a Python kernel to invoke
torch.compile
for its implementation.cc @ezyang @msaroufim @wconstab @bdhirsh @anijain2305 @zou3519
The text was updated successfully, but these errors were encountered: