Skip to content

Commit

Permalink
Bad attempt at representing output of name identification with NameList
Browse files Browse the repository at this point in the history
  • Loading branch information
MWYang committed Oct 7, 2019
1 parent fb5da80 commit b42469a
Show file tree
Hide file tree
Showing 7 changed files with 35 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,10 +45,10 @@ class HumanNameIdentifier
(
uid: String = UID[HumanNameIdentifier],
operationName: String = "human name identifier"
) extends UnaryEstimator[Text, Text](
) extends UnaryEstimator[Text, NameList](
uid = uid,
operationName = operationName
) with UniqueCountFun {
) {
// Parameter
val defaultThreshold = new DoubleParam(
parent = this,
Expand Down Expand Up @@ -128,10 +128,9 @@ class HumanNameIdentifier
}
}


class HumanNameIdentifierModel(override val uid: String, val treatAsName: Boolean)
extends UnaryModel[Text, Text]("human name identifier", uid) {
// For now, will just return a copy of the text, typed correctly
// Eventually we will want to map into other fields like Race or Gender
def transformFn: Text => Text = input =>
if (treatAsName) Name(input.value) else input
extends UnaryModel[Text, NameList]("human name identifier", uid) {
def transformFn: Text => NameList = input =>
if (treatAsName) NameList(Seq(Name(input.value))) else NameList(Seq.empty)
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,18 @@

package com.salesforce.op.stages.impl.feature

import com.salesforce.op.features.Feature
import com.salesforce.op.features.types._
import com.salesforce.op.stages.base.unary.{UnaryEstimator, UnaryModel}
import com.salesforce.op.test.{OpEstimatorSpec, TestFeatureBuilder}
import org.apache.spark.sql.DataFrame
import org.junit.runner.RunWith
import org.scalatest.junit.JUnitRunner


@RunWith(classOf[JUnitRunner])
class HumanNameIdentifierTest extends OpEstimatorSpec[Text, UnaryModel[Text, Text], UnaryEstimator[Text, Text]] {
class HumanNameIdentifierTest
extends OpEstimatorSpec[NameList, UnaryModel[Text, NameList], UnaryEstimator[Text, NameList]] {
/**
* Input Dataset to fit & transform
*/
Expand All @@ -52,9 +55,9 @@ class HumanNameIdentifierTest extends OpEstimatorSpec[Text, UnaryModel[Text, Tex
/**
* Expected result of the transformer applied on the Input Dataset
*/
val expectedResult: Seq[Text] = Seq("Blah").map(_.toText)
val expectedResult: Seq[NameList] = Seq(Seq(Name("Blah"))).map(_.toNameList)

private def identifyName(data: Seq[Text]) = {
private def identifyName(data: Seq[Text]): (DataFrame, Feature[Text], UnaryModel[Text, NameList], DataFrame) = {
val (newData, newFeature) = TestFeatureBuilder(data)
val model = estimator.setInput(newFeature).fit(newData)
val newResult = model.transform(newData)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,14 @@ case object FeatureSparkTypes {
val City = Text
val PostalCode = Text
val Street = Text
val Name = Text

// Vector
val OPVector = VectorType

// Lists
val TextList = ArrayType(Text, containsNull = true)
val NameList = ArrayType(Name, containsNull = true)
val DateList = ArrayType(Date, containsNull = true)
val DateTimeList = DateList
val Geolocation = ArrayType(Real, containsNull = true)
Expand Down Expand Up @@ -127,6 +129,7 @@ case object FeatureSparkTypes {

// Lists
case wt if wt =:= weakTypeOf[t.TextList] => TextList
case wt if wt =:= weakTypeOf[t.NameList] => NameList
case wt if wt =:= weakTypeOf[t.DateList] => DateList
case wt if wt =:= weakTypeOf[t.DateTimeList] => DateTimeList
case wt if wt =:= weakTypeOf[t.Geolocation] => Geolocation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ case object FeatureTypeDefaults {

// Lists
val TextList = new t.TextList(Seq.empty)
val NameList = new t.NameList(Seq.empty)
val DateList = new t.DateList(Seq.empty)
val DateTimeList = new t.DateTimeList(Seq.empty)
val Geolocation = new t.Geolocation(Seq.empty)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ object TextList {
def empty: TextList = FeatureTypeDefaults.TextList
}

/**
* A list of Name values
*
* @param value list of Name values
*/
class NameList(val value: Seq[Name]) extends OPList[Text] {
def this(v: Name*)(implicit d: DummyImplicit) = this(v)
}
object NameList {
def apply(value: Seq[Name]): NameList = new NameList(value)
def empty: NameList = FeatureTypeDefaults.NameList
}

/**
* A list of date values
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,10 @@ object Street {
*
* @param value name
*/
class Name(value: Option[String]) extends Text(value){
class Name(value: Option[String], isFemale: Option[Boolean] = None) extends Text(value){
def this(value: String) = this(Option(value))
def firstName: Option[String] = None
def lastName: Option[String] = None
}
object Name {
def apply(value: Option[String]): Name = new Name(value)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,9 @@ package object types extends FeatureTypeSparkConverters {
def toText: Seq[Text] = v.map(_.toText)
def toEmail: Seq[Email] = v.map(_.toEmail)
}
implicit class SeqNameConversions(val v: Seq[Name]) extends AnyVal {
def toNameList: NameList = new NameList(v)
}
implicit class Tup3DoubleConversions(val v: (Double, Double, Double)) extends AnyVal {
def toGeolocation: Geolocation = new Geolocation(v)
}
Expand Down

0 comments on commit b42469a

Please sign in to comment.