forked from mlpack/mlpack
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #9 from Yashwants19/Generate-cpp/R
Generate cpp/r
- Loading branch information
Showing
85 changed files
with
1,700 additions
and
695 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
# FindRModule.cmake: find a specific R module. | ||
function(find_r_module module) | ||
string(TOUPPER ${module} module_upper) | ||
if (NOT R_${module_upper}) | ||
if (ARGC GREATER 1) | ||
# Not required but we have version constraints. | ||
set(VERSION_REQ ${ARGV1}) | ||
endif () | ||
# A module's location is usually a directory, but for binary modules | ||
# it's a .so file. | ||
execute_process(COMMAND "Rscript" "-e" "find.package('${module}')" | ||
RESULT_VARIABLE _${module}_status | ||
OUTPUT_VARIABLE _${module}_location | ||
ERROR_QUIET OUTPUT_STRIP_TRAILING_WHITESPACE) | ||
|
||
if (NOT _${module}_status) | ||
# Now we have to check the version. | ||
if (VERSION_REQ) | ||
execute_process(COMMAND "Rscript" "-e" "packageVersion('${module}')" | ||
RESULT_VARIABLE _version_status | ||
OUTPUT_VARIABLE _version_compare | ||
OUTPUT_STRIP_TRAILING_WHITESPACE) | ||
|
||
string(REGEX MATCHALL "‘[0-9._]*’" _version_compare "${_version_compare}") | ||
string(REGEX REPLACE "‘" "" _version_compare "${_version_compare}") | ||
string(REGEX REPLACE "’" "" _version_compare "${_version_compare}") | ||
if ("${_version_compare}" GREATER_EQUAL "${VERSION_REQ}") | ||
set(R_${module_upper} ${_${module}_location} CACHE STRING | ||
"Location of R module ${module}") | ||
else () | ||
message(WARNING "Unsuitable version of R module ${module} (${VERSION_REQ} or greater required).") | ||
endif () | ||
else () | ||
# No version requirement so we are done. | ||
set(R_${module_upper} ${_${module}_location} CACHE STRING "Location of R module ${module}") | ||
endif () | ||
endif () | ||
endif () | ||
|
||
include(FindPackageHandleStandardArgs) | ||
find_package_handle_standard_args(R_${module} DEFAULT_MSG R_${module_upper}) | ||
endfunction () |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
# AppendSerialization.cmake: append imports for serialization and | ||
# deserialization for mlpack model types to the existing list of serialization | ||
# and deserialization imports. | ||
|
||
# This function depends on the following variables being set: | ||
# | ||
# * PROGRAM_MAIN_FILE: the file containing the mlpackMain() function. | ||
# * SERIALIZATION_FILE: file to append types to | ||
# * SERIALIZE: It is of bool type. If SERIALIZE is true we have to print | ||
# Serialize, else Deserialize. | ||
# | ||
function(append_serialization SERIALIZATION_FILE PROGRAM_MAIN_FILE SERIALIZE) | ||
include("${CMAKE_SOURCE_DIR}/CMake/StripType.cmake") | ||
strip_type("${PROGRAM_MAIN_FILE}") | ||
|
||
list(LENGTH MODEL_TYPES NUM_MODEL_TYPES) | ||
if (${NUM_MODEL_TYPES} GREATER 0) | ||
math(EXPR LOOP_MAX "${NUM_MODEL_TYPES}-1") | ||
foreach (INDEX RANGE ${LOOP_MAX}) | ||
list(GET MODEL_TYPES ${INDEX} MODEL_TYPE) | ||
list(GET MODEL_SAFE_TYPES ${INDEX} MODEL_SAFE_TYPE) | ||
file(READ "${SERIALIZATION_FILE}" SERIALIZATION_FILE_CONTENTS) | ||
if (SERIALIZE) | ||
# See if the model type already exists. | ||
string(FIND | ||
"${SERIALIZATION_FILE_CONTENTS}" | ||
"\"${MODEL_SAFE_TYPE}\" = Serialize${MODEL_SAFE_TYPE}Ptr," | ||
FIND_OUT) | ||
|
||
# If it doesn't exist, append it. | ||
if (${FIND_OUT} EQUAL -1) | ||
# Now append the type to the list of types, and define any serialization | ||
# function. | ||
file(APPEND | ||
"${SERIALIZATION_FILE}" | ||
" \"${MODEL_SAFE_TYPE}\" = Serialize${MODEL_SAFE_TYPE}Ptr,\n") | ||
endif() | ||
elseif (NOT SERIALIZE) | ||
# See if the model type already exists. | ||
string(FIND | ||
"${SERIALIZATION_FILE_CONTENTS}" | ||
"\"${MODEL_SAFE_TYPE}\" = Deserialize${MODEL_SAFE_TYPE}Ptr," | ||
FIND_OUT) | ||
|
||
# If it doesn't exist, append it. | ||
if (${FIND_OUT} EQUAL -1) | ||
# Now append the type to the list of types, and define any deserialization | ||
# function. | ||
file(APPEND | ||
"${SERIALIZATION_FILE}" | ||
" \"${MODEL_SAFE_TYPE}\" = Deserialize${MODEL_SAFE_TYPE}Ptr,\n") | ||
endif() | ||
endif() | ||
endforeach () | ||
endif() | ||
endfunction() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# ConfigureRCPP.cmake: generate an mlpack .cpp file for a R binding given | ||
# input arguments. | ||
# | ||
# This file depends on the following variables being set: | ||
# | ||
# * PROGRAM_NAME: name of the binding | ||
# * PROGRAM_MAIN_FILE: the file containing the mlpackMain() function. | ||
# * R_CPP_IN: path of the r_method.cpp.in file. | ||
# * R_CPP_OUT: name of the output .cpp file. | ||
include("${SOURCE_DIR}/CMake/StripType.cmake") | ||
strip_type("${PROGRAM_MAIN_FILE}") | ||
|
||
file(READ "${MODEL_FILE}" MODEL_FILE_TYPE) | ||
if (NOT (MODEL_FILE_TYPE MATCHES "\"${MODEL_SAFE_TYPES}\"")) | ||
file(APPEND "${MODEL_FILE}" "\"${MODEL_SAFE_TYPES}\"\n") | ||
# Now, generate the implementation of the functions we need. | ||
set(MODEL_PTR_IMPLS "") | ||
list(LENGTH MODEL_TYPES NUM_MODEL_TYPES) | ||
# Append content to the list. | ||
if (${NUM_MODEL_TYPES} GREATER 0) | ||
math(EXPR LOOP_MAX "${NUM_MODEL_TYPES}-1") | ||
foreach (INDEX RANGE ${LOOP_MAX}) | ||
list(GET MODEL_TYPES ${INDEX} MODEL_TYPE) | ||
list(GET MODEL_SAFE_TYPES ${INDEX} MODEL_SAFE_TYPE) | ||
|
||
# Define typedef for the model. | ||
set(MODEL_PTR_TYPEDEF "${MODEL_PTR_TYPEDEF}Rcpp::XPtr<${MODEL_TYPE}>") | ||
|
||
# Generate the implementation. | ||
set(MODEL_PTR_IMPLS "${MODEL_PTR_IMPLS} | ||
// Get the pointer to a ${MODEL_TYPE} parameter. | ||
// [[Rcpp::export]] | ||
SEXP CLI_GetParam${MODEL_SAFE_TYPE}Ptr(const std::string& paramName) | ||
{ | ||
return std::move((${MODEL_PTR_TYPEDEF}) CLI::GetParam<${MODEL_TYPE}*>(paramName)); | ||
} | ||
// Set the pointer to a ${MODEL_TYPE} parameter. | ||
// [[Rcpp::export]] | ||
void CLI_SetParam${MODEL_SAFE_TYPE}Ptr(const std::string& paramName, SEXP ptr) | ||
{ | ||
CLI::GetParam<${MODEL_TYPE}*>(paramName) = Rcpp::as<${MODEL_PTR_TYPEDEF}>(ptr); | ||
CLI::SetPassed(paramName); | ||
} | ||
// Serialize a ${MODEL_TYPE} pointer. | ||
// [[Rcpp::export]] | ||
Rcpp::RawVector Serialize${MODEL_SAFE_TYPE}Ptr(SEXP ptr) | ||
{ | ||
std::ostringstream oss; | ||
{ | ||
boost::archive::binary_oarchive oa(oss); | ||
oa << boost::serialization::make_nvp(\"${MODEL_SAFE_TYPE}\", | ||
*Rcpp::as<${MODEL_PTR_TYPEDEF}>(ptr)); | ||
} | ||
Rcpp::RawVector raw_vec(oss.str().size()); | ||
// Copy the string buffer so we can return one that won't get deallocated when | ||
// we exit this function. | ||
memcpy(&raw_vec[0], oss.str().c_str(), oss.str().size()); | ||
raw_vec.attr(\"type\") = \"${MODEL_SAFE_TYPE}\"; | ||
return raw_vec; | ||
} | ||
// Deserialize a ${MODEL_TYPE} pointer. | ||
// [[Rcpp::export]] | ||
SEXP Deserialize${MODEL_SAFE_TYPE}Ptr(Rcpp::RawVector str) | ||
{ | ||
${MODEL_TYPE}* ptr = new ${MODEL_TYPE}(); | ||
std::istringstream iss(std::string((char *) &str[0], str.size())); | ||
{ | ||
boost::archive::binary_iarchive ia(iss); | ||
ia >> boost::serialization::make_nvp(\"${MODEL_SAFE_TYPE}\", *ptr); | ||
} | ||
// R will be responsible for freeing this. | ||
return std::move((${MODEL_PTR_TYPEDEF})ptr); | ||
} | ||
") | ||
endforeach () | ||
endif() | ||
endif() | ||
|
||
# Now configure the files. | ||
configure_file("${R_CPP_IN}" "${R_CPP_OUT}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# StripType.cmake: Extract ModelType from the main file and turn it into | ||
# something that has no special characters that can simply be used. | ||
|
||
# This function depends on the following variables being set: | ||
# | ||
# * PROGRAM_MAIN_FILE: the file containing the mlpackMain() function. | ||
# | ||
function(strip_type PROGRAM_MAIN_FILE) | ||
# We need to parse the main file and find any PARAM_MODEL_* lines. | ||
file(READ "${PROGRAM_MAIN_FILE}" MAIN_FILE) | ||
|
||
# Grab all "PARAM_MODEL_IN(Model,", "PARAM_MODEL_IN_REQ(Model,", | ||
# "PARAM_MODEL_OUT(Model,". | ||
string(REGEX MATCHALL "PARAM_MODEL_IN\\([A-Za-z_<>]*," MODELS_IN | ||
"${MAIN_FILE}") | ||
string(REGEX MATCHALL "PARAM_MODEL_IN_REQ\\([A-Za-z_<>]*," MODELS_IN_REQ | ||
"${MAIN_FILE}") | ||
string(REGEX MATCHALL "PARAM_MODEL_OUT\\([A-Za-z_]*," MODELS_OUT "${MAIN_FILE}") | ||
|
||
string(REGEX REPLACE "PARAM_MODEL_IN\\(" "" MODELS_IN_STRIP1 "${MODELS_IN}") | ||
string(REGEX REPLACE "," "" MODELS_IN_STRIP2 "${MODELS_IN_STRIP1}") | ||
string(REGEX REPLACE "[<>,]" "" MODELS_IN_SAFE_STRIP2 "${MODELS_IN_STRIP1}") | ||
|
||
string(REGEX REPLACE "PARAM_MODEL_IN_REQ\\(" "" MODELS_IN_REQ_STRIP1 | ||
"${MODELS_IN_REQ}") | ||
string(REGEX REPLACE "," "" MODELS_IN_REQ_STRIP2 "${MODELS_IN_REQ_STRIP1}") | ||
string(REGEX REPLACE "[<>,]" "" MODELS_IN_REQ_SAFE_STRIP2 | ||
"${MODELS_IN_REQ_STRIP1}") | ||
|
||
string(REGEX REPLACE "PARAM_MODEL_OUT\\(" "" MODELS_OUT_STRIP1 "${MODELS_OUT}") | ||
string(REGEX REPLACE "," "" MODELS_OUT_STRIP2 "${MODELS_OUT_STRIP1}") | ||
string(REGEX REPLACE "[<>,]" "" MODELS_OUT_SAFE_STRIP2 "${MODELS_OUT_STRIP1}") | ||
|
||
set(MODEL_TYPES ${MODELS_IN_STRIP2} ${MODELS_IN_REQ_STRIP2} | ||
${MODELS_OUT_STRIP2}) | ||
set(MODEL_SAFE_TYPES ${MODELS_IN_SAFE_STRIP2} ${MODELS_IN_REQ_SAFE_STRIP2} | ||
${MODELS_OUT_SAFE_STRIP2}) | ||
if (MODEL_TYPES) | ||
list(REMOVE_DUPLICATES MODEL_TYPES) | ||
endif () | ||
if (MODEL_SAFE_TYPES) | ||
list(REMOVE_DUPLICATES MODEL_SAFE_TYPES) | ||
endif () | ||
|
||
set(MODEL_TYPES ${MODEL_TYPES} PARENT_SCOPE) | ||
set(MODEL_SAFE_TYPES ${MODEL_SAFE_TYPES} PARENT_SCOPE) | ||
endfunction() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.