From 12bfb571a895a28a58d3189b0fc10cfc1b89e24c Mon Sep 17 00:00:00 2001 From: Jed Cunningham <66968678+jedcunningham@users.noreply.github.com> Date: Fri, 23 Sep 2022 13:28:33 -0700 Subject: [PATCH] Check user is active (#26635) (cherry picked from commit 59707cdf7eacb698ca375b5220af30a39ca1018c) --- airflow/www/app.py | 7 ++++++- airflow/www/extensions/init_security.py | 11 +++++++++++ tests/test_utils/decorators.py | 1 + tests/www/views/conftest.py | 1 + tests/www/views/test_session.py | 14 ++++++++++++++ tests/www/views/test_views_base.py | 13 +++++++++++-- 6 files changed, 44 insertions(+), 3 deletions(-) diff --git a/airflow/www/app.py b/airflow/www/app.py index b67314c99a8e9..d0c38b2936e97 100644 --- a/airflow/www/app.py +++ b/airflow/www/app.py @@ -39,7 +39,11 @@ from airflow.www.extensions.init_jinja_globals import init_jinja_globals from airflow.www.extensions.init_manifest_files import configure_manifest_files from airflow.www.extensions.init_robots import init_robots -from airflow.www.extensions.init_security import init_api_experimental_auth, init_xframe_protection +from airflow.www.extensions.init_security import ( + init_api_experimental_auth, + init_check_user_active, + init_xframe_protection, +) from airflow.www.extensions.init_session import init_airflow_session_interface from airflow.www.extensions.init_views import ( init_api_connexion, @@ -152,6 +156,7 @@ def create_app(config=None, testing=False): init_jinja_globals(flask_app) init_xframe_protection(flask_app) init_airflow_session_interface(flask_app) + init_check_user_active(flask_app) return flask_app diff --git a/airflow/www/extensions/init_security.py b/airflow/www/extensions/init_security.py index 1d96e351df017..b967b7408428a 100644 --- a/airflow/www/extensions/init_security.py +++ b/airflow/www/extensions/init_security.py @@ -19,6 +19,9 @@ import logging from importlib import import_module +from flask import g, redirect, url_for +from flask_login import logout_user + from airflow.configuration import conf from airflow.exceptions import AirflowConfigException, AirflowException @@ -60,3 +63,11 @@ def init_api_experimental_auth(app): except ImportError as err: log.critical("Cannot import %s for API authentication due to: %s", backend, err) raise AirflowException(err) + + +def init_check_user_active(app): + @app.before_request + def check_user_active(): + if g.user is not None and not g.user.is_anonymous and not g.user.is_active: + logout_user() + return redirect(url_for(app.appbuilder.sm.auth_view.endpoint + ".login")) diff --git a/tests/test_utils/decorators.py b/tests/test_utils/decorators.py index bdb8d678070c2..d0b71b502c3d0 100644 --- a/tests/test_utils/decorators.py +++ b/tests/test_utils/decorators.py @@ -45,6 +45,7 @@ def no_op(*args, **kwargs): "init_xframe_protection", "init_airflow_session_interface", "init_appbuilder", + "init_check_user_active", ] @functools.wraps(f) diff --git a/tests/www/views/conftest.py b/tests/www/views/conftest.py index 02c857180f0e1..ad562385bc42c 100644 --- a/tests/www/views/conftest.py +++ b/tests/www/views/conftest.py @@ -58,6 +58,7 @@ def app(examples_dag_bag): "init_jinja_globals", "init_plugins", "init_airflow_session_interface", + "init_check_user_active", ] ) def factory(): diff --git a/tests/www/views/test_session.py b/tests/www/views/test_session.py index 090bc503a8525..380239926404b 100644 --- a/tests/www/views/test_session.py +++ b/tests/www/views/test_session.py @@ -88,3 +88,17 @@ def test_session_id_rotates(app, user_client): new_session_cookie = get_session_cookie(user_client) assert new_session_cookie is not None assert old_session_cookie.value != new_session_cookie.value + + +def test_check_active_user(app, user_client): + user = app.appbuilder.sm.find_user(username="test_user") + user.active = False + resp = user_client.get("/home") + assert resp.status_code == 302 + assert "/login" in resp.headers.get("Location") + + # And they were logged out + user.active = True + resp = user_client.get("/home") + assert resp.status_code == 302 + assert "/login" in resp.headers.get("Location") diff --git a/tests/www/views/test_views_base.py b/tests/www/views/test_views_base.py index d0acc4df27d47..9c9c4f0aba68b 100644 --- a/tests/www/views/test_views_base.py +++ b/tests/www/views/test_views_base.py @@ -30,9 +30,18 @@ from tests.test_utils.www import check_content_in_response, check_content_not_in_response -def test_index(admin_client): +def test_index_redirect(admin_client): + resp = admin_client.get('/') + assert resp.status_code == 302 + assert '/home' in resp.headers.get("Location") + + resp = admin_client.get('/', follow_redirects=True) + check_content_in_response('DAGs', resp) + + +def test_homepage_query_count(admin_client): with assert_queries_count(16): - resp = admin_client.get('/', follow_redirects=True) + resp = admin_client.get('/home') check_content_in_response('DAGs', resp)