diff --git a/setup.py b/setup.py index 086f60e..b9c1c4c 100644 --- a/setup.py +++ b/setup.py @@ -33,7 +33,7 @@ print("**could not install pip or setuptools_scm, version is defaulted") def myversion(): - version = '2.0.22' + version = '2.0.23' try: mversion = get_version() s = mversion.split('.') @@ -118,6 +118,7 @@ def read(*names, **kwargs): 'Programming Language :: Python :: 3.6', 'Programming Language :: Python :: 3.7', 'Programming Language :: Python :: 3.8', + 'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: Implementation :: CPython', 'Programming Language :: Python :: Implementation :: PyPy', 'Topic :: Utilities', diff --git a/src/pnumpy/_pnumpy.cpp b/src/pnumpy/_pnumpy.cpp index 2e04c64..b73b4c1 100644 --- a/src/pnumpy/_pnumpy.cpp +++ b/src/pnumpy/_pnumpy.cpp @@ -345,7 +345,7 @@ stArangeFunc g_ArangeFuncLUT[ATOP_LAST]; // set to 0 to disable -stSettings g_Settings = { 1, 0, 0, 0, 0 }; +stSettings g_Settings = { 1, 0, 0, 0, 0, 0, 0 }; // Macro used just before call a ufunc #define LEDGER_START() g_Settings.LedgerEnabled = 0; int64_t ledgerStartTime = __rdtsc(); @@ -1555,7 +1555,6 @@ PyObject* newinit(PyObject* self, PyObject* args, PyObject* kwargs) { // Allocate an array of 10 PyArrayObject* pTemp= AllocateNumpyArray(1, dims, srcdtype); - //PyArray_Descr* pSrcDtype = PyArray_DescrFromType(srcdtype); PyArray_Descr* pSrcDtype = PyArray_DESCR(pTemp); if (pSrcDtype) { @@ -1662,9 +1661,10 @@ PyObject* newinit(PyObject* self, PyObject* args, PyObject* kwargs) { // } //} } - //Py_DECREF(pSrcDtype); - - //Py_DECREF(arr); + //Py_DECREF(pSrcDtype); + //Py_DECREF(pTemp); + // + } RETURN_NONE; } @@ -1675,6 +1675,94 @@ PyObject* newinit(PyObject* self, PyObject* args, PyObject* kwargs) { RETURN_NONE; } +//--------------------------------------------------------- +// GetItem hook +// +extern "C" +PyObject * GetItemHook(PyObject* aValues, PyObject* aIndex) { + + // Quick check for an array or we bail + if (PyType_IsSubtype(aIndex->ob_type, &PyArray_Type)) { + int32_t numpyValuesType = PyArray_TYPE((PyArrayObject*)aValues); + int32_t numpyIndexType = PyArray_TYPE((PyArrayObject*)aIndex); + + PyObject* result = NULL; + if (numpyIndexType == NPY_BOOL) { + // special path for boolean + result = BooleanIndexInternal((PyArrayObject*)aValues, (PyArrayObject*)aIndex); + if (result) { + return result; + } + // clear error since punting + PyErr_Clear(); + + } else + // TODO: improve this to handle strings/unicode/datetime/other + if (numpyIndexType <= NPY_LONGDOUBLE) { + result = getitem(aValues, aIndex); + if (result) { + return result; + } + PyErr_Clear(); + } + } + return g_Settings.NumpyGetItem(aValues, aIndex); +} + +//--------------------------------------------------------- +// Call to get hook +extern "C" +PyObject * hook_enable(PyObject * self, PyObject * args) { + + if (g_Settings.NumpyGetItem == NULL) { + npy_intp dims[1] = { 10 }; + + // Allocate an array of 10 bools + PyArrayObject* pTemp = AllocateNumpyArray(1, dims, 0); + if (pTemp) { + struct _typeobject* pNumpyType = ((PyObject*)pTemp)->ob_type; + + // Not hooked yet + // PyNumberMethods* numbermethods = pNumpyType->tp_as_number; + // richcmpfunc comparefunc = pNumpyType->tp_richcompare; + // __setitem__ + // objobjargproc PyMappingMethods.mp_ass_subscript + + // __getitem__ + // Reroute hook + g_Settings.NumpyGetItem= pNumpyType->tp_as_mapping->mp_subscript; + pNumpyType->tp_as_mapping->mp_subscript = GetItemHook; + + Py_DECREF(pTemp); + RETURN_TRUE; + } + } + RETURN_FALSE; +} + +//--------------------------------------------------------- +// Call to remove previous hook +extern "C" +PyObject * hook_disable(PyObject * self, PyObject * args) { + if (g_Settings.NumpyGetItem != NULL) { + npy_intp dims[1] = { 10 }; + + // Allocate an array of 10 bools + PyArrayObject* pTemp = AllocateNumpyArray(1, dims, 0); + if (pTemp) { + struct _typeobject* pNumpyType = ((PyObject*)pTemp)->ob_type; + // __getitem__ + // Put hook back + pNumpyType->tp_as_mapping->mp_subscript = g_Settings.NumpyGetItem; + g_Settings.NumpyGetItem = NULL; + Py_DECREF(pTemp); + RETURN_TRUE; + } + } + RETURN_FALSE; +} + + extern "C" PyObject * atop_enable(PyObject * self, PyObject * args) { g_Settings.AtopEnabled = TRUE; diff --git a/src/pnumpy/common.h b/src/pnumpy/common.h index 16ec643..4a6412b 100644 --- a/src/pnumpy/common.h +++ b/src/pnumpy/common.h @@ -26,6 +26,8 @@ struct stSettings { int32_t RecyclerEnabled; int32_t ZigZag; // set to 0 to disable int32_t Initialized; + int32_t Reserved; + binaryfunc NumpyGetItem; // optional hook }; extern stSettings g_Settings; @@ -112,6 +114,9 @@ extern ArrayInfo* BuildArrayInfo( extern void FreeArrayInfo(ArrayInfo* pAlloc); +extern PyObject* BooleanIndexInternal(PyArrayObject* aValues, PyArrayObject* aIndex); +extern "C" PyObject *getitem(PyObject * self, PyObject * args); + #define RETURN_NONE Py_INCREF(Py_None); return Py_None; #define RETURN_FALSE Py_XINCREF(Py_False); return Py_False; #define RETURN_TRUE Py_XINCREF(Py_True); return Py_True; diff --git a/src/pnumpy/module_init.cpp b/src/pnumpy/module_init.cpp index 60b2656..d1da675 100644 --- a/src/pnumpy/module_init.cpp +++ b/src/pnumpy/module_init.cpp @@ -37,6 +37,9 @@ extern "C" PyObject* recycler_disable(PyObject * self, PyObject * args); extern "C" PyObject* recycler_isenabled(PyObject * self, PyObject * args); extern "C" PyObject* recycler_info(PyObject * self, PyObject * args); +extern "C" PyObject * hook_enable(PyObject * self, PyObject * args); +extern "C" PyObject * hook_disable(PyObject * self, PyObject * args); + extern "C" PyObject* timer_gettsc(PyObject * self, PyObject * args); extern "C" PyObject* timer_getutc(PyObject * self, PyObject * args); extern "C" PyObject* cpustring(PyObject * self, PyObject * args); @@ -69,6 +72,8 @@ static PyMethodDef module_functions[] = { {"thread_zigzag", (PyCFunction)thread_zigzag, METH_VARARGS, "toggle zigzag mode"}, {"timer_gettsc", (PyCFunction)timer_gettsc, METH_VARARGS, TIMER_GETTSC_DOC}, {"timer_getutc", (PyCFunction)timer_getutc, METH_VARARGS, TIMER_GETUTC_DOC}, + {"hook_enable", (PyCFunction)hook_enable, METH_VARARGS, "Enable hook for numpy array __getitem__ for fancy and bool indexing"}, + {"hook_disable", (PyCFunction)hook_disable, METH_VARARGS, "Disable hook for numpy array __getitem__ for fancy and bool indexing"}, {"ledger_enable", (PyCFunction)ledger_enable, METH_VARARGS, LEDGER_ENABLE_DOC}, {"ledger_disable", (PyCFunction)ledger_disable, METH_VARARGS, LEDGER_DISABLE_DOC}, {"ledger_isenabled", (PyCFunction)ledger_isenabled, METH_VARARGS, LEDGER_ISENABLED_DOC},