diff --git a/core/src/main/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifier.scala b/core/src/main/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifier.scala index d71d162f6a..7810c9a9c2 100644 --- a/core/src/main/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifier.scala +++ b/core/src/main/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifier.scala @@ -45,10 +45,10 @@ class HumanNameIdentifier ( uid: String = UID[HumanNameIdentifier], operationName: String = "human name identifier" -) extends UnaryEstimator[Text, Text]( +) extends UnaryEstimator[Text, NameList]( uid = uid, operationName = operationName -) with UniqueCountFun { +) { // Parameter val defaultThreshold = new DoubleParam( parent = this, @@ -128,10 +128,9 @@ class HumanNameIdentifier } } + class HumanNameIdentifierModel(override val uid: String, val treatAsName: Boolean) - extends UnaryModel[Text, Text]("human name identifier", uid) { - // For now, will just return a copy of the text, typed correctly - // Eventually we will want to map into other fields like Race or Gender - def transformFn: Text => Text = input => - if (treatAsName) Name(input.value) else input + extends UnaryModel[Text, NameList]("human name identifier", uid) { + def transformFn: Text => NameList = input => + if (treatAsName) NameList(Seq(Name(input.value))) else NameList(Seq.empty) } diff --git a/core/src/test/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifierTest.scala b/core/src/test/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifierTest.scala index 0577ecbe96..7785d43bba 100644 --- a/core/src/test/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifierTest.scala +++ b/core/src/test/scala/com/salesforce/op/stages/impl/feature/HumanNameIdentifierTest.scala @@ -30,15 +30,18 @@ package com.salesforce.op.stages.impl.feature +import com.salesforce.op.features.Feature import com.salesforce.op.features.types._ import com.salesforce.op.stages.base.unary.{UnaryEstimator, UnaryModel} import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder} +import org.apache.spark.sql.DataFrame import org.junit.runner.RunWith import org.scalatest.junit.JUnitRunner @RunWith(classOf[JUnitRunner]) -class HumanNameIdentifierTest extends OpEstimatorSpec[Text, UnaryModel[Text, Text], UnaryEstimator[Text, Text]] { +class HumanNameIdentifierTest + extends OpEstimatorSpec[NameList, UnaryModel[Text, NameList], UnaryEstimator[Text, NameList]] { /** * Input Dataset to fit & transform */ @@ -52,9 +55,9 @@ class HumanNameIdentifierTest extends OpEstimatorSpec[Text, UnaryModel[Text, Tex /** * Expected result of the transformer applied on the Input Dataset */ - val expectedResult: Seq[Text] = Seq("Blah").map(_.toText) + val expectedResult: Seq[NameList] = Seq(Seq(Name("Blah"))).map(_.toNameList) - private def identifyName(data: Seq[Text]) = { + private def identifyName(data: Seq[Text]): (DataFrame, Feature[Text], UnaryModel[Text, NameList], DataFrame) = { val (newData, newFeature) = TestFeatureBuilder(data) val model = estimator.setInput(newFeature).fit(newData) val newResult = model.transform(newData) diff --git a/features/src/main/scala/com/salesforce/op/features/FeatureSparkTypes.scala b/features/src/main/scala/com/salesforce/op/features/FeatureSparkTypes.scala index 73eee1f357..3dbdc856f2 100644 --- a/features/src/main/scala/com/salesforce/op/features/FeatureSparkTypes.scala +++ b/features/src/main/scala/com/salesforce/op/features/FeatureSparkTypes.scala @@ -74,12 +74,14 @@ case object FeatureSparkTypes { val City = Text val PostalCode = Text val Street = Text + val Name = Text // Vector val OPVector = VectorType // Lists val TextList = ArrayType(Text, containsNull = true) + val NameList = ArrayType(Name, containsNull = true) val DateList = ArrayType(Date, containsNull = true) val DateTimeList = DateList val Geolocation = ArrayType(Real, containsNull = true) @@ -127,6 +129,7 @@ case object FeatureSparkTypes { // Lists case wt if wt =:= weakTypeOf[t.TextList] => TextList + case wt if wt =:= weakTypeOf[t.NameList] => NameList case wt if wt =:= weakTypeOf[t.DateList] => DateList case wt if wt =:= weakTypeOf[t.DateTimeList] => DateTimeList case wt if wt =:= weakTypeOf[t.Geolocation] => Geolocation diff --git a/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeDefaults.scala b/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeDefaults.scala index 7308538c54..8abf69802d 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeDefaults.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/FeatureTypeDefaults.scala @@ -71,6 +71,7 @@ case object FeatureTypeDefaults { // Lists val TextList = new t.TextList(Seq.empty) + val NameList = new t.NameList(Seq.empty) val DateList = new t.DateList(Seq.empty) val DateTimeList = new t.DateTimeList(Seq.empty) val Geolocation = new t.Geolocation(Seq.empty) diff --git a/features/src/main/scala/com/salesforce/op/features/types/Lists.scala b/features/src/main/scala/com/salesforce/op/features/types/Lists.scala index eebf1f7a47..a4096021c0 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/Lists.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/Lists.scala @@ -43,6 +43,19 @@ object TextList { def empty: TextList = FeatureTypeDefaults.TextList } +/** + * A list of Name values + * + * @param value list of Name values + */ +class NameList(val value: Seq[Name]) extends OPList[Text] { + def this(v: Name*)(implicit d: DummyImplicit) = this(v) +} +object NameList { + def apply(value: Seq[Name]): NameList = new NameList(value) + def empty: NameList = FeatureTypeDefaults.NameList +} + /** * A list of date values * diff --git a/features/src/main/scala/com/salesforce/op/features/types/Text.scala b/features/src/main/scala/com/salesforce/op/features/types/Text.scala index edef39795b..ea104ca36f 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/Text.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/Text.scala @@ -309,8 +309,10 @@ object Street { * * @param value name */ -class Name(value: Option[String]) extends Text(value){ +class Name(value: Option[String], isFemale: Option[Boolean] = None) extends Text(value){ def this(value: String) = this(Option(value)) + def firstName: Option[String] = None + def lastName: Option[String] = None } object Name { def apply(value: Option[String]): Name = new Name(value) diff --git a/features/src/main/scala/com/salesforce/op/features/types/package.scala b/features/src/main/scala/com/salesforce/op/features/types/package.scala index 18955a41e0..344639f8d2 100644 --- a/features/src/main/scala/com/salesforce/op/features/types/package.scala +++ b/features/src/main/scala/com/salesforce/op/features/types/package.scala @@ -185,6 +185,9 @@ package object types extends FeatureTypeSparkConverters { def toText: Seq[Text] = v.map(_.toText) def toEmail: Seq[Email] = v.map(_.toEmail) } + implicit class SeqNameConversions(val v: Seq[Name]) extends AnyVal { + def toNameList: NameList = new NameList(v) + } implicit class Tup3DoubleConversions(val v: (Double, Double, Double)) extends AnyVal { def toGeolocation: Geolocation = new Geolocation(v) }