-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[microNPU] Add support for SIGMOID #9627
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM Overall! Just a couple minor things I noticed :)
return y | ||
|
||
|
||
class SigmoidRewriter(DFPatternCallback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks very similar to TanhRewriter
, is it worth creating a subclass?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea! I made the change
|
||
def is_valid(self): | ||
""" | ||
This function checks whether reshape has compatible attributes with the NPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This function checks whether reshape has compatible attributes with the NPU | |
This function checks whether sigmoid has compatible attributes with the NPU |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
class Model(tf.Module): | ||
@tf.function | ||
def tanh_function(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def tanh_function(self, x): | |
def sigmoid_function(self, x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
dtype = "int8" | ||
|
||
def create_tflite_graph(): | ||
tf.config.run_functions_eagerly(True) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tf.config.run_functions_eagerly(True) |
We don't need it, do we?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No we don't :D
a550c44
to
90ac839
Compare
Thanks for the review @lhutton1! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some type hints are missing, could you add them?
)(wildcard()) | ||
self.pattern = (wildcard().has_attr({"Composite": params_class.composite_name}))(wildcard()) | ||
self.activation_type = activation_type | ||
self.calc_func = calc_func | ||
|
||
def callback(self, pre, post, node_map): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def callback(self, pre, post, node_map): | |
def callback(self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: tvm.ir.container.Map): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -194,6 +205,48 @@ def __call__(self, *args, **kwargs): | |||
pass | |||
|
|||
|
|||
def sigmoid_calc_func(x): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def sigmoid_calc_func(x): | |
def sigmoid_calc_func(x: float) -> float: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
|
||
def __init__(self): | ||
def __init__(self, params_class, activation_type, calc_func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def __init__(self, params_class, activation_type, calc_func): | |
def __init__(self, params_class: Type, activation_type: string, calc_func: Callable[[float], float]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -125,30 +125,30 @@ def __call__(self, *args, **kwargs): | |||
pass | |||
|
|||
|
|||
def find_tanh_values(ifm_scale, ifm_zp, ofm_scale, ofm_zp): | |||
"""Method to calculate the values of the tanh lookup table""" | |||
def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def get_lut_from_func(ifm_scale, ifm_zp, ofm_scale, ofm_zp, func): | |
def get_lut_from_func(ifm_scale: float, ifm_zp: int, ofm_scale: float, ofm_zp: int, func: Callable[[float], float]) -> list[int]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies, something I missed in my first review. Not essential, but given @NicolaLancellotti's review I suppose there's no harm in including this also :)
@@ -1035,6 +1035,35 @@ def mean_pattern() -> tvm.relay.dataflow_pattern.DFPattern: | |||
return pattern | |||
|
|||
|
|||
class SigmoidParams: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we also create a subclass for this and TanhParams
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
90ac839
to
f818ef3
Compare
Thanks for the reviews @NicolaLancellotti and @lhutton1! :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @ekalda, LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
Thanks all! |
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup table mechanism in the NPU.
Add support for SIGMOID activation function using the lookup
table mechanism in the NPU.