Permalink
Browse files

Updated LSH to be a tad more correct. Also, CLI programs to predict t…

…he parameters for stable distribution-based LSH were improved/enhanced.
  • Loading branch information...
cestella committed Mar 4, 2012
1 parent fc926be commit ae25973b58e05a0c9a6ef4b0a489caf0d265876b
@@ -0,0 +1,76 @@
+package com.caseystella.lsh;
+
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.random.RandomDataImpl;
+import org.apache.commons.math.random.RandomGenerator;
+
+import com.caseystella.interfaces.IDistanceMetric;
+import com.caseystella.interfaces.IHashCreator;
+import com.caseystella.lsh.interfaces.ILSH;
+import com.caseystella.math.L1DistanceMetric;
+import com.caseystella.math.stabledistribution.AbstractStableDistributionFunction;
+import com.google.common.base.Function;
+
+public class L1LSH extends AbstractStableDistributionFunction
+{
+ protected static class L1Sampler implements ISampler
+ {
+ /**
+ * {@inheritDoc}
+ * @throws MathException
+ *
+ * @see Function#apply(RandomDataImpl)
+ */
+ public double apply(RandomDataImpl randomData) throws MathException {
+
+ return randomData.nextCauchy(0, 1);
+
+ }
+
+ }
+ private static ISampler sampler = new L1Sampler();
+ public static IDistanceMetric metric = new L1DistanceMetric();
+
+ public static class Creator implements IHashCreator
+ {
+ int dim;
+ float w;
+
+ public Creator(int dim, float w)
+ {
+ this.dim = dim;
+ this.w = w;
+
+ }
+ @Override
+ public ILSH construct(long seed) throws MathException {
+ return new L1LSH(dim, w, seed);
+ }
+
+ }
+
+ /**
+ * Constructs a new instance.
+ * @throws MathException
+ */
+ public L1LSH(int dim, float w, RandomGenerator rand) throws MathException {
+ super(dim, w, rand);
+ }
+
+ public L1LSH(int dim, float w, long seed) throws MathException {
+
+ super(dim, w, seed);
+ }
+
+ @Override
+ public IDistanceMetric getMetric() {
+ return metric;
+ }
+
+ @Override
+ protected ISampler getSampler(
+ RandomDataImpl dataSampler) {
+ return sampler;
+ }
+
+}
@@ -0,0 +1,73 @@
+package com.caseystella.lsh;
+
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.random.RandomDataImpl;
+import org.apache.commons.math.random.RandomGenerator;
+
+import com.caseystella.interfaces.IDistanceMetric;
+import com.caseystella.interfaces.IHashCreator;
+import com.caseystella.lsh.interfaces.ILSH;
+import com.caseystella.math.L2DistanceMetric;
+import com.caseystella.math.stabledistribution.AbstractStableDistributionFunction;
+import com.google.common.base.Function;
+
+public class L2LSH extends AbstractStableDistributionFunction {
+
+ protected static class L2Sampler implements ISampler
+ {
+
+ /**
+ * {@inheritDoc}
+ * @see Function#apply(RandomDataImpl)
+ */
+ public double apply(RandomDataImpl randomData)
+ {
+ return randomData.nextGaussian(0,1);
+ }
+
+ }
+ public static class Creator implements IHashCreator
+ {
+ int dim;
+ float w;
+
+ public Creator(int dim, float w)
+ {
+ this.dim = dim;
+ this.w = w;
+
+ }
+ @Override
+ public ILSH construct(long seed) throws MathException {
+ return new L2LSH(dim, w, seed);
+ }
+
+ }
+ private static ISampler sampler = new L2Sampler();
+ public static IDistanceMetric metric = new L2DistanceMetric();
+ /**
+ * Constructs a new instance.
+ * @throws MathException
+ */
+ public L2LSH(int dim, float w, RandomGenerator rand) throws MathException {
+ super(dim, w, rand);
+ }
+
+ public L2LSH(int dim, float w, long seed) throws MathException {
+ super(dim, w, seed);
+ }
+
+
+ @Override
+ public IDistanceMetric getMetric() {
+ return metric;
+ }
+
+ @Override
+ protected ISampler getSampler(
+ RandomDataImpl dataSampler) {
+ return sampler;
+ }
+
+
+}
@@ -0,0 +1,65 @@
+package com.caseystella.lsh;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+
+import org.apache.commons.math.MathException;
+import org.apache.commons.math.linear.RealVector;
+
+import com.caseystella.interfaces.IDistanceMetric;
+import com.caseystella.interfaces.IHashCreator;
+import com.caseystella.lsh.interfaces.ILSH;
+
+public class RepeatingLSH implements ILSH
+{
+ public static class Creator implements IHashCreator
+ {
+ int numRepetitions;
+ IHashCreator underlyingCreator;
+
+
+ public Creator(int numRepetitions, IHashCreator underlyingCreator)
+ {
+ this.numRepetitions = numRepetitions;
+ this.underlyingCreator = underlyingCreator;
+
+ }
+ @Override
+ public ILSH construct(long seed) throws MathException {
+ return new RepeatingLSH(numRepetitions, underlyingCreator, seed);
+ }
+
+ }
+ private List<ILSH> lshList;
+ private int[] coefficients;
+ public RepeatingLSH(int numRepetitions, IHashCreator pHashCreator, long seed) throws MathException
+ {
+ lshList = new ArrayList<ILSH>();
+ coefficients = new int[numRepetitions];
+ Random r = new Random(seed);
+ for(int i = 0;i < numRepetitions;++i)
+ {
+ lshList.add(pHashCreator.construct(r.nextLong()));
+ coefficients[i] = Math.abs(r.nextInt());
+ }
+
+
+ }
+
+ @Override
+ public IDistanceMetric getMetric() {
+ return lshList.get(0).getMetric();
+ }
+
+ @Override
+ public long apply(RealVector vector) {
+ long ret = 0;
+ for(int i = 0;i < lshList.size();++i)
+ {
+ ret += coefficients[i] * lshList.get(i).apply(vector);
+ }
+ return ret;
+ }
+
+}
@@ -0,0 +1,11 @@
+package com.caseystella.lsh.interfaces;
+
+import org.apache.commons.math.linear.RealVector;
+
+import com.caseystella.interfaces.IDistanceMetric;
+
+public interface ILSH
+{
+ public IDistanceMetric getMetric();
+ public long apply(RealVector vector);
+}
@@ -0,0 +1,63 @@
+package com.caseystella.math.stabledistribution;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.PrintStream;
+
+import org.apache.commons.math.DimensionMismatchException;
+import org.apache.commons.math.linear.MatrixUtils;
+import org.apache.commons.math.linear.NotPositiveDefiniteMatrixException;
+import org.apache.commons.math.linear.RealMatrix;
+import org.apache.commons.math.random.CorrelatedRandomVectorGenerator;
+import org.apache.commons.math.random.GaussianRandomGenerator;
+import org.apache.commons.math.random.JDKRandomGenerator;
+import org.apache.commons.math.random.RandomGenerator;
+
+public class GenerateRandomVectorsCLI {
+
+ /**
+ * @param args
+ * @throws DimensionMismatchException
+ * @throws NotPositiveDefiniteMatrixException
+ * @throws FileNotFoundException
+ */
+ public static void main(String[] args) throws NotPositiveDefiniteMatrixException, DimensionMismatchException, FileNotFoundException
+ {
+ // Create and seed a RandomGenerator (could use any of the generators in the random package here)
+ RandomGenerator rg = new JDKRandomGenerator();
+ rg.setSeed(17399225432l); // Fixed seed means same results every time
+
+ // Create a GassianRandomGenerator using rg as its source of randomness
+ GaussianRandomGenerator rawGenerator = new GaussianRandomGenerator(rg);
+ double c = 3 * 4 * .5;
+ double[] mean = {1, 2};
+ double[][] cov = {{9, c}, {c, 16}};
+ RealMatrix covariance = MatrixUtils.createRealMatrix(cov);
+ // Create a CorrelatedRandomVectorGenerator using rawGenerator for the components
+ CorrelatedRandomVectorGenerator generator =
+ new CorrelatedRandomVectorGenerator(mean, covariance, 1.0e-12 * covariance.getNorm(), rawGenerator);
+ int num = Integer.parseInt(args[0]);
+ int dim = Integer.parseInt(args[1]);
+ PrintStream writer = args.length > 2?new PrintStream(new File(args[2])):System.out;
+ for(int i = 0;i < num;++i)
+ {
+ for(int j = 0;j < dim/2;++j)
+ {
+ // Use the generator to generate correlated vectors
+ double[] randomVector = generator.nextVector();
+ for(double d : randomVector)
+ {
+ writer.print(d + " ");
+ }
+ }
+ writer.println("");
+ }
+ if(writer != System.out)
+ {
+ writer.flush();
+ writer.close();
+ }
+
+ }
+
+}
Oops, something went wrong.

0 comments on commit ae25973

Please sign in to comment.