Skip to content
Permalink
Browse files

Bremen79 fix save ftrl (#1919)

From @bremen79 

* first version of the KT algorithm

* changed from 'kt' to 'approximate cocob' and implemented normalization by average lenght of the feature vectors

* variant that works with squared loss

* fixed all bugs: works great in binary classification, slightly worse than default one in ooa and cbify

* cleaned version, no bias used

* bias and fix bug

* another bug fix

* removed bias and added default params for logistic

* prediction is now stateless

* added comments

* added tests

* fix to ftrl state saving

* moved ftrl_size to a parameter of save_load_online_state

* fixed all the bugs related to resume models

* removed comments

* remove dependence on global
  • Loading branch information...
JohnLangford committed Jun 6, 2019
1 parent 47be22d commit cb3019b8176180e5a4c32633ca9763a3eb563e16
Showing with 85 additions and 52 deletions.
  1. +3 −2 test/save_resume_test.py
  2. +2 −1 vowpalwabbit/OjaNewton.cc
  3. +7 −3 vowpalwabbit/ftrl.cc
  4. +69 −44 vowpalwabbit/gd.cc
  5. +2 −1 vowpalwabbit/gd.h
  6. +2 −1 vowpalwabbit/svrg.cc
@@ -153,8 +153,9 @@ def do_test(filename, args, verbose=None, repeat_args=None, known_failure=False)
errors += do_test(filename, '--loss_function logistic --link logistic')
errors += do_test(filename, '--nn 2')
errors += do_test(filename, '--binary')
errors += do_test(filename, '--ftrl', known_failure=True)
errors += do_test(filename, '--pistol', known_failure=True)
errors += do_test(filename, '--ftrl')
errors += do_test(filename, '--pistol')
errors += do_test(filename, '--coin')

# this one also fails but pollutes output
#errors += do_test(filename, '--ksvm', known_failure=True)
@@ -520,8 +520,9 @@ void save_load(OjaNewton& ON, io_buf& model_file, bool read, bool text)
msg << ":" << resume << "\n";
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

double temp=0.;
if (resume)
GD::save_load_online_state(all, model_file, read, text);
GD::save_load_online_state(all, model_file, read, text, temp);
else
GD::save_load_regressor(all, model_file, read, text);
}
@@ -38,6 +38,7 @@ struct ftrl
struct update_data data;
size_t no_win_counter;
size_t early_stop_thres;
uint32_t ftrl_size;
double total_weight;
};

@@ -151,7 +152,7 @@ void inner_update_pistol_state_and_predict(update_data& d, float x, float& wref)

float squared_theta = w[W_ZT] * w[W_ZT];
float tmp = 1.f / (d.ftrl_alpha * w[W_MX] * (w[W_G2] + w[W_MX]));
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2 * tmp) * tmp;
w[W_XT] = sqrt(w[W_G2]) * d.ftrl_beta * w[W_ZT] * correctedExp(squared_theta / 2.f * tmp) * tmp;

d.predict += w[W_XT] * x;
}
@@ -315,7 +316,7 @@ void save_load(ftrl& b, io_buf& model_file, bool read, bool text)
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

if (resume)
GD::save_load_online_state(*all, model_file, read, text);
GD::save_load_online_state(*all, model_file, read, text, b.total_weight, nullptr, b.ftrl_size);
else
GD::save_load_regressor(*all, model_file, read, text);
}
@@ -376,7 +377,7 @@ base_learner* ftrl_setup(options_i& options, vw& all)
b->all = &all;
b->no_win_counter = 0;
b->all->normalized_sum_norm_x = 0;
b->total_weight = 0.;
b->total_weight = 0;

void (*learn_ptr)(ftrl&, single_learner&, example&) = nullptr;

@@ -389,18 +390,21 @@ base_learner* ftrl_setup(options_i& options, vw& all)
else
learn_ptr = learn_proximal<false>;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 3;
}
else if (pistol)
{
algorithm_name = "PiSTOL";
learn_ptr = learn_pistol;
all.weights.stride_shift(2); // NOTE: for more parameter storage
b->ftrl_size = 4;
}
else if (coin)
{
algorithm_name = "Coin Betting";
learn_ptr = learn_cb;
all.weights.stride_shift(3); // NOTE: for more parameter storage
b->ftrl_size = 6;
}

b->data.ftrl_alpha = b->ftrl_alpha;
@@ -43,7 +43,7 @@ namespace GD
{
struct gd
{
// double normalized_sum_norm_x;
// double normalized_sum_norm_x;
double total_weight;
size_t no_win_counter;
size_t early_stop_thres;
@@ -684,6 +684,23 @@ void sync_weights(vw& all)
all.sd->contraction = 1.;
}

size_t write_index(io_buf& model_file, stringstream& msg, bool text, uint32_t num_bits, uint64_t i) {
size_t brw;
uint32_t old_i = 0;

msg << i;

if (num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

return brw;
}

template <class T>
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text, T& weights)
{
@@ -738,16 +755,8 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text, T& w
{
i = v.index() >> weights.stride_shift();
stringstream msg;
msg << i;

if (all.num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);

brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
}
@@ -762,7 +771,7 @@ void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text)
}

template <class T>
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, T& weights)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g, stringstream& msg, uint32_t ftrl_size, T& weights)
{
uint64_t length = (uint64_t)1 << all.num_bits;

@@ -786,8 +795,10 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
if (i >= length)
THROW("Model content is corrupted, weight vector index " << i << " must be less than total vector length "
<< length);
weight buff[4] = {0, 0, 0, 0};
if (g == NULL || (!g->adaptive && !g->normalized))
weight buff[8] = {0, 0, 0, 0, 0, 0, 0, 0};
if (ftrl_size>0)
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * ftrl_size, "");
else if (g == NULL || (!g->adaptive && !g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]), "");
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
brw += model_file.bin_read_fixed((char*)buff, sizeof(buff[0]) * 2, "");
@@ -799,40 +810,60 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
}
} while (brw > 0);
else // write binary or text
for (typename T::iterator v = weights.begin(); v != weights.end(); ++v)
if (*v != 0.)
{
i = v.index() >> weights.stride_shift();
msg << i;
if (all.num_bits < 31)
{
old_i = (uint32_t)i;
brw = bin_text_write_fixed(model_file, (char*)&old_i, sizeof(old_i), msg, text);
}
else
brw = bin_text_write_fixed(model_file, (char*)&i, sizeof(i), msg, text);
for (typename T::iterator v = weights.begin(); v != weights.end(); ++v) {
i = v.index() >> weights.stride_shift();

if (g == nullptr || (!g->adaptive && !g->normalized))
{
if (ftrl_size==3) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg <<":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
}
else if (ftrl_size==4) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0. || (&(*v))[3]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 4 * sizeof(*v), msg, text);
}
}
else if (ftrl_size==6) {
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0. || (&(*v))[3]!=0. || (&(*v))[4]!=0. || (&(*v))[5]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << " " << (&(*v))[3] << " " << (&(*v))[4] << " " << (&(*v))[5] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 6 * sizeof(*v), msg, text);
}
}
else if (g == nullptr || (!g->adaptive && !g->normalized))
{
if (*v != 0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), sizeof(*v), msg, text);
}
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
{
// either adaptive or normalized
}
else if ((g->adaptive && !g->normalized) || (!g->adaptive && g->normalized))
{
// either adaptive or normalized
if (*v != 0. || (&(*v))[1]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 2 * sizeof(*v), msg, text);
}
else
{
// adaptive and normalized
}
else
{
// adaptive and normalized
if (*v != 0. || (&(*v))[1]!=0. || (&(*v))[2]!=0.) {
brw = write_index(model_file, msg, text, all.num_bits, i);
msg << ":" << *v << " " << (&(*v))[1] << " " << (&(*v))[2] << "\n";
brw += bin_text_write_fixed(model_file, (char*)&(*v), 3 * sizeof(*v), msg, text);
}
}
}
}

void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, gd* g)
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, double& total_weight, gd* g, uint32_t ftrl_size)
{
// vw& all = *g.all;
stringstream msg;
@@ -891,13 +922,8 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
// restore some data to allow --save_resume work more accurate

// fix average loss
double total_weight = 0.; // value holder as g* may be null
if (!read && g != nullptr)
total_weight = g->total_weight;
msg << "gd::total_weight " << total_weight << "\n";
msg << "total_weight " << total_weight << "\n";
bin_text_read_write_fixed(model_file, (char*)&total_weight, sizeof(total_weight), "", read, msg, text);
if (read && g != nullptr)
g->total_weight = total_weight;

// fix "loss since last" for first printed out example details
msg << "sd::oec.weighted_labeled_examples " << all.sd->old_weighted_labeled_examples << "\n";
@@ -931,9 +957,9 @@ void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, g
all.current_pass = 0;
}
if (all.weights.sparse)
save_load_online_state(all, model_file, read, text, g, msg, all.weights.sparse_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.sparse_weights);
else
save_load_online_state(all, model_file, read, text, g, msg, all.weights.dense_weights);
save_load_online_state(all, model_file, read, text, g, msg, ftrl_size, all.weights.dense_weights);
}

template <class T>
@@ -987,8 +1013,7 @@ void save_load(gd& g, io_buf& model_file, bool read, bool text)
<< "WARNING: --save_resume functionality is known to have inaccuracy in model files version less than "
<< VERSION_SAVE_RESUME_FIX << endl
<< endl;
// save_load_online_state(g, model_file, read, text);
save_load_online_state(all, model_file, read, text, &g);
save_load_online_state(all, model_file, read, text, g.total_weight, &g);
}
else
save_load_regressor(all, model_file, read, text);
@@ -23,7 +23,8 @@ struct gd;
float finalize_prediction(shared_data* sd, float ret);
void print_audit_features(vw&, example& ec);
void save_load_regressor(vw& all, io_buf& model_file, bool read, bool text);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, GD::gd* g = nullptr);
void save_load_online_state(vw& all, io_buf& model_file, bool read, bool text, double& total_weight, GD::gd* g = nullptr, uint32_t ftrl_size = 0);


template <class T>
struct multipredict_info
@@ -154,8 +154,9 @@ void save_load(svrg& s, io_buf& model_file, bool read, bool text)
msg << ":" << resume << "\n";
bin_text_read_write_fixed(model_file, (char*)&resume, sizeof(resume), "", read, msg, text);

double temp=0.;
if (resume)
GD::save_load_online_state(*s.all, model_file, read, text);
GD::save_load_online_state(*s.all, model_file, read, text, temp);
else
GD::save_load_regressor(*s.all, model_file, read, text);
}

0 comments on commit cb3019b

Please sign in to comment.
You can’t perform that action at this time.