Skip to content
Permalink
Browse files
Upgrade FAB to 4.1.1 (#24399)
* Upgrade FAB to 4.1.1

The Flask Application Builder have been updated recently to
support a number of newer dependencies. This PR is the
attempt to migrate FAB to newer version.

This includes:

* update setup.py and setup.cfg upper and lower bounds to
  account for proper version of dependencies that
  FAB < 4.0.0 was blocking from upgrade
* added typed Flask application retrieval with a custom
  application fields available for MyPy typing checks.
* fix typing to account for typing hints added in multiple
  upgraded libraries optional values and content of request
  returned as Mapping
* switch to PyJWT 2.* by using non-deprecated "required" claim as
  list rather than separate fields
* add possibiliyt to install providers without constraints
  so that we could avoid errors on conflicting constraints when
  upgrade-to-newer-dependencies is used
* add pre-commit to check that 2.4+ only get_airflow_app is not
  used in providers
* avoid Bad Request in case the request sent to Flask 2.0 is not
  JSon content type
* switch imports of internal classes to direct packages
  where classes are available rather than from "airflow.models" to
  satisfy MyPY
* synchronize changes of FAB Security Manager 4.1.1 with our copy
  of the Security Manager.
* add error handling for a few "None" cases detected by MyPY
* corrected test cases that were broken by immutability of
  Flask 2 objects and better escaping done by Flask 2
* updated test cases to account for redirection to "path" rather
  than full URL by Flask2

Fixes: #22397

* fixup! Upgrade FAB to 4.1.1
  • Loading branch information
potiuk committed Jun 22, 2022
1 parent c3c1f7e commit e2f19505bf3622935480e80bee55bf5b6d80097b
Showing 59 changed files with 762 additions and 561 deletions.
@@ -805,6 +805,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}"
run: >
breeze verify-provider-packages --use-airflow-version wheel --use-packages-from-dist
--package-format wheel
env:
SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}"
- name: "Remove airflow package and replace providers with 2.2-compliant versions"
run: |
rm -vf dist/apache_airflow-*.whl \
@@ -882,6 +884,8 @@ ${{ hashFiles('.pre-commit-config.yaml') }}"
run: >
breeze verify-provider-packages --use-airflow-version sdist --use-packages-from-dist
--package-format sdist
env:
SKIP_CONSTRAINTS: "${{ needs.build-info.outputs.upgradeToNewerDependencies }}"
- name: "Fix ownership"
run: breeze fix-ownership
if: always()
@@ -347,7 +347,7 @@ repos:
language: python
files: ^BREEZE\.rst$|^dev/breeze/.*$
pass_filenames: false
additional_dependencies: ['rich>=12.4.4', 'rich-click']
additional_dependencies: ['rich>=12.4.4', 'rich-click>=1.5']
- id: update-local-yml-file
name: Update mounts in the local yml file
entry: ./scripts/ci/pre_commit/pre_commit_local_yml_mounts.py
@@ -686,29 +686,47 @@ if [[ ${SKIP_ENVIRONMENT_INITIALIZATION=} != "true" ]]; then
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "none"
else
echo "${COLOR_BLUE}Install airflow from wheel package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_wheel "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
uninstall_providers
elif [[ ${USE_AIRFLOW_VERSION} == "sdist" ]]; then
echo
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "none"
else
echo "${COLOR_BLUE}Install airflow from sdist package with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_airflow_from_sdist "${AIRFLOW_EXTRAS}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
uninstall_providers
else
echo
echo "${COLOR_BLUE}Uninstalling airflow and providers"
echo
uninstall_airflow_and_providers
echo
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
if [[ ${SKIP_CONSTRAINTS,,=} == "true" ]]; then
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' with no constraints.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "none"
else
echo "${COLOR_BLUE}Install released airflow from PyPI with extras: '${AIRFLOW_EXTRAS}' and constraints reference ${AIRFLOW_CONSTRAINTS_REFERENCE}.${COLOR_RESET}"
echo
install_released_airflow_version "${USE_AIRFLOW_VERSION}" "${AIRFLOW_CONSTRAINTS_REFERENCE}"
fi
fi
if [[ ${USE_PACKAGES_FROM_DIST=} == "true" ]]; then
echo
@@ -18,10 +18,11 @@
from functools import wraps
from typing import Any, Callable, Optional, Tuple, TypeVar, Union, cast

from flask import Response, current_app, request
from flask import Response, request
from flask_appbuilder.const import AUTH_LDAP
from flask_login import login_user

from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.www.fab_security.sqla.models import User

CLIENT_AUTH: Optional[Union[Tuple[str, str], Any]] = None
@@ -40,7 +41,7 @@ def auth_current_user() -> Optional[User]:
if auth is None or not auth.username or not auth.password:
return None

ab_security_manager = current_app.appbuilder.sm
ab_security_manager = get_airflow_app().appbuilder.sm
user = None
if ab_security_manager.auth_type == AUTH_LDAP:
user = ab_security_manager.auth_user_ldap(auth.username, auth.password)
@@ -19,7 +19,7 @@
from typing import Collection, Optional

from connexion import NoContent
from flask import current_app, g, request
from flask import g, request
from marshmallow import ValidationError
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import or_
@@ -38,6 +38,7 @@
from airflow.exceptions import AirflowException, DagNotFound
from airflow.models.dag import DagModel, DagTag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session


@@ -56,7 +57,7 @@ def get_dag(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG)])
def get_dag_details(*, dag_id: str) -> APIResponse:
"""Get details of DAG."""
dag: DAG = current_app.dag_bag.get_dag(dag_id)
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f"The DAG with dag_id: {dag_id} was not found")
return dag_detail_schema.dump(dag)
@@ -83,7 +84,7 @@ def get_dags(
if dag_id_pattern:
dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))

readable_dags = current_app.appbuilder.sm.get_accessible_dag_ids(g.user)
readable_dags = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user)

dags_query = dags_query.filter(DagModel.dag_id.in_(readable_dags))
if tags:
@@ -143,7 +144,7 @@ def patch_dags(limit, session, offset=0, only_active=True, tags=None, dag_id_pat
if dag_id_pattern == '~':
dag_id_pattern = '%'
dags_query = dags_query.filter(DagModel.dag_id.ilike(f'%{dag_id_pattern}%'))
editable_dags = current_app.appbuilder.sm.get_editable_dag_ids(g.user)
editable_dags = get_airflow_app().appbuilder.sm.get_editable_dag_ids(g.user)

dags_query = dags_query.filter(DagModel.dag_id.in_(editable_dags))
if tags:
@@ -19,7 +19,7 @@

import pendulum
from connexion import NoContent
from flask import current_app, g, request
from flask import g
from marshmallow import ValidationError
from sqlalchemy import or_
from sqlalchemy.orm import Query, Session
@@ -30,6 +30,7 @@
set_dag_run_state_to_success,
)
from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_datetime, format_parameters
from airflow.api_connexion.schemas.dag_run_schema import (
@@ -47,6 +48,7 @@
from airflow.api_connexion.types import APIResponse
from airflow.models import DagModel, DagRun
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session
from airflow.utils.state import DagRunState
from airflow.utils.types import DagRunType
@@ -167,7 +169,7 @@ def get_dag_runs(

# This endpoint allows specifying ~ as the dag_id to retrieve DAG Runs for all DAGs.
if dag_id == "~":
appbuilder = current_app.appbuilder
appbuilder = get_airflow_app().appbuilder
query = query.filter(DagRun.dag_id.in_(appbuilder.sm.get_readable_dag_ids(g.user)))
else:
query = query.filter(DagRun.dag_id == dag_id)
@@ -199,13 +201,13 @@ def get_dag_runs(
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
"""Get list of DAG Runs"""
body = request.get_json()
body = get_json_request_dict()
try:
data = dagruns_batch_form_schema.load(body)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

appbuilder = current_app.appbuilder
appbuilder = get_airflow_app().appbuilder
readable_dag_ids = appbuilder.sm.get_readable_dag_ids(g.user)
query = session.query(DagRun)
if data.get("dag_ids"):
@@ -252,7 +254,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
detail=f"DAG with dag_id: '{dag_id}' has import errors",
)
try:
post_body = dagrun_schema.load(request.json, session=session)
post_body = dagrun_schema.load(get_json_request_dict(), session=session)
except ValidationError as err:
raise BadRequest(detail=str(err))

@@ -268,7 +270,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
)
if not dagrun_instance:
try:
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
dag_run = dag.create_dagrun(
run_type=DagRunType.MANUAL,
run_id=run_id,
@@ -277,7 +279,7 @@ def post_dag_run(*, dag_id: str, session: Session = NEW_SESSION) -> APIResponse:
state=DagRunState.QUEUED,
conf=post_body.get("conf"),
external_trigger=True,
dag_hash=current_app.dag_bag.dags_hash.get(dag_id),
dag_hash=get_airflow_app().dag_bag.dags_hash.get(dag_id),
)
return dagrun_schema.dump(dag_run)
except ValueError as ve:
@@ -310,12 +312,12 @@ def update_dag_run_state(*, dag_id: str, dag_run_id: str, session: Session = NEW
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
raise NotFound(error_message)
try:
post_body = set_dagrun_state_form_schema.load(request.json)
post_body = set_dagrun_state_form_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest(detail=str(err))

state = post_body['state']
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if state == DagRunState.SUCCESS:
set_dag_run_state_to_success(dag=dag, run_id=dag_run.run_id, commit=True)
elif state == DagRunState.QUEUED:
@@ -339,15 +341,15 @@ def clear_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSIO
session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).one_or_none()
)
if dag_run is None:
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
error_message = f'Dag Run id {dag_run_id} not found in dag {dag_id}'
raise NotFound(error_message)
try:
post_body = clear_dagrun_form_schema.load(request.json)
post_body = clear_dagrun_form_schema.load(get_json_request_dict())
except ValidationError as err:
raise BadRequest(detail=str(err))

dry_run = post_body.get('dry_run', False)
dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
start_date = dag_run.logical_date
end_date = dag_run.logical_date

@@ -15,7 +15,6 @@
# specific language governing permissions and limitations
# under the License.

from flask import current_app
from sqlalchemy.orm.session import Session

from airflow import DAG
@@ -25,6 +24,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models.dagbag import DagBag
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.session import NEW_SESSION, provide_session


@@ -46,7 +46,7 @@ def get_extra_links(
"""Get extra links for task instance"""
from airflow.models.taskinstance import TaskInstance

dagbag: DagBag = current_app.dag_bag
dagbag: DagBag = get_airflow_app().dag_bag
dag: DAG = dagbag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found", detail=f'DAG with ID = "{dag_id}" not found')
@@ -14,10 +14,9 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import Any, Optional

from flask import Response, current_app, request
from flask import Response, request
from itsdangerous.exc import BadSignature
from itsdangerous.url_safe import URLSafeSerializer
from sqlalchemy.orm.session import Session
@@ -29,6 +28,7 @@
from airflow.exceptions import TaskNotFound
from airflow.models import TaskInstance
from airflow.security import permissions
from airflow.utils.airflow_flask_app import get_airflow_app
from airflow.utils.log.log_reader import TaskLogReader
from airflow.utils.session import NEW_SESSION, provide_session

@@ -52,7 +52,7 @@ def get_log(
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get logs for specific task instance"""
key = current_app.config["SECRET_KEY"]
key = get_airflow_app().config["SECRET_KEY"]
if not token:
metadata = {}
else:
@@ -87,7 +87,7 @@ def get_log(
metadata['end_of_log'] = True
raise NotFound(title="TaskInstance not found")

dag = current_app.dag_bag.get_dag(dag_id)
dag = get_airflow_app().dag_bag.get_dag(dag_id)
if dag:
try:
ti.task = dag.get_task(ti.task_id)
@@ -101,7 +101,8 @@ def get_log(
if return_type == 'application/json' or return_type is None: # default
logs, metadata = task_log_reader.read_log_chunks(ti, task_try_number, metadata)
logs = logs[0] if task_try_number is not None else logs
token = URLSafeSerializer(key).dumps(metadata)
# we must have token here, so we can safely ignore it
token = URLSafeSerializer(key).dumps(metadata) # type: ignore[assignment]
return logs_schema.dump(LogResponseObject(continuation_token=token, content=logs))
# text/plain. Stream
logs = task_log_reader.read_log_stream(ti, task_try_number, metadata)
@@ -18,13 +18,14 @@
from http import HTTPStatus
from typing import Optional

from flask import Response, request
from flask import Response
from marshmallow import ValidationError
from sqlalchemy import func
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session

from airflow.api_connexion import security
from airflow.api_connexion.endpoints.request_dict import get_json_request_dict
from airflow.api_connexion.exceptions import AlreadyExists, BadRequest, NotFound
from airflow.api_connexion.parameters import apply_sorting, check_limit, format_parameters
from airflow.api_connexion.schemas.pool_schema import PoolCollection, pool_collection_schema, pool_schema
@@ -85,9 +86,10 @@ def patch_pool(
session: Session = NEW_SESSION,
) -> APIResponse:
"""Update a pool"""
request_dict = get_json_request_dict()
# Only slots can be modified in 'default_pool'
try:
if pool_name == Pool.DEFAULT_POOL_NAME and request.json["name"] != Pool.DEFAULT_POOL_NAME:
if pool_name == Pool.DEFAULT_POOL_NAME and request_dict["name"] != Pool.DEFAULT_POOL_NAME:
if update_mask and len(update_mask) == 1 and update_mask[0].strip() == "slots":
pass
else:
@@ -100,7 +102,7 @@ def patch_pool(
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")

try:
patch_body = pool_schema.load(request.json)
patch_body = pool_schema.load(request_dict)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

@@ -121,7 +123,7 @@ def patch_pool(

else:
required_fields = {"name", "slots"}
fields_diff = required_fields - set(request.json.keys())
fields_diff = required_fields - set(get_json_request_dict().keys())
if fields_diff:
raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}")

@@ -136,12 +138,12 @@ def patch_pool(
def post_pool(*, session: Session = NEW_SESSION) -> APIResponse:
"""Create a pool"""
required_fields = {"name", "slots"} # Pool would require both fields in the post request
fields_diff = required_fields - set(request.json.keys())
fields_diff = required_fields - set(get_json_request_dict().keys())
if fields_diff:
raise BadRequest(detail=f"Missing required property(ies): {sorted(fields_diff)}")

try:
post_body = pool_schema.load(request.json, session=session)
post_body = pool_schema.load(get_json_request_dict(), session=session)
except ValidationError as err:
raise BadRequest(detail=str(err.messages))

0 comments on commit e2f1950

Please sign in to comment.