Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-6255] [MLLIB] Support multiclass classification in Python API #5137

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -78,7 +78,13 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector): JList[Object] = {
try {
val model = learner.run(data.rdd.persist(StorageLevel.MEMORY_AND_DISK), initialWeights)
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
if (model.isInstanceOf[LogisticRegressionModel]) {
val lrModel = model.asInstanceOf[LogisticRegressionModel]
List(lrModel.weights, lrModel.intercept, lrModel.numFeatures, lrModel.numClasses)
.map(_.asInstanceOf[Object]).asJava
} else {
List(model.weights, model.intercept).map(_.asInstanceOf[Object]).asJava
}
} finally {
data.rdd.unpersist(blocking = false)
}
Expand Down Expand Up @@ -181,9 +187,11 @@ private[python] class PythonMLLibAPI extends Serializable {
miniBatchFraction: Double,
initialWeights: Vector,
regType: String,
intercept: Boolean): JList[Object] = {
intercept: Boolean,
validateData: Boolean): JList[Object] = {
val SVMAlg = new SVMWithSGD()
SVMAlg.setIntercept(intercept)
.setValidateData(validateData)
SVMAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
Expand All @@ -207,9 +215,11 @@ private[python] class PythonMLLibAPI extends Serializable {
initialWeights: Vector,
regParam: Double,
regType: String,
intercept: Boolean): JList[Object] = {
intercept: Boolean,
validateData: Boolean): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithSGD()
LogRegAlg.setIntercept(intercept)
.setValidateData(validateData)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
Expand All @@ -233,9 +243,13 @@ private[python] class PythonMLLibAPI extends Serializable {
regType: String,
intercept: Boolean,
corrections: Int,
tolerance: Double): JList[Object] = {
tolerance: Double,
validateData: Boolean,
numClasses: Int): JList[Object] = {
val LogRegAlg = new LogisticRegressionWithLBFGS()
LogRegAlg.setIntercept(intercept)
.setValidateData(validateData)
.setNumClasses(numClasses)
LogRegAlg.optimizer
.setNumIterations(numIterations)
.setRegParam(regParam)
Expand Down
134 changes: 108 additions & 26 deletions python/pyspark/mllib/classification.py
Expand Up @@ -22,7 +22,7 @@

from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
from pyspark.mllib.util import Saveable, Loader, inherit_doc

Expand All @@ -31,13 +31,13 @@
'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']


class LinearBinaryClassificationModel(LinearModel):
class LinearClassificationModel(LinearModel):
"""
Represents a linear binary classification model that predicts to whether an
example is positive (1.0) or negative (0.0).
A private abstract class representing a multiclass classification model.
The categories are represented by int values: 0, 1, 2, etc.
"""
def __init__(self, weights, intercept):
super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
super(LinearClassificationModel, self).__init__(weights, intercept)
self._threshold = None

def setThreshold(self, value):
Expand All @@ -47,14 +47,26 @@ def setThreshold(self, value):
Sets the threshold that separates positive predictions from negative
predictions. An example with prediction score greater than or equal
to this threshold is identified as an positive, and negative otherwise.
It is used for binary classification only.
"""
self._threshold = value

@property
def threshold(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you make this a property, please?

"""
.. note:: Experimental

Returns the threshold (if any) used for converting raw prediction scores
into 0/1 predictions. It is used for binary classification only.
"""
return self._threshold

def clearThreshold(self):
"""
.. note:: Experimental

Clears the threshold so that `predict` will output raw prediction scores.
It is used for binary classification only.
"""
self._threshold = None

Expand All @@ -66,7 +78,7 @@ def predict(self, test):
raise NotImplementedError


class LogisticRegressionModel(LinearBinaryClassificationModel):
class LogisticRegressionModel(LinearClassificationModel):

"""A linear binary classification model derived from logistic regression.

Expand Down Expand Up @@ -112,10 +124,39 @@ class LogisticRegressionModel(LinearBinaryClassificationModel):
... os.removedirs(path)
... except:
... pass
>>> multi_class_data = [
... LabeledPoint(0.0, [0.0, 1.0, 0.0]),
... LabeledPoint(1.0, [1.0, 0.0, 0.0]),
... LabeledPoint(2.0, [0.0, 0.0, 1.0])
... ]
>>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3)
>>> mcm.predict([0.0, 0.5, 0.0])
0
>>> mcm.predict([0.8, 0.0, 0.0])
1
>>> mcm.predict([0.0, 0.0, 0.3])
2
"""
def __init__(self, weights, intercept):
def __init__(self, weights, intercept, numFeatures, numClasses):
super(LogisticRegressionModel, self).__init__(weights, intercept)
self._numFeatures = int(numFeatures)
self._numClasses = int(numClasses)
self._threshold = 0.5
if self._numClasses == 2:
self._dataWithBiasSize = None
self._weightsMatrix = None
else:
self._dataWithBiasSize = self._coeff.size / (self._numClasses - 1)
self._weightsMatrix = self._coeff.toArray().reshape(self._numClasses - 1,
self._dataWithBiasSize)

@property
def numFeatures(self):
return self._numFeatures

@property
def numClasses(self):
return self._numClasses

def predict(self, x):
"""
Expand All @@ -126,20 +167,38 @@ def predict(self, x):
return x.map(lambda v: self.predict(v))

x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
prob = 1 / (1 + exp(-margin))
if self.numClasses == 2:
margin = self.weights.dot(x) + self._intercept
if margin > 0:
prob = 1 / (1 + exp(-margin))
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
if self._threshold is None:
return prob
else:
return 1 if prob > self._threshold else 0
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
if self._threshold is None:
return prob
else:
return 1 if prob > self._threshold else 0
best_class = 0
max_margin = 0.0
if x.size + 1 == self._dataWithBiasSize:
for i in range(0, self._numClasses - 1):
margin = x.dot(self._weightsMatrix[i][0:x.size]) + \
self._weightsMatrix[i][x.size]
if margin > max_margin:
max_margin = margin
best_class = i + 1
else:
for i in range(0, self._numClasses - 1):
margin = x.dot(self._weightsMatrix[i])
if margin > max_margin:
max_margin = margin
best_class = i + 1
return best_class

def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be updated to take numClasses, numFeatures

_py2java(sc, self._coeff), self.intercept)
_py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses)
java_model.save(sc._jsc.sc(), path)

@classmethod
Expand All @@ -148,8 +207,10 @@ def load(cls, sc, path):
sc._jsc.sc(), path)
weights = _java2py(sc, java_model.weights())
intercept = java_model.intercept()
numFeatures = java_model.numFeatures()
numClasses = java_model.numClasses()
threshold = java_model.getThreshold().get()
model = LogisticRegressionModel(weights, intercept)
model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses)
model.setThreshold(threshold)
return model

Expand All @@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object):

@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
initialWeights=None, regParam=0.01, regType="l2", intercept=False):
initialWeights=None, regParam=0.01, regType="l2", intercept=False,
validateData=True):
"""
Train a logistic regression model on the given data.

Expand All @@ -184,11 +246,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
:param validateData: Boolean parameter which indicates if the
algorithm should validate data before training.
(default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
float(step), float(miniBatchFraction), i, float(regParam), regType,
bool(intercept))
bool(intercept), bool(validateData))

return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)

Expand All @@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object):

@classmethod
def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
intercept=False, corrections=10, tolerance=1e-4):
intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2):
"""
Train a logistic regression model on the given data.

Expand All @@ -223,6 +288,11 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
update (default: 10).
:param tolerance: The convergence tolerance of iterations for
L-BFGS (default: 1e-4).
:param validateData: Boolean parameter which indicates if the
algorithm should validate data before training.
(default: True)
:param numClasses: The number of classes (i.e., outcomes) a label can take
in Multinomial Logistic Regression (default: 2).

>>> data = [
... LabeledPoint(0.0, [0.0, 1.0]),
Expand All @@ -237,12 +307,20 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType
def train(rdd, i):
return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
float(regParam), regType, bool(intercept), int(corrections),
float(tolerance))

float(tolerance), bool(validateData), int(numClasses))

if initialWeights is None:
if numClasses == 2:
initialWeights = [0.0] * len(data.first().features)
else:
if intercept:
initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1)
else:
initialWeights = [0.0] * len(data.first().features) * (numClasses - 1)
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)


class SVMModel(LinearBinaryClassificationModel):
class SVMModel(LinearClassificationModel):

"""A support vector machine.

Expand Down Expand Up @@ -325,7 +403,8 @@ class SVMWithSGD(object):

@classmethod
def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False):
miniBatchFraction=1.0, initialWeights=None, regType="l2",
intercept=False, validateData=True):
"""
Train a support vector machine on the given data.

Expand All @@ -351,11 +430,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
:param validateData: Boolean parameter which indicates if the
algorithm should validate data before training.
(default: True)
"""
def train(rdd, i):
return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
float(regParam), float(miniBatchFraction), i, regType,
bool(intercept))
bool(intercept), bool(validateData))

return _regression_train_wrapper(train, SVMModel, data, initialWeights)

Expand Down
10 changes: 8 additions & 2 deletions python/pyspark/mllib/regression.py
Expand Up @@ -160,13 +160,19 @@ def load(cls, sc, path):
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
from pyspark.mllib.classification import LogisticRegressionModel
first = data.first()
if not isinstance(first, LabeledPoint):
raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
if initial_weights is None:
initial_weights = [0.0] * len(data.first().features)
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)
if (modelClass == LogisticRegressionModel):
weights, intercept, numFeatures, numClasses = train_func(
data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept, numFeatures, numClasses)
else:
weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)


class LinearRegressionWithSGD(object):
Expand Down