-
Notifications
You must be signed in to change notification settings - Fork 28k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-21633][ML][Python] UnaryTransformer in Python #18746
[SPARK-21633][ML][Python] UnaryTransformer in Python #18746
Conversation
@jkbradley @thunterdb @MrBago Could you please review this? |
Test build #79988 has finished for PR 18746 at commit
|
Test build #79991 has finished for PR 18746 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suggest add a python example code for UnaryTramsformer
in python. Like the scala example MyTransformer
.
python/pyspark/ml/base.py
Outdated
return StructType(outputFields) | ||
|
||
def transform(self, dataset, paramMap=None): | ||
transformSchema(dataset.schema()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here seems exist some problem.
The transform provide paramMap
, but createTransformFunc
has no way to get the passed in paramMap
, here lost something I think.
Because custom UnaryTransformer will only need to override the createTransformFunc
, the base class need to handle the passed in paramMap
properly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Right, I accidentally overrode transform instead of _transform. Fixed!
python/pyspark/ml/base.py
Outdated
def transform(self, dataset, paramMap=None): | ||
transformSchema(dataset.schema()) | ||
transformUDF = udf(self.createTransformFunc(), self.outputDataType()) | ||
dataset.withColumn(self.getOutputCol(), transformUDF(self.getInputCol())) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The udf
need first parameter to be a function, but here why you pass in the return value of self.createTransformFunc
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self.createTransformFunc returns a function which is passed to the udf so in this case I think it is okay
Test build #80183 has finished for PR 18746 at commit
|
@ajaysaini725 Is there a JIRA for this PR? Please tag this PR in the title. |
Also, you can remove "implemented" from the title. & update the description now that you have tests, please |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK done with review pass. Thanks for the PR!
python/pyspark/ml/base.py
Outdated
@inherit_doc | ||
class UnaryTransformer(HasInputCol, HasOutputCol, Transformer): | ||
""" | ||
Abstract class for transformers that tae one input column, apply a transoformation to it, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
typo: tae
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually multiple typos. Why not just copy the text from Scala?
python/pyspark/ml/base.py
Outdated
@abstractmethod | ||
def createTransformFunc(self): | ||
""" | ||
Creates the transoform function using the given param map. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use the IntelliJ spellcheck feature
python/pyspark/ml/base.py
Outdated
|
||
def _transform(self, dataset): | ||
self.transformSchema(dataset.schema) | ||
transformFunc = self.createTransformFunc() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove
python/pyspark/ml/tests.py
Outdated
df = df.withColumn("input", df.input.cast(dataType="double")) | ||
|
||
transformed_df = transformer.transform(df) | ||
output = transformed_df.select("output").collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's better practice to select both input & output and collect both for comparison, rather than relying on DataFrame rows maintaining their order.
python/pyspark/ml/tests.py
Outdated
@@ -1957,6 +1987,24 @@ def test_chisquaretest(self): | |||
self.assertTrue(all(field in fieldNames for field in expectedFields)) | |||
|
|||
|
|||
class UnaryTransformerTests(SparkSessionTestCase): | |||
|
|||
def test_unary_transformer_transform(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you also please test validateInputType?
Test build #80225 has finished for PR 18746 at commit
|
@ajaysaini725 Is there a JIRA for this PR? Please tag this PR in the title. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just 1 comment left!
python/pyspark/ml/tests.py
Outdated
df = df.withColumn("input", df.input.cast(dataType="double")) | ||
|
||
transformed_df = transformer.transform(df) | ||
inputCol = transformed_df.select("input").collect() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do this instead:
results = transformed_df.select("input", "output").collect()
for res in results:
self.assertEqual(res.input + shiftVal, res.output)
Test build #80228 has finished for PR 18746 at commit
|
LGTM |
What changes were proposed in this pull request?
Implemented UnaryTransformer in Python.
How was this patch tested?
This patch was tested by creating a MockUnaryTransformer class in the unit tests that extends UnaryTransformer and testing that the transform function produced correct output.