Skip to content

Commit

Permalink
Fix to save the state of FTRL models (#1912)
Browse files Browse the repository at this point in the history
* first version of the KT algorithm

* changed from 'kt' to 'approximate cocob' and implemented normalization by average lenght of the feature vectors

* variant that works with squared loss

* fixed all bugs: works great in binary classification, slightly worse than default one in ooa and cbify

* cleaned version, no bias used

* bias and fix bug

* another bug fix

* removed bias and added default params for logistic

* prediction is now stateless

* added comments

* added tests

* fix to ftrl state saving

* moved ftrl_size to a parameter of save_load_online_state
  • Loading branch information
bremen79 authored and JohnLangford committed Jun 5, 2019
1 parent 61aafc8 commit 538e9bb
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 11 deletions.
3 changes: 2 additions & 1 deletion test/save_resume_test.py
Expand Up @@ -153,8 +153,9 @@ def do_test(filename, args, verbose=None, repeat_args=None, known_failure=False)
errors += do_test(filename, '--loss_function logistic --link logistic')
errors += do_test(filename, '--nn 2')
errors += do_test(filename, '--binary')
errors += do_test(filename, '--ftrl', known_failure=True)
errors += do_test(filename, '--ftrl')
errors += do_test(filename, '--pistol', known_failure=True)
errors += do_test(filename, '--coin', known_failure=True)

# this one also fails but pollutes output
#errors += do_test(filename, '--ksvm', known_failure=True)
Expand Down
8 changes: 6 additions & 2 deletions vowpalwabbit/ftrl.cc
Expand Up @@ -39,6 +39,7 @@ struct ftrl
size_t no_win_counter;
size_t early_stop_thres;
double total_weight;
uint32_t ftrl_size;
};

struct uncertainty
Expand Down Expand Up @@ -151,7 +152,7 @@ void inner_update_pistol_state_and_predict(update_data& d, float x, float& wref)

float squared_theta = w[W_ZT] * w[W_ZT];
float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2 * tmp) * tmp;
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;

d.predict += w[W_XT] * x;
}
Expand Down Expand Up @@ -315,7 +316,7 @@ void save_load(ftrl& b, io_buf& model_file, bool read, bool text)
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

if (resume)
GD::save_load_online_state(*all, model_file, read, text);
GD::save_load_online_state(*all, model_file, read, text, nullptr, b.ftrl_size);
else
GD::save_load_regressor(*all, model_file, read, text);
}
Expand Down Expand Up @@ -389,18 +390,21 @@ base_learner* ftrl_setup(options_i& options, vw& all)
else
learn_ptr = learn_proximal<false>;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 3;
}
else if (pistol)
{
algorithm_name = "PiSTOL";
learn_ptr = learn_pistol;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 4;
}
else if (coin)
{
algorithm_name = "Coin Betting";
learn_ptr = learn_cb;
all.weights.stride_shift(3); // NOTE: for more parameter storage
b->ftrl_size = 6;
}

b->data.ftrl_alpha = b->ftrl_alpha;
Expand Down
28 changes: 21 additions & 7 deletions vowpalwabbit/gd.cc
Expand Up @@ -762,7 +762,7 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
}

template <class T>
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, T& weights)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, uint32_t ftrl_size, T& weights)
{
uint64_t length = (uint64_t)1 << all.num_bits;

Expand All @@ -786,8 +786,10 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
if (i >= length)
THROW("Model content is corrupted, weight vector index " << i << " must be less than total vector length "
<< length);
weight buff[4] = {0, 0, 0, 0};
if (g == NULL || (!g->adaptive && !g->normalized))
weight buff[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (ftrl_size>0)
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * ftrl_size, "");
else if (g == NULL || (!g->adaptive && !g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]), "");
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * 2, "");
Expand All @@ -812,7 +814,19 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

if (g == nullptr || (!g->adaptive && !g->normalized))
if (ftrl_size==3) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
else if (ftrl_size==4) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
else if (ftrl_size==6) {
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
else if (g == nullptr || (!g->adaptive && !g->normalized))
{
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
Expand All @@ -832,7 +846,7 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
}
}

void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, uint32_t ftrl_size)
{
// vw& all = *g.all;
stringstream msg;
Expand Down Expand Up @@ -931,9 +945,9 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
all.current_pass = 0;
}
if (all.weights.sparse)
save_load_online_state(all, model_file, read, text, g, msg, all.weights.sparse_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.sparse_weights);
else
save_load_online_state(all, model_file, read, text, g, msg, all.weights.dense_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.dense_weights);
}

template <class T>
Expand Down
3 changes: 2 additions & 1 deletion vowpalwabbit/gd.h
Expand Up @@ -23,7 +23,8 @@ struct gd;
float finalize_prediction(shared_data* sd, float ret);
void print_audit_features(vw&, example& ec);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr, uint32_t ftrl_size = 0);


template <class T>
struct multipredict_info
Expand Down

0 comments on commit 538e9bb

Please sign in to comment.