Skip to content

Commit

Permalink
Add print logging behind a flag
Browse files Browse the repository at this point in the history
  • Loading branch information
nikhilsimha committed May 3, 2024
1 parent a05382e commit c6daa14
Show file tree
Hide file tree
Showing 5 changed files with 93 additions and 83 deletions.
34 changes: 17 additions & 17 deletions spark/src/main/scala/ai/chronon/spark/GroupBy.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
skewFilter: Option[String] = None,
finalize: Boolean = true)
extends Serializable {
@transient lazy val logger = LoggerFactory.getLogger(getClass)
@transient implicit lazy val logger = LoggerFactory.getLogger(getClass)

protected[spark] val tsIndex: Int = inputDf.schema.fieldNames.indexOf(Constants.TimeColumn)
protected val selectedSchema: Array[(String, api.DataType)] = SparkConversions.toChrononSchema(inputDf.schema)
Expand Down Expand Up @@ -120,7 +120,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],
inputDf -> updateFunc
}

logger.info(s"""
tableUtils.log(s"""
|Prepped input schema
|${preppedInputDf.schema.pretty}
|""".stripMargin)
Expand Down Expand Up @@ -393,7 +393,7 @@ class GroupBy(val aggregations: Seq[api.Aggregation],

// TODO: truncate queryRange for caching
object GroupBy {
@transient lazy val logger = LoggerFactory.getLogger(getClass)
@transient implicit lazy val logger = LoggerFactory.getLogger(getClass)

// Need to use a case class here to allow null matching
case class SourceDataProfile(earliestRequired: String, earliestPresent: String, latestAllowed: String)
Expand All @@ -408,7 +408,7 @@ object GroupBy {
val result = groupByConf.deepCopy()
val newSources: java.util.List[api.Source] = groupByConf.sources.toScala.map { source =>
if (source.isSetJoinSource) {
logger.info("Join source detected. Materializing the join.")
tableUtils.log("Join source detected. Materializing the join.")
val joinSource = source.getJoinSource
val joinConf = joinSource.join
// materialize the table with the right end date. QueryRange.end could be shifted for temporal events
Expand All @@ -421,7 +421,7 @@ object GroupBy {
if (computeDependency) {
val df = join.computeJoin()
if (showDf) {
logger.info(
tableUtils.log(
s"printing output data from groupby::join_source: ${groupByConf.metaData.name}::${joinConf.metaData.name}")
df.prettyPrint()
}
Expand Down Expand Up @@ -462,7 +462,7 @@ object GroupBy {
finalize: Boolean = true,
mutationScan: Boolean = true,
showDf: Boolean = false): GroupBy = {
logger.info(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----")
tableUtils.log(s"\n----[Processing GroupBy: ${groupByConfOld.metaData.name}]----")
val groupByConf = replaceJoinSource(groupByConfOld, queryRange, tableUtils, computeDependency, showDf)
val inputDf = groupByConf.sources.toScala
.map { source =>
Expand Down Expand Up @@ -496,7 +496,7 @@ object GroupBy {
val keyColumns = groupByConf.getKeyColumns.toScala
val skewFilteredDf = skewFilter
.map { sf =>
logger.info(s"$logPrefix filtering using skew filter:\n $sf")
tableUtils.log(s"$logPrefix filtering using skew filter:\n $sf")
val filtered = inputDf.filter(sf)
filtered
}
Expand All @@ -508,7 +508,7 @@ object GroupBy {
val nullFilterClause = groupByConf.keyColumns.toScala.map(key => s"($key IS NOT NULL)").mkString(" OR ")
val nullFiltered = processedInputDf.filter(nullFilterClause)
if (showDf) {
logger.info(s"printing input date for groupBy: ${groupByConf.metaData.name}")
tableUtils.log(s"printing input date for groupBy: ${groupByConf.metaData.name}")
nullFiltered.prettyPrint()
}

Expand Down Expand Up @@ -580,7 +580,7 @@ object GroupBy {
val queryableDataRange =
PartitionRange(dataProfile.earliestRequired, Seq(queryEnd, dataProfile.latestAllowed).max)(tableUtils)
val intersectedRange = sourceRange.intersect(queryableDataRange)
logger.info(s"""
tableUtils.log(s"""
|Computing intersected range as:
| query range: $queryRange
| query window: $window
Expand Down Expand Up @@ -628,14 +628,14 @@ object GroupBy {
Some(Constants.TimeColumn -> Option(source.query.timeColumn).getOrElse(dsBasedTimestamp))
}
}
logger.info(s"""
tableUtils.log(s"""
|Time Mapping: $timeMapping
|""".stripMargin)
metaColumns ++= timeMapping

val partitionConditions = intersectedRange.map(_.whereClauses()).getOrElse(Seq.empty)

logger.info(s"""
tableUtils.log(s"""
|Rendering source query:
| intersected/effective scan range: $intersectedRange
| partitionConditions: $partitionConditions
Expand Down Expand Up @@ -692,25 +692,25 @@ object GroupBy {
skipFirstHole = skipFirstHole)

if (groupByUnfilledRangesOpt.isEmpty) {
logger.info(s"""Nothing to backfill for $outputTable - given
tableUtils.log(s"""Nothing to backfill for $outputTable - given
|endPartition of $endPartition
|backfill start of $overrideStart
|Exiting...""".stripMargin)
return
}
val groupByUnfilledRanges = groupByUnfilledRangesOpt.get
logger.info(s"group by unfilled ranges: $groupByUnfilledRanges")
tableUtils.log(s"group by unfilled ranges: $groupByUnfilledRanges")
val exceptions = mutable.Buffer.empty[String]
groupByUnfilledRanges.foreach {
case groupByUnfilledRange =>
try {
val stepRanges = stepDays.map(groupByUnfilledRange.steps).getOrElse(Seq(groupByUnfilledRange))
logger.info(s"Group By ranges to compute: ${stepRanges.map {
tableUtils.log(s"Group By ranges to compute: ${stepRanges.map {
_.toString
}.pretty}")
stepRanges.zipWithIndex.foreach {
case (range, index) =>
logger.info(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]")
tableUtils.log(s"Computing group by for range: $range [${index + 1}/${stepRanges.size}]")
val groupByBackfill = from(groupByConf, range, tableUtils, computeDependency = true)
val outputDf = groupByConf.dataModel match {
// group by backfills have to be snapshot only
Expand All @@ -724,9 +724,9 @@ object GroupBy {
val result = outputDf.select(finalOutputColumns: _*)
result.save(outputTable, tableProps)
}
logger.info(s"Wrote to table $outputTable, into partitions: $range")
tableUtils.log(s"Wrote to table $outputTable, into partitions: $range")
}
logger.info(s"Wrote to table $outputTable for range: $groupByUnfilledRange")
tableUtils.log(s"Wrote to table $outputTable for range: $groupByUnfilledRange")

} catch {
case err: Throwable =>
Expand Down
18 changes: 9 additions & 9 deletions spark/src/main/scala/ai/chronon/spark/Join.scala
Original file line number Diff line number Diff line change
Expand Up @@ -196,13 +196,13 @@ class Join(joinConf: api.Join,
(joinPartMetadata, coveringSets)
}

logger.info(
tableUtils.log(
s"\n======= CoveringSet for JoinPart ${joinConf.metaData.name} for PartitionRange(${leftRange.start}, ${leftRange.end}) =======\n")
coveringSetsPerJoinPart.foreach {
case (joinPartMetadata, coveringSets) =>
logger.info(s"Bootstrap sets for join part ${joinPartMetadata.joinPart.groupBy.metaData.name}")
tableUtils.log(s"Bootstrap sets for join part ${joinPartMetadata.joinPart.groupBy.metaData.name}")
coveringSets.foreach { coveringSet =>
logger.info(
tableUtils.log(
s"CoveringSet(hash=${coveringSet.hashes.prettyInline}, rowCount=${coveringSet.rowCount}, isCovering=${coveringSet.isCovering})")
}
}
Expand All @@ -221,7 +221,7 @@ class Join(joinConf: api.Join,
}
val wheres = Seq(s"ds >= '${effectiveRange.start}'", s"ds <= '${effectiveRange.end}'")
val sql = QueryUtils.build(null, partTable, wheres)
logger.info(s"Pulling data from joinPart table with: $sql")
tableUtils.log(s"Pulling data from joinPart table with: $sql")
val df = tableUtils.sparkSession.sql(sql)
(joinPart, df)
}
Expand Down Expand Up @@ -436,7 +436,7 @@ class Join(joinConf: api.Join,

val result = baseDf.select(finalOutputColumns: _*)
if (showDf) {
logger.info(s"printing results for join: ${joinConf.metaData.name}")
tableUtils.log(s"printing results for join: ${joinConf.metaData.name}")
result.prettyPrint()
}
result
Expand Down Expand Up @@ -505,15 +505,15 @@ class Join(joinConf: api.Join,
val joinedDf = parts.foldLeft(initDf) {
case (partialDf, part) => {

logger.info(s"\nProcessing Bootstrap from table ${part.table} for range ${unfilledRange}")
tableUtils.log(s"\nProcessing Bootstrap from table ${part.table} for range ${unfilledRange}")

val bootstrapRange = if (part.isSetQuery) {
unfilledRange.intersect(PartitionRange(part.startPartition, part.endPartition)(tableUtils))
} else {
unfilledRange
}
if (!bootstrapRange.valid) {
logger.info(s"partition range of bootstrap table ${part.table} is beyond unfilled range")
tableUtils.log(s"partition range of bootstrap table ${part.table} is beyond unfilled range")
partialDf
} else {
var bootstrapDf = tableUtils.sql(
Expand Down Expand Up @@ -559,7 +559,7 @@ class Join(joinConf: api.Join,
})

val elapsedMins = (System.currentTimeMillis() - startMillis) / (60 * 1000)
logger.info(s"Finished computing bootstrap table ${joinConf.metaData.bootstrapTable} in ${elapsedMins} minutes")
tableUtils.log(s"Finished computing bootstrap table ${joinConf.metaData.bootstrapTable} in ${elapsedMins} minutes")

tableUtils.sql(range.genScanQuery(query = null, table = bootstrapTable))
}
Expand All @@ -578,7 +578,7 @@ class Join(joinConf: api.Join,
return Some(bootstrapDfWithStats)
}
val filterExpr = CoveringSet.toFilterExpression(coveringSets)
logger.info(s"Using covering set filter: $filterExpr")
tableUtils.log(s"Using covering set filter: $filterExpr")
val filteredDf = bootstrapDf.where(filterExpr)
val filteredCount = filteredDf.count()
if (bootstrapDfWithStats.count == filteredCount) { // counting is faster than computing stats
Expand Down
Loading

0 comments on commit c6daa14

Please sign in to comment.