/
neighbor_search.hpp
373 lines (335 loc) · 14.6 KB
/
neighbor_search.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
/**
* @file methods/neighbor_search/neighbor_search.hpp
* @author Ryan Curtin
*
* Defines the NeighborSearch class, which performs an abstract
* nearest-neighbor-like query on two datasets.
*
* mlpack is free software; you may redistribute it and/or modify it under the
* terms of the 3-clause BSD license. You should have received a copy of the
* 3-clause BSD license along with mlpack. If not, see
* http://www.opensource.org/licenses/BSD-3-Clause for more information.
*/
#ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
#define MLPACK_METHODS_NEIGHBOR_SEARCH_NEIGHBOR_SEARCH_HPP
#include <mlpack/core.hpp>
#include "neighbor_search_stat.hpp"
#include "sort_policies/nearest_neighbor_sort.hpp"
#include "sort_policies/furthest_neighbor_sort.hpp"
#include "neighbor_search_rules.hpp"
#include "unmap.hpp"
namespace mlpack {
// Forward declaration.
template<typename SortPolicy,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType,
template<typename RuleType> class DualTreeTraversalType,
template<typename RuleType> class SingleTreeTraversalType>
class LeafSizeNSWrapper;
//! NeighborSearchMode represents the different neighbor search modes available.
enum NeighborSearchMode
{
NAIVE_MODE,
SINGLE_TREE_MODE,
DUAL_TREE_MODE,
GREEDY_SINGLE_TREE_MODE
};
/**
* The NeighborSearch class is a template class for performing distance-based
* neighbor searches. It takes a query dataset and a reference dataset (or just
* a reference dataset) and, for each point in the query dataset, finds the k
* neighbors in the reference dataset which have the 'best' distance according
* to a given sorting policy. A constructor is given which takes only a
* reference dataset, and if that constructor is used, the given reference
* dataset is also used as the query dataset.
*
* The template parameters SortPolicy and Metric define the sort function used
* and the metric (distance function) used. More information on those classes
* can be found in the NearestNeighborSort class and the ExampleKernel class.
*
* @tparam SortPolicy The sort policy for distances; see NearestNeighborSort.
* @tparam MetricType The metric to use for computation.
* @tparam MatType The type of data matrix.
* @tparam TreeType The tree type to use; must adhere to the TreeType API.
* @tparam DualTreeTraversalType The type of dual tree traversal to use
* (defaults to the tree's default traverser).
* @tparam SingleTreeTraversalType The type of single tree traversal to use
* (defaults to the tree's default traverser).
*/
template<typename SortPolicy = NearestNeighborSort,
typename MetricType = EuclideanDistance,
typename MatType = arma::mat,
template<typename TreeMetricType,
typename TreeStatType,
typename TreeMatType> class TreeType = KDTree,
template<typename RuleType> class DualTreeTraversalType =
TreeType<MetricType,
NeighborSearchStat<SortPolicy>,
MatType>::template DualTreeTraverser,
template<typename RuleType> class SingleTreeTraversalType =
TreeType<MetricType,
NeighborSearchStat<SortPolicy>,
MatType>::template SingleTreeTraverser>
class NeighborSearch
{
public:
//! Convenience typedef.
typedef TreeType<MetricType, NeighborSearchStat<SortPolicy>, MatType> Tree;
//! The type of element held in MatType.
typedef typename MatType::elem_type ElemType;
/**
* Initialize the NeighborSearch object, passing a reference dataset (this is
* the dataset which is searched). Optionally, perform the computation in
* a different mode. An initialized distance metric can be given, for cases
* where the metric has internal data (i.e. the distance::MahalanobisDistance
* class).
*
* This method will move the matrices to internal copies, which are rearranged
* during tree-building. You can avoid creating an extra copy by pre-constructing
* the trees, passing std::move(yourReferenceSet).
*
* @param referenceSet Set of reference points.
* @param mode Neighbor search mode.
* @param epsilon Relative approximate error (non-negative).
* @param metric An optional instance of the MetricType class.
*/
NeighborSearch(MatType referenceSet,
const NeighborSearchMode mode = DUAL_TREE_MODE,
const double epsilon = 0,
const MetricType metric = MetricType());
/**
* Initialize the NeighborSearch object with a copy of the given
* pre-constructed reference tree (this is the tree built on the points that
* will be searched). Optionally, choose to use single-tree mode. Naive mode
* is not available as an option for this constructor. Additionally, an
* instantiated distance metric can be given, for cases where the distance
* metric holds data.
*
* This method will copy the given tree. When copies must absolutely be avoided,
* you can avoid this copy, while taking ownership of the given tree, by passing
* std::move(yourReferenceTree)
*
* @note
* Mapping the points of the matrix back to their original indices is not done
* when this constructor is used, so if the tree type you are using maps
* points (like BinarySpaceTree), then you will have to perform the re-mapping
* manually.
*
* @param referenceTree Pre-built tree for reference points.
* @param mode Neighbor search mode.
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated distance metric.
*/
NeighborSearch(Tree referenceTree,
const NeighborSearchMode mode = DUAL_TREE_MODE,
const double epsilon = 0,
const MetricType metric = MetricType());
/**
* Create a NeighborSearch object without any reference data. If Search() is
* called before a reference set is set with Train(), an exception will be
* thrown.
*
* @param mode Neighbor search mode.
* @param epsilon Relative approximate error (non-negative).
* @param metric Instantiated metric.
*/
NeighborSearch(const NeighborSearchMode mode = DUAL_TREE_MODE,
const double epsilon = 0,
const MetricType metric = MetricType());
/**
* Construct the NeighborSearch object by copying the given NeighborSearch
* object.
*
* @param other NeighborSearch object to copy.
*/
NeighborSearch(const NeighborSearch& other);
/**
* Construct the NeighborSearch object by taking ownership of the given
* NeighborSearch object.
*
* @param other NeighborSearch object to take ownership of.
*/
NeighborSearch(NeighborSearch&& other);
/**
* Copy the given NeighborSearch object.
*
* @param other NeighborSearch object to copy.
*/
NeighborSearch& operator=(const NeighborSearch& other);
/**
* Take ownership of the given NeighborSearch object.
*
* @param other NeighborSearch object to take ownership of.
*/
NeighborSearch& operator=(NeighborSearch&& other);
/**
* Delete the NeighborSearch object. The tree is the only member we are
* responsible for deleting. The others will take care of themselves.
*/
~NeighborSearch();
/**
* Set the reference set to a new reference set, and build a tree if
* necessary. The dataset is copied by default, but the copy can be avoided by
* transferring the ownership of the dataset using std::move(). This method
* is called 'Train()' in order to match the rest of the mlpack abstractions,
* even though calling this "training" is maybe a bit of a stretch.
*
* @param referenceSet New set of reference data.
*/
void Train(MatType referenceSet);
/**
* Set the reference tree to a new reference tree. The tree is copied by
* default, but the copy can be avoided by using std::move() to transfer the
* ownership of the tree. This method is called 'Train()' in order to match
* the rest of the mlpack abstractions, even though calling this "training" is
* maybe a bit of a stretch.
*
* @param referenceTree Pre-built tree for reference points.
*/
void Train(Tree referenceTree);
/**
* For each point in the query set, compute the nearest neighbors and store
* the output in the given matrices. The matrices will be set to the size of
* n columns by k rows, where n is the number of points in the query dataset
* and k is the number of neighbors being searched for.
*
* If querySet contains only a few query points, the extra cost of building a
* tree on the points for dual-tree search may not be warranted, and it may be
* worthwhile to set singleMode = false (either in the constructor or with
* SingleMode()).
*
* @param querySet Set of query points (can be just one point).
* @param k Number of neighbors to search for.
* @param neighbors Matrix storing lists of neighbors for each query point.
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
void Search(const MatType& querySet,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::Mat<ElemType>& distances);
/**
* Given a pre-built query tree, search for the nearest neighbors of each
* point in the query tree, storing the output in the given matrices. The
* matrices will be set to the size of n columns by k rows, where n is the
* number of points in the query dataset and k is the number of neighbors
* being searched for.
*
* Note that if you are calling Search() multiple times with a single query
* tree, you need to reset the bounds in the statistic of each query node,
* otherwise the result may be wrong! You can do this by calling
* \c TreeType::Stat().Reset() on each node in the query tree.
*
* @param queryTree Tree built on query points.
* @param k Number of neighbors to search for.
* @param neighbors Matrix storing lists of neighbors for each query point.
* @param distances Matrix storing distances of neighbors for each query
* point.
* @param sameSet Denotes whether or not the reference and query sets are the
* same.
*/
void Search(Tree& queryTree,
const size_t k,
arma::Mat<size_t>& neighbors,
arma::Mat<ElemType>& distances,
bool sameSet = false);
/**
* Search for the nearest neighbors of every point in the reference set. This
* is basically equivalent to calling any other overload of Search() with the
* reference set as the query set; so, this lets you do
* all-k-nearest-neighbors search. The results are stored in the given
* matrices. The matrices will be set to the size of n columns by k rows,
* where n is the number of points in the query dataset and k is the number of
* neighbors being searched for.
*
* @param k Number of neighbors to search for.
* @param neighbors Matrix storing lists of neighbors for each query point.
* @param distances Matrix storing distances of neighbors for each query
* point.
*/
void Search(const size_t k,
arma::Mat<size_t>& neighbors,
arma::Mat<ElemType>& distances);
/**
* Calculate the average relative error (effective error) between the
* distances calculated and the true distances provided. The input matrices
* must have the same size.
*
* Cases where the true distance is zero (the same point) or the calculated
* distance is SortPolicy::WorstDistance() (didn't find enough points) will be
* ignored.
*
* @param foundDistances Matrix storing lists of calculated distances for each
* query point.
* @param realDistances Matrix storing lists of true best distances for each
* query point.
* @return Average relative error.
*/
static double EffectiveError(arma::Mat<ElemType>& foundDistances,
arma::Mat<ElemType>& realDistances);
/**
* Calculate the recall (% of neighbors found) given the list of found
* neighbors and the true set of neighbors. The recall returned will be in
* the range [0, 1].
*
* @param foundNeighbors Matrix storing lists of calculated neighbors for each
* query point.
* @param realNeighbors Matrix storing lists of true best neighbors for each
* query point.
* @return Recall.
*/
static double Recall(arma::Mat<size_t>& foundNeighbors,
arma::Mat<size_t>& realNeighbors);
//! Return the total number of base case evaluations performed during the last
//! search.
size_t BaseCases() const { return baseCases; }
//! Return the number of node combination scores during the last search.
size_t Scores() const { return scores; }
//! Access the search mode.
NeighborSearchMode SearchMode() const { return searchMode; }
//! Modify the search mode.
NeighborSearchMode& SearchMode() { return searchMode; }
//! Access the relative error to be considered in approximate search.
double Epsilon() const { return epsilon; }
//! Modify the relative error to be considered in approximate search.
double& Epsilon() { return epsilon; }
//! Access the reference dataset.
const MatType& ReferenceSet() const { return *referenceSet; }
//! Access the reference tree.
const Tree& ReferenceTree() const { return *referenceTree; }
//! Modify the reference tree.
Tree& ReferenceTree() { return *referenceTree; }
//! Serialize the NeighborSearch model.
template<typename Archive>
void serialize(Archive& ar, const uint32_t version);
private:
//! Permutations of reference points during tree building.
std::vector<size_t> oldFromNewReferences;
//! Pointer to the root of the reference tree.
Tree* referenceTree;
//! Reference dataset. In some situations we may be the owner of this.
const MatType* referenceSet;
//! Indicates the neighbor search mode.
NeighborSearchMode searchMode;
//! Indicates the relative error to be considered in approximate search.
double epsilon;
//! Instantiation of metric.
MetricType metric;
//! The total number of base cases.
size_t baseCases;
//! The total number of scores (applicable for non-naive search).
size_t scores;
//! If this is true, the reference tree bounds need to be reset on a call to
//! Search() without a query set.
bool treeNeedsReset;
//! The NSModel class should have access to internal members.
friend class LeafSizeNSWrapper<SortPolicy, TreeType, DualTreeTraversalType,
SingleTreeTraversalType>;
}; // class NeighborSearch
} // namespace mlpack
// Include implementation.
#include "neighbor_search_impl.hpp"
// Include convenience typedefs.
#include "typedef.hpp"
#endif