Permalink
Browse files

added S parameter to search

  • Loading branch information...
aaalgo committed Jun 17, 2016
1 parent c3a7f14 commit 953c36c60fe9bdd5a0dc35e22db2d30b68b086ce
Showing with 133 additions and 91 deletions.
  1. +107 −76 kgraph.cpp
  2. +4 −1 kgraph.h
  3. +3 −0 python/pykgraph.cpp
  4. +19 −14 search.cpp
View
@@ -68,6 +68,15 @@ namespace kgraph {
}
};
// extended neighbor structure for search time
struct NeighborX: public Neighbor {
uint16_t m;
uint16_t M; // actual M used
NeighborX () {}
NeighborX (unsigned i, float d): Neighbor(i, d, true), m(0), M(0) {
}
};
static inline bool operator < (Neighbor const &n1, Neighbor const &n2) {
return n1.dist < n2.dist;
}
@@ -159,7 +168,8 @@ namespace kgraph {
// Special case: K == 0
// addr[0] <- nn
// return 0
static inline unsigned UpdateKnnList (Neighbor *addr, unsigned K, Neighbor nn) {
template <typename NeighborT>
unsigned UpdateKnnListHelper (NeighborT *addr, unsigned K, NeighborT nn) {
// find the location to insert
unsigned j;
unsigned i = K;
@@ -186,6 +196,14 @@ namespace kgraph {
return i;
}
static inline unsigned UpdateKnnList (Neighbor *addr, unsigned K, Neighbor nn) {
return UpdateKnnListHelper<Neighbor>(addr, K, nn);
}
static inline unsigned UpdateKnnList (NeighborX *addr, unsigned K, NeighborX nn) {
return UpdateKnnListHelper<NeighborX>(addr, K, nn);
}
void LinearSearch (IndexOracle const &oracle, unsigned i, unsigned K, vector<Neighbor> *pnns) {
vector<Neighbor> nns(K+1);
unsigned N = oracle.size();
@@ -256,46 +274,12 @@ namespace kgraph {
vector<vector<Neighbor>> graph;
bool no_dist; // Distance & flag information in Neighbor is not valid.
void reverse (int rev_k) {
if (rev_k == 0) return;
if (no_dist) throw runtime_error("Need distance information to reverse graph");
vector<vector<Neighbor>> ng; // new graph adds on original one
{
cerr << "Graph completion with reverse edges..." << endl;
ng = graph;
progress_display progress(graph.size(), cerr);
for (unsigned i = 0; i < graph.size(); ++i) {
auto const &v = graph[i];
unsigned K = M[i];
if (rev_k > 0) {
K = rev_k;
if (K > v.size()) K = v.size();
}
//if (v.size() < XX) XX = v.size();
for (unsigned j = 0; j < K; ++j) {
auto const &e = v[j];
auto re = e;
re.id = i;
ng[e.id].push_back(re);
}
++progress;
}
graph.swap(ng);
}
{
cerr << "Reranking edges..." << endl;
progress_display progress(graph.size(), cerr);
#pragma omp parallel for
for (unsigned i = 0; i < graph.size(); ++i) {
auto &v = graph[i];
std::sort(v.begin(), v.end());
v.resize(std::unique(v.begin(), v.end()) - v.begin());
M[i] = v.size();
#pragma omp critical
++progress;
}
}
// actual M for a node that should be used in search time
unsigned actual_M (unsigned pM, unsigned i) const {
return std::min(std::max(M[i], pM), unsigned(graph[i].size()));
}
public:
virtual ~KGraphImpl () {
}
@@ -427,8 +411,10 @@ namespace kgraph {
}
return oracle.search(params.K, params.epsilon, ids, dists);
}
vector<Neighbor> knn(params.K + params.P +1);
vector<Neighbor> results;
vector<NeighborX> knn(params.K + params.P +1);
vector<NeighborX> results;
// flags access is totally random, so use small block to avoid
// extra memory access
boost::dynamic_bitset<> flags(graph.size(), false);
if (params.init && params.T > 1) {
@@ -460,47 +446,50 @@ namespace kgraph {
}
}
for (unsigned k = 0; k < L; ++k) {
flags[knn[k].id] = true;
knn[k].flag = true;
knn[k].dist = oracle(knn[k].id);
auto &e = knn[k];
flags[e.id] = true;
e.flag = true;
e.dist = oracle(e.id);
e.m = 0;
e.M = actual_M(params.M, e.id);
}
sort(knn.begin(), knn.begin() + L);
unsigned k = 0;
while (k < L) {
unsigned nk = L;
if (knn[k].flag) {
knn[k].flag = false;
unsigned cur = knn[k].id;
//BOOST_VERIFY(cur < graph.size());
unsigned maxM = M[cur];
if (params.M > maxM) maxM = params.M;
auto const &neighbors = graph[cur];
if (maxM > neighbors.size()) {
maxM = neighbors.size();
}
for (unsigned m = 0; m < maxM; ++m) {
unsigned id = neighbors[m].id;
//BOOST_VERIFY(id < graph.size());
if (flags[id]) continue;
flags[id] = true;
++n_comps;
float dist = oracle(id);
Neighbor nn(id, dist);
unsigned r = UpdateKnnList(&knn[0], L, nn);
BOOST_VERIFY(r <= L);
//if (r > L) continue;
if (L + 1 < knn.size()) ++L;
if (r < nk) {
nk = r;
}
}
auto &e = knn[k];
if (!e.flag) { // all neighbors of this node checked
++k;
continue;
}
if (nk <= k) {
k = nk;
unsigned beginM = e.m;
unsigned endM = beginM + params.S; // check this many entries
if (endM > e.M) { // we are done with this node
e.flag = false;
endM = e.M;
}
else {
++k;
e.m = endM;
// all modification to knn[k] must have been done now,
// as we might be relocating knn[k] in the loop below
auto const &neighbors = graph[e.id];
for (unsigned m = beginM; m < endM; ++m) {
unsigned id = neighbors[m].id;
//BOOST_VERIFY(id < graph.size());
if (flags[id]) continue;
flags[id] = true;
++n_comps;
float dist = oracle(id);
NeighborX nn(id, dist);
unsigned r = UpdateKnnList(&knn[0], L, nn);
BOOST_VERIFY(r <= L);
//if (r > L) continue;
if (L + 1 < knn.size()) ++L;
if (r < L) {
knn[r].M = actual_M(params.M, id);
if (r < k) {
k = r;
}
}
}
}
if (L > params.K) L = params.K;
@@ -657,6 +646,48 @@ namespace kgraph {
prune2();
}
}
void reverse (int rev_k) {
if (rev_k == 0) return;
if (no_dist) throw runtime_error("Need distance information to reverse graph");
{
cerr << "Graph completion with reverse edges..." << endl;
vector<vector<Neighbor>> ng(graph.size()); // new graph adds on original one
//ng = graph;
progress_display progress(graph.size(), cerr);
for (unsigned i = 0; i < graph.size(); ++i) {
auto const &v = graph[i];
unsigned K = M[i];
if (rev_k > 0) {
K = rev_k;
if (K > v.size()) K = v.size();
}
//if (v.size() < XX) XX = v.size();
for (unsigned j = 0; j < K; ++j) {
auto const &e = v[j];
auto re = e;
re.id = i;
ng[i].push_back(e);
ng[e.id].push_back(re);
}
++progress;
}
graph.swap(ng);
}
{
cerr << "Reranking edges..." << endl;
progress_display progress(graph.size(), cerr);
#pragma omp parallel for
for (unsigned i = 0; i < graph.size(); ++i) {
auto &v = graph[i];
std::sort(v.begin(), v.end());
v.resize(std::unique(v.begin(), v.end()) - v.begin());
M[i] = v.size();
#pragma omp critical
++progress;
}
}
}
};
class KGraphConstructor: public KGraphImpl {
View
@@ -112,13 +112,14 @@ namespace kgraph {
unsigned K;
unsigned M;
unsigned P;
unsigned S;
unsigned T;
float epsilon;
unsigned seed;
unsigned init;
/// Construct with default values.
SearchParams (): K(default_K), M(default_M), P(default_P), T(default_T), epsilon(default_epsilon), seed(1998), init(0) {
SearchParams (): K(default_K), M(default_M), P(default_P), S(default_S), T(default_T), epsilon(default_epsilon), seed(1998), init(0) {
}
};
@@ -219,6 +220,8 @@ namespace kgraph {
* @params L Actually returned number of neighbors, output only.
*/
virtual void get_nn (unsigned id, unsigned *nns, float *dists, unsigned *M, unsigned *L) const = 0;
virtual void reverse (int) = 0;
};
}
View
@@ -424,6 +424,7 @@ class KGraph {
unsigned K,
unsigned P,
unsigned M,
unsigned S,
unsigned T,
unsigned threads,
bool withDistance,
@@ -432,6 +433,7 @@ class KGraph {
params.K = K;
params.P = P;
params.M = M;
params.S = S;
params.T = T;
params.threads = threads;
params.withDistance = withDistance;
@@ -467,6 +469,7 @@ BOOST_PYTHON_MODULE(pykgraph)
python::arg("K") = kgraph::default_K,
python::arg("P") = kgraph::default_P,
python::arg("M") = kgraph::default_M,
python::arg("S") = kgraph::default_S,
python::arg("T") = kgraph::default_T,
python::arg("threads") = 0,
python::arg("withDistance") = false,
View
@@ -35,7 +35,7 @@ int main (int argc, char *argv[]) {
string output_path;
string init_path;
string eval_path;
unsigned K, M, P, T;
unsigned K, M, P, S, T;
po::options_description desc_visible("General options");
desc_visible.add_options()
@@ -50,6 +50,7 @@ int main (int argc, char *argv[]) {
(",K", po::value(&K)->default_value(default_K), "")
(",M", po::value(&M)->default_value(default_M), "")
(",P", po::value(&P)->default_value(default_P), "")
(",S", po::value(&S)->default_value(default_S), "")
(",T", po::value(&T)->default_value(default_T), "")
("linear", "")
("l2norm", "l2-normalize data, so as to mimic cosine similarity")
@@ -106,42 +107,46 @@ int main (int argc, char *argv[]) {
float cost = 0;
float time = 0;
if (vm.count("linear")) {
boost::timer::auto_cpu_timer timer;
result.resize(query.size(), K);
boost::progress_display progress(query.size(), cerr);
{ // to ensure auto_cpu_timer destruct before "time" computation
boost::timer::auto_cpu_timer timer;
#pragma omp parallel for
for (unsigned i = 0; i < query.size(); ++i) {
oracle.query(query[i]).search(K, default_epsilon, result[i]);
for (unsigned i = 0; i < query.size(); ++i) {
oracle.query(query[i]).search(K, default_epsilon, result[i]);
#pragma omp critical
++progress;
++progress;
}
time = timer.elapsed().wall / 1e9;
}
cost = 1.0;
time = timer.elapsed().wall / 1e9;
}
else {
result.resize(query.size(), K);
KGraph::SearchParams params;
params.K = K;
params.M = M;
params.P = P;
params.S = S;
params.T = T;
params.init = init;
KGraph *kgraph = KGraph::create();
kgraph->load(index_path.c_str());
boost::timer::auto_cpu_timer timer;
cerr << "Searching..." << endl;
boost::progress_display progress(query.size(), cerr);
{
boost::timer::auto_cpu_timer timer;
#pragma omp parallel for reduction(+:cost)
for (unsigned i = 0; i < query.size(); ++i) {
KGraph::SearchInfo info;
kgraph->search(oracle.query(query[i]), params, result[i], &info);
for (unsigned i = 0; i < query.size(); ++i) {
KGraph::SearchInfo info;
kgraph->search(oracle.query(query[i]), params, result[i], &info);
#pragma omp critical
++progress;
cost += info.cost;
++progress;
cost += info.cost;
}
time = timer.elapsed().wall / 1e9;
}
cost /= query.size();
time = timer.elapsed().wall / 1e9;
//cerr << "Cost: " << cost << endl;
delete kgraph;
}

0 comments on commit 953c36c

Please sign in to comment.