Permalink
Browse files

return both ids & distances for python search api

  • Loading branch information...
jonbakerfish committed Jul 19, 2015
1 parent ca62bfe commit 0c9ee8e765d1e4561787ef2432a1f551e2e99bb0
Showing with 36 additions and 8 deletions.
  1. +2 −0 .gitignore
  2. +34 −8 python/pykgraph.cpp
@@ -0,0 +1,2 @@
*.so
*.o
@@ -46,31 +46,55 @@ class KGraph {
python::object searchImpl (python::object const &data,
python::object const &query,
kgraph::KGraph::SearchParams params,
unsigned threads) {
unsigned threads,
bool withDistance) {
checkArray<TYPE>(data);
checkArray<TYPE>(query);
kgraph::MatrixProxy<TYPE> dmatrix(reinterpret_cast<PyArrayObject *>(data.ptr()));
kgraph::MatrixProxy<TYPE> qmatrix(reinterpret_cast<PyArrayObject *>(query.ptr()));
kgraph::MatrixOracle<TYPE, kgraph::metric::l2sqr> oracle(dmatrix);
npy_intp dims[] = {qmatrix.size(), params.K};
PyObject *result = PyArray_SimpleNew(2, dims, NPY_UINT32);
PyObject *distance = PyArray_SimpleNew(2, dims, NPY_FLOAT);
kgraph::MatrixProxy<unsigned, 1> rmatrix(reinterpret_cast<PyArrayObject *>(result));
kgraph::MatrixProxy<float, 1> distmatrix(reinterpret_cast<PyArrayObject *>(distance));
#ifdef _OPENMP
if (threads) ::omp_set_max_threads(threads);
#endif
if (hasIndex) {
#pragma omp parallel for reduction(+:cost)
for (unsigned i = 0; i < qmatrix.size(); ++i) {
index->search(oracle.query(qmatrix[i]), params, const_cast<unsigned *>(rmatrix[i]), NULL);
if (withDistance) {
index->search(oracle.query(qmatrix[i]), params, const_cast<unsigned *>(rmatrix[i]),
const_cast<float *>(distmatrix[i]),NULL);
}
else {
index->search(oracle.query(qmatrix[i]), params, const_cast<unsigned *>(rmatrix[i]), NULL);
}
}
}
else {
#pragma omp parallel for reduction(+:cost)
for (unsigned i = 0; i < qmatrix.size(); ++i) {
oracle.query(qmatrix[i]).search(params.K, params.epsilon, const_cast<unsigned *>(rmatrix[i]));
if (withDistance) {
oracle.query(qmatrix[i]).search(params.K, params.epsilon, const_cast<unsigned *>(rmatrix[i]),
const_cast<float *>(distmatrix[i]));
}
else {
oracle.query(qmatrix[i]).search(params.K, params.epsilon, const_cast<unsigned *>(rmatrix[i]), NULL);
}
}
}
return python::object(python::handle<>(result));
if (withDistance) {
PyObject* tup = PyTuple_New(2);
PyTuple_SetItem(tup,0,result);
PyTuple_SetItem(tup,1,distance);
return python::object(python::handle<>(tup));
}
else {
return python::object(python::handle<>(result));
}
}
public:
@@ -118,7 +142,8 @@ class KGraph {
unsigned P,
unsigned M,
unsigned T,
unsigned threads) {
unsigned threads,
bool withDistance) {
kgraph::KGraph::SearchParams params;
params.K = K;
params.P = P;
@@ -128,8 +153,8 @@ class KGraph {
PyArrayObject *pq = reinterpret_cast<PyArrayObject *>(data.ptr());
if (pd->descr->type_num != pq->descr->type_num) throw runtime_error("data and query have different types");
switch (pd->descr->type_num) {
case NPY_FLOAT: return searchImpl<float>(data, query, params, threads);
case NPY_DOUBLE: return searchImpl<double>(data, query, params, threads);
case NPY_FLOAT: return searchImpl<float>(data, query, params, threads, withDistance);
case NPY_DOUBLE: return searchImpl<double>(data, query, params, threads, withDistance);
}
throw runtime_error("data type not supported.");
return python::object();
@@ -160,7 +185,8 @@ BOOST_PYTHON_MODULE(pykgraph)
python::arg("P") = kgraph::default_P,
python::arg("M") = kgraph::default_M,
python::arg("T") = kgraph::default_T,
python::arg("threads") = 0))
python::arg("threads") = 0,
python::arg("withDistance") = false))
;
}

0 comments on commit 0c9ee8e

Please sign in to comment.