Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,24 @@
},
"sqlState" : "56K00"
},
"CONNECT_ML" : {
"message" : [
"Generic Spark Connect ML error."
],
"subClass" : {
"ATTRIBUTE_NOT_ALLOWED" : {
"message" : [
"<attribute> is not allowed to be accessed."
]
},
"UNSUPPORTED_EXCEPTION" : {
"message" : [
"<message>"
]
}
},
"sqlState" : "XX000"
},
"CONVERSION_INVALID_INPUT" : {
"message" : [
"The value <str> (<fmt>) cannot be converted to <targetType> because it is malformed. Correct the value as per the syntax, or change its format. Use <suggestion> to tolerate malformed input and return NULL instead."
Expand Down
2 changes: 2 additions & 0 deletions dev/sparktestsupport/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,6 +686,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_legacy_mode_classification",
"pyspark.ml.tests.connect.test_legacy_mode_pipeline",
"pyspark.ml.tests.connect.test_legacy_mode_tuning",
"pyspark.ml.tests.test_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy and it isn't available there
Expand Down Expand Up @@ -1106,6 +1107,7 @@ def __hash__(self):
"pyspark.ml.tests.connect.test_connect_classification",
"pyspark.ml.tests.connect.test_connect_pipeline",
"pyspark.ml.tests.connect.test_connect_tuning",
"pyspark.ml.tests.connect.test_connect_spark_ml_classification",
],
excluded_python_implementations=[
"PyPy" # Skip these tests under PyPy since they require numpy, pandas, and pyarrow and
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml estimators.
# So register the supported estimator here if you're trying to add a new one.
org.apache.spark.ml.classification.LogisticRegression
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Spark Connect ML uses ServiceLoader to find out the supported Spark Ml non-model transformer.
# So register the supported transformer here if you're trying to add a new one.
org.apache.spark.ml.feature.VectorAssembler
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark.ml.classification

import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.util.Summary
import org.apache.spark.mllib.evaluation.{BinaryClassificationMetrics, MulticlassMetrics}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions.{col, lit}
Expand All @@ -28,7 +29,7 @@ import org.apache.spark.sql.types.DoubleType
/**
* Abstraction for multiclass classification results for a given model.
*/
private[classification] trait ClassificationSummary extends Serializable {
private[classification] trait ClassificationSummary extends Summary with Serializable {

/**
* Dataframe output by the model's `transform` method.
Expand Down
9 changes: 7 additions & 2 deletions mllib/src/main/scala/org/apache/spark/ml/param/params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.spark.ml.param

import java.lang.reflect.Modifier
import java.util.{List => JList}
import java.util.NoSuchElementException

import scala.annotation.varargs
import scala.collection.mutable
import scala.jdk.CollectionConverters._
import scala.reflect.ClassTag

import org.json4s._
import org.json4s.jackson.JsonMethods._
Expand All @@ -45,9 +45,14 @@ import org.apache.spark.util.ArrayImplicits._
* See [[ParamValidators]] for factory methods for common validation functions.
* @tparam T param value type
*/
class Param[T](val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
class Param[T: ClassTag](
val parent: String, val name: String, val doc: String, val isValid: T => Boolean)
extends Serializable {

// Spark Connect ML needs T type information which has been erased when compiling,
// Use classTag to preserve the T type.
val paramValueClassTag = implicitly[ClassTag[T]]

def this(parent: Identifiable, name: String, doc: String, isValid: T => Boolean) =
this(parent.uid, name, doc, isValid)

Expand Down
28 changes: 28 additions & 0 deletions mllib/src/main/scala/org/apache/spark/ml/util/Summary.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.ml.util

import org.apache.spark.annotation.Since

/**
* Trait for the Summary
* All the summaries should extend from this Summary in order to
* support connect.
*/
@Since("4.0.0")
private[spark] trait Summary
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.util.List;

import org.apache.spark.ml.util.Identifiable$;
import scala.reflect.ClassTag;

/**
* A subclass of Params for testing.
Expand Down Expand Up @@ -110,7 +111,7 @@ private void init() {
ParamValidators.inRange(0.0, 1.0));
List<String> validStrings = Arrays.asList("a", "b");
myStringParam_ = new Param<>(this, "myStringParam", "this is a string param",
ParamValidators.inArray(validStrings));
ParamValidators.inArray(validStrings), ClassTag.apply(String.class));
myDoubleArrayParam_ =
new DoubleArrayParam(this, "myDoubleArrayParam", "this is a double param");

Expand Down
7 changes: 7 additions & 0 deletions python/pyspark/ml/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
HasSolver,
HasParallelism,
)
from pyspark.ml.remote.util import try_remote_attribute_relation
from pyspark.ml.tree import (
_DecisionTreeModel,
_DecisionTreeParams,
Expand Down Expand Up @@ -336,6 +337,7 @@ class _ClassificationSummary(JavaWrapper):

@property
@since("3.1.0")
@try_remote_attribute_relation
def predictions(self) -> DataFrame:
"""
Dataframe outputted by the model's `transform` method.
Expand Down Expand Up @@ -521,6 +523,7 @@ def scoreCol(self) -> str:
return self._call_java("scoreCol")

@property
@try_remote_attribute_relation
def roc(self) -> DataFrame:
"""
Returns the receiver operating characteristic (ROC) curve,
Expand All @@ -546,6 +549,7 @@ def areaUnderROC(self) -> float:

@property
@since("3.1.0")
@try_remote_attribute_relation
def pr(self) -> DataFrame:
"""
Returns the precision-recall curve, which is a Dataframe
Expand All @@ -556,6 +560,7 @@ def pr(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def fMeasureByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, F-Measure) curve
Expand All @@ -565,6 +570,7 @@ def fMeasureByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def precisionByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, precision) curve.
Expand All @@ -575,6 +581,7 @@ def precisionByThreshold(self) -> DataFrame:

@property
@since("3.1.0")
@try_remote_attribute_relation
def recallByThreshold(self) -> DataFrame:
"""
Returns a dataframe with two fields (threshold, recall) curve.
Expand Down
16 changes: 16 additions & 0 deletions python/pyspark/ml/remote/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
76 changes: 76 additions & 0 deletions python/pyspark/ml/remote/proto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
#
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're a bit in a bad position here because, high-level we should use the pyspark.ml.connect module, but this is occupied by the previous approach of the implementation.

I'm wondering if we should do some cleanup later to delete the previous code and move this one to the right package.

@WeichenXu123 what do you think?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I could also move the existing "remote" to the connect if necessary.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed that we'd better move it to the right package

# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, TYPE_CHECKING, List

import pyspark.sql.connect.proto as pb2
from pyspark.sql.connect.plan import LogicalPlan

if TYPE_CHECKING:
from pyspark.sql.connect.client import SparkConnectClient


class TransformerRelation(LogicalPlan):
"""A logical plan for transforming of a transformer which could be a cached model
or a non-model transformer like VectorAssembler."""

def __init__(
self,
child: Optional["LogicalPlan"],
name: str,
ml_params: pb2.MlParams,
uid: str = "",
is_model: bool = True,
) -> None:
super().__init__(child)
self._name = name
self._ml_params = ml_params
self._uid = uid
self._is_model = is_model

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
assert self._child is not None
plan = self._create_proto_relation()
plan.ml_relation.transform.input.CopyFrom(self._child.plan(session))

if self._is_model:
plan.ml_relation.transform.obj_ref.CopyFrom(pb2.ObjectRef(id=self._name))
else:
plan.ml_relation.transform.transformer.CopyFrom(
pb2.MlOperator(name=self._name, uid=self._uid, type=pb2.MlOperator.TRANSFORMER)
)

if self._ml_params is not None:
plan.ml_relation.transform.params.CopyFrom(self._ml_params)

return plan


class AttributeRelation(LogicalPlan):
"""A logical plan used in ML to represent an attribute of an instance, which
could be a model or a summary. This attribute returns a DataFrame.
"""

def __init__(self, ref_id: str, methods: List[pb2.Fetch.Method]) -> None:
super().__init__(None)
self._ref_id = ref_id
self._methods = methods

def plan(self, session: "SparkConnectClient") -> pb2.Relation:
plan = self._create_proto_relation()
plan.ml_relation.fetch.obj_ref.CopyFrom(pb2.ObjectRef(id=self._ref_id))
plan.ml_relation.fetch.methods.extend(self._methods)
return plan
Loading
Loading