Skip to content
Permalink
Browse files

Bare Java JNI Bindings optimized for Apache Spark (#1798)

* working version for Spark
* package natives up with Java code. This includes boost program options
* added dynamic port allocation to spanning tree to allow for parallel hyper-parameter tuning on Spark
* extracting VW command line args from model
* removed java header generation and checked in the headers just like the classic headers
* expose native hashing for testing
* fixed Java hashing
* formatted code
* improved Spanning Tree error message
  • Loading branch information...
eisber committed Jun 4, 2019
1 parent 7b1fd30 commit fa5c3dfbcf689d72bc9795d7c0e59967eea866d3
Showing with 1,879 additions and 194 deletions.
  1. +1 −1 build-linux.sh
  2. +70 −3 java/CMakeLists.txt
  3. +13 −0 java/README.md
  4. +23 −1 java/pom.xml.in
  5. +7 −32 java/src/main/c++/jni_base_learner.cc
  6. +27 −37 java/src/main/c++/jni_base_learner.h
  7. +71 −0 java/src/main/c++/jni_spark_cluster.cc
  8. +422 −0 java/src/main/c++/jni_spark_vw.cc
  9. +42 −0 java/src/main/c++/jni_spark_vw.h
  10. +211 −0 java/src/main/c++/jni_spark_vw_generated.h
  11. +45 −0 java/src/main/c++/util.cc
  12. +7 −0 java/src/main/c++/util.h
  13. +44 −0 java/src/main/c++/vector_io_buf.cc
  14. +31 −0 java/src/main/c++/vector_io_buf.h
  15. +4 −3 java/src/main/c++/vowpalWabbit_VW.cc
  16. +17 −10 java/src/main/c++/vowpalWabbit_learner_VWActionProbsLearner.cc
  17. +17 −10 java/src/main/c++/vowpalWabbit_learner_VWActionScoresLearner.cc
  18. +52 −38 java/src/main/c++/vowpalWabbit_learner_VWLearners.cc
  19. +11 −7 java/src/main/c++/vowpalWabbit_learner_VWMulticlassLearner.cc
  20. +14 −9 java/src/main/c++/vowpalWabbit_learner_VWMultilabelsLearner.cc
  21. +11 −7 java/src/main/c++/vowpalWabbit_learner_VWProbLearner.cc
  22. +11 −7 java/src/main/c++/vowpalWabbit_learner_VWScalarLearner.cc
  23. +14 −9 java/src/main/c++/vowpalWabbit_learner_VWScalarsLearner.cc
  24. +34 −0 java/src/main/java/org/vowpalwabbit/spark/ClusterSpanningTree.java
  25. +57 −0 java/src/main/java/org/vowpalwabbit/spark/Native.java
  26. +36 −0 java/src/main/java/org/vowpalwabbit/spark/VowpalWabbitArguments.java
  27. +105 −0 java/src/main/java/org/vowpalwabbit/spark/VowpalWabbitExample.java
  28. +78 −0 java/src/main/java/org/vowpalwabbit/spark/VowpalWabbitMurmur.java
  29. +124 −0 java/src/main/java/org/vowpalwabbit/spark/VowpalWabbitNative.java
  30. +35 −0 java/src/main/java/org/vowpalwabbit/spark/prediction/ScalarPrediction.java
  31. +190 −0 java/src/test/java/org/vowpalwabbit/spark/VowpalWabbitNativeIT.java
  32. +5 −0 java/src/test/resources/dataset1.csv
  33. +10 −0 java/src/test/resources/test.txt
  34. +3 −2 vowpalwabbit/allreduce.h
  35. +0 −1 vowpalwabbit/allreduce_sockets.cc
  36. +5 −2 vowpalwabbit/parse_args.cc
  37. +28 −13 vowpalwabbit/spanning_tree.cc
  38. +4 −2 vowpalwabbit/spanning_tree.h
@@ -31,7 +31,7 @@ make test_with_output
cd ..

# Run Java build and test
mvn clean test -f java/pom.xml
mvn verify -f java/pom.xml

# Run python build and tests
cd python
@@ -14,6 +14,10 @@ if(JNI_FOUND)
${src_base}/vowpalWabbit_learner_VWScalarLearner.h
${src_base}/vowpalWabbit_learner_VWScalarsLearner.h
${src_base}/vowpalWabbit_VW.h
${src_base}/jni_spark_vw.h
${src_base}/jni_spark_vw_generated.h
${src_base}/vector_io_buf.h
${src_base}/util.h
)

set(vw_jni_sources
@@ -27,18 +31,28 @@ if(JNI_FOUND)
${src_base}/vowpalWabbit_learner_VWScalarLearner.cc
${src_base}/vowpalWabbit_learner_VWScalarsLearner.cc
${src_base}/vowpalWabbit_VW.cc
${src_base}/jni_spark_vw.cc
${src_base}/vector_io_buf.cc
${src_base}/jni_spark_cluster.cc
${src_base}/util.cc
${CMAKE_CURRENT_SOURCE_DIR}/../vowpalwabbit/spanning_tree.cc
)

add_library(vw_jni SHARED ${vw_jni_headers} ${vw_jni_sources})
target_link_libraries(vw_jni PUBLIC vw)
target_include_directories(vw_jni PUBLIC ${CMAKE_CURRENT_SOURCE_DIR} ${JNI_INCLUDE_DIRS})
target_include_directories(vw_jni PUBLIC
${CMAKE_CURRENT_SOURCE_DIR}
${JNI_INCLUDE_DIRS})

# Ensure target directory exists
file(MAKE_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/target/)
add_custom_command(TARGET vw_jni POST_BUILD
COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:vw_jni> ${CMAKE_CURRENT_SOURCE_DIR}/target/
COMMAND ${CMAKE_COMMAND} -E copy $<TARGET_FILE:vw_jni> ${CMAKE_CURRENT_SOURCE_DIR}/target/bin/
)

# enable-new-dtags and rpath enables shared object library lookup in the location of libvw_spark_jni.so
target_link_libraries(vw_jni PUBLIC -Wl,--enable-new-dtags -Wl,-rpath,\"\$ORIGIN\" vw)

# Replace version number in POM
configure_file(pom.xml.in ${CMAKE_CURRENT_SOURCE_DIR}/pom.xml @ONLY)

@@ -54,4 +68,57 @@ if(JNI_FOUND)
LIBRARY DESTINATION ${JAVA_INSTALL_PATH}
)
endif()
endif()

if(NOT WIN32)
# Ensure target directory exists
file(MAKE_DIRECTORY target/classes)
file(MAKE_DIRECTORY target/test-classes)
file(MAKE_DIRECTORY target/bin/natives/linux_64)

# Development
# - uncomment the following section to generate the jni headers
# - it's commented to speed up the build as it's not expected to change frequently

# find_package(Java)
# include(UseJava)

# add_custom_target(javacompile
# COMMAND mvn compile
# WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
# COMMENT "Compile classes for javah")

# create_javah(TARGET javaheaders
# CLASSES
# org.vowpalwabbit.spark.VowpalWabbitNative
# org.vowpalwabbit.spark.VowpalWabbitExample
# org.vowpalwabbit.spark.ClusterSpanningTree
# CLASSPATH ${CMAKE_CURRENT_SOURCE_DIR}/target/classes
# OUTPUT_NAME ${CMAKE_CURRENT_SOURCE_DIR}/src/main/c++/jni_spark_vw_generated.h)
# add_dependencies(javaheaders javacompile)

add_custom_command(TARGET vw_jni POST_BUILD
COMMAND ldd $<TARGET_FILE:vw_jni> | grep -E 'boost|libz' | grep -oP '=> \\K\\S+' | xargs -i cp {} target/bin/natives/linux_64
COMMAND cp $<TARGET_FILE:vw_jni> target/bin/natives/linux_64
COMMAND echo $<TARGET_FILE:vw-bin> > ${CMAKE_CURRENT_SOURCE_DIR}/target/test-classes/vw-bin.txt
COMMAND mvn verify
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
COMMENT "Copying shared libary dependencies to output directory")

# Replace version number in POM
configure_file(pom.xml.in ${CMAKE_CURRENT_SOURCE_DIR}/pom.xml @ONLY)

if (${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
set(JAVA_INSTALL_PATH /usr/lib)
else()
set(JAVA_INSTALL_PATH /Library/Java/Extensions)
endif()

if(VW_INSTALL)
install(TARGETS vw_jni
RUNTIME DESTINATION ${JAVA_INSTALL_PATH}
LIBRARY DESTINATION ${JAVA_INSTALL_PATH}
)
endif()

endif()
endif()
@@ -50,3 +50,16 @@ It should also be noted that Vowpal Wabbit makes all attempts at compatibility b
| ---------- | ---------------------------------------- |
| 8.4.1 | 10bd09ab06f59291e04ad7805e88fd3e693b7159 |
| 8.1.0 | 9e5831a72d5b0a124c845dcaec75879f498b355f |

# Spark Layer
To improve performance when hosting VW in Spark an additional optimized layer can be found in org.vowpalwabbit.spark.*. The actual VW/Spark integration will be available throogh [MMLSpark](https://github.com/Azure/mmlspark).
## Features
1. Native dependencies are included in the JAR file.
2. Features are expected to be already hashed.
3. Multi-pass support.
## Limitations
1. Only simple label is supported for now (e.g. classification/regression).
@@ -50,6 +50,10 @@
<name>John Langford</name>
<email>jl@hunch.net</email>
</developer>
<developer>
<name>Markus Cozowicz</name>
<email>marcozo@microsoft.com</email>
</developer>
</developers>

<properties>
@@ -79,6 +83,11 @@
</dependencies>

<build>
<resources>
<resource>
<directory>target/bin</directory>
</resource>
</resources>
<testResources>
<testResource>
<directory>${project.build.directory}</directory>
@@ -152,6 +161,19 @@
</dependency>
</dependencies>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>2.19.1</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
@@ -175,7 +197,7 @@
<!--<forkMode>once</forkMode>-->
<forkCount>1</forkCount>
<reuseForks>false</reuseForks>
<argLine>-Djava.library.path=${project.build.directory}</argLine>
<argLine>-Djava.library.path=${project.build.directory}/bin</argLine>

</configuration>
</plugin>
@@ -1,37 +1,11 @@
#include "../../../../vowpalwabbit/vw.h"
#include "../../../../vowpalwabbit/vw_exception.h"
#include "vw.h"
#include "vw_exception.h"

#include "jni_base_learner.h"

void throw_java_exception(JNIEnv *env, const char* name, const char* msg)
{ jclass jc = env->FindClass(name);
if (jc)
env->ThrowNew(jc, msg);
}

void rethrow_cpp_exception_as_java_exception(JNIEnv *env)
{ try
{ throw;
}
catch(const std::bad_alloc& e)
{ throw_java_exception(env, "java/lang/OutOfMemoryError", e.what());
}
catch(const VW::vw_unrecognised_option_exception& e)
{ throw_java_exception(env, "java/lang/IllegalArgumentException", e.what());
}
catch(const std::exception& e)
{ throw_java_exception(env, "java/lang/Exception", e.what());
}

catch (...)
{ throw_java_exception(env, "java/lang/Error", "Unidentified exception => "
"rethrow_cpp_exception_as_java_exception "
"may require some completion...");
}
}

example* read_example(JNIEnv *env, jstring example_string, vw* vwInstance)
{ const char* utf_string = env->GetStringUTFChars(example_string, NULL);
example* read_example(JNIEnv* env, jstring example_string, vw* vwInstance)
{
const char* utf_string = env->GetStringUTFChars(example_string, NULL);
example* ex = read_example(utf_string, vwInstance);

env->ReleaseStringUTFChars(example_string, utf_string);
@@ -41,5 +15,6 @@ example* read_example(JNIEnv *env, jstring example_string, vw* vwInstance)
}

example* read_example(const char* example_string, vw* vwInstance)
{ return VW::read_example(*vwInstance, example_string);
{
return VW::read_example(*vwInstance, example_string);
}
@@ -3,11 +3,9 @@

#include <jni.h>
#include <functional>
#include "util.h"

void throw_java_exception(JNIEnv *env, const char* name, const char* msg);
void rethrow_cpp_exception_as_java_exception(JNIEnv *env);

example* read_example(JNIEnv *env, jstring example_string, vw* vwInstance);
example* read_example(JNIEnv* env, jstring example_string, vw* vwInstance);
example* read_example(const char* example_string, vw* vwInstance);

// It would appear that after reading posts like
@@ -16,17 +14,13 @@ example* read_example(const char* example_string, vw* vwInstance);
// http://stackoverflow.com/questions/3203305/write-a-function-that-accepts-a-lambda-expression-as-argument
// it is more efficient to use another type parameter instead of std::function<T(example*)>
// but more difficult to read.
template<typename T, typename F>
T base_predict(
JNIEnv *env,
example* ex,
bool learn,
vw* vwInstance,
const F& predictor,
const bool predict)
{ T result = 0;
template <typename T, typename F>
T base_predict(JNIEnv* env, example* ex, bool learn, vw* vwInstance, const F& predictor, const bool predict)
{
T result = 0;
try
{ if (learn)
{
if (learn)
vwInstance->learn(*ex);
else
vwInstance->predict(*ex);
@@ -37,36 +31,30 @@ T base_predict(
vwInstance->finish_example(*ex);
}
catch (...)
{ rethrow_cpp_exception_as_java_exception(env);
{
rethrow_cpp_exception_as_java_exception(env);
}
return result;
}

template<typename T, typename F>
T base_predict(
JNIEnv *env,
jstring example_string,
jboolean learn,
jlong vwPtr,
const F& predictor)
{ vw* vwInstance = (vw*)vwPtr;
template <typename T, typename F>
T base_predict(JNIEnv* env, jstring example_string, jboolean learn, jlong vwPtr, const F& predictor)
{
vw* vwInstance = (vw*)vwPtr;
example* ex = read_example(env, example_string, vwInstance);
return base_predict<T>(env, ex, learn, vwInstance, predictor, true);
}

template<typename T, typename F>
T base_predict(
JNIEnv *env,
jobjectArray example_strings,
jboolean learn,
jlong vwPtr,
const F& predictor)
{ vw* vwInstance = (vw*)vwPtr;
template <typename T, typename F>
T base_predict(JNIEnv* env, jobjectArray example_strings, jboolean learn, jlong vwPtr, const F& predictor)
{
vw* vwInstance = (vw*)vwPtr;
int example_count = env->GetArrayLength(example_strings);
multi_ex ex_coll; // When doing multiline prediction the final result is stored in the FIRST example parsed.
multi_ex ex_coll; // When doing multiline prediction the final result is stored in the FIRST example parsed.
example* first_example = NULL;
for (int i=0; i<example_count; i++)
{ jstring example_string = (jstring) (env->GetObjectArrayElement(example_strings, i));
for (int i = 0; i < example_count; i++)
{
jstring example_string = (jstring)(env->GetObjectArrayElement(example_strings, i));
example* ex = read_example(env, example_string, vwInstance);
ex_coll.push_back(ex);
if (i == 0)
@@ -75,18 +63,20 @@ T base_predict(
env->DeleteLocalRef(example_strings);

try
{ if (learn)
{
if (learn)
vwInstance->learn(ex_coll);
else
vwInstance->predict(ex_coll);
}
catch (...)
{ rethrow_cpp_exception_as_java_exception(env);
{
rethrow_cpp_exception_as_java_exception(env);
}

vwInstance->finish_example(ex_coll);

return predictor(first_example, env);
}

#endif // VW_BASE_LEARNER_H
#endif // VW_BASE_LEARNER_H

0 comments on commit fa5c3df

Please sign in to comment.
You can’t perform that action at this time.