Skip to content

Commit

Permalink
Add empty and corner test cases, fix names and spaces
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jun 19, 2015
1 parent fe93873 commit 9fadd36
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 15 deletions.
19 changes: 11 additions & 8 deletions mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,10 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
* values in the input array are ignored.
* It returns an array of n-grams where each n-gram is represented by a space-separated string of
* words.
*
* When the input is empty, an empty array is returned.
* When the input array length is less than n (number of elements per n-gram), a single n-gram
* consisting of the input array is returned.
*/
@Experimental
class NGram(override val uid: String)
Expand All @@ -38,28 +42,27 @@ class NGram(override val uid: String)

/**
* Minimum n-gram length, >= 1.
* Defauult: 2, bigram features
* Default: 2, bigram features
* @group param
*/
val NGramLength: IntParam = new IntParam(this, "NGramLength", "number elements per n-gram (>=1)",
val n: IntParam = new IntParam(this, "n", "number elements per n-gram (>=1)",
ParamValidators.gtEq(1))

/** @group setParam */
def setNGramLength(value: Int): this.type = set(NGramLength, value)
def setN(value: Int): this.type = set(n, value)

/** @group getParam */
def getNGramLength: Int = $(NGramLength)
def getN: Int = $(n)

setDefault(NGramLength -> 2)
setDefault(n -> 2)

override protected def createTransformFunc: Seq[String] => Seq[String] = {
val minLength = $(NGramLength)
val minLength = $(n)
_.sliding(minLength).map(_.mkString(" ")).toSeq
}

override protected def validateInputType(inputType: DataType): Unit = {
require(
inputType.sameType(ArrayType(StringType)),
require(inputType.sameType(ArrayType(StringType)),
s"Input type must be ArrayType(StringType) but got $inputType.")
}

Expand Down
39 changes: 32 additions & 7 deletions mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,28 +30,53 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.NGramSuite._

test("default behavior yields bigram features") {
val tokenizer = new NGram()
val NGramTransformer = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("Test", "for", "ngram", "."),
Array("Test for", "for ngram", "ngram .")
)))
testNGram(tokenizer, dataset)
testNGram(NGramTransformer, dataset)
}

test("NGramLength=4 yields length 4 n-grams") {
val tokenizer = new NGram()
val NGramTransformer = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setNGramLength(4)
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d", "b c d e")
)))
testNGram(tokenizer, dataset)
testNGram(NGramTransformer, dataset)
}

test("empty input yields empty output") {
val NGramTransformer = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array(),
Array()
)))
testNGram(NGramTransformer, dataset)
}
test("input array < n yields a single n-gram consisting of input array") {
val NGramTransformer = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setN(6)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d e")
)))
testNGram(NGramTransformer, dataset)
}
}

Expand All @@ -62,7 +87,7 @@ object NGramSuite extends SparkFunSuite {
.select("NGrams", "wantedNGrams")
.collect()
.foreach { case Row(actualNGrams, wantedNGrams) =>
assert(actualNGrams === wantedNGrams)
}
assert(actualNGrams === wantedNGrams)
}
}
}

0 comments on commit 9fadd36

Please sign in to comment.