Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert] Remove training info from nnfw_train_prepare #12383

Merged
merged 2 commits into from
Dec 29, 2023

Conversation

zetwhite
Copy link
Contributor

@zetwhite zetwhite commented Dec 28, 2023

This PR removes training info from 'nnfw_train_prepare' argument.
It also makes 'nnfw_train_prepare' to use training info from session.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn sseung.youn@samsung.com

related : #11692

Comment on lines 1266 to +1268
try
{
nnfw_train_info tinfo;
if (info != nullptr)
{
tinfo = *info;
}

auto convertLossType = [](const int &type) {
if (type == NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR)
return onert::ir::train::LossCode::MeanSquaredError;
if (type == NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY)
return onert::ir::train::LossCode::CategoricalCrossentropy;
else
throw std::runtime_error("not supported loss type");
};
auto convertLossReductionType = [](const int &type) {
if (type == NNFW_TRAIN_LOSS_REDUCTION_AUTO)
return onert::ir::train::LossReductionType::Auto;
else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE)
return onert::ir::train::LossReductionType::SumOverBatchSize;
else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM)
return onert::ir::train::LossReductionType::Sum;
else
throw std::runtime_error("not supported loss reduction type");
};
onert::ir::train::LossInfo loss_info;
loss_info.loss_code = convertLossType(tinfo.loss_info.loss);
// TODO Consider the reduction type of model file
loss_info.reduction_type = convertLossReductionType(tinfo.loss_info.reduction_type);

auto convertOptType = [](const int &type) {
if (type == NNFW_TRAIN_OPTIMIZER_SGD)
return onert::ir::train::OptimizerCode::SGD;
else if (type == NNFW_TRAIN_OPTIMIZER_ADAM)
return onert::ir::train::OptimizerCode::Adam;
else
throw std::runtime_error("not supported optimizer type");
};
onert::ir::train::OptimizerInfo opt_info;
opt_info.learning_rate = tinfo.learning_rate;
opt_info.optim_code = convertOptType(tinfo.opt);

onert::ir::train::TrainingInfo training_info;
training_info.setBatchSize(tinfo.batch_size);
training_info.setLossInfo(loss_info);
training_info.setOptimizerInfo(opt_info);
if (not _train_info->isValid())
throw std::runtime_error{"training info is not valid"};

auto compiler =
onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, &training_info);
onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, _train_info.get());
Copy link
Contributor Author

@zetwhite zetwhite Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using training info from the argument, use_train_info in the session.

We expect training info in the session will be set by those API :

  • nnfw_load_model.. - set training info from model metadata.
  • nnfw_train_set_traininfo - users' manual set

@zetwhite zetwhite force-pushed the 1228/api_setter branch 2 times, most recently from 67ee378 to 545d78e Compare December 29, 2023 02:09
This PR removes training info from 'nnfw_train_prepare' argument.
It also makes 'nnfw_train_prepare' to use training info from session.

ONE-DCO-1.0-Signed-off-by: SeungHui Youn <sseung.youn@samsung.com>
@zetwhite zetwhite changed the title [onert] Introduce nnfw_train_set_traininfo() API [onert] Remove training info from nnfw_train_prepare Dec 29, 2023
@zetwhite zetwhite marked this pull request as ready for review December 29, 2023 02:30
@zetwhite zetwhite added approval: 2 Require at least 2 approvals PR/ready for review It is ready to review. Please review it. labels Dec 29, 2023
@zetwhite zetwhite requested a review from a team December 29, 2023 02:38
@jyoungyun
Copy link
Contributor

The nnfw_session in nnfw_api_internal.h contains TrainingInfo as an unique_ptr object.

std::unique_ptr<onert::ir::train::TrainingInfo> _train_info;

This PR obtains the address of _train_info and passes it to TrainingCompiler.

onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, _train_info.get());

The TrainingCompiler stores a copy of _train_info.

const ir::train::TrainingInfo &training_info)
: _model{nnpkg->primary_model()}, _options{copts[0].get()}, _training_info{training_info}

The TrainingCompiler creates a TrainableExecutor using createTrainableExecutor.
auto executor = std::unique_ptr<exec::IExecutor>{
ExecutorFactory::get().create(std::move(lowered_subg), executors, args, _training_info)};

The createTrainableExecutor copies optimizer and loss information from _training_info.
And it is no longer used.

At first, I thought that you should use std::move instead of get() api. However, while the session is alive, the _train_info pointer is alive and other functions use the value of this pointer. So I think it's okay.

In addition, nnfw_session has _training_step to manage epoch count. The TrainingInfo also contains epoch information. So the _training_step can be merged into the TrainingInfo and std::move should not be performed because this epoch count must be maintained while the session is alive.

I leave a long comment, but what I want to say is "LGTM". Thank you :)

@glistening

This comment was marked as resolved.

Copy link
Contributor

@glistening glistening left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. It breaks API compatibility. Let's merge at once.

@jyoungyun
Copy link
Contributor

Please split this into two PRs: (api and tools)

If split this into two PRs, it makes build error.

@glistening glistening merged commit 5bf552a into Samsung:master Dec 29, 2023
10 checks passed
@zetwhite
Copy link
Contributor Author

@jyoungyun

Thank you a lot for checking detailed things :)

However, while the session is alive, the _train_info pointer is alive and other functions use the value of this pointer.

I agree. Also nnfw_train_get_traininfo(not yet merged in master branch, but planned) needs to access _train_info in session.

The TrainingCompiler stores a copy of _train_info.

I think this copy is redundant.
Maybe, we could remove this by updating TrainingCompiler holds raw pointer, instead of structure itself.
(similar to _options under code)

private:
std::shared_ptr<ir::Model> _model;
CompilerOptions *_options;
const ir::train::TrainingInfo _training_info;

@jyoungyun
Copy link
Contributor

I think this copy is redundant.
Maybe, we could remove this by updating TrainingCompiler holds raw pointer, instead of structure itself.

I agree this point :)
And I'm making a PR for merging the _training_step into TrainingInfo. If this commit is merged, the TrainingInfo will be more useful in various ways.

@zetwhite
Copy link
Contributor Author

zetwhite commented Dec 29, 2023

In addition, nnfw_session has _training_step to manage epoch count.

@jyoungyun

Let me add some background about this, just in case.

I added epoch to TrainingInfo, Because circle+ schema(training info flatbuffer schema) holds epoch information. (sorry for inner-link 🔗)

Since there is no way to save a model file's info except saving in session._train_info,
I put the model file's epoch in the session, just in case.
But there is no exact plan to use epoch in session.

@jyoungyun
Copy link
Contributor

The epoch in circle+ schema is not necessary information. Because the epoch count is not related to the training model. We should remove the epoch info from circle+ schema. :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
approval: 2 Require at least 2 approvals PR/ready for review It is ready to review. Please review it.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants