diff --git a/runtime/onert/api/include/nnfw_experimental.h b/runtime/onert/api/include/nnfw_experimental.h index 4c25b04e5e9..93ff9de2808 100644 --- a/runtime/onert/api/include/nnfw_experimental.h +++ b/runtime/onert/api/include/nnfw_experimental.h @@ -233,16 +233,14 @@ NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_inf * @brief Prepare session to be ready for training * @note The session will be entered into training mode * + * If training info is NOT set in session, this function returns @c NNFW_STATUS_ERROR . + * You should set training info using {@link nnfw_train_set_traininfo}. + * * @param[in] session The session to be prepared for training - * @param[in] info Training information. - * If info is nullptr, it will not change training information. - * If it is nullptr and model has not training information, - * it will use default training information. - * Default training information is {learning_rate = 0.001f, batch_size = 1} * * @return @c NNFW_STATUS_NO_ERROR if successful */ -NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info); +NNFW_STATUS nnfw_train_prepare(nnfw_session *session); /** * @brief Set training input diff --git a/runtime/onert/api/src/nnfw_api.cc b/runtime/onert/api/src/nnfw_api.cc index a713b66ac87..6ebaf0970f1 100644 --- a/runtime/onert/api/src/nnfw_api.cc +++ b/runtime/onert/api/src/nnfw_api.cc @@ -394,10 +394,10 @@ NNFW_STATUS nnfw_train_set_traininfo(nnfw_session *session, const nnfw_train_inf return session->train_set_traininfo(info); } -NNFW_STATUS nnfw_train_prepare(nnfw_session *session, const nnfw_train_info *info) +NNFW_STATUS nnfw_train_prepare(nnfw_session *session) { NNFW_RETURN_ERROR_IF_NULL(session); - return session->train_prepare(info); + return session->train_prepare(); } NNFW_STATUS nnfw_train_input_tensorinfo(nnfw_session *session, uint32_t index, diff --git a/runtime/onert/api/src/nnfw_api_internal.cc b/runtime/onert/api/src/nnfw_api_internal.cc index fbdd0c7c00c..b28b65f6777 100644 --- a/runtime/onert/api/src/nnfw_api_internal.cc +++ b/runtime/onert/api/src/nnfw_api_internal.cc @@ -1242,7 +1242,7 @@ NNFW_STATUS nnfw_session::train_set_traininfo(const nnfw_train_info *info) return NNFW_STATUS_NO_ERROR; } -NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info) +NNFW_STATUS nnfw_session::train_prepare() { // We may need different state to represent training model is loaded if (!isStateModelLoaded()) @@ -1256,56 +1256,16 @@ NNFW_STATUS nnfw_session::train_prepare(const nnfw_train_info *info) return NNFW_STATUS_INVALID_STATE; } + // after model loaded, it ensures that _train_info is not nullptr + assert(_train_info != nullptr); + try { - nnfw_train_info tinfo; - if (info != nullptr) - { - tinfo = *info; - } - - auto convertLossType = [](const int &type) { - if (type == NNFW_TRAIN_LOSS_MEAN_SQUARED_ERROR) - return onert::ir::train::LossCode::MeanSquaredError; - if (type == NNFW_TRAIN_LOSS_CATEGORICAL_CROSSENTROPY) - return onert::ir::train::LossCode::CategoricalCrossentropy; - else - throw std::runtime_error("not supported loss type"); - }; - auto convertLossReductionType = [](const int &type) { - if (type == NNFW_TRAIN_LOSS_REDUCTION_AUTO) - return onert::ir::train::LossReductionType::Auto; - else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM_OVER_BATCH_SIZE) - return onert::ir::train::LossReductionType::SumOverBatchSize; - else if (type == NNFW_TRAIN_LOSS_REDUCTION_SUM) - return onert::ir::train::LossReductionType::Sum; - else - throw std::runtime_error("not supported loss reduction type"); - }; - onert::ir::train::LossInfo loss_info; - loss_info.loss_code = convertLossType(tinfo.loss_info.loss); - // TODO Consider the reduction type of model file - loss_info.reduction_type = convertLossReductionType(tinfo.loss_info.reduction_type); - - auto convertOptType = [](const int &type) { - if (type == NNFW_TRAIN_OPTIMIZER_SGD) - return onert::ir::train::OptimizerCode::SGD; - else if (type == NNFW_TRAIN_OPTIMIZER_ADAM) - return onert::ir::train::OptimizerCode::Adam; - else - throw std::runtime_error("not supported optimizer type"); - }; - onert::ir::train::OptimizerInfo opt_info; - opt_info.learning_rate = tinfo.learning_rate; - opt_info.optim_code = convertOptType(tinfo.opt); - - onert::ir::train::TrainingInfo training_info; - training_info.setBatchSize(tinfo.batch_size); - training_info.setLossInfo(loss_info); - training_info.setOptimizerInfo(opt_info); + if (not _train_info->isValid()) + throw std::runtime_error{"training info is not valid"}; auto compiler = - onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, &training_info); + onert::compiler::CompilerFactory::get().create(_nnpkg, _coptions, _train_info.get()); _nnpkg.reset(); _compiler_artifact = compiler->compile(); _execution = std::make_unique(_compiler_artifact->_executors); diff --git a/runtime/onert/api/src/nnfw_api_internal.h b/runtime/onert/api/src/nnfw_api_internal.h index 1ea67450880..5f803921fc6 100644 --- a/runtime/onert/api/src/nnfw_api_internal.h +++ b/runtime/onert/api/src/nnfw_api_internal.h @@ -171,7 +171,7 @@ struct nnfw_session NNFW_STATUS set_backends_per_operation(const char *backend_settings); NNFW_STATUS train_set_traininfo(const nnfw_train_info *info); - NNFW_STATUS train_prepare(const nnfw_train_info *info); + NNFW_STATUS train_prepare(); NNFW_STATUS train_input_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); NNFW_STATUS train_expected_tensorinfo(uint32_t index, nnfw_tensorinfo *ti); NNFW_STATUS train_set_input(uint32_t index, const void *input, diff --git a/tests/tools/onert_train/src/onert_train.cc b/tests/tools/onert_train/src/onert_train.cc index 4f5bf33eca7..247d09bd22e 100644 --- a/tests/tools/onert_train/src/onert_train.cc +++ b/tests/tools/onert_train/src/onert_train.cc @@ -185,11 +185,13 @@ int main(const int argc, char **argv) std::cout << "== training parameter ==" << std::endl; std::cout << tri; std::cout << "========================" << std::endl; + + NNPR_ENSURE_STATUS(nnfw_train_set_traininfo(session, &tri)); + // prepare execution // TODO When nnfw_{prepare|run} are failed, can't catch the time - measure.run(PhaseType::PREPARE, - [&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session, &tri)); }); + measure.run(PhaseType::PREPARE, [&]() { NNPR_ENSURE_STATUS(nnfw_train_prepare(session)); }); // prepare input and expected tensor info lists std::vector input_infos;