Skip to content

Commit

Permalink
Test libsvm
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxGekk committed Oct 18, 2020
1 parent 5f739dd commit e579d48
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 15 deletions.
Expand Up @@ -27,13 +27,26 @@ import org.apache.spark.ml.attribute.AttributeGroup
import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.linalg.SQLDataTypes.VectorType
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{FakeFileSystemRequiringDSOption, Row, SaveMode}
import org.apache.spark.sql.catalyst.plans.SQLHelper
import org.apache.spark.sql.{Row, SaveMode}
import org.apache.spark.sql.execution.datasources.CommonFileDataSourceSuite
import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
import org.apache.spark.util.Utils

class LibSVMRelationSuite
extends SparkFunSuite
with MLlibTestSparkContext
with CommonFileDataSourceSuite {

override protected def dataSourceFormat = "libsvm"
override protected def inputDataset = {
val rawData = new java.util.ArrayList[Row]()
rawData.add(Row(1.0, Vectors.sparse(1, Seq((0, 1.0)))))
val struct = new StructType()
.add("labelFoo", DoubleType, false)
.add("featuresBar", VectorType, false)
spark.createDataFrame(rawData, struct)
}

class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext with SQLHelper {
// Path for dataset
var path: String = _

Expand Down Expand Up @@ -212,13 +225,4 @@ class LibSVMRelationSuite extends SparkFunSuite with MLlibTestSparkContext with
assert(v == Vectors.sparse(2, Seq((0, 2.0), (1, 3.0))))
}
}

test("SPARK-33101: should propagate Hadoop config from DS options to underlying file system") {
withSQLConf(
"fs.file.impl" -> classOf[FakeFileSystemRequiringDSOption].getName,
"fs.file.impl.disable.cache" -> "true") {
val df = spark.read.option("ds_option", "value").format("libsvm").load(path)
assert(df.columns(0) == "label")
}
}
}
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.sql.execution.datasources

import org.apache.spark.sql.{Encoders, FakeFileSystemRequiringDSOption, QueryTest, Row}
import org.scalatest.funsuite.AnyFunSuite

import org.apache.spark.sql.{Dataset, Encoders, FakeFileSystemRequiringDSOption, SparkSession}
import org.apache.spark.sql.catalyst.plans.SQLHelper

// The trait contains tests for all file-based data sources. The tests that are not applicable to
// all file-based data sources should be placed to `FileBasedDataSourceSuite`.
trait CommonFileDataSourceSuite { self: QueryTest =>
trait CommonFileDataSourceSuite extends SQLHelper { self: AnyFunSuite =>

protected def spark: SparkSession
protected def dataSourceFormat: String
protected def inputDataset: Dataset[_] = spark.createDataset(Seq("abc"))(Encoders.STRING)

test(s"Propagate Hadoop configs from $dataSourceFormat options to underlying file system") {
withSQLConf(
Expand All @@ -33,7 +38,7 @@ trait CommonFileDataSourceSuite { self: QueryTest =>
withTempPath { dir =>
val path = dir.getAbsolutePath
val conf = Map("ds_option" -> "value", "mergeSchema" -> mergeSchema.toString)
spark.createDataset(Seq("abc"))(Encoders.STRING)
inputDataset
.write
.options(conf)
.format(dataSourceFormat)
Expand Down

0 comments on commit e579d48

Please sign in to comment.