Skip to content

Commit

Permalink
Add support for external SWA library (feathr-ai#1093)
Browse files Browse the repository at this point in the history
* working test

* Minor comment

* bump version

* documentation update

* update version

---------

Co-authored-by: rkashyap <rkashyap@linkedin.com>
Co-authored-by: Rakesh Kashyap Hanasoge Padmanabha <rkashyap@rkashyap-mn3.linkedin.biz>
  • Loading branch information
3 people authored and Yuqing-cat committed May 23, 2023
1 parent bbdeebe commit e493503
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlanner
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.swa.SWAHandler
import com.linkedin.feathr.offline.util._
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.internal.SQLConf
Expand All @@ -29,7 +30,8 @@ import scala.util.{Failure, Success}
*
*/
class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups: FeatureGroups, logicalPlanner: MultiStageJoinPlanner,
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) {
featureGroupsUpdater: FeatureGroupsUpdater, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext],
swaHandler: Option[SWAHandler]) {
private val log = LogManager.getLogger(getClass)

type KeyTagStringTuple = Seq[String]
Expand Down Expand Up @@ -306,7 +308,7 @@ class FeathrClient private[offline] (sparkSession: SparkSession, featureGroups:
s"regular expression: ${AnchorUtils.featureNamePattern}, but found feature names: $invalidFeatureNames")
}

val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext)
val joiner = new DataFrameFeatureJoiner(logicalPlan=logicalPlan,dataPathHandlers=dataPathHandlers, mvelContext, swaHandler)
// Check conflicts between feature names and data set column names
val conflictFeatureNames: Seq[String] = findConflictFeatureNames(keyTaggedFeatures, left.schema.fieldNames)
val joinConfigSettings = joinConfig.settings
Expand Down Expand Up @@ -418,6 +420,7 @@ object FeathrClient {
private var featureDefConfs: List[FeathrConfig] = List()
private var dataPathHandlers: List[DataPathHandler] = List()
private var mvelContext: Option[FeathrExpressionExecutionContext] = None;
private var swaHandler: Option[SWAHandler] = None;


/**
Expand Down Expand Up @@ -581,6 +584,16 @@ object FeathrClient {
this
}

/**
* Add an optional SWA handler method to allow support for external SWA library handling.
* @param _swaHandler
* @return
*/
def addSWAHandler(_swaHandler: Option[SWAHandler]): Builder = {
this.swaHandler = _swaHandler
this
}

/**
* Build a new instance of the FeathrClient from the added feathr definition configs and any local overrides.
*
Expand Down Expand Up @@ -614,7 +627,8 @@ object FeathrClient {
featureDefConfigs = featureDefConfigs ++ featureDefConfs

val featureGroups = FeatureGroupsGenerator(featureDefConfigs, Some(localDefConfigs)).getFeatureGroups()
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(), dataPathHandlers, mvelContext)
val feathrClient = new FeathrClient(sparkSession, featureGroups, MultiStageJoinPlanner(), FeatureGroupsUpdater(),
dataPathHandlers, mvelContext, swaHandler)

feathrClient
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.source.dataloader.{DataLoaderFactory, DataLoaderHandler}
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType}
import com.linkedin.feathr.offline.swa.SWAHandler
import com.linkedin.feathr.offline.util.FeathrTestUtils.createSparkSession
import com.linkedin.feathr.offline.util.{FeaturizedDatasetMetadata, SparkFeaturizedDataset}
import org.apache.spark.sql.SparkSession
Expand Down Expand Up @@ -36,10 +37,11 @@ object LocalFeatureJoinJob {
extraParams: Array[String] = Array(),
ss: SparkSession = ss,
dataPathHandlers: List[DataPathHandler],
mvelContext: Option[FeathrExpressionExecutionContext]): SparkFeaturizedDataset = {
mvelContext: Option[FeathrExpressionExecutionContext],
swaHandler: Option[SWAHandler]): SparkFeaturizedDataset = {
val joinConfig = FeatureJoinConfig.parseJoinConfig(joinConfigAsHoconString)
val feathrClient = FeathrClient.builder(ss).addFeatureDef(featureDefAsString).addDataPathHandlers(dataPathHandlers)
.addFeathrExpressionContext(mvelContext).build()
.addFeathrExpressionContext(mvelContext).addSWAHandler(swaHandler).build()
val outputPath: String = FeatureJoinJob.SKIP_OUTPUT

val defaultParams = Array(
Expand Down Expand Up @@ -67,10 +69,11 @@ object LocalFeatureJoinJob {
extraParams: Array[String] = Array(),
ss: SparkSession = ss,
dataPathHandlers: List[DataPathHandler],
mvelContext: Option[FeathrExpressionExecutionContext]=None): SparkFeaturizedDataset = {
mvelContext: Option[FeathrExpressionExecutionContext]=None,
swaHandler: Option[SWAHandler] = None): SparkFeaturizedDataset = {
val dataLoaderHandlers: List[DataLoaderHandler] = dataPathHandlers.map(_.dataLoaderHandler)
val obsDf = loadObservationAsFDS(ss, observationDataPath,dataLoaderHandlers=dataLoaderHandlers)
joinWithObsDFAndHoconJoinConfig(joinConfigAsHoconString, featureDefAsString, obsDf, extraParams, ss, dataPathHandlers=dataPathHandlers, mvelContext)
joinWithObsDFAndHoconJoinConfig(joinConfigAsHoconString, featureDefAsString, obsDf, extraParams, ss, dataPathHandlers=dataPathHandlers, mvelContext, swaHandler)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import com.linkedin.feathr.offline.join.workflow._
import com.linkedin.feathr.offline.logical.{FeatureGroups, MultiStageJoinPlan, MultiStageJoinPlanner}
import com.linkedin.feathr.offline.mvel.plugins.FeathrExpressionExecutionContext
import com.linkedin.feathr.offline.source.accessor.DataPathHandler
import com.linkedin.feathr.offline.swa.SlidingWindowAggregationJoiner
import com.linkedin.feathr.offline.swa.{SWAHandler, SlidingWindowAggregationJoiner}
import com.linkedin.feathr.offline.transformation.AnchorToDataSourceMapper
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.transformation.FeatureColumnFormat.FeatureColumnFormat
Expand All @@ -32,7 +32,8 @@ import scala.collection.JavaConverters._
* Joiner to join observation with feature data using Spark DataFrame API
* @param logicalPlan analyzed feature info
*/
private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, dataPathHandlers: List[DataPathHandler], mvelContext: Option[FeathrExpressionExecutionContext]) extends Serializable {
private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, dataPathHandlers: List[DataPathHandler],
mvelContext: Option[FeathrExpressionExecutionContext], swaHandler: Option[SWAHandler]) extends Serializable {
@transient lazy val log = LogManager.getLogger(getClass.getName)
@transient lazy val anchorToDataSourceMapper = new AnchorToDataSourceMapper(dataPathHandlers)
private val windowAggFeatureStages = logicalPlan.windowAggFeatureStages
Expand Down Expand Up @@ -210,7 +211,7 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
JoinExecutionContext(ss, updatedLogicalPlan, updatedFeatureGroups, bloomFilters, Some(saltedJoinFrequentItemDFs))
// 3. Join sliding window aggregation features
val FeatureDataFrame(withWindowAggFeatureDF, inferredSWAFeatureTypes) =
joinSWAFeatures(ss, obsToJoinWithFeatures, joinConfig, featureGroups, failOnMissingPartition, bloomFilters, swaObsTime)
joinSWAFeatures(ss, obsToJoinWithFeatures, joinConfig, featureGroups, failOnMissingPartition, bloomFilters, swaObsTime, swaHandler)

// 4. Join basic anchored features
val anchoredFeatureJoinStep =
Expand Down Expand Up @@ -320,7 +321,8 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
featureGroups: FeatureGroups,
failOnMissingPartition: Boolean,
bloomFilters: Option[Map[Seq[Int], BloomFilter]],
swaObsTime: Option[DateTimeInterval]): FeatureDataFrame = {
swaObsTime: Option[DateTimeInterval],
swaHandler: Option[SWAHandler]): FeatureDataFrame = {
if (windowAggFeatureStages.isEmpty) {
offline.FeatureDataFrame(obsToJoinWithFeatures, Map())
} else {
Expand All @@ -334,7 +336,8 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
requiredWindowAggFeatures,
bloomFilters,
swaObsTime,
failOnMissingPartition)
failOnMissingPartition,
swaHandler)
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private[offline] object DataSourceAccessor {
new NonTimeBasedDataSourceAccessor(ss, dataLoaderFactory, source, expectDatumType)
} else {
import scala.util.control.Breaks._

val timeInterval = dateIntervalOpt.get
var dataAccessorOpt: Option[DataSourceAccessor] = None
breakable {
Expand Down Expand Up @@ -149,7 +149,7 @@ private[offline] object DataSourceAccessor {
*/
private[offline] case class DataAccessorHandler(
validatePath: String => Boolean,
getAccessor:
getAccessor:
(
SparkSession,
DataSource,
Expand All @@ -168,4 +168,5 @@ private[offline] case class DataAccessorHandler(
private[offline] case class DataPathHandler(
dataAccessorHandler: DataAccessorHandler,
dataLoaderHandler: DataLoaderHandler
)
)

Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,15 @@ import com.linkedin.feathr.offline.config.FeatureJoinConfig
import com.linkedin.feathr.offline.exception.FeathrIllegalStateException
import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
import com.linkedin.feathr.offline.transformation.AnchorToDataSourceMapper
import com.linkedin.feathr.offline.transformation.DataFrameDefaultValueSubstituter.substituteDefaults
import com.linkedin.feathr.offline.util.{DataFrameUtils, FeathrUtils}
import com.linkedin.feathr.offline.util.FeathrUtils.shouldCheckPoint
import com.linkedin.feathr.offline.util.datetime.DateTimeInterval
import com.linkedin.feathr.offline.{FeatureDataFrame, JoinStage}
import com.linkedin.feathr.swj.{LabelData, SlidingWindowJoin}
import com.linkedin.feathr.swj.{FactData, LabelData, SlidingWindowJoin}
import com.linkedin.feathr.{common, offline}
import org.apache.logging.log4j.LogManager
import org.apache.spark.sql.functions.{col, lit}
Expand All @@ -26,6 +28,20 @@ import org.apache.spark.util.sketch.BloomFilter

import scala.collection.mutable

/**
* Case class containing other SWA handler methods
* @param join
*/
private[offline] case class SWAHandler(
/**
* The SWA join method
*/
join:
(
LabelData,
List[FactData]
) => DataFrame
)
/**
* Sliding window aggregation joiner
* @param allWindowAggFeatures all window aggregation features
Expand Down Expand Up @@ -64,7 +80,8 @@ private[offline] class SlidingWindowAggregationJoiner(
requiredWindowAggFeatures: Seq[common.ErasedEntityTaggedFeature],
bloomFilters: Option[Map[Seq[Int], BloomFilter]],
swaObsTimeOpt: Option[DateTimeInterval],
failOnMissingPartition: Boolean): FeatureDataFrame = {
failOnMissingPartition: Boolean,
swaHandler: Option[SWAHandler]): FeatureDataFrame = {
val joinConfigSettings = joinConfig.settings
// extract time window settings
if (joinConfigSettings.isEmpty) {
Expand Down Expand Up @@ -239,12 +256,13 @@ private[offline] class SlidingWindowAggregationJoiner(
}
val origContextObsColumns = labelDataDef.dataSource.columns

contextDF = SlidingWindowJoin.join(labelDataDef, factDataDefs.toList)
contextDF = if (swaHandler.isDefined) swaHandler.get.join(labelDataDef, factDataDefs.toList) else SlidingWindowJoin.join(labelDataDef, factDataDefs.toList)

contextDF = if (shouldFilterNulls && !factDataRowsWithNulls.isEmpty) {
val nullDfWithFeatureCols = joinedFeatures.foldLeft(factDataRowsWithNulls)((s, x) => s.withColumn(x, lit(null)))
contextDF.union(nullDfWithFeatureCols)
} else contextDF

val defaults = windowAggAnchorDFThisStage.flatMap(s => s._1.featureAnchor.defaults)
val userSpecifiedTypesConfig = windowAggAnchorDFThisStage.flatMap(_._1.featureAnchor.featureTypeConfigs)

Expand Down
2 changes: 1 addition & 1 deletion gradle.properties
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
version=1.0.0
version=1.0.1-rc1
SONATYPE_AUTOMATIC_RELEASE=true
POM_ARTIFACT_ID=feathr_2.12

0 comments on commit e493503

Please sign in to comment.