Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enable 0-indexed labels for csoaa #3533

Merged
merged 9 commits into from
Dec 16, 2021
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
37 changes: 36 additions & 1 deletion test/core.vwtest.json
Original file line number Diff line number Diff line change
Expand Up @@ -4440,6 +4440,41 @@
"depends_on": [
373
]
},
{
"id": 375,
"desc": "Use csoaa with 1-indexing flag",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 -p csoaa_1_ind.predict --indexing 1",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind.stderr",
"csoaa_1_ind.predict": "pred-sets/ref/csoaa_1_ind.predict"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
},
{
"id": 376,
"desc": "Use csoaa with auto-detect 1-indexing",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 -p csoaa_1_ind_auto.predict",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind_auto.stderr",
"csoaa_1_ind_auto.predict": "pred-sets/ref/csoaa_1_ind.predict"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
},
{
"id": 377,
"desc": "Use csoaa with on 1-indexed data set with 0-indexing flag",
"vw_command": "-d train-sets/csoaa_1_ind.dat --csoaa 4 --indexing 0",
"diff_files": {
"stderr": "train-sets/ref/csoaa_1_ind_0_flag.stderr",
"stdout": "train-sets/ref/csoaa_1_ind_0_flag.stdout"
},
"input_files": [
"train-sets/csoaa_1_ind.dat"
]
}

]
5 changes: 5 additions & 0 deletions test/pred-sets/ref/csoaa_1_ind.predict
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1
1
1
3
4
5 changes: 5 additions & 0 deletions test/train-sets/csoaa_1_ind.dat
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
1:0 2:1 3:1 4:1 | a b c
1:1 2:1 3:0 4:1 | b c d
1:0 2:1 3:1 4:1 | a c e
1:1 2:1 3:1 4:0 | b d f
1:1 2:0 3:1 4:1 | d e f
23 changes: 23 additions & 0 deletions test/train-sets/ref/csoaa_1_ind.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
predictions = csoaa_1_ind.predict
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 known 1 4
0.500000 1.000000 2 2.0 known 1 4
0.500000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.600000
total feature number = 20
22 changes: 22 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_0_flag.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
1.000000 1.000000 1 1.0 known 0 4
1.000000 1.000000 2 2.0 known 1 4
0.750000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.800000
total feature number = 20
5 changes: 5 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_0_flag.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
[warning] label 4 is not in {0,3}. This won't work for 0-indexed actions.
23 changes: 23 additions & 0 deletions test/train-sets/ref/csoaa_1_ind_auto.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
predictions = csoaa_1_ind_auto.predict
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
using no cache
Reading datafile = train-sets/csoaa_1_ind.dat
num sources = 1
Enabled reductions: gd, scorer-identity, csoaa
Input label = cs
Output pred = multiclass
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 known 1 4
0.500000 1.000000 2 2.0 known 1 4
0.500000 0.500000 4 4.0 known 3 4

finished run
number of examples = 5
weighted example sum = 5.000000
weighted label sum = 0.000000
average loss = 0.600000
total feature number = 20
1 change: 1 addition & 0 deletions test/train-sets/ref/help.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ Cost Sensitive Active Learning Options:
--csa_debug Print debug stuff for cs_active
Cost Sensitive One Against All Options:
--csoaa arg One-against-all multiclass with <k> costs
--indexing arg Choose between 0 or 1-indexing. Choices: {0, 1}
Cost Sensitive One Against All with Label Dependent Features Options:
--csoaa_ldf arg Use one-against-all multiclass learning with label
dependent features
Expand Down
78 changes: 66 additions & 12 deletions vowpalwabbit/csoaa.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,34 @@ namespace CSOAA
struct csoaa
{
uint32_t num_classes = 0;
int indexing = -1;
bool search = false;
polyprediction* pred = nullptr;
~csoaa() { free(pred); }
};

template <bool is_learn>
inline void inner_loop(single_learner& base, example& ec, uint32_t i, float cost, uint32_t& prediction, float& score,
float& partial_prediction)
float& partial_prediction, int& indexing)
bassmang marked this conversation as resolved.
Show resolved Hide resolved
{
if (is_learn)
{
ec.weight = (cost == FLT_MAX) ? 0.f : 1.f;
ec.l.simple.label = cost;
base.learn(ec, i - 1);
if (indexing == 0) { base.learn(ec, i); }
else
{
base.learn(ec, i - 1);
}
}
else
base.predict(ec, i - 1);
{
if (indexing == 0) { base.predict(ec, i); }
else
{
base.predict(ec, i - 1);
}
}

partial_prediction = ec.partial_prediction;
if (ec.partial_prediction < score || (ec.partial_prediction == score && i < prediction))
Expand All @@ -65,12 +77,39 @@ inline void inner_loop(single_learner& base, example& ec, uint32_t i, float cost
template <bool is_learn>
void predict_or_learn(csoaa& c, single_learner& base, example& ec)
{
if (!c.search)
{
for (auto& cost : ec.l.cs.costs)
{
auto& lbl = cost.class_index;
// Update indexing
if (c.indexing == -1 && lbl == 0) { c.indexing = 0; }
else if (c.indexing == -1 && lbl == c.num_classes)
{
c.indexing = 1;
}

// Label validation
if (c.indexing == 0 && lbl >= c.num_classes)
{
logger::log_warn(
"label {0} is not in {{0,{1}}}. This won't work for 0-indexed actions.", lbl, c.num_classes - 1);
lbl = 0;
}
else if (c.indexing == 1 && (lbl < 1 || lbl > c.num_classes))
{
logger::log_warn("label {0} is not in {{1,{1}}}. This won't work for 1-indexed actions.", lbl, c.num_classes);
lbl = static_cast<uint32_t>(c.num_classes);
}
}
}

COST_SENSITIVE::label ld = std::move(ec.l.cs);

// Guard example state restore against throws
auto restore_guard = VW::scope_exit([&ld, &ec] { ec.l.cs = std::move(ld); });

uint32_t prediction = 1;
uint32_t prediction = (c.indexing == 0) ? 0 : 1;
float score = FLT_MAX;
size_t pt_start = ec.passthrough ? ec.passthrough->size() : 0;
ec.l.simple = {0.};
Expand All @@ -81,7 +120,7 @@ void predict_or_learn(csoaa& c, single_learner& base, example& ec)
if (!ld.costs.empty())
{
for (auto& cl : ld.costs)
inner_loop<is_learn>(base, ec, cl.class_index, cl.x, prediction, score, cl.partial_prediction);
inner_loop<is_learn>(base, ec, cl.class_index, cl.x, prediction, score, cl.partial_prediction, c.indexing);
ec.partial_prediction = score;
}
else if (dont_learn)
Expand All @@ -90,17 +129,30 @@ void predict_or_learn(csoaa& c, single_learner& base, example& ec)
ec._reduction_features.template get<simple_label_reduction_features>().reset_to_default();

base.multipredict(ec, 0, c.num_classes, c.pred, false);
for (uint32_t i = 1; i <= c.num_classes; i++)
if (c.indexing == 0)
{
for (uint32_t i = 0; i <= c.num_classes; i++)
{
add_passthrough_feature(ec, i, c.pred[i].scalar);
if (c.pred[i].scalar < c.pred[prediction].scalar) prediction = i;
}
ec.partial_prediction = c.pred[prediction].scalar;
}
else
{
add_passthrough_feature(ec, i, c.pred[i - 1].scalar);
if (c.pred[i - 1].scalar < c.pred[prediction - 1].scalar) prediction = i;
for (uint32_t i = 1; i <= c.num_classes; i++)
{
add_passthrough_feature(ec, i, c.pred[i - 1].scalar);
if (c.pred[i - 1].scalar < c.pred[prediction - 1].scalar) prediction = i;
}
ec.partial_prediction = c.pred[prediction - 1].scalar;
}
ec.partial_prediction = c.pred[prediction - 1].scalar;
}
else
{
float temp;
for (uint32_t i = 1; i <= c.num_classes; i++) inner_loop<false>(base, ec, i, FLT_MAX, prediction, score, temp);
for (uint32_t i = 1; i <= c.num_classes; i++)
inner_loop<false>(base, ec, i, FLT_MAX, prediction, score, temp, c.indexing);
}

if (ec.passthrough)
Expand Down Expand Up @@ -137,13 +189,15 @@ base_learner* csoaa_setup(VW::setup_base_i& stack_builder)
VW::workspace& all = *stack_builder.get_all_pointer();
auto c = VW::make_unique<csoaa>();
option_group_definition new_options("Cost Sensitive One Against All");
new_options.add(
make_option("csoaa", c->num_classes).keep().necessary().help("One-against-all multiclass with <k> costs"));
new_options
.add(make_option("csoaa", c->num_classes).keep().necessary().help("One-against-all multiclass with <k> costs"))
.add(make_option("indexing", c->indexing).one_of({0, 1}).keep().help("Choose between 0 or 1-indexing"));
bassmang marked this conversation as resolved.
Show resolved Hide resolved

if (!options.add_parse_and_check_necessary(new_options)) return nullptr;

if (options.was_supplied("probabilities"))
{ THROW("Error: csoaa does not support probabilities flag, please use oaa or multilabel_oaa"); }
c->search = options.was_supplied("search");

c->pred = calloc_or_throw<polyprediction>(c->num_classes);
size_t ws = c->num_classes;
Expand Down