-
Notifications
You must be signed in to change notification settings - Fork 28.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[SPARK-13568] [ML] Create feature transformer to impute missing values #11601
Conversation
Test build #52734 has finished for PR 11601 at commit
|
val strategy: Param[String] = new Param(this, "strategy", "strategy for imputation. " + | ||
"If mean, then replace missing values using the mean along the axis." + | ||
"If median, then replace missing values using the median along the axis." + | ||
"If most, then replace missing using the most frequent value along the axis.") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add a param validation function since there are a limited number of valid strategies? You can add an attribute like val supportedMissingValueStrategies = Set("mean", "median", "most")
to the Imputer
companion object like is done here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I added the validation to validateParameter. (which should be moved since it's the deprecated). Thanks for the suggestion. I'll add them.
Looking at the Jiras, it is unclear if any concrete decisions were made regarding handling Vectors and how NaN values should be handled in colStats. Is there any update? |
I prefer to keep Statistics.colStats(rdd) unchanged for now. As ut in this PR suggests, we can cover Double and Vector for now. |
Test build #52842 has finished for PR 11601 at commit
|
val colStatistics = $(strategy) match { | ||
case "mean" => | ||
filteredDF.selectExpr(s"avg($colName)").first().getDouble(0) | ||
case "median" => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should favour using the new approxQuantile
sql stat function here rather than computing exactly.
Test build #53923 has finished for PR 11601 at commit
|
Test build #53931 has finished for PR 11601 at commit
|
Test build #73268 has started for PR 11601 at commit |
Looks like CI was interrupted. |
/** @group getParam */ | ||
def getMissingValue: Double = $(missingValue) | ||
|
||
/** |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix comment indentation here.
* All Null values in the input column are treated as missing, and so are also imputed. | ||
*/ | ||
@Experimental | ||
class Imputer @Since("2.1.0")(override val uid: String) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All @Since
annotations -> 2.2.0
/** | ||
* Params for [[Imputer]] and [[ImputerModel]]. | ||
*/ | ||
private[feature] trait ImputerParams extends Params with HasInputCols with HasOutputCol { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't use HasOutputCol
anymore, correct?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, however I didn't get your first comment. Do you mean we should remove the import?
object Imputer extends DefaultParamsReadable[Imputer] { | ||
|
||
/** Set of strategy names that Imputer currently supports. */ | ||
private[ml] val supportedStrategyNames = Set("mean", "median") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we factor out the mean
and median
names in to private[ml] val
so to be used instead of the raw strings throughout?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, that's better.
case "mean" => filtered.select(avg(inputCol)).first().getDouble(0) | ||
case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0) | ||
} | ||
surrogate.asInstanceOf[Double] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is the asInstanceOf[Double]
necessary here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no, will remove it.
test("ImputerModel read/write") { | ||
val spark = this.spark | ||
import spark.implicits._ | ||
val surrogateDF = Seq(1.234).toDF("myInputCol") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be "surrogate" col name - though I see we don't actually use it in load or transform
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this happens to be the correct column name for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - we should add a test here to check the column names of instance
and newInstance
match up? (The below check is just for the actual values of the surrogate, correct?
var outputDF = dataset | ||
val surrogates = surrogateDF.head().getSeq[Double](0) | ||
|
||
$(inputCols).indices.foreach { i => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You could do $(inputCols).zip($(outputCols)).zip(surrogates).map { case ((inputCol, outputCol), icSurrogate) => ...
val localOutputCols = $(outputCols) | ||
var outputSchema = schema | ||
|
||
$(inputCols).indices.foreach { i => |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can do $(inputCols).zip($(outputCols)).foreach { case (inputCol, outputCol) => ...
} | ||
val surrogate = $(strategy) match { | ||
case "mean" => filtered.select(avg(inputCol)).first().getDouble(0) | ||
case "median" => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001)(0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.head
* Model fitted by [[Imputer]]. | ||
* | ||
* @param surrogateDF Value by which missing values in the input columns will be replaced. This | ||
* is stored using DataFrame with input column names and the corresponding surrogates. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is misleading - you're just storing the array of surrogates
... did you mean something different? Otherwise the comment must be changed,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It sounds like you had the idea of storing the surrogates something like:
+------+---------+
|column|surrogate|
+------+---------+
| col1| 1.2|
| col2| 3.4|
| col3| 5.4|
+------+---------+
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored it a little for better extensibility.
inputCol1 | inputCol2 |
---|---|
surrogate1 | surrogate2 |
jenkins retest this please |
Test build #73753 has finished for PR 11601 at commit
|
Thanks a lot for making a pass @MLnick. The last update mainly focused on the interface and behavior change. I'll make a pass and also address your comments. |
Hi @MLnick I changed the surrogateDF format for better extensibility in the last update and added unit tests for multi-column support. Let me know if I miss anything.
|
Test build #73868 has finished for PR 11601 at commit
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Made a pass. A few minor comments.
* The imputation strategy. | ||
* If "mean", then replace missing values using the mean value of the feature. | ||
* If "median", then replace missing values using the approximate median value of the | ||
* feature (relative error less than 0.001). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think remove the part (relative error less than 0.001)
.
This can be moved to the overall ScalaDoc for Imputer
at L95.
/** | ||
* :: Experimental :: | ||
* Imputation estimator for completing missing values, either using the mean or the median | ||
* of the column in which the missing values are located. The input column should be of |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As mentioned above at https://github.com/apache/spark/pull/11601/files#r104403880, you can add the note about relative error here.
Something like "For computing median, approxQuantile
is used with a relative error of X" (provide a ScalaDoc link to approxQuantile).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't add the link as it may break java doc generation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah right - perhaps just mention using approxQuantile?
@Since("2.2.0") | ||
def setMissingValue(value: Double): this.type = set(missingValue, value) | ||
|
||
import org.apache.spark.ml.feature.Imputer._ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This import should probably be above with the others (or within fit
)
} | ||
val surrogate = $(strategy) match { | ||
case Imputer.mean => filtered.select(avg(inputCol)).as[Double].first() | ||
case Imputer.median => filtered.stat.approxQuantile(inputCol, Array(0.5), 0.001).head |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really sure about the relative error here - perhaps 0.01
is sufficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Later perhaps we can even expose it as an expert param (but not for now)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I tried it before. 0.01 and 0.001 actually takes the same time for even a large dataset. Agree we can make it a param later.
override def transform(dataset: Dataset[_]): DataFrame = { | ||
transformSchema(dataset.schema, logging = true) | ||
var outputDF = dataset | ||
val surrogates = surrogateDF.select($(inputCols).head, $(inputCols).tail: _*).head().toSeq |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe this is slightly cleaner: surrogateDF.select($(inputCols).map(col): _*)
.setInputCols(Array("value1", "value2")) | ||
.setOutputCols(Array("out1")) | ||
.setStrategy(strategy) | ||
intercept[IllegalArgumentException] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also test for thrown message here and withClue
test("ImputerModel read/write") { | ||
val spark = this.spark | ||
import spark.implicits._ | ||
val surrogateDF = Seq(1.234).toDF("myInputCol") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ok - we should add a test here to check the column names of instance
and newInstance
match up? (The below check is just for the actual values of the surrogate, correct?
|
||
} | ||
|
||
object ImputerSuite{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
space before {
Seq("mean", "median").foreach { strategy => | ||
val imputer = new Imputer().setInputCols(Array("value")).setOutputCols(Array("out")) | ||
.setStrategy(strategy) | ||
intercept[SparkException] { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Check message here also.
)).toDF("id", "value1", "value2", "value3") | ||
Seq("mean", "median").foreach { strategy => | ||
// inputCols and outCols length different | ||
val imputer = new Imputer() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can also perhaps use withClue
to put a message for the subtest / exception assertion (e.g. withClue("Imputer should fail if inputCols and outputCols are different length")
Test build #74038 has finished for PR 11601 at commit
|
@@ -99,7 +98,8 @@ private[feature] trait ImputerParams extends Params with HasInputCols { | |||
* (SPARK-15041) and possibly creates incorrect values for a categorical feature. | |||
* | |||
* Note that the mean/median value is computed after filtering out missing values. | |||
* All Null values in the input column are treated as missing, and so are also imputed. | |||
* All Null values in the input column are treated as missing, and so are also imputed. For | |||
* computing median, DataFrameStatFunctions.approxQuantile is used with a relative error of 0.001. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I see it is here - nevermind
val ic = col(inputCol) | ||
val filtered = dataset.select(ic.cast(DoubleType)) | ||
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) | ||
if(filtered.rdd.isEmpty()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can do filtered.take(1).size == 0
which should be more efficient
.filter(ic.isNotNull && ic =!= $(missingValue) && !ic.isNaN) | ||
if(filtered.rdd.isEmpty()) { | ||
throw new SparkException(s"surrogate cannot be computed. " + | ||
s"All the values in $inputCol are Null, Nan or missingValue ($missingValue)") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
($missingValue)
-> ${$(missingValue)}
?
Made a few last comments. LGTM. cc @sethah @jkbradley I am going to merge this for 2.2. Let me know if you have any final comments. |
By the way out of curiosity, I tested things out on a cluster (4x workers, 192 cores & 480GB RAM total), with 100 columns of 100 million doubles each, 1% not cached cached |
Test build #74216 has finished for PR 11601 at commit
|
Thanks @MLnick for being the Shepherd and providing consistent help on discussion and review. The performance test matches what I got from my local environment. |
jenkins retest this please |
Created SPARK-19969 to track doc and examples to be done for 2.2 release. I can help with this if you're tied up. |
Test build #74651 has finished for PR 11601 at commit
|
Merged to master. Thanks @hhbyyh and also everyone for reviews. |
What changes were proposed in this pull request?
jira: https://issues.apache.org/jira/browse/SPARK-13568
It is quite common to encounter missing values in data sets. It would be useful to implement a Transformer that can impute missing data points, similar to e.g. Imputer in scikit-learn.
Initially, options for imputation could include mean, median and most frequent, but we could add various other approaches, where possible existing DataFrame code can be used (e.g. for approximate quantiles etc).
Currently this PR supports imputation for Double and Vector (null and NaN in Vector).
How was this patch tested?
new unit tests and manual test