Skip to content

Commit

Permalink
Implement B_aux according to what was discussed in mlpack#642.
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcosPividori committed May 30, 2016
1 parent ec1c46a commit 558a4fd
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 14 deletions.
41 changes: 27 additions & 14 deletions src/mlpack/methods/neighbor_search/neighbor_search_rules_impl.hpp
Expand Up @@ -344,40 +344,52 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::

double worstDistance = SortPolicy::BestDistance();
double bestDistance = SortPolicy::WorstDistance();
double bestPointDistance = SortPolicy::WorstDistance();
double auxDistance = SortPolicy::WorstDistance();

// Loop over points held in the node.
for (size_t i = 0; i < queryNode.NumPoints(); ++i)
{
const double distance = distances(distances.n_rows - 1, queryNode.Point(i));
if (SortPolicy::IsBetter(worstDistance, distance))
worstDistance = distance;
if (SortPolicy::IsBetter(distance, bestDistance))
bestDistance = distance;
if (SortPolicy::IsBetter(distance, bestPointDistance))
bestPointDistance = distance;
}

// Add triangle inequality adjustment to best distance. It is possible this
// could be tighter for some certain types of trees.
bestDistance = SortPolicy::CombineWorst(bestDistance,
queryNode.FurthestPointDistance() +
queryNode.FurthestDescendantDistance());
auxDistance = bestPointDistance;

// Loop over children of the node, and use their cached information to
// assemble bounds.
for (size_t i = 0; i < queryNode.NumChildren(); ++i)
{
const double firstBound = queryNode.Child(i).Stat().FirstBound();
const double adjustment = std::max(0.0,
queryNode.FurthestDescendantDistance() -
queryNode.Child(i).FurthestDescendantDistance());
const double adjustedSecondBound = SortPolicy::CombineWorst(
queryNode.Child(i).Stat().SecondBound(), 2 * adjustment);
const double auxBound = queryNode.Child(i).Stat().AuxBound();

if (SortPolicy::IsBetter(worstDistance, firstBound))
worstDistance = firstBound;
if (SortPolicy::IsBetter(adjustedSecondBound, bestDistance))
bestDistance = adjustedSecondBound;
if (SortPolicy::IsBetter(auxBound, auxDistance))
auxDistance = auxBound;
}

// Add triangle inequality adjustment to best distance. It is possible this
// could be tighter for some certain types of trees.
bestDistance = SortPolicy::CombineWorst(auxDistance,
2 * queryNode.FurthestDescendantDistance());

// Add triangle inequality adjustment to best distance of points in node.
bestPointDistance = SortPolicy::CombineWorst(bestPointDistance,
queryNode.FurthestPointDistance() +
queryNode.FurthestDescendantDistance());

if (SortPolicy::IsBetter(bestPointDistance, bestDistance))
bestDistance = bestPointDistance;

// At this point:
// worstDistance holds the value of B_1(N_q).
// bestDistance holds the value of B_2(N_q).
// auxDistance holds the value of B_aux(N_q).

// Now consider the parent bounds.
if (queryNode.Parent() != NULL)
{
Expand Down Expand Up @@ -405,6 +417,7 @@ inline double NeighborSearchRules<SortPolicy, MetricType, TreeType>::
// Cache bounds for later.
queryNode.Stat().FirstBound() = worstDistance;
queryNode.Stat().SecondBound() = bestDistance;
queryNode.Stat().AuxBound() = auxDistance;

if (SortPolicy::IsBetter(worstDistance, bestDistance))
return worstDistance;
Expand Down
10 changes: 10 additions & 0 deletions src/mlpack/methods/neighbor_search/neighbor_search_stat.hpp
Expand Up @@ -29,6 +29,9 @@ class NeighborSearchStat
//! using the best descendant candidate distance modified by the furthest
//! descendant distance.
double secondBound;
//! The aux bound on the node's neighbor distances (B_aux). This represents
//! the best descendant candidate distance (used to calculate secondBound).
double auxBound;
//! The better of the two bounds.
double bound;

Expand All @@ -45,6 +48,7 @@ class NeighborSearchStat
NeighborSearchStat() :
firstBound(SortPolicy::WorstDistance()),
secondBound(SortPolicy::WorstDistance()),
auxBound(SortPolicy::WorstDistance()),
bound(SortPolicy::WorstDistance()),
lastDistance(0.0) { }

Expand All @@ -56,6 +60,7 @@ class NeighborSearchStat
NeighborSearchStat(TreeType& /* node */) :
firstBound(SortPolicy::WorstDistance()),
secondBound(SortPolicy::WorstDistance()),
auxBound(SortPolicy::WorstDistance()),
bound(SortPolicy::WorstDistance()),
lastDistance(0.0) { }

Expand All @@ -67,6 +72,10 @@ class NeighborSearchStat
double SecondBound() const { return secondBound; }
//! Modify the second bound.
double& SecondBound() { return secondBound; }
//! Get the aux bound.
double AuxBound() const { return auxBound; }
//! Modify the aux bound.
double& AuxBound() { return auxBound; }
//! Get the overall bound (the better of the two bounds).
double Bound() const { return bound; }
//! Modify the overall bound (it should be the better of the two bounds).
Expand All @@ -84,6 +93,7 @@ class NeighborSearchStat

ar & CreateNVP(firstBound, "firstBound");
ar & CreateNVP(secondBound, "secondBound");
ar & CreateNVP(auxBound, "auxBound");
ar & CreateNVP(bound, "bound");
ar & CreateNVP(lastDistance, "lastDistance");
}
Expand Down

0 comments on commit 558a4fd

Please sign in to comment.