Skip to content

Commit

Permalink
Fix updating variables during variable imports (#33932)
Browse files Browse the repository at this point in the history
* Fix updating variables during variable imports

We should only create new variables during variable imports and not update
already existing variables.

* Apply suggestions from code review

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* Use flag for variable import in cli and UI

* apply suggestions from code review

* Update airflow/cli/commands/variable_command.py

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>

* fixup! Update airflow/cli/commands/variable_command.py

---------

Co-authored-by: Tzu-ping Chung <uranusjr@gmail.com>
  • Loading branch information
ephraimbuddy and uranusjr committed Sep 1, 2023
1 parent c4967b0 commit 0e1c106
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 4 deletions.
8 changes: 7 additions & 1 deletion airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,12 @@ def string_lower_type(val):
help="Export all variables to JSON file",
type=argparse.FileType("w", encoding="UTF-8"),
)
ARG_VAR_ACTION_ON_EXISTING_KEY = Arg(
("-a", "--action-on-existing-key"),
help="Action to take if we encounter a variable key that already exists.",
default="overwrite",
choices=("overwrite", "fail", "skip"),
)

# kerberos
ARG_PRINCIPAL = Arg(("principal",), help="kerberos principal", nargs="?")
Expand Down Expand Up @@ -1454,7 +1460,7 @@ class GroupCommand(NamedTuple):
name="import",
help="Import variables",
func=lazy_load_command("airflow.cli.commands.variable_command.variables_import"),
args=(ARG_VAR_IMPORT, ARG_VERBOSE),
args=(ARG_VAR_IMPORT, ARG_VAR_ACTION_ON_EXISTING_KEY, ARG_VERBOSE),
),
ActionCommand(
name="export",
Expand Down
20 changes: 18 additions & 2 deletions airflow/cli/commands/variable_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from airflow.utils import cli as cli_utils
from airflow.utils.cli import suppress_logs_and_warning
from airflow.utils.providers_configuration_loader import providers_configuration_loaded
from airflow.utils.session import create_session
from airflow.utils.session import create_session, provide_session


@suppress_logs_and_warning
Expand Down Expand Up @@ -76,7 +76,8 @@ def variables_delete(args):

@cli_utils.action_cli
@providers_configuration_loaded
def variables_import(args):
@provide_session
def variables_import(args, session):
"""Import variables from a given file."""
if not os.path.exists(args.file):
raise SystemExit("Missing variables file.")
Expand All @@ -86,7 +87,17 @@ def variables_import(args):
except JSONDecodeError:
raise SystemExit("Invalid variables file.")
suc_count = fail_count = 0
skipped = set()
action_on_existing = args.action_on_existing_key
existing_keys = set()
if action_on_existing != "overwrite":
existing_keys = set(session.scalars(select(Variable.key).where(Variable.key.in_(var_json))))
if action_on_existing == "fail" and existing_keys:
raise SystemExit(f"Failed. These keys: {sorted(existing_keys)} already exists.")
for k, v in var_json.items():
if action_on_existing == "skip" and k in existing_keys:
skipped.add(k)
continue
try:
Variable.set(k, v, serialize_json=not isinstance(v, str))
except Exception as e:
Expand All @@ -97,6 +108,11 @@ def variables_import(args):
print(f"{suc_count} of {len(var_json)} variables successfully updated.")
if fail_count:
print(f"{fail_count} variable(s) failed to be updated.")
if skipped:
print(
f"The variables with these keys: {list(sorted(skipped))} "
f"were skipped because they already exists"
)


@providers_configuration_loaded
Expand Down
12 changes: 12 additions & 0 deletions airflow/www/templates/airflow/variable_list.html
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@
<div class="form-group">
<input class="form-control-file" type="file" name="file">
</div>
<div class="form-group form-check">
<input type="radio" class="form-check-input" name="action_if_exists" value="overwrite" checked/>
<label class="form-check-label">Overwrite if exists</label>
</div>
<div class="form-group form-check">
<input type="radio" class="form-check-input" name="action_if_exists" value="fail"/>
<label class="form-check-label">Fail if exists</label>
</div>
<div class="form-group form-check">
<input type="radio" class="form-check-input" name="action_if_exists" value="skip" />
<label class="form-check-label">Skip if exists</label>
</div>
<button type="submit" class="btn">
<span class="material-icons">cloud_upload</span>
Import Variables
Expand Down
26 changes: 25 additions & 1 deletion airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -5137,17 +5137,34 @@ def action_varexport(self, items):
@expose("/varimport", methods=["POST"])
@auth.has_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_VARIABLE)])
@action_logging(event=f"{permissions.RESOURCE_VARIABLE.lower()}.varimport")
def varimport(self):
@provide_session
def varimport(self, session):
"""Import variables."""
try:
variable_dict = json.loads(request.files["file"].read())
action_on_existing = request.form.get("action_if_exists", "overwrite").lower()
except Exception:
self.update_redirect()
flash("Missing file or syntax error.", "error")
return redirect(self.get_redirect())
else:
existing_keys = set()
if action_on_existing != "overwrite":
existing_keys = set(
session.scalars(select(models.Variable.key).where(models.Variable.key.in_(variable_dict)))
)
if action_on_existing == "fail" and existing_keys:
failed_repr = ", ".join(repr(k) for k in sorted(existing_keys))
flash(f"Failed. The variables with these keys: {failed_repr} already exists.")
logging.error(f"Failed. The variables with these keys: {failed_repr} already exists.")
return redirect(location=request.referrer)
skipped = set()
suc_count = fail_count = 0
for k, v in variable_dict.items():
if action_on_existing == "skip" and k in existing_keys:
logging.warning("Variable: %s already exists, skipping.", k)
skipped.add(k)
continue
try:
models.Variable.set(k, v, serialize_json=not isinstance(v, str))
except Exception as exc:
Expand All @@ -5158,6 +5175,13 @@ def varimport(self):
flash(f"{suc_count} variable(s) successfully updated.")
if fail_count:
flash(f"{fail_count} variable(s) failed to be updated.", "error")
if skipped:
skipped_repr = ", ".join(repr(k) for k in sorted(skipped))
flash(
f"The variables with these keys: {skipped_repr} were skipped "
"because they already exists",
"warning",
)
self.update_redirect()
return redirect(self.get_redirect())

Expand Down
18 changes: 18 additions & 0 deletions tests/cli/commands/test_variable_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,24 @@ def test_variables_set_different_types(self):
assert Variable.get("false", deserialize_json=True) is False
assert Variable.get("null", deserialize_json=True) is None

# test variable import skip existing
# set varliable list to ["airflow"] and have it skip during import
variable_command.variables_set(self.parser.parse_args(["variables", "set", "list", '["airflow"]']))
variable_command.variables_import(
self.parser.parse_args(
["variables", "import", "variables_types.json", "--action-on-existing-key", "skip"]
)
)
assert ["airflow"] == Variable.get("list", deserialize_json=True) # should not be overwritten

# test variable import fails on existing when action is set to fail
with pytest.raises(SystemExit):
variable_command.variables_import(
self.parser.parse_args(
["variables", "import", "variables_types.json", "--action-on-existing-key", "fail"]
)
)

os.remove("variables_types.json")

def test_variables_list(self):
Expand Down
49 changes: 49 additions & 0 deletions tests/www/views/test_views_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,55 @@ def test_import_variables_success(session, admin_client):
_check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None)


def test_import_variables_override_existing_variables_if_set(session, admin_client, caplog):
assert session.query(Variable).count() == 0
Variable.set("str_key", "str_value")
content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists
bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))

resp = admin_client.post(
"/variable/varimport",
data={"file": (bytes_content, "test.json"), "action_if_exist": "overwrite"},
follow_redirects=True,
)
check_content_in_response("2 variable(s) successfully updated.", resp)
_check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None)


def test_import_variables_skips_update_if_set(session, admin_client, caplog):
assert session.query(Variable).count() == 0
Variable.set("str_key", "str_value")
content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists
bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))

resp = admin_client.post(
"/variable/varimport",
data={"file": (bytes_content, "test.json"), "action_if_exists": "skip"},
follow_redirects=True,
)
check_content_in_response("1 variable(s) successfully updated.", resp)

check_content_in_response(
"The variables with these keys: &#39;str_key&#39; were skipped because they already exists", resp
)
_check_last_log(session, dag_id=None, event="variables.varimport", execution_date=None)
assert "Variable: str_key already exists, skipping." in caplog.text


def test_import_variables_fails_if_action_if_exists_is_fail(session, admin_client, caplog):
assert session.query(Variable).count() == 0
Variable.set("str_key", "str_value")
content = '{"str_key": "str_value", "int_key": 60}' # str_key already exists
bytes_content = io.BytesIO(bytes(content, encoding="utf-8"))

admin_client.post(
"/variable/varimport",
data={"file": (bytes_content, "test.json"), "action_if_exists": "fail"},
follow_redirects=True,
)
assert "Failed. The variables with these keys: 'str_key' already exists." in caplog.text


def test_import_variables_anon(session, app):
assert session.query(Variable).count() == 0

Expand Down

0 comments on commit 0e1c106

Please sign in to comment.