diff --git a/vowpalwabbit/core/include/vw/core/learner.h b/vowpalwabbit/core/include/vw/core/learner.h index 02a97cef1a9..2bffe33d81f 100644 --- a/vowpalwabbit/core/include/vw/core/learner.h +++ b/vowpalwabbit/core/include/vw/core/learner.h @@ -482,10 +482,18 @@ class common_learner_builder this->learner_ptr->_cleanup_example_f = [fn_ptr, data](polymorphic_ex ex) { fn_ptr(*data, ex); }; ) + // Set the function pointer for persisting metrics for the learner. + // This function enforces that force reductions can only add metrics in their own namespace. + // The metrics output dictionary is a dictionary of dictionaries, where each sub-dictionary + // contains the metrics for a single learner. The key for each sub-dictionary is the learner's name. LEARNER_BUILDER_DEFINE(set_persist_metrics(void (*fn_ptr)(DataT&, metric_sink&)), assert(fn_ptr != nullptr); DataT* data = this->learner_data.get(); - this->learner_ptr->_persist_metrics_f = [fn_ptr, data](metric_sink& metrics) { fn_ptr(*data, metrics); }; + std::string learner_name = this->name; + this->learner_ptr->_persist_metrics_f = [fn_ptr, data, learner_name](metric_sink& metrics) { + metrics[learner_name] = {}; + fn_ptr(*data, metrics[learner_name]); + }; ) LEARNER_BUILDER_DEFINE(set_pre_save_load(void (*fn_ptr)(VW::workspace& all, DataT&)),