Skip to content

Commit

Permalink
[SPARK-15096][ML] LogisticRegression MultiClassSummarizer numClasses …
Browse files Browse the repository at this point in the history
…can fail if no valid labels are found

## What changes were proposed in this pull request?

(Please fill in changes proposed in this fix)
Throw better exception when numClasses is empty and empty.max is thrown.

## How was this patch tested?

(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)
Add a new unit test, which calls histogram with empty numClasses.

Author: wm624@hotmail.com <wm624@hotmail.com>

Closes #12969 from wangmiao1981/logisticR.
  • Loading branch information
wangmiao1981 authored and srowen committed May 14, 2016
1 parent 0f1f31d commit 354f8f1
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ private[classification] class MultiClassSummarizer extends Serializable {
def countInvalid: Long = totalInvalidCnt

/** @return The number of distinct labels in the input dataset. */
def numClasses: Int = distinctMap.keySet.max + 1
def numClasses: Int = if (distinctMap.isEmpty) 0 else distinctMap.keySet.max + 1

/** @return The weightSum of each label in the input dataset. */
def histogram: Array[Double] = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,10 @@ class LogisticRegressionSuite
assert(summarizer4.countInvalid === 2)
assert(summarizer4.numClasses === 4)

val summarizer5 = new MultiClassSummarizer
assert(summarizer5.histogram.isEmpty)
assert(summarizer5.numClasses === 0)

// small map merges large one
val summarizerA = summarizer1.merge(summarizer2)
assert(summarizerA.hashCode() === summarizer2.hashCode())
Expand Down

0 comments on commit 354f8f1

Please sign in to comment.