Skip to content

Commit

Permalink
[SYSTEMML-1224] Migrate Vector and LabeledPoint classes from mllib to ml
Browse files Browse the repository at this point in the history
Migrate:
mllib.linalg.DenseVector to ml.linalg.DenseVector.
mllib.linalg.Vector to ml.linalg.Vector.
mllib.linalg.Vectors to ml.linalg.Vectors.
mllib.linalg.VectorUDT to ml.linalg.VectorUDT.
mllib.regression.LabeledPoint to ml.feature.LabeledPoint.

Closes #369.
  • Loading branch information
deroneriksson committed Feb 4, 2017
1 parent 4049ce4 commit 578e595
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 88 deletions.
2 changes: 1 addition & 1 deletion src/main/java/org/apache/sysml/api/MLOutput.java
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public Dataset<Row> getDF(SQLContext sqlContext, String varName) throws DMLRunti
* Obtain the DataFrame
* @param sqlContext the SQLContext
* @param varName the variable name
* @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.mllib.linalg.Vector
* @param outputVector if true, returns DataFrame with two column: ID and org.apache.spark.ml.linalg.Vector
* @return the DataFrame
* @throws DMLRuntimeException if DMLRuntimeException occurs
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.ml.feature.LabeledPoint;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.mllib.linalg.distributed.CoordinateMatrix;
import org.apache.spark.mllib.linalg.distributed.MatrixEntry;
import org.apache.spark.sql.Dataset;
Expand Down Expand Up @@ -128,69 +128,6 @@ public static JavaPairRDD<MatrixIndexes, MatrixBlock> coordinateMatrixToBinaryBl
return coordinateMatrixToBinaryBlock(new JavaSparkContext(sc), input, mcIn, true);
}

public static Dataset<Row> stringDataFrameToVectorDataFrame(SQLContext sqlContext, Dataset<Row> inputDF)
throws DMLRuntimeException {

StructField[] oldSchema = inputDF.schema().fields();
//create the new schema
StructField[] newSchema = new StructField[oldSchema.length];
for(int i = 0; i < oldSchema.length; i++) {
String colName = oldSchema[i].name();
newSchema[i] = DataTypes.createStructField(colName, new VectorUDT(), true);
}

//converter
class StringToVector implements Function<Tuple2<Row, Long>, Row> {
private static final long serialVersionUID = -4733816995375745659L;
@Override
public Row call(Tuple2<Row, Long> arg0) throws Exception {
Row oldRow = arg0._1;
int oldNumCols = oldRow.length();
if (oldNumCols > 1) {
throw new DMLRuntimeException("The row must have at most one column");
}

// parse the various strings. i.e
// ((1.2,4.3, 3.4)) or (1.2, 3.4, 2.2) or (1.2 3.4)
// [[1.2,34.3, 1.2, 1.2]] or [1.2, 3.4] or [1.3 1.2]
Object [] fields = new Object[oldNumCols];
ArrayList<Object> fieldsArr = new ArrayList<Object>();
for (int i = 0; i < oldRow.length(); i++) {
Object ci=oldRow.get(i);
if (ci instanceof String) {
String cis = (String)ci;
StringBuffer sb = new StringBuffer(cis.trim());
for (int nid=0; i < 2; i++) { //remove two level nesting
if ((sb.charAt(0) == '(' && sb.charAt(sb.length() - 1) == ')') ||
(sb.charAt(0) == '[' && sb.charAt(sb.length() - 1) == ']')
) {
sb.deleteCharAt(0);
sb.setLength(sb.length() - 1);
}
}
//have the replace code
String ncis = "[" + sb.toString().replaceAll(" *, *", ",") + "]";
Vector v = Vectors.parse(ncis);
fieldsArr.add(v);
} else {
throw new DMLRuntimeException("Only String is supported");
}
}
Row row = RowFactory.create(fieldsArr.toArray());
return row;
}
}

//output DF
JavaRDD<Row> newRows = inputDF.rdd().toJavaRDD().zipWithIndex().map(new StringToVector());
// DataFrame outDF = sqlContext.createDataFrame(newRows, new StructType(newSchema)); //TODO investigate why it doesn't work
Dataset<Row> outDF = sqlContext.createDataFrame(newRows.rdd(),
DataTypes.createStructType(newSchema));

return outDF;
}


public static Dataset<Row> projectColumns(Dataset<Row> df, ArrayList<String> columns) throws DMLRuntimeException {
ArrayList<String> columnToSelect = new ArrayList<String>();
for(int i = 1; i < columns.size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ class LogisticRegressionModel(override val uid: String)(
object LogisticRegressionExample {
import org.apache.spark.{ SparkConf, SparkContext }
import org.apache.spark.sql.types._
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.LabeledPoint

def main(args: Array[String]) = {
val sparkConf: SparkConf = new SparkConf();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.mllib.linalg.DenseVector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.ml.linalg.DenseVector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.VectorUDT;
import org.apache.spark.mllib.linalg.Vectors;
import org.apache.spark.ml.linalg.Vector;
import org.apache.spark.ml.linalg.VectorUDT;
import org.apache.spark.ml.linalg.Vectors;
import org.apache.spark.rdd.RDD;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ package org.apache.sysml.api.ml

import org.scalatest.FunSuite
import org.scalatest.Matchers
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.ml.linalg.Vector
import scala.reflect.runtime.universe._

case class LabeledDocument[T:TypeTag](id: Long, text: String, label: Double)
Expand Down

0 comments on commit 578e595

Please sign in to comment.