Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,13 @@

package org.apache.flink.ml

import org.apache.flink.api.common.functions.RichMapFunction
import org.apache.flink.api.common.functions.{RichFlatMapFunction, RichMapFunction}
import org.apache.flink.api.java.operators.DataSink
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.ml.common.LabeledVector
import org.apache.flink.ml.math.SparseVector
import org.apache.flink.util.Collector

/** Convenience functions for machine learning tasks
*
Expand Down Expand Up @@ -53,32 +54,35 @@ object MLUtils {
* file
*/
def readLibSVM(env: ExecutionEnvironment, filePath: String): DataSet[LabeledVector] = {
val labelCOODS = env.readTextFile(filePath).flatMap {
line =>
// remove all comments which start with a '#'
val commentFreeLine = line.takeWhile(_ != '#').trim

if(commentFreeLine.nonEmpty) {
val splits = commentFreeLine.split(' ')
val label = splits.head.toDouble
val sparseFeatures = splits.tail
val coos = sparseFeatures.map {
str =>
val labelCOODS = env.readTextFile(filePath).flatMap(
new RichFlatMapFunction[String, (Double, Array[(Int, Double)])] {
val splitPattern = "\\s+".r

override def flatMap(
line: String,
out: Collector[(Double, Array[(Int, Double)])]
): Unit = {
val commentFreeLine = line.takeWhile(_ != '#').trim

if (commentFreeLine.nonEmpty) {
val splits = splitPattern.split(commentFreeLine)
val label = splits.head.toDouble
val sparseFeatures = splits.tail
val coos = sparseFeatures.flatMap { str =>
val pair = str.split(':')
require(pair.length == 2, "Each feature entry has to have the form <feature>:<value>")

// libSVM index is 1-based, but we expect it to be 0-based
val index = pair(0).toInt - 1
val value = pair(1).toDouble

(index, value)
}
Some((index, value))
}

Some((label, coos))
} else {
None
out.collect((label, coos))
}
}
}
})

// Calculate maximum dimension of vectors
val dimensionDS = labelCOODS.map {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@ class MLUtilsSuite extends FlatSpec with Matchers with FlinkTestBase {

val content =
"""
|1 2:10.0 4:4.5 8:4.2 # foo
|1 2:10.0 4:4.5 8:4.2 # foo
|-1 1:9.0 4:-4.5 7:2.4 # bar
|0.4 3:1.0 8:-5.6 10:1.0
|0.4 3:1.0 8:-5.6 10:1.0
|-42.1 2:2.0 4:-6.1 3:5.1 # svm
""".stripMargin

Expand Down