Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[inductor] Remove redundant views (pytorch#111773)
As a follow-up to pytorch#110740, this patches enables removing redundant complex views to allow more operation fusing. E.g, given ``` @torch.compile def foo(X, Y): Z = X + Y A = X + Y return A + Z ``` the generated code is: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tmp3 = tmp2 + tmp2 tl.store(out_ptr0 + (x0), tmp3, xmask) ''') def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [add_2], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [add_2], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) del buf4 buf6 = buf5 del buf5 return (buf6, ) ``` whereas previously the generated code was: ``` @triton.jit def triton_(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr): xnumel = 6 xoffset = tl.program_id(0) * XBLOCK xindex = xoffset + tl.arange(0, XBLOCK)[:] xmask = xindex < xnumel x0 = xindex tmp0 = tl.load(in_ptr0 + (x0), xmask) tmp1 = tl.load(in_ptr1 + (x0), xmask) tmp2 = tmp0 + tmp1 tl.store(out_ptr0 + (x0), tmp2, xmask) def call(args): arg0_1, arg1_1 = args args.clear() assert_size_stride(arg0_1, (3, ), (1, )) assert_size_stride(arg1_1, (3, ), (1, )) with torch.cuda._DeviceGuard(0): torch.cuda.set_device(0) # no-op to ensure context # Source Nodes: [A], Original ATen: [aten.add] buf0 = aten.view.dtype(arg0_1, torch.float32) buf1 = buf0 del buf0 # Source Nodes: [A], Original ATen: [aten.add] buf2 = aten.view.dtype(arg1_1, torch.float32) buf3 = buf2 del buf2 buf4 = empty_strided((6, ), (1, ), device='cuda', dtype=torch.float32) # Source Nodes: [A], Original ATen: [aten.add] stream0 = get_cuda_stream(0) triton_poi_fused_add_0.run(buf1, buf3, buf4, 6, grid=grid(6), stream=stream0) del buf1 del buf3 # Source Nodes: [A], Original ATen: [aten.add] buf5 = aten.view.dtype(buf4, torch.complex64) buf6 = buf5 del buf5 # Source Nodes: [add_2], Original ATen: [aten.add] buf7 = aten.view.dtype(buf6, torch.float32) del buf6 buf8 = buf7 del buf7 # Source Nodes: [Z], Original ATen: [aten.add] buf9 = aten.view.dtype(arg0_1, torch.float32) del arg0_1 buf10 = buf9 del buf9 # Source Nodes: [Z], Original ATen: [aten.add] buf11 = aten.view.dtype(arg1_1, torch.float32) del arg1_1 buf12 = buf11 del buf11 buf13 = buf4; del buf4 # reuse # Source Nodes: [Z], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf10, buf12, buf13, 6, grid=grid(6), stream=stream0) del buf10 del buf12 # Source Nodes: [Z], Original ATen: [aten.add] buf14 = aten.view.dtype(buf13, torch.complex64) buf15 = buf14 del buf14 # Source Nodes: [add_2], Original ATen: [aten.add] buf16 = aten.view.dtype(buf15, torch.float32) del buf15 buf17 = buf16 del buf16 buf18 = buf13; del buf13 # reuse # Source Nodes: [add_2], Original ATen: [aten.add] triton_poi_fused_add_0.run(buf8, buf17, buf18, 6, grid=grid(6), stream=stream0) del buf17 del buf8 # Source Nodes: [add_2], Original ATen: [aten.add] buf19 = aten.view.dtype(buf18, torch.complex64) del buf18 buf20 = buf19 del buf19 return (buf20, ) ``` Pull Request resolved: pytorch#111773 Approved by: https://github.com/jansel
- Loading branch information