From 7b608250468740954c6b0af7a5f7f23dfa52b473 Mon Sep 17 00:00:00 2001 From: Wei Lee Date: Sun, 14 Apr 2024 18:06:43 +0800 Subject: [PATCH] check whether AUTH_ROLE_PUBLIC is set in check_authentication (#38924) * fix(security): check whether AUTH_ROLE_PUBLIC is set in check_authentication * test(api_connexion): ensure the auth_role_public is not set in minimal_app_for_api * test(endpoints): add test case to each of the endpoints for auth_role_public cases --- airflow/api_connexion/security.py | 6 + tests/api_connexion/conftest.py | 15 +- .../endpoints/test_config_endpoint.py | 22 +++ .../endpoints/test_connection_endpoint.py | 89 +++++++++ .../endpoints/test_dag_endpoint.py | 99 ++++++++++ .../endpoints/test_dag_run_endpoint.py | 185 +++++++++++++++++ .../endpoints/test_dag_source_endpoint.py | 16 ++ .../endpoints/test_dag_warning_endpoint.py | 12 ++ .../endpoints/test_dataset_endpoint.py | 186 +++++++++++++++++- .../endpoints/test_event_log_endpoint.py | 44 +++++ 10 files changed, 672 insertions(+), 2 deletions(-) diff --git a/airflow/api_connexion/security.py b/airflow/api_connexion/security.py index 1cc044d9dd310..660bc6cce2370 100644 --- a/airflow/api_connexion/security.py +++ b/airflow/api_connexion/security.py @@ -49,6 +49,12 @@ def check_authentication() -> None: response = auth.requires_authentication(Response)() if response.status_code == 200: return + + # Even if the current_user is anonymous, the AUTH_ROLE_PUBLIC might still have permission. + appbuilder = get_airflow_app().appbuilder + if appbuilder.get_app.config.get("AUTH_ROLE_PUBLIC", None): + return + # since this handler only checks authentication, not authorization, # we should always return 401 raise Unauthenticated(headers=response.headers) diff --git a/tests/api_connexion/conftest.py b/tests/api_connexion/conftest.py index c860a78f27167..481f07fe7371d 100644 --- a/tests/api_connexion/conftest.py +++ b/tests/api_connexion/conftest.py @@ -40,7 +40,9 @@ def minimal_app_for_api(): ) def factory(): with conf_vars({("api", "auth_backends"): "tests.test_utils.remote_user_api_auth_backend"}): - return app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + _app = app.create_app(testing=True, config={"WTF_CSRF_ENABLED": False}) # type:ignore + _app.config["AUTH_ROLE_PUBLIC"] = None + return _app return factory() @@ -67,3 +69,14 @@ def dagbag(): ) DagBag(include_examples=True, read_dags_from_db=False).sync_to_db() return DagBag(include_examples=True, read_dags_from_db=True) + + +@pytest.fixture +def set_auto_role_public(request): + app = request.getfixturevalue("minimal_app_for_api") + auto_role_public = app.config["AUTH_ROLE_PUBLIC"] + app.config["AUTH_ROLE_PUBLIC"] = request.param + + yield + + app.config["AUTH_ROLE_PUBLIC"] = auto_role_public diff --git a/tests/api_connexion/endpoints/test_config_endpoint.py b/tests/api_connexion/endpoints/test_config_endpoint.py index c091c4ef1c9f3..3dd5814e5d79e 100644 --- a/tests/api_connexion/endpoints/test_config_endpoint.py +++ b/tests/api_connexion/endpoints/test_config_endpoint.py @@ -222,6 +222,16 @@ def test_should_respond_403_when_expose_config_off(self): assert response.status_code == 403 assert "chose not to expose" in response.json["detail"] + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + response = self.client.get("/api/v1/config", headers={"Accept": "application/json"}) + + assert response.status_code == expected_status_code + class TestGetValue: @pytest.fixture(autouse=True) @@ -339,3 +349,15 @@ def test_should_respond_403_when_expose_config_off(self): ) assert response.status_code == 403 assert "chose not to expose" in response.json["detail"] + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + response = self.client.get( + "/api/v1/config/section/smtp/option/smtp_mail_from", headers={"Accept": "application/json"} + ) + + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_connection_endpoint.py b/tests/api_connexion/endpoints/test_connection_endpoint.py index dc0f2893e01ce..c88b8a56de9d5 100644 --- a/tests/api_connexion/endpoints/test_connection_endpoint.py +++ b/tests/api_connexion/endpoints/test_connection_endpoint.py @@ -112,6 +112,22 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 204)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + connection_model = Connection(conn_id="test-connection", conn_type="test_type") + session.add(connection_model) + session.commit() + conn = session.query(Connection).all() + assert len(conn) == 1 + + response = self.client.delete("/api/v1/connections/test-connection") + + assert response.status_code == expected_status_code + class TestGetConnection(TestConnectionEndpoint): def test_should_respond_200(self, session): @@ -178,6 +194,31 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + connection_model = Connection( + conn_id="test-connection-id", + conn_type="mysql", + description="test description", + host="mysql", + login="login", + schema="testschema", + port=80, + extra='{"param": "value"}', + ) + session.add(connection_model) + session.commit() + result = session.query(Connection).all() + assert len(result) == 1 + + response = self.client.get("/api/v1/connections/test-connection-id") + + assert response.status_code == expected_status_code + class TestGetConnections(TestConnectionEndpoint): def test_should_respond_200(self, session): @@ -256,6 +297,16 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + response = self.client.get("/api/v1/connections") + + assert response.status_code == expected_status_code + class TestGetConnectionsPagination(TestConnectionEndpoint): @pytest.mark.parametrize( @@ -529,6 +580,21 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_connection(session) + + response = self.client.patch( + "/api/v1/connections/test-connection-id", + json={"connection_id": "test-connection-id", "conn_type": "test_type", "extra": '{"key": "var"}'}, + ) + + assert response.status_code == expected_status_code + class TestPostConnection(TestConnectionEndpoint): def test_post_should_respond_200(self, session): @@ -610,6 +676,18 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + response = self.client.post( + "/api/v1/connections", json={"connection_id": "test-connection-id", "conn_type": "test_type"} + ) + + assert response.status_code == expected_status_code + class TestConnection(TestConnectionEndpoint): @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) @@ -663,3 +741,14 @@ def test_should_respond_403_by_default(self): "Testing connections is disabled in Airflow configuration. " "Contact your deployment admin to enable it." ) + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + @mock.patch.dict(os.environ, {"AIRFLOW__CORE__TEST_CONNECTION": "Enabled"}) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + payload = {"connection_id": "test-connection-id", "conn_type": "sqlite"} + response = self.client.post("/api/v1/connections/test", json=payload) + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dag_endpoint.py b/tests/api_connexion/endpoints/test_dag_endpoint.py index 8578f633cf6bb..b514faba276d9 100644 --- a/tests/api_connexion/endpoints/test_dag_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_endpoint.py @@ -317,6 +317,24 @@ def test_should_respond_400_with_not_exists_fields(self, fields): ) assert response.status_code == 400, f"Current code: {response.status_code}" + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + dag_model = DagModel( + dag_id="TEST_DAG_1", + fileloc="/tmp/dag_1.py", + schedule_interval=None, + is_paused=False, + ) + session.add(dag_model) + session.commit() + + response = self.client.get("/api/v1/dags/TEST_DAG_1") + assert response.status_code == expected_status_code + class TestGetDagDetails(TestDagEndpoint): def test_should_respond_200(self, url_safe_serializer): @@ -728,6 +746,18 @@ def test_should_respond_400_with_not_exists_fields(self): ) assert response.status_code == 400, f"Current code: {response.status_code}" + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, url_safe_serializer): + self._create_dag_model_for_details_endpoint(self.dag_id) + url_safe_serializer.dumps("/tmp/dag.py") + response = self.client.get(f"/api/v1/dags/{self.dag_id}/details") + + assert response.status_code == expected_status_code + class TestGetDags(TestDagEndpoint): @provide_session @@ -1259,6 +1289,22 @@ def test_should_respond_400_with_not_exists_fields(self): assert response.status_code == 400, f"Current code: {response.status_code}" + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_dag_models(2) + self._create_deactivated_dag() + + dags_query = session.query(DagModel).filter(~DagModel.is_subdag) + assert len(dags_query.all()) == 3 + + response = self.client.get("api/v1/dags") + + assert response.status_code == expected_status_code + class TestPatchDag(TestDagEndpoint): def test_should_respond_200_on_patch_is_paused(self, url_safe_serializer, session): @@ -1485,6 +1531,24 @@ def test_should_respond_403_unauthorized(self): assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, url_safe_serializer, session + ): + url_safe_serializer.dumps("/tmp/dag_1.py") + dag_model = self._create_dag_model() + payload = {"is_paused": False} + response = self.client.patch( + f"/api/v1/dags/{dag_model.dag_id}", + json=payload, + ) + + assert response.status_code == expected_status_code + class TestPatchDags(TestDagEndpoint): @provide_session @@ -2291,6 +2355,29 @@ def test_should_respons_400_dag_id_pattern_missing(self): ) assert response.status_code == 400 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, session, url_safe_serializer + ): + url_safe_serializer.dumps("/tmp/dag_1.py") + url_safe_serializer.dumps("/tmp/dag_2.py") + self._create_dag_models(2) + self._create_deactivated_dag() + + dags_query = session.query(DagModel).filter(~DagModel.is_subdag) + assert len(dags_query.all()) == 3 + + response = self.client.patch( + "/api/v1/dags?dag_id_pattern=~", + json={"is_paused": False}, + ) + + assert response.status_code == expected_status_code + class TestDeleteDagEndpoint(TestDagEndpoint): def test_that_dag_can_be_deleted(self, session): @@ -2342,3 +2429,15 @@ def test_users_without_delete_permission_cannot_delete_dag(self): environ_overrides={"REMOTE_USER": "test_no_permissions"}, ) assert response.status_code == 403 + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 204)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + self._create_dag_models(1) + + response = self.client.delete("/api/v1/dags/TEST_DAG_1") + + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dag_run_endpoint.py b/tests/api_connexion/endpoints/test_dag_run_endpoint.py index f6ace160998cc..5182ef427e624 100644 --- a/tests/api_connexion/endpoints/test_dag_run_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_run_endpoint.py @@ -214,6 +214,18 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 204)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + session.add_all(self._create_test_dag_run()) + session.commit() + response = self.client.delete("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID_1") + + assert response.status_code == expected_status_code + class TestGetDagRun(TestDagRunEndpoint): def test_should_respond_200(self, session): @@ -333,6 +345,29 @@ def test_should_respond_400_with_not_exists_fields(self, session): ) assert response.status_code == 400, f"Current code: {response.status_code}" + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + dagrun_model = DagRun( + dag_id="TEST_DAG_ID", + run_id="TEST_DAG_RUN_ID", + run_type=DagRunType.MANUAL, + execution_date=timezone.parse(self.default_time), + start_date=timezone.parse(self.default_time), + external_trigger=True, + state="running", + ) + session.add(dagrun_model) + session.commit() + result = session.query(DagRun).all() + assert len(result) == 1 + + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID") + assert response.status_code == expected_status_code + class TestGetDagRuns(TestDagRunEndpoint): def test_should_respond_200(self, session): @@ -508,6 +543,18 @@ def test_should_respond_400_with_not_exists_fields(self): ) assert response.status_code == 400, f"Current code: {response.status_code}" + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_test_dag_run() + result = session.query(DagRun).all() + assert len(result) == 2 + response = self.client.get("api/v1/dags/TEST_DAG_ID/dagRuns") + assert response.status_code == expected_status_code + class TestGetDagRunsPagination(TestDagRunEndpoint): @pytest.mark.parametrize( @@ -931,6 +978,18 @@ def test_should_raises_401_unauthenticated(self): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + self._create_test_dag_run() + + response = self.client.post("api/v1/dags/~/dagRuns/list", json={"dag_ids": ["TEST_DAG_ID"]}) + + assert response.status_code == expected_status_code + class TestGetDagRunBatchPagination(TestDagRunEndpoint): @pytest.mark.parametrize( @@ -1564,6 +1623,26 @@ def test_should_raises_403_unauthorized(self, username): ) assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + execution_date = "2020-11-10T08:25:56.939143+00:00" + logical_date = "2020-11-10T08:25:56.939143+00:00" + self._create_dag("TEST_DAG_ID") + + response = self.client.post( + "api/v1/dags/TEST_DAG_ID/dagRuns", + json={ + "execution_date": execution_date, + "logical_date": logical_date, + }, + ) + + assert response.status_code == expected_status_code + class TestPatchDagRunState(TestDagRunEndpoint): @pytest.mark.parametrize("state", ["failed", "success", "queued"]) @@ -1687,6 +1766,31 @@ def test_should_respond_404(self): ) assert response.status_code == 404 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, dag_maker, session): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + task = EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.RUNNING + session.merge(ti) + session.commit() + + response = self.client.patch( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", + json={"state": "failed"}, + ) + + assert response.status_code == expected_status_code + class TestClearDagRun(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): @@ -1822,6 +1926,31 @@ def test_should_respond_404(self): ) assert response.status_code == 404 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, dag_maker, session): + dag_id = "TEST_DAG_ID" + dag_run_id = "TEST_DAG_RUN_ID" + with dag_maker(dag_id) as dag: + task = EmptyOperator(task_id="task_id", dag=dag) + self.app.dag_bag.bag_dag(dag, root_dag=dag) + dr = dag_maker.create_dagrun(run_id=dag_run_id, run_type=DagRunType.SCHEDULED) + ti = dr.get_task_instance(task_id="task_id") + ti.task = task + ti.state = State.RUNNING + session.merge(ti) + session.commit() + + response = self.client.patch( + f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}", + json={"state": "failed"}, + ) + + assert response.status_code == expected_status_code + @pytest.mark.need_serialized_dag class TestGetDagRunDatasetTriggerEvents(TestDagRunEndpoint): @@ -1916,6 +2045,42 @@ def test_should_raises_401_unauthenticated(self, session): assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, dag_maker, session): + dataset1 = Dataset(uri="ds1") + + with dag_maker(dag_id="source_dag", start_date=timezone.utcnow(), session=session): + EmptyOperator(task_id="task", outlets=[dataset1]) + dr = dag_maker.create_dagrun() + ti = dr.task_instances[0] + + ds1_id = session.query(DatasetModel.id).filter_by(uri=dataset1.uri).scalar() + event = DatasetEvent( + dataset_id=ds1_id, + source_task_id=ti.task_id, + source_dag_id=ti.dag_id, + source_run_id=ti.run_id, + source_map_index=ti.map_index, + ) + session.add(event) + + with dag_maker(dag_id="TEST_DAG_ID", start_date=timezone.utcnow(), session=session): + pass + dr = dag_maker.create_dagrun(run_id="TEST_DAG_RUN_ID", run_type=DagRunType.DATASET_TRIGGERED) + dr.consumed_dataset_events.append(event) + + session.commit() + assert event.timestamp + + response = self.client.get( + "api/v1/dags/TEST_DAG_ID/dagRuns/TEST_DAG_RUN_ID/upstreamDatasetEvents", + ) + assert response.status_code == expected_status_code + class TestSetDagRunNote(TestDagRunEndpoint): def test_should_respond_200(self, dag_maker, session): @@ -2046,3 +2211,23 @@ def test_should_respond_200_with_anonymous_user(self, dag_maker, session): json={"note": "I am setting a note with anonymous user"}, ) assert response.status_code == 200 + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + dag_runs: list[DagRun] = self._create_test_dag_run(DagRunState.SUCCESS) + session.add_all(dag_runs) + session.commit() + created_dr: DagRun = dag_runs[0] + new_note_value = "My super cool DagRun notes" + response = self.client.patch( + f"api/v1/dags/{created_dr.dag_id}/dagRuns/{created_dr.run_id}/setNote", + json={"note": new_note_value}, + ) + + session.query(DagRun).filter(DagRun.run_id == created_dr.run_id).first() + + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dag_source_endpoint.py b/tests/api_connexion/endpoints/test_dag_source_endpoint.py index d48d7e1c02fc6..14c7d1534d4dc 100644 --- a/tests/api_connexion/endpoints/test_dag_source_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_source_endpoint.py @@ -202,3 +202,19 @@ def test_should_respond_403_some_dags_not_readable_in_the_file(self, url_safe_se ) assert response.status_code == 403 assert read_dag.status_code == 200 + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, url_safe_serializer): + dagbag = DagBag(dag_folder=EXAMPLE_DAG_FILE) + dagbag.sync_to_db() + test_dag: DAG = dagbag.dags[TEST_DAG_ID] + self._get_dag_file_docstring(test_dag.fileloc) + + url = f"/api/v1/dagSources/{url_safe_serializer.dumps(test_dag.fileloc)}" + response = self.client.get(url, headers={"Accept": "text/plain"}) + + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py index 9310956d24f63..cc398329b9644 100644 --- a/tests/api_connexion/endpoints/test_dag_warning_endpoint.py +++ b/tests/api_connexion/endpoints/test_dag_warning_endpoint.py @@ -170,3 +170,15 @@ def test_should_raise_403_forbidden_when_user_has_no_dag_read_permission(self): query_string={"dag_id": "dag1"}, ) assert response.status_code == 403 + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code): + response = self.client.get( + "/api/v1/dagWarnings", + query_string={"dag_id": "dag1", "warning_type": "non-existent pool"}, + ) + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_dataset_endpoint.py b/tests/api_connexion/endpoints/test_dataset_endpoint.py index a2451fb30ac26..5b6e2f24146e4 100644 --- a/tests/api_connexion/endpoints/test_dataset_endpoint.py +++ b/tests/api_connexion/endpoints/test_dataset_endpoint.py @@ -143,6 +143,22 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.get(f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}") assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_dataset(session) + assert session.query(DatasetModel).count() == 1 + + with assert_queries_count(5): + response = self.client.get( + f"/api/v1/datasets/{urllib.parse.quote('s3://bucket/key', safe='')}", + ) + + assert response.status_code == expected_status_code + class TestGetDatasets(TestDatasetEndpoint): def test_should_respond_200(self, session): @@ -313,6 +329,31 @@ def test_filter_datasets_by_dag_ids_and_uri_pattern_works( response_data = response.json assert len(response_data["datasets"]) == expected_num + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + datasets = [ + DatasetModel( + id=i, + uri=f"s3://bucket/key/{i}", + extra={"foo": "bar"}, + created_at=timezone.parse(self.default_time), + updated_at=timezone.parse(self.default_time), + ) + for i in [1, 2] + ] + session.add_all(datasets) + session.commit() + assert session.query(DatasetModel).count() == 2 + + with assert_queries_count(8): + response = self.client.get("/api/v1/datasets") + + assert response.status_code == expected_status_code + class TestGetDatasetsEndpointPagination(TestDatasetEndpoint): @pytest.mark.parametrize( @@ -579,6 +620,32 @@ def test_includes_created_dagrun(self, session): "total_entries": 1, } + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_dataset(session) + common = { + "dataset_id": 1, + "extra": {"foo": "bar"}, + "source_dag_id": "foo", + "source_task_id": "bar", + "source_run_id": "custom", + "source_map_index": -1, + "created_dagruns": [], + } + + events = [DatasetEvent(id=i, timestamp=timezone.parse(self.default_time), **common) for i in [1, 2]] + session.add_all(events) + session.commit() + assert session.query(DatasetEvent).count() == 2 + + response = self.client.get("/api/v1/datasets/events") + + assert response.status_code == expected_status_code + class TestPostDatasetEvents(TestDatasetEndpoint): @pytest.fixture @@ -651,6 +718,19 @@ def test_should_raises_401_unauthenticated(self, session): response = self.client.post("/api/v1/datasets/events", json={"dataset_uri": "TEST_DATASET_URI"}) assert_401(response) + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + @pytest.mark.usefixtures("time_freezer") + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, session): + self._create_dataset(session) + event_payload = {"dataset_uri": "s3://bucket/key", "extra": {"foo": "bar"}} + response = self.client.post("/api/v1/datasets/events", json=event_payload) + + assert response.status_code == expected_status_code + class TestGetDatasetEventsEndpointPagination(TestDatasetEndpoint): @pytest.mark.parametrize( @@ -821,6 +901,27 @@ def test_should_raise_403_forbidden(self, session): assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + @pytest.mark.usefixtures("time_freezer") + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, create_dummy_dag, session + ): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_dataset(session).id + self._create_dataset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + ) + + assert response.status_code == expected_status_code + class TestDeleteDagDatasetQueuedEvent(TestDatasetEndpoint): def test_delete_should_respond_204(self, session, create_dummy_dag): @@ -882,7 +983,7 @@ def test_should_raise_403_forbidden(self, session): class TestGetDagDatasetQueuedEvents(TestQueuedEventEndpoint): @pytest.mark.usefixtures("time_freezer") - def test_should_respond_200(self, session, create_dummy_dag, time_freezer): + def test_should_respond_200(self, session, create_dummy_dag): dag, _ = create_dummy_dag() dag_id = dag.dag_id dataset_id = self._create_dataset(session).id @@ -938,6 +1039,24 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, session, create_dummy_dag + ): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_dataset(session).id + self._create_dataset_dag_run_queues(dag_id, dataset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + ) + assert response.status_code == expected_status_code + class TestDeleteDagDatasetQueuedEvents(TestDatasetEndpoint): def test_should_respond_404(self): @@ -973,6 +1092,31 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 204)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, session, create_dummy_dag + ): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_uri = "s3://bucket/key" + dataset_id = self._create_dataset(session).id + + ddrq = DatasetDagRunQueue(target_dag_id=dag_id, dataset_id=dataset_id) + session.add(ddrq) + session.commit() + conn = session.query(DatasetDagRunQueue).all() + assert len(conn) == 1 + + response = self.client.delete( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent/{dataset_uri}", + ) + + assert response.status_code == expected_status_code + class TestGetDatasetQueuedEvents(TestQueuedEventEndpoint): @pytest.mark.usefixtures("time_freezer") @@ -1033,6 +1177,26 @@ def test_should_raise_403_forbidden(self): assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + @pytest.mark.usefixtures("time_freezer") + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, session, create_dummy_dag + ): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_dataset(session).id + self._create_dataset_dag_run_queues(dag_id, dataset_id, session) + + response = self.client.get( + f"/api/v1/dags/{dag_id}/datasets/queuedEvent", + ) + + assert response.status_code == expected_status_code + class TestDeleteDatasetQueuedEvents(TestQueuedEventEndpoint): def test_delete_should_respond_204(self, session, create_dummy_dag): @@ -1084,3 +1248,23 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 204)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, session, create_dummy_dag + ): + dag, _ = create_dummy_dag() + dag_id = dag.dag_id + dataset_id = self._create_dataset(session).id + self._create_dataset_dag_run_queues(dag_id, dataset_id, session) + dataset_uri = "s3://bucket/key" + + response = self.client.delete( + f"/api/v1/datasets/queuedEvent/{dataset_uri}", + ) + + assert response.status_code == expected_status_code diff --git a/tests/api_connexion/endpoints/test_event_log_endpoint.py b/tests/api_connexion/endpoints/test_event_log_endpoint.py index 6e71a86b948d5..6738858ddd00f 100644 --- a/tests/api_connexion/endpoints/test_event_log_endpoint.py +++ b/tests/api_connexion/endpoints/test_event_log_endpoint.py @@ -109,6 +109,21 @@ def setup_attrs(self, configured_app) -> None: def teardown_method(self) -> None: clear_db_logs() + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, log_model): + event_log_id = log_model.id + response = self.client.get( + f"/api/v1/eventLogs/{event_log_id}", environ_overrides={"REMOTE_USER": "test"} + ) + + response = self.client.get("/api/v1/eventLogs") + + assert response.status_code == expected_status_code + class TestGetEventLog(TestEventLogEndpoint): def test_should_respond_200(self, log_model): @@ -152,6 +167,18 @@ def test_should_raise_403_forbidden(self): ) assert response.status_code == 403 + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set(self, set_auto_role_public, expected_status_code, log_model): + event_log_id = log_model.id + + response = self.client.get(f"/api/v1/eventLogs/{event_log_id}") + + assert response.status_code == expected_status_code + class TestGetEventLogs(TestEventLogEndpoint): def test_should_respond_200(self, session, create_log_model): @@ -349,6 +376,23 @@ def test_should_filter_eventlogs_by_excluded_events(self, create_log_model): assert response_data["total_entries"] == 1 assert {"cli_scheduler"} == {x["event"] for x in response_data["event_logs"]} + @pytest.mark.parametrize( + "set_auto_role_public, expected_status_code", + (("Public", 403), ("Admin", 200)), + indirect=["set_auto_role_public"], + ) + def test_with_auth_role_public_set( + self, set_auto_role_public, expected_status_code, create_log_model, session + ): + log_model_3 = Log(event="cli_scheduler", owner="root", extra='{"host_name": "e24b454f002a"}') + log_model_3.dttm = self.default_time_2 + + session.add(log_model_3) + session.flush() + response = self.client.get("/api/v1/eventLogs") + + assert response.status_code == expected_status_code + class TestGetEventLogPagination(TestEventLogEndpoint): @pytest.mark.parametrize(