New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-6255] [MLLIB] Support multiclass classification in Python API #5137
Changes from 3 commits
ded847c
b0d9c63
fc7990b
444d5e2
0bd531e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
|
||
from pyspark import RDD | ||
from pyspark.mllib.common import callMLlibFunc, _py2java, _java2py | ||
from pyspark.mllib.linalg import SparseVector, _convert_to_vector | ||
from pyspark.mllib.linalg import DenseVector, SparseVector, _convert_to_vector | ||
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper | ||
from pyspark.mllib.util import Saveable, Loader, inherit_doc | ||
|
||
|
@@ -31,13 +31,14 @@ | |
'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes'] | ||
|
||
|
||
class LinearBinaryClassificationModel(LinearModel): | ||
class LinearClassificationModel(LinearModel): | ||
""" | ||
Represents a linear binary classification model that predicts to whether an | ||
example is positive (1.0) or negative (0.0). | ||
A private abstract class represents a classification model that predicts to | ||
which of a set of categories an example belongs. The categories are represented by | ||
int values: 0, 1, 2, etc. | ||
""" | ||
def __init__(self, weights, intercept): | ||
super(LinearBinaryClassificationModel, self).__init__(weights, intercept) | ||
super(LinearClassificationModel, self).__init__(weights, intercept) | ||
self._threshold = None | ||
|
||
def setThreshold(self, value): | ||
|
@@ -50,6 +51,15 @@ def setThreshold(self, value): | |
""" | ||
self._threshold = value | ||
|
||
def getThreshold(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I realized that this should probably be a property called "threshold" instead of a method like this, in order to fit with existing Python conventions. |
||
""" | ||
.. note:: Experimental | ||
|
||
Returns the threshold (if any) used for converting raw prediction scores | ||
into 0/1 predictions. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add note: For binary classification only. |
||
""" | ||
return self._threshold | ||
|
||
def clearThreshold(self): | ||
""" | ||
.. note:: Experimental | ||
|
@@ -66,7 +76,7 @@ def predict(self, test): | |
raise NotImplementedError | ||
|
||
|
||
class LogisticRegressionModel(LinearBinaryClassificationModel): | ||
class LogisticRegressionModel(LinearClassificationModel): | ||
|
||
"""A linear binary classification model derived from logistic regression. | ||
|
||
|
@@ -112,11 +122,47 @@ class LogisticRegressionModel(LinearBinaryClassificationModel): | |
... os.removedirs(path) | ||
... except: | ||
... pass | ||
>>> multi_class_data = [ | ||
... LabeledPoint(0.0, [0.0, 1.0, 0.0]), | ||
... LabeledPoint(1.0, [1.0, 0.0, 0.0]), | ||
... LabeledPoint(2.0, [0.0, 0.0, 1.0]) | ||
... ] | ||
>>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3) | ||
>>> mcm.predict([0.0, 0.5, 0.0]) | ||
0 | ||
>>> mcm.predict([0.8, 0.0, 0.0]) | ||
1 | ||
>>> mcm.predict([0.0, 0.0, 0.3]) | ||
2 | ||
""" | ||
def __init__(self, weights, intercept): | ||
def __init__(self, weights, intercept, numFeatures, numClasses): | ||
super(LogisticRegressionModel, self).__init__(weights, intercept) | ||
self._numFeatures = int(numFeatures) | ||
self._numClasses = int(numClasses) | ||
self._threshold = 0.5 | ||
|
||
@property | ||
def numFeatures(self): | ||
return self._numFeatures | ||
|
||
@property | ||
def numClasses(self): | ||
return self._numClasses | ||
|
||
@property | ||
def dataWithBiasSize(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think we want to expose dataWithBiasSize or weightsMatrix. Can you please define them in init and add underscores in front of their names to make it clear they are private? |
||
return self.weights.size / (self.numClasses - 1) | ||
|
||
@property | ||
def weightsMatrix(self): | ||
if self.numClasses == 2: | ||
return None | ||
else: | ||
if not isinstance(self.weights, DenseVector): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use _convert_to_vector as in LinearModel. That will check types and support Python types beyond DenseVector. |
||
raise ValueError("weights only supports dense vector but got type " | ||
+ type(self.weights)) | ||
return self.weights.toArray().reshape(self.numClasses - 1, self.dataWithBiasSize) | ||
|
||
def predict(self, x): | ||
""" | ||
Predict values for a single data point or an RDD of points using | ||
|
@@ -126,20 +172,33 @@ def predict(self, x): | |
return x.map(lambda v: self.predict(v)) | ||
|
||
x = _convert_to_vector(x) | ||
margin = self.weights.dot(x) + self._intercept | ||
if margin > 0: | ||
prob = 1 / (1 + exp(-margin)) | ||
if self.numClasses == 2: | ||
margin = self.weights.dot(x) + self._intercept | ||
if margin > 0: | ||
prob = 1 / (1 + exp(-margin)) | ||
else: | ||
exp_margin = exp(margin) | ||
prob = exp_margin / (1 + exp_margin) | ||
if self._threshold is None: | ||
return prob | ||
else: | ||
return 1 if prob > self._threshold else 0 | ||
else: | ||
exp_margin = exp(margin) | ||
prob = exp_margin / (1 + exp_margin) | ||
if self._threshold is None: | ||
return prob | ||
else: | ||
return 1 if prob > self._threshold else 0 | ||
best_class = 0 | ||
max_margin = 0.0 | ||
for i in range(0, self.numClasses - 1): | ||
if x.size + 1 == self.dataWithBiasSize: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This test can be moved outside of the loop. |
||
margin = x.dot(self.weightsMatrix[i][0:x.size]) + self.weightsMatrix[i][x.size] | ||
else: | ||
margin = x.dot(self.weightsMatrix[i]) | ||
if margin > max_margin: | ||
max_margin = margin | ||
best_class = i + 1 | ||
return best_class | ||
|
||
def save(self, sc, path): | ||
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be updated to take numClasses, numFeatures |
||
_py2java(sc, self._coeff), self.intercept) | ||
_py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses) | ||
java_model.save(sc._jsc.sc(), path) | ||
|
||
@classmethod | ||
|
@@ -148,8 +207,10 @@ def load(cls, sc, path): | |
sc._jsc.sc(), path) | ||
weights = _java2py(sc, java_model.weights()) | ||
intercept = java_model.intercept() | ||
numFeatures = java_model.numFeatures() | ||
numClasses = java_model.numClasses() | ||
threshold = java_model.getThreshold().get() | ||
model = LogisticRegressionModel(weights, intercept) | ||
model = LogisticRegressionModel(weights, intercept, numFeatures, numClasses) | ||
model.setThreshold(threshold) | ||
return model | ||
|
||
|
@@ -158,7 +219,8 @@ class LogisticRegressionWithSGD(object): | |
|
||
@classmethod | ||
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, | ||
initialWeights=None, regParam=0.01, regType="l2", intercept=False): | ||
initialWeights=None, regParam=0.01, regType="l2", intercept=False, | ||
validateData=True): | ||
""" | ||
Train a logistic regression model on the given data. | ||
|
||
|
@@ -184,11 +246,14 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0, | |
or not of the augmented representation for | ||
training data (i.e. whether bias features | ||
are activated or not). | ||
:param validateData: Boolean parameter which indicates if the | ||
algorithm should validate data before training. | ||
(default: True) | ||
""" | ||
def train(rdd, i): | ||
return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations), | ||
float(step), float(miniBatchFraction), i, float(regParam), regType, | ||
bool(intercept)) | ||
bool(intercept), bool(validateData)) | ||
|
||
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) | ||
|
||
|
@@ -197,7 +262,7 @@ class LogisticRegressionWithLBFGS(object): | |
|
||
@classmethod | ||
def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2", | ||
intercept=False, corrections=10, tolerance=1e-4): | ||
intercept=False, corrections=10, tolerance=1e-4, validateData=True, numClasses=2): | ||
""" | ||
Train a logistic regression model on the given data. | ||
|
||
|
@@ -223,6 +288,12 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType | |
update (default: 10). | ||
:param tolerance: The convergence tolerance of iterations for | ||
L-BFGS (default: 1e-4). | ||
:param validateData: Boolean parameter which indicates if the | ||
algorithm should validate data before training. | ||
(default: True) | ||
:param numClasses: The number of possible outcomes for k classes | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better phrasing: "The number of possible outcomes for k classes classification problem" |
||
classification problem in Multinomial Logistic | ||
Regression (default: 2). | ||
|
||
>>> data = [ | ||
... LabeledPoint(0.0, [0.0, 1.0]), | ||
|
@@ -237,12 +308,20 @@ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType | |
def train(rdd, i): | ||
return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i, | ||
float(regParam), regType, bool(intercept), int(corrections), | ||
float(tolerance)) | ||
|
||
float(tolerance), bool(validateData), int(numClasses)) | ||
|
||
if initialWeights is None: | ||
if numClasses == 2: | ||
initialWeights = [0.0] * len(data.first().features) | ||
else: | ||
if intercept: | ||
initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1) | ||
else: | ||
initialWeights = [0.0] * len(data.first().features) * (numClasses - 1) | ||
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights) | ||
|
||
|
||
class SVMModel(LinearBinaryClassificationModel): | ||
class SVMModel(LinearClassificationModel): | ||
|
||
"""A support vector machine. | ||
|
||
|
@@ -325,7 +404,8 @@ class SVMWithSGD(object): | |
|
||
@classmethod | ||
def train(cls, data, iterations=100, step=1.0, regParam=0.01, | ||
miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False): | ||
miniBatchFraction=1.0, initialWeights=None, regType="l2", | ||
intercept=False, validateData=True): | ||
""" | ||
Train a support vector machine on the given data. | ||
|
||
|
@@ -351,11 +431,14 @@ def train(cls, data, iterations=100, step=1.0, regParam=0.01, | |
or not of the augmented representation for | ||
training data (i.e. whether bias features | ||
are activated or not). | ||
:param validateData: Boolean parameter which indicates if the | ||
algorithm should validate data before training. | ||
(default: True) | ||
""" | ||
def train(rdd, i): | ||
return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step), | ||
float(regParam), float(miniBatchFraction), i, regType, | ||
bool(intercept)) | ||
bool(intercept), bool(validateData)) | ||
|
||
return _regression_train_wrapper(train, SVMModel, data, initialWeights) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better phrasing: "A private abstract class represents a classification model that predicts to which of a set of categories an example belongs."
-->
"A private abstract class representing a multiclass classification model."