Skip to content

Commit

Permalink
[halide-backend] Add GPU support
Browse files Browse the repository at this point in the history
ghstack-source-id: 7a791d8e523fc31585d6cb8b295b8c66e33dc32b
Pull Request resolved: pytorch#127506
  • Loading branch information
jansel committed Jun 21, 2024
1 parent 051724d commit 6ca7919
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 17 deletions.
11 changes: 9 additions & 2 deletions test/inductor/test_halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@


make_halide = config.patch(
cpu_backend="halide",
fallback_random=True, # TODO(jansel): support random
{
"cpu_backend": "halide",
"cuda_backend": "halide",
"fallback_random": True, # TODO(jansel): support random
}
)


Expand Down Expand Up @@ -111,6 +114,10 @@ def generate(g):
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
CpuHalideTests = make_halide(test_torchinductor.CpuTests)

if test_torchinductor.HAS_GPU:
SweepInputsGPUHalideTest = make_halide(test_torchinductor.SweepInputsGPUTest)
GPUHalideTests = make_halide(test_torchinductor.GPUTests)

if __name__ == "__main__":
if HAS_CPU and not IS_MACOS and HAS_HALIDE:
run_tests(needs="filelock")
9 changes: 6 additions & 3 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,7 +749,9 @@ def is_cpp_backend(device):


def is_halide_backend(device):
return getattr(device, "type", device) == "cpu" and config.cpu_backend == "halide"
if getattr(device, "type", device) == "cpu":
return config.cpu_backend == "halide"
return config.cuda_backend == "halide"


def skip_if_halide(fn):
Expand Down Expand Up @@ -4602,7 +4604,7 @@ def forward(self, l_input_: torch.Tensor):
mod = Repro().to(device=GPU_TYPE)
o1 = mod(inp)
o2 = torch.compile(mod)(inp)
self.assertEqual(o1, o2)
self.assertEqual(o1, o2, rtol=1e-3, atol=1e-3)

@patch.object(config.trace, "enabled", True)
def test_layer_norm(self):
Expand Down Expand Up @@ -8415,6 +8417,7 @@ def forward(arg38_1, arg81_1, getitem_17, new_zeros_default_4):
self.common(forward, args, atol=1e-5, rtol=1e-5)

@requires_gpu()
@skip_if_halide # cascading accuracy issues due rsqrt fallback
def test_tmp_not_defined_issue3(self):
from torch import device

Expand Down Expand Up @@ -9863,7 +9866,7 @@ def fn(x: torch.Tensor) -> torch.Tensor:
inp2 = torch.as_strided(base2, (64, 64), (64, 1), offset2)
ref2 = fn(inp2)
res2 = fn_c(inp2)
self.assertEqual(ref2, res2)
self.assertEqual(ref2, res2, atol=1e-5, rtol=1e-5)

@requires_gpu()
@config.patch(assume_aligned_inputs=False)
Expand Down
4 changes: 2 additions & 2 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,9 +209,10 @@ def init_backend_registration():

if get_scheduling_for_device("cuda") is None:
# CUDACombinedScheduling combines Triton and CUDA C++ scheduling for CUDA devices via delegation
cuda_backends = {"triton": CUDACombinedScheduling, "halide": HalideScheduling}
register_backend_for_device(
"cuda",
CUDACombinedScheduling,
lambda *args, **kwargs: cuda_backends[config.cuda_backend](*args, **kwargs),
WrapperCodeGen,
CppWrapperCuda,
)
Expand Down Expand Up @@ -1340,7 +1341,6 @@ def python_argdefs(self):
arg_defs.append("ws_ptr")
call_args.append("workspace")
precompile_args.append(self.workspace_arg)

return arg_defs, call_args, precompile_args, arg_types

def aliases(self):
Expand Down
57 changes: 55 additions & 2 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,7 +1100,10 @@ def reduction(

if isinstance(value, tuple):
assert reduction_type == "welford_combine"
raise NotImplementedError("welford_combine")
self.cse.reduction_cache[
cache_key
] = result_tuple = self.welford_combine_impl(*value)
return result_tuple

assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
reduction_vars = {*self.reduction_renames}
Expand Down Expand Up @@ -1154,6 +1157,44 @@ def reduction(
self.cse.reduction_cache[cache_key] = result_var
return result_var

def welford_combine_impl(self, mean, m2, weight):
assert isinstance(mean, HalideCSEVariable) and mean.used_dims is not None
assert isinstance(m2, HalideCSEVariable) and m2.used_dims is not None
assert isinstance(weight, HalideCSEVariable) and weight.used_dims is not None
used_dims = {*mean.used_dims, *m2.used_dims, *weight.used_dims} or {
*self.halide_vars
}
used_dims -= {*self.reduction_renames}
result_var = self.newfunc(self.sort_used_dims(used_dims))
default = [f"hl.cast({x.name}.type(), 0)" for x in (mean, m2, weight)]
pfx = result_var.name
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(default)}])")
self.body.writeline(f"{pfx}_mean_1 = {result_var}[0]")
self.body.writeline(f"{pfx}_m2_1 = {result_var}[1]")
self.body.writeline(f"{pfx}_weight_1 = {result_var}[2]")
self.body.writeline(f"{pfx}_mean_2 = {mean.subs_str(self.reduction_renames)}")
self.body.writeline(f"{pfx}_m2_2 = {m2.subs_str(self.reduction_renames)}")
self.body.writeline(
f"{pfx}_weight_2 = {weight.subs_str(self.reduction_renames)}"
)
self.body.writeline(f"{pfx}_delta = {pfx}_mean_2 - {pfx}_mean_1")
self.body.writeline(f"{pfx}_new_weight = {pfx}_weight_1 + {pfx}_weight_2")
self.body.writeline(
f"{pfx}_w2_over_w = hl.select({pfx}_new_weight == 0.0, 0.0, {pfx}_weight_2 / {pfx}_new_weight)"
)
update = [
f"{pfx}_mean_1 + {pfx}_delta * {pfx}_w2_over_w",
f"{pfx}_m2_1 + {pfx}_m2_2 + {pfx}_delta * {pfx}_delta * {pfx}_weight_1 * {pfx}_w2_over_w",
f"{pfx}_new_weight",
]
self.body.writeline(f"{result_var} = hl.Tuple([{', '.join(update)}])")

unpacked = []
for i in range(3):
unpacked.append(self.newfunc(result_var.used_dims))
self.body.writeline(f"{unpacked[-1]} = {result_var}[{i}]")
return tuple(unpacked)

def genfunc(
self, line, used_dims, *, bounds=ValueRanges.unknown()
) -> HalideCSEVariable:
Expand Down Expand Up @@ -1361,7 +1402,9 @@ def update_index(m):
dims = self.buffer_dimensions[arg.name]
range_hints = []
for i, dim in enumerate(dims):
hint = V.graph.sizevars.size_hint(dim.size, fallback=1)
hint = self._autoscheduler_workarounds(
V.graph.sizevars.size_hint(dim.size, fallback=1)
)
range_hints.append(f"hl.Range(0, {hint})")
if "out" not in arg.name:
code.writeline(f"{arg.name}.dim({i}).set_min(0)")
Expand Down Expand Up @@ -1402,6 +1445,16 @@ def update_index(m):
)
return code.getvalue()

@staticmethod
def _autoscheduler_workarounds(n):
if (
config.halide.scheduler_cuda == "Anderson2021"
and V.graph.scheduler.get_current_device_or_throw().type == "cuda"
):
# workaround https://github.com/halide/Halide/issues/8246
n = max(2, n)
return n

def call_kernel(self, name: str, node=None):
"""Codegen a call to this kernel"""
wrapper = V.graph.wrapper_code
Expand Down
1 change: 1 addition & 0 deletions torch/_inductor/codegen/triton.py
Original file line number Diff line number Diff line change
Expand Up @@ -2659,6 +2659,7 @@ def add_numel_to_call_args_and_grid(self, name, call_args, arg_types, grid):

def call_kernel(self, name: str, node: Optional[IRNode] = None):
wrapper = V.graph.wrapper_code
wrapper.write_triton_header_once()
_, call_args, _, arg_types = self.args.python_argdefs()
grid: List[Any] = []
self.add_numel_to_call_args_and_grid(name, call_args, arg_types, grid)
Expand Down
18 changes: 10 additions & 8 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,15 +524,17 @@ def write_header(self) -> None:
@cache_on_self
def write_triton_header_once(self) -> None:
self.header.splice(
"""
f"""
import triton
import triton.language as tl
from {} import grid, split_scan_grid, start_graph, end_graph
{}
""".format(
triton_heuristics.__name__,
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"),
)
from {triton_heuristics.__name__} import grid, split_scan_grid, start_graph, end_graph
"""
)

@cache_on_self
def write_get_raw_stream_header_once(self) -> None:
self.header.writeline(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)

def add_meta_once(self, meta: TritonMetaParams) -> str:
Expand Down Expand Up @@ -603,7 +605,7 @@ def call(args):
# that stream caching happens per graph instance. this
# is important for nested subgraph codegening.
def write_get_raw_stream(self, device_idx: int, graph=None) -> str:
self.write_triton_header_once()
self.write_get_raw_stream_header_once()
name = f"stream{device_idx}"
self.writeline(f"{name} = get_raw_stream({device_idx})")
return name
Expand Down
7 changes: 7 additions & 0 deletions torch/_inductor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,6 +860,9 @@ class cuda:
# Backend to use for CPU codegen either "cpp" or "halide" (experimental)
cpu_backend = "cpp"

# Backend to use for CUDA codegen either "triton" or "halide" (experimental)
cuda_backend = "triton"


class halide:
# Base halide target to use for CPU devices
Expand All @@ -879,6 +882,10 @@ class halide:
# Controls `debug` flag passed to Halide target
debug = False

# Enable (or fallback on) scan kernels such as cumsum
# Halide autoschedulers struggle with these kernels
scan_kernels = False


# create a directory containing lots of debug information
class trace:
Expand Down

0 comments on commit 6ca7919

Please sign in to comment.