-
Notifications
You must be signed in to change notification settings - Fork 157
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[onert/api] Add nnfw_train_set_traininfo() #12384
Conversation
@@ -220,6 +220,15 @@ typedef struct nnfw_train_info | |||
NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; | |||
} nnfw_train_info; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For reviewers
This PR just adds one new API (_train_info setter - named nnfw_train_set_traininfo
).
But I also have a plan to change nnfw_train_prepare
a bit.
The overall direction is like this.
AS-IS
nnfw_train_prepare()
gets train_info and compiles training model using the given train_info
TO-BE
nnfw_set_train_traininfo()
set train_info into sessionnnfw_train_prepare()
compile training model using the train_info inside the session
c89b1e5
to
c59b036
Compare
@@ -220,6 +220,15 @@ typedef struct nnfw_train_info | |||
NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; | |||
} nnfw_train_info; | |||
|
|||
/** | |||
* @brief Set train_info into session | |||
* @note This function should be called in MODEL_LOADED state, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
User cannot see MODEL_LOADED
state, how about removing the content related to MODEL_LOADED
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated :)
* | ||
* @return @c NNFW_STATUS_NO_ERROR If successful | ||
*/ | ||
NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_info *train_info); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API should be called after model_loaded
.
Because.. If the user set train_info before the model loaded,
train_info in the session can be overwritten while the model loaded.
This can be confusing.
So, To avoid unexpected overwritten while the model loaded, Let's make sure that nnfw_train_set_traininfo
is called after the model loaded.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove 'into session" in description as I suggested. I think "into session" is better to be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated :)
fc516e4
to
8e545fa
Compare
@@ -220,6 +220,15 @@ typedef struct nnfw_train_info | |||
NNFW_TRAIN_OPTIMIZER opt = NNFW_TRAIN_OPTIMIZER_SGD; | |||
} nnfw_train_info; | |||
|
|||
/** | |||
* @brief Set train_info into session |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't use variable name for description for user.
* @brief Set train_info into session | |
* @brief Set training info |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you. Updated :)
|
||
auto convertLossReductionType = [](const int &type) { | ||
if (type == NNFW_TRAIN_LOSS_REDUCTION_AUTO) | ||
return onert::ir::train::LossReductionType::Invalid; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
return onert::ir::train::LossReductionType::Invalid; | |
return onert::ir::train::LossReductionType::Auto; |
Maybe it should be Auto type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for catching :)
Updated
This PR addis nnfw_train_set_traininfo() API into nnfw_experimental header. The 'nnfw_train_set_traininfo()' API is to set train_info in the session. ONE-DCO-1.0-Signed-off-by: SeungHui Youn sseung.youn@samsung.com
f1e80e3
to
54378e7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, Thank you!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM :)
This PR adds
nnfw_train_set_traininfo()
API into nnfw_experimental header.The
nnfw_train_set_traininfo()
API is to set train_info in the session.ONE-DCO-1.0-Signed-off-by: SeungHui Youn sseung.youn@samsung.com
related issue : #11692