Skip to content

Commit

Permalink
Merge branch 'main' into print_logging
Browse files Browse the repository at this point in the history
Signed-off-by: Nikhil <r.nikhilsimha@gmail.com>
  • Loading branch information
nikhilsimha committed May 17, 2024
2 parents c6daa14 + 34a7849 commit f830a89
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 122 deletions.
82 changes: 45 additions & 37 deletions api/py/ai/chronon/repo/join_backfill.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,50 @@
import logging
import json
import os
from typing import Optional

from ai.chronon.constants import ADAPTERS
from ai.chronon.join import Join
from ai.chronon.scheduler.interfaces.flow import Flow
from ai.chronon.scheduler.interfaces.node import Node
from ai.chronon.utils import get_join_output_table_name, join_part_name, sanitize
from ai.chronon.utils import (
convert_json_to_obj,
dict_to_bash_commands,
dict_to_exports,
get_join_output_table_name,
join_part_name,
sanitize,
)

SPARK_VERSION = "3.1.1"
SPARK_JAR_TYPE = "uber"
EXECUTOR_MEMORY = "4g"
DRIVER_MEMORY = "4g"
TASK_PREFIX = "compute_join"
logging.basicConfig(level=logging.INFO)
DEFAULT_SPARK_SETTINGS = {
"default": {
"spark_version": "3.1.1",
"executor_memory": "4G",
"driver_memory": "4G",
"executor_cores": 2,
}
}


class JoinBackfill:
def __init__(
self,
join: Join,
start_date: str,
end_date: str,
config_path: str,
s3_bucket: str,
spark_version: str = SPARK_VERSION,
executor_memory: str = EXECUTOR_MEMORY,
driver_memory: str = DRIVER_MEMORY,
settings: dict = DEFAULT_SPARK_SETTINGS,
):
self.dag_id = "_".join(
map(
sanitize, ["chronon_joins_backfill", os.path.basename(config_path).split("/")[-1], start_date, end_date]
)
)
self.join = join
self.start_date = start_date
self.end_date = end_date
self.s3_bucket = s3_bucket
self.config_path = config_path
self.spark_version = spark_version
self.executor_memory = executor_memory
self.driver_memory = driver_memory
self.settings = settings
with open(self.config_path, "r") as file:
config = file.read()
self.join = convert_json_to_obj(json.loads(config))

def build_flow(self) -> Flow:
"""
Expand All @@ -54,7 +59,7 @@ def build_flow(self) -> Flow:
final_node = Node(
f"{TASK_PREFIX}__{sanitize(get_join_output_table_name(self.join, full_name=True))}", self.run_final_join()
)
left_node = Node(f"{TASK_PREFIX}__left_table", self.run_left())
left_node = Node(f"{TASK_PREFIX}__left_table", self.run_left_table())
flow.add_node(final_node)
flow.add_node(left_node)
for join_part in self.join.joinParts:
Expand All @@ -65,31 +70,34 @@ def build_flow(self) -> Flow:
final_node.add_dependency(jp_node)
return flow

def command_template(self):
config_dir = os.path.dirname(self.config_path) + "/"
cmd = f"""
aws s3 cp {self.s3_bucket}{self.config_path} /tmp/{config_dir} &&
aws s3 cp {self.s3_bucket}run.py /tmp/ &&
aws s3 cp {self.s3_bucket}spark_submit.sh /tmp/ &&
export SPARK_VERSION={self.spark_version} &&
export EXECUTOR_MEMORY={self.executor_memory} &&
export DRIVER_MEMORY={self.driver_memory} &&
python3 /tmp/run.py \
--conf=/tmp/{self.config_path} --env=production --spark-submit-path /tmp/spark_submit.sh --ds={self.end_date}"""
def export_template(self, settings: dict):
return f"{dict_to_exports(settings)}"

def command_template(self, extra_args: dict):
if self.start_date:
cmd += f" --start-ds={self.start_date}"
return cmd
extra_args.update({"start_ds": self.start_date})
return f"""python3 /tmp/run.py --conf=/tmp/{self.config_path} --env=production --ds={self.end_date} \
{dict_to_bash_commands(extra_args)}"""

def run_join_part(self, join_part: str):
return self.command_template() + f" --mode=backfill --selected-join-parts={join_part} --use-cached-left"
args = {
"mode": "backfill",
"selected_join_parts": join_part,
"use_cached_left": None,
}
settings = self.settings.get(join_part, self.settings["default"])
return self.export_template(settings) + " && " + self.command_template(extra_args=args)

def run_left(self):
return self.command_template() + " --mode=backfill-left"
def run_left_table(self):
settings = self.settings.get("left_table", self.settings["default"])
return self.export_template(settings) + " && " + self.command_template(extra_args={"mode": "backfill-left"})

def run_final_join(self):
return self.command_template() + " --mode=backfill-final"
settings = self.settings.get("final_join", self.settings["default"])
return self.export_template(settings) + " && " + self.command_template(extra_args={"mode": "backfill-final"})

def run(self, orchestrator: str):
def run(self, orchestrator: str, overrides: Optional[dict] = None):
ADAPTERS.update(overrides)
orchestrator = ADAPTERS[orchestrator](dag_id=self.dag_id, start_date=self.start_date)
orchestrator.setup()
orchestrator.build_dag_from_flow(self.build_flow())
Expand Down
9 changes: 3 additions & 6 deletions api/py/ai/chronon/scheduler/adapters/airflow_adapter.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from datetime import datetime

import airflow_client
from ai.chronon.scheduler.interfaces.orchestrator import WorkflowOrchestrator

from airflow import DAG
from airflow.operators.bash_operator import BashOperator

AIRFLOW_CLUSTER = airflow_client.Service.STONE


class AirflowAdapter(WorkflowOrchestrator):
def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluster=AIRFLOW_CLUSTER):
def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluster=None):
self.dag = DAG(
dag_id,
start_date=datetime.strptime(start_date, "%Y-%m-%d"),
Expand All @@ -19,7 +16,7 @@ def __init__(self, dag_id, start_date, schedule_interval="@once", airflow_cluste
self.airflow_cluster = airflow_cluster

def setup(self):
airflow_client.init(self.airflow_cluster)
"""Initialize a connection to Airflow"""

def schedule_task(self, node):
return BashOperator(task_id=node.name, dag=self.dag, bash_command=node.command)
Expand All @@ -37,4 +34,4 @@ def build_dag_from_flow(self, flow):
return self.dag

def trigger_run(self):
airflow_client.create_dag(self.dag, overwrite=True)
"""Trigger the DAG run"""
126 changes: 61 additions & 65 deletions api/py/ai/chronon/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,15 @@ def __init__(self):
self.old_name = "old.json"

def diff(self, new_json_str: object, old_json_str: object, skipped_keys=[]) -> str:
new_json = {
k: v for k, v in json.loads(new_json_str).items() if k not in skipped_keys
}
old_json = {
k: v for k, v in json.loads(old_json_str).items() if k not in skipped_keys
}
new_json = {k: v for k, v in json.loads(new_json_str).items() if k not in skipped_keys}
old_json = {k: v for k, v in json.loads(old_json_str).items() if k not in skipped_keys}

with open(os.path.join(self.temp_dir, self.old_name), mode="w") as old, open(
os.path.join(self.temp_dir, self.new_name), mode="w"
) as new:
old.write(json.dumps(old_json, sort_keys=True, indent=2))
new.write(json.dumps(new_json, sort_keys=True, indent=2))
diff_str = subprocess.run(
["diff", old.name, new.name], stdout=subprocess.PIPE
).stdout.decode("utf-8")
diff_str = subprocess.run(["diff", old.name, new.name], stdout=subprocess.PIPE).stdout.decode("utf-8")
return diff_str

def clean(self):
Expand Down Expand Up @@ -179,6 +173,28 @@ def sanitize(name):
return None


def dict_to_bash_commands(d):
"""
Convert a dict into a bash command substring
"""
if not d:
return ""
bash_commands = []
for key, value in d.items():
cmd = f"--{key.replace('_', '-')}={value}" if value else f"--{key.replace('_', '-')}"
bash_commands.append(cmd)
return " ".join(bash_commands)


def dict_to_exports(d):
if not d:
return ""
exports = []
for key, value in d.items():
exports.append(f"export {key.upper()}={value}")
return " && ".join(exports)


def output_table_name(obj, full_name: bool):
table_name = sanitize(obj.metaData.name)
db = obj.metaData.outputNamespace
Expand All @@ -191,17 +207,11 @@ def output_table_name(obj, full_name: bool):

def join_part_name(jp):
if jp.groupBy is None:
raise NotImplementedError(
"Join Part names for non group bys is not implemented."
)
if not jp.groupBy.metaData.name:
raise NotImplementedError("Join Part names for non group bys is not implemented.")
if not jp.groupBy.metaData.name and isinstance(jp.groupBy, api.GroupBy):
__set_name(jp.groupBy, api.GroupBy, "group_bys")
return "_".join(
[
component
for component in [jp.prefix, sanitize(jp.groupBy.metaData.name)]
if component is not None
]
[component for component in [jp.prefix, sanitize(jp.groupBy.metaData.name)] if component is not None]
)


Expand Down Expand Up @@ -240,23 +250,20 @@ def log_table_name(obj, full_name: bool = False):
return output_table_name(obj, full_name=full_name) + "_logged"


def get_staging_query_output_table_name(
staging_query: api.StagingQuery, full_name: bool = False
):
def get_staging_query_output_table_name(staging_query: api.StagingQuery, full_name: bool = False):
"""generate output table name for staging query job"""
__set_name(staging_query, api.StagingQuery, "staging_queries")
return output_table_name(staging_query, full_name=full_name)


def get_join_output_table_name(join: api.Join, full_name: bool = False):
"""generate output table name for join backfill job"""
__set_name(join, api.Join, "joins")
if isinstance(join, api.Join):
__set_name(join, api.Join, "joins")
# set output namespace
if not join.metaData.outputNamespace:
team_name = join.metaData.name.split(".")[0]
namespace = teams.get_team_conf(
os.path.join(chronon_root_path, TEAMS_FILE_PATH), team_name, "namespace"
)
namespace = teams.get_team_conf(os.path.join(chronon_root_path, TEAMS_FILE_PATH), team_name, "namespace")
join.metaData.outputNamespace = namespace
return output_table_name(join, full_name=full_name)

Expand All @@ -273,10 +280,7 @@ def get_dependencies(
if meta_data is not None:
result = [json.loads(dep) for dep in meta_data.dependencies]
elif dependencies:
result = [
{"name": wait_for_name(dep), "spec": dep, "start": start, "end": end}
for dep in dependencies
]
result = [{"name": wait_for_name(dep), "spec": dep, "start": start, "end": end} for dep in dependencies]
else:
if src.entities and src.entities.mutationTable:
# Opting to use no lag for all use cases because that the "safe catch-all" case when
Expand All @@ -286,23 +290,15 @@ def get_dependencies(
filter(
None,
[
wait_for_simple_schema(
src.entities.snapshotTable, lag, start, end
),
wait_for_simple_schema(
src.entities.mutationTable, lag, start, end
),
wait_for_simple_schema(src.entities.snapshotTable, lag, start, end),
wait_for_simple_schema(src.entities.mutationTable, lag, start, end),
],
)
)
elif src.entities:
result = [
wait_for_simple_schema(src.entities.snapshotTable, lag, start, end)
]
result = [wait_for_simple_schema(src.entities.snapshotTable, lag, start, end)]
elif src.joinSource:
parentJoinOutputTable = get_join_output_table_name(
src.joinSource.join, True
)
parentJoinOutputTable = get_join_output_table_name(src.joinSource.join, True)
result = [wait_for_simple_schema(parentJoinOutputTable, lag, start, end)]
else:
result = [wait_for_simple_schema(src.events.table, lag, start, end)]
Expand All @@ -316,31 +312,17 @@ def get_bootstrap_dependencies(bootstrap_parts) -> List[str]:
dependencies = []
for bootstrap_part in bootstrap_parts:
table = bootstrap_part.table
start = (
bootstrap_part.query.startPartition
if bootstrap_part.query is not None
else None
)
end = (
bootstrap_part.query.endPartition
if bootstrap_part.query is not None
else None
)
start = bootstrap_part.query.startPartition if bootstrap_part.query is not None else None
end = bootstrap_part.query.endPartition if bootstrap_part.query is not None else None
dependencies.append(wait_for_simple_schema(table, 0, start, end))
return [json.dumps(dep) for dep in dependencies]


def get_label_table_dependencies(label_part) -> List[str]:
label_info = [
(label.groupBy.sources, label.groupBy.metaData) for label in label_part.labels
]
label_info = [
(source, meta_data) for (sources, meta_data) in label_info for source in sources
]
label_info = [(label.groupBy.sources, label.groupBy.metaData) for label in label_part.labels]
label_info = [(source, meta_data) for (sources, meta_data) in label_info for source in sources]
label_dependencies = [
dep
for (source, meta_data) in label_info
for dep in get_dependencies(src=source, meta_data=meta_data)
dep for (source, meta_data) in label_info for dep in get_dependencies(src=source, meta_data=meta_data)
]
label_dependencies.append(
json.dumps(
Expand All @@ -360,9 +342,7 @@ def wait_for_simple_schema(table, lag, start, end):
clean_name = table_tokens[0]
subpartition_spec = "/".join(table_tokens[1:]) if len(table_tokens) > 1 else ""
return {
"name": "wait_for_{}_ds{}".format(
clean_name, "" if lag == 0 else f"_minus_{lag}"
),
"name": "wait_for_{}_ds{}".format(clean_name, "" if lag == 0 else f"_minus_{lag}"),
"spec": "{}/ds={}{}".format(
clean_name,
"{{ ds }}" if lag == 0 else "{{{{ macros.ds_add(ds, -{}) }}}}".format(lag),
Expand All @@ -388,8 +368,7 @@ def dedupe_in_order(seq):
def has_topic(group_by: api.GroupBy) -> bool:
"""Find if there's topic or mutationTopic for a source helps define streaming tasks"""
return any(
(source.entities and source.entities.mutationTopic)
or (source.events and source.events.topic)
(source.entities and source.entities.mutationTopic) or (source.events and source.events.topic)
for source in group_by.sources
)

Expand Down Expand Up @@ -468,3 +447,20 @@ def get_related_table_names(conf: ChrononJobTypes) -> List[str]:
related_tables.append(f"{table_name}_bootstrap")

return related_tables


class DotDict(dict):
def __getattr__(self, attr):
if attr in self:
value = self[attr]
return DotDict(value) if isinstance(value, dict) else value
return None


def convert_json_to_obj(d):
if isinstance(d, dict):
return DotDict({k: convert_json_to_obj(v) for k, v in d.items()})
elif isinstance(d, list):
return [convert_json_to_obj(item) for item in d]
else:
return d

0 comments on commit f830a89

Please sign in to comment.