Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AMP][Pass][Typing] Add faster type inference #9735

Merged
merged 19 commits into from Jan 4, 2022

Conversation

AndrewZhaoLuo
Copy link
Contributor

@AndrewZhaoLuo AndrewZhaoLuo commented Dec 14, 2021

This PR adds a faster type inference pass which specifically is designed for the Automatic Mixed Precision Pass (AMP). The issue with AMP pass is it uses the existing type inference infrastructure extensively but existing type inference is not designed for the AMP workload.

AMP works by topologically going through the expression graph, replacing nodes with casted versions and using type inference extensively to do this. However, in order to use the type inference we must, for every subgraph build an IRModule and run type inference. The current type inference ignores previously populated type information and essentially repopulates the type fields of the subgraph we are examining. In a situation with N nodes arranged in a linear fashion, for AMP pass we will have N subgraphs we examine. For the i-th subgraph we have i nodes which IRModule and type inference will touch. This essentially means we have O(N^2) runtime at least which is bad.

The key issues are therefore:

  1. Type inference needs an IRModule which touches all nodes in a graph
  2. Type inference looks at all nodes in a graph and does not reuse information

The solution I came up with is a bit of a hack that let's me avoid rewriting the Type Inference pass (which is super essential and would take a long time to change). Essentially given an expression graph with partially populated type information, we can, given a subgraph, very easily construct an analogous graph which has the same type; we just need to replace nodes with known type information with a constant or variable expression. Doing this means if we are only interested in the type of a single node, we can extract a smaller subgraph with all the needed information to infer type. We then build an IRModule and run standard type inference on this much smaller subgraph.

This has 100x reduction in the AMP pass runtime. arcfaceresnet100 on a 2020 m1 macbook pro went from 20s --> 0.2s for example.

@AndrewZhaoLuo AndrewZhaoLuo changed the title [WIP] Add faster type inference [WIP][AMP][Pass][Typing] Add faster type inference Dec 14, 2021
@AndrewZhaoLuo AndrewZhaoLuo changed the title [WIP][AMP][Pass][Typing] Add faster type inference [AMP][Pass][Typing] Add faster type inference Dec 14, 2021
@AndrewZhaoLuo
Copy link
Contributor Author

Discussed @jroesch and @mbs-octoml, main changes we want to do is change the name "Fast" --> "Local" and better documentating pre-conditions.

@AndrewZhaoLuo
Copy link
Contributor Author

This is now ready for review

Copy link
Contributor

@mbs-octoml mbs-octoml left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Never looked at to_mixed_precision.cc before but boy do I see why this would help!
Just some nits, thanks, pretty sure this is going to get more use.

return mod->Lookup("main").as<FunctionNode>()->body->checked_type();
Type checked_type = expr->checked_type_;
if (checked_type.defined()) {
return checked_type;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// The expression has not been changed AND it's existing type
// is known to still be valid. (See special handling for tuples etc
// below for where we null out checked_type_ when we can not
// sure it is still valid.

(though see my comment below)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@@ -381,6 +381,18 @@ class MixedPrecisionPass : public MixedModeMutator {
return Call(cur_op, new_args, pre_call_node->attrs, new_arg_types, pre_call_node->span);
}

Expr Rewrite_(const TupleGetItemNode* pre, const Expr& post) {
// The old checked type in the expression may not be valid so clear it
post->checked_type_ = Type(nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

am I missing something or will checked_type_ = null iff some sub-expression of post has been rewritten and thus it's type has changed?
ie checked_type_ is non-null only if pre == post.get() ??

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm so you would think so, but it looks like the mutator does not by default invalidate the checked_type (and appears to reuse the reference? giving us this problem).

I can dig a little deeper, but if I remove this line for TupleGetItemNode the checked type will be wrong (it will be fp32 instead of fp16)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

https://github.com/apache/tvm/blob/main/src/relay/ir/expr_functor.cc#L248

Here is the behavior for generating post, there is some Copy on write stuff which i don't quite understand the full mechanics of so 🤷

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah! It's the COW, that makes sense. I think that means we should be clearing checked_type_ on COW but let's not dig ourselves any deeper until we've thought about incremental type inference a bit more carefully.

@@ -824,8 +824,107 @@ void AddGlobalTypes(IRModule mod) {
}
}

class SameTypedSubgraphExtractor : public ExprMutator {
/*
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

micro nit: move to before class, used /*! etc.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Returns the largest sub-graph who's inner nodes need types and leaves are vars standing in
for already typed sub-expressions.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

}

private:
Expr get_analogous_expression(const Expr& expr) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: GetAnalogousExpression

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

return VisitExpr(expr);
}

return Var("dummy_var", expr->checked_type(), expr->span);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// Since the expression already has a checked_type which we trust we don't need
// full type inference to enter it. So stub it out with a dummy var of the same type.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Dec 17, 2021

Was trying to play around with replacing some type inference in dynamic_to_static pass and ran into some small errors related to funcnodes, so I'm going to add a few basic tests

@AndrewZhaoLuo
Copy link
Contributor Author

Added (or rather replaced) some tests. PTAL @mbs-octoml

@masahi masahi merged commit 9cc1df6 into apache:main Jan 4, 2022
@FrozenGene
Copy link
Member

FrozenGene commented Jan 4, 2022

@AndrewZhaoLuo Sorry for later reply. Does this help us to solve ADT problem in our MixedPrecision? Let us imagine we have one if node in our relay graph, if will be converted two subgraphs mentioned by you in this pr. For example:

fn main():
    let %1 = xxx;
    let %2 = if (%1) {
    let %3: = @func___inference_a(%4, %5, %6) 
  } else {
    let %7: = @func___inference_b(%8, %9)
  };  

Then we have two subgraph func___inference_a and func___inference_b. Does this help us to make our two subgraph type infer correctly? As I see you have supported GlobalVarNode.

@AndrewZhaoLuo
Copy link
Contributor Author

AndrewZhaoLuo commented Jan 7, 2022

@AndrewZhaoLuo Sorry for later reply. Does this help us to solve ADT problem in our MixedPrecision? Let us imagine we have one if node in our relay graph, if will be converted two subgraphs mentioned by you in this pr. For example:

fn main():
    let %1 = xxx;
    let %2 = if (%1) {
    let %3: = @func___inference_a(%4, %5, %6) 
  } else {
    let %7: = @func___inference_b(%8, %9)
  };  

Then we have two subgraph func___inference_a and func___inference_b. Does this help us to make our two subgraph type infer correctly? As I see you have supported GlobalVarNode.

@FrozenGene not sure if I understand the concern 😅, global var nodes are just used to reference function calls right? These functions have a known type ahead of time right?

ylc pushed a commit to ylc/tvm that referenced this pull request Jan 7, 2022
* reuse checked types

* analogous subgraph

* brr go fast

* clean up src logs

* clean up PR more

* more clean up

* more documenetation

* clean up

* formatting

* rename fast --> local

* more ocmments

* jostle ci

* type inference

* change comment for SameTypedSubgraphExtractor

* get_analogous_expression -> GetAnalogousExpression

* comment in GetAnaalogousExpression

* add comment

* replace infer tests

* jostle
@FrozenGene
Copy link
Member

@AndrewZhaoLuo Sorry for later reply. Does this help us to solve ADT problem in our MixedPrecision? Let us imagine we have one if node in our relay graph, if will be converted two subgraphs mentioned by you in this pr. For example:

fn main():
    let %1 = xxx;
    let %2 = if (%1) {
    let %3: = @func___inference_a(%4, %5, %6) 
  } else {
    let %7: = @func___inference_b(%8, %9)
  };  

Then we have two subgraph func___inference_a and func___inference_b. Does this help us to make our two subgraph type infer correctly? As I see you have supported GlobalVarNode.

@FrozenGene not sure if I understand the concern 😅, global var nodes are just used to reference function calls right? These functions have a known type ahead of time right?

@AndrewZhaoLuo Yes. In fact I saw your pr support global var node, I thought you will leverage it to solve this undo: https://github.com/apache/tvm/blob/main/src/relay/transforms/to_mixed_precision.cc#L297

@AndrewZhaoLuo
Copy link
Contributor Author

@FrozenGene ah yes, so the type inference will work, but need to think about how to handle it properly for AMP, when I initially wrote AMP I ignored stuff not usually found in most real-life models. It is on list of todos here: #8296

ylc pushed a commit to ylc/tvm that referenced this pull request Jan 13, 2022
* reuse checked types

* analogous subgraph

* brr go fast

* clean up src logs

* clean up PR more

* more clean up

* more documenetation

* clean up

* formatting

* rename fast --> local

* more ocmments

* jostle ci

* type inference

* change comment for SameTypedSubgraphExtractor

* get_analogous_expression -> GetAnalogousExpression

* comment in GetAnaalogousExpression

* add comment

* replace infer tests

* jostle
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants