From 8bf481b290389ac84f0774c354324009fe42c38d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 9 Oct 2015 21:18:12 -0700 Subject: [PATCH 1/9] add PMML export for Naive Bayes --- .../mllib/classification/NaiveBayes.scala | 4 +- .../export/NaiveBayesPMMLModelExport.scala | 93 +++++++++++++++++++ .../pmml/export/PMMLModelExportFactory.scala | 5 +- .../NaiveBayesPMMLModelExportSuite.scala | 83 +++++++++++++++++ 4 files changed, 182 insertions(+), 3 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index a956084ae06e8..8dcbd7b6734a3 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,6 +19,8 @@ package org.apache.spark.mllib.classification import java.lang.{Iterable => JIterable} +import org.apache.spark.mllib.pmml.PMMLExportable + import scala.collection.JavaConverters._ import org.json4s.JsonDSL._ @@ -47,7 +49,7 @@ class NaiveBayesModel private[spark] ( @Since("0.9.0") val pi: Array[Double], @Since("0.9.0") val theta: Array[Array[Double]], @Since("1.4.0") val modelType: String) - extends ClassificationModel with Serializable with Saveable { + extends ClassificationModel with Serializable with Saveable with PMMLExportable { import NaiveBayes.{Bernoulli, Multinomial, supportedModelTypes} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala new file mode 100644 index 0000000000000..41da51360ba17 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -0,0 +1,93 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.mllib.classification.{NaiveBayesModel => SNaiveBayesModel} + +/** + * PMML Model Export for GeneralizedLinearModel abstract class + */ +private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, description: String) + extends PMMLModelExport { + + populateNaiveBayesPMML(model) + + /** + * Export the input Naive Bayes model to PMML format. + */ + private def populateNaiveBayesPMML(model: SNaiveBayesModel): Unit = { + pmml.getHeader.setDescription(description) + + val nbModel = new NaiveBayesModel() + + nbModel.setAlgorithmName(model.modelType) + nbModel.setFunctionName(MiningFunctionType.CLASSIFICATION) + nbModel.setModelName(description) + + val fields = new SArray[FieldName](model.theta(0).length) + val dataDictionary = new DataDictionary() + val miningSchema = new MiningSchema() + val bayesInputs = new BayesInputs() + val bayesOutput = new BayesOutput() + + val labelIndices = model.pi.indices + val featureIndices = model.theta.head.indices + + // add Bayes input + for (i <- featureIndices) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + + val pairs = labelIndices.map { label => + new TargetValueCount().withValue("target_" + label).withCount(model.theta(label)(i)) + } + + val bayesInput = new BayesInput() + val pairCounts = new PairCounts() + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairs: _*)) + bayesInput.withFieldName(fields(i)).withPairCounts(pairCounts) + bayesInputs.withBayesInputs(bayesInput) + } + + // add Bayes output + val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => + new TargetValueCount().withValue("target_" + i).withCount(x) } + bayesOutput + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) + + // add target field + val targetField = FieldName.create("target") + dataDictionary.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.DOUBLE) + .withValues(labelIndices.map { x => new Value().withValue(x.toString)}: _*)) + miningSchema.withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.TARGET)) + + nbModel.setMiningSchema(miningSchema) + nbModel.setBayesInputs(bayesInputs) + nbModel.setBayesOutput(bayesOutput) + + dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + + pmml.setDataDictionary(dataDictionary) + pmml.withModels(nbModel) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala index 29bd689e1185a..ff133aa147711 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactory.scala @@ -19,8 +19,7 @@ package org.apache.spark.mllib.pmml.export import org.dmg.pmml.RegressionNormalizationMethodType -import org.apache.spark.mllib.classification.LogisticRegressionModel -import org.apache.spark.mllib.classification.SVMModel +import org.apache.spark.mllib.classification.{NaiveBayesModel, LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.regression.LassoModel import org.apache.spark.mllib.regression.LinearRegressionModel @@ -55,6 +54,8 @@ private[mllib] object PMMLModelExportFactory { throw new IllegalArgumentException( "PMML Export not supported for Multinomial Logistic Regression") } + case nb: NaiveBayesModel => + new NaiveBayesPMMLModelExport(nb, "naive bayes") case _ => throw new IllegalArgumentException( "PMML Export not supported for model: " + model.getClass.getName) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala new file mode 100644 index 0000000000000..87b04c686dedb --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.pmml.export + +import scala.{Array => SArray} + +import org.dmg.pmml._ + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel => SNaiveBayesModel} + +class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { + + test("Naive Bayes PMML export") { + val label = SArray(0.0, 1.0, 2.0) + val pi = SArray(0.5, 0.1, 0.4).map(math.log) + val theta = SArray( + SArray(0.70, 0.10, 0.10, 0.10), // label 0 + SArray(0.10, 0.70, 0.10, 0.10), // label 1 + SArray(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val nbModel = new SNaiveBayesModel(label, pi, theta, NaiveBayes.Multinomial) + val nbModelExport = PMMLModelExportFactory.createPMMLModelExport(nbModel) + val pmml = nbModelExport.getPmml + + assert(pmml.getHeader.getDescription === "naive bayes") + assert(pmml.getDataDictionary.getNumberOfFields === theta(0).length + 1) + + // assert Bayes input + val pmmlRegressionModel = pmml.getModels.get(0).asInstanceOf[NaiveBayesModel] + val bayesInputs = pmmlRegressionModel.getBayesInputs + assert(bayesInputs.getBayesInputs.size() === 4) + + val bIter = bayesInputs.iterator() + var i = 0 + while (bIter.hasNext) { + val bayesInput = bIter.next() + assert(bayesInput.getFieldName.getValue === "field_" + i) + val pIter = bayesInput.getPairCounts.iterator() + while (pIter.hasNext) { + val pairs = pIter.next() + val tIter = pairs.getTargetValueCounts.iterator() + var j = 0 + while (tIter.hasNext) { + val targetValueCount = tIter.next() + assert(targetValueCount.getCount === theta(j)(i)) + j += 1 + } + } + i += 1 + } + + // assert Bayes output + val bayesOutput = pmmlRegressionModel.getBayesOutput.getTargetValueCounts + assert(bayesOutput.getTargetValueCounts.size() === pi.length) + + val bayesOutputIter = bayesOutput.iterator() + i = 0 + while (bayesOutputIter.hasNext) { + val targetCount = bayesOutputIter.next() + assert(targetCount.getValue === "target_" + i) + assert(targetCount.getCount === pi(i)) + i += 1 + } + } +} + From 1a609f5a6968c0f5f68f72975f2966547bb9a501 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 27 Oct 2015 21:12:47 -0700 Subject: [PATCH 2/9] fix errors --- .../spark/mllib/classification/NaiveBayes.scala | 3 +-- .../pmml/export/NaiveBayesPMMLModelExport.scala | 3 ++- .../export/NaiveBayesPMMLModelExportSuite.scala | 2 +- .../export/PMMLModelExportFactorySuite.scala | 17 ++++++++++++++++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 8dcbd7b6734a3..1a43519220ae9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -19,8 +19,6 @@ package org.apache.spark.mllib.classification import java.lang.{Iterable => JIterable} -import org.apache.spark.mllib.pmml.PMMLExportable - import scala.collection.JavaConverters._ import org.json4s.JsonDSL._ @@ -29,6 +27,7 @@ import org.json4s.jackson.JsonMethods._ import org.apache.spark.{Logging, SparkContext, SparkException} import org.apache.spark.annotation.Since import org.apache.spark.mllib.linalg.{BLAS, DenseMatrix, DenseVector, SparseVector, Vector} +import org.apache.spark.mllib.pmml.PMMLExportable import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.{Loader, Saveable} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index 41da51360ba17..fdd183bd7fa51 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.pmml.export +// Scala Array is conflict with Array imported in PMML. import scala.{Array => SArray} import org.dmg.pmml._ @@ -24,7 +25,7 @@ import org.dmg.pmml._ import org.apache.spark.mllib.classification.{NaiveBayesModel => SNaiveBayesModel} /** - * PMML Model Export for GeneralizedLinearModel abstract class + * PMML Model Export for Naive Bayes abstract class */ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, description: String) extends PMMLModelExport { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala index 87b04c686dedb..314c1286d9c6b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.mllib.pmml.export +// Scala Array is conflict with Array imported in PMML. import scala.{Array => SArray} import org.dmg.pmml._ @@ -80,4 +81,3 @@ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { } } } - diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index af49450961750..b30c1fa4ac428 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.pmml.export import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.classification.{LogisticRegressionModel, SVMModel} +import org.apache.spark.mllib.classification.{NaiveBayesModel, NaiveBayes, LogisticRegressionModel, SVMModel} import org.apache.spark.mllib.clustering.KMeansModel import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.{LassoModel, LinearRegressionModel, RidgeRegressionModel} @@ -38,6 +38,21 @@ class PMMLModelExportFactorySuite extends SparkFunSuite { assert(modelExport.isInstanceOf[KMeansPMMLModelExport]) } + test("PMMLModelExportFactory create NaiveBayesPMMLModelExport when passing a NaiveBayesModel") { + val label = Array(0.0, 1.0, 2.0) + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val nbModel = new NaiveBayesModel(label, pi, theta, NaiveBayes.Multinomial) + val modelExport = PMMLModelExportFactory.createPMMLModelExport(nbModel) + + assert(modelExport.isInstanceOf[NaiveBayesPMMLModelExport]) + } + test("PMMLModelExportFactory create GeneralizedLinearPMMLModelExport when passing a " + "LinearRegressionModel, RidgeRegressionModel or LassoModel") { val linearInput = LinearDataGenerator.generateLinearInput(3.0, Array(10.0, 10.0), 1, 17) From dd5224b4ce8d70f441249fe523795008566b4004 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Mon, 9 Nov 2015 13:43:41 +0800 Subject: [PATCH 3/9] fix errors --- .../pmml/export/NaiveBayesPMMLModelExport.scala | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index fdd183bd7fa51..9511bb6f54725 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -56,31 +56,33 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript // add Bayes input for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) + .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) val pairs = labelIndices.map { label => - new TargetValueCount().withValue("target_" + label).withCount(model.theta(label)(i)) + new TargetValueCount().withValue(label.toDouble.toString).withCount(model.theta(label)(i)) } val bayesInput = new BayesInput() val pairCounts = new PairCounts() .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairs: _*)) + .withValue(i.toDouble.toString) bayesInput.withFieldName(fields(i)).withPairCounts(pairCounts) bayesInputs.withBayesInputs(bayesInput) } // add Bayes output val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => - new TargetValueCount().withValue("target_" + i).withCount(x) } + new TargetValueCount().withValue(i.toDouble.toString).withCount(x) } bayesOutput .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) // add target field - val targetField = FieldName.create("target") + val targetField = FieldName.create("class") dataDictionary.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.DOUBLE) - .withValues(labelIndices.map { x => new Value().withValue(x.toString)}: _*)) - miningSchema.withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.TARGET)) + .withValues(labelIndices.map { x => new Value().withValue(x.toDouble.toString)}: _*)) + miningSchema.withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.PREDICTED)) nbModel.setMiningSchema(miningSchema) nbModel.setBayesInputs(bayesInputs) From e1295aafc9430df8f8211a96b02223ce25cdd1ca Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 11 Nov 2015 16:17:08 +0800 Subject: [PATCH 4/9] fix multi-normial dist naive bayes --- .../export/NaiveBayesPMMLModelExport.scala | 70 +++++++++++++------ 1 file changed, 48 insertions(+), 22 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index 9511bb6f54725..bac8884395b10 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -22,7 +22,7 @@ import scala.{Array => SArray} import org.dmg.pmml._ -import org.apache.spark.mllib.classification.{NaiveBayesModel => SNaiveBayesModel} +import org.apache.spark.mllib.classification.{NaiveBayesModel => SNaiveBayesModel, NaiveBayes} /** * PMML Model Export for Naive Bayes abstract class @@ -53,30 +53,48 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript val labelIndices = model.pi.indices val featureIndices = model.theta.head.indices - // add Bayes input - for (i <- featureIndices) { - fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) - .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) - miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) - val pairs = labelIndices.map { label => - new TargetValueCount().withValue(label.toDouble.toString).withCount(model.theta(label)(i)) - } + if (model.modelType == NaiveBayes.Multinomial) { + // add Bayes input + for (i <- featureIndices) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) + miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) - val bayesInput = new BayesInput() - val pairCounts = new PairCounts() - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairs: _*)) - .withValue(i.toDouble.toString) - bayesInput.withFieldName(fields(i)).withPairCounts(pairCounts) - bayesInputs.withBayesInputs(bayesInput) - } + val stats = labelIndices.map { label => + new TargetValueStat().withValue(label.toDouble.toString) + .withContinuousDistribution( + new GaussianDistribution().withMean(math.exp(model.theta(label)(i))).withVariance(1.0)) + } - // add Bayes output - val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => - new TargetValueCount().withValue(i.toDouble.toString).withCount(x) } - bayesOutput - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) + val targetValueStats = new TargetValueStats().withTargetValueStats(stats: _*) + + val bayesInput = new BayesInput() + bayesInput.withFieldName(fields(i)).withTargetValueStats(targetValueStats) + bayesInputs.withBayesInputs(bayesInput) + } + } else if (model.modelType == NaiveBayes.Bernoulli) { + // add Bayes input + for (i <- featureIndices) { + fields(i) = FieldName.create("field_" + i) + dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) + .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) + miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + + val pairs = labelIndices.map { label => + new TargetValueCount().withValue(label.toDouble.toString).withCount(model.theta(label)(i)) + } + + val bayesInput = new BayesInput() + val pairCounts = new PairCounts() + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairs: _*)) + .withValue(i.toDouble.toString) + bayesInput.withFieldName(fields(i)).withPairCounts(pairCounts) + bayesInputs.withBayesInputs(bayesInput) + } + } else { + throw new Exception("Unsupported model type.") + } // add target field val targetField = FieldName.create("class") @@ -84,6 +102,14 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript .withValues(labelIndices.map { x => new Value().withValue(x.toDouble.toString)}: _*)) miningSchema.withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.PREDICTED)) + // add Bayes output + val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => + new TargetValueCount().withValue(i.toDouble.toString).withCount(math.exp(x)) } + bayesOutput + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) + .withFieldName(targetField) + + nbModel.setMiningSchema(miningSchema) nbModel.setBayesInputs(bayesInputs) nbModel.setBayesOutput(bayesOutput) From 3eb227ff7d6ef16ce6d965fd601319316bba8afe Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 12 Nov 2015 16:03:58 +0800 Subject: [PATCH 5/9] fix bernulli model --- .../export/NaiveBayesPMMLModelExport.scala | 23 ++++++++++++++----- .../NaiveBayesPMMLModelExportSuite.scala | 4 ++-- 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index bac8884395b10..be30741221183 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -81,15 +81,26 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) - val pairs = labelIndices.map { label => - new TargetValueCount().withValue(label.toDouble.toString).withCount(model.theta(label)(i)) + val pairsExist = labelIndices.map { label => + new TargetValueCount().withValue(label.toDouble.toString).withCount(math.exp(model.theta(label)(i))) } + val pairCountsExist = new PairCounts() + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairsExist: _*)) + .withValue("1.0") + + val pairsAbsent = labelIndices.map { label => + new TargetValueCount().withValue(label.toDouble.toString).withCount(1.0 - math.exp(model.theta(label)(i))) + } + + val pairCountsAbsent = new PairCounts() + .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairsAbsent: _*)) + .withValue("0.0") + + val bayesInput = new BayesInput() - val pairCounts = new PairCounts() - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairs: _*)) - .withValue(i.toDouble.toString) - bayesInput.withFieldName(fields(i)).withPairCounts(pairCounts) + + bayesInput.withFieldName(fields(i)).withPairCounts(pairCountsExist, pairCountsAbsent) bayesInputs.withBayesInputs(bayesInput) } } else { diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala index 314c1286d9c6b..c4f87894bdf6f 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala @@ -75,8 +75,8 @@ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { i = 0 while (bayesOutputIter.hasNext) { val targetCount = bayesOutputIter.next() - assert(targetCount.getValue === "target_" + i) - assert(targetCount.getCount === pi(i)) + assert(targetCount.getValue.toDouble === i) + assert(targetCount.getCount === math.exp(pi(i))) i += 1 } } From 7d8fcb72b0f737a44c282dc226bf52d387b46690 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Thu, 12 Nov 2015 17:14:10 +0800 Subject: [PATCH 6/9] fix style --- .../export/NaiveBayesPMMLModelExport.scala | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index be30741221183..1cc61cb4df8e1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -53,18 +53,19 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript val labelIndices = model.pi.indices val featureIndices = model.theta.head.indices - + // add Bayes input if (model.modelType == NaiveBayes.Multinomial) { - // add Bayes input for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + miningSchema + .withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) val stats = labelIndices.map { label => new TargetValueStat().withValue(label.toDouble.toString) .withContinuousDistribution( - new GaussianDistribution().withMean(math.exp(model.theta(label)(i))).withVariance(1.0)) + new GaussianDistribution() + .withMean(math.exp(model.theta(label)(i))).withVariance(1.0)) } val targetValueStats = new TargetValueStats().withTargetValueStats(stats: _*) @@ -74,15 +75,16 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript bayesInputs.withBayesInputs(bayesInput) } } else if (model.modelType == NaiveBayes.Bernoulli) { - // add Bayes input for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) - miningSchema.withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + miningSchema + .withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) val pairsExist = labelIndices.map { label => - new TargetValueCount().withValue(label.toDouble.toString).withCount(math.exp(model.theta(label)(i))) + new TargetValueCount() + .withValue(label.toDouble.toString).withCount(math.exp(model.theta(label)(i))) } val pairCountsExist = new PairCounts() @@ -90,14 +92,14 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript .withValue("1.0") val pairsAbsent = labelIndices.map { label => - new TargetValueCount().withValue(label.toDouble.toString).withCount(1.0 - math.exp(model.theta(label)(i))) + new TargetValueCount() + .withValue(label.toDouble.toString).withCount(1.0 - math.exp(model.theta(label)(i))) } val pairCountsAbsent = new PairCounts() .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairsAbsent: _*)) .withValue("0.0") - val bayesInput = new BayesInput() bayesInput.withFieldName(fields(i)).withPairCounts(pairCountsExist, pairCountsAbsent) @@ -111,7 +113,8 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript val targetField = FieldName.create("class") dataDictionary.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.DOUBLE) .withValues(labelIndices.map { x => new Value().withValue(x.toDouble.toString)}: _*)) - miningSchema.withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.PREDICTED)) + miningSchema + .withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.PREDICTED)) // add Bayes output val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => @@ -120,7 +123,6 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) .withFieldName(targetField) - nbModel.setMiningSchema(miningSchema) nbModel.setBayesInputs(bayesInputs) nbModel.setBayesOutput(bayesOutput) From 4dad4db9de085832c3d275db742e6422d876709b Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Fri, 13 Nov 2015 15:35:06 +0800 Subject: [PATCH 7/9] add output --- .../export/NaiveBayesPMMLModelExport.scala | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index 1cc61cb4df8e1..31e3cb44658d2 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -78,7 +78,7 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) - .withValues(SArray(new Value().withValue(i.toDouble.toString)): _*)) + .withValues(new Value("0.0"), new Value("1.0"))) miningSchema .withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) @@ -123,9 +123,25 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) .withFieldName(targetField) + // add output + val output = new Output() + output.withOutputFields( + new OutputField() + .withName(FieldName.create("Predicted_class")) + .withFeature(ResultFeatureType.PREDICTED_VALUE)) + output.withOutputFields(labelIndices.map { label => + new OutputField() + .withName(FieldName.create(s"Probability_${label.toDouble}")) + .withOpType(OpType.CONTINUOUS) + .withDataType(DataType.DOUBLE) + .withFeature(ResultFeatureType.PROBABILITY) + .withValue(s"${label.toDouble}") + }: _*) + nbModel.setMiningSchema(miningSchema) nbModel.setBayesInputs(bayesInputs) nbModel.setBayesOutput(bayesOutput) + nbModel.setOutput(output) dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) From 5a89d9dd5af22b08365efb70512932fd8cbf896d Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Tue, 8 Dec 2015 17:28:18 +0800 Subject: [PATCH 8/9] remove multinomial case --- .../export/NaiveBayesPMMLModelExport.scala | 25 +++---------------- .../NaiveBayesPMMLModelExportSuite.scala | 18 ++++++++++--- .../export/PMMLModelExportFactorySuite.scala | 2 +- 3 files changed, 19 insertions(+), 26 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index 31e3cb44658d2..6e1f64c87f820 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -54,27 +54,7 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript val featureIndices = model.theta.head.indices // add Bayes input - if (model.modelType == NaiveBayes.Multinomial) { - for (i <- featureIndices) { - fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CONTINUOUS, DataType.DOUBLE)) - miningSchema - .withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) - - val stats = labelIndices.map { label => - new TargetValueStat().withValue(label.toDouble.toString) - .withContinuousDistribution( - new GaussianDistribution() - .withMean(math.exp(model.theta(label)(i))).withVariance(1.0)) - } - - val targetValueStats = new TargetValueStats().withTargetValueStats(stats: _*) - - val bayesInput = new BayesInput() - bayesInput.withFieldName(fields(i)).withTargetValueStats(targetValueStats) - bayesInputs.withBayesInputs(bayesInput) - } - } else if (model.modelType == NaiveBayes.Bernoulli) { + if (model.modelType == NaiveBayes.Bernoulli) { for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) @@ -106,7 +86,8 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript bayesInputs.withBayesInputs(bayesInput) } } else { - throw new Exception("Unsupported model type.") + throw new IllegalArgumentException( + "Naive Bayes model PMML export only supports Bernoulli model type for now.") } // add target field diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala index c4f87894bdf6f..4162ffe476d89 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExportSuite.scala @@ -24,10 +24,11 @@ import org.dmg.pmml._ import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel => SNaiveBayesModel} +import org.apache.spark.mllib.util.TestingUtils._ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { - test("Naive Bayes PMML export") { + test("Naive Bayes PMML export: Bernoulli model type") { val label = SArray(0.0, 1.0, 2.0) val pi = SArray(0.5, 0.1, 0.4).map(math.log) val theta = SArray( @@ -36,7 +37,7 @@ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { SArray(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - val nbModel = new SNaiveBayesModel(label, pi, theta, NaiveBayes.Multinomial) + val nbModel = new SNaiveBayesModel(label, pi, theta, NaiveBayes.Bernoulli) val nbModelExport = PMMLModelExportFactory.createPMMLModelExport(nbModel) val pmml = nbModelExport.getPmml @@ -53,6 +54,10 @@ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { while (bIter.hasNext) { val bayesInput = bIter.next() assert(bayesInput.getFieldName.getValue === "field_" + i) + val pairCounts = bayesInput.getPairCounts + assert(pairCounts.size() === 2, + "Only two values in one variables is allowed in Bernoulli model type.") + var k = 0 val pIter = bayesInput.getPairCounts.iterator() while (pIter.hasNext) { val pairs = pIter.next() @@ -60,9 +65,16 @@ class NaiveBayesPMMLModelExportSuite extends SparkFunSuite { var j = 0 while (tIter.hasNext) { val targetValueCount = tIter.next() - assert(targetValueCount.getCount === theta(j)(i)) + if (k == 0) { + // test values of pairsExist + assert(math.log(targetValueCount.getCount) ~== theta(j)(i) relTol 1e-5) + } else { + // test values of pairsAbsent + assert(math.log(1 - targetValueCount.getCount) ~== theta(j)(i) relTol 1e-5) + } j += 1 } + k += 1 } i += 1 } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index b30c1fa4ac428..c38bd756f50ca 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -47,7 +47,7 @@ class PMMLModelExportFactorySuite extends SparkFunSuite { Array(0.10, 0.10, 0.70, 0.10) // label 2 ).map(_.map(math.log)) - val nbModel = new NaiveBayesModel(label, pi, theta, NaiveBayes.Multinomial) + val nbModel = new NaiveBayesModel(label, pi, theta, NaiveBayes.Bernoulli) val modelExport = PMMLModelExportFactory.createPMMLModelExport(nbModel) assert(modelExport.isInstanceOf[NaiveBayesPMMLModelExport]) From b17491d5f2138c8bb24db1498a9c2f8b27943046 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 9 Dec 2015 00:17:07 +0800 Subject: [PATCH 9/9] change API with JPMML 1.2.7 --- .../export/NaiveBayesPMMLModelExport.scala | 58 ++++++++++--------- .../export/PMMLModelExportFactorySuite.scala | 17 ++++++ 2 files changed, 47 insertions(+), 28 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala index 6e1f64c87f820..18ed52673d764 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/NaiveBayesPMMLModelExport.scala @@ -57,33 +57,34 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript if (model.modelType == NaiveBayes.Bernoulli) { for (i <- featureIndices) { fields(i) = FieldName.create("field_" + i) - dataDictionary.withDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) - .withValues(new Value("0.0"), new Value("1.0"))) + dataDictionary + .addDataFields(new DataField(fields(i), OpType.CATEGORICAL, DataType.DOUBLE) + .addValues(new Value("0.0"), new Value("1.0"))) miningSchema - .withMiningFields(new MiningField(fields(i)).withUsageType(FieldUsageType.ACTIVE)) + .addMiningFields(new MiningField(fields(i)).setUsageType(FieldUsageType.ACTIVE)) val pairsExist = labelIndices.map { label => new TargetValueCount() - .withValue(label.toDouble.toString).withCount(math.exp(model.theta(label)(i))) + .setValue(label.toDouble.toString).setCount(math.exp(model.theta(label)(i))) } val pairCountsExist = new PairCounts() - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairsExist: _*)) - .withValue("1.0") + .setTargetValueCounts(new TargetValueCounts().addTargetValueCounts(pairsExist: _*)) + .setValue("1.0") val pairsAbsent = labelIndices.map { label => new TargetValueCount() - .withValue(label.toDouble.toString).withCount(1.0 - math.exp(model.theta(label)(i))) + .setValue(label.toDouble.toString).setCount(1.0 - math.exp(model.theta(label)(i))) } val pairCountsAbsent = new PairCounts() - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(pairsAbsent: _*)) - .withValue("0.0") + .setTargetValueCounts(new TargetValueCounts().addTargetValueCounts(pairsAbsent: _*)) + .setValue("0.0") val bayesInput = new BayesInput() - bayesInput.withFieldName(fields(i)).withPairCounts(pairCountsExist, pairCountsAbsent) - bayesInputs.withBayesInputs(bayesInput) + bayesInput.setFieldName(fields(i)).addPairCounts(pairCountsExist, pairCountsAbsent) + bayesInputs.addBayesInputs(bayesInput) } } else { throw new IllegalArgumentException( @@ -92,31 +93,32 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript // add target field val targetField = FieldName.create("class") - dataDictionary.withDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.DOUBLE) - .withValues(labelIndices.map { x => new Value().withValue(x.toDouble.toString)}: _*)) + dataDictionary + .addDataFields(new DataField(targetField, OpType.CATEGORICAL, DataType.DOUBLE) + .addValues(labelIndices.map { x => new Value().setValue(x.toDouble.toString)}: _*)) miningSchema - .withMiningFields(new MiningField(targetField).withUsageType(FieldUsageType.PREDICTED)) + .addMiningFields(new MiningField(targetField).setUsageType(FieldUsageType.PREDICTED)) // add Bayes output val targetValueCounts = model.pi.zipWithIndex.map { case (x, i) => - new TargetValueCount().withValue(i.toDouble.toString).withCount(math.exp(x)) } + new TargetValueCount().setValue(i.toDouble.toString).setCount(math.exp(x)) } bayesOutput - .withTargetValueCounts(new TargetValueCounts().withTargetValueCounts(targetValueCounts: _*)) - .withFieldName(targetField) + .setTargetValueCounts(new TargetValueCounts().addTargetValueCounts(targetValueCounts: _*)) + .setFieldName(targetField) // add output val output = new Output() - output.withOutputFields( + output.addOutputFields( new OutputField() - .withName(FieldName.create("Predicted_class")) - .withFeature(ResultFeatureType.PREDICTED_VALUE)) - output.withOutputFields(labelIndices.map { label => + .setName(FieldName.create("Predicted_class")) + .setFeature(FeatureType.PREDICTED_VALUE)) + output.addOutputFields(labelIndices.map { label => new OutputField() - .withName(FieldName.create(s"Probability_${label.toDouble}")) - .withOpType(OpType.CONTINUOUS) - .withDataType(DataType.DOUBLE) - .withFeature(ResultFeatureType.PROBABILITY) - .withValue(s"${label.toDouble}") + .setName(FieldName.create(s"Probability_${label.toDouble}")) + .setOpType(OpType.CONTINUOUS) + .setDataType(DataType.DOUBLE) + .setFeature(FeatureType.PROBABILITY) + .setValue(s"${label.toDouble}") }: _*) nbModel.setMiningSchema(miningSchema) @@ -124,9 +126,9 @@ private[mllib] class NaiveBayesPMMLModelExport(model: SNaiveBayesModel, descript nbModel.setBayesOutput(bayesOutput) nbModel.setOutput(output) - dataDictionary.withNumberOfFields(dataDictionary.getDataFields.size) + dataDictionary.setNumberOfFields(dataDictionary.getDataFields.size) pmml.setDataDictionary(dataDictionary) - pmml.withModels(nbModel) + pmml.addModels(nbModel) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala index c38bd756f50ca..37af201a9fe8d 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/pmml/export/PMMLModelExportFactorySuite.scala @@ -106,4 +106,21 @@ class PMMLModelExportFactorySuite extends SparkFunSuite { PMMLModelExportFactory.createPMMLModelExport(invalidModel) } } + + test("PMMLModelExportFactory throw IllegalArgumentException " + + "when passing a Multinomial Naive Bayes") { + val label = Array(0.0, 1.0, 2.0) + val pi = Array(0.5, 0.1, 0.4).map(math.log) + val theta = Array( + Array(0.70, 0.10, 0.10, 0.10), // label 0 + Array(0.10, 0.70, 0.10, 0.10), // label 1 + Array(0.10, 0.10, 0.70, 0.10) // label 2 + ).map(_.map(math.log)) + + val nbModel = new NaiveBayesModel(label, pi, theta, NaiveBayes.Multinomial) + + intercept[IllegalArgumentException] { + PMMLModelExportFactory.createPMMLModelExport(nbModel) + } + } }