diff --git a/api/py/ai/chronon/repo/join_backfill.py b/api/py/ai/chronon/repo/join_backfill.py index 9df9e4d3c..a366a8144 100644 --- a/api/py/ai/chronon/repo/join_backfill.py +++ b/api/py/ai/chronon/repo/join_backfill.py @@ -1,45 +1,50 @@ -import logging +import json import os +from typing import Optional from ai.chronon.constants import ADAPTERS -from ai.chronon.join import Join from ai.chronon.scheduler.interfaces.flow import Flow from ai.chronon.scheduler.interfaces.node import Node -from ai.chronon.utils import get_join_output_table_name, join_part_name, sanitize +from ai.chronon.utils import ( + convert_json_to_obj, + dict_to_bash_commands, + dict_to_exports, + get_join_output_table_name, + join_part_name, + sanitize, +) -SPARK_VERSION = "3.1.1" -SPARK_JAR_TYPE = "uber" -EXECUTOR_MEMORY = "4g" -DRIVER_MEMORY = "4g" TASK_PREFIX = "compute_join" -logging.basicConfig(level=logging.INFO) +DEFAULT_SPARK_SETTINGS = { + "default": { + "spark_version": "3.1.1", + "executor_memory": "4G", + "driver_memory": "4G", + "executor_cores": 2, + } +} class JoinBackfill: def __init__( self, - join: Join, start_date: str, end_date: str, config_path: str, - s3_bucket: str, - spark_version: str = SPARK_VERSION, - executor_memory: str = EXECUTOR_MEMORY, - driver_memory: str = DRIVER_MEMORY, + settings: dict = DEFAULT_SPARK_SETTINGS, ): self.dag_id = "_".join( map( sanitize, ["chronon_joins_backfill", os.path.basename(config_path).split("/")[-1], start_date, end_date] ) ) - self.join = join self.start_date = start_date self.end_date = end_date - self.s3_bucket = s3_bucket self.config_path = config_path - self.spark_version = spark_version - self.executor_memory = executor_memory - self.driver_memory = driver_memory + self.settings = settings + with open(self.config_path, "r") as file: + config = file.read() + self.join = convert_json_to_obj(json.loads(config)) def build_flow(self) -> Flow: """ @@ -54,7 +59,7 @@ def build_flow(self) -> Flow: final_node = Node( f"{TASK_PREFIX}__{sanitize(get_join_output_table_name(self.join, full_name=True))}", self.run_final_join() ) - left_node = Node(f"{TASK_PREFIX}__left_table", self.run_left()) + left_node = Node(f"{TASK_PREFIX}__left_table", self.run_left_table()) flow.add_node(final_node) flow.add_node(left_node) for join_part in self.join.joinParts: @@ -65,31 +70,34 @@ def build_flow(self) -> Flow: final_node.add_dependency(jp_node) return flow - def command_template(self): - config_dir = os.path.dirname(self.config_path) + "/" - cmd = f""" - aws s3 cp {self.s3_bucket}{self.config_path} /tmp/{config_dir} && - aws s3 cp {self.s3_bucket}run.py /tmp/ && - aws s3 cp {self.s3_bucket}spark_submit.sh /tmp/ && - export SPARK_VERSION={self.spark_version} && - export EXECUTOR_MEMORY={self.executor_memory} && - export DRIVER_MEMORY={self.driver_memory} && - python3 /tmp/run.py \ ---conf=/tmp/{self.config_path} --env=production --spark-submit-path /tmp/spark_submit.sh --ds={self.end_date}""" + def export_template(self, settings: dict): + return f"{dict_to_exports(settings)}" + + def command_template(self, extra_args: dict): if self.start_date: - cmd += f" --start-ds={self.start_date}" - return cmd + extra_args.update({"start_ds": self.start_date}) + return f"""python3 /tmp/run.py --conf=/tmp/{self.config_path} --env=production --ds={self.end_date} \ +{dict_to_bash_commands(extra_args)}""" def run_join_part(self, join_part: str): - return self.command_template() + f" --mode=backfill --selected-join-parts={join_part} --use-cached-left" + args = { + "mode": "backfill", + "selected_join_parts": join_part, + "use_cached_left": None, + } + settings = self.settings.get(join_part, self.settings["default"]) + return self.export_template(settings) + " && " + self.command_template(extra_args=args) - def run_left(self): - return self.command_template() + " --mode=backfill-left" + def run_left_table(self): + settings = self.settings.get("left_table", self.settings["default"]) + return self.export_template(settings) + " && " + self.command_template(extra_args={"mode": "backfill-left"}) def run_final_join(self): - return self.command_template() + " --mode=backfill-final" + settings = self.settings.get("final_join", self.settings["default"]) + return self.export_template(settings) + " && " + self.command_template(extra_args={"mode": "backfill-final"}) - def run(self, orchestrator: str): + def run(self, orchestrator: str, overrides: Optional[dict] = None): + ADAPTERS.update(overrides) orchestrator = ADAPTERS[orchestrator](dag_id=self.dag_id, start_date=self.start_date) orchestrator.setup() orchestrator.build_dag_from_flow(self.build_flow()) diff --git a/api/py/ai/chronon/scheduler/adapters/airflow_adapter.py b/api/py/ai/chronon/scheduler/adapters/airflow_adapter.py index fef16b889..c047b911e 100644 --- a/api/py/ai/chronon/scheduler/adapters/airflow_adapter.py +++ b/api/py/ai/chronon/scheduler/adapters/airflow_adapter.py @@ -1,16 +1,13 @@ from datetime import datetime -import airflow_client from ai.chronon.scheduler.interfaces.orchestrator import WorkflowOrchestrator from airflow import DAG from airflow.operators.bash_operator import BashOperator -AIRFLOW_CLUSTER = airflow_client.Service.STONE - class AirflowAdapter(WorkflowOrchestrator): - def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluster=AIRFLOW_CLUSTER): + def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluster=None): self.dag = DAG( dag_id, start_date=datetime.strptime(start_date, "%Y-%m-%d"), @@ -19,7 +16,7 @@ def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluste self.airflow_cluster = airflow_cluster def setup(self): - airflow_client.init(self.airflow_cluster) + """Initialize a connection to Airflow""" def schedule_task(self, node): return BashOperator(task_id=node.name, dag=self.dag, bash_command=node.command) @@ -37,4 +34,4 @@ def build_dag_from_flow(self, flow): return self.dag def trigger_run(self): - airflow_client.create_dag(self.dag, overwrite=True) + """Trigger the DAG run""" diff --git a/api/py/ai/chronon/utils.py b/api/py/ai/chronon/utils.py index 5f021b4a9..9db29f43f 100644 --- a/api/py/ai/chronon/utils.py +++ b/api/py/ai/chronon/utils.py @@ -56,21 +56,15 @@ def __init__(self): self.old_name = "old.json" def diff(self, new_json_str: object, old_json_str: object, skipped_keys=[]) -> str: - new_json = { - k: v for k, v in json.loads(new_json_str).items() if k not in skipped_keys - } - old_json = { - k: v for k, v in json.loads(old_json_str).items() if k not in skipped_keys - } + new_json = {k: v for k, v in json.loads(new_json_str).items() if k not in skipped_keys} + old_json = {k: v for k, v in json.loads(old_json_str).items() if k not in skipped_keys} with open(os.path.join(self.temp_dir, self.old_name), mode="w") as old, open( os.path.join(self.temp_dir, self.new_name), mode="w" ) as new: old.write(json.dumps(old_json, sort_keys=True, indent=2)) new.write(json.dumps(new_json, sort_keys=True, indent=2)) - diff_str = subprocess.run( - ["diff", old.name, new.name], stdout=subprocess.PIPE - ).stdout.decode("utf-8") + diff_str = subprocess.run(["diff", old.name, new.name], stdout=subprocess.PIPE).stdout.decode("utf-8") return diff_str def clean(self): @@ -179,6 +173,28 @@ def sanitize(name): return None +def dict_to_bash_commands(d): + """ + Convert a dict into a bash command substring + """ + if not d: + return "" + bash_commands = [] + for key, value in d.items(): + cmd = f"--{key.replace('_', '-')}={value}" if value else f"--{key.replace('_', '-')}" + bash_commands.append(cmd) + return " ".join(bash_commands) + + +def dict_to_exports(d): + if not d: + return "" + exports = [] + for key, value in d.items(): + exports.append(f"export {key.upper()}={value}") + return " && ".join(exports) + + def output_table_name(obj, full_name: bool): table_name = sanitize(obj.metaData.name) db = obj.metaData.outputNamespace @@ -191,17 +207,11 @@ def output_table_name(obj, full_name: bool): def join_part_name(jp): if jp.groupBy is None: - raise NotImplementedError( - "Join Part names for non group bys is not implemented." - ) - if not jp.groupBy.metaData.name: + raise NotImplementedError("Join Part names for non group bys is not implemented.") + if not jp.groupBy.metaData.name and isinstance(jp.groupBy, api.GroupBy): __set_name(jp.groupBy, api.GroupBy, "group_bys") return "_".join( - [ - component - for component in [jp.prefix, sanitize(jp.groupBy.metaData.name)] - if component is not None - ] + [component for component in [jp.prefix, sanitize(jp.groupBy.metaData.name)] if component is not None] ) @@ -240,9 +250,7 @@ def log_table_name(obj, full_name: bool = False): return output_table_name(obj, full_name=full_name) + "_logged" -def get_staging_query_output_table_name( - staging_query: api.StagingQuery, full_name: bool = False -): +def get_staging_query_output_table_name(staging_query: api.StagingQuery, full_name: bool = False): """generate output table name for staging query job""" __set_name(staging_query, api.StagingQuery, "staging_queries") return output_table_name(staging_query, full_name=full_name) @@ -250,13 +258,12 @@ def get_staging_query_output_table_name( def get_join_output_table_name(join: api.Join, full_name: bool = False): """generate output table name for join backfill job""" - __set_name(join, api.Join, "joins") + if isinstance(join, api.Join): + __set_name(join, api.Join, "joins") # set output namespace if not join.metaData.outputNamespace: team_name = join.metaData.name.split(".")[0] - namespace = teams.get_team_conf( - os.path.join(chronon_root_path, TEAMS_FILE_PATH), team_name, "namespace" - ) + namespace = teams.get_team_conf(os.path.join(chronon_root_path, TEAMS_FILE_PATH), team_name, "namespace") join.metaData.outputNamespace = namespace return output_table_name(join, full_name=full_name) @@ -273,10 +280,7 @@ def get_dependencies( if meta_data is not None: result = [json.loads(dep) for dep in meta_data.dependencies] elif dependencies: - result = [ - {"name": wait_for_name(dep), "spec": dep, "start": start, "end": end} - for dep in dependencies - ] + result = [{"name": wait_for_name(dep), "spec": dep, "start": start, "end": end} for dep in dependencies] else: if src.entities and src.entities.mutationTable: # Opting to use no lag for all use cases because that the "safe catch-all" case when @@ -286,23 +290,15 @@ def get_dependencies( filter( None, [ - wait_for_simple_schema( - src.entities.snapshotTable, lag, start, end - ), - wait_for_simple_schema( - src.entities.mutationTable, lag, start, end - ), + wait_for_simple_schema(src.entities.snapshotTable, lag, start, end), + wait_for_simple_schema(src.entities.mutationTable, lag, start, end), ], ) ) elif src.entities: - result = [ - wait_for_simple_schema(src.entities.snapshotTable, lag, start, end) - ] + result = [wait_for_simple_schema(src.entities.snapshotTable, lag, start, end)] elif src.joinSource: - parentJoinOutputTable = get_join_output_table_name( - src.joinSource.join, True - ) + parentJoinOutputTable = get_join_output_table_name(src.joinSource.join, True) result = [wait_for_simple_schema(parentJoinOutputTable, lag, start, end)] else: result = [wait_for_simple_schema(src.events.table, lag, start, end)] @@ -316,31 +312,17 @@ def get_bootstrap_dependencies(bootstrap_parts) -> List[str]: dependencies = [] for bootstrap_part in bootstrap_parts: table = bootstrap_part.table - start = ( - bootstrap_part.query.startPartition - if bootstrap_part.query is not None - else None - ) - end = ( - bootstrap_part.query.endPartition - if bootstrap_part.query is not None - else None - ) + start = bootstrap_part.query.startPartition if bootstrap_part.query is not None else None + end = bootstrap_part.query.endPartition if bootstrap_part.query is not None else None dependencies.append(wait_for_simple_schema(table, 0, start, end)) return [json.dumps(dep) for dep in dependencies] def get_label_table_dependencies(label_part) -> List[str]: - label_info = [ - (label.groupBy.sources, label.groupBy.metaData) for label in label_part.labels - ] - label_info = [ - (source, meta_data) for (sources, meta_data) in label_info for source in sources - ] + label_info = [(label.groupBy.sources, label.groupBy.metaData) for label in label_part.labels] + label_info = [(source, meta_data) for (sources, meta_data) in label_info for source in sources] label_dependencies = [ - dep - for (source, meta_data) in label_info - for dep in get_dependencies(src=source, meta_data=meta_data) + dep for (source, meta_data) in label_info for dep in get_dependencies(src=source, meta_data=meta_data) ] label_dependencies.append( json.dumps( @@ -360,9 +342,7 @@ def wait_for_simple_schema(table, lag, start, end): clean_name = table_tokens[0] subpartition_spec = "/".join(table_tokens[1:]) if len(table_tokens) > 1 else "" return { - "name": "wait_for_{}_ds{}".format( - clean_name, "" if lag == 0 else f"_minus_{lag}" - ), + "name": "wait_for_{}_ds{}".format(clean_name, "" if lag == 0 else f"_minus_{lag}"), "spec": "{}/ds={}{}".format( clean_name, "{{ ds }}" if lag == 0 else "{{{{ macros.ds_add(ds, -{}) }}}}".format(lag), @@ -388,8 +368,7 @@ def dedupe_in_order(seq): def has_topic(group_by: api.GroupBy) -> bool: """Find if there's topic or mutationTopic for a source helps define streaming tasks""" return any( - (source.entities and source.entities.mutationTopic) - or (source.events and source.events.topic) + (source.entities and source.entities.mutationTopic) or (source.events and source.events.topic) for source in group_by.sources ) @@ -468,3 +447,20 @@ def get_related_table_names(conf: ChrononJobTypes) -> List[str]: related_tables.append(f"{table_name}_bootstrap") return related_tables + + +class DotDict(dict): + def __getattr__(self, attr): + if attr in self: + value = self[attr] + return DotDict(value) if isinstance(value, dict) else value + return None + + +def convert_json_to_obj(d): + if isinstance(d, dict): + return DotDict({k: convert_json_to_obj(v) for k, v in d.items()}) + elif isinstance(d, list): + return [convert_json_to_obj(item) for item in d] + else: + return d diff --git a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala index 840a11d64..cf63a77ed 100644 --- a/spark/src/main/scala/ai/chronon/spark/JoinBase.scala +++ b/spark/src/main/scala/ai/chronon/spark/JoinBase.scala @@ -411,16 +411,6 @@ abstract class JoinBase(joinConf: api.Join, s"groupBy.metaData.team needs to be set for joinPart ${jp.groupBy.metaData.name}") } - val source = joinConf.left - if (useBootstrapForLeft) { - tableUtils.log("Overwriting left side to use saved Bootstrap table...") - source.overwriteTable(bootstrapTable) - val query = source.query - // sets map and where clauses already applied to bootstrap transformation - query.setSelects(null) - query.setWheres(null) - } - // Run validations before starting the job val today = tableUtils.partitionSpec.at(System.currentTimeMillis()) val analyzer = new Analyzer(tableUtils, joinConf, today, today, silenceMode = true) @@ -440,9 +430,21 @@ abstract class JoinBase(joinConf: api.Join, // First run command to archive tables that have changed semantically since the last run val archivedAtTs = Instant.now() + // TODO: We should not archive the output table in the case of selected join parts mode tablesToRecompute(joinConf, outputTable, tableUtils).foreach( tableUtils.archiveOrDropTableIfExists(_, Some(archivedAtTs))) + // Check semantic hash before overwriting left side + val source = joinConf.left + if (useBootstrapForLeft) { + logger.info("Overwriting left side to use saved Bootstrap table...") + source.overwriteTable(bootstrapTable) + val query = source.query + // sets map and where clauses already applied to bootstrap transformation + query.setSelects(null) + query.setWheres(null) + } + // detect holes and chunks to fill // OverrideStartPartition is used to replace the start partition of the join config. This is useful when // 1 - User would like to test run with different start partition diff --git a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala index d9d8859a4..5603807fb 100644 --- a/spark/src/main/scala/ai/chronon/spark/TableUtils.scala +++ b/spark/src/main/scala/ai/chronon/spark/TableUtils.scala @@ -24,6 +24,7 @@ import ai.chronon.api.Extensions._ import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import ai.chronon.spark.Extensions.{DfStats, DfWithStats} import jnr.ffi.annotations.Synchronized +import org.apache.spark.SparkException import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project} import org.apache.spark.sql.functions._ @@ -260,17 +261,17 @@ case class TableUtils(sparkSession: SparkSession) { sparkSession.sql(s"SELECT * FROM $tableName where $partitionColumn='$partitionFilter' LIMIT 1").collect() true } catch { - case e: RuntimeException => + case e: SparkException => if (e.getMessage.contains("ACCESS DENIED")) logger.error(s"[Error] No access to table: $tableName ") else { logger.error(s"[Error] Encountered exception when reading table: $tableName.") - e.printStackTrace() } + e.printStackTrace() false - case ex: Exception => + case e: Exception => logger.error(s"[Error] Encountered exception when reading table: $tableName.") - ex.printStackTrace() + e.printStackTrace() true } }