Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] KDTree optimization #7

Merged
merged 21 commits into from Oct 23, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -18,6 +18,9 @@

import lombok.val;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.custom.KnnMinDistance;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.primitives.Pair;

import java.io.Serializable;
import java.util.ArrayList;
Expand All @@ -28,113 +31,114 @@
*/
public class HyperRect implements Serializable {

private List<Interval> points;
//private List<Interval> points;
private float[] lowerEnds;
private float[] higherEnds;
private INDArray lowerEndsIND;
private INDArray higherEndsIND;

public HyperRect(float[] lowerEndsIn, float[] higherEndsIn) {
this.lowerEnds = new float[lowerEndsIn.length];
this.higherEnds = new float[lowerEndsIn.length];
System.arraycopy(lowerEndsIn, 0 , this.lowerEnds, 0, lowerEndsIn.length);
System.arraycopy(higherEndsIn, 0 , this.higherEnds, 0, higherEndsIn.length);
lowerEndsIND = Nd4j.createFromArray(lowerEnds);
higherEndsIND = Nd4j.createFromArray(higherEnds);
}

public HyperRect(List<Interval> points) {
//this.points = points;
this.points = new ArrayList<>(points.size());
for (int i = 0; i < points.size(); ++i) {
Interval newInterval = new Interval(points.get(i).lower, points.get(i).higher);
this.points.add(newInterval);
}
public HyperRect(float[] point) {
this(point, point);
}

public HyperRect(Pair<float[], float[]> ends) {
this(ends.getFirst(), ends.getSecond());
}


public void enlargeTo(INDArray point) {
for (int i = 0; i < points.size(); i++)
points.get(i).enlarge(point.getDouble(i));
float[] pointAsArray = point.toFloatVector();
for (int i = 0; i < lowerEnds.length; i++) {
float p = pointAsArray[i];
if (lowerEnds[i] > p)
lowerEnds[i] = p;
else if (higherEnds[i] < p)
higherEnds[i] = p;
}
}


public static List<Interval> point(INDArray vector) {
List<Interval> ret = new ArrayList<>();
public static Pair<float[],float[]> point(INDArray vector) {
Pair<float[],float[]> ret = new Pair<>();
float[] curr = new float[(int)vector.length()];
for (int i = 0; i < vector.length(); i++) {
double curr = vector.getDouble(i);
ret.add(new Interval(curr, curr));
curr[i] = vector.getFloat(i);
}
ret.setFirst(curr);
ret.setSecond(curr);
return ret;
}


public List<Boolean> contains(INDArray hPoint) {
/*public List<Boolean> contains(INDArray hPoint) {
List<Boolean> ret = new ArrayList<>();
for (int i = 0; i < hPoint.length(); i++)
ret.add(points.get(i).contains(hPoint.getDouble(i)));
return ret;
}

public double minDistance(INDArray hPoint) {
double ret = 0.0;
for (int i = 0; i < hPoint.length(); i++) {
double p = hPoint.getDouble(i);
Interval interval = points.get(i);
if (!interval.contains(p)) {
if (p < interval.lower)
ret += Math.pow((p - interval.lower), 2);
else
ret += Math.pow((p - interval.higher), 2);
}
ret.add(lowerEnds[i] <= hPoint.getDouble(i) &&
higherEnds[i] >= hPoint.getDouble(i));
}

ret = Math.pow(ret, 0.5);
return ret;
}*/

public double minDistance(INDArray hPoint, INDArray output) {
Nd4j.exec(new KnnMinDistance(hPoint, lowerEndsIND, higherEndsIND, output));
return output.getFloat(0);

/*double ret = 0.0;
double[] pointAsArray = hPoint.toDoubleVector();
for (int i = 0; i < pointAsArray.length; i++) {
double p = pointAsArray[i];
if (!(lowerEnds[i] <= p || higherEnds[i] <= p)) {
if (p < lowerEnds[i])
ret += Math.pow((p - lowerEnds[i]), 2);
else
ret += Math.pow((p - higherEnds[i]), 2);
}
}
ret = Math.pow(ret, 0.5);
return ret;*/
}

public HyperRect getUpper(INDArray hPoint, int desc) {
Interval interval = points.get(desc);
double d = hPoint.getDouble(desc);
if (interval.higher < d)
//Interval interval = points.get(desc);
float higher = higherEnds[desc];
float d = hPoint.getFloat(desc);
if (higher < d)
return null;
HyperRect ret = new HyperRect(new ArrayList<>(points));
Interval i2 = ret.points.get(desc);
if (i2.lower < d)
i2.lower = d;
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
if (ret.lowerEnds[desc] < d)
ret.lowerEnds[desc] = d;
return ret;
}

public HyperRect getLower(INDArray hPoint, int desc) {
Interval interval = points.get(desc);
double d = hPoint.getDouble(desc);
if (interval.lower > d)
//Interval interval = points.get(desc);
float lower = lowerEnds[desc];
float d = hPoint.getFloat(desc);
if (lower > d)
return null;
HyperRect ret = new HyperRect(new ArrayList<>(points));
Interval i2 = ret.points.get(desc);
if (i2.higher > d)
i2.higher = d;
HyperRect ret = new HyperRect(lowerEnds,higherEnds);
//Interval i2 = ret.points.get(desc);
if (ret.higherEnds[desc] > d)
ret.higherEnds[desc] = d;
return ret;
}

@Override
public String toString() {
String retVal = "";
retVal += "[";
for (val point : points) {
retVal += "(" + point.lower + " - " + point.higher + ") ";
for (int i = 0; i < lowerEnds.length; ++i) {
retVal += "(" + lowerEnds[i] + " - " + higherEnds[i] + ") ";
}
retVal += "]";
return retVal;
}

public static class Interval {
private double lower, higher;

public Interval(double lower, double higher) {
this.lower = lower;
this.higher = higher;
}

public boolean contains(double point) {
return lower <= point || point <= higher;

}

public void enlarge(double p) {
if (lower > p)
lower = p;
else if (higher < p)
higher = p;
}

}

}
Expand Up @@ -56,7 +56,7 @@ public void insert(INDArray point) {

if (root == null) {
root = new KDNode(point);
rect = new HyperRect(HyperRect.point(point));
rect = new HyperRect(/*HyperRect.point(point)*/ point.toFloatVector());
} else {
int disc = 0;
KDNode node = root;
Expand Down Expand Up @@ -125,38 +125,43 @@ public INDArray delete(INDArray point) {
return node.getPoint();
}

// Share this data for recursive calls of "knn"
private float currentDistance;
private INDArray currentPoint;
private INDArray minDistance = Nd4j.scalar(0.f);


public List<Pair<Double, INDArray>> knn(INDArray point, double distance) {
List<Pair<Double, INDArray>> best = new ArrayList<>();
knn(root, point, rect, distance, best, 0);
Collections.sort(best, new Comparator<Pair<Double, INDArray>>() {
public List<Pair<Float, INDArray>> knn(INDArray point, float distance) {
List<Pair<Float, INDArray>> best = new ArrayList<>();
currentDistance = distance;
currentPoint = point;
knn(root, rect, best, 0);
Collections.sort(best, new Comparator<Pair<Float, INDArray>>() {
@Override
public int compare(Pair<Double, INDArray> o1, Pair<Double, INDArray> o2) {
return Double.compare(o1.getKey(), o2.getKey());
public int compare(Pair<Float, INDArray> o1, Pair<Float, INDArray> o2) {
return Float.compare(o1.getKey(), o2.getKey());
}
});

return best;
}


private void knn(KDNode node, INDArray point, HyperRect rect, double dist, List<Pair<Double, INDArray>> best,
int _disc) {
if (node == null || rect == null || rect.minDistance(point) > dist)
private void knn(KDNode node, HyperRect rect, List<Pair<Float, INDArray>> best, int _disc) {
if (node == null || rect == null || rect.minDistance(currentPoint, minDistance) > currentDistance)
return;
int _discNext = (_disc + 1) % dims;
double distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(point,node.point)).getFinalResult()
.doubleValue();
float distance = Nd4j.getExecutioner().execAndReturn(new EuclideanDistance(currentPoint,node.point, minDistance)).getFinalResult()
.floatValue();

if (distance <= dist) {
if (distance <= currentDistance) {
best.add(Pair.of(distance, node.getPoint()));
}

HyperRect lower = rect.getLower(node.point, _disc);
HyperRect upper = rect.getUpper(node.point, _disc);
knn(node.getLeft(), point, lower, dist, best, _discNext);
knn(node.getRight(), point, upper, dist, best, _discNext);
knn(node.getLeft(), lower, best, _discNext);
knn(node.getRight(), upper, best, _discNext);
}

/**
Expand All @@ -171,7 +176,7 @@ public Pair<Double, INDArray> nn(INDArray point) {

private Pair<Double, INDArray> nn(KDNode node, INDArray point, HyperRect rect, double dist, INDArray best,
int _disc) {
if (node == null || rect.minDistance(point) > dist)
if (node == null || rect.minDistance(point, minDistance) > dist)
return Pair.of(Double.POSITIVE_INFINITY, null);

int _discNext = (_disc + 1) % dims;
Expand Down