diff --git a/cs/unittest/TestArguments.cs b/cs/unittest/TestArguments.cs index 1db92275011..e4a8785ac5b 100644 --- a/cs/unittest/TestArguments.cs +++ b/cs/unittest/TestArguments.cs @@ -1,128 +1,128 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; -using System; -using System.Collections.Generic; -using System.IO; -using System.Linq; -using System.Text; -using System.Threading; -using System.Threading.Tasks; -using VW; - -namespace cs_unittest -{ - [TestClass] - public class TestArgumentsClass - { - [TestMethod] - [TestCategory("Vowpal Wabbit")] - public void TestArguments() - { - using (var vw = new VowpalWabbit(new VowpalWabbitSettings("--cb_explore_adf --epsilon 0.3 --interact ud") { Verbose = true })) - { - // --cb_explore_adf --epsilon 0.3 --interact ud --cb_adf--csoaa_ldf multiline --csoaa_rank - Console.WriteLine(vw.Arguments.CommandLine); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_explore_adf")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--epsilon 0.3")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--interact ud")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_ldf multiline")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_rank")); - vw.SaveModel("args.model"); - } - - using (var vw = new VowpalWabbit(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) - { - Console.WriteLine(vw.Arguments.CommandLine); - // --no_stdin--bit_precision 18--cb_explore_adf--epsilon 0.300000--cb_adf--cb_type ips --csoaa_ldf multiline--csoaa_rank--interact ud - - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--no_stdin")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--bit_precision 18")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_explore_adf")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--epsilon 0.3")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--interact ud")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_ldf multiline")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_rank")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_type ips")); - } - } - - [TestMethod] - [TestCategory("Vowpal Wabbit")] - public void TestQuietAndTestArguments() - { - using (var vw = new VowpalWabbit("--quiet -t")) - { - vw.SaveModel("args.model"); - } - - using (var vw = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) - { - Assert.IsFalse(vw.Arguments.CommandLine.Contains("--quiet")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("-t")); - - using (var vwSub = new VowpalWabbit(new VowpalWabbitSettings { Model = vw })) - { - Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("--quiet")); - Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("-t")); - } - } - - using (var vw = new VowpalWabbit("")) - { - vw.SaveModel("args.model"); - } - - using (var vw = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) - { - Assert.IsFalse(vw.Arguments.CommandLine.Contains("--quiet")); - Assert.IsTrue(vw.Arguments.CommandLine.Contains("-t")); - - using (var vwSub = new VowpalWabbit(new VowpalWabbitSettings { Model = vw })) - { - Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("--quiet")); - Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("-t")); - } - } - } - - [TestMethod] - [TestCategory("Vowpal Wabbit")] - public void TestArgumentDeDup() - { - using (var vw = new VowpalWabbit("-l 0.3 -l 0.3 --learning_rate 0.3 -f model1 --save_resume -q ab")) - { - Assert.AreEqual(0.3f, vw.Native.Arguments.LearningRate); - } - - try - { - using (var vw = - new VowpalWabbit( - "--cb 2 --cb_type ips --cb_type dm --learning_rate 0.1 -f model_bad --save_resume -q ab")) - { - Assert.AreEqual(0.1f, vw.Native.Arguments.LearningRate); - } - - Assert.Fail("Disagreering arguments not detected"); - } - catch (VowpalWabbitException) - { } - - using (var vw = new VowpalWabbit("-i model1 --save_resume")) - { - Assert.AreEqual(0.5f, vw.Native.Arguments.LearningRate); - } - - using (var vw = new VowpalWabbit("-i model1 --save_resume -q ab -l 0.4")) - { - Assert.AreEqual(0.4f, vw.Native.Arguments.LearningRate); - } - - // make sure different representations of arguments are matched - using (var vw = new VowpalWabbit("--cb_explore_adf --epsilon 0.1 -f model2")) - { } - - using (var vw = new VowpalWabbit("--cb_explore_adf --epsilon 0.1000 -i model2")) - { } - } - } -} +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using VW; + +namespace cs_unittest +{ + [TestClass] + public class TestArgumentsClass + { + [TestMethod] + [TestCategory("Vowpal Wabbit")] + public void TestArguments() + { + using (var vw = new VowpalWabbit(new VowpalWabbitSettings("--cb_explore_adf --epsilon 0.3 --interact ud") { Verbose = true })) + { + // --cb_explore_adf --epsilon 0.3 --interact ud --cb_adf--csoaa_ldf multiline --csoaa_rank + Console.WriteLine(vw.Arguments.CommandLine); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_explore_adf")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--epsilon 0.3")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--interact ud")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_ldf multiline")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_rank")); + vw.SaveModel("args.model"); + } + + using (var vw = new VowpalWabbit(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) + { + Console.WriteLine(vw.Arguments.CommandLine); + // --no_stdin--bit_precision 18--cb_explore_adf--epsilon 0.300000--cb_adf--cb_type ips --csoaa_ldf multiline--csoaa_rank--interact ud + + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--no_stdin")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--bit_precision 18")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_explore_adf")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--epsilon 0.3")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--interact ud")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_ldf multiline")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--csoaa_rank")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("--cb_type mtr")); + } + } + + [TestMethod] + [TestCategory("Vowpal Wabbit")] + public void TestQuietAndTestArguments() + { + using (var vw = new VowpalWabbit("--quiet -t")) + { + vw.SaveModel("args.model"); + } + + using (var vw = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) + { + Assert.IsFalse(vw.Arguments.CommandLine.Contains("--quiet")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("-t")); + + using (var vwSub = new VowpalWabbit(new VowpalWabbitSettings { Model = vw })) + { + Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("--quiet")); + Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("-t")); + } + } + + using (var vw = new VowpalWabbit("")) + { + vw.SaveModel("args.model"); + } + + using (var vw = new VowpalWabbitModel(new VowpalWabbitSettings { ModelStream = File.Open("args.model", FileMode.Open) })) + { + Assert.IsFalse(vw.Arguments.CommandLine.Contains("--quiet")); + Assert.IsTrue(vw.Arguments.CommandLine.Contains("-t")); + + using (var vwSub = new VowpalWabbit(new VowpalWabbitSettings { Model = vw })) + { + Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("--quiet")); + Assert.IsTrue(vwSub.Arguments.CommandLine.Contains("-t")); + } + } + } + + [TestMethod] + [TestCategory("Vowpal Wabbit")] + public void TestArgumentDeDup() + { + using (var vw = new VowpalWabbit("-l 0.3 -l 0.3 --learning_rate 0.3 -f model1 --save_resume -q ab")) + { + Assert.AreEqual(0.3f, vw.Native.Arguments.LearningRate); + } + + try + { + using (var vw = + new VowpalWabbit( + "--cb 2 --cb_type ips --cb_type dm --learning_rate 0.1 -f model_bad --save_resume -q ab")) + { + Assert.AreEqual(0.1f, vw.Native.Arguments.LearningRate); + } + + Assert.Fail("Disagreering arguments not detected"); + } + catch (VowpalWabbitException) + { } + + using (var vw = new VowpalWabbit("-i model1 --save_resume")) + { + Assert.AreEqual(0.5f, vw.Native.Arguments.LearningRate); + } + + using (var vw = new VowpalWabbit("-i model1 --save_resume -q ab -l 0.4")) + { + Assert.AreEqual(0.4f, vw.Native.Arguments.LearningRate); + } + + // make sure different representations of arguments are matched + using (var vw = new VowpalWabbit("--cb_explore_adf --epsilon 0.1 -f model2")) + { } + + using (var vw = new VowpalWabbit("--cb_explore_adf --epsilon 0.1000 -i model2")) + { } + } + } +} diff --git a/java/src/test/java/vowpalWabbit/learner/VWActionProbsLearnerTest.java b/java/src/test/java/vowpalWabbit/learner/VWActionProbsLearnerTest.java index a21be633cb9..06872d6a40e 100644 --- a/java/src/test/java/vowpalWabbit/learner/VWActionProbsLearnerTest.java +++ b/java/src/test/java/vowpalWabbit/learner/VWActionProbsLearnerTest.java @@ -112,8 +112,8 @@ public void testCBADFExplore() throws IOException { actionProb(1, 0.025f) ), actionProbs( - actionProb(1, 0.97499996f), - actionProb(0, 0.025f) + actionProb(0, 0.97499996f), + actionProb(1, 0.025f) ) }; vw.close(); diff --git a/java/src/test/java/vowpalWabbit/learner/VWActionScoresLearnerTest.java b/java/src/test/java/vowpalWabbit/learner/VWActionScoresLearnerTest.java index 37d86834879..8f1e27de934 100644 --- a/java/src/test/java/vowpalWabbit/learner/VWActionScoresLearnerTest.java +++ b/java/src/test/java/vowpalWabbit/learner/VWActionScoresLearnerTest.java @@ -97,16 +97,16 @@ private void testCBADF(boolean withRank) throws IOException { actionScore(1, 0) ), actionScores( - actionScore(0, 0.14991696f), - actionScore(1, 0.14991696f) + actionScore(0, 0.11246802f), + actionScore(1, 0.11246802f) ), actionScores( - actionScore(0, 0.27180168f), - actionScore(1, 0.31980497f) + actionScore(0, 0.3682006f), + actionScore(1, 0.5136312f) ), actionScores( - actionScore(1, 0.35295868f), - actionScore(0, 0.3869971f) + actionScore(0, 0.58848584f), + actionScore(1, 0.6244352f) ) }; vw.close(); @@ -117,8 +117,8 @@ private void testCBADF(boolean withRank) throws IOException { ActionScores[] expectedTestPreds = new ActionScores[]{ actionScores( - actionScore(0, 0.33543912f), - actionScore(1, 0.37897447f) + actionScore(0, 0.39904374f), + actionScore(1, 0.49083984f) ) }; diff --git a/test/RunTests b/test/RunTests index 25f5e5840cc..d4855d771c8 100755 --- a/test/RunTests +++ b/test/RunTests @@ -1226,7 +1226,7 @@ echo "" | {VW} # Test 84: check cb_adf {VW} --cb_adf -d train-sets/cb_test.ldf --noconstant - train-sets/ref/cb_adf.stderr + train-sets/ref/cb_adf_mtr.stderr # Test 85: check multilabel_oaa {VW} --multilabel_oaa 10 -d train-sets/multilabel -p multilabel.predict @@ -1577,8 +1577,9 @@ echo "1 | feature:1" | {VW} -a --initial_weight 0.1 --initial_t 0.3 train-sets/ref/157.stderr # Test 158: test decision service json parsing -{VW} -d train-sets/decisionservice.json --dsjson --cb_explore_adf --epsilon 0.2 --quadratic GT - train-sets/ref/decisionservice.stderr +{VW} -d train-sets/decisionservice.json --dsjson --cb_explore_adf --epsilon 0.2 --quadratic GT -P 1 -p cbe_adf_dsjson.predict + train-sets/ref/cbe_adf_dsjson.stderr + pred-sets/ref/cbe_adf_dsjson.predict # Test 159: test --bootstrap & --binary interaction {VW} -d train-sets/rcv1_mini.dat --bootstrap 5 --binary -c -k --passes 2 diff --git a/test/pred-sets/ref/cb_adf_rank.predict b/test/pred-sets/ref/cb_adf_rank.predict index b179679a319..98944b2f2ea 100644 --- a/test/pred-sets/ref/cb_adf_rank.predict +++ b/test/pred-sets/ref/cb_adf_rank.predict @@ -1,6 +1,6 @@ 0:0,1:0,2:0 -1:-0.182022,0:0.342521 +1:0,0:0.239436 -1:-0.256651,0:0.461775 +1:0,0:0.359153 diff --git a/test/pred-sets/ref/cbe_adf_dsjson.predict b/test/pred-sets/ref/cbe_adf_dsjson.predict new file mode 100644 index 00000000000..614cceb083c --- /dev/null +++ b/test/pred-sets/ref/cbe_adf_dsjson.predict @@ -0,0 +1,6 @@ +0:0.0833333,1:0.0833333,2:0.0833333,3:0.0833333,4:0.0833333,5:0.0833333,6:0.0833333,7:0.0833333,8:0.0833333,9:0.0833333,10:0.0833333,11:0.0833333 + +6:0.816667,5:0.0166667,9:0.0166667,2:0.0166667,10:0.0166667,1:0.0166667,3:0.0166667,7:0.0166667,4:0.0166667,0:0.0166667,8:0.0166667,11:0.0166667 + +6:0.816667,5:0.0166667,9:0.0166667,2:0.0166667,1:0.0166667,3:0.0166667,10:0.0166667,4:0.0166667,0:0.0166667,7:0.0166667,8:0.0166667,11:0.0166667 + diff --git a/test/pred-sets/ref/cbe_adf_softmax.predict b/test/pred-sets/ref/cbe_adf_softmax.predict index 85c568d2ff3..2a34d4a7df7 100644 --- a/test/pred-sets/ref/cbe_adf_softmax.predict +++ b/test/pred-sets/ref/cbe_adf_softmax.predict @@ -1,6 +1,6 @@ 0:0.333333,1:0.333333,2:0.333333 -1:0.628209,0:0.371791 +1:0.559575,0:0.440425 -1:0.67226,0:0.32774 +1:0.588836,0:0.411165 diff --git a/test/train-sets/ref/cb_adf_rank.stderr b/test/train-sets/ref/cb_adf_rank.stderr index 984e142243c..edc2b75a05b 100644 --- a/test/train-sets/ref/cb_adf_rank.stderr +++ b/test/train-sets/ref/cb_adf_rank.stderr @@ -9,7 +9,7 @@ num sources = 1 average since example example current current current loss last counter weight label predict features 2.000000 2.000000 1 1.0 known 0:0... 9 -1.000000 0.000000 2 2.0 known 1:-0.182022... 6 +1.000000 0.000000 2 2.0 known 1:0... 6 finished run number of examples = 3 diff --git a/test/train-sets/ref/cbe_adf_cover.stderr b/test/train-sets/ref/cbe_adf_cover.stderr index 2c37e2d40ae..06fc981e901 100644 --- a/test/train-sets/ref/cbe_adf_cover.stderr +++ b/test/train-sets/ref/cbe_adf_cover.stderr @@ -1,4 +1,5 @@ predictions = cbe_adf_cover.predict +warning: currently, mtr is only used for the first policy in cover, other policies use dr Num weight bits = 18 learning rate = 0.5 initial_t = 0 diff --git a/test/train-sets/ref/decisionservice.stderr b/test/train-sets/ref/cbe_adf_dsjson.stderr similarity index 75% rename from test/train-sets/ref/decisionservice.stderr rename to test/train-sets/ref/cbe_adf_dsjson.stderr index 7eee9ec91ea..2f070887361 100644 --- a/test/train-sets/ref/decisionservice.stderr +++ b/test/train-sets/ref/cbe_adf_dsjson.stderr @@ -1,4 +1,5 @@ -creating quadratic features for pairs: GT +creating quadratic features for pairs: GT +predictions = cbe_adf_dsjson.predict Num weight bits = 18 learning rate = 0.5 initial_t = 0 @@ -10,10 +11,11 @@ average since example example current current current loss last counter weight label predict features -0.102041 -0.102041 1 1.0 known 0:0.0833333... 361 -0.051020 0.000000 2 2.0 known 6:0.816667... 361 +-0.040816 -0.020408 3 3.0 known 6:0.816667... 361 finished run number of examples = 3 weighted example sum = 3.000000 weighted label sum = 0.000000 -average loss = -0.367347 +average loss = -0.040816 total feature number = 1083 diff --git a/test/train-sets/ref/cbe_adf_softmax.stderr b/test/train-sets/ref/cbe_adf_softmax.stderr index 74dc15ae256..f9d5640827b 100644 --- a/test/train-sets/ref/cbe_adf_softmax.stderr +++ b/test/train-sets/ref/cbe_adf_softmax.stderr @@ -9,7 +9,7 @@ num sources = 1 average since example example current current current loss last counter weight label predict features 0.666667 0.666667 1 1.0 known 0:0.333333... 9 -0.333333 0.000000 2 2.0 known 1:0.628209... 6 +0.333333 0.000000 2 2.0 known 1:0.559575... 6 finished run number of examples = 3 diff --git a/vowpalwabbit/cb_adf.cc b/vowpalwabbit/cb_adf.cc index ccc9cefe131..9da6ef7a2a1 100644 --- a/vowpalwabbit/cb_adf.cc +++ b/vowpalwabbit/cb_adf.cc @@ -341,7 +341,7 @@ base_learner* cb_adf_setup(options_i& options, vw& all) { auto ld = scoped_calloc_or_throw(); bool cb_adf_option = false; - std::string type_string = "ips"; + std::string type_string = "mtr"; option_group_definition new_options("Contextual Bandit with Action Dependent Features"); new_options @@ -350,7 +350,7 @@ base_learner* cb_adf_setup(options_i& options, vw& all) .help("Do Contextual Bandit learning with multiline action dependent features.")) .add(make_option("rank_all", ld->rank_all).keep().help("Return actions sorted by score order")) .add(make_option("no_predict", ld->no_predict).help("Do not do a prediction when training")) - .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dm,dr, mtr}")); + .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dm,dr,mtr}. Default: mtr")); options.add_and_parse(new_options); if (!cb_adf_option) @@ -384,8 +384,8 @@ base_learner* cb_adf_setup(options_i& options, vw& all) ld->gen_cs.cb_type = CB_TYPE_DM; else { - all.trace_message << "warning: cb_type must be in {'ips','dr','mtr','dm'}; resetting to ips." << std::endl; - ld->gen_cs.cb_type = CB_TYPE_IPS; + all.trace_message << "warning: cb_type must be in {'ips','dr','mtr','dm'}; resetting to mtr." << std::endl; + ld->gen_cs.cb_type = CB_TYPE_MTR; } all.delete_prediction = ACTION_SCORE::delete_action_scores; diff --git a/vowpalwabbit/cb_explore_adf.cc b/vowpalwabbit/cb_explore_adf.cc index 3203edaf8db..f37b26649b8 100644 --- a/vowpalwabbit/cb_explore_adf.cc +++ b/vowpalwabbit/cb_explore_adf.cc @@ -738,7 +738,7 @@ base_learner* cb_explore_adf_setup(options_i& options, vw& all) bool cb_explore_adf_option = false; bool softmax = false; bool regcb = false; - std::string type_string = "ips"; + std::string type_string = "mtr"; option_group_definition new_options("Contextual Bandit Exploration with Action Dependent Features"); new_options .add(make_option("cb_explore_adf", cb_explore_adf_option) @@ -766,7 +766,7 @@ base_learner* cb_explore_adf_setup(options_i& options, vw& all) .keep() .help("Only explore the first action in a tie-breaking event")) .add(make_option("lambda", data->lambda).keep().default_value(-1.f).help("parameter for softmax")) - .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dm,dr}")); + .add(make_option("cb_type", type_string).keep().help("contextual bandit method to use in {ips,dr,mtr}. Default: mtr")); options.add_and_parse(new_options); if (!cb_explore_adf_option) @@ -831,8 +831,8 @@ base_learner* cb_explore_adf_setup(options_i& options, vw& all) } else { - all.trace_message << "warning: cb_type must be in {'ips','dr','mtr'}; resetting to ips." << std::endl; - options.replace("cb_type", "ips"); + all.trace_message << "warning: cb_type must be in {'ips','dr','mtr'}; resetting to mtr." << std::endl; + options.replace("cb_type", "mtr"); } if (data->explore_type == REGCB && data->gen_cs.cb_type != CB_TYPE_MTR)