Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

start project

  • Loading branch information...
commit cac6f6c0c5cbfe91a11dfcbcb606270b3a011b96 0 parents
@DNCrane authored
669 Cover_Tree.h
@@ -0,0 +1,669 @@
+/*
+ * opencog/util/Cover_Tree.h
+ *
+ * Copyright (C) 2011 by Singularity Institute for Artificial Intelligence
+ * All Rights Reserved
+ *
+ * Written by David Crane <dncrane@gmail.com>
+ *
+ * This program is free software; you can redistribute it and/or modify
+ * it under the terms of the GNU Affero General Public License v3 as
+ * published by the Free Software Foundation and including the exceptions
+ * at http://opencog.org/wiki/Licenses
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU Affero General Public License
+ * along with this program; if not, write to:
+ * Free Software Foundation, Inc.,
+ * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
+ */
+
+#ifndef _COVER_TREE_H
+#define _COVER_TREE_H
+
+#include <vector>
+#include <map>
+#include <set>
+#include <cmath>
+#include <float.h>
+#include <iostream>
+
+/**
+ * Cover Tree. Allows for insertion, removal, and k-nearest-neighbor
+ * queries.
+ *
+ * The user should define double Point::distance(const Point& p) and
+ * bool Point::operator==(const Point& p), where
+ * p1.distance(p2)==0 doesn't necessarily mean that p1==p2).
+ *
+ * For example, a point could consist of a vector and a string
+ * name, where their distance measure is simply euclidean distance but to be
+ * equal they must have the same name in addition to having distance 0.
+ */
+template<class Point>
+class CoverTree
+{
+ /**
+ * Cover tree node. Consists of arbitrarily many points P, as long as
+ * they have distance 0 to each other. Keeps track of its children.
+ */
+ class CoverTreeNode
+ {
+ private:
+ //_childMap[i] is a vector of the node's children at level i
+ std::map<int,std::vector<CoverTreeNode* > > _childMap;
+ //_points is all of the points with distance 0 which are not equal.
+ std::vector<Point> _points;
+ public:
+ CoverTreeNode(const Point& p);
+ /**
+ * Returns the children of the node at level i. Note that this means
+ * the children exist in cover set i-1, not level i.
+ *
+ * Does not include the node itself, though technically every node
+ * has itself as a child in a cover tree.
+ */
+ std::vector<CoverTreeNode* > getChildren(int level) const;
+ void addChild(int level, CoverTreeNode* p);
+ void removeChild(int level, CoverTreeNode* p);
+ void addPoint(const Point& p);
+ void removePoint(const Point& p);
+ const std::vector<Point>& getPoints() { return _points; }
+ double distance(const CoverTreeNode& p) const;
+
+ bool isSingle() const;
+
+ const Point& getPoint() const;
+
+ /**
+ * Return every child of the node from any level. This is handy for
+ * the destructor.
+ */
+ std::vector<CoverTreeNode* > getAllChildren() const;
+ }; // CoverTreeNode class
+ private:
+ CoverTreeNode* _root;
+ unsigned int _numNodes;
+ int _maxLevel;//base^_maxLevel should be the max distance
+ //between any 2 points
+ int _minLevel;//A level beneath which there are no more new nodes.
+
+ std::vector<CoverTreeNode* >
+ kNearestNodes(const Point& p, const unsigned int& k);
+ /**
+ * Recursive implementation of the insert algorithm (see paper).
+ */
+ bool insert_rec(const Point& p,
+ const std::vector<std::pair<double, CoverTreeNode* > >& Qi,
+ const int& level);
+
+ /**
+ * Finds the node in Q with the minimum distance to p. Returns a
+ * pair consisting of this node and the distance.
+ */
+ std::pair<double, CoverTreeNode* >
+ distance(const Point& p,
+ const std::vector<CoverTreeNode* >& Q);
+
+
+ void remove_rec(const Point& p,
+ std::map<int,std::vector<std::pair<double,CoverTreeNode* > > >& coverSets,
+ int level,
+ bool& multi);
+
+ public:
+ static const double base = 2.0;
+
+ /**
+ * Constructs a cover tree which begins with all points in points.
+ *
+ * maxDist should be the maximum distance that any two points
+ * can have between each other. IE p.distance(q) < maxDist for all
+ * p,q that you will ever try to insert. The cover tree may be invalid
+ * if an inaccurate maxDist is given.
+ */
+
+ CoverTree(const double& maxDist,
+ const std::vector<Point>& points=std::vector<Point>());
+ ~CoverTree();
+
+ /**
+ * Just for testing/debugging. Returns true iff the cover tree satisfies the
+ * the covering tree invariants (every node in level i is greater than base^i
+ * distance from every other node, and every node in level i is less than
+ * or equal to base^i distance from its children). See the cover tree
+ * papers for details.
+ */
+ bool isValidTree();
+
+ /**
+ * Insert newPoint into the cover tree. If newPoint is already present,
+ * (that is, newPoint==p for some p already in the tree), then the tree
+ * is unchanged. If p.distance(newPoint)==0.0 but newPoint!=p, then
+ * newPoint WILL be inserted and both points may be returned in k-nearest-
+ * neighbor searches.
+ */
+ void insert(const Point& newPoint);
+
+ /**
+ * Remove point p from the cover tree. If p is not present in the tree,
+ * it will remain unchanged. Otherwise, this will remove exactly one
+ * point q from the tree satisfying p==q.
+ */
+ void remove(const Point& p);
+
+ /**
+ * Returns the k nearest points to p in order (the 0th element of the vector
+ * is closest to p, 1th is next, etc). It may return greater than k points
+ * if there is a tie for the kth place.
+ */
+ std::vector<Point> kNearestNeighbors(const Point& p, const unsigned int& k);
+
+ CoverTreeNode* getRoot() const;
+
+ /**
+ * Print the cover tree.
+ */
+ void print() const;
+}; // CoverTree class
+
+template<class Point>
+CoverTree<Point>::CoverTree(const double& maxDist,
+ const std::vector<Point>& points)
+{
+ _root=NULL;
+ _numNodes=0;
+ _maxLevel=ceilf(log(maxDist)/log(base));
+ _minLevel=_maxLevel-1;
+ typename std::vector<Point>::const_iterator it;
+ for(it=points.begin(); it!=points.end(); it++) {
+ this->insert(*it);
+ }
+}
+
+template<class Point>
+CoverTree<Point>::~CoverTree()
+{
+ if(_root==NULL) return;
+ //Get all of the root's children (from any level),
+ //delete the root, repeat for each of the children
+ std::vector<CoverTreeNode* > nodes;
+ nodes.push_back(_root);
+ while(!nodes.empty()) {
+ CoverTreeNode* byeNode = nodes[0];
+ nodes.erase(nodes.begin());
+ std::vector<CoverTreeNode* > children = byeNode->getAllChildren();
+ nodes.insert(nodes.begin(),children.begin(),children.end());
+ //std::cout << _numNodes << "\n";
+ delete byeNode;
+ _numNodes--;
+ }
+
+}
+
+template<class Point>
+std::vector<typename CoverTree<Point>::CoverTreeNode*>
+CoverTree<Point>::kNearestNodes(const Point& p, const unsigned int& k)
+{
+ if(_root==NULL) return std::vector<CoverTreeNode* >();
+ //maxDist is the kth nearest known point to p, and also the farthest
+ //point from p in the set minNodes defined below.
+ double maxDist = p.distance(_root->getPoint());
+ //minNodes stores the k nearest known points to p.
+ std::set<std::pair<double, CoverTreeNode* > > minNodes;
+
+ minNodes.insert(make_pair(maxDist,
+ _root));
+ std::vector<std::pair<double,CoverTreeNode* > >
+ Qj(1,make_pair(maxDist,_root));
+ for(int level = _maxLevel; level>=_minLevel;level--) {
+ typename std::vector<std::pair<double,CoverTreeNode* > >::const_iterator it;
+ int size = Qj.size();
+ for(int i=0; i<size; i++) {
+ std::vector<CoverTreeNode* > children =
+ Qj[i].second->getChildren(level);
+ typename std::vector<CoverTreeNode* >::const_iterator it2;
+ for(it2=children.begin(); it2!=children.end(); it2++) {
+ double d = p.distance((*it2)->getPoint());
+ if(d < maxDist || minNodes.size() < k) {
+ minNodes.insert(make_pair(d,*it2));
+ //--minNodes.end() gives us an iterator to the greatest
+ //element of minNodes.
+ if(minNodes.size() > k) minNodes.erase(--minNodes.end());
+ maxDist = (--minNodes.end())->first;
+ }
+ Qj.push_back(make_pair(d,*it2));
+ }
+ }
+ double sep = maxDist + pow(base, level);
+ size = Qj.size();
+ for(int i=0; i<size; i++) {
+ if(Qj[i].first > sep) {
+ //quickly removes an element from a vector w/o preserving order.
+ Qj[i]=Qj.back();
+ Qj.pop_back();
+ size--; i--;
+ }
+ }
+ }
+ std::vector<CoverTreeNode* > kNN;
+ typename std::set<std::pair<double, CoverTreeNode* > >::iterator it;
+ for(it=minNodes.begin();it!=minNodes.end();it++) {
+ kNN.push_back(it->second);
+ }
+ return kNN;
+}
+template<class Point>
+bool CoverTree<Point>::insert_rec(const Point& p,
+ const std::vector<std::pair<double, CoverTreeNode* > >& Qi,
+ const int& level)
+{
+ std::vector<std::pair<double, CoverTreeNode*> > Qj;
+ double sep = pow(base,level);
+ double minDist = DBL_MAX;
+ std::pair<double,CoverTreeNode*> minQiDist(DBL_MAX,NULL);
+ typename std::vector<std::pair<double, CoverTreeNode*> >::const_iterator it;
+ for(it=Qi.begin(); it!=Qi.end(); it++) {
+ if(it->first<minQiDist.first) minQiDist = *it;
+ if(it->first<minDist) minDist=it->first;
+ if(it->first<=sep) Qj.push_back(*it);
+ std::vector<CoverTreeNode*> children = it->second->getChildren(level);
+ typename std::vector<CoverTreeNode*>::const_iterator it2;
+ for(it2=children.begin();it2!=children.end();it2++) {
+ double d = p.distance((*it2)->getPoint());
+ if(d<minDist) minDist = d;
+ if(d<=sep) {
+ Qj.push_back(make_pair(d,*it2));
+ }
+ }
+ }
+ //std::cout << "level: " << level << ", sep: " << sep << ", dist: " << minQDist.first << "\n";
+ if(minDist > sep) {
+ return true;
+ } else {
+ bool found = insert_rec(p,Qj,level-1);
+ //std::pair<double,CoverTreeNode* > minQiDist = distance(p,Qi);
+ if(found && minQiDist.first <= sep) {
+ if(level-1<_minLevel) _minLevel=level-1;
+ minQiDist.second->addChild(level,
+ new CoverTreeNode(p));
+ //std::cout << "parent is ";
+ //minQiDist.second->getPoint().print();
+ _numNodes++;
+ return false;
+ } else {
+ return found;
+ }
+ }
+}
+
+template<class Point>
+void CoverTree<Point>::remove_rec(const Point& p,
+ std::map<int,std::vector<std::pair<double,CoverTreeNode* > > >& coverSets,
+ int level,
+ bool& multi)
+{
+ std::vector<std::pair<double, CoverTreeNode* > >& Qi = coverSets[level];
+ std::vector<std::pair<double, CoverTreeNode* > >& Qj = coverSets[level-1];
+ double minDist = DBL_MAX;
+ CoverTreeNode* minNode = _root;
+ CoverTreeNode* parent = 0;
+ double sep = pow(base, level);
+ typename std::vector<std::pair<double, CoverTreeNode* > >::const_iterator it;
+ //set Qj to be all children q of Qi such that p.distance(q)<=sep
+ //and also keep track of the minimum distance from p to a node in Qj
+ //note that every node has itself as a child, but the
+ //getChildren function only returns non-self-children.
+ for(it=Qi.begin();it!=Qi.end();it++) {
+ std::vector<CoverTreeNode* > children = it->second->getChildren(level);
+ double dist = it->first;
+ if(dist<minDist) {
+ minDist = dist;
+ minNode = it->second;
+ }
+ if(dist <= sep) {
+ Qj.push_back(*it);
+ }
+ typename std::vector<CoverTreeNode* >::iterator it2;
+ for(it2=children.begin();it2!=children.end();it2++) {
+ dist = p.distance((*it2)->getPoint());
+ if(dist<minDist) {
+ minDist = dist;
+ minNode = *it2;
+ if(dist == 0.0) parent = it->second;
+ }
+ if(dist <= sep) {
+ Qj.push_back(make_pair(dist,*it2));
+ }
+ }
+ }
+ if(level>_minLevel) remove_rec(p,coverSets,level-1,multi);
+ if(minDist == 0.0) {//if minDist is 0.0 then minNode must be removed.
+ if(!minNode->isSingle()) {
+ minNode->removePoint(p);
+ multi=true;
+ return;
+ }
+ //the multi flag indicates the point we removed is from a
+ //node containing multiple points, and we have removed it.
+ if(multi) return;
+ if(parent!=NULL) parent->removeChild(level, minNode);
+ std::vector<CoverTreeNode* > children = minNode->getChildren(level-1);
+ std::vector<std::pair<double, CoverTreeNode* > >& Q = coverSets[level-1];
+ if(Q.size()==1 && Q[0].second==minNode) {
+ Q.pop_back();
+ } else {
+ for(unsigned int i=0;i<Q.size();i++) {
+ if(Q[i].second==minNode) {
+ Q[i]=Q.back();
+ Q.pop_back();
+ break;
+ }
+ }
+ }
+ typename std::vector<CoverTreeNode* >::const_iterator it;
+ for(it=children.begin();it!=children.end();it++) {
+ int i = level-1;
+ Point q = (*it)->getPoint();
+ double minDQ = DBL_MAX;
+ CoverTreeNode* minDQNode;
+ double sep = pow(base,i);
+ bool br=false;
+ while(true) {
+ std::vector<std::pair<double, CoverTreeNode* > >&
+ Q = coverSets[i];
+ typename std::vector<std::pair<double, CoverTreeNode* > >::const_iterator it2;
+ minDQ = DBL_MAX;
+ for(it2=Q.begin();it2!=Q.end();it2++) {
+ double d = q.distance(it2->second->getPoint());
+ if(d<minDQ) {
+ minDQ = d;
+ minDQNode = it2->second;
+ if(d <=sep) {
+ br=true;
+ break;
+ }
+ }
+ }
+ minDQ=DBL_MAX;
+ if(br) break;
+ Q.push_back(make_pair((*it)->distance(p),*it));
+ i++;
+ sep = pow(base,i);
+ }
+ //minDQNode->getPoint().print();
+ //std::cout << " is level " << i << " parent of ";
+ //(*it)->getPoint().print();
+ minDQNode->addChild(i,*it);
+ }
+ if(parent!=NULL) {
+ delete minNode;
+ _numNodes--;
+ }
+ }
+}
+
+template<class Point>
+std::pair<double, typename CoverTree<Point>::CoverTreeNode* >
+CoverTree<Point>::distance(const Point& p,
+ const std::vector<CoverTreeNode* >& Q)
+{
+ double minDist = DBL_MAX;
+ CoverTreeNode* minNode;
+ typename std::vector<CoverTreeNode* >::const_iterator it;
+ for(it=Q.begin();it!=Q.end();it++) {
+ double dist = p.distance((*it)->getPoint());
+ if(dist < minDist) {
+ minDist = dist;
+ minNode = *it;
+ }
+ }
+ return std::pair<double, CoverTreeNode* >(minDist,minNode);
+}
+
+template<class Point>
+void CoverTree<Point>::insert(const Point& newPoint)
+{
+ if(_root==NULL) {
+ _root = new CoverTreeNode(newPoint);
+ _numNodes=1;
+ return;
+ }
+ //TODO: this is pretty inefficient, there may be a better way
+ //to check if the node already exists...
+ CoverTreeNode* n = kNearestNodes(newPoint,1)[0];
+ if(newPoint.distance(n->getPoint())==0.0) {
+ n->addPoint(newPoint);
+ } else {
+ //insert_rec acts under the assumption that there are no nodes with
+ //distance 0 to newPoint in the cover tree (the previous lines check it)
+ insert_rec(newPoint,
+ std::vector<std::pair<double, CoverTreeNode* > >
+ (1,make_pair(_root->distance(newPoint),_root)),
+ _maxLevel);
+ }
+}
+
+template<class Point>
+void CoverTree<Point>::remove(const Point& p)
+{
+ //Most of this function's code is for the special case of removing the root
+ if(_root==NULL) return;
+ bool removingRoot=p.distance(_root->getPoint())==0.0;
+ CoverTreeNode* newRoot=NULL;
+ if(removingRoot) {
+ if(_numNodes==1) {
+ //removing the last node...
+ delete _root;
+ _numNodes--;
+ _root=NULL;
+ return;
+ } else {
+ for(int i=_maxLevel;i>_minLevel;i--) {
+ if(!(_root->getChildren(i).empty())) {
+ newRoot = _root->getChildren(i).back();
+ _root->removeChild(i,newRoot);
+ break;
+ }
+ }
+ }
+ }
+ std::map<int, std::vector<std::pair<double, CoverTreeNode* > > > coverSets;
+ coverSets[_maxLevel].push_back(make_pair(_root->distance(p),_root));
+ if(removingRoot)
+ coverSets[_maxLevel].push_back(make_pair(newRoot->distance(p),newRoot));
+ bool multi = false;
+ remove_rec(p,coverSets,_maxLevel,multi);
+ if(removingRoot) {
+ delete _root;
+ _numNodes--;
+ _root=newRoot;
+ }
+}
+
+template<class Point>
+std::vector<Point> CoverTree<Point>::kNearestNeighbors(const Point& p,
+ const unsigned int& k)
+{
+ if(_root==NULL) return std::vector<Point>();
+ std::vector<CoverTreeNode* > v = kNearestNodes(p, k);
+ std::vector<Point> kNN;
+ typename std::vector<CoverTreeNode* >::iterator it;
+ for(it=v.begin();it!=v.end();it++) {
+ const std::vector<Point>& p = (*it)->getPoints();
+ kNN.insert(kNN.end(),p.begin(),p.end());
+ if(kNN.size() >= k) break;
+ }
+ return kNN;
+}
+
+template<class Point>
+void CoverTree<Point>::print() const
+{
+ int d = _maxLevel-_minLevel+1;
+ std::vector<CoverTreeNode* > Q;
+ Q.push_back(_root);
+ for(int i=0;i<d;i++) {
+ std::cout << "LEVEL " << _maxLevel-i << "\n";
+ typename std::vector<CoverTreeNode* >::const_iterator it;
+ for(it=Q.begin();it!=Q.end();it++) {
+ (*it)->getPoint().print();
+ std::vector<CoverTreeNode* >
+ children = (*it)->getChildren(_maxLevel-i);
+ typename std::vector<CoverTreeNode* >::const_iterator it2;
+ for(it2=children.begin();it2!=children.end();it2++) {
+ std::cout << " ";
+ (*it2)->getPoint().print();
+ }
+ }
+ std::vector<CoverTreeNode* > newQ;
+ for(it=Q.begin();it!=Q.end();it++) {
+ std::vector<CoverTreeNode* >
+ children = (*it)->getChildren(_maxLevel-i);
+ newQ.insert(newQ.end(),children.begin(),children.end());
+ }
+ Q.insert(Q.end(),newQ.begin(),newQ.end());
+ std::cout << "\n\n";
+ }
+}
+
+template<class Point>
+typename CoverTree<Point>::CoverTreeNode* CoverTree<Point>::getRoot() const
+{
+ return _root;
+}
+
+template<class Point>
+CoverTree<Point>::CoverTreeNode::CoverTreeNode(const Point& p) {
+ _points.push_back(p);
+}
+
+template<class Point>
+std::vector<typename CoverTree<Point>::CoverTreeNode*>
+CoverTree<Point>::CoverTreeNode::getChildren(int level) const
+{
+ typename std::map<int,std::vector<CoverTreeNode* > >::const_iterator
+ it = _childMap.find(level);
+ if(it!=_childMap.end()) {
+ return it->second;
+ }
+ return std::vector<CoverTreeNode* >();
+}
+
+template<class Point>
+void CoverTree<Point>::CoverTreeNode::addChild(int level, CoverTreeNode* p)
+{
+ _childMap[level].push_back(p);
+}
+
+template<class Point>
+void CoverTree<Point>::CoverTreeNode::removeChild(int level, CoverTreeNode* p)
+{
+ std::vector<CoverTreeNode* >& v = _childMap[level];
+ for(unsigned int i=0;i<v.size();i++) {
+ if(v[i]==p) {
+ v[i]=v.back();
+ v.pop_back();
+ break;
+ }
+ }
+}
+
+template<class Point>
+void CoverTree<Point>::CoverTreeNode::addPoint(const Point& p)
+{
+ typename std::vector<Point>::iterator it;
+ for(it=_points.begin();it!=_points.end();it++) {
+ if(*it==p) return;
+ }
+ _points.push_back(p);
+}
+
+template<class Point>
+void CoverTree<Point>::CoverTreeNode::removePoint(const Point& p)
+{
+ typename std::vector<Point>::iterator it;
+ for(it=_points.begin();it!=_points.end();it++) {
+ if(*it==p) {
+ _points.erase(it);
+ return;
+ }
+ }
+}
+
+template<class Point>
+double CoverTree<Point>::CoverTreeNode::distance(const CoverTreeNode& p) const
+{
+ return _points[0].distance(p.getPoint());
+}
+
+template<class Point>
+bool CoverTree<Point>::CoverTreeNode::isSingle() const
+{
+ if(_points.size()>1) return false;
+ return true;
+}
+
+template<class Point>
+const Point& CoverTree<Point>::CoverTreeNode::getPoint() const { return _points[0]; }
+
+template<class Point>
+std::vector<typename CoverTree<Point>::CoverTreeNode*>
+CoverTree<Point>::CoverTreeNode::getAllChildren() const
+{
+ std::vector<CoverTreeNode* > children;
+ typename std::map<int,std::vector<CoverTreeNode* > >::const_iterator it;
+ typename std::vector<CoverTreeNode* >::const_iterator it2;
+ for(it=_childMap.begin();it!=_childMap.end();it++) {
+ children.insert(children.end(), it->second.begin(), it->second.end());
+ }
+ return children;
+}
+
+template<class Point>
+bool CoverTree<Point>::isValidTree() {
+ if(_numNodes==0) {
+ if(_root==NULL) return true;
+ else return false;
+ }
+ std::vector<CoverTreeNode* > nodes;
+ nodes.push_back(_root);
+ for(int i=_maxLevel;i>_minLevel;i--) {
+ double sep = pow(base,i);
+ typename std::vector<CoverTreeNode* >::iterator it, it2;
+ //verify separation invariant of cover tree: for each level,
+ //every point is farther than base^level away
+ for(it=nodes.begin(); it!=nodes.end(); it++) {
+ for(it2=nodes.begin(); it2!=nodes.end(); it2++) {
+ double dist=(*it)->distance((*it2)->getPoint());
+ if(dist<=sep && dist!=0.0) {
+ std::cout << "Level " << i << " Separation invariant failed.\n";
+ return false;
+ }
+ }
+ }
+ std::vector<CoverTreeNode* > allChildren;
+ for(it=nodes.begin(); it!=nodes.end(); it++) {
+ std::vector<CoverTreeNode* > children = (*it)->getChildren(i);
+ //verify covering tree invariant: the children of node n at level
+ //i are no further than base^i away
+ for(it2=children.begin(); it2!=children.end(); it2++) {
+ double dist = (*it2)->distance((*it)->getPoint());
+ if(dist>sep) {
+ std::cout << "Level" << i << " covering tree invariant failed.n";
+ return false;
+ }
+ }
+ allChildren.insert
+ (allChildren.end(),children.begin(),children.end());
+ }
+ nodes.insert(nodes.begin(),allChildren.begin(),allChildren.end());
+ }
+ return true;
+}
+#endif // _COVER_TREE_H
+
42 Cover_Tree_Point.cc
@@ -0,0 +1,42 @@
+#include "Cover_Tree_Point.h"
+#include <vector>
+#include <iostream>
+#include <cmath>
+
+using namespace std;
+
+double CoverTreePoint::distance(const CoverTreePoint& p) const {
+ static int timescalled = 0;
+ //if(timescalled%1000000==0) cout << timescalled << "\n";
+ timescalled++;
+ const vector<double>& vec=p.getVec();
+ double dist=0;
+ int lim = vec.size();
+ for(int i=0; i<lim;i++) {
+ double d = vec[i]-_vec[i];
+ dist+=d*d;
+ }
+ dist=sqrt(dist);
+ return dist;
+}
+
+const vector<double>& CoverTreePoint::getVec() const {
+ return _vec;
+}
+
+const char& CoverTreePoint::getChar() const {
+ return _name;
+}
+
+void CoverTreePoint::print() const {
+ vector<double>::const_iterator it;
+ cout << "point " << _name << ": ";
+ for(it=_vec.begin();it!=_vec.end();it++) {
+ cout << *it << " ";
+ }
+ cout << "\n";
+}
+
+bool CoverTreePoint::operator==(const CoverTreePoint& p) const {
+ return (this->distance(p)==0.0 && _name==p.getChar());
+}
24 Cover_Tree_Point.h
@@ -0,0 +1,24 @@
+#ifndef _COVER_TREE_POINT_H
+#define _COVER_TREE_POINT_H
+
+#include <vector>
+
+/**
+ * A simple point class containing a vector of doubles and a single char name.
+ */
+class CoverTreePoint
+{
+private:
+ std::vector<double> _vec;
+ char _name;
+public:
+ CoverTreePoint(std::vector<double> v, char name) : _vec(v), _name(name) {}
+ double distance(const CoverTreePoint& p) const;
+ const std::vector<double>& getVec() const;
+ const char& getChar() const;
+ void print() const;
+ bool operator==(const CoverTreePoint&) const;
+};
+
+#endif // _COVER_TREE_POINT_H
+
12 Makefile
@@ -0,0 +1,12 @@
+FLAGS=-Wall -O3 -ffast-math -funroll-loops
+
+all: test
+
+Cover_Tree_Point.o: Cover_Tree_Point.h Cover_Tree_Point.cc
+ g++ -c $(FLAGS) Cover_Tree_Point.cc
+
+test: test.cc Cover_Tree.h Cover_Tree_Point.o Cover_Tree_Point.h Cover_Tree_Point.cc
+ g++ $(FLAGS) -o test test.cc Cover_Tree.h Cover_Tree_Point.o
+
+clean:
+ rm *.o
44 README
@@ -0,0 +1,44 @@
+This is a C++ implementation of the cover tree datastructure.
+
+Relevant links:
+https://secure.wikimedia.org/wikipedia/en/wiki/Cover_tree - Wikipedia's page
+on cover trees.
+http://hunch.net/~jl/projects/cover_tree/cover_tree.html - John Langford's (one
+of the inventors of cover trees) page on cover trees with links to papers. This
+implementation implements the cover tree algorithms for insert, removal, and
+k-nearest-neighbor search as described in the papers.
+
+To use the Cover Tree, you must implement your own Point class. CoverTreePoint
+is provided for testing and as an example. Your Point class must implement the
+following functions:
+
+double distance(const YourPoint& p);
+bool operator==(const YourPoint& p);
+and optionally (for debugging/printing only):
+void print();
+
+The distance function must be a Metric, meaning (from Wikipedia):
+1: d(x, y) = 0 if and only if x = y
+2: d(x, y) = d(y, x) (symmetry)
+3: d(x, z) =< d(x, y) + d(y, z) (subadditivity / triangle inequality).
+
+See https://secure.wikimedia.org/wikipedia/en/wiki/Metric_%28mathematics%29
+for details.
+
+actually, 1 does not exactly need to hold for this implementation; you can
+provide, for example, names for your points which are unrelated to distance
+but important for equality. You can insert multiple points with distance 0 to
+each other and the tree will keep track of them, but you cannot insert multiple
+points that are equal to each other; attempting to insert a point that
+already exists in the tree will not alter the tree at all.
+
+If you do not want to allow multiple nodes with distance 0, then just make
+your equality operator always return true when distance is 0. If you want
+to allow multiple points of distance 0 but do not want to add a name or other
+identifier to them, you can have your equality operator always return false.
+
+TODO:
+-The papers describe batch insert and batch-nearest-neighbors algorithms which
+may be worth implementing.
+-Try using a third "upper bound" argument for distance functions, beyond which
+the distance does not need to be calculated, to improve efficiency in practice.
147 test.cc
@@ -0,0 +1,147 @@
+#include "Cover_Tree.h"
+#include "Cover_Tree_Point.h"
+#include <vector>
+#include <iostream>
+#include <cstdlib>
+
+using namespace std;
+
+void testTree() {
+ vector<double> a;
+ a.push_back(1.0);
+ CoverTree<CoverTreePoint> cTree(10);
+ cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=2.1; cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=3.2; cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=4.1; cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=1.1; cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=2.5; cTree.insert(CoverTreePoint(a,'a'));
+ a[0]=3.1; cTree.insert(CoverTreePoint(a,'a'));
+ if(cTree.isValidTree()) cout << "Insert test: \t\t\t\tPassed\n";
+ else cout << "Insert test: \t\t\t\tFailed\n";
+
+ a[0]=2.0; // the 5 nearest points to this are 2.1, 2.5, 1.1, 1, 3.1
+ vector<CoverTreePoint>
+ points = cTree.kNearestNeighbors(CoverTreePoint(a,'a'),5);
+ bool kNNGood=true;
+ //for(int i =0; i<points.size(); i++) {
+ // points[i].print();
+ //}
+ a[0]=2.1; if(!(CoverTreePoint(a,'a')==points[0])) kNNGood=false;
+ a[0]=2.5; if(!(CoverTreePoint(a,'a')==points[1])) kNNGood=false;
+ a[0]=1.1; if(!(CoverTreePoint(a,'a')==points[2])) kNNGood=false;
+ a[0]=1.0; if(!(CoverTreePoint(a,'a')==points[3])) kNNGood=false;
+ a[0]=3.1; if(!(CoverTreePoint(a,'a')==points[4])) kNNGood=false;
+ if(kNNGood) cout << "KNN test: \t\t\t\tPassed\n";
+ else cout << "KNN test: \t\t\t\tFailed\n";
+
+ cTree.insert(CoverTreePoint(a,'b'));
+ cTree.insert(CoverTreePoint(a,'c'));
+
+
+ points = cTree.kNearestNeighbors(CoverTreePoint(a,'a'), 1);
+ //there should be a three-way tie since there are 3 nodes with distance 0
+ if(points.size()==3 && points[2].distance(CoverTreePoint(a,'a'))==0.0)
+ cout << "Multiple 0 distance points test: \tPassed\n";
+ else cout << "Multiple 0 distance points test: \tFailed\n";
+
+ cTree.remove(CoverTreePoint(a,'b'));
+ points = cTree.kNearestNeighbors(CoverTreePoint(a,'a'), 1);
+ //there should be a two-way tie now since one was removed
+ if(points.size()==2 && points[1].distance(CoverTreePoint(a,'a'))==0.0)
+ cout << "Multiple 0 distance points removal test:Passed\n";
+ else cout << "Multiple 0 distance points removal test:Failed\n";
+
+ a[0]=2.124; cTree.remove(CoverTreePoint(a,'a'));
+ a[0]=4.683; cTree.remove(CoverTreePoint(a,'a'));
+ a[0]=9.123; cTree.remove(CoverTreePoint(a,'a'));
+ if(cTree.isValidTree())
+ cout << "Remove nonexistent point test: \t\tPassed\n";
+ else cout << "Remove nonexistent point test: \t\tFailed\n";
+
+ a[0]=3.2; cTree.remove(CoverTreePoint(a,'a'));
+ a[0]=1.1; cTree.remove(CoverTreePoint(a,'a'));
+ a[0]=2.5; cTree.remove(CoverTreePoint(a,'a'));
+ if(cTree.isValidTree()) cout << "Remove test: \t\t\t\tPassed\n";
+ else cout << "Remove test: \t\t\t\tPassed\n";
+
+ a[0]=1.0; cTree.remove(CoverTreePoint(a,'a'));
+ if(cTree.isValidTree()) cout << "Remove root test: \t\t\tPassed\n";
+ else cout << "Remove root test: \t\t\tPassed\n";
+
+ vector<double> start;
+ for(int j=0;j<5;j++) start.push_back((double)rand()/(double)RAND_MAX);
+
+ vector<CoverTreePoint> initVec;
+ initVec.push_back(CoverTreePoint(start,'a'));
+
+ CoverTree<CoverTreePoint>
+ cTree2(10,initVec);
+ cTree2.remove(CoverTreePoint(start,'a'));//Now the tree has no nodes...
+ //make sure it can handle trying to remove a node when
+ //it has no nodes to begin with...
+ cTree2.remove(CoverTreePoint(start,'a'));
+
+ points = vector<CoverTreePoint>();
+ for(int i=0;i<500;i++) {
+ vector<double> a;
+ for(int j=0;j<5;j++) {
+ a.push_back((double)rand()/(double)RAND_MAX);
+ }
+ cTree2.insert(CoverTreePoint(a,'a'));
+ points.push_back(CoverTreePoint(a,'a'));
+ }
+ if(cTree2.isValidTree()) cout << "500 random inserts test: \t\tPassed\n";
+ else cout << "500 random inserts test: \t\tPassed\n";
+
+ bool NNGood=true;
+ for(int i=0;i<100;i++) {
+ vector<CoverTreePoint>
+ v = cTree2.kNearestNeighbors(points[i],1);
+ if(v[0].distance(points[i])!=0.0) NNGood=false;
+ }
+ if(NNGood) cout << "Nearest Neighbor test: \t\t\tPassed\n";
+ else cout << "Nearest Neighbor test: \t\t\tFailed\n";
+
+ for(int i=0;i<490;i++) {
+ cTree2.remove(points[i]);
+ }
+ if(cTree2.isValidTree()) cout << "Remove random test: \t\t\tPassed\n";
+ else cout << "Remove random test: \t\t\tFailed\n";
+}
+
+void bigTest(unsigned int numNodes, unsigned int numDimensions){
+ vector<CoverTreePoint> points;
+ for(unsigned int i=0;i<numNodes;i++) {
+ vector<double> a;
+ for(unsigned int j=0;j<numDimensions;j++) {
+ a.push_back((double)rand()/(double)RAND_MAX);
+ }
+ points.push_back(CoverTreePoint(a,'a'));
+ }
+ cout << "Building Cover Tree with " << numNodes << " nodes\n";
+ CoverTree<CoverTreePoint> cTree(51,points);
+ cout << "Cover tree built.\n";
+
+ cout << "2000 random NN searches beginning...\n";
+ for(int i=0;i<2000;i++) {
+ vector<double> a;
+ for(unsigned int j=0;j<numDimensions;j++) {
+ a.push_back((double)rand()/(double)RAND_MAX);
+ }
+ cTree.kNearestNeighbors(CoverTreePoint(a,'a'),1);
+ }
+ cout << "NN searches done.\n";
+
+ cout << "Removing all nodes...\n";
+ for(unsigned int i=0; i<numNodes; i++) {
+ cTree.remove(points[i]);
+ }
+ cout << "Removal done.\n";
+}
+int main()
+{
+ testTree();
+ bigTest(3000,50);
+ return 0;
+}
Please sign in to comment.
Something went wrong with that request. Please try again.