Skip to content

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
yanboliang committed Mar 27, 2015
1 parent b0d9c63 commit fc7990b
Showing 1 changed file with 41 additions and 22 deletions.
63 changes: 41 additions & 22 deletions python/pyspark/mllib/classification.py
Expand Up @@ -33,8 +33,9 @@

class LinearClassificationModel(LinearModel):
"""
Represents a linear binary classification model that predicts to whether an
example is positive (1.0) or negative (0.0).
A private abstract class represents a classification model that predicts to
which of a set of categories an example belongs. The categories are represented by
int values: 0, 1, 2, etc.
"""
def __init__(self, weights, intercept):
super(LinearClassificationModel, self).__init__(weights, intercept)
Expand Down Expand Up @@ -121,6 +122,18 @@ class LogisticRegressionModel(LinearClassificationModel):
... os.removedirs(path)
... except:
... pass
>>> multi_class_data = [
... LabeledPoint(0.0, [0.0, 1.0, 0.0]),
... LabeledPoint(1.0, [1.0, 0.0, 0.0]),
... LabeledPoint(2.0, [0.0, 0.0, 1.0])
... ]
>>> mcm = LogisticRegressionWithLBFGS.train(data=sc.parallelize(multi_class_data), numClasses=3)
>>> mcm.predict([0.0, 0.5, 0.0])
0
>>> mcm.predict([0.8, 0.0, 0.0])
1
>>> mcm.predict([0.0, 0.0, 0.3])
2
"""
def __init__(self, weights, intercept, numFeatures, numClasses):
super(LogisticRegressionModel, self).__init__(weights, intercept)
Expand All @@ -136,6 +149,20 @@ def numFeatures(self):
def numClasses(self):
return self._numClasses

@property
def dataWithBiasSize(self):
return self.weights.size / (self.numClasses - 1)

@property
def weightsMatrix(self):
if self.numClasses == 2:
return None
else:
if not isinstance(self.weights, DenseVector):
raise ValueError("weights only supports dense vector but got type "
+ type(self.weights))
return self.weights.toArray().reshape(self.numClasses - 1, self.dataWithBiasSize)

def predict(self, x):
"""
Predict values for a single data point or an RDD of points using
Expand All @@ -157,30 +184,21 @@ def predict(self, x):
else:
return 1 if prob > self._threshold else 0
else:
data_with_bias_size = self.weights.size / (self.numClasses - 1)
if not isinstance(self.weights, DenseVector):
raise ValueError("weights only supports dense vector but got type "
+ type(self.weights))
weights_matrix = self.weights.toArray().reshape(self.numClasses - 1,
data_with_bias_size)
margins = []
for i in range(0, self.numClasses - 1):
if x.size + 1 == data_with_bias_size:
margin = x.dot(weights_matrix[i][0:x.size]) + weights_matrix[i][x.size]
else:
margin = x.dot(weights_matrix[i])
margins.append(margin)
best_class = 0
max_margin = 0.0
for i in range(0, len(margins)):
if(margins[i] > max_margin):
max_margin = margins[i]
for i in range(0, self.numClasses - 1):
if x.size + 1 == self.dataWithBiasSize:
margin = x.dot(self.weightsMatrix[i][0:x.size]) + self.weightsMatrix[i][x.size]
else:
margin = x.dot(self.weightsMatrix[i])
if margin > max_margin:
max_margin = margin
best_class = i + 1
return best_class

def save(self, sc, path):
java_model = sc._jvm.org.apache.spark.mllib.classification.LogisticRegressionModel(
_py2java(sc, self._coeff), self.intercept)
_py2java(sc, self._coeff), self.intercept, self.numFeatures, self.numClasses)
java_model.save(sc._jsc.sc(), path)

@classmethod
Expand Down Expand Up @@ -295,10 +313,11 @@ def train(rdd, i):
if initialWeights is None:
if numClasses == 2:
initialWeights = [0.0] * len(data.first().features)
elif intercept:
initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1)
else:
initialWeights = [0.0] * len(data.first().features) * (numClasses - 1)
if intercept:
initialWeights = [0.0] * (len(data.first().features) + 1) * (numClasses - 1)
else:
initialWeights = [0.0] * len(data.first().features) * (numClasses - 1)
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)


Expand Down

0 comments on commit fc7990b

Please sign in to comment.