Skip to content

Commit

Permalink
Generic Input/Output of DataFrames (feathr-ai#475)
Browse files Browse the repository at this point in the history
* GenericLocation for DataFrame read/write

* WIP

* Generate id column

* Fix unit test

* Parse string into DataLocation

* Id column must be string

* Fix auth logic

* Fix unit test

* Fix id column generation

* CosmosDb Sink
  • Loading branch information
windoze authored and ahlag committed Aug 26, 2022
1 parent 98d622d commit 8343ab3
Show file tree
Hide file tree
Showing 22 changed files with 444 additions and 106 deletions.
4 changes: 3 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ val localAndCloudCommonDependencies = Seq(
"net.snowflake" % "spark-snowflake_2.12" % "2.10.0-spark_3.2",
"org.apache.commons" % "commons-lang3" % "3.12.0",
"org.xerial" % "sqlite-jdbc" % "3.36.0.3",
"com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1"
"com.github.changvvb" %% "jackson-module-caseclass" % "1.1.1",
"com.azure.cosmos.spark" % "azure-cosmos-spark_3-1_2-12" % "4.11.1",
"org.eclipse.jetty" % "jetty-util" % "9.3.24.v20180605"
) // Common deps

val jdbcDrivers = Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import com.linkedin.feathr.offline.anchored.anchorExtractor.{SQLConfigurableAnch
import com.linkedin.feathr.offline.anchored.feature.{FeatureAnchor, FeatureAnchorWithSource}
import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor, SQLSourceKeyExtractor}
import com.linkedin.feathr.offline.client.plugins.{AnchorExtractorAdaptor, FeathrUdfPluginContext, FeatureDerivationFunctionAdaptor, SimpleAnchorExtractorSparkAdaptor, SourceKeyExtractorAdaptor}
import com.linkedin.feathr.offline.config.location.{InputLocation, Jdbc, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.config.location.{DataLocation, KafkaEndpoint, LocationUtils, SimplePath}
import com.linkedin.feathr.offline.derived._
import com.linkedin.feathr.offline.derived.functions.{MvelFeatureDerivationFunction, SQLFeatureDerivationFunction, SeqJoinDerivationFunction, SimpleMvelDerivationFunction}
import com.linkedin.feathr.offline.source.{DataSource, SourceFormatType, TimeWindowParams}
Expand Down Expand Up @@ -735,7 +735,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] {
* 2. a placeholder with reserved string "PASSTHROUGH" for anchor defined pass-through features,
* since anchor defined pass-through features do not have path
*/
val path: InputLocation = dataSourceType match {
val path: DataLocation = dataSourceType match {
case "KAFKA" =>
Option(node.get("config")) match {
case Some(field: ObjectNode) =>
Expand All @@ -748,7 +748,7 @@ private[offline] class DataSourceLoader extends JsonDeserializer[DataSource] {
case "PASSTHROUGH" => SimplePath("PASSTHROUGH")
case _ => Option(node.get("location")) match {
case Some(field: ObjectNode) =>
LocationUtils.getMapper().treeToValue(field, classOf[InputLocation])
LocationUtils.getMapper().treeToValue(field, classOf[DataLocation])
case None => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR,
s"Data location is not defined for data source ${node.toPrettyString()}")
case _ => throw new FeathrConfigException(ErrorLabel.FEATHR_USER_ERROR,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.annotation.{JsonSubTypes, JsonTypeInfo}
import com.fasterxml.jackson.core.JacksonException
import com.fasterxml.jackson.databind.module.SimpleModule
import com.fasterxml.jackson.databind.{DeserializationFeature, ObjectMapper}
import com.fasterxml.jackson.module.caseclass.mapper.CaseClassObjectMapper
import com.jasonclawson.jackson.dataformat.hocon.HoconFactory
import com.linkedin.feathr.common.FeathrJacksonScalaModule
import com.linkedin.feathr.common.{FeathrJacksonScalaModule, Header}
import com.linkedin.feathr.offline.config.DataSourceLoader
import com.linkedin.feathr.offline.source.DataSource
import com.typesafe.config.{Config, ConfigException}
import org.apache.spark.sql.{DataFrame, SparkSession}

import scala.collection.JavaConverters._

/**
* An InputLocation is a data source definition, it can either be HDFS files or a JDBC database connection
*/
Expand All @@ -20,38 +24,50 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
new JsonSubTypes.Type(value = classOf[SimplePath], name = "path"),
new JsonSubTypes.Type(value = classOf[PathList], name = "pathlist"),
new JsonSubTypes.Type(value = classOf[Jdbc], name = "jdbc"),
new JsonSubTypes.Type(value = classOf[GenericLocation], name = "generic"),
))
trait InputLocation {
trait DataLocation {
/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `path` from `SimplePath` only
*/
@deprecated("Do not use this method in any new code, it will be removed soon")
def getPath: String

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
* WARN: This method is deprecated, you must use match/case on InputLocation,
* and get `paths` from `PathList` only
*/
@deprecated("Do not use this method in any new code, it will be removed soon")
def getPathList: List[String]

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
def loadDf(ss: SparkSession, dataIOParameters: Map[String, String] = Map()): DataFrame

/**
* Write DataFrame to the location
* @param ss SparkSession
* @param df DataFrame to write
*/
def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header])

/**
* Tell if this location is file based
*
* @return boolean
*/
def isFileBasedLocation(): Boolean
Expand All @@ -67,6 +83,7 @@ object LocationUtils {
/**
* String template substitution, replace "...${VAR}.." with corresponding System property or environment variable
* Non-existent pattern is replaced by empty string.
*
* @param s String template to be processed
* @return Processed result
*/
Expand All @@ -76,6 +93,7 @@ object LocationUtils {

/**
* Get an ObjectMapper to deserialize DataSource
*
* @return the ObjectMapper
*/
def getMapper(): ObjectMapper = {
Expand All @@ -86,3 +104,50 @@ object LocationUtils {
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
}
}

object DataLocation {
/**
* Create DataLocation from string, try parsing the string as JSON and fallback to SimplePath
* @param cfg the input string
* @return DataLocation
*/
def apply(cfg: String): DataLocation = {
val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper)
.registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem
.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
try {
// Cfg is either a plain path or a JSON object
if (cfg.trim.startsWith("{")) {
val location = jackson.readValue(cfg, classOf[DataLocation])
location
} else {
SimplePath(cfg)
}
} catch {
case _ @ (_: ConfigException | _: JacksonException) => SimplePath(cfg)
}
}

def apply(cfg: Config): DataLocation = {
apply(cfg.root().keySet().asScala.map(key key cfg.getString(key)).toMap)
}

def apply(cfg: Any): DataLocation = {
val jackson = (new ObjectMapper(new HoconFactory) with CaseClassObjectMapper)
.registerModule(FeathrJacksonScalaModule) // DefaultScalaModule causes a fail on holdem
.configure(DeserializationFeature.ACCEPT_SINGLE_VALUE_AS_ARRAY, true)
.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false)
.registerModule(new SimpleModule().addDeserializer(classOf[DataSource], new DataSourceLoader))
try {
val location = jackson.convertValue(cfg, classOf[DataLocation])
location
} catch {
case e: JacksonException => {
print(e)
SimplePath(cfg.toString)
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
package com.linkedin.feathr.offline.config.location

import com.fasterxml.jackson.annotation.JsonAnySetter
import com.fasterxml.jackson.module.caseclass.annotation.CaseClassDeserialize
import com.linkedin.feathr.common.Header
import com.linkedin.feathr.common.exception.FeathrException
import com.linkedin.feathr.offline.generation.FeatureGenUtils
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner
import net.minidev.json.annotate.JsonIgnore
import org.apache.log4j.Logger
import org.apache.spark.sql.functions.monotonically_increasing_id
import org.apache.spark.sql.{DataFrame, DataFrameWriter, Row, SparkSession}

@CaseClassDeserialize()
case class GenericLocation(format: String, mode: Option[String] = None) extends DataLocation {
val log: Logger = Logger.getLogger(getClass)
val options: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()
val conf: collection.mutable.Map[String, String] = collection.mutable.Map[String, String]()

/**
* Backward Compatibility
* Many existing codes expect a simple path
*
* @return the `path` or `url` of the data source
*
* WARN: This method is deprecated, you must use match/case on DataLocation,
* and get `path` from `SimplePath` only
*/
override def getPath: String = s"GenericLocation(${format})"

/**
* Backward Compatibility
*
* @return the `path` or `url` of the data source, wrapped in an List
*
* WARN: This method is deprecated, you must use match/case on DataLocation,
* and get `paths` from `PathList` only
*/
override def getPathList: List[String] = List(getPath)

/**
* Load DataFrame from Spark session
*
* @param ss SparkSession
* @return
*/
override def loadDf(ss: SparkSession, dataIOParameters: Map[String, String]): DataFrame = {
GenericLocationFixes.readDf(ss, this)
}

/**
* Write DataFrame to the location
*
* @param ss SparkSession
* @param df DataFrame to write
*/
override def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header]): Unit = {
GenericLocationFixes.writeDf(ss, df, header, this)
}

/**
* Tell if this location is file based
*
* @return boolean
*/
override def isFileBasedLocation(): Boolean = false

@JsonAnySetter
def setOption(key: String, value: Any): Unit = {
println(s"GenericLocation.setOption(key: $key, value: $value)")
if (key == null) {
log.warn("Got null key, skipping")
return
}
if (value == null) {
log.warn(s"Got null value for key '$key', skipping")
return
}
val v = value.toString
if (v == null) {
log.warn(s"Got invalid value for key '$key', skipping")
return
}
if (key.startsWith("__conf__")) {
conf += (key.stripPrefix("__conf__").replace("__", ".") -> LocationUtils.envSubstitute(v))
} else {
options += (key.replace("__", ".") -> LocationUtils.envSubstitute(v))
}
}
}

/**
* Some Spark connectors need extra actions before read or write, namely CosmosDb and ElasticSearch
* Need to run specific fixes base on `format`
*/
object GenericLocationFixes {
def readDf(ss: SparkSession, location: GenericLocation): DataFrame = {
location.conf.foreach(e => {
ss.conf.set(e._1, e._2)
})
ss.read.format(location.format)
.options(location.options)
.load()
}

def writeDf(ss: SparkSession, df: DataFrame, header: Option[Header], location: GenericLocation) = {
location.conf.foreach(e => {
ss.conf.set(e._1, e._2)
})

location.format.toLowerCase() match {
case "cosmos.oltp" =>
// Ensure the database and the table exist before writing
val endpoint = location.options.getOrElse("spark.cosmos.accountEndpoint", throw new FeathrException("Missing spark__cosmos__accountEndpoint"))
val key = location.options.getOrElse("spark.cosmos.accountKey", throw new FeathrException("Missing spark__cosmos__accountKey"))
val databaseName = location.options.getOrElse("spark.cosmos.database", throw new FeathrException("Missing spark__cosmos__database"))
val tableName = location.options.getOrElse("spark.cosmos.container", throw new FeathrException("Missing spark__cosmos__container"))
ss.conf.set("spark.sql.catalog.cosmosCatalog", "com.azure.cosmos.spark.CosmosCatalog")
ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountEndpoint", endpoint)
ss.conf.set("spark.sql.catalog.cosmosCatalog.spark.cosmos.accountKey", key)
ss.sql(s"CREATE DATABASE IF NOT EXISTS cosmosCatalog.${databaseName};")
ss.sql(s"CREATE TABLE IF NOT EXISTS cosmosCatalog.${databaseName}.${tableName} using cosmos.oltp TBLPROPERTIES(partitionKeyPath = '/id')")

// CosmosDb requires the column `id` to exist and be the primary key, and `id` must be in `string` type
val keyDf = if (!df.columns.contains("id")) {
header match {
case Some(h) => {
// Generate key column from header info, which is required by CosmosDb
val (keyCol, keyedDf) = DataFrameKeyCombiner().combine(df, FeatureGenUtils.getKeyColumnsFromHeader(h))
// Rename key column to `id`
keyedDf.withColumnRenamed(keyCol, "id")
}
case None => {
// If there is no key column, we use a auto-generated monotonic id.
// but in this case the result could be duplicated if you run job for multiple times
// This function is for offline-storage usage, ideally user should create a new container for every run
df.withColumn("id", (monotonically_increasing_id().cast("string")))
}
}
} else {
// We already have an `id` column
// TODO: Should we do anything here?
// A corner case is that the `id` column exists but not unique, then the output will be incomplete as
// CosmosDb will overwrite the old entry with the new one with same `id`.
// We can either rename the existing `id` column and use header/autogen key column, or we can tell user
// to avoid using `id` column for non-unique data, but both workarounds have pros and cons.
df
}
keyDf.write.format(location.format)
.options(location.options)
.mode(location.mode.getOrElse("append")) // CosmosDb doesn't support ErrorIfExist mode in batch mode
.save()
case _ =>
// Normal writing procedure, just set format and options then write
df.write.format(location.format)
.options(location.options)
.mode(location.mode.getOrElse("default"))
.save()
}
}
}
Loading

0 comments on commit 8343ab3

Please sign in to comment.