Skip to content

[Primitive] Add fallback fusion#78

Merged
comaniac merged 2 commits intoawslabs:mainfrom
chhzh123:fallback_fusion
Mar 9, 2023
Merged

[Primitive] Add fallback fusion#78
comaniac merged 2 commits intoawslabs:mainfrom
chhzh123:fallback_fusion

Conversation

@chhzh123
Copy link
Contributor

@chhzh123 chhzh123 commented Mar 9, 2023

Description

This PR adds a fallback fusion option for .fuse() primitive, which directly puts the operations in the given subgraph into a nn.Sequential module but preserves exactly the same computation rules. It is useful for debugging and further dispatching for different backends. In this way, users even do not need to register a new compiler for Slapo, but can just replace this "fake fused" module with their efficient module using .replace().

Checklist

  • PR's title starts with a category (e.g. [Bugfix], [Model], [Tutorial], etc)
  • Changes are complete (i.e. I finished coding on this PR)
  • All changes have test coverage
  • Code is well-documented

@chhzh123 chhzh123 requested a review from comaniac March 9, 2023 05:59
@comaniac
Copy link
Contributor

comaniac commented Mar 9, 2023

This is an interesting feature. I agree that this could be useful for debugging, but I didn't get the point of replacing it with an efficient module. Shouldn't it be done with just .find and .replace?

In addition, I would suggest use torchscript as the default value instead of None, as I believe most users would expect to see speedups with .fuse.

@chhzh123
Copy link
Contributor Author

chhzh123 commented Mar 9, 2023

I'm imagining some backend compilers or passes to further handle these "fake fused" modules. Since the optimized fusion module may not be available ahead of time, this can be viewed as a delayed version of op fusion. TorchScript tightly couples with CPU/GPU backend and has tracer limitations, if users want another full-graph compiler to handle those fused op, using TorchScript will just make things complicated. I can think of two use cases now:

  1. It actually mimics the preprocessing process of quantization in PyTorch. Specifically, torch.quantization.fuse_modules is doing the same thing as this PR does. Later on, PyTorch will take another pass to convert those fake fused modules to real fused modules on CPU.
  2. If I leverage other backend compilers like HeteroCL which focuses more on operator-level optimization, it will be easier for low-level compilers to accept an encapsulated operation and then dispatch the fused op to accelerators.

This feature may not be very useful for now, but it does not break current facilities and also provides users more options to conduct graph-level optimizations.

@comaniac
Copy link
Contributor

comaniac commented Mar 9, 2023

So you actually meant users may want to compile the matched subgraphs in an arbitrary way, and use .replace to put the compiled module back. It makes sense to me.

Copy link
Contributor

@comaniac comaniac left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@comaniac comaniac merged commit 10ac17d into awslabs:main Mar 9, 2023
@comaniac
Copy link
Contributor

comaniac commented Mar 9, 2023

Thanks @chhzh123

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants