From 50ed8d636bcbcb793dc34e53327b43f26f085214 Mon Sep 17 00:00:00 2001 From: darionyaphet Date: Tue, 13 Jun 2017 18:30:38 +0800 Subject: [PATCH] [SPARK-21066] LibSVM load just one input file --- .../spark/ml/source/libsvm/LibSVMRelation.scala | 6 ++---- .../ml/source/libsvm/LibSVMRelationSuite.scala | 13 ++++++++++--- 2 files changed, 12 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala index f68847a664b69..133fcaa13d9e3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/source/libsvm/LibSVMRelation.scala @@ -91,12 +91,10 @@ private[libsvm] class LibSVMFileFormat extends TextBasedFileFormat with DataSour val numFeatures: Int = libSVMOptions.numFeatures.getOrElse { // Infers number of features if the user doesn't specify (a valid) one. val dataFiles = files.filterNot(_.getPath.getName startsWith "_") - val path = if (dataFiles.length == 1) { - dataFiles.head.getPath.toUri.toString - } else if (dataFiles.isEmpty) { + val path = if (dataFiles.isEmpty) { throw new IOException("No input path specified for libsvm data") } else { - throw new IOException("Multiple input paths are not supported for libsvm data.") + dataFiles.map(_.getPath.toUri.toString).mkString(",") } val sc = sparkSession.sparkContext diff --git a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala index e164d279f3f02..aa0521afc6fc6 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/source/libsvm/LibSVMRelationSuite.scala @@ -35,15 +35,22 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext { override def beforeAll(): Unit = { super.beforeAll() - val lines = + val lines0 = """ |1 1:1.0 3:2.0 5:3.0 |0 + """.stripMargin + val lines1 = + """ |0 2:4.0 4:5.0 6:6.0 """.stripMargin val dir = Utils.createDirectory(tempDir.getCanonicalPath, "data") - val file = new File(dir, "part-00000") - Files.write(lines, file, StandardCharsets.UTF_8) + val succ = new File(dir, "_SUCCESS") + val file0 = new File(dir, "part-00000") + val file1 = new File(dir, "part-00001") + Files.write("", succ, StandardCharsets.UTF_8) + Files.write(lines0, file0, StandardCharsets.UTF_8) + Files.write(lines1, file1, StandardCharsets.UTF_8) path = dir.toURI.toString }