diff --git a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala index 635756bf7..6dda16098 100644 --- a/spark/src/main/scala/ai/chronon/spark/GroupBy.scala +++ b/spark/src/main/scala/ai/chronon/spark/GroupBy.scala @@ -35,6 +35,7 @@ import org.slf4j.LoggerFactory import java.util import scala.collection.{Seq, mutable} import scala.util.ScalaJavaConversions.{JListOps, ListOps, MapOps} +import scala.util.Try class GroupBy(val aggregations: Seq[api.Aggregation], val keyColumns: Seq[String], @@ -515,22 +516,21 @@ object GroupBy { // Generate mutation Df if required, align the columns with inputDf so no additional schema is needed by aggregator. val mutationSources = groupByConf.sources.toScala.filter { _.isSetEntities } val mutationsColumnOrder = inputDf.columns ++ Constants.MutationFields.map(_.name) + val mutationQueriesTry = Try( + mutationSources.map(ms => + renderDataSourceQuery(groupByConf, + ms, + groupByConf.getKeyColumns.toScala, + queryRange.shift(1), + tableUtils, + groupByConf.maxWindow, + groupByConf.inferredAccuracy, + mutations = true))) + def mutationDfFn(): DataFrame = { val df: DataFrame = if (groupByConf.inferredAccuracy == api.Accuracy.TEMPORAL && mutationSources.nonEmpty) { - val mutationDf = mutationSources - .map { - renderDataSourceQuery(groupByConf, - _, - groupByConf.getKeyColumns.toScala, - queryRange.shift(1), - tableUtils, - groupByConf.maxWindow, - groupByConf.inferredAccuracy, - mutations = true) - } - .map { - tableUtils.sql - } + val mutationDf = mutationQueriesTry.get + .map { tableUtils.sql } .reduce { (df1, df2) => val columns1 = df1.schema.fields.map(_.name) df1.union(df2.selectExpr(columns1: _*))