Skip to content

Commit

Permalink
feat: enable 0-indexed labels for csoaa (#3533)
Browse files Browse the repository at this point in the history
* feat: enable 0-indexed labels for csoaa

* clang

* address comments

* info output

* 177
  • Loading branch information
bassmang committed Dec 16, 2021
1 parent 47cf09a commit 471c4f0
Show file tree
Hide file tree
Showing 13 changed files with 206 additions and 15 deletions.
37 changes: 36 additions & 1 deletion test/core.vwtest.json
Original file line number Diff line number Diff line change
Expand Up @@ -4440,6 +4440,41 @@
"depends_on": [
373
]
},
{
"id": 375,
"desc": "Use csoaa with 1-indexing flag",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 -p csoaa_1_ind.predict --indexing 1",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind.stderr",
"csoaa_1_ind.predict": "pred-sets/ref/csoaa_1_ind.predict"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
},
{
"id": 376,
"desc": "Use csoaa with auto-detect 1-indexing",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 -p csoaa_1_ind_auto.predict",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind_auto.stderr",
"csoaa_1_ind_auto.predict": "pred-sets/ref/csoaa_1_ind.predict"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
},
{
"id": 377,
"desc": "Use csoaa with on 1-indexed data set with 0-indexing flag",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 --indexing 0",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind_0_flag.stderr",
"stdout": "train-sets/ref/csoaa_1_ind_0_flag.stdout"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
}

]
5 changes: 5 additions & 0 deletions test/pred-sets/ref/csoaa_1_ind.predict
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1
1
1
3
4
1 change: 1 addition & 0 deletions test/test-sets/ref/backwards.stdout
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
[info] label 3 found -- labels are now considered 1-indexed.
Id
Min label:-1
Max label:1
Expand Down
5 changes: 5 additions & 0 deletions test/train-sets/csoaa_1_ind.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1:0 2:1 3:1 4:1 | a b c
1:1 2:1 3:0 4:1 | b c d
1:0 2:1 3:1 4:1 | a c e
1:1 2:1 3:1 4:0 | b d f
1:1 2:0 3:1 4:1 | d e f
1 change: 1 addition & 0 deletions test/train-sets/ref/audit2.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
f^a:57421:1:0@0 f^b:62865:1:0@0 f^c:228993:1:0@0 e^x:125153:1:0@0 e^y:27665:1:0@0 e^z:259285:1:0@0 Constant:202097:1:0@0 e^x*f^a:96493:1:0@0 e^x*f^b:93489:1:0@0 e^x*f^c:189985:1:0@0 e^y*f^a:195965:1:0@0 e^y*f^b:190625:1:0@0 e^y*f^c:91057:1:0@0 e^z*f^a:166385:1:0@0 e^z*f^b:171053:1:0@0 e^z*f^c:71485:1:0@0
0
f^a:57422:1:0@0 f^b:62866:1:0@0 f^c:228994:1:0@0 e^x:125154:1:0@0 e^y:27666:1:0@0 e^z:259286:1:0@0 Constant:202098:1:0@0 e^x*f^a:96494:1:0@0 e^x*f^b:93490:1:0@0 e^x*f^c:189986:1:0@0 e^y*f^a:195966:1:0@0 e^y*f^b:190626:1:0@0 e^y*f^c:91058:1:0@0 e^z*f^a:166386:1:0@0 e^z*f^b:171054:1:0@0 e^z*f^c:71486:1:0@0
[info] label 3 found -- labels are now considered 1-indexed.
-0.324225
f^a:57420:1:-0.0540374@4 f^c:228992:1:-0.0540374@4 e^x:125152:1:-0.0540374@4 Constant:202096:1:-0.0540374@4 e^x*f^a:96492:1:-0.0540374@4 e^x*f^c:189984:1:-0.0540374@4 f^y:154732:1:0@0 e^x*f^y:246988:1:0@0
-0.324225
Expand Down
23 changes: 23 additions & 0 deletions test/train-sets/ref/csoaa_1_ind.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
predictions = csoaa_1_ind.predict
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 known 1 4
0.500000 1.000000 2 2.0 known 1 4
0.500000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.600000
total feature number = 20
22 changes: 22 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_0_flag.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
1.000000 1.000000 1 1.0 known 0 4
1.000000 1.000000 2 2.0 known 1 4
0.750000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.800000
total feature number = 20
5 changes: 5 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_0_flag.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
23 changes: 23 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_auto.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
predictions = csoaa_1_ind_auto.predict
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 known 1 4
0.500000 1.000000 2 2.0 known 1 4
0.500000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.600000
total feature number = 20
1 change: 1 addition & 0 deletions test/train-sets/ref/help.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ Cost Sensitive Active Learning Options:
--csa_debug Print debug stuff for cs_active
Cost Sensitive One Against All Options:
--csoaa arg One-against-all multiclass with <k> costs
--indexing arg Choose between 0 or 1-indexing. Choices: {0, 1}
Cost Sensitive One Against All with Label Dependent Features Options:
--csoaa_ldf arg Use one-against-all multiclass learning with label
dependent features
Expand Down
1 change: 1 addition & 0 deletions test/train-sets/ref/oaa_mixed_probabilities.stdout
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
[info] label 0 found -- labels are now considered 0-indexed.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
83 changes: 71 additions & 12 deletions vowpalwabbit/csoaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,22 +21,34 @@ namespace CSOAA
struct csoaa
{
uint32_t num_classes = 0;
int indexing = -1;
bool search = false;
polyprediction* pred = nullptr;
~csoaa() { free(pred); }
};

template <bool is_learn>
inline void inner_loop(single_learner& base, example& ec, uint32_t i, float cost, uint32_t& prediction, float& score,
float& partial_prediction)
float& partial_prediction, int indexing)
{
if (is_learn)
{
ec.weight = (cost == FLT_MAX) ? 0.f : 1.f;
ec.l.simple.label = cost;
base.learn(ec, i - 1);
if (indexing == 0) { base.learn(ec, i); }
else
{
base.learn(ec, i - 1);
}
}
else
base.predict(ec, i - 1);
{
if (indexing == 0) { base.predict(ec, i); }
else
{
base.predict(ec, i - 1);
}
}

partial_prediction = ec.partial_prediction;
if (ec.partial_prediction < score || (ec.partial_prediction == score && i < prediction))
Expand All @@ -52,12 +64,44 @@ inline void inner_loop(single_learner& base, example& ec, uint32_t i, float cost
template <bool is_learn>
void predict_or_learn(csoaa& c, single_learner& base, example& ec)
{
if (!c.search)
{
for (auto& cost : ec.l.cs.costs)
{
auto& lbl = cost.class_index;
// Update indexing
if (c.indexing == -1 && lbl == 0)
{
logger::log_info("label 0 found -- labels are now considered 0-indexed.");
c.indexing = 0;
}
else if (c.indexing == -1 && lbl == c.num_classes)
{
logger::log_info("label {0} found -- labels are now considered 1-indexed.", c.num_classes);
c.indexing = 1;
}

// Label validation
if (c.indexing == 0 && lbl >= c.num_classes)
{
logger::log_warn(
"label {0} is not in {{0,{1}}}. This won't work for 0-indexed actions.", lbl, c.num_classes - 1);
lbl = 0;
}
else if (c.indexing == 1 && (lbl < 1 || lbl > c.num_classes))
{
logger::log_warn("label {0} is not in {{1,{1}}}. This won't work for 1-indexed actions.", lbl, c.num_classes);
lbl = static_cast<uint32_t>(c.num_classes);
}
}
}

COST_SENSITIVE::label ld = std::move(ec.l.cs);

// Guard example state restore against throws
auto restore_guard = VW::scope_exit([&ld, &ec] { ec.l.cs = std::move(ld); });

uint32_t prediction = 1;
uint32_t prediction = (c.indexing == 0) ? 0 : 1;
float score = FLT_MAX;
size_t pt_start = ec.passthrough ? ec.passthrough->size() : 0;
ec.l.simple = {0.};
Expand All @@ -68,7 +112,7 @@ void predict_or_learn(csoaa& c, single_learner& base, example& ec)
if (!ld.costs.empty())
{
for (auto& cl : ld.costs)
inner_loop<is_learn>(base, ec, cl.class_index, cl.x, prediction, score, cl.partial_prediction);
inner_loop<is_learn>(base, ec, cl.class_index, cl.x, prediction, score, cl.partial_prediction, c.indexing);
ec.partial_prediction = score;
}
else if (dont_learn)
Expand All @@ -77,17 +121,30 @@ void predict_or_learn(csoaa& c, single_learner& base, example& ec)
ec._reduction_features.template get<simple_label_reduction_features>().reset_to_default();

base.multipredict(ec, 0, c.num_classes, c.pred, false);
for (uint32_t i = 1; i <= c.num_classes; i++)
if (c.indexing == 0)
{
add_passthrough_feature(ec, i, c.pred[i - 1].scalar);
if (c.pred[i - 1].scalar < c.pred[prediction - 1].scalar) prediction = i;
for (uint32_t i = 0; i <= c.num_classes; i++)
{
add_passthrough_feature(ec, i, c.pred[i].scalar);
if (c.pred[i].scalar < c.pred[prediction].scalar) prediction = i;
}
ec.partial_prediction = c.pred[prediction].scalar;
}
else
{
for (uint32_t i = 1; i <= c.num_classes; i++)
{
add_passthrough_feature(ec, i, c.pred[i - 1].scalar);
if (c.pred[i - 1].scalar < c.pred[prediction - 1].scalar) prediction = i;
}
ec.partial_prediction = c.pred[prediction - 1].scalar;
}
ec.partial_prediction = c.pred[prediction - 1].scalar;
}
else
{
float temp;
for (uint32_t i = 1; i <= c.num_classes; i++) inner_loop<false>(base, ec, i, FLT_MAX, prediction, score, temp);
for (uint32_t i = 1; i <= c.num_classes; i++)
inner_loop<false>(base, ec, i, FLT_MAX, prediction, score, temp, c.indexing);
}

if (ec.passthrough)
Expand Down Expand Up @@ -124,13 +181,15 @@ base_learner* csoaa_setup(VW::setup_base_i& stack_builder)
VW::workspace& all = *stack_builder.get_all_pointer();
auto c = VW::make_unique<csoaa>();
option_group_definition new_options("Cost Sensitive One Against All");
new_options.add(
make_option("csoaa", c->num_classes).keep().necessary().help("One-against-all multiclass with <k> costs"));
new_options
.add(make_option("csoaa", c->num_classes).keep().necessary().help("One-against-all multiclass with <k> costs"))
.add(make_option("indexing", c->indexing).one_of({0, 1}).keep().help("Choose between 0 or 1-indexing"));

if (!options.add_parse_and_check_necessary(new_options)) return nullptr;

if (options.was_supplied("probabilities"))
{ THROW("Error: csoaa does not support probabilities flag, please use oaa or multilabel_oaa"); }
c->search = options.was_supplied("search");

c->pred = calloc_or_throw<polyprediction>(c->num_classes);
size_t ws = c->num_classes;
Expand Down
14 changes: 12 additions & 2 deletions vowpalwabbit/oaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,14 @@ struct oaa
void learn_randomized(oaa& o, VW::LEARNER::single_learner& base, example& ec)
{
// Update indexing
if (o.indexing == -1 && ec.l.multi.label == 0) { o.indexing = 0; }
if (o.indexing == -1 && ec.l.multi.label == 0)
{
logger::log_info("label 0 found -- labels are now considered 0-indexed.");
o.indexing = 0;
}
else if (o.indexing == -1 && ec.l.multi.label == o.k)
{
logger::log_info("label {0} found -- labels are now considered 1-indexed.", o.k);
o.indexing = 1;
}

Expand Down Expand Up @@ -98,9 +103,14 @@ template <bool print_all, bool scores, bool probabilities>
void learn(oaa& o, VW::LEARNER::single_learner& base, example& ec)
{
// Update indexing
if (o.indexing == -1 && ec.l.multi.label == 0) { o.indexing = 0; }
if (o.indexing == -1 && ec.l.multi.label == 0)
{
logger::log_info("label 0 found -- labels are now considered 0-indexed.");
o.indexing = 0;
}
else if (o.indexing == -1 && ec.l.multi.label == o.k)
{
logger::log_info("label {0} found -- labels are now considered 1-indexed.", o.k);
o.indexing = 1;
}

Expand Down

0 comments on commit 471c4f0

Please sign in to comment.