Skip to content

Commit

Permalink
check whether AUTH_ROLE_PUBLIC is set in check_authentication (#38924)
Browse files Browse the repository at this point in the history
* fix(security): check whether AUTH_ROLE_PUBLIC is set in check_authentication

* test(api_connexion): ensure the auth_role_public is not set in minimal_app_for_api

* test(endpoints): add test case to each of the endpoints for auth_role_public cases
  • Loading branch information
Lee-W committed Apr 14, 2024
1 parent 3eac977 commit 7b60825
Show file tree
Hide file tree
Showing 10 changed files with 672 additions and 2 deletions.
6 changes: 6 additions & 0 deletions airflow/api_connexion/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,12 @@ def check_authentication() -> None:
response = auth.requires_authentication(Response)()
if response.status_code == 200:
return

# Even if the current_user is anonymous, the AUTH_ROLE_PUBLIC might still have permission.
appbuilder = get_airflow_app().appbuilder
if appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None):
return

# since this handler only checks authentication, not authorization,
# we should always return 401
raise Unauthenticated(headers=response.headers)
Expand Down
15 changes: 14 additions & 1 deletion tests/api_connexion/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ def minimal_app_for_api():
)
def factory():
with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}):
return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore
_app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore
_app.config["AUTH_ROLE_PUBLIC"] = None
return _app

return factory()

Expand All @@ -67,3 +69,14 @@ def dagbag():
)
DagBag(include_examples=True, read_dags_from_db=False).sync_to_db()
return DagBag(include_examples=True, read_dags_from_db=True)


@pytest.fixture
def set_auto_role_public(request):
app = request.getfixturevalue("minimal_app_for_api")
auto_role_public = app.config["AUTH_ROLE_PUBLIC"]
app.config["AUTH_ROLE_PUBLIC"] = request.param

yield

app.config["AUTH_ROLE_PUBLIC"] = auto_role_public
22 changes: 22 additions & 0 deletions tests/api_connexion/endpoints/test_config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,16 @@ def test_should_respond_403_when_expose_config_off(self):
assert response.status_code == 403
assert "chose not to expose" in response.json["detail"]

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
response = self.client.get("/api/v1/config", headers={"Accept": "application/json"})

assert response.status_code == expected_status_code


class TestGetValue:
@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -339,3 +349,15 @@ def test_should_respond_403_when_expose_config_off(self):
)
assert response.status_code == 403
assert "chose not to expose" in response.json["detail"]

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
response = self.client.get(
"/api/v1/config/section/smtp/option/smtp_mail_from", headers={"Accept": "application/json"}
)

assert response.status_code == expected_status_code
89 changes: 89 additions & 0 deletions tests/api_connexion/endpoints/test_connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,22 @@ def test_should_raise_403_forbidden(self):
)
assert response.status_code == 403

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 204)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session):
connection_model = Connection(conn_id="test-connection", conn_type="test_type")
session.add(connection_model)
session.commit()
conn = session.query(Connection).all()
assert len(conn) == 1

response = self.client.delete("/api/v1/connections/test-connection")

assert response.status_code == expected_status_code


class TestGetConnection(TestConnectionEndpoint):
def test_should_respond_200(self, session):
Expand Down Expand Up @@ -178,6 +194,31 @@ def test_should_raises_401_unauthenticated(self):

assert_401(response)

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session):
connection_model = Connection(
conn_id="test-connection-id",
conn_type="mysql",
description="test description",
host="mysql",
login="login",
schema="testschema",
port=80,
extra='{"param": "value"}',
)
session.add(connection_model)
session.commit()
result = session.query(Connection).all()
assert len(result) == 1

response = self.client.get("/api/v1/connections/test-connection-id")

assert response.status_code == expected_status_code


class TestGetConnections(TestConnectionEndpoint):
def test_should_respond_200(self, session):
Expand Down Expand Up @@ -256,6 +297,16 @@ def test_should_raises_401_unauthenticated(self):

assert_401(response)

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
response = self.client.get("/api/v1/connections")

assert response.status_code == expected_status_code


class TestGetConnectionsPagination(TestConnectionEndpoint):
@pytest.mark.parametrize(
Expand Down Expand Up @@ -529,6 +580,21 @@ def test_should_raises_401_unauthenticated(self, session):

assert_401(response)

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session):
self._create_connection(session)

response = self.client.patch(
"/api/v1/connections/test-connection-id",
json={"connection_id": "test-connection-id", "conn_type": "test_type", "extra": '{"key": "var"}'},
)

assert response.status_code == expected_status_code


class TestPostConnection(TestConnectionEndpoint):
def test_post_should_respond_200(self, session):
Expand Down Expand Up @@ -610,6 +676,18 @@ def test_should_raises_401_unauthenticated(self):

assert_401(response)

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
response = self.client.post(
"/api/v1/connections", json={"connection_id": "test-connection-id", "conn_type": "test_type"}
)

assert response.status_code == expected_status_code


class TestConnection(TestConnectionEndpoint):
@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
Expand Down Expand Up @@ -663,3 +741,14 @@ def test_should_respond_403_by_default(self):
"Testing connections is disabled in Airflow configuration. "
"Contact your deployment admin to enable it."
)

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
@mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"})
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"}
response = self.client.post("/api/v1/connections/test", json=payload)
assert response.status_code == expected_status_code
99 changes: 99 additions & 0 deletions tests/api_connexion/endpoints/test_dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,6 +317,24 @@ def test_should_respond_400_with_not_exists_fields(self, fields):
)
assert response.status_code == 400, f"Current code: {response.status_code}"

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session):
dag_model = DagModel(
dag_id="TEST_DAG_1",
fileloc="/tmp/dag_1.py",
schedule_interval=None,
is_paused=False,
)
session.add(dag_model)
session.commit()

response = self.client.get("/api/v1/dags/TEST_DAG_1")
assert response.status_code == expected_status_code


class TestGetDagDetails(TestDagEndpoint):
def test_should_respond_200(self, url_safe_serializer):
Expand Down Expand Up @@ -728,6 +746,18 @@ def test_should_respond_400_with_not_exists_fields(self):
)
assert response.status_code == 400, f"Current code: {response.status_code}"

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, url_safe_serializer):
self._create_dag_model_for_details_endpoint(self.dag_id)
url_safe_serializer.dumps("/tmp/dag.py")
response = self.client.get(f"/api/v1/dags/{self.dag_id}/details")

assert response.status_code == expected_status_code


class TestGetDags(TestDagEndpoint):
@provide_session
Expand Down Expand Up @@ -1259,6 +1289,22 @@ def test_should_respond_400_with_not_exists_fields(self):

assert response.status_code == 400, f"Current code: {response.status_code}"

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session):
self._create_dag_models(2)
self._create_deactivated_dag()

dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
assert len(dags_query.all()) == 3

response = self.client.get("api/v1/dags")

assert response.status_code == expected_status_code


class TestPatchDag(TestDagEndpoint):
def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, session):
Expand Down Expand Up @@ -1485,6 +1531,24 @@ def test_should_respond_403_unauthorized(self):

assert response.status_code == 403

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(
self, set_auto_role_public, expected_status_code, url_safe_serializer, session
):
url_safe_serializer.dumps("/tmp/dag_1.py")
dag_model = self._create_dag_model()
payload = {"is_paused": False}
response = self.client.patch(
f"/api/v1/dags/{dag_model.dag_id}",
json=payload,
)

assert response.status_code == expected_status_code


class TestPatchDags(TestDagEndpoint):
@provide_session
Expand Down Expand Up @@ -2291,6 +2355,29 @@ def test_should_respons_400_dag_id_pattern_missing(self):
)
assert response.status_code == 400

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 200)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(
self, set_auto_role_public, expected_status_code, session, url_safe_serializer
):
url_safe_serializer.dumps("/tmp/dag_1.py")
url_safe_serializer.dumps("/tmp/dag_2.py")
self._create_dag_models(2)
self._create_deactivated_dag()

dags_query = session.query(DagModel).filter(~DagModel.is_subdag)
assert len(dags_query.all()) == 3

response = self.client.patch(
"/api/v1/dags?dag_id_pattern=~",
json={"is_paused": False},
)

assert response.status_code == expected_status_code


class TestDeleteDagEndpoint(TestDagEndpoint):
def test_that_dag_can_be_deleted(self, session):
Expand Down Expand Up @@ -2342,3 +2429,15 @@ def test_users_without_delete_permission_cannot_delete_dag(self):
environ_overrides={"REMOTE_USER": "test_no_permissions"},
)
assert response.status_code == 403

@pytest.mark.parametrize(
"set_auto_role_public, expected_status_code",
(("Public", 403), ("Admin", 204)),
indirect=["set_auto_role_public"],
)
def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code):
self._create_dag_models(1)

response = self.client.delete("/api/v1/dags/TEST_DAG_1")

assert response.status_code == expected_status_code

0 comments on commit 7b60825

Please sign in to comment.