diff --git a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala index 804ab5fb7e1fe..f4119f5ab6086 100644 --- a/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala +++ b/flink-staging/flink-ml/src/main/scala/org/apache/flink/ml/MLUtils.scala @@ -18,12 +18,13 @@ package org.apache.flink.ml -import org.apache.flink.api.common.functions.RichMapFunction +import org.apache.flink.api.common.functions.{RichFlatMapFunction, RichMapFunction} import org.apache.flink.api.java.operators.DataSink import org.apache.flink.api.scala._ import org.apache.flink.configuration.Configuration import org.apache.flink.ml.common.LabeledVector import org.apache.flink.ml.math.SparseVector +import org.apache.flink.util.Collector /** Convenience functions for machine learning tasks * @@ -53,17 +54,21 @@ object MLUtils { * file */ def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = { - val labelCOODS = env.readTextFile(filePath).flatMap { - line => - // remove all comments which start with a '#' - val commentFreeLine = line.takeWhile(_ != '#').trim - - if(commentFreeLine.nonEmpty) { - val splits = commentFreeLine.split(' ') - val label = splits.head.toDouble - val sparseFeatures = splits.tail - val coos = sparseFeatures.map { - str => + val labelCOODS = env.readTextFile(filePath).flatMap( + new RichFlatMapFunction[String, (Double, Array[(Int, Double)])] { + val splitPattern = "\\s+".r + + override def flatMap( + line: String, + out: Collector[(Double, Array[(Int, Double)])] + ): Unit = { + val commentFreeLine = line.takeWhile(_ != '#').trim + + if (commentFreeLine.nonEmpty) { + val splits = splitPattern.split(commentFreeLine) + val label = splits.head.toDouble + val sparseFeatures = splits.tail + val coos = sparseFeatures.flatMap { str => val pair = str.split(':') require(pair.length == 2, "Each feature entry has to have the form :") @@ -71,14 +76,13 @@ object MLUtils { val index = pair(0).toInt - 1 val value = pair(1).toDouble - (index, value) - } + Some((index, value)) + } - Some((label, coos)) - } else { - None + out.collect((label, coos)) + } } - } + }) // Calculate maximum dimension of vectors val dimensionDS = labelCOODS.map { diff --git a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala index d896937204e9e..f02f5ffc1c245 100644 --- a/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala +++ b/flink-staging/flink-ml/src/test/scala/org/apache/flink/ml/MLUtilsSuite.scala @@ -40,9 +40,9 @@ class MLUtilsSuite extends FlatSpec with Matchers with FlinkTestBase { val content = """ - |1 2:10.0 4:4.5 8:4.2 # foo + |1 2:10.0 4:4.5 8:4.2 # foo |-1 1:9.0 4:-4.5 7:2.4 # bar - |0.4 3:1.0 8:-5.6 10:1.0 + |0.4 3:1.0 8:-5.6 10:1.0 |-42.1 2:2.0 4:-6.1 3:5.1 # svm """.stripMargin