Skip to content

Commit b492025

Browse files
committed
mhnsw: don't guess whether it's insert or update
we know it every time
1 parent 267092d commit b492025

File tree

1 file changed

+32
-22
lines changed

1 file changed

+32
-22
lines changed

sql/vector_mhnsw.cc

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ class FVectorNode: public FVector
6565
int instantiate_vector();
6666
size_t get_ref_len() const;
6767
uchar *get_ref() const { return ref; }
68+
bool is_new() const;
6869

6970
static uchar *get_key(const FVectorNode *elem, size_t *key_len, my_bool);
7071
};
@@ -76,6 +77,7 @@ class MHNSW_Context
7677
TABLE *table;
7778
Field *vec_field;
7879
size_t vec_len= 0;
80+
FVector *target= 0;
7981

8082
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
8183

@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
133135
return ctx->table->file->ref_length;
134136
}
135137

138+
bool FVectorNode::is_new() const
139+
{
140+
return this == ctx->target;
141+
}
142+
136143
uchar *FVectorNode::get_key(const FVectorNode *elem, size_t *key_len, my_bool)
137144
{
138145
*key_len= elem->get_ref_len();
@@ -315,6 +322,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
315322
const FVectorNode &source_node,
316323
const List<FVectorNode> &new_neighbors)
317324
{
325+
int err;
318326
TABLE *graph= ctx->table->hlindex;
319327
DBUG_ASSERT(new_neighbors.elements <= HNSW_MAX_M);
320328

@@ -337,25 +345,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
337345
graph->field[1]->store_binary(source_node.get_ref(), source_node.get_ref_len());
338346
graph->field[2]->store_binary(neighbor_array_bytes, total_size);
339347

340-
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
341-
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
342-
343-
// XXX try to write first?
344-
int err= graph->file->ha_index_read_map(graph->record[1], key, HA_WHOLE_KEY,
345-
HA_READ_KEY_EXACT);
346-
347-
// no record
348-
if (err == HA_ERR_KEY_NOT_FOUND)
348+
if (source_node.is_new())
349349
{
350350
dbug_print_vec_ref("INSERT ", layer_number, source_node);
351351
err= graph->file->ha_write_row(graph->record[0]);
352352
}
353-
else if (!err)
353+
else
354354
{
355355
dbug_print_vec_ref("UPDATE ", layer_number, source_node);
356356
dbug_print_vec_neigh(layer_number, new_neighbors);
357357

358-
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
358+
uchar *key= static_cast<uchar*>(alloca(graph->key_info->key_length));
359+
key_copy(key, graph->record[0], graph->key_info, graph->key_info->key_length);
360+
361+
err= graph->file->ha_index_read_map(graph->record[1], key,
362+
HA_WHOLE_KEY, HA_READ_KEY_EXACT);
363+
if (!err)
364+
err= graph->file->ha_update_row(graph->record[1], graph->record[0]);
365+
359366
}
360367
my_safe_afree(neighbor_array_bytes, total_size);
361368
return err;
@@ -416,7 +423,7 @@ static int update_neighbors(MHNSW_Context *ctx,
416423
}
417424

418425

419-
static int search_layer(MHNSW_Context *ctx, const FVector &target,
426+
static int search_layer(MHNSW_Context *ctx,
420427
const List<FVectorNode> &start_nodes,
421428
uint max_candidates_return, size_t layer,
422429
List<FVectorNode> *result)
@@ -427,6 +434,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
427434
Queue<FVectorNode, const FVector> candidates;
428435
Queue<FVectorNode, const FVector> best;
429436
Hash_set<FVectorNode> visited(PSI_INSTRUMENT_MEM, FVectorNode::get_key);
437+
const FVector &target= *ctx->target;
430438

431439
candidates.init(10000, false, cmp_vec, &target);
432440
best.init(max_candidates_return, true, cmp_vec, &target);
@@ -537,20 +545,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
537545

538546
SCOPE_EXIT([graph](){ graph->file->ha_index_end(); });
539547

548+
h->position(table->record[0]);
549+
540550
if (int err= graph->file->ha_index_last(graph->record[0]))
541551
{
542552
if (err != HA_ERR_END_OF_FILE)
543553
return err;
544554

545555
// First insert!
546-
h->position(table->record[0]);
547-
return write_neighbors(&ctx, 0, {&ctx, h->ref}, {});
556+
FVectorNode target(&ctx, h->ref);
557+
ctx.target= &target;
558+
return write_neighbors(&ctx, 0, target, {});
548559
}
549560

550561
longlong max_layer= graph->field[0]->val_int();
551562

552-
h->position(table->record[0]);
553-
554563
List<FVectorNode> candidates;
555564
List<FVectorNode> start_nodes;
556565
String ref_str, *ref_ptr;
@@ -570,14 +579,15 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
570579
return bad_value_on_insert(vec_field);
571580

572581
FVectorNode target(&ctx, h->ref, res->ptr());
582+
ctx.target= &target;
573583

574584
double new_num= my_rnd(&thd->rand);
575585
double log= -std::log(new_num) * NORMALIZATION_FACTOR;
576586
longlong new_node_layer= static_cast<longlong>(std::floor(log));
577587

578588
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
579589
{
580-
if (int err= search_layer(&ctx, target, start_nodes,
590+
if (int err= search_layer(&ctx, start_nodes,
581591
thd->variables.hnsw_ef_constructor, cur_layer,
582592
&candidates))
583593
return err;
@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
590600
cur_layer >= 0; cur_layer--)
591601
{
592602
List<FVectorNode> neighbors;
593-
if (int err= search_layer(&ctx, target, start_nodes,
603+
if (int err= search_layer(&ctx, start_nodes,
594604
thd->variables.hnsw_ef_constructor, cur_layer,
595605
&candidates))
596606
return err;
@@ -669,23 +679,23 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
669679
res= vec_field->val_str(&buf);
670680

671681
FVector target(&ctx, res->ptr());
682+
ctx.target= &target;
672683

673684
ulonglong ef_search= std::max<ulonglong>( //XXX why not always limit?
674685
thd->variables.hnsw_ef_search, limit);
675686

676687
for (size_t cur_layer= max_layer; cur_layer > 0; cur_layer--)
677688
{
678689
//XXX in the paper ef_search=1 here
679-
if (int err= search_layer(&ctx, target, start_nodes, ef_search,
680-
cur_layer, &candidates))
690+
if (int err= search_layer(&ctx, start_nodes, ef_search, cur_layer,
691+
&candidates))
681692
return err;
682693
start_nodes.empty();
683694
start_nodes.push_back(candidates.head(), &ctx.root); // XXX so ef_search=1 ???
684695
candidates.empty();
685696
}
686697

687-
if (int err= search_layer(&ctx, target, start_nodes, ef_search, 0,
688-
&candidates))
698+
if (int err= search_layer(&ctx, start_nodes, ef_search, 0, &candidates))
689699
return err;
690700

691701
size_t context_size= limit * h->ref_length + sizeof(ulonglong);

0 commit comments

Comments
 (0)