Skip to content

Commit

Permalink
Make n > input length yield empty output
Browse files Browse the repository at this point in the history
  • Loading branch information
Feynman Liang committed Jun 19, 2015
1 parent 9fadd36 commit d2c839f
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 19 deletions.
7 changes: 3 additions & 4 deletions mllib/src/main/scala/org/apache/spark/ml/feature/NGram.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ import org.apache.spark.sql.types.{ArrayType, DataType, StringType}
* words.
*
* When the input is empty, an empty array is returned.
* When the input array length is less than n (number of elements per n-gram), a single n-gram
* consisting of the input array is returned.
* When the input array length is less than n (number of elements per n-gram), no n-grams are
* returned.
*/
@Experimental
class NGram(override val uid: String)
Expand All @@ -57,8 +57,7 @@ class NGram(override val uid: String)
setDefault(n -> 2)

override protected def createTransformFunc: Seq[String] => Seq[String] = {
val minLength = $(n)
_.sliding(minLength).map(_.mkString(" ")).toSeq
_.iterator.sliding($(n)).withPartial(false).map(_.mkString(" ")).toSeq
}

override protected def validateInputType(inputType: DataType): Unit = {
Expand Down
31 changes: 16 additions & 15 deletions mllib/src/test/scala/org/apache/spark/ml/feature/NGramSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,61 +30,62 @@ class NGramSuite extends SparkFunSuite with MLlibTestSparkContext {
import org.apache.spark.ml.feature.NGramSuite._

test("default behavior yields bigram features") {
val NGramTransformer = new NGram()
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setOutputCol("nGrams")
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("Test", "for", "ngram", "."),
Array("Test for", "for ngram", "ngram .")
)))
testNGram(NGramTransformer, dataset)
testNGram(nGram, dataset)
}

test("NGramLength=4 yields length 4 n-grams") {
val NGramTransformer = new NGram()
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setOutputCol("nGrams")
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d", "b c d e")
)))
testNGram(NGramTransformer, dataset)
testNGram(nGram, dataset)
}

test("empty input yields empty output") {
val NGramTransformer = new NGram()
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setOutputCol("nGrams")
.setN(4)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array(),
Array()
)))
testNGram(NGramTransformer, dataset)
testNGram(nGram, dataset)
}
test("input array < n yields a single n-gram consisting of input array") {
val NGramTransformer = new NGram()

test("input array < n yields empty output") {
val nGram = new NGram()
.setInputCol("inputTokens")
.setOutputCol("NGrams")
.setOutputCol("nGrams")
.setN(6)
val dataset = sqlContext.createDataFrame(Seq(
NGramTestData(
Array("a", "b", "c", "d", "e"),
Array("a b c d e")
Array()
)))
testNGram(NGramTransformer, dataset)
testNGram(nGram, dataset)
}
}

object NGramSuite extends SparkFunSuite {

def testNGram(t: NGram, dataset: DataFrame): Unit = {
t.transform(dataset)
.select("NGrams", "wantedNGrams")
.select("nGrams", "wantedNGrams")
.collect()
.foreach { case Row(actualNGrams, wantedNGrams) =>
assert(actualNGrams === wantedNGrams)
Expand Down

0 comments on commit d2c839f

Please sign in to comment.