From acf3e17b46e44b20b4144b585e08de7151b517d4 Mon Sep 17 00:00:00 2001 From: Joshi Date: Wed, 1 Jul 2015 16:07:35 -0700 Subject: [PATCH] update checkInputColumn to print more info if needed --- .../scala/org/apache/spark/ml/param/params.scala | 15 +-------------- .../org/apache/spark/ml/util/SchemaUtils.scala | 5 +++-- 2 files changed, 4 insertions(+), 16 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index ea779f7d409cf..51ce19d29cd29 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -24,8 +24,7 @@ import scala.annotation.varargs import scala.collection.mutable import org.apache.spark.annotation.AlphaComponent -import org.apache.spark.ml.util.{SchemaUtils, Identifiable} -import org.apache.spark.sql.types.{DataType, StructType} +import org.apache.spark.ml.util.Identifiable /** * :: AlphaComponent :: @@ -381,18 +380,6 @@ trait Params extends Identifiable with Serializable { this } - /** - * Check whether the given schema contains an input column. - * @param colName Input column name - * @param dataType Input column DataType - */ - protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = { - val actualDataType = schema(colName).dataType - SchemaUtils.checkColumnType(schema, colName, dataType) - require(actualDataType.equals(dataType), s"Input column Name: $colName Description: ${getParam(colName)}") - } - - /** * Gets the default value of a parameter. */ diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala index 0383bf0b382b7..9252618715625 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/SchemaUtils.scala @@ -34,10 +34,11 @@ object SchemaUtils { * @param colName column name * @param dataType required column data type */ - def checkColumnType(schema: StructType, colName: String, dataType: DataType): Unit = { + def checkColumnType(schema: StructType, colName: String, dataType: DataType, + msg: String = ""): Unit = { val actualDataType = schema(colName).dataType require(actualDataType.equals(dataType), - s"Column $colName must be of type $dataType but was actually $actualDataType.") + s"Column $colName must be of type $dataType but was actually $actualDataType.$msg") } /**