diff --git a/spark/src/main/scala/ai/chronon/spark/Join.scala b/spark/src/main/scala/ai/chronon/spark/Join.scala index 0adff47f4..98ac3cd4e 100644 --- a/spark/src/main/scala/ai/chronon/spark/Join.scala +++ b/spark/src/main/scala/ai/chronon/spark/Join.scala @@ -261,9 +261,20 @@ class Join(joinConf: api.Join, // combine bootstrap table and join part tables // sequentially join bootstrap table and each join part table. some column may exist both on left and right because // a bootstrap source can cover a partial date range. we combine the columns using coalesce-rule + var previous: Option[DataFrame] = None rightResults .foldLeft(bootstrapDf) { - case (partialDf, (rightPart, rightDf)) => joinWithLeft(partialDf, rightDf, rightPart) + case (partialDf, ((rightPart, rightDf), i)) => + val next = joinWithLeft(partialDf, rightDf, rightPart) + // Join breaks are added to prevent the Spark app from stalling on a Join that involves too many + // rightParts. + if (((i + 1) % tableUtils.finalJoinParallelism) == 0 && (i != (rightResults.size - 1))) { + tableUtils.addJoinBreak(next) + previous.map(_.unpersist()) + previous = Some(next) + } else { + next + } } // drop all processing metadata columns .drop(Constants.MatchedHashes, Constants.TimePartitionColumn) diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index 5d9d71144..75fe58d61 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -68,6 +68,7 @@ case class TableUtils(sparkSession: SparkSession) { val joinPartParallelism: Int = sparkSession.conf.get("spark.chronon.join.part.parallelism", "1").toInt val aggregationParallelism: Int = sparkSession.conf.get("spark.chronon.group_by.parallelism", "1000").toInt + val finalJoinParallelism: Int = sparkSession.conf.get("spark.chronon.join.final_join_parallelism", "8").toInt val maxWait: Int = sparkSession.conf.get("spark.chronon.wait.hours", "48").toInt sparkSession.sparkContext.setLogLevel("ERROR") @@ -324,6 +325,9 @@ case class TableUtils(sparkSession: SparkSession) { df } + def addJoinBreak(dataFrame: DataFrame): DataFrame = + dataFrame.persist(cacheLevel.getOrElse(StorageLevel.MEMORY_AND_DISK)) + def insertUnPartitioned(df: DataFrame, tableName: String, tableProperties: Map[String, String] = null,