Skip to content

Commit

Permalink
[inductor] Remove redundant views (pytorch#111773)
Browse files Browse the repository at this point in the history
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing.

E.g,  given

```
@torch.compile
def foo(X, Y):
    Z = X + Y
    A = X + Y
    return A + Z
```

the generated code is:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tmp3 = tmp2 + tmp2
    tl.store(out_ptr0 + (x0), tmp3, xmask)
''')

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [add_2], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        del buf4
        buf6 = buf5
        del buf5
        return (buf6, )
```

whereas previously the generated code was:

```
@triton.jit
def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 6
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = xindex < xnumel
    x0 = xindex
    tmp0 = tl.load(in_ptr0 + (x0), xmask)
    tmp1 = tl.load(in_ptr1 + (x0), xmask)
    tmp2 = tmp0 + tmp1
    tl.store(out_ptr0 + (x0), tmp2, xmask)

def call(args):
    arg0_1, arg1_1 = args
    args.clear()
    assert_size_stride(arg0_1, (3, ), (1, ))
    assert_size_stride(arg1_1, (3, ), (1, ))
    with torch.cuda._DeviceGuard(0):
        torch.cuda.set_device(0) # no-op to ensure context
        # Source Nodes: [A], Original ATen: [aten.add]
        buf0 = aten.view.dtype(arg0_1, torch.float32)
        buf1 = buf0
        del buf0
        # Source Nodes: [A], Original ATen: [aten.add]
        buf2 = aten.view.dtype(arg1_1, torch.float32)
        buf3 = buf2
        del buf2
        buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32)
        # Source Nodes: [A], Original ATen: [aten.add]
        stream0 = get_cuda_stream(0)
        triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0)
        del buf1
        del buf3
        # Source Nodes: [A], Original ATen: [aten.add]
        buf5 = aten.view.dtype(buf4, torch.complex64)
        buf6 = buf5
        del buf5
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf7 = aten.view.dtype(buf6, torch.float32)
        del buf6
        buf8 = buf7
        del buf7
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf9 = aten.view.dtype(arg0_1, torch.float32)
        del arg0_1
        buf10 = buf9
        del buf9
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf11 = aten.view.dtype(arg1_1, torch.float32)
        del arg1_1
        buf12 = buf11
        del buf11
        buf13 = buf4; del buf4  # reuse
        # Source Nodes: [Z], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0)
        del buf10
        del buf12
        # Source Nodes: [Z], Original ATen: [aten.add]
        buf14 = aten.view.dtype(buf13, torch.complex64)
        buf15 = buf14
        del buf14
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf16 = aten.view.dtype(buf15, torch.float32)
        del buf15
        buf17 = buf16
        del buf16
        buf18 = buf13; del buf13  # reuse
        # Source Nodes: [add_2], Original ATen: [aten.add]
        triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0)
        del buf17
        del buf8
        # Source Nodes: [add_2], Original ATen: [aten.add]
        buf19 = aten.view.dtype(buf18, torch.complex64)
        del buf18
        buf20 = buf19
        del buf19
        return (buf20, )
```

Pull Request resolved: pytorch#111773
Approved by: https://github.com/jansel
  • Loading branch information
htyu authored and Skylion007 committed Nov 14, 2023
1 parent a692a27 commit 8edace7
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 0 deletions.
13 changes: 13 additions & 0 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,19 @@ def fn(a, b, alpha):

self.common(fn, (x, y, 2))

def test_add_complex2(self):
@torch.compile
def fn(a, b):
c = a + b
d = a + b
return c + d

x = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])
y = torch.tensor([1 + 1j, -1 + 1j, -2 + 2j, 3 - 3j, 0, 1j, 1, -1])

_, code = run_and_get_code(fn, x, y)
self.assertEqual(code[0].count("aten.view"), 3)

def test_concat_add_inplace(self):
def fn(x, y, z):
return torch.cat([x, y], dim=1).add_(z)
Expand Down
54 changes: 54 additions & 0 deletions torch/_inductor/fx_passes/joint_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,59 @@ def replace_no_op(node, replace_input_index):
replace_no_op(node, 0)


@torch.utils._python_dispatch._disable_current_modes()
def remove_redundant_views(gm: torch.fx.GraphModule):
"""
Removes redundant views by reusing existing ones.
"""

# A dictionary mapping a tensor to all aliased views.
views: Dict[torch.fx.Node, Dict[torch.dtype, torch.fx.Node]] = {}
graph = gm.graph

for node in graph.nodes:
if node.op != "call_function":
continue

if node.target != torch.ops.aten.view.dtype:
continue

src = node.args[0]
to_type = node.args[1]
existing_views = views.get(src)
is_needed = True

if existing_views:
# Replace the view with the an existing view if available.
alias = existing_views.get(to_type)
if alias:
is_needed = False
node.replace_all_uses_with(alias)
alias.meta.update(node.meta)
graph.erase_node(node)
else:
from_type = src.meta["val"].dtype
existing_views = {from_type: src}
views[src] = existing_views

if is_needed:
# Save the new alias but do not replace existing one.
existing_views.setdefault(to_type, node)
views[node] = existing_views

# Clean up unused views.
while True:
unused_views = []
for alias in views:
if not alias.users:
unused_views.append(alias)
if len(unused_views) == 0:
break
for unused in unused_views:
views.pop(unused)
graph.erase_node(unused)


class UniformValueConstantFolder(ConstantFolder):
"""
Runs constant folding and replaces tensors that have a unifrom value
Expand Down Expand Up @@ -202,6 +255,7 @@ def constant_fold_uniform_value(gm: torch.fx.GraphModule):
ones.add(new_node)

remove_no_ops(gm, zeros, ones)
remove_redundant_views(gm)


def joint_graph_passes(graph: torch.fx.GraphModule):
Expand Down

0 comments on commit 8edace7

Please sign in to comment.