@@ -65,6 +65,7 @@ class FVectorNode: public FVector
65
65
int instantiate_vector ();
66
66
size_t get_ref_len () const ;
67
67
uchar *get_ref () const { return ref; }
68
+ bool is_new () const ;
68
69
69
70
static uchar *get_key (const FVectorNode *elem, size_t *key_len, my_bool);
70
71
};
@@ -76,6 +77,7 @@ class MHNSW_Context
76
77
TABLE *table;
77
78
Field *vec_field;
78
79
size_t vec_len= 0 ;
80
+ FVector *target= 0 ;
79
81
80
82
Hash_set<FVectorNode> node_cache{PSI_INSTRUMENT_MEM, FVectorNode::get_key};
81
83
@@ -133,6 +135,11 @@ size_t FVectorNode::get_ref_len() const
133
135
return ctx->table ->file ->ref_length ;
134
136
}
135
137
138
+ bool FVectorNode::is_new () const
139
+ {
140
+ return this == ctx->target ;
141
+ }
142
+
136
143
uchar *FVectorNode::get_key (const FVectorNode *elem, size_t *key_len, my_bool)
137
144
{
138
145
*key_len= elem->get_ref_len ();
@@ -315,6 +322,7 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
315
322
const FVectorNode &source_node,
316
323
const List<FVectorNode> &new_neighbors)
317
324
{
325
+ int err;
318
326
TABLE *graph= ctx->table ->hlindex ;
319
327
DBUG_ASSERT (new_neighbors.elements <= HNSW_MAX_M);
320
328
@@ -337,25 +345,24 @@ static int write_neighbors(MHNSW_Context *ctx, size_t layer_number,
337
345
graph->field [1 ]->store_binary (source_node.get_ref (), source_node.get_ref_len ());
338
346
graph->field [2 ]->store_binary (neighbor_array_bytes, total_size);
339
347
340
- uchar *key= static_cast <uchar*>(alloca (graph->key_info ->key_length ));
341
- key_copy (key, graph->record [0 ], graph->key_info , graph->key_info ->key_length );
342
-
343
- // XXX try to write first?
344
- int err= graph->file ->ha_index_read_map (graph->record [1 ], key, HA_WHOLE_KEY,
345
- HA_READ_KEY_EXACT);
346
-
347
- // no record
348
- if (err == HA_ERR_KEY_NOT_FOUND)
348
+ if (source_node.is_new ())
349
349
{
350
350
dbug_print_vec_ref (" INSERT " , layer_number, source_node);
351
351
err= graph->file ->ha_write_row (graph->record [0 ]);
352
352
}
353
- else if (!err)
353
+ else
354
354
{
355
355
dbug_print_vec_ref (" UPDATE " , layer_number, source_node);
356
356
dbug_print_vec_neigh (layer_number, new_neighbors);
357
357
358
- err= graph->file ->ha_update_row (graph->record [1 ], graph->record [0 ]);
358
+ uchar *key= static_cast <uchar*>(alloca (graph->key_info ->key_length ));
359
+ key_copy (key, graph->record [0 ], graph->key_info , graph->key_info ->key_length );
360
+
361
+ err= graph->file ->ha_index_read_map (graph->record [1 ], key,
362
+ HA_WHOLE_KEY, HA_READ_KEY_EXACT);
363
+ if (!err)
364
+ err= graph->file ->ha_update_row (graph->record [1 ], graph->record [0 ]);
365
+
359
366
}
360
367
my_safe_afree (neighbor_array_bytes, total_size);
361
368
return err;
@@ -416,7 +423,7 @@ static int update_neighbors(MHNSW_Context *ctx,
416
423
}
417
424
418
425
419
- static int search_layer (MHNSW_Context *ctx, const FVector &target,
426
+ static int search_layer (MHNSW_Context *ctx,
420
427
const List<FVectorNode> &start_nodes,
421
428
uint max_candidates_return, size_t layer,
422
429
List<FVectorNode> *result)
@@ -427,6 +434,7 @@ static int search_layer(MHNSW_Context *ctx, const FVector &target,
427
434
Queue<FVectorNode, const FVector> candidates;
428
435
Queue<FVectorNode, const FVector> best;
429
436
Hash_set<FVectorNode> visited (PSI_INSTRUMENT_MEM, FVectorNode::get_key);
437
+ const FVector &target= *ctx->target ;
430
438
431
439
candidates.init (10000 , false , cmp_vec, &target);
432
440
best.init (max_candidates_return, true , cmp_vec, &target);
@@ -537,20 +545,21 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
537
545
538
546
SCOPE_EXIT ([graph](){ graph->file ->ha_index_end (); });
539
547
548
+ h->position (table->record [0 ]);
549
+
540
550
if (int err= graph->file ->ha_index_last (graph->record [0 ]))
541
551
{
542
552
if (err != HA_ERR_END_OF_FILE)
543
553
return err;
544
554
545
555
// First insert!
546
- h->position (table->record [0 ]);
547
- return write_neighbors (&ctx, 0 , {&ctx, h->ref }, {});
556
+ FVectorNode target (&ctx, h->ref );
557
+ ctx.target = ⌖
558
+ return write_neighbors (&ctx, 0 , target, {});
548
559
}
549
560
550
561
longlong max_layer= graph->field [0 ]->val_int ();
551
562
552
- h->position (table->record [0 ]);
553
-
554
563
List<FVectorNode> candidates;
555
564
List<FVectorNode> start_nodes;
556
565
String ref_str, *ref_ptr;
@@ -570,14 +579,15 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
570
579
return bad_value_on_insert (vec_field);
571
580
572
581
FVectorNode target (&ctx, h->ref , res->ptr ());
582
+ ctx.target = ⌖
573
583
574
584
double new_num= my_rnd (&thd->rand );
575
585
double log= -std::log (new_num) * NORMALIZATION_FACTOR;
576
586
longlong new_node_layer= static_cast <longlong>(std::floor (log));
577
587
578
588
for (longlong cur_layer= max_layer; cur_layer > new_node_layer; cur_layer--)
579
589
{
580
- if (int err= search_layer (&ctx, target, start_nodes,
590
+ if (int err= search_layer (&ctx, start_nodes,
581
591
thd->variables .hnsw_ef_constructor , cur_layer,
582
592
&candidates))
583
593
return err;
@@ -590,7 +600,7 @@ int mhnsw_insert(TABLE *table, KEY *keyinfo)
590
600
cur_layer >= 0 ; cur_layer--)
591
601
{
592
602
List<FVectorNode> neighbors;
593
- if (int err= search_layer (&ctx, target, start_nodes,
603
+ if (int err= search_layer (&ctx, start_nodes,
594
604
thd->variables .hnsw_ef_constructor , cur_layer,
595
605
&candidates))
596
606
return err;
@@ -669,23 +679,23 @@ int mhnsw_read_first(TABLE *table, KEY *keyinfo, Item *dist, ulonglong limit)
669
679
res= vec_field->val_str (&buf);
670
680
671
681
FVector target (&ctx, res->ptr ());
682
+ ctx.target = ⌖
672
683
673
684
ulonglong ef_search= std::max<ulonglong>( // XXX why not always limit?
674
685
thd->variables .hnsw_ef_search , limit);
675
686
676
687
for (size_t cur_layer= max_layer; cur_layer > 0 ; cur_layer--)
677
688
{
678
689
// XXX in the paper ef_search=1 here
679
- if (int err= search_layer (&ctx, target, start_nodes, ef_search,
680
- cur_layer, &candidates))
690
+ if (int err= search_layer (&ctx, start_nodes, ef_search, cur_layer ,
691
+ &candidates))
681
692
return err;
682
693
start_nodes.empty ();
683
694
start_nodes.push_back (candidates.head (), &ctx.root ); // XXX so ef_search=1 ???
684
695
candidates.empty ();
685
696
}
686
697
687
- if (int err= search_layer (&ctx, target, start_nodes, ef_search, 0 ,
688
- &candidates))
698
+ if (int err= search_layer (&ctx, start_nodes, ef_search, 0 , &candidates))
689
699
return err;
690
700
691
701
size_t context_size= limit * h->ref_length + sizeof (ulonglong);
0 commit comments