Skip to content

Commit 3c6e836

Browse files
committed
generous_furthest optimization
make generosity depend on 1. M. Keep small M's fast, increase generosity for larger M's to get better recall. 2. distance. Keep generosity small when vectors are far from the target, increase generosity when the search gets closer. This allows to examine more relevant vectors but doesn't waste time examining irrelevant vectors. Particularly important with cosine metric when the distance is bounded
1 parent fb04cad commit 3c6e836

File tree

1 file changed

+20
-4
lines changed

1 file changed

+20
-4
lines changed

sql/vector_mhnsw.cc

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727
// Algorithm parameters
2828
static constexpr float alpha = 1.1f;
29-
static constexpr float generosity = 1.1f;
3029
static constexpr uint ef_construction= 10;
3130

3231
static ulonglong mhnsw_cache_size;
@@ -334,6 +333,7 @@ class MHNSW_Context : public Sql_alloc
334333
size_t vec_len= 0;
335334
size_t byte_len= 0;
336335
Atomic_relaxed<double> ef_power{0.6}; // for the bloom filter size heuristic
336+
Atomic_relaxed<float> diameter{0}; // for the generosity heuristic
337337
FVectorNode *start= 0;
338338
const uint tref_len;
339339
const uint gref_len;
@@ -957,6 +957,17 @@ static int update_second_degree_neighbors(MHNSW_Context *ctx, TABLE *graph,
957957
return 0;
958958
}
959959

960+
961+
static inline float generous_furthest(const Queue<Visited> &q, float maxd, float g)
962+
{
963+
float d0=maxd*g/2;
964+
float d= q.top()->distance_to_target;
965+
float k= 5;
966+
float x= (d-d0)/d0;
967+
float sigmoid= k*x/std::sqrt(1+(k*k-1)*x*x); // or any other sigmoid
968+
return d*(1 + (g - 1)/2 * (1 - sigmoid));
969+
}
970+
960971
static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
961972
Neighborhood *start_nodes, uint result_size,
962973
size_t layer, Neighborhood *result, bool construction)
@@ -968,6 +979,7 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
968979
Queue<Visited> candidates, best;
969980
bool skip_deleted;
970981
uint ef= result_size;
982+
float generosity= 1.1f + ctx->M/500.0f;
971983

972984
if (construction)
973985
{
@@ -991,17 +1003,19 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
9911003
best.init(ef, true, Visited::cmp);
9921004

9931005
DBUG_ASSERT(start_nodes->num <= result_size);
1006+
float max_distance= ctx->diameter;
9941007
for (size_t i=0; i < start_nodes->num; i++)
9951008
{
9961009
Visited *v= visited.create(start_nodes->links[i]);
1010+
max_distance= std::max(max_distance, v->distance_to_target);
9971011
candidates.push(v);
9981012
if (skip_deleted && v->node->deleted)
9991013
continue;
10001014
best.push(v);
10011015
}
10021016

10031017
float furthest_best= best.is_empty() ? FLT_MAX
1004-
: best.top()->distance_to_target * generosity;
1018+
: generous_furthest(best, max_distance, generosity);
10051019
while (candidates.elements())
10061020
{
10071021
const Visited &cur= *candidates.pop();
@@ -1027,11 +1041,12 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
10271041
Visited *v= visited.create(links[i]);
10281042
if (!best.is_full())
10291043
{
1044+
max_distance= std::max(max_distance, v->distance_to_target);
10301045
candidates.push(v);
10311046
if (skip_deleted && v->node->deleted)
10321047
continue;
10331048
best.push(v);
1034-
furthest_best= best.top()->distance_to_target * generosity;
1049+
furthest_best= generous_furthest(best, max_distance, generosity);
10351050
}
10361051
else if (v->distance_to_target < furthest_best)
10371052
{
@@ -1041,12 +1056,13 @@ static int search_layer(MHNSW_Context *ctx, TABLE *graph, const FVector *target,
10411056
if (v->distance_to_target < best.top()->distance_to_target)
10421057
{
10431058
best.replace_top(v);
1044-
furthest_best= best.top()->distance_to_target * generosity;
1059+
furthest_best= generous_furthest(best, max_distance, generosity);
10451060
}
10461061
}
10471062
}
10481063
}
10491064
}
1065+
set_if_bigger(ctx->diameter, max_distance); // not atomic, but it's ok
10501066
if (ef > 1 && visited.count*2 > est_size)
10511067
{
10521068
double ef_power= std::log(visited.count*2/est_heuristic) / std::log(ef);

0 commit comments

Comments
 (0)