Skip to content
6 changes: 3 additions & 3 deletions airflow/cli/commands/role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@
from tabulate import tabulate

from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


def roles_list(args):
"""Lists all existing roles"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
roles = appbuilder.sm.get_all_roles()
print("Existing roles:\n")
role_names = sorted([[r.name] for r in roles])
Expand All @@ -38,6 +38,6 @@ def roles_list(args):
@cli_utils.action_logging
def roles_create(args):
"""Creates new empty role in DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
for role_name in args.role:
appbuilder.sm.add_role(role_name)
4 changes: 2 additions & 2 deletions airflow/cli/commands/sync_perm_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,13 @@
"""Sync permission command"""
from airflow.models import DagBag
from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


@cli_utils.action_logging
def sync_perm(args):
"""Updates permissions for existing roles and DAGs"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
print('Updating permission, view-menu for all existing roles')
appbuilder.sm.sync_roles()
print('Updating permission on all DAG views')
Expand Down
14 changes: 7 additions & 7 deletions airflow/cli/commands/user_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@
from tabulate import tabulate

from airflow.utils import cli as cli_utils
from airflow.www.app import cached_appbuilder
from airflow.www.app import cached_app


def users_list(args):
"""Lists users at the command line"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']
users = [[user.__getattribute__(field) for field in fields] for user in users]
Expand All @@ -44,7 +44,7 @@ def users_list(args):
@cli_utils.action_logging
def users_create(args):
"""Creates new user in the DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
role = appbuilder.sm.find_role(args.role)
if not role:
valid_roles = appbuilder.sm.get_all_roles()
Expand Down Expand Up @@ -74,7 +74,7 @@ def users_create(args):
@cli_utils.action_logging
def users_delete(args):
"""Deletes user from DB"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member

try:
user = next(u for u in appbuilder.sm.get_all_users()
Expand All @@ -98,7 +98,7 @@ def users_manage_role(args, remove=False):
raise SystemExit('Conflicting args: must supply either --username'
' or --email, but not both')

appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
user = (appbuilder.sm.find_user(username=args.username) or
appbuilder.sm.find_user(email=args.email))
if not user:
Expand Down Expand Up @@ -136,7 +136,7 @@ def users_manage_role(args, remove=False):

def users_export(args):
"""Exports all users to the json file"""
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users = appbuilder.sm.get_all_users()
fields = ['id', 'username', 'email', 'first_name', 'last_name', 'roles']

Expand Down Expand Up @@ -184,7 +184,7 @@ def users_import(args):


def _import_users(users_list): # pylint: disable=redefined-outer-name
appbuilder = cached_appbuilder()
appbuilder = cached_app().appbuilder # pylint: disable=no-member
users_created = []
users_updated = []

Expand Down
2 changes: 1 addition & 1 deletion airflow/cli/commands/webserver_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def webserver(args):
print(
"Starting the web server on port {0} and host {1}.".format(
args.port, args.hostname))
app, _ = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
app = create_app(testing=conf.getboolean('core', 'unit_test_mode'))
app.run(debug=True, use_reloader=not app.config['TESTING'],
port=args.port, host=args.hostname,
ssl_context=(ssl_cert, ssl_key) if ssl_cert and ssl_key else None)
Expand Down
69 changes: 36 additions & 33 deletions airflow/www/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import logging
import socket
from datetime import timedelta
from typing import Any, Optional
from typing import Optional
from urllib.parse import urlparse

import flask
Expand All @@ -39,15 +39,18 @@
from airflow.utils.json import AirflowJsonEncoder
from airflow.www.static_config import configure_manifest_files

app = None # type: Any
appbuilder = None # type: Optional[AppBuilder]
app: Optional[Flask] = None
csrf = CSRFProtect()

log = logging.getLogger(__name__)


def root_app(env, resp):
resp(b'404 Not Found', [('Content-Type', 'text/plain')])
return [b'Apache Airflow is not at this location']


def create_app(config=None, testing=False, app_name="Airflow"):
global app, appbuilder
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This causes a side effect. We want the cache to be modified only by the cached_app method.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!

app = Flask(__name__)
app.secret_key = conf.get('webserver', 'SECRET_KEY')

Expand All @@ -70,6 +73,31 @@ def create_app(config=None, testing=False, app_name="Airflow"):
app.json_encoder = AirflowJsonEncoder

csrf.init_app(app)

def apply_middlewares(flask_app: Flask):
# Apply DispatcherMiddleware
base_url = urlparse(conf.get('webserver', 'base_url'))[2]
if not base_url or base_url == '/':
base_url = ""
if base_url:
flask_app.wsgi_app = DispatcherMiddleware( # type: ignore
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This middleware is now optional. We use addresses in the form of blocked which do not work properly with this middleware. This middleware expects the addresses to be in the form /blocked. However, if we don't use this middleware, it doesn't matter.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we remove it then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't use it in tests - test_views. This is still needed in production for some users.

root_app,
mounts={base_url: flask_app.wsgi_app}
)

# Apply ProxyFix middleware
if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
flask_app.wsgi_app = ProxyFix( # type: ignore
flask_app.wsgi_app,
x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
)

apply_middlewares(app)

db = SQLA()
db.session = settings.Session
db.init_app(app)
Expand Down Expand Up @@ -286,36 +314,11 @@ def apply_caching(response):
def make_session_permanent():
flask_session.permanent = True

return app, appbuilder


def root_app(env, resp):
resp(b'404 Not Found', [('Content-Type', 'text/plain')])
return [b'Apache Airflow is not at this location']
return app


def cached_app(config=None, testing=False):
global app, appbuilder
if not app or not appbuilder:
base_url = urlparse(conf.get('webserver', 'base_url'))[2]
if not base_url or base_url == '/':
base_url = ""

app, _ = create_app(config=config, testing=testing)
app = DispatcherMiddleware(root_app, {base_url: app})
if conf.getboolean('webserver', 'ENABLE_PROXY_FIX'):
app = ProxyFix(
app,
x_for=conf.getint("webserver", "PROXY_FIX_X_FOR", fallback=1),
x_proto=conf.getint("webserver", "PROXY_FIX_X_PROTO", fallback=1),
x_host=conf.getint("webserver", "PROXY_FIX_X_HOST", fallback=1),
x_port=conf.getint("webserver", "PROXY_FIX_X_PORT", fallback=1),
x_prefix=conf.getint("webserver", "PROXY_FIX_X_PREFIX", fallback=1)
)
global app
if not app:
app = create_app(config=config, testing=testing)
return app


def cached_appbuilder(config=None, testing=False):
global appbuilder
cached_app(config=config, testing=testing)
return appbuilder
7 changes: 3 additions & 4 deletions airflow/www/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
# under the License.
#

from flask import g
from flask import current_app, g
from flask_appbuilder.security.sqla import models as sqla_models
from flask_appbuilder.security.sqla.manager import SecurityManager
from sqlalchemy import and_, or_
Expand All @@ -26,7 +26,6 @@
from airflow.exceptions import AirflowException
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import provide_session
from airflow.www.app import appbuilder
from airflow.www.utils import CustomSQLAInterface

EXISTING_ROLES = {
Expand Down Expand Up @@ -250,8 +249,8 @@ def get_user_roles(user=None):
if user is None:
user = g.user
if user.is_anonymous:
public_role = appbuilder.config.get('AUTH_ROLE_PUBLIC')
return [appbuilder.security_manager.find_role(public_role)] \
public_role = current_app.appbuilder.config.get('AUTH_ROLE_PUBLIC')
return [current_app.appbuilder.security_manager.find_role(public_role)] \
if public_role else []
return user.roles

Expand Down
21 changes: 10 additions & 11 deletions airflow/www/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
import markdown
import sqlalchemy as sqla
from flask import (
Markup, Response, escape, flash, jsonify, make_response, redirect, render_template, request,
Markup, Response, current_app, escape, flash, jsonify, make_response, redirect, render_template, request,
session as flask_session, url_for,
)
from flask_appbuilder import BaseView, ModelView, expose, has_access, permission_name
Expand Down Expand Up @@ -72,7 +72,6 @@
from airflow.utils.session import create_session, provide_session
from airflow.utils.state import State
from airflow.www import utils as wwwutils
from airflow.www.app import appbuilder
from airflow.www.decorators import action_logging, gzipped, has_dag_access
from airflow.www.forms import (
ConnectionForm, DagRunForm, DateTimeForm, DateTimeWithNumRunsForm, DateTimeWithNumRunsWithDagRunsForm,
Expand Down Expand Up @@ -270,7 +269,7 @@ def get_int_arg(value, default=0):
end = start + dags_per_page

# Get all the dag id the user could access
filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

with create_session() as session:
# read orm_dags from the db
Expand Down Expand Up @@ -368,7 +367,7 @@ def get_int_arg(value, default=0):
def dag_stats(self, session=None):
dr = models.DagRun

allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]

Expand Down Expand Up @@ -416,7 +415,7 @@ def task_stats(self, session=None):
DagRun = models.DagRun
Dag = models.DagModel

allowed_dag_ids = set(appbuilder.sm.get_accessible_dag_ids())
allowed_dag_ids = set(current_app.appbuilder.sm.get_accessible_dag_ids())

if not allowed_dag_ids:
return wwwutils.json_response({})
Expand Down Expand Up @@ -512,7 +511,7 @@ def task_stats(self, session=None):
def last_dagruns(self, session=None):
DagRun = models.DagRun

allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -1167,7 +1166,7 @@ def dagrun_clear(self):
@has_access
@provide_session
def blocked(self, session=None):
allowed_dag_ids = appbuilder.sm.get_accessible_dag_ids()
allowed_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()

if 'all_dags' in allowed_dag_ids:
allowed_dag_ids = [dag_id for dag_id, in session.query(models.DagModel.dag_id)]
Expand Down Expand Up @@ -1912,7 +1911,7 @@ def refresh(self, session=None):

dag = dagbag.get_dag(dag_id)
# sync dag permission
appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)
current_app.appbuilder.sm.sync_perm_for_dag(dag_id, dag.access_control)

flash("DAG [{}] is now fresh as a daisy".format(dag_id))
return redirect(request.referrer)
Expand Down Expand Up @@ -2163,9 +2162,9 @@ def conf(self):

class DagFilter(BaseFilter):
def apply(self, query, func): # noqa
if appbuilder.sm.has_all_dags_access():
if current_app.appbuilder.sm.has_all_dags_access():
return query
filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
return query.filter(self.model.dag_id.in_(filter_dag_ids))


Expand Down Expand Up @@ -2800,7 +2799,7 @@ def autocomplete(self, session=None):
dag_ids_query = dag_ids_query.filter(DagModel.is_paused)
owners_query = owners_query.filter(DagModel.is_paused)

filter_dag_ids = appbuilder.sm.get_accessible_dag_ids()
filter_dag_ids = current_app.appbuilder.sm.get_accessible_dag_ids()
if 'all_dags' not in filter_dag_ids:
dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids))
owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids))
Expand Down
5 changes: 1 addition & 4 deletions tests/cli/commands/test_celery_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
from airflow.configuration import conf
from tests.test_utils.config import conf_vars

mock.patch('airflow.utils.cli.action_logging', lambda x: x).start()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This caused a side effect. I don't know why it affected this change. Probably some session is now saving properly.

mock_args = Namespace(queues=1, concurrency=1)


class TestWorkerPrecheck(unittest.TestCase):
@mock.patch('airflow.settings.validate_session')
Expand All @@ -42,7 +39,7 @@ def test_error(self, mock_validate_session):
"""
mock_validate_session.return_value = False
with self.assertRaises(SystemExit) as cm:
celery_command.worker(mock_args)
celery_command.worker(Namespace(queues=1, concurrency=1))
self.assertEqual(cm.exception.code, 1)

@conf_vars({('core', 'worker_precheck'): 'False'})
Expand Down
3 changes: 2 additions & 1 deletion tests/cli/commands/test_role_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,8 @@ def setUpClass(cls):

def setUp(self):
from airflow.www import app as application
self.app, self.appbuilder = application.create_app(testing=True)
self.app = application.create_app(testing=True)
self.appbuilder = self.app.appbuilder # pylint: disable=no-member
self.clear_roles_and_roles()

def tearDown(self):
Expand Down
Loading