Skip to content

Commit

Permalink
Fixes MAHOUT-1202.
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan Filimon committed May 3, 2013
1 parent bf82164 commit 221b595
Show file tree
Hide file tree
Showing 29 changed files with 2,462 additions and 614 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.common.parameters.Parameter;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

/**
* This class implements a "Chebyshev distance" metric by finding the maximum difference
Expand Down Expand Up @@ -52,17 +52,7 @@ public double distance(Vector v1, Vector v2) {
if (v1.size() != v2.size()) {
throw new CardinalityException(v1.size(), v2.size());
}
double result = 0.0;
Vector vector = v1.minus(v2);
Iterator<Vector.Element> iter = vector.iterateNonZero();
while (iter.hasNext()) {
Vector.Element e = iter.next();
double d = Math.abs(e.get());
if (d > result) {
result = d;
}
}
return result;
return v1.aggregate(v2, Functions.MAX_ABS, Functions.MINUS);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@

import java.util.Collection;
import java.util.Collections;
import java.util.Iterator;

import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.common.parameters.Parameter;
import org.apache.mahout.math.CardinalityException;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

/**
* This class implements a "manhattan distance" metric by summing the absolute values of the difference
Expand Down Expand Up @@ -60,15 +60,7 @@ public double distance(Vector v1, Vector v2) {
if (v1.size() != v2.size()) {
throw new CardinalityException(v1.size(), v2.size());
}
double result = 0;
Vector vector = v1.minus(v2);
Iterator<Vector.Element> iter = vector.iterateNonZero();
// this contains all non zero elements between the two
while (iter.hasNext()) {
Vector.Element e = iter.next();
result += Math.abs(e.get());
}
return result;
return v1.aggregate(v2, Functions.PLUS, Functions.MINUS_ABS);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@
package org.apache.mahout.common.distance;

import java.util.Collection;
import java.util.Iterator;
import java.util.List;

import com.google.common.collect.Lists;
import org.apache.hadoop.conf.Configuration;
import org.apache.mahout.common.parameters.DoubleParameter;
import org.apache.mahout.common.parameters.Parameter;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.Vector.Element;
import org.apache.mahout.math.function.Functions;

/**
* Implement Minkowski distance, a real-valued generalization of the
Expand Down Expand Up @@ -82,14 +81,7 @@ public void setExponent(double exponent) {
*/
@Override
public double distance(Vector v1, Vector v2) {
Vector distVector = v1.minus(v2);
double sum = 0.0;
Iterator<Element> it = distVector.iterateNonZero();
while (it.hasNext()) {
Element e = it.next();
sum += Math.pow(Math.abs(e.get()), exponent);
}
return Math.pow(sum, 1.0 / exponent);
return Math.pow(v1.aggregate(v2, Functions.PLUS, Functions.minusAbsPow(exponent)), 1.0 / exponent);
}

// TODO: how?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,8 @@

package org.apache.mahout.common.distance;

import java.util.Iterator;

import org.apache.mahout.math.Vector;
import org.apache.mahout.math.function.Functions;

/**
* Tanimoto coefficient implementation.
Expand All @@ -42,8 +41,10 @@ public double distance(Vector a, Vector b) {
double ab;
double denominator;
if (getWeights() != null) {
ab = dot(b, a); // b is SequentialAccess
denominator = dot(a, a) + dot(b, b) - ab;
ab = a.times(b).aggregate(getWeights(), Functions.PLUS, Functions.MULT);
denominator = a.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
+ b.aggregate(getWeights(), Functions.PLUS, Functions.MULT_SQUARE_LEFT)
- ab;
} else {
ab = b.dot(a); // b is SequentialAccess
denominator = a.getLengthSquared() + b.getLengthSquared() - ab;
Expand All @@ -53,28 +54,13 @@ public double distance(Vector a, Vector b) {
denominator = ab;
}
if (denominator > 0) {
// denom == 0 only when dot(a,a) == dot(b,b) == dot(a,b) == 0
// denominator == 0 only when dot(a,a) == dot(b,b) == dot(a,b) == 0
return 1.0 - ab / denominator;
} else {
return 0.0;
}
}

public double dot(Vector a, Vector b) {
boolean sameVector = a == b;
Iterator<Vector.Element> it = a.iterateNonZero();
Vector.Element el;
Vector weights = getWeights();
double dot = 0.0;
while (it.hasNext() && (el = it.next()) != null) {
double elementValue = el.get();
double value = elementValue * (sameVector ? elementValue : b.getQuick(el.index()));
value *= weights.getQuick(el.index());
dot += value;
}
return dot;
}


@Override
public double distance(double centroidLengthSquare, Vector centroid, Vector v) {
return distance(centroid, v); // TODO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.mahout.math.AbstractVector;
import org.apache.mahout.math.DenseVector;
import org.apache.mahout.math.Matrix;
import org.apache.mahout.math.OrderedIntDoubleMapping;
import org.apache.mahout.math.Vector;
import org.apache.mahout.math.hadoop.stochasticsvd.UpperTriangular;

Expand Down Expand Up @@ -595,11 +596,43 @@ public int getNumNondefaultElements() {
throw new UnsupportedOperationException();
}

@Override
public double getLookupCost() {
return 1;
}

@Override
public double getIteratorAdvanceCost() {
return 1;
}

@Override
public boolean isAddConstantTime() {
return true;
}

@Override
public Matrix matrixLike(int rows, int columns) {
throw new UnsupportedOperationException();
}

/**
* Used internally by assign() to update multiple indices and values at once.
* Only really useful for sparse vectors (especially SequentialAccessSparseVector).
* <p/>
* If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
*
* @param updates a mapping of indices to values to merge in the vector.
*/
@Override
public void mergeUpdates(OrderedIntDoubleMapping updates) {
int indices[] = updates.getIndices();
double values[] = updates.getValues();
for (int i = 0; i < updates.getNumMappings(); ++i) {
viewed.setQuick(rowNum, indices[i], values[i]);
}
}

}

}
7 changes: 7 additions & 0 deletions math/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -172,5 +172,12 @@
<artifactId>junit</artifactId>
<scope>test</scope>
</dependency>

<dependency>
<groupId>org.easymock</groupId>
<artifactId>easymock</artifactId>
<version>3.1</version>
<scope>test</scope>
</dependency>
</dependencies>
</project>
30 changes: 30 additions & 0 deletions math/src/main/java/org/apache/mahout/math/AbstractMatrix.java
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,7 @@ public double determinant() {

}

@SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException")
@Override
public Matrix clone() {
AbstractMatrix clone;
Expand Down Expand Up @@ -617,6 +618,7 @@ protected TransposeViewVector(Matrix m, int offset, boolean rowToColumn) {
numCols = rowToColumn ? m.numCols() : m.numRows();
}

@SuppressWarnings("CloneDoesntCallSuperClone")
@Override
public Vector clone() {
Vector v = new DenseVector(size());
Expand Down Expand Up @@ -683,6 +685,19 @@ public void set(double value) {
};
}

/**
* Used internally by assign() to update multiple indices and values at once.
* Only really useful for sparse vectors (especially SequentialAccessSparseVector).
* <p/>
* If someone ever adds a new type of sparse vectors, this method must merge (index, value) pairs into the vector.
*
* @param updates a mapping of indices to values to merge in the vector.
*/
@Override
public void mergeUpdates(OrderedIntDoubleMapping updates) {
throw new UnsupportedOperationException("Cannot mutate TransposeViewVector");
}

@Override
public double getQuick(int index) {
Vector v = rowToColumn ? matrix.viewColumn(index) : matrix.viewRow(index);
Expand Down Expand Up @@ -725,6 +740,21 @@ public Vector like(int cardinality) {
public int getNumNondefaultElements() {
return size();
}

@Override
public double getLookupCost() {
return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getLookupCost();
}

@Override
public double getIteratorAdvanceCost() {
return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).getIteratorAdvanceCost();
}

@Override
public boolean isAddConstantTime() {
return (rowToColumn ? matrix.viewColumn(0) : matrix.viewRow(0)).isAddConstantTime();
}
}

@Override
Expand Down

0 comments on commit 221b595

Please sign in to comment.