Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[Numpy] FFI for cumsum and add (#17747)
Browse files Browse the repository at this point in the history
* FFI cumsum

* Dispatch ufunc

* Add PythonArg

* Remove unused data type

* Seperate op_utils and utils
  • Loading branch information
haojin2 authored Mar 5, 2020
1 parent aba4008 commit 938b35b
Show file tree
Hide file tree
Showing 24 changed files with 629 additions and 146 deletions.
13 changes: 5 additions & 8 deletions include/mxnet/runtime/c_runtime_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,14 +47,11 @@ typedef enum {
kNull = 4U,
kMXNetType = 5U,
kMXNetContext = 6U,
kArrayHandle = 7U,
kObjectHandle = 8U,
kModuleHandle = 9U,
kFuncHandle = 10U,
kStr = 11U,
kBytes = 12U,
kNDArrayContainer = 13U,
kNDArrayHandle = 14U,
kObjectHandle = 7U,
kStr = 8U,
kBytes = 9U,
kPyArg = 10U,
kNDArrayHandle = 11U,
// Extension codes for other frameworks to integrate MXNet PackedFunc.
// To make sure each framework's id do not conflict, use first and
// last sections to mark ranges.
Expand Down
36 changes: 6 additions & 30 deletions include/mxnet/runtime/packed_func.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
#include <mxnet/runtime/container.h>
#include <mxnet/runtime/ffi_helper.h>
#include <mxnet/runtime/data_type.h>
#include <mxnet/runtime/py_arg.h>
#include <mxnet/node/container.h>
#include <mxnet/ir/expr.h>
#include <mxnet/ndarray.h>
Expand Down Expand Up @@ -416,7 +417,6 @@ class MXNetPODValue_ {
}
operator void*() const {
if (type_code_ == kNull) return nullptr;
if (type_code_ == kArrayHandle) return value_.v_handle;
MXNET_CHECK_TYPE_CODE(type_code_, kHandle);
return value_.v_handle;
}
Expand Down Expand Up @@ -520,11 +520,6 @@ class MXNetArgValue : public MXNetPODValue_ {
MXNET_CHECK_TYPE_CODE(type_code_, kNDArrayHandle);
return reinterpret_cast<::mxnet::NDArray*>(value_.v_handle);
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -597,11 +592,6 @@ class MXNetRetValue : public MXNetPODValue_ {
operator MXNetDataType() const {
return MXNetDataType(operator DLDataType());
}
operator PackedFunc() const {
if (type_code_ == kNull) return PackedFunc();
MXNET_CHECK_TYPE_CODE(type_code_, kFuncHandle);
return *ptr<PackedFunc>();
}
template<typename FType>
operator TypedPackedFunc<FType>() const {
return TypedPackedFunc<FType>(operator PackedFunc());
Expand Down Expand Up @@ -668,10 +658,6 @@ class MXNetRetValue : public MXNetPODValue_ {
SwitchToObject(kObjectHandle, std::move(other));
return *this;
}
MXNetRetValue& operator=(PackedFunc f) {
this->SwitchToClass(kFuncHandle, f);
return *this;
}
template<typename FType>
MXNetRetValue& operator=(const TypedPackedFunc<FType>& f) {
return operator=(f.packed());
Expand All @@ -689,6 +675,11 @@ class MXNetRetValue : public MXNetPODValue_ {
value_.v_handle = reinterpret_cast<void*>(value);
return *this;
}
MXNetRetValue& operator=(const PythonArg& value) {
this->SwitchToPOD(kPyArg);
value_.v_int64 = value.offset();
return *this;
}
template<typename T,
typename = typename std::enable_if<
extension_type_info<T>::code != 0>::type>
Expand Down Expand Up @@ -717,7 +708,6 @@ class MXNetRetValue : public MXNetPODValue_ {
/*! \return The value field, if the data is POD */
const MXNetValue& value() const {
CHECK(type_code_ != kObjectHandle &&
type_code_ != kFuncHandle &&
type_code_ != kStr) << "MXNetRetValue.value can only be used for POD data";
return value_;
}
Expand All @@ -741,10 +731,6 @@ class MXNetRetValue : public MXNetPODValue_ {
SwitchToClass<std::string>(kBytes, other);
break;
}
case kFuncHandle: {
SwitchToClass<PackedFunc>(kFuncHandle, other);
break;
}
case kObjectHandle: {
*this = other.operator ObjectRef();
break;
Expand Down Expand Up @@ -792,7 +778,6 @@ class MXNetRetValue : public MXNetPODValue_ {
if (type_code_ == kNull) return;
switch (type_code_) {
case kStr: delete ptr<std::string>(); break;
case kFuncHandle: delete ptr<PackedFunc>(); break;
case kObjectHandle: {
static_cast<Object*>(value_.v_handle)->DecRef();
break;
Expand Down Expand Up @@ -857,7 +842,6 @@ inline const char* TypeCode2Str(int type_code) {
case kBytes: return "bytes";
case kHandle: return "handle";
case kNull: return "NULL";
case kFuncHandle: return "FunctionHandle";
case kObjectHandle: return "ObjectCell";
default: LOG(FATAL) << "unknown type_code="
<< static_cast<int>(type_code); return "";
Expand Down Expand Up @@ -1012,10 +996,6 @@ class MXNetArgsSetter {
values_[i].v_handle = value;
type_codes_[i] = kHandle;
}
void operator()(size_t i, DLTensor* value) const {
values_[i].v_handle = value;
type_codes_[i] = kArrayHandle;
}
void operator()(size_t i, const char* value) const {
values_[i].v_str = value;
type_codes_[i] = kStr;
Expand All @@ -1038,10 +1018,6 @@ class MXNetArgsSetter {
values_[i].v_handle = const_cast<MXNetByteArray*>(&value);
type_codes_[i] = kBytes;
}
void operator()(size_t i, const PackedFunc& value) const { // NOLINT(*)
values_[i].v_handle = const_cast<PackedFunc*>(&value);
type_codes_[i] = kFuncHandle;
}
template<typename FType>
void operator()(size_t i, const TypedPackedFunc<FType>& value) const { // NOLINT(*)
operator()(i, value.packed());
Expand Down
42 changes: 42 additions & 0 deletions include/mxnet/runtime/py_arg.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
/*
* \file py_arg.h
* \brief Python runtime arguments specifier.
*/
#ifndef MXNET_RUNTIME_PY_ARG_H_
#define MXNET_RUNTIME_PY_ARG_H_

namespace mxnet {
namespace runtime {

class PythonArg {
public:
explicit PythonArg(int offset): offset_(offset) {}
int offset() const {
return offset_;
}
private:
int offset_;
};

} // namespace runtime

} // namespace mxnet
#endif // MXNET_RUNTIME_PY_ARG_H_
7 changes: 6 additions & 1 deletion python/mxnet/_ffi/_ctypes/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"""
import ctypes
from numbers import Number, Integral
import numpy as onp

from ...base import get_last_ffi_error, _LIB
from ..base import c_str
Expand Down Expand Up @@ -66,6 +67,9 @@ def _make_mxnet_args(args, temp_args):
elif isinstance(arg, ctypes.c_void_p):
values[i].v_handle = arg
type_codes[i] = TypeCode.HANDLE
elif isinstance(arg, type):
values[i].v_str = c_str(onp.dtype(arg).name)
type_codes[i] = TypeCode.STR
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return values, type_codes, num_args
Expand Down Expand Up @@ -110,7 +114,8 @@ def __call__(self, *args):
raise get_last_ffi_error()
_ = temp_args
_ = args
return RETURN_SWITCH[ret_tcode.value](ret_val)
return (RETURN_SWITCH[ret_tcode.value](ret_val) if ret_tcode.value != TypeCode.PYARG
else RETURN_SWITCH[ret_tcode.value](ret_val, args))


_CLASS_OBJECT = None
Expand Down
16 changes: 7 additions & 9 deletions python/mxnet/_ffi/_ctypes/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ class TypeCode(object):
NULL = 4
MXNET_TYPE = 5
MXNET_CONTEXT = 6
ARRAY_HANDLE = 7
OBJECT_HANDLE = 8
MODULE_HANDLE = 9
FUNC_HANDLE = 10
STR = 11
BYTES = 12
NDARRAY_CONTAINER = 13
NDARRAYHANDLE = 14
OBJECT_HANDLE = 7
STR = 8
BYTES = 9
PYARG = 10
NDARRAYHANDLE = 11
EXT_BEGIN = 15


Expand All @@ -54,5 +51,6 @@ class MXNetValue(ctypes.Union):
TypeCode.INT: lambda x: x.v_int64,
TypeCode.FLOAT: lambda x: x.v_float64,
TypeCode.NULL: lambda x: None,
TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle))
TypeCode.NDARRAYHANDLE: lambda x: _global_var._np_ndarray_cls(handle=NDArrayHandle(x.v_handle)),
TypeCode.PYARG: lambda x, args: args[x.v_int64],
}
13 changes: 5 additions & 8 deletions python/mxnet/_ffi/_cython/base.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,11 @@ cdef enum MXNetTypeCode:
kNull = 4
kMXNetType = 5
kMXNetContext = 6
kArrayHandle = 7
kObjectHandle = 8
kModuleHandle = 9
kFuncHandle = 10
kStr = 11
kBytes = 12
kNDArrayContainer = 13
kNDArrayHandle = 14
kObjectHandle = 7
kStr = 8
kBytes = 9
kPyArg = 10
kNDArrayHandle = 11
kExtBegin = 15

cdef extern from "mxnet/runtime/c_runtime_api.h":
Expand Down
18 changes: 13 additions & 5 deletions python/mxnet/_ffi/_cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"""Acknowledgement: This file originates from incubator-tvm"""

import ctypes
import numpy as onp
import traceback
from ...ndarray._internal import NDArrayBase
from numbers import Number, Integral
Expand Down Expand Up @@ -58,14 +59,23 @@ cdef inline int make_arg(object arg,
elif isinstance(arg, ctypes.c_void_p):
value[0].v_handle = c_handle(arg)
tcode[0] = kHandle
elif isinstance(arg, type):
tstr = c_str(onp.dtype(arg).name)
value[0].v_str = tstr
tcode[0] = kStr
temp_args.append(tstr)
else:
raise TypeError("Don't know how to handle type %s" % type(arg))
return 0


cdef inline object make_ret(MXNetValue value, int tcode):
cdef inline object make_ret(MXNetValue value, int tcode, tuple args):
"""convert result to return value."""
if tcode == kNull:
if tcode == kNDArrayHandle:
return c_make_array(value.v_handle)
elif tcode == kPyArg:
return args[value.v_int64]
elif tcode == kNull:
return None
elif tcode == kInt:
return value.v_int64
Expand All @@ -75,8 +85,6 @@ cdef inline object make_ret(MXNetValue value, int tcode):
return py_str(value.v_str)
elif tcode == kHandle:
return ctypes_handle(value.v_handle)
elif tcode == kNDArrayHandle:
return c_make_array(value.v_handle)
raise ValueError("Unhandled type code %d" % tcode)


Expand Down Expand Up @@ -160,4 +168,4 @@ cdef class FunctionBase:
cdef MXNetValue ret_val
cdef int ret_tcode
FuncCall(self.chandle, args, &ret_val, &ret_tcode)
return make_ret(ret_val, ret_tcode)
return make_ret(ret_val, ret_tcode, args)
51 changes: 0 additions & 51 deletions python/mxnet/_numpy_op_doc.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,57 +134,6 @@ def _np_sometrue(a, axis=None, keepdims=False, out=None):
pass


def _np_cumsum(a, axis=None, dtype=None, out=None):
"""
Return the cumulative sum of the elements along a given axis.
Parameters
----------
a : array_like
Input array.
axis : int, optional
Axis along which the cumulative sum is computed. The default
(None) is to compute the cumsum over the flattened array.
dtype : dtype, optional
Type of the returned array and of the accumulator in which the
elements are summed. If `dtype` is not specified, it defaults
to the dtype of `a`, unless `a` has an integer dtype with a
precision less than that of the default platform integer. In
that case, the default platform integer is used.
out : ndarray, optional
Alternative output array in which to place the result. It must
have the same shape and buffer length as the expected output
but the type will be cast if necessary. See `doc.ufuncs`
(Section "Output arguments") for more details.
Returns
-------
cumsum_along_axis : ndarray.
A new array holding the result is returned unless `out` is
specified, in which case a reference to `out` is returned. The
result has the same size as `a`, and the same shape as `a` if
`axis` is not None or `a` is a 1-d array.
Examples
--------
>>> a = np.array([[1,2,3], [4,5,6]])
>>> a
array([[1, 2, 3],
[4, 5, 6]])
>>> np.cumsum(a)
array([ 1, 3, 6, 10, 15, 21])
>>> np.cumsum(a, dtype=float) # specifies type of output value(s)
array([ 1., 3., 6., 10., 15., 21.])
>>> np.cumsum(a,axis=0) # sum over rows for each of the 3 columns
array([[1, 2, 3],
[5, 7, 9]])
>>> np.cumsum(a,axis=1) # sum over columns for each of the 2 rows
array([[ 1, 3, 6],
[ 4, 9, 15]])
"""
pass


def _npx_nonzero(a):
"""
Return the indices of the elements that are non-zero.
Expand Down
Loading

0 comments on commit 938b35b

Please sign in to comment.