Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify add_license_url DAG to use batched_update #4370

Merged
merged 6 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
218 changes: 92 additions & 126 deletions catalog/dags/maintenance/add_license_url.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from airflow.exceptions import AirflowSkipException
from airflow.models.abstractoperator import AbstractOperator
from airflow.models.param import Param
from airflow.utils.state import State
from airflow.utils.trigger_rule import TriggerRule
from airflow.operators.trigger_dagrun import TriggerDagRunOperator
from psycopg2._json import Json
from tabulate import tabulate

from common import slack
from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID
from common.licenses import get_license_info_from_license_pair
from common.sql import RETURN_ROW_COUNT, PostgresHook
from common.sql import PostgresHook
from database.batched_update.constants import DAG_ID as BATCHED_UPDATE_DAG_ID


DAG_ID = "add_license_url"
Expand Down Expand Up @@ -54,148 +55,106 @@ def run_sql(


@task
def get_license_groups(query: str, ti=None) -> list[tuple[str, str]]:
def get_licenses(query: str, ti=None) -> list[tuple[str, str, str]]:
"""
Get license groups of rows that don't have a `license_url` in their
`meta_data` field.
`meta_data` field and notify the start of the DAG.

:return: List of (license, version) tuples.
"""
license_groups = run_sql(query, dag_task=ti.task)

total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups
licenses, invalid = [], []
headers = ["license", "version", "count"]
tabulate_params = {
"headers": headers,
"showindex": True,
"tablefmt": "rounded_grid",
"floatfmt": ".1f",
"intfmt": ",",
}

for row in license_groups:
license_, license_version, _ = row
license_info = get_license_info_from_license_pair(license_, license_version)
if license_info is None:
invalid.append(row)
else:
licenses.append(license_info)

license_groups = [lg for lg in license_groups if lg not in invalid]

message = (
f"""
Starting `{DAG_ID}` DAG. Found {len(license_groups):.0f} license groups with {total_nulls:.0f}
records to back fill `license_url` in `meta_data`.\nCount per license-version:
```
{tabulate(license_groups, **tabulate_params)}
```
"""
if license_groups
else f"""
No license groups found with records missing `license_url` in `meta_data`. The `{DAG_ID}` DAG is done.
"""
)

message = f"""
Starting `{DAG_ID}` DAG. Found {len(license_groups)} license groups with {total_nulls}
records without `license_url` in `meta_data` left.\nCount per license-version:
{licenses_detailed}
"""
if invalid:
message += f"""
\nThe following *invalid license(s)* were found and will be skipped:
```
{tabulate(invalid, **tabulate_params)}
```
"""

slack.send_message(
message,
username="Airflow DAG Data Normalization - license_url",
dag_id=DAG_ID,
)

return [(group[0], group[1]) for group in license_groups]

return licenses

@task(max_active_tis_per_dag=1, execution_timeout=timedelta(hours=36))
def update_license_url(license_group: tuple[str, str], batch_size: int, ti=None) -> int:
"""
Add license_url to meta_data batching all records with the same license.

:param license_group: tuple of license and version
:param batch_size: number of records to update in one update statement
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
license_, version = license_group
license_info = get_license_info_from_license_pair(license_, version)
if license_info is None:
raise AirflowSkipException(
f"No license pair ({license_}, {version}) in the license map."
)
*_, license_url = license_info

logging.info(
f"Will add `license_url` in `meta_data` for records with license "
f"{license_} {version} to {license_url}."
)
def get_license_conf(license_info) -> dict:
license_, license_version, license_url = license_info
license_url_dict = {"license_url": license_url}
query_id = f"add_license_url_{license_}_{license_version}"
for char_to_remove in [".", "-"]:
query_id = query_id.replace(char_to_remove, "_")

conf = {
"query_id": query_id,
"table_name": "image",
"select_query": (
f"WHERE license = '{license_}' AND license_version = '{license_version}' "
f"AND meta_data->>'license_url' IS NULL"
),
# Merge existing metadata with the new license_url
"update_query": f"SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now()",
"update_timeout": 259200, # 3 days in seconds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could maybe be better expressed as:

Suggested change
"update_timeout": 259200, # 3 days in seconds
"update_timeout": 60 * 60 * 24 * 3, # 3 days in seconds

"dry_run": False,
"resume_update": False,
}
return conf

# Merge existing metadata with the new license_url
update_query = dedent(
f"""
UPDATE image
SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now()
WHERE identifier IN (
SELECT identifier
FROM image
WHERE license = '{license_}' AND license_version = '{version}'
AND meta_data->>'license_url' IS NULL
LIMIT {batch_size}
FOR UPDATE SKIP LOCKED
);
"""
)
total_updated = 0
updated_count = 1
while updated_count:
updated_count = run_sql(
update_query,
log_sql=total_updated == 0,
method="run",
handler=RETURN_ROW_COUNT,
autocommit=True,
dag_task=ti.task,
)
total_updated += updated_count
logger.info(f"Updated {total_updated} rows with {license_url}.")

return total_updated


@task(trigger_rule=TriggerRule.ALL_DONE)
def report_completion(updated, query: str, ti=None):
"""
Check for null in `meta_data` and send a message to Slack with the statistics
of the DAG run.

:param updated: total number of records updated
:param query: SQL query to get the count of records left with `license_url` as NULL
:param ti: automatically passed by Airflow, used to set the execution timeout.
"""
total_updated = sum(updated) if updated else 0

license_groups = run_sql(query, dag_task=ti.task)
total_nulls = sum(group[2] for group in license_groups)
licenses_detailed = "\n".join(
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups
)

message = f"""
`{DAG_ID}` DAG run completed. Updated {total_updated} record(s) with `license_url` in the
`meta_data` field. Found {len(license_groups)} license groups with {total_nulls} record(s) left pending.
"""
if total_nulls != 0:
message += f"\nCount per license-version:\n{licenses_detailed}"

slack.send_message(
message,
username="Airflow DAG Data Normalization - license_url",
dag_id=DAG_ID,
)

@task
def get_confs(licenses, batch_size: int) -> list[dict]:
if not licenses:
raise AirflowSkipException("No config required.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This sounds confusing. What does "No config required." mean here? Should it be the opposite, "License config required."?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there are no licenses to backfill, then the DAG stops here. There is no need to create a set of configurations for the batched_update DAG. I rephrased it; I hope it's clearer now!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely clearer :) Thank you!


@task(trigger_rule=TriggerRule.ALL_DONE)
def report_failed_license_pairs(dag_run=None):
"""
Send a message to Slack with the license-version pairs that could not be found
in the license map.
"""
skipped_tasks = [
dag_task
for dag_task in dag_run.get_task_instances(state=State.SKIPPED)
if "update_license_url" in dag_task.task_id
return [
{"batch_size": batch_size, **get_license_conf(license_info)}
for license_info in licenses
]

if not skipped_tasks:
raise AirflowSkipException

message = (
f"""
One or more license pairs could not be found in the license map while running
the `{DAG_ID}` DAG. See the logs for more details:
"""
) + "\n".join(
f" - <{dag_task.log_url}|{dag_task.task_id}>" for dag_task in skipped_tasks[:5]
)

slack.send_alert(
message,
username="Airflow DAG Data Normalization - license_url",
@task
def notify_slack():
slack.send_message(
"Finished processing the groups of licenses.",
username=f"Airflow DAG Data Normalization - {DAG_ID}",
dag_id=DAG_ID,
)

Expand Down Expand Up @@ -227,12 +186,19 @@ def add_license_url():
GROUP BY license, license_version
""")

license_groups = get_license_groups(query)
updated = update_license_url.partial(batch_size="{{ params.batch_size }}").expand(
license_group=license_groups
)
report_completion(updated, query)
updated >> report_failed_license_pairs()
licenses = get_licenses(query)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: I think it would be easier to understand the code flow if the query is moved inside get_licenses function.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, the previous version of the DAG reused this query for two tasks. Now this can be in there as you suggest 👍


trigger = TriggerDagRunOperator.partial(
task_id="trigger_batched_update",
trigger_dag_id=BATCHED_UPDATE_DAG_ID,
wait_for_completion=True,
execution_timeout=timedelta(hours=5),
max_active_tis_per_dag=1,
map_index_template="""{{ task.conf['query_id'] }}""",
retries=0,
).expand(conf=get_confs(licenses, batch_size="{{ params.batch_size }}"))

trigger >> notify_slack()


add_license_url()
1 change: 1 addition & 0 deletions catalog/requirements-prod.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ psycopg2-binary
requests-file==2.0.*
requests-oauthlib
retry==0.9.2
tabulate==0.9.0
tldextract==5.1.2