Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #2 from tqchen/R
Browse files Browse the repository at this point in the history
[R] MINOR Change
  • Loading branch information
Qiang Kou (KK) committed Oct 6, 2015
2 parents 39ac76b + 90b7370 commit 6781426
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 16 deletions.
11 changes: 8 additions & 3 deletions R-package/src/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@
#define MXNET_RCPP_BASE_H_

#include <Rcpp.h>
#include <dmlc/base.h>
#include <mxnet/c_api.h>
// to be removed
#include <dmlc/logging.h>

namespace mxnet {
namespace R {

// change to Rcpp::cerr later, for compatiblity of older version for now
#define RLOG_FATAL LOG(FATAL)
#define RLOG_FATAL ::Rcpp::Rcerr

// checking macro for R side
#define RCHECK(x) \
if (!(x)) \
RLOG_FATAL << "Check " \
"failed: " #x << ' '

/*!
* \brief protected MXNet C API call, report R error if happens.
Expand Down
15 changes: 6 additions & 9 deletions R-package/src/ndarray.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
#include <Rcpp.h>
#include <dmlc/base.h>
#include "./base.h"
#include "./ndarray.h"

Expand Down Expand Up @@ -41,7 +40,7 @@ SEXP NDArray::Load(const std::string& filename) {
&out_name_size, &out_names));
Rcpp::List out(out_size);
for (mx_uint i = 0; i < out_size; ++i) {
out[i] = Rcpp::XPtr<NDArray>(new NDArray(out_arr[i]));
out[i] = NDArray::RObject(out_arr[i]);
}
if (out_name_size != 0) {
std::vector<std::string> lst_names(out_size);
Expand Down Expand Up @@ -105,9 +104,8 @@ NDArrayFunction::NDArrayFunction(FunctionHandle handle)

SEXP NDArrayFunction::operator() (SEXP* args) {
BEGIN_RCPP;
if (!accept_empty_out_) {
RLOG_FATAL << "not yet support mutate target";
}
RCHECK(accept_empty_out_)
<< "not yet support mutate target";
NDArrayHandle ohandle;
MX_CALL(MXNDArrayCreateNone(&ohandle));
std::vector<mx_float> scalars(num_scalars_);
Expand All @@ -124,7 +122,7 @@ SEXP NDArrayFunction::operator() (SEXP* args) {
dmlc::BeginPtr(use_vars),
dmlc::BeginPtr(scalars),
&ohandle));
return Rcpp::XPtr<NDArray>(new NDArray(ohandle));
return NDArray::RObject(ohandle);
END_RCPP;
}

Expand All @@ -137,9 +135,8 @@ void NDArray::InitRcppModule() {

void NDArrayFunction::InitRcppModule() {
Rcpp::Module* scope = ::getCurrentScope();
if (scope == NULL) {
RLOG_FATAL << "Init Module need to be called inside scope";
}
RCHECK(scope != NULL)
<< "Init Module need to be called inside scope";
mx_uint out_size;
FunctionHandle *arr;
MX_CALL(MXListFunctions(&out_size, &arr));
Expand Down
18 changes: 14 additions & 4 deletions R-package/src/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ class NDArray {
/*! \brief default constructor */
NDArray() {}
/*!
* \brief construct NDArray from handle
* \brief create a R object that correspond to the NDArray
* \param handle the NDArrayHandle needed for output.
* \param writable Whether the NDArray is writable or not.
*/
explicit NDArray(NDArrayHandle handle,
bool writable = true)
: handle_(handle), writable_(writable) {}
static SEXP RObject(NDArrayHandle handle, bool writable = true) {
NDArray *nd = new NDArray();
nd->handle_ = handle;
nd->writable_ = writable;
// will call destructor after finalize
return Rcpp::XPtr<NDArray>(nd, true);
}
/*!
* \brief Load a list of ndarray from the file.
* \param filename the name of the file.
Expand All @@ -42,6 +46,12 @@ class NDArray {
/*! \brief static function to initialize the Rcpp functions */
static void InitRcppModule();

/*! \brief destructor */
~NDArray() {
// free the handle
MX_CALL(MXNDArrayFree(handle_));
}

private:
// declare friend class
friend class NDArrayFunction;
Expand Down

0 comments on commit 6781426

Please sign in to comment.