-
Notifications
You must be signed in to change notification settings - Fork 177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modify add_license_url
DAG to use batched_update
#4370
Changes from 5 commits
f5f52dd
ac25e18
ca08682
fb33638
5ba6bec
5e11841
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,14 +16,15 @@ | |
from airflow.exceptions import AirflowSkipException | ||
from airflow.models.abstractoperator import AbstractOperator | ||
from airflow.models.param import Param | ||
from airflow.utils.state import State | ||
from airflow.utils.trigger_rule import TriggerRule | ||
from airflow.operators.trigger_dagrun import TriggerDagRunOperator | ||
from psycopg2._json import Json | ||
from tabulate import tabulate | ||
|
||
from common import slack | ||
from common.constants import DAG_DEFAULT_ARGS, POSTGRES_CONN_ID | ||
from common.licenses import get_license_info_from_license_pair | ||
from common.sql import RETURN_ROW_COUNT, PostgresHook | ||
from common.sql import PostgresHook | ||
from database.batched_update.constants import DAG_ID as BATCHED_UPDATE_DAG_ID | ||
|
||
|
||
DAG_ID = "add_license_url" | ||
|
@@ -54,148 +55,106 @@ def run_sql( | |
|
||
|
||
@task | ||
def get_license_groups(query: str, ti=None) -> list[tuple[str, str]]: | ||
def get_licenses(query: str, ti=None) -> list[tuple[str, str, str]]: | ||
""" | ||
Get license groups of rows that don't have a `license_url` in their | ||
`meta_data` field. | ||
`meta_data` field and notify the start of the DAG. | ||
|
||
:return: List of (license, version) tuples. | ||
""" | ||
license_groups = run_sql(query, dag_task=ti.task) | ||
|
||
total_nulls = sum(group[2] for group in license_groups) | ||
licenses_detailed = "\n".join( | ||
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups | ||
licenses, invalid = [], [] | ||
headers = ["license", "version", "count"] | ||
tabulate_params = { | ||
"headers": headers, | ||
"showindex": True, | ||
"tablefmt": "rounded_grid", | ||
"floatfmt": ".1f", | ||
"intfmt": ",", | ||
} | ||
|
||
for row in license_groups: | ||
license_, license_version, _ = row | ||
license_info = get_license_info_from_license_pair(license_, license_version) | ||
if license_info is None: | ||
invalid.append(row) | ||
else: | ||
licenses.append(license_info) | ||
|
||
license_groups = [lg for lg in license_groups if lg not in invalid] | ||
|
||
message = ( | ||
f""" | ||
Starting `{DAG_ID}` DAG. Found {len(license_groups):.0f} license groups with {total_nulls:.0f} | ||
records to back fill `license_url` in `meta_data`.\nCount per license-version: | ||
``` | ||
{tabulate(license_groups, **tabulate_params)} | ||
``` | ||
""" | ||
if license_groups | ||
else f""" | ||
No license groups found with records missing `license_url` in `meta_data`. The `{DAG_ID}` DAG is done. | ||
""" | ||
) | ||
|
||
message = f""" | ||
Starting `{DAG_ID}` DAG. Found {len(license_groups)} license groups with {total_nulls} | ||
records without `license_url` in `meta_data` left.\nCount per license-version: | ||
{licenses_detailed} | ||
""" | ||
if invalid: | ||
message += f""" | ||
\nThe following *invalid license(s)* were found and will be skipped: | ||
``` | ||
{tabulate(invalid, **tabulate_params)} | ||
``` | ||
""" | ||
|
||
slack.send_message( | ||
message, | ||
username="Airflow DAG Data Normalization - license_url", | ||
dag_id=DAG_ID, | ||
) | ||
|
||
return [(group[0], group[1]) for group in license_groups] | ||
|
||
return licenses | ||
|
||
@task(max_active_tis_per_dag=1, execution_timeout=timedelta(hours=36)) | ||
def update_license_url(license_group: tuple[str, str], batch_size: int, ti=None) -> int: | ||
""" | ||
Add license_url to meta_data batching all records with the same license. | ||
|
||
:param license_group: tuple of license and version | ||
:param batch_size: number of records to update in one update statement | ||
:param ti: automatically passed by Airflow, used to set the execution timeout. | ||
""" | ||
license_, version = license_group | ||
license_info = get_license_info_from_license_pair(license_, version) | ||
if license_info is None: | ||
raise AirflowSkipException( | ||
f"No license pair ({license_}, {version}) in the license map." | ||
) | ||
*_, license_url = license_info | ||
|
||
logging.info( | ||
f"Will add `license_url` in `meta_data` for records with license " | ||
f"{license_} {version} to {license_url}." | ||
) | ||
def get_license_conf(license_info) -> dict: | ||
license_, license_version, license_url = license_info | ||
license_url_dict = {"license_url": license_url} | ||
query_id = f"add_license_url_{license_}_{license_version}" | ||
for char_to_remove in [".", "-"]: | ||
query_id = query_id.replace(char_to_remove, "_") | ||
|
||
conf = { | ||
"query_id": query_id, | ||
"table_name": "image", | ||
"select_query": ( | ||
f"WHERE license = '{license_}' AND license_version = '{license_version}' " | ||
f"AND meta_data->>'license_url' IS NULL" | ||
), | ||
# Merge existing metadata with the new license_url | ||
"update_query": f"SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now()", | ||
"update_timeout": 259200, # 3 days in seconds | ||
"dry_run": False, | ||
"resume_update": False, | ||
} | ||
return conf | ||
|
||
# Merge existing metadata with the new license_url | ||
update_query = dedent( | ||
f""" | ||
UPDATE image | ||
SET meta_data = ({Json(license_url_dict)}::jsonb || meta_data), updated_on = now() | ||
WHERE identifier IN ( | ||
SELECT identifier | ||
FROM image | ||
WHERE license = '{license_}' AND license_version = '{version}' | ||
AND meta_data->>'license_url' IS NULL | ||
LIMIT {batch_size} | ||
FOR UPDATE SKIP LOCKED | ||
); | ||
""" | ||
) | ||
total_updated = 0 | ||
updated_count = 1 | ||
while updated_count: | ||
updated_count = run_sql( | ||
update_query, | ||
log_sql=total_updated == 0, | ||
method="run", | ||
handler=RETURN_ROW_COUNT, | ||
autocommit=True, | ||
dag_task=ti.task, | ||
) | ||
total_updated += updated_count | ||
logger.info(f"Updated {total_updated} rows with {license_url}.") | ||
|
||
return total_updated | ||
|
||
|
||
@task(trigger_rule=TriggerRule.ALL_DONE) | ||
def report_completion(updated, query: str, ti=None): | ||
""" | ||
Check for null in `meta_data` and send a message to Slack with the statistics | ||
of the DAG run. | ||
|
||
:param updated: total number of records updated | ||
:param query: SQL query to get the count of records left with `license_url` as NULL | ||
:param ti: automatically passed by Airflow, used to set the execution timeout. | ||
""" | ||
total_updated = sum(updated) if updated else 0 | ||
|
||
license_groups = run_sql(query, dag_task=ti.task) | ||
total_nulls = sum(group[2] for group in license_groups) | ||
licenses_detailed = "\n".join( | ||
f"{group[0]} \t{group[1]} \t{group[2]}" for group in license_groups | ||
) | ||
|
||
message = f""" | ||
`{DAG_ID}` DAG run completed. Updated {total_updated} record(s) with `license_url` in the | ||
`meta_data` field. Found {len(license_groups)} license groups with {total_nulls} record(s) left pending. | ||
""" | ||
if total_nulls != 0: | ||
message += f"\nCount per license-version:\n{licenses_detailed}" | ||
|
||
slack.send_message( | ||
message, | ||
username="Airflow DAG Data Normalization - license_url", | ||
dag_id=DAG_ID, | ||
) | ||
|
||
@task | ||
def get_confs(licenses, batch_size: int) -> list[dict]: | ||
if not licenses: | ||
raise AirflowSkipException("No config required.") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This sounds confusing. What does "No config required." mean here? Should it be the opposite, "License config required."? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If there are no licenses to backfill, then the DAG stops here. There is no need to create a set of configurations for the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Definitely clearer :) Thank you! |
||
|
||
@task(trigger_rule=TriggerRule.ALL_DONE) | ||
def report_failed_license_pairs(dag_run=None): | ||
""" | ||
Send a message to Slack with the license-version pairs that could not be found | ||
in the license map. | ||
""" | ||
skipped_tasks = [ | ||
dag_task | ||
for dag_task in dag_run.get_task_instances(state=State.SKIPPED) | ||
if "update_license_url" in dag_task.task_id | ||
return [ | ||
{"batch_size": batch_size, **get_license_conf(license_info)} | ||
for license_info in licenses | ||
] | ||
|
||
if not skipped_tasks: | ||
raise AirflowSkipException | ||
|
||
message = ( | ||
f""" | ||
One or more license pairs could not be found in the license map while running | ||
the `{DAG_ID}` DAG. See the logs for more details: | ||
""" | ||
) + "\n".join( | ||
f" - <{dag_task.log_url}|{dag_task.task_id}>" for dag_task in skipped_tasks[:5] | ||
) | ||
|
||
slack.send_alert( | ||
message, | ||
username="Airflow DAG Data Normalization - license_url", | ||
@task | ||
def notify_slack(): | ||
slack.send_message( | ||
"Finished processing the groups of licenses.", | ||
username=f"Airflow DAG Data Normalization - {DAG_ID}", | ||
dag_id=DAG_ID, | ||
) | ||
|
||
|
@@ -227,12 +186,19 @@ def add_license_url(): | |
GROUP BY license, license_version | ||
""") | ||
|
||
license_groups = get_license_groups(query) | ||
updated = update_license_url.partial(batch_size="{{ params.batch_size }}").expand( | ||
license_group=license_groups | ||
) | ||
report_completion(updated, query) | ||
updated >> report_failed_license_pairs() | ||
licenses = get_licenses(query) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: I think it would be easier to understand the code flow if the query is moved inside There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Right, the previous version of the DAG reused this query for two tasks. Now this can be in there as you suggest 👍 |
||
|
||
trigger = TriggerDagRunOperator.partial( | ||
task_id="trigger_batched_update", | ||
trigger_dag_id=BATCHED_UPDATE_DAG_ID, | ||
wait_for_completion=True, | ||
execution_timeout=timedelta(hours=5), | ||
max_active_tis_per_dag=1, | ||
map_index_template="""{{ task.conf['query_id'] }}""", | ||
retries=0, | ||
).expand(conf=get_confs(licenses, batch_size="{{ params.batch_size }}")) | ||
|
||
trigger >> notify_slack() | ||
|
||
|
||
add_license_url() |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,4 +8,5 @@ psycopg2-binary | |
requests-file==2.0.* | ||
requests-oauthlib | ||
retry==0.9.2 | ||
tabulate==0.9.0 | ||
tldextract==5.1.2 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could maybe be better expressed as: