From 09cbbb14eef917bb0615c39e8b60c9fa70c9e750 Mon Sep 17 00:00:00 2001 From: Lee Dongjin Date: Mon, 11 Jun 2018 17:49:06 +0900 Subject: [PATCH] [SPARK-24513][ML] Attribute support in UnaryTransformer --- .../main/scala/org/apache/spark/ml/Transformer.scala | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala index a3a2b55adc25d..14390527a4018 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala @@ -116,10 +116,17 @@ abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]] StructType(outputFields) } + /** + * Returns [[Metadata]] to be attached to the output column. + */ + protected def outputMetadata(outputSchema: StructType, dataset: Dataset[_]): Metadata = + Metadata.empty + override def transform(dataset: Dataset[_]): DataFrame = { - transformSchema(dataset.schema, logging = true) + val outputSchema = transformSchema(dataset.schema, logging = true) val transformUDF = udf(this.createTransformFunc, outputDataType) - dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol)))) + val metadata = outputMetadata(outputSchema, dataset) + dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))), metadata) } override def copy(extra: ParamMap): T = defaultCopy(extra)