[Schedule] Add .fuse() primitive#25
Conversation
|
Seems like we could ask for sample inputs in kwargs if required by the compiler backend. For example: |
|
It seems this PR requires several changes to the subgraph matching mechanism. I will open a new PR to make step 1 & 2 work first. |
comaniac
left a comment
There was a problem hiding this comment.
Overall LGTM. Later we should apply .fuse to the examples, which currently replace an entire MLP for bias_new_gelu.
slapo/pattern.py
Outdated
| raise NotImplementedError | ||
|
|
||
|
|
||
| class CallModule(nn.Module): |
There was a problem hiding this comment.
The names of CallModule and call_module are confusing.
There was a problem hiding this comment.
They are just used in different places. CallModule is used in a pattern class, while call_module can be used in a pattern function. Do you have any suggestions about the naming?
There was a problem hiding this comment.
Hmm base on your use cases, such as
class LinearReLUPattern(slapo.Pattern):
def __init__(self):
super().__init__()
self.fc = CallModule(r"fc?")
self.relu = nn.ReLU()Here you're actually constructing a module instead of calling it, so I guess you could use the name like ModulePattern or something like that.
There was a problem hiding this comment.
It may also lead to confusion with the Pattern class.
There was a problem hiding this comment.
Well, to me it is a kind of pattern, especially you have a regex in its argument.
There was a problem hiding this comment.
Yes, I agree it is a pattern, and ModulePattern is a proper name, but since we have the Pattern base class which is also a module, I'm afraid it will cause confusion between these two.
Sure, I'll add it in the next PR. |
|
Thanks @chhzh123 |
Description
This PR adds a new primitive called
.fuse(subgraph, compiler)for operator fusion. Currently we only support pattern-based vertical operator fusion using TorchScript as the backend compiler. A simple example is shown below.The fusion performance needs to be tested. As I first create an identical torch.fx subgraph and pass it into TorchScript's scripting mode for optimization, it may cause performance issue if some of the operators are not recognized by TorchScript. Tracing mode can achieve the best performance but it is hard to leverage since we cannot always obtain example inputs for each subgraph. Scripting a user-defined function is also not a good approach since the backward pass is not captured by TorchScript. Therefore, only scripting an entire module is a good fit for our case, and we need to test the compatibility of torch.fx and TorchScript.
Checklist
torch.nn.functionalto match modules intorch.nn. For example,F.reluandnn.ReLUshould be treated as the same in pattern matching, since users cannot specify a module in a function pattern. ([Schedule] Refactor subgraph matching #35).decompose()primitive and support decoupling bias fromnn.Linearflattenargument to .trace() #29)Future plan (Updated)
The following features will be added in separate PRs.