Skip to content

Commit

Permalink
fixed nonserializable error that was causing naivebayes test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
leahmcguire committed Mar 7, 2015
1 parent 2d0c1ba commit e2d925e
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,19 +46,19 @@ class NaiveBayesModel private[mllib] (
val labels: Array[Double],
val pi: Array[Double],
val theta: Array[Array[Double]],
val modelType: NaiveBayes.ModelType)
val modelType: String)
extends ClassificationModel with Serializable with Saveable {

def this(labels: Array[Double], pi: Array[Double], theta: Array[Array[Double]]) =
this(labels, pi, theta, NaiveBayes.Multinomial)
this(labels, pi, theta, NaiveBayes.Multinomial.toString)

private val brzPi = new BDV[Double](pi)
private val brzTheta = new BDM(theta(0).length, theta.length, theta.flatten).t

// Bernoulli scoring requires log(condprob) if 1 log(1-condprob) if 0
// this precomputes log(1.0 - exp(theta)) and its sum for linear algebra application
// of this condition in predict function
private val (brzNegTheta, brzNegThetaSum) = modelType match {
private val (brzNegTheta, brzNegThetaSum) = NaiveBayes.ModelType.fromString(modelType) match {
case NaiveBayes.Multinomial => (None, None)
case NaiveBayes.Bernoulli =>
val negTheta = brzLog((brzExp(brzTheta.copy) :*= (-1.0)) :+= 1.0) // log(1.0 - exp(x))
Expand All @@ -74,7 +74,7 @@ class NaiveBayesModel private[mllib] (
}

override def predict(testData: Vector): Double = {
modelType match {
NaiveBayes.ModelType.fromString(modelType) match {
case NaiveBayes.Multinomial =>
labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
case NaiveBayes.Bernoulli =>
Expand All @@ -84,7 +84,7 @@ class NaiveBayesModel private[mllib] (
}

override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType.toString)
val data = NaiveBayesModel.SaveLoadV1_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV1_0.save(sc, path, data)
}

Expand Down Expand Up @@ -137,15 +137,15 @@ object NaiveBayesModel extends Loader[NaiveBayesModel] {
val labels = data.getAs[Seq[Double]](0).toArray
val pi = data.getAs[Seq[Double]](1).toArray
val theta = data.getAs[Seq[Seq[Double]]](2).map(_.toArray).toArray
val modelType = NaiveBayes.ModelType.fromString(data.getString(3))
val modelType = NaiveBayes.ModelType.fromString(data.getString(3)).toString
new NaiveBayesModel(labels, pi, theta, modelType)
}
}

override def load(sc: SparkContext, path: String): NaiveBayesModel = {
def getModelType(metadata: JValue): NaiveBayes.ModelType = {
def getModelType(metadata: JValue): String = {
implicit val formats = DefaultFormats
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String])
NaiveBayes.ModelType.fromString((metadata \ "modelType").extract[String]).toString
}
val (loadedClassName, version, metadata) = loadMetadata(sc, path)
val classNameV1_0 = SaveLoadV1_0.thisClassName
Expand Down Expand Up @@ -265,7 +265,7 @@ class NaiveBayes private (
i += 1
}

new NaiveBayesModel(labels, pi, theta, modelType)
new NaiveBayesModel(labels, pi, theta, modelType.toString)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ object NaiveBayesSuite {
sample: Int = 10): Seq[LabeledPoint] = {
val D = theta(0).length
val rnd = new Random(seed)

c
val _pi = pi.map(math.pow(math.E, _))
val _theta = theta.map(row => row.map(math.pow(math.E, _)))

Expand All @@ -77,7 +77,7 @@ object NaiveBayesSuite {

/** Binary labels, 3 features */
private val binaryModel = new NaiveBayesModel(labels = Array(0.0, 1.0), pi = Array(0.2, 0.8),
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli)
theta = Array(Array(0.1, 0.3, 0.6), Array(0.2, 0.4, 0.4)), NaiveBayes.Bernoulli.toString)
}

class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
Expand Down Expand Up @@ -111,7 +111,6 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {

test("Naive Bayes Multinomial") {
val nPoints = 1000

val pi = Array(0.5, 0.1, 0.4).map(math.log)
val theta = Array(
Array(0.70, 0.10, 0.10, 0.10), // label 0
Expand All @@ -120,7 +119,11 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
).map(_.map(math.log))

val testData = NaiveBayesSuite.generateNaiveBayesInput(
pi, theta, nPoints, 42, NaiveBayes.Multinomial)
pi,
theta,
nPoints,
42,
NaiveBayes.Multinomial)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()

Expand All @@ -144,7 +147,6 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {

test("Naive Bayes Bernoulli") {
val nPoints = 10000

val pi = Array(0.5, 0.3, 0.2).map(math.log)
val theta = Array(
Array(0.50, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.40), // label 0
Expand Down

0 comments on commit e2d925e

Please sign in to comment.