diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 51ce19d29cd29..ea779f7d409cf 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,7 +24,8 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.Identifiable +import org.apache.spark.ml.util.{SchemaUtils, Identifiable} +import org.apache.spark.sql.types.{DataType, StructType} /** * :: AlphaComponent :: @@ -380,6 +381,18 @@ trait Params extends Identifiable with Serializable { this } + /** + * Check whether the given schema contains an input column. + * @param colName Input column name + * @param dataType Input column DataType + */ + protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { + val actualDataType = schema(colName).dataType + SchemaUtils.checkColumnType(schema, colName, dataType) + require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}") + } + + /** * Gets the default value of a parameter. */