Permalink
Browse files

added prune.

  • Loading branch information...
aaalgo committed Apr 16, 2014
1 parent cedb52b commit 5d29654393c0e19e85e496f4280d7a83473f68bd
Showing with 232 additions and 90 deletions.
  1. +1 −1 Makefile
  2. +145 −77 kgraph.cpp
  3. +10 −1 kgraph.h
  4. +71 −10 prune.cpp
  5. +5 −1 search.cpp
View
@@ -12,7 +12,7 @@ NABO_LIBS+=-lnabo
COMMON=kgraph.o metric.o
HEADERS=kgraph.h kgraph-data.h
PROGS=index search split fvec2lshkit
PROGS=index search prune split fvec2lshkit
EXTRA_PROGS=test
FLANN_PROGS=flann_index flann_search
NABO_PROGS=nabo_search
View
@@ -318,71 +318,107 @@ namespace kgraph {
throw runtime_error("dataset larger than index");
}
vector<Neighbor> knn(params.K + params.P +1);
vector<Neighbor> results;
boost::dynamic_bitset<> flags(graph.size(), false);
unsigned L = params.init;
BOOST_VERIFY(L < params.K);
if (L) BOOST_VERIFY(ids);
for (unsigned l = 0; l < params.init; ++l) {
knn[l].id = ids[l];
if (params.init && params.T > 1) {
throw runtime_error("when init > 0, T must be 1.");
}
if (L == 0) {
unsigned seed = params.seed;
if (seed == 0) seed = time(NULL);
mt19937 rng(seed);
L = params.P;
vector<unsigned> random(L);
GenRandom(rng, &random[0], L, graph.size());
for (unsigned l = 0; l < L; ++l) {
knn[l].id = random[l];
}
}
for (unsigned k = 0; k < L; ++k) {
flags[knn[k].id] = false;
knn[k].flag = true;
knn[k].dist = oracle(knn[k].id);
}
sort(knn.begin(), knn.begin() + L);
unsigned seed = params.seed;
unsigned updates = 0;
if (seed == 0) seed = time(NULL);
mt19937 rng(seed);
unsigned n_comps = 0;
unsigned k = 0;
while (k < L) {
unsigned nk = L;
if (knn[k].flag) {
knn[k].flag = false;
unsigned cur = knn[k].id;
unsigned maxM = M[cur];
auto const &neighbors = graph[cur];
if (maxM > neighbors.size()) {
maxM = neighbors.size();
for (unsigned trial = 0; trial < params.T; ++trial) {
unsigned L = params.init;
if (L == 0) { // generate random starting points
vector<unsigned> random(params.P);
GenRandom(rng, &random[0], random.size(), graph.size());
for (unsigned s: random) {
if (!flags[s]) {
knn[L++].id = s;
}
}
}
else { // user-provided starting points.
BOOST_VERIFY(ids);
BOOST_VERIFY(L < params.K);
for (unsigned l = 0; l < L; ++l) {
knn[l].id = ids[l];
}
for (unsigned m = 0; m < maxM; ++m) {
unsigned id = neighbors[m];
if (flags[id]) continue;
flags[id] = true;
++n_comps;
float dist = oracle(id);
if (dist > params.epsilon) continue;
Neighbor nn(id, oracle(id));
unsigned r = UpdateKnnList(&knn[0], L, nn);
if (L + 1 < knn.size()) ++L;
if (r < nk) {
nk = r;
}
for (unsigned k = 0; k < L; ++k) {
flags[knn[k].id] = false;
knn[k].flag = true;
knn[k].dist = oracle(knn[k].id);
}
sort(knn.begin(), knn.begin() + L);
unsigned k = 0;
while (k < L) {
unsigned nk = L;
if (knn[k].flag) {
knn[k].flag = false;
unsigned cur = knn[k].id;
unsigned maxM = M[cur];
if (params.M > maxM) maxM = params.M;
auto const &neighbors = graph[cur];
if (maxM > neighbors.size()) {
maxM = neighbors.size();
}
for (unsigned m = 0; m < maxM; ++m) {
unsigned id = neighbors[m];
if (flags[id]) continue;
flags[id] = true;
++n_comps;
float dist = oracle(id);
Neighbor nn(id, dist);
unsigned r = UpdateKnnList(&knn[0], L, nn);
if (L + 1 < knn.size()) ++L;
if (r < nk) {
nk = r;
}
}
}
if (nk <= k) {
k = nk;
}
else {
++k;
}
}
if (nk <= k) {
k = nk;
if (L > params.K) L = params.K;
if (results.empty()) {
results.reserve(params.K + 1);
results.resize(L + 1);
copy(knn.begin(), knn.begin() + L, results.begin());
}
else {
++k;
// update results
for (unsigned l = 0; l < L; ++l) {
unsigned r = UpdateKnnList(&results[0], results.size() - 1, knn[l]);
if (r < results.size() /* inserted */ && results.size() < (params.K + 1)) {
results.resize(results.size() + 1);
}
}
}
}
// check epsilon
{
for (unsigned l = 0; l < results.size(); ++l) {
if (results[l].dist > params.epsilon) {
results.resize(l);
break;
}
}
}
if (L > params.K) L = params.K;
unsigned L = results.size() - 1;
BOOST_VERIFY(L <= params.K);
// check epsilon
if (ids) {
for (unsigned k = 0; k < L; ++k) {
ids[k] = knn[k].id;
ids[k] = results[k].id;
}
}
if (pinfo) {
@@ -399,6 +435,58 @@ namespace kgraph {
*pL = v.size();
copy(v.begin(), v.end(), nns);
}
void prune1 () {
for (unsigned i = 0; i < graph.size(); ++i) {
if (graph[i].size() > M[i]) {
graph[i].resize(M[i]);
}
}
}
void prune2 () {
vector<vector<unsigned>> new_graph(graph.size());
vector<unsigned> new_M(graph.size());
vector<vector<unsigned>> reverse(graph.size());
vector<set<unsigned>> reachable(graph.size());
unsigned L = 0;
for (auto const &v: graph) {
if (v.size() > L) L = v.size();
}
cerr << "Level 2 Prune ..." << endl;
for (unsigned l = 0; l < L; ++l) {
cerr << l << endl;
for (unsigned i = 0; i < graph.size(); ++i) {
if (l >= graph[i].size()) continue;
unsigned T = graph[i][l];
if (reachable[i].insert(T).second) { // inserted
new_graph[i].push_back(T);
reverse[T].push_back(i);
// mark newly reachable nodes
for (auto n2: new_graph[T]) {
reachable[i].insert(n2);
}
for (auto r: reverse[i]) {
reachable[r].insert(T);
}
}
if (l + 1 == M[i]) {
new_M[i] = new_graph[i].size();
}
}
}
graph.swap(new_graph);
M.swap(new_M);
}
virtual void prune (IndexOracle const &oracle, unsigned level) {
if (level & PRUNE_LEVEL_1) {
prune1();
}
if (level & PRUNE_LEVEL_2) {
prune2();
}
}
};
class KGraphConstructor: public KGraphImpl {
@@ -581,34 +669,6 @@ namespace kgraph {
}
}
void prune2 () {
vector<vector<unsigned>> new_graph(graph.size());
vector<vector<unsigned>> reverse(graph.size());
vector<set<unsigned>> reachable(graph.size());
unsigned L = *max_element(M.begin(), M.end());
cerr << "Pruning ..." << endl;
for (unsigned l = 0; l < L; ++l) {
cerr << l << endl;
for (unsigned i = 0; i < graph.size(); ++i) {
if (l >= M[i]) continue;
unsigned T = graph[i][l];
if (reachable[i].count(T)) continue;
new_graph[i].push_back(T);
reverse[T].push_back(i);
// mark newly reachable nodes
for (auto n2: new_graph[T]) {
reachable[i].insert(n2);
}
for (auto r: reverse[i]) {
reachable[r].insert(T);
}
}
}
graph.swap(new_graph);
for (unsigned i = 0; i < graph.size(); ++i) {
M[i] = graph[i].size();
}
}
public:
KGraphConstructor (IndexOracle const &o, IndexParams const &p, IndexInfo *r)
: oracle(o), params(p), pinfo(r), nhoods(o.size()), n_comps(0)
@@ -694,21 +754,29 @@ namespace kgraph {
M[n] = nhoods[n].M;
auto const &pool = nhoods[n].pool;
unsigned K = params.L;
/*
if (params.prune == 1) {
K = M[n];
}
*/
knn.resize(K);
for (unsigned k = 0; k < K; ++k) {
knn[k] = pool[k].id;
}
}
/*
if (params.prune == 2) {
prune2();
}
*/
if (params.prune) {
prune(o, params.prune);
}
if (pinfo) {
*pinfo = info;
}
}
};
void KGraphImpl::build (IndexOracle const &oracle, IndexParams const &param, IndexInfo *info) {
View
@@ -8,6 +8,8 @@ namespace kgraph {
static unsigned const default_L = 50;
static unsigned const default_K = 10;
static unsigned const default_P = 100;
static unsigned const default_M = 0;
static unsigned const default_T = 1;
static unsigned const default_S = 10;
static unsigned const default_R = 100;
static unsigned const default_controls = 100;
@@ -16,6 +18,10 @@ namespace kgraph {
static float const default_recall = 0.98;
static float const default_epsilon = 1e30;
static unsigned const default_verbosity = 1;
enum {
PRUNE_LEVEL_1 = 1,
PRUNE_LEVEL_2 = 2
};
static unsigned const default_prune = 0;
extern unsigned verbosity;
@@ -53,12 +59,14 @@ namespace kgraph {
struct SearchParams {
unsigned K;
unsigned M;
unsigned P;
unsigned T;
float epsilon;
unsigned seed;
unsigned init;
SearchParams (): K(default_K), P(default_P), epsilon(default_epsilon), seed(1998), init(0) {
SearchParams (): K(default_K), M(default_M), P(default_P), T(default_T), epsilon(default_epsilon), seed(1998), init(0) {
}
};
@@ -86,6 +94,7 @@ namespace kgraph {
virtual void load (char const *path) = 0;
virtual void save (char const *path) const = 0; // save to file
virtual void build (IndexOracle const &oracle, IndexParams const &params, IndexInfo *info) = 0;
virtual void prune (IndexOracle const &oracle, unsigned level) = 0;
virtual unsigned search (SearchOracle const &oracle, SearchParams const &params, unsigned *ids, SearchInfo *info) const = 0;
static KGraph *create ();
virtual void get_nn (unsigned id, unsigned *nns, unsigned *M, unsigned *L) const = 0;
Oops, something went wrong.

0 comments on commit 5d29654

Please sign in to comment.