Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert] Remove training info from nnfw_train_prepare #12383

Merged
merged 2 commits into from
Dec 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 4 additions & 6 deletions runtime/onert/api/include/nnfw_experimental.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,16 +233,14 @@ NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_inf
* @brief Prepare session to be ready for training
* @note The session will be entered into training mode
*
* If training info is NOT set in session, this function returns @c NNFW_STATUS_ERROR .
* You should set training info using {@link nnfw_train_set_traininfo}.
*
* @param[in] session The session to be prepared for training
* @param[in] info Training information.
* If info is nullptr, it will not change training information.
* If it is nullptr and model has not training information,
* it will use default training information.
* Default training information is {learning_rate = 0.001f, batch_size = 1}
*
* @return @c NNFW_STATUS_NO_ERROR if successful
*/
NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info);
NNFW_STATUS nnfw_train_prepare(nnfw_session *session);

/**
* @brief Set training input
Expand Down
4 changes: 2 additions & 2 deletions runtime/onert/api/src/nnfw_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -394,10 +394,10 @@ NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_inf
return session->train_set_traininfo(info);
}

NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info)
NNFW_STATUS nnfw_train_prepare(nnfw_session *session)
{
NNFW_RETURN_ERROR_IF_NULL(session);
return session->train_prepare(info);
return session->train_prepare();
}

NNFW_STATUS nnfw_train_input_tensorinfo(nnfw_session *session, uint32_t index,
Expand Down
54 changes: 7 additions & 47 deletions runtime/onert/api/src/nnfw_api_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1242,7 +1242,7 @@ NNFW_STATUS nnfw_session::train_set_traininfo(const nnfw_train_info *info)
return NNFW_STATUS_NO_ERROR;
}

NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info)
NNFW_STATUS nnfw_session::train_prepare()
{
// We may need different state to represent training model is loaded
if (!isStateModelLoaded())
Expand All @@ -1256,56 +1256,16 @@ NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info)
return NNFW_STATUS_INVALID_STATE;
}

// after model loaded, it ensures that _train_info is not nullptr
assert(_train_info != nullptr);

try
{
nnfw_train_info tinfo;
if (info != nullptr)
{
tinfo = *info;
}

auto convertLossType = [](const int &type) {
if (type == NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR)
return onert::ir::train::LossCode::MeanSquaredError;
if (type == NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY)
return onert::ir::train::LossCode::CategoricalCrossentropy;
else
throw std::runtime_error("not supported loss type");
};
auto convertLossReductionType = [](const int &type) {
if (type == NNFW_TRAIN_LOSS_REDUCTION_AUTO)
return onert::ir::train::LossReductionType::Auto;
else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE)
return onert::ir::train::LossReductionType::SumOverBatchSize;
else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM)
return onert::ir::train::LossReductionType::Sum;
else
throw std::runtime_error("not supported loss reduction type");
};
onert::ir::train::LossInfo loss_info;
loss_info.loss_code = convertLossType(tinfo.loss_info.loss);
// TODO Consider the reduction type of model file
loss_info.reduction_type = convertLossReductionType(tinfo.loss_info.reduction_type);

auto convertOptType = [](const int &type) {
if (type == NNFW_TRAIN_OPTIMIZER_SGD)
return onert::ir::train::OptimizerCode::SGD;
else if (type == NNFW_TRAIN_OPTIMIZER_ADAM)
return onert::ir::train::OptimizerCode::Adam;
else
throw std::runtime_error("not supported optimizer type");
};
onert::ir::train::OptimizerInfo opt_info;
opt_info.learning_rate = tinfo.learning_rate;
opt_info.optim_code = convertOptType(tinfo.opt);

onert::ir::train::TrainingInfo training_info;
training_info.setBatchSize(tinfo.batch_size);
training_info.setLossInfo(loss_info);
training_info.setOptimizerInfo(opt_info);
if (not _train_info->isValid())
throw std::runtime_error{"training info is not valid"};

auto compiler =
onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, &training_info);
onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, _train_info.get());
Comment on lines 1262 to +1268
Copy link
Contributor Author

@zetwhite zetwhite Dec 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of using training info from the argument, use_train_info in the session.

We expect training info in the session will be set by those API :

  • nnfw_load_model.. - set training info from model metadata.
  • nnfw_train_set_traininfo - users' manual set

_nnpkg.reset();
_compiler_artifact = compiler->compile();
_execution = std::make_unique<onert::exec::Execution>(_compiler_artifact->_executors);
Expand Down
2 changes: 1 addition & 1 deletion runtime/onert/api/src/nnfw_api_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ struct nnfw_session
NNFW_STATUS set_backends_per_operation(const char *backend_settings);

NNFW_STATUS train_set_traininfo(const nnfw_train_info *info);
NNFW_STATUS train_prepare(const nnfw_train_info *info);
NNFW_STATUS train_prepare();
NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti);
NNFW_STATUS train_set_input(uint32_t index, const void *input,
Expand Down
6 changes: 4 additions & 2 deletions tests/tools/onert_train/src/onert_train.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,11 +185,13 @@ int main(const int argc, char **argv)
std::cout << "== training parameter ==" << std::endl;
std::cout << tri;
std::cout << "========================" << std::endl;

NNPR_ENSURE_STATUS(nnfw_train_set_traininfo(session, &tri));

// prepare execution

// TODO When nnfw_{prepare|run} are failed, can't catch the time
measure.run(PhaseType::PREPARE,
[&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri)); });
measure.run(PhaseType::PREPARE, [&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session)); });

// prepare input and expected tensor info lists
std::vector<nnfw_tensorinfo> input_infos;
Expand Down