Skip to content

Commit

Permalink
All but unit tests..woo.
Browse files Browse the repository at this point in the history
  • Loading branch information
cestella committed Feb 8, 2012
1 parent ba19eaa commit 3ec0c9b
Showing 1 changed file with 64 additions and 7 deletions.
71 changes: 64 additions & 7 deletions spatial_search/src/main/java/com/caseystella/KNN.java
Expand Up @@ -12,20 +12,66 @@

import com.google.common.collect.Iterables;

import com.google.common.primitives.Longs;

import com.sun.tools.javac.util.Pair;

public class KNN
{
public static class Payload
{
private RealVector vector;
private byte[] payload;

public Payload(RealVector vector, byte[] payload)
{
this.vector = vector;
this.payload = payload;
}

/**
* Gets the vector for this instance.
*
* @return The vector.
*/
public RealVector getVector()
{
return vector;
}

/**
* Gets the payload for this instance.
*
* @return The payload.
*/
public byte[] getPayload()
{
return payload;
}


}
public static interface IHashCreator
{
public Function<RealVector, Long> construct(int hashDimension, long seed);
}

public static interface IBackingStore
{
public void persist(long key, Payload payload);
public Iterable<Payload> getBucket(long key);
}

private Iterable<Function<RealVector, Long>> hashes;
private IBackingStore backingStore;
public KNN( int numHashes
, int hashDimension
, long seed
, IHashCreator creator
, IBackingStore backingStore
)
{
this.backingStore = backingStore;
List<Function<RealVector, Long > > hashList = new ArrayList<Function<RealVector,Long>> (numHashes);
for(int i = 0;i < numHashes;++i)
{
Expand All @@ -34,19 +80,30 @@ public KNN( int numHashes
hashes = hashList;
}

public Iterable<byte[]> query(RealVector q, Function<RealVector, Double> metric, double limit)
public Iterable< Payload> query(RealVector q, Function<RealVector, Double> metric, double limit)
{
List<Payload> results = new ArrayList<Payload>();
for(Function<RealVector, Long> hash : hashes)
{
//find the thing in the bucket

Iterable<Payload> values = backingStore.getBucket(hash.apply(q));
for(Payload value : values)
{
if(metric.apply(value.getVector()) < limit)
{
results.add(value);
}
}
}
return results;
}

public <T> Iterable<T> query( RealVector q
, Function<byte[], T> transformer
)

public void insert(Payload payload)
{
return Iterables.transform(query(q), transformer);
for(Function<RealVector, Long> hash : hashes)
{
long key = hash.apply(payload.getVector());
backingStore.persist(key, payload);
}
}
}

0 comments on commit 3ec0c9b

Please sign in to comment.