-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
vowpalWabbit_learner_VWActionScoresLearner.cc
37 lines (31 loc) · 1.63 KB
/
vowpalWabbit_learner_VWActionScoresLearner.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include "vowpalWabbit_learner_VWActionScoresLearner.h"
#include "jni_base_learner.h"
#include "vw/core/vw.h"
jobject action_scores_prediction(example* vec, JNIEnv* env)
{
jclass action_score_class = env->FindClass("vowpalWabbit/responses/ActionScore");
jmethodID action_score_constructor = env->GetMethodID(action_score_class, "<init>", "(IF)V");
ACTION_SCORE::action_scores a_s = vec->pred.a_s;
size_t num_values = a_s.size();
jobjectArray j_action_scores = env->NewObjectArray(num_values, action_score_class, 0);
jclass action_scores_class = env->FindClass("vowpalWabbit/responses/ActionScores");
for (uint32_t i = 0; i < num_values; ++i)
{
ACTION_SCORE::action_score a = a_s[i];
jobject j_action_score = env->NewObject(action_score_class, action_score_constructor, a.action, a.score);
env->SetObjectArrayElement(j_action_scores, i, j_action_score);
}
jmethodID action_scores_constructor =
env->GetMethodID(action_scores_class, "<init>", "([LvowpalWabbit/responses/ActionScore;)V");
return env->NewObject(action_scores_class, action_scores_constructor, j_action_scores);
}
JNIEXPORT jobject JNICALL Java_vowpalWabbit_learner_VWActionScoresLearner_predict(
JNIEnv* env, jobject obj, jstring example_string, jboolean learn, jlong vwPtr)
{
return base_predict<jobject>(env, example_string, learn, vwPtr, action_scores_prediction);
}
JNIEXPORT jobject JNICALL Java_vowpalWabbit_learner_VWActionScoresLearner_predictMultiline(
JNIEnv* env, jobject obj, jobjectArray example_strings, jboolean learn, jlong vwPtr)
{
return base_predict<jobject>(env, example_strings, learn, vwPtr, action_scores_prediction);
}