Skip to content

Commit

Permalink
Merge fa3463c into ea4134e
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Feb 12, 2020
2 parents ea4134e + fa3463c commit 86e890d
Show file tree
Hide file tree
Showing 145 changed files with 3,966 additions and 2,958 deletions.
11 changes: 4 additions & 7 deletions cs/cli/vowpalwabbit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ void VowpalWabbit::Driver()
void VowpalWabbit::RunMultiPass()
{ if (m_vw->numpasses > 1)
{ try
{ adjust_used_index(*m_vw);
{
m_vw->do_reset_source = true;
VW::start_parser(*m_vw);
LEARNER::generic_driver(*m_vw);
Expand Down Expand Up @@ -307,7 +307,7 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
auto ex = GetOrCreateNativeExample();
state->examples->Add(ex);

v_array<example*> examples = v_init<example*>();
v_array<example*> examples;
example* native_example = ex->m_example;
examples.push_back(native_example);

Expand All @@ -326,9 +326,6 @@ List<VowpalWabbitExample^>^ VowpalWabbit::ParseDecisionServiceJson(cli::array<By
// finalize example
VW::setup_examples(*m_vw, examples);

// delete native array of pointers, keep examples
examples.delete_v();

header->EventId = gcnew String(interaction.eventId.c_str());
header->Actions = gcnew cli::array<int>((int)interaction.actions.size());
int index = 0;
Expand Down Expand Up @@ -789,15 +786,15 @@ VowpalWabbitExample^ VowpalWabbit::GetOrCreateNativeExample()
if (ex == nullptr)
{ try
{ auto ex = VW::alloc_examples(0, 1);
m_vw->p->lp.default_label(&ex->l);
m_vw->p->lp.default_label(ex->l);
return gcnew VowpalWabbitExample(this, ex);
}
CATCHRETHROW
}

try
{ VW::empty_example(*m_vw, *ex->m_example);
m_vw->p->lp.default_label(&ex->m_example->l);
m_vw->p->lp.default_label(ex->m_example->l);

return ex;
}
Expand Down
10 changes: 5 additions & 5 deletions cs/cli/vw_example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ void VowpalWabbitExample::Label::set(ILabel^ label)
label->UpdateExample(m_owner->Native->m_vw, m_example);

// we need to update the example weight as setup_example() can be called prior to this call.
m_example->weight = m_owner->Native->m_vw->p->lp.get_weight(&m_example->l);
m_example->weight = m_owner->Native->m_vw->p->lp.get_weight(m_example->l);
}

void VowpalWabbitExample::MakeEmpty(VowpalWabbit^ vw)
Expand Down Expand Up @@ -280,8 +280,8 @@ System::String^ VowpalWabbitExample::Diff(VowpalWabbit^ vw, VowpalWabbitExample^
}

String^ VowpalWabbitSimpleLabelComparator::Diff(VowpalWabbitExample^ ex1, VowpalWabbitExample^ ex2)
{ auto s1 = ex1->m_example->l.simple;
auto s2 = ex2->m_example->l.simple;
{ auto& s1 = ex1->m_example->l.simple();
auto& s2 = ex2->m_example->l.simple();

if (!(FloatEqual(s1.initial, s2.initial) &&
FloatEqual(s1.label, s2.label) &&
Expand All @@ -296,8 +296,8 @@ String^ VowpalWabbitSimpleLabelComparator::Diff(VowpalWabbitExample^ ex1, Vowpal
}

String^ VowpalWabbitContextualBanditLabelComparator::Diff(VowpalWabbitExample^ ex1, VowpalWabbitExample^ ex2)
{ auto s1 = ex1->m_example->l.cb;
auto s2 = ex2->m_example->l.cb;
{ auto& s1 = ex1->m_example->l.cb();
auto& s2 = ex2->m_example->l.cb();

if (s1.costs.size() != s2.costs.size())
{ return System::String::Format("Cost size differ: {0} vs {1}", s1.costs.size(), s2.costs.size());
Expand Down
104 changes: 66 additions & 38 deletions cs/cli/vw_prediction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,17 @@
namespace VW
{
void CheckExample(vw* vw, example* ex, prediction_type_t type)
{ if (vw == nullptr)
{
if (vw == nullptr)
throw gcnew ArgumentNullException("vw");

if (ex == nullptr)
throw gcnew ArgumentNullException("ex");

auto ex_pred_type = vw->l->pred_type;
if (ex_pred_type != type)
{ auto sb = gcnew StringBuilder();
{
auto sb = gcnew StringBuilder();
sb->Append("Prediction type must be ");
sb->Append(gcnew String(to_string(type)));
sb->Append(" but is ");
Expand All @@ -29,20 +31,23 @@ void CheckExample(vw* vw, example* ex, prediction_type_t type)
}

float VowpalWabbitScalarPredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
{
CheckExample(vw, ex, PredictionType);

try
{ return VW::get_prediction(ex);
{
return VW::get_prediction(ex);
}
CATCHRETHROW
}


VowpalWabbitScalar VowpalWabbitScalarConfidencePredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
{
CheckExample(vw, ex, PredictionType);

try
{ VowpalWabbitScalar ret;
{
VowpalWabbitScalar ret;

ret.Value = VW::get_prediction(ex);
ret.Confidence = ex->confidence;
Expand All @@ -52,37 +57,41 @@ VowpalWabbitScalar VowpalWabbitScalarConfidencePredictionFactory::Create(vw* vw,
CATCHRETHROW
}

cli::array<float>^ VowpalWabbitScalarsPredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
cli::array<float> ^ VowpalWabbitScalarsPredictionFactory::Create(vw* vw, example* ex)
{
CheckExample(vw, ex, PredictionType);

try
{ auto& scalars = ex->pred.scalars;
{
auto& scalars = ex->pred.scalars();
auto values = gcnew cli::array<float>((int)scalars.size());
int index = 0;
for (float s : scalars)
values[index++] = s;
for (float s : scalars) values[index++] = s;

return values;
}
CATCHRETHROW
}

float VowpalWabbitProbabilityPredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
{
CheckExample(vw, ex, PredictionType);

return ex->pred.prob;
return ex->pred.prob();
}

float VowpalWabbitCostSensitivePredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
{
CheckExample(vw, ex, PredictionType);

try
{ return VW::get_cost_sensitive_prediction(ex);
{
return VW::get_cost_sensitive_prediction(ex);
}
CATCHRETHROW
}

Dictionary<int, float>^ VowpalWabbitMulticlassProbabilitiesPredictionFactory::Create(vw* vw, example* ex)
Dictionary<int, float> ^ VowpalWabbitMulticlassProbabilitiesPredictionFactory::Create(vw* vw, example* ex)
{
#if _DEBUG
if (ex == nullptr)
Expand All @@ -91,33 +100,38 @@ Dictionary<int, float>^ VowpalWabbitMulticlassProbabilitiesPredictionFactory::Cr
v_array<float> confidence_scores;

try
{ confidence_scores = VW::get_cost_sensitive_prediction_confidence_scores(ex);
{
confidence_scores = VW::get_cost_sensitive_prediction_confidence_scores(ex);
}
CATCHRETHROW

auto values = gcnew Dictionary<int, float>();
int i = 0;
for (auto& val : confidence_scores)
{ values->Add(++i, val);
{
values->Add(++i, val);
}

return values;
}

uint32_t VowpalWabbitMulticlassPredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
{
CheckExample(vw, ex, PredictionType);

return ex->pred.multiclass;
return ex->pred.multiclass();
}

cli::array<int>^ VowpalWabbitMultilabelPredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, prediction_type_t::multilabels);
cli::array<int> ^ VowpalWabbitMultilabelPredictionFactory::Create(vw* vw, example* ex)
{
CheckExample(vw, ex, prediction_type_t::multilabels);

size_t length;
uint32_t* labels;

try
{ labels = VW::get_multilabel_predictions(ex, length);
{
labels = VW::get_multilabel_predictions(ex, length);
}
CATCHRETHROW

Expand All @@ -132,38 +146,51 @@ cli::array<int>^ VowpalWabbitMultilabelPredictionFactory::Create(vw* vw, example
return values;
}

cli::array<ActionScore>^ VowpalWabbitActionScoreBasePredictionFactory::Create(vw* vw, example* ex)
{ CheckExample(vw, ex, PredictionType);
cli::array<ActionScore> ^ VowpalWabbitActionScoreBasePredictionFactory::Create(vw* vw, example* ex)
{
CheckExample(vw, ex, PredictionType);

auto& a_s = ex->pred.a_s;
auto values = gcnew cli::array<ActionScore>((int)a_s.size());
ACTION_SCORE::action_scores* a_s = nullptr;
if (ex->pred.get_type() == prediction_type_t::action_scores)
{
a_s = &ex->pred.action_scores();
}
else
{
a_s = &ex->pred.action_probs();
}
auto values = gcnew cli::array<ActionScore>((int)a_s->size());

auto index = 0;
for (auto& as : a_s)
{ values[index].Action = as.action;
for (auto& as : *a_s)
{
values[index].Action = as.action;
values[index].Score = as.score;
index++;
}

return values;
}

cli::array<float>^ VowpalWabbitTopicPredictionFactory::Create(vw* vw, example* ex)
{ if (ex == nullptr)
cli::array<float> ^ VowpalWabbitTopicPredictionFactory::Create(vw* vw, example* ex)
{
if (ex == nullptr)
throw gcnew ArgumentNullException("ex");

auto values = gcnew cli::array<float>(vw->lda);
Marshal::Copy(IntPtr(ex->pred.scalars.begin()), values, 0, vw->lda);
Marshal::Copy(IntPtr(ex->pred.scalars().begin()), values, 0, vw->lda);

return values;
}

System::Object^ VowpalWabbitDynamicPredictionFactory::Create(vw* vw, example* ex)
{ if (ex == nullptr)
System::Object ^ VowpalWabbitDynamicPredictionFactory::Create(vw* vw, example* ex)
{
if (ex == nullptr)
throw gcnew ArgumentNullException("ex");

switch (vw->l->pred_type)
{ case prediction_type_t::scalar:
{
case prediction_type_t::scalar:
return VowpalWabbitPredictionType::Scalar->Create(vw, ex);
case prediction_type_t::scalars:
return VowpalWabbitPredictionType::Scalars->Create(vw, ex);
Expand All @@ -180,11 +207,12 @@ System::Object^ VowpalWabbitDynamicPredictionFactory::Create(vw* vw, example* ex
case prediction_type_t::multiclassprobs:
return VowpalWabbitPredictionType::MultiClassProbabilities->Create(vw, ex);
default:
{ auto sb = gcnew StringBuilder();
{
auto sb = gcnew StringBuilder();
sb->Append("Unsupported prediction type: ");
sb->Append(gcnew String(to_string(vw->l->pred_type)));
throw gcnew ArgumentException(sb->ToString());
}
}
}
}
} // namespace VW
3 changes: 2 additions & 1 deletion java/src/main/c++/jni_base_learner.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ T base_predict(JNIEnv* env, jobjectArray example_strings, jboolean learn, jlong
rethrow_cpp_exception_as_java_exception(env);
}

T result = predictor(first_example, env);
vwInstance->finish_example(ex_coll);

return predictor(first_example, env);
return result;
}

#endif // VW_BASE_LEARNER_H
10 changes: 5 additions & 5 deletions java/src/main/c++/jni_spark_vw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ JNIEXPORT jlong JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_initiali

try
{
example* ex = VW::alloc_examples(0, 1);
example* ex = VW::alloc_examples(1);
ex->interactions = &all->interactions;

if (isEmpty)
Expand All @@ -265,7 +265,7 @@ JNIEXPORT jlong JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_initiali
VW::read_line(*all, ex, &empty);
}
else
all->p->lp.default_label(&ex->l);
all->p->lp.default_label(ex->l);

return (jlong) new VowpalWabbitExampleWrapper(all, ex);
}
Expand Down Expand Up @@ -297,7 +297,7 @@ JNIEXPORT void JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_clear(JNI
try
{
VW::empty_example(*all, *ex);
all->p->lp.default_label(&ex->l);
all->p->lp.default_label(ex->l);
}
catch (...)
{
Expand Down Expand Up @@ -444,7 +444,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_getPre
ctr = env->GetMethodID(predClass, "<init>", "(F)V");
CHECK_JNI_EXCEPTION(nullptr);

return env->NewObject(predClass, ctr, ex->pred.prob);
return env->NewObject(predClass, ctr, ex->pred.prob());

case prediction_type_t::multiclass:
predClass = env->FindClass("java/lang/Integer");
Expand All @@ -453,7 +453,7 @@ JNIEXPORT jobject JNICALL Java_org_vowpalwabbit_spark_VowpalWabbitExample_getPre
ctr = env->GetMethodID(predClass, "<init>", "(I)V");
CHECK_JNI_EXCEPTION(nullptr);

return env->NewObject(predClass, ctr, ex->pred.multiclass);
return env->NewObject(predClass, ctr, ex->pred.multiclass());

case prediction_type_t::scalars:
return scalars_predictor(ex, env);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@ jobject action_probs_prediction(example *vec, JNIEnv *env)
jclass action_prob_class = env->FindClass("vowpalWabbit/responses/ActionProb");
jmethodID action_prob_constructor = env->GetMethodID(action_prob_class, "<init>", "(IF)V");

// The action_probs prediction_type_t is just a placeholder identifying when the aciton_scores
// The action_probs prediction_type_t is just a placeholder identifying when the action_scores
// should be treated as probabilities or scores. That is why this function references a_s yet returns
// ActionProbs to the Java side.
ACTION_SCORE::action_scores a_s = vec->pred.a_s;
const auto& a_s = vec->pred.action_probs();
size_t num_values = a_s.size();
jobjectArray j_action_probs = env->NewObjectArray(num_values, action_prob_class, 0);

jclass action_probs_class = env->FindClass("vowpalWabbit/responses/ActionProbs");
for (uint32_t i = 0; i < num_values; ++i)
{
ACTION_SCORE::action_score a = a_s[i];
const auto& a = a_s[i];
jobject j_action_prob = env->NewObject(action_prob_class, action_prob_constructor, a.action, a.score);
env->SetObjectArrayElement(j_action_probs, i, j_action_prob);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,14 @@ jobject action_scores_prediction(example *vec, JNIEnv *env)
jclass action_score_class = env->FindClass("vowpalWabbit/responses/ActionScore");
jmethodID action_score_constructor = env->GetMethodID(action_score_class, "<init>", "(IF)V");

ACTION_SCORE::action_scores a_s = vec->pred.a_s;
const auto a_s = vec->pred.action_scores();
size_t num_values = a_s.size();
jobjectArray j_action_scores = env->NewObjectArray(num_values, action_score_class, 0);

jclass action_scores_class = env->FindClass("vowpalWabbit/responses/ActionScores");
for (uint32_t i = 0; i < num_values; ++i)
{
ACTION_SCORE::action_score a = a_s[i];
const auto a = a_s[i];
jobject j_action_score = env->NewObject(action_score_class, action_score_constructor, a.action, a.score);
env->SetObjectArrayElement(j_action_scores, i, j_action_score);
}
Expand Down

0 comments on commit 86e890d

Please sign in to comment.