Skip to content

Commit 885eb19

Browse files
committed
cleanup search_layer()
to return only as many elements as needed, the caller no longer needs to overallocate result arrays for throwaway nodes
1 parent fa2078d commit 885eb19

File tree

1 file changed

+32
-19
lines changed

1 file changed

+32
-19
lines changed

sql/vector_mhnsw.cc

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -886,16 +886,29 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
886886
}
887887

888888
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
889-
Neighborhood *start_nodes, uint ef, size_t layer,
890-
Neighborhood *result, bool skip_deleted)
889+
Neighborhood *start_nodes, uint result_size,
890+
size_t layer, Neighborhood *result, bool construction)
891891
{
892892
DBUG_ASSERT(start_nodes->num > 0);
893893
result->num= 0;
894894

895895
MEM_ROOT * const root= graph->in_use->mem_root;
896+
Queue<Visited> candidates, best;
897+
bool skip_deleted;
898+
uint ef= result_size;
896899

897-
Queue<Visited> candidates;
898-
Queue<Visited> best;
900+
if (construction)
901+
{
902+
skip_deleted= false;
903+
if (ef > 1)
904+
ef= std::max(ef_construction, ef);
905+
}
906+
else
907+
{
908+
skip_deleted= layer == 0;
909+
if (ef > 1 || layer == 0)
910+
ef= std::max(graph->in_use->variables.mhnsw_min_limit, ef);
911+
}
899912

900913
// WARNING! heuristic here
901914
const double est_heuristic= 8 * std::sqrt(ctx->max_neighbors(layer));
@@ -905,23 +918,21 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
905918
candidates.init(10000, false, Visited::cmp);
906919
best.init(ef, true, Visited::cmp);
907920

921+
DBUG_ASSERT(start_nodes->num <= result_size);
908922
for (size_t i=0; i < start_nodes->num; i++)
909923
{
910924
Visited *v= visited.create(start_nodes->links[i]);
911925
candidates.push(v);
912926
if (skip_deleted && v->node->deleted)
913927
continue;
914-
if (best.elements() < ef)
915-
best.push(v);
916-
else if (v->distance_to_target < best.top()->distance_to_target)
917-
best.replace_top(v);
928+
best.push(v);
918929
}
919930

920931
float furthest_best= FLT_MAX;
921932
while (candidates.elements())
922933
{
923934
const Visited &cur= *candidates.pop();
924-
if (cur.distance_to_target > furthest_best && best.elements() == ef)
935+
if (cur.distance_to_target > furthest_best && best.is_full())
925936
break; // All possible candidates are worse than what we have
926937

927938
visited.flush();
@@ -941,7 +952,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
941952
if (int err= links[i]->load(graph))
942953
return err;
943954
Visited *v= visited.create(links[i]);
944-
if (best.elements() < ef)
955+
if (!best.is_full())
945956
{
946957
candidates.push(v);
947958
if (skip_deleted && v->node->deleted)
@@ -966,6 +977,9 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
966977
set_if_bigger(ctx->ef_power, ef_power); // not atomic, but it's ok
967978
}
968979

980+
while (best.elements() > result_size)
981+
best.pop();
982+
969983
result->num= best.elements();
970984
for (FVectorNode **links= result->links + result->num; best.elements();)
971985
*--links= best.pop()->node;
@@ -1033,9 +1047,10 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
10331047
root_make_savepoint(thd->mem_root, &memroot_sv);
10341048
SCOPE_EXIT([memroot_sv](){ root_free_to_savepoint(&memroot_sv); });
10351049

1050+
const size_t max_found= ctx->max_neighbors(0);
10361051
Neighborhood candidates, start_nodes;
1037-
candidates.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
1038-
start_nodes.init(thd->alloc<FVectorNode*>(ef_construction + 7), ef_construction);
1052+
candidates.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
1053+
start_nodes.init(thd->alloc<FVectorNode*>(max_found + 7), max_found);
10391054
start_nodes.links[start_nodes.num++]= ctx->start;
10401055

10411056
const double NORMALIZATION_FACTOR= 1 / std::log(ctx->M);
@@ -1063,7 +1078,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
10631078
{
10641079
uint max_neighbors= ctx->max_neighbors(cur_layer);
10651080
if (int err= search_layer(ctx, graph, target->vec, &start_nodes,
1066-
ef_construction, cur_layer, &candidates, false))
1081+
max_neighbors, cur_layer, &candidates, true))
10671082
return err;
10681083

10691084
if (int err= select_neighbors(ctx, graph, cur_layer, *target, candidates,
@@ -1106,11 +1121,9 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
11061121
if (err)
11071122
return err;
11081123

1109-
size_t ef= thd->variables.mhnsw_min_limit;
1110-
11111124
Neighborhood candidates, start_nodes;
1112-
candidates.init(thd->alloc<FVectorNode*>(ef + 7), ef);
1113-
start_nodes.init(thd->alloc<FVectorNode*>(ef + 7), ef);
1125+
candidates.init(thd->alloc<FVectorNode*>(limit + 7), limit);
1126+
start_nodes.init(thd->alloc<FVectorNode*>(limit + 7), limit);
11141127

11151128
// one could put all max_layer nodes in start_nodes
11161129
// but it has no effect on the recall or speed
@@ -1146,8 +1159,8 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
11461159
std::swap(start_nodes, candidates);
11471160
}
11481161

1149-
if (int err= search_layer(ctx, graph, target, &start_nodes, ef, 0,
1150-
&candidates, true))
1162+
if (int err= search_layer(ctx, graph, target, &start_nodes,
1163+
static_cast<uint>(limit), 0, &candidates, false))
11511164
return err;
11521165

11531166
if (limit > candidates.num)

0 commit comments

Comments
 (0)