-
Notifications
You must be signed in to change notification settings - Fork 1.9k
/
vowpalWabbit_learner_VWLearners.cc
109 lines (101 loc) · 3.37 KB
/
vowpalWabbit_learner_VWLearners.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
#include "vowpalWabbit_learner_VWLearners.h"
#include "jni_base_learner.h"
#include "vw/core/vw.h"
#define RETURN_TYPE "vowpalWabbit/learner/VWLearners$VWReturnType"
#define RETURN_TYPE_INSTANCE "L" RETURN_TYPE ";"
JNIEXPORT jlong JNICALL Java_vowpalWabbit_learner_VWLearners_initialize(JNIEnv* env, jclass obj, jstring command)
{
jlong vwPtr = 0;
try
{
VW::workspace* vwInstance = VW::initialize(env->GetStringUTFChars(command, NULL));
vwPtr = (jlong)vwInstance;
}
catch (...)
{
rethrow_cpp_exception_as_java_exception(env);
}
return vwPtr;
}
JNIEXPORT void JNICALL Java_vowpalWabbit_learner_VWLearners_performRemainingPasses(JNIEnv* env, jclass obj, jlong vwPtr)
{
try
{
VW::workspace* vwInstance = (VW::workspace*)vwPtr;
if (vwInstance->numpasses > 1)
{
vwInstance->do_reset_source = true;
VW::start_parser(*vwInstance);
VW::LEARNER::generic_driver(*vwInstance);
VW::end_parser(*vwInstance);
}
}
catch (...)
{
rethrow_cpp_exception_as_java_exception(env);
}
}
JNIEXPORT void JNICALL Java_vowpalWabbit_learner_VWLearners_closeInstance(JNIEnv* env, jclass obj, jlong vwPtr)
{
try
{
VW::workspace* vwInstance = (VW::workspace*)vwPtr;
VW::finish(*vwInstance);
}
catch (...)
{
rethrow_cpp_exception_as_java_exception(env);
}
}
JNIEXPORT void JNICALL Java_vowpalWabbit_learner_VWLearners_saveModel(
JNIEnv* env, jclass obj, jlong vwPtr, jstring filename)
{
try
{
const char* utf_string = env->GetStringUTFChars(filename, NULL);
std::string filenameCpp(utf_string);
env->ReleaseStringUTFChars(filename, utf_string);
env->DeleteLocalRef(filename);
VW::save_predictor(*(VW::workspace*)vwPtr, filenameCpp);
}
catch (...)
{
rethrow_cpp_exception_as_java_exception(env);
}
}
JNIEXPORT jobject JNICALL Java_vowpalWabbit_learner_VWLearners_getReturnType(JNIEnv* env, jclass obj, jlong vwPtr)
{
jclass clVWReturnType = env->FindClass(RETURN_TYPE);
jfieldID field;
VW::workspace* vwInstance = (VW::workspace*)vwPtr;
switch (vwInstance->l->get_output_prediction_type())
{
case VW::prediction_type_t::action_probs:
field = env->GetStaticFieldID(clVWReturnType, "ActionProbs", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::action_scores:
field = env->GetStaticFieldID(clVWReturnType, "ActionScores", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::multiclass:
field = env->GetStaticFieldID(clVWReturnType, "Multiclass", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::multilabels:
field = env->GetStaticFieldID(clVWReturnType, "Multilabels", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::prob:
field = env->GetStaticFieldID(clVWReturnType, "Prob", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::scalar:
field = env->GetStaticFieldID(clVWReturnType, "Scalar", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::scalars:
field = env->GetStaticFieldID(clVWReturnType, "Scalars", RETURN_TYPE_INSTANCE);
break;
case VW::prediction_type_t::decision_probs:
field = env->GetStaticFieldID(clVWReturnType, "DecisionProbs", RETURN_TYPE_INSTANCE);
break;
default:
field = env->GetStaticFieldID(clVWReturnType, "Unknown", RETURN_TYPE_INSTANCE);
}
return env->GetStaticObjectField(clVWReturnType, field);
}