diff --git a/test/save_resume_test.py b/test/save_resume_test.py index 08daf908db9..8b16eb0cdec 100644 --- a/test/save_resume_test.py +++ b/test/save_resume_test.py @@ -153,8 +153,9 @@ def do_test(filename, args, verbose=None, repeat_args=None, known_failure=False) errors += do_test(filename, '--loss_function logistic --link logistic') errors += do_test(filename, '--nn 2') errors += do_test(filename, '--binary') - errors += do_test(filename, '--ftrl', known_failure=True) + errors += do_test(filename, '--ftrl') errors += do_test(filename, '--pistol', known_failure=True) + errors += do_test(filename, '--coin', known_failure=True) # this one also fails but pollutes output #errors += do_test(filename, '--ksvm', known_failure=True) diff --git a/vowpalwabbit/ftrl.cc b/vowpalwabbit/ftrl.cc index 806ae9160b0..188656a5abc 100644 --- a/vowpalwabbit/ftrl.cc +++ b/vowpalwabbit/ftrl.cc @@ -39,6 +39,7 @@ struct ftrl size_t no_win_counter; size_t early_stop_thres; double total_weight; + uint32_t ftrl_size; }; struct uncertainty @@ -151,7 +152,7 @@ void inner_update_pistol_state_and_predict(update_data& d, float x, float& wref) float squared_theta = w[W_ZT] * w[W_ZT]; float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX])); - w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2 * tmp) * tmp; + w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp; d.predict += w[W_XT] * x; } @@ -315,7 +316,7 @@ void save_load(ftrl& b, io_buf& model_file, bool read, bool text) bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text); if (resume) - GD::save_load_online_state(*all, model_file, read, text); + GD::save_load_online_state(*all, model_file, read, text, nullptr, b.ftrl_size); else GD::save_load_regressor(*all, model_file, read, text); } @@ -389,18 +390,21 @@ base_learner* ftrl_setup(options_i& options, vw& all) else learn_ptr = learn_proximal; all.weights.stride_shift(2); // NOTE: for more parameter storage + b->ftrl_size = 3; } else if (pistol) { algorithm_name = "PiSTOL"; learn_ptr = learn_pistol; all.weights.stride_shift(2); // NOTE: for more parameter storage + b->ftrl_size = 4; } else if (coin) { algorithm_name = "Coin Betting"; learn_ptr = learn_cb; all.weights.stride_shift(3); // NOTE: for more parameter storage + b->ftrl_size = 6; } b->data.ftrl_alpha = b->ftrl_alpha; diff --git a/vowpalwabbit/gd.cc b/vowpalwabbit/gd.cc index 70fe76cf2f3..d0a4fd5900d 100644 --- a/vowpalwabbit/gd.cc +++ b/vowpalwabbit/gd.cc @@ -762,7 +762,7 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text) } template -void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, T& weights) +void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, uint32_t ftrl_size, T& weights) { uint64_t length = (uint64_t)1 << all.num_bits; @@ -786,8 +786,10 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g if (i >= length) THROW("Model content is corrupted, weight vector index " << i << " must be less than total vector length " << length); - weight buff[4] = {0, 0, 0, 0}; - if (g == NULL || (!g->adaptive && !g->normalized)) + weight buff[8] = {0, 0, 0, 0, 0, 0, 0, 0}; + if (ftrl_size>0) + brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * ftrl_size, ""); + else if (g == NULL || (!g->adaptive && !g->normalized)) brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]), ""); else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized)) brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * 2, ""); @@ -812,7 +814,19 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g else brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text); - if (g == nullptr || (!g->adaptive && !g->normalized)) + if (ftrl_size==3) { + msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n"; + brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text); + } + else if (ftrl_size==4) { + msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n"; + brw += bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text); + } + else if (ftrl_size==6) { + msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n"; + brw += bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text); + } + else if (g == nullptr || (!g->adaptive && !g->normalized)) { msg << ":" << *v << "\n"; brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text); @@ -832,7 +846,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g } } -void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g) +void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, uint32_t ftrl_size) { // vw& all = *g.all; stringstream msg; @@ -931,9 +945,9 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g all.current_pass = 0; } if (all.weights.sparse) - save_load_online_state(all, model_file, read, text, g, msg, all.weights.sparse_weights); + save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.sparse_weights); else - save_load_online_state(all, model_file, read, text, g, msg, all.weights.dense_weights); + save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.dense_weights); } template diff --git a/vowpalwabbit/gd.h b/vowpalwabbit/gd.h index 697e61cba51..4c9380d3f62 100644 --- a/vowpalwabbit/gd.h +++ b/vowpalwabbit/gd.h @@ -23,7 +23,8 @@ struct gd; float finalize_prediction(shared_data* sd, float ret); void print_audit_features(vw&, example& ec); void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text); -void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr); +void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr, uint32_t ftrl_size = 0); + template struct multipredict_info