Skip to content

Commit

Permalink
[inductor][cpp] epilogue support for gemm template (pytorch#126019)
Browse files Browse the repository at this point in the history
As part of pytorch#125683, this PR adds the epilogue support for c++ gemm template by reusing the c++ vector codegen on sub-slices of tensors. This is implemented by retracing the epilogue IR nodes with new ranges and offsets. The new `codegen_loop_bodies` and `codegen_functions` methods are added to c++ vector codegen for this purpose. This is leveraged by the `store_output` method of the template kernel for epilogue codegen and store to the final result.

Pull Request resolved: pytorch#126019
Approved by: https://github.com/jansel
  • Loading branch information
jgong5 authored and pytorchmergebot committed May 16, 2024
1 parent 6065a4d commit 7844c20
Show file tree
Hide file tree
Showing 7 changed files with 367 additions and 73 deletions.
135 changes: 131 additions & 4 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# Owner(s): ["oncall: cpu inductor"]
import functools

import sys
import unittest
from unittest.mock import patch

Expand All @@ -17,14 +19,33 @@

from torch.testing._internal.common_utils import IS_MACOS, parametrize, TEST_MKL

try:
try:
from . import test_torchinductor
except ImportError:
import test_torchinductor
except unittest.SkipTest:
if __name__ == "__main__":
sys.exit(0)
raise

check_model = test_torchinductor.check_model

aten = torch.ops.aten


def patches(fn):
def skip_cache(self, choices, name, key, benchmark):
if benchmark is None:
return {}
return benchmark(choices)
timings = benchmark(choices)
for choice, timing in timings.items():
if isinstance(choice, select_algorithm.ExternKernelCaller):
# we intentionally make ATEN kernel slower to cover the cases
# where template kernels are always chosen with fusions applied
# and correctness checks at runtime.
timings[choice] = timing * 1000
return timings

for patcher in [
dynamo_config.patch(verbose=True),
Expand All @@ -49,6 +70,8 @@ def wrapped(*args, **kwargs):


class TestSelectAlgorithm(TestCase):
common = check_model

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
Expand All @@ -67,15 +90,14 @@ def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

@torch.compile
def forward(self, x):
return self.linear(x)

counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
B = (2, batch_size) if input_3d else (batch_size,)
v = torch.randn(*B, in_features).to(dtype=dtype)
mod(v)
self.common(mod, (v,))
self.assertEqual(
counters["inductor"]["select_algorithm_autotune"],
1 if out_features != 1 else 0,
Expand Down Expand Up @@ -104,18 +126,123 @@ def forward(self, x):
counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(in_features, batch_size).to(dtype=dtype)
mod(v.transpose(0, 1))
self.common(mod, (v.transpose(0, 1),))
# TODO(jgong5): support transposed input
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 0)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@parametrize(
"epilogue",
(
"relu",
"gelu",
"silu",
"sigmoid",
"tanh",
"hardswish",
"hardsigmoid",
"leaky_relu",
"hardtanh",
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float)
def test_linear_with_pointwise(self, bias, epilogue, dtype):
batch_size = 384
in_features = 196
out_features = 384

class M(torch.nn.Module):
def __init__(self, bias, epilogue, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
if epilogue == "relu":
self.epilogue = torch.nn.ReLU()
elif epilogue == "gelu":
self.epilogue = torch.nn.GELU()
elif epilogue == "silu":
self.epilogue = torch.nn.SiLU()
elif epilogue == "sigmoid":
self.epilogue = torch.nn.Sigmoid()
elif epilogue == "tanh":
self.epilogue = torch.nn.Tanh()
elif epilogue == "hardswish":
self.epilogue = torch.nn.Hardswish()
elif epilogue == "hardsigmoid":
self.epilogue = torch.nn.Hardsigmoid()
elif epilogue == "leaky_relu":
self.epilogue = torch.nn.LeakyReLU()
elif epilogue == "hardtanh":
self.epilogue = torch.nn.Hardtanh()
elif epilogue == "add":
self.epilogue = lambda x: x + other
elif epilogue == "sub":
self.epilogue = lambda x: x - other
elif epilogue == "mul":
self.epilogue = lambda x: x * other
elif epilogue == "div":
self.epilogue = lambda x: x / other

def forward(self, x):
return self.epilogue(self.linear(x))

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
self.common(mod, (v,))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@dtypes(torch.float)
def test_linear_with_transpose(self, bias, dtype):
batch_size = 384
in_features = 196
out_features = 128

class M(torch.nn.Module):
def __init__(self, bias):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)

def forward(self, x, y):
return self.linear(x).transpose(0, 1) + y

counters.clear()
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(out_features, batch_size).to(dtype=dtype)
self.common(mod, (v, u))
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)


@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
class _DynamicShapesTestBase(TestCase):
pass


class TestSelectAlgorithmDynamicShapes(_DynamicShapesTestBase):
common = check_model
test_linear_dynamic_shapes = TestSelectAlgorithm.test_linear_static_shapes
test_linear_with_pointwise_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_pointwise
)
test_linear_with_transpose_dynamic_shapes = (
TestSelectAlgorithm.test_linear_with_transpose
)


instantiate_device_type_tests(TestSelectAlgorithm, globals(), only_for="cpu")
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1508,7 +1508,7 @@ def _bound_variable(name, *args, **kwargs):
return ValueRanges.unknown()

fx_node = V.interpreter.current_node
if fx_node.target == name:
if fx_node.target == name and self.node_to_bounds is not None:
assert isinstance(self.node_to_bounds, dict)
return self.node_to_bounds.get(fx_node, ValueRanges.unknown())
elif config.compute_all_bounds and hasattr(ValueRangeAnalysis, name):
Expand Down
91 changes: 59 additions & 32 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
sympy_subs,
)

from ..virtualized import ops, OpsValue, V
from ..virtualized import NullKernelHandler, ops, OpsValue, V
from .common import (
BracesBuffer,
CppWrapperKernelArgs,
Expand Down Expand Up @@ -3148,27 +3148,11 @@ def is_memory_copy_scheduler_node(node: SchedulerNode):
body: ir.LoopBody = node._body
_legalize_lowp_fp(body)

def codegen_nodes(self, nodes: List[SchedulerNode]):
# Legalize BF16 node by adding to_dtype explicitly
self.legalize_lowp_fp_dtype(nodes)
self.data_type_propagation(nodes)

assert len(nodes) >= 1
first_node = nodes[0]
vec_dtype = (
first_node._lowp_fp_type # type: ignore[attr-defined]
if all(
hasattr(_node, "_lowp_fp_type")
and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined]
for _node in nodes
)
else torch.float
)

def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float):
# TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes
assert len(fn_list) == len(var_sizes_list)
kernel_group = self.kernel_group
_, (group, reduction_group) = max(
nodes, key=lambda x: int(x.is_reduction())
).group
group, reduction_group = max(var_sizes_list, key=lambda sizes: len(sizes[1]))

self.set_ranges(group, reduction_group)

Expand All @@ -3184,22 +3168,22 @@ def codegen_kernel(cls, *args):
def run(kernel):
vars, reduction_vars = kernel.set_ranges(group, reduction_group)
in_suffix = False
for node in nodes:
if node.group[1] in [
for fn, var_sizes in zip(fn_list, var_sizes_list):
if var_sizes in [
(group, reduction_group),
(group + reduction_group, ()),
]:
assert not in_suffix
node.run(vars, reduction_vars)
fn(vars, reduction_vars)
else:
in_suffix = True
assert node.group[1] == (
assert var_sizes == (
group,
(),
), f"unexpected group: {node.group[1]} != {group}, {reduction_group}"
), f"unexpected group: {var_sizes} != {group}, {reduction_group}"
# we can fuse in some extra pointwise into the suffix
with kernel.write_to_suffix():
node.run(vars, ())
fn(vars, ())

scalar_kernel = codegen_kernel(CppKernel)
V.graph.removed_buffers |= scalar_kernel.removed_buffers
Expand All @@ -3211,8 +3195,8 @@ def run(kernel):

def select_tiling_indices(tiling_factor):
all_index = []
for node in nodes:
rw = dependencies.extract_read_writes(node._body, *node._sizes)
for fn, var_sizes in zip(fn_list, var_sizes_list):
rw = dependencies.extract_read_writes(fn, *var_sizes)
all_index += [dep.index for dep in itertools.chain(rw.reads, rw.writes)]
contig_vars = set()
contig_vars_list = []
Expand Down Expand Up @@ -3326,6 +3310,41 @@ def select_tiling(dtype: torch.dtype = torch.float):
inner_main_loop.set_kernel(tile2d_kernel)
inner_tail_loop.set_kernel(vec_kernel)

def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
# TODO(jgong5): support lowp legalization
for body in loop_bodies:
DataTypePropagation.propagate_loopbody(body)
self.codegen_functions(loop_bodies, var_sizes_list)

def codegen_nodes(self, nodes: List[SchedulerNode]):
# Legalize BF16 node by adding to_dtype explicitly
self.legalize_lowp_fp_dtype(nodes)
self.data_type_propagation(nodes)

assert len(nodes) >= 1
first_node = nodes[0]
vec_dtype = (
first_node._lowp_fp_type # type: ignore[attr-defined]
if all(
hasattr(_node, "_lowp_fp_type")
and _node._lowp_fp_type == first_node._lowp_fp_type # type: ignore[attr-defined]
for _node in nodes
)
else torch.float
)

def fn(node, *index_vars):
node.decide_inplace_update()
node.mark_run()
if isinstance(V.kernel, NullKernelHandler):
return node._body(*index_vars)
else:
return node.codegen(index_vars)

fn_list = [functools.partial(fn, node) for node in nodes]
var_sizes_list = [node.group[1] for node in nodes]
self.codegen_functions(fn_list, var_sizes_list, vec_dtype)

def codegen_loops(self, code, worksharing):
self.codegen_loops_impl(self.loop_nest, code, worksharing)

Expand Down Expand Up @@ -3390,6 +3409,9 @@ def reset_kernel_group(self):
def fuse(self, node1, node2):
if node1.is_foreach() or node2.is_foreach():
return ForeachKernelSchedulerNode.fuse(node1, node2)
elif node1.is_template():
assert not node2.is_template()
return FusedSchedulerNode.fuse(node1, node2)
else:
if (
self._why_fuse_nodes(node1, node2)
Expand Down Expand Up @@ -3588,7 +3610,9 @@ def _get_outer_loop_fusion_depth(self, node1, node2):

def can_fuse_vertical_outer_loop(self, node1, node2):
return (
node1.get_names() & node2.ancestors
not node1.is_template()
and not node2.is_template()
and node1.get_names() & node2.ancestors
and not (
self._can_fuse_horizontal_impl(node1, node2)
and not node1.is_reduction()
Expand All @@ -3604,9 +3628,11 @@ def get_fusion_pair_priority(self, node1, node2):
return 0

def can_fuse_vertical(self, node1, node2):
# TODO(jgong5): support vertical fusion for template nodes
if node1.is_template() or node2.is_template():
if node2.is_template():
# TODO(jgong5): support pre-op fusion with template
return False
if node1.is_template():
return not node2.is_reduction()
return (
self._can_fuse_horizontal_impl(node1, node2) and not node1.is_reduction()
) or self.can_fuse_vertical_outer_loop(node1, node2)
Expand Down Expand Up @@ -3689,6 +3715,7 @@ def codegen_template(
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.decide_inplace_update()
node.mark_run()
src_code = render()

Expand Down
Loading

0 comments on commit 7844c20

Please sign in to comment.