Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ org.apache.spark.ml.classification.RandomForestClassifier
org.apache.spark.ml.classification.GBTClassifier

# regression
org.apache.spark.ml.regression.AFTSurvivalRegression
org.apache.spark.ml.regression.IsotonicRegression
org.apache.spark.ml.regression.LinearRegression
org.apache.spark.ml.regression.GeneralizedLinearRegression
org.apache.spark.ml.regression.DecisionTreeRegressor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ org.apache.spark.ml.classification.RandomForestClassificationModel
org.apache.spark.ml.classification.GBTClassificationModel

# regression
org.apache.spark.ml.regression.AFTSurvivalRegressionModel
org.apache.spark.ml.regression.IsotonicRegressionModel
org.apache.spark.ml.regression.LinearRegressionModel
org.apache.spark.ml.regression.GeneralizedLinearRegressionModel
org.apache.spark.ml.regression.DecisionTreeRegressionModel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,9 @@ class AFTSurvivalRegressionModel private[ml] (
extends RegressionModel[Vector, AFTSurvivalRegressionModel] with AFTSurvivalRegressionParams
with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("aftSurvReg"),
Vectors.empty, Double.NaN, Double.NaN)

@Since("3.0.0")
override def numFeatures: Int = coefficients.size

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class IsotonicRegressionModel private[ml] (
private val oldModel: MLlibIsotonicRegressionModel)
extends Model[IsotonicRegressionModel] with IsotonicRegressionBase with MLWritable {

private[ml] def this() = this(Identifiable.randomUID("isoReg"), null)

/** @group setParam */
@Since("1.5.0")
def setFeaturesCol(value: String): this.type = set(featuresCol, value)
Expand Down
102 changes: 102 additions & 0 deletions python/pyspark/ml/tests/test_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@
from pyspark.ml.linalg import Vectors
from pyspark.sql import SparkSession
from pyspark.ml.regression import (
AFTSurvivalRegression,
AFTSurvivalRegressionModel,
IsotonicRegression,
IsotonicRegressionModel,
LinearRegression,
LinearRegressionModel,
GeneralizedLinearRegression,
Expand Down Expand Up @@ -57,6 +61,104 @@ def df(self):
.sortWithinPartitions("weight")
)

def test_aft_survival(self):
spark = self.spark
df = spark.createDataFrame(
[(1.0, Vectors.dense(1.0), 1.0), (1e-40, Vectors.sparse(1, [], []), 0.0)],
["label", "features", "censor"],
)

aft = AFTSurvivalRegression()
aft.setMaxIter(1)
self.assertEqual(aft.getMaxIter(), 1)

model = aft.fit(df)
self.assertEqual(aft.uid, model.uid)
self.assertEqual(model.numFeatures, 1)
self.assertTrue(np.allclose(model.intercept, 0.0, atol=1e-4), model.intercept)
self.assertTrue(
np.allclose(model.coefficients.toArray(), [0.0], atol=1e-4), model.coefficients
)
self.assertTrue(np.allclose(model.scale, 1.0, atol=1e-4), model.scale)

vec = Vectors.dense(6.3)
pred = model.predict(vec)
self.assertEqual(pred, 1.0)
pred = model.predictQuantiles(vec)
self.assertTrue(
np.allclose(
pred,
[
0.010050335853501444,
0.051293294387550536,
0.1053605156578263,
0.2876820724517809,
0.6931471805599453,
1.3862943611198906,
2.302585092994046,
2.9957322735539895,
4.60517018598809,
],
atol=1e-4,
),
pred,
)

output = model.transform(df)
expected_cols = ["label", "features", "censor", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 2)

# Model save & load
with tempfile.TemporaryDirectory(prefix="aft_survival") as d:
aft.write().overwrite().save(d)
aft2 = AFTSurvivalRegression.load(d)
self.assertEqual(str(aft), str(aft2))

model.write().overwrite().save(d)
model2 = AFTSurvivalRegressionModel.load(d)
self.assertEqual(str(model), str(model2))

def test_isotonic_regression(self):
spark = self.spark
df = spark.createDataFrame(
[(1.0, Vectors.dense(1.0)), (0.0, Vectors.sparse(1, [], []))], ["label", "features"]
)

ir = IsotonicRegression(
isotonic=True,
featureIndex=0,
)
self.assertTrue(ir.getIsotonic())
self.assertEqual(ir.getFeatureIndex(), 0)

model = ir.fit(df)
self.assertEqual(model.numFeatures, 1)
self.assertTrue(
np.allclose(model.boundaries.toArray(), [0.0, 1.0], atol=1e-4), model.boundaries
)
self.assertTrue(
np.allclose(model.predictions.toArray(), [0.0, 1.0], atol=1e-4), model.predictions
)

pred = model.predict(1.0)
self.assertTrue(np.allclose(pred, 1.0, atol=1e-4), pred)

output = model.transform(df)
expected_cols = ["label", "features", "prediction"]
self.assertEqual(output.columns, expected_cols)
self.assertEqual(output.count(), 2)

# Model save & load
with tempfile.TemporaryDirectory(prefix="isotonic_regression") as d:
ir.write().overwrite().save(d)
ir2 = IsotonicRegression.load(d)
self.assertEqual(str(ir), str(ir2))

model.write().overwrite().save(d)
model2 = IsotonicRegressionModel.load(d)
self.assertEqual(str(model), str(model2))

def test_linear_regression(self):
df = self.df
lr = LinearRegression(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -532,6 +532,12 @@ private[ml] object MLUtils {
(classOf[BinaryLogisticRegressionSummary], Set("scoreCol")),

// Regression Models
(
classOf[AFTSurvivalRegressionModel],
Set("intercept", "coefficients", "scale", "predictQuantiles")),
(
classOf[IsotonicRegressionModel],
Set("boundaries", "predictions", "numFeatures", "predict")),
(
classOf[GeneralizedLinearRegressionModel],
Set("intercept", "coefficients", "numFeatures", "evaluate")),
Expand Down
Loading