Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Completed D400 for airflow/api_connexion/* directory #27718

Merged
merged 4 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions airflow/api_connexion/endpoints/config_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@


def _conf_dict_to_config(conf_dict: dict) -> Config:
"""Convert config dict to a Config object"""
"""Convert config dict to a Config object."""
config = Config(
sections=[
ConfigSection(
Expand All @@ -44,25 +44,25 @@ def _conf_dict_to_config(conf_dict: dict) -> Config:


def _option_to_text(config_option: ConfigOption) -> str:
"""Convert a single config option to text"""
"""Convert a single config option to text."""
return f"{config_option.key} = {config_option.value}"


def _section_to_text(config_section: ConfigSection) -> str:
"""Convert a single config section to text"""
"""Convert a single config section to text."""
return (
f"[{config_section.name}]{LINE_SEP}"
f"{LINE_SEP.join(_option_to_text(option) for option in config_section.options)}{LINE_SEP}"
)


def _config_to_text(config: Config) -> str:
"""Convert the entire config to text"""
"""Convert the entire config to text."""
return LINE_SEP.join(_section_to_text(s) for s in config.sections)


def _config_to_json(config: Config) -> str:
"""Convert a Config object to a JSON formatted string"""
"""Convert a Config object to a JSON formatted string."""
return json.dumps(config_schema.dump(config), indent=4)


Expand Down
14 changes: 8 additions & 6 deletions airflow/api_connexion/endpoints/connection_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_CONNECTION)])
@provide_session
def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a connection entry"""
"""Delete a connection entry."""
connection = session.query(Connection).filter_by(conn_id=connection_id).one_or_none()
if connection is None:
raise NotFound(
Expand All @@ -59,7 +59,7 @@ def delete_connection(*, connection_id: str, session: Session = NEW_SESSION) ->
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_CONNECTION)])
@provide_session
def get_connection(*, connection_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a connection entry"""
"""Get a connection entry."""
connection = session.query(Connection).filter(Connection.conn_id == connection_id).one_or_none()
if connection is None:
raise NotFound(
Expand All @@ -79,7 +79,7 @@ def get_connections(
order_by: str = "id",
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get all connection entries"""
"""Get all connection entries."""
to_replace = {"connection_id": "conn_id"}
allowed_filter_attrs = ["connection_id", "conn_type", "description", "host", "port", "id"]

Expand All @@ -100,7 +100,7 @@ def patch_connection(
update_mask: UpdateMask = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Update a connection entry"""
"""Update a connection entry."""
try:
data = connection_schema.load(request.json, partial=True)
except ValidationError as err:
Expand Down Expand Up @@ -134,7 +134,7 @@ def patch_connection(
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
@provide_session
def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
"""Create connection entry"""
"""Create connection entry."""
body = request.json
try:
data = connection_schema.load(body)
Expand All @@ -154,7 +154,9 @@ def post_connection(*, session: Session = NEW_SESSION) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_CONNECTION)])
def test_connection() -> APIResponse:
"""
To test a connection, this method first creates an in-memory dummy conn_id & exports that to an
Test an API connection.

This method first creates an in-memory dummy conn_id & exports that to an
env var, as some hook classes tries to find out the conn from their __init__ method & errors out
if not found. It also deletes the conn id env variable after the test.
"""
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def get_dags(
@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_DAG)])
@provide_session
def patch_dag(*, dag_id: str, update_mask: UpdateMask = None, session: Session = NEW_SESSION) -> APIResponse:
"""Update the specific DAG"""
"""Update the specific DAG."""
try:
patch_body = dag_schema.load(request.json, session=session)
except ValidationError as err:
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/dag_run_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@
)
@provide_session
def delete_dag_run(*, dag_id: str, dag_run_id: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a DAG Run"""
"""Delete a DAG Run."""
if session.query(DagRun).filter(DagRun.dag_id == dag_id, DagRun.run_id == dag_run_id).delete() == 0:
raise NotFound(detail=f"DAGRun with DAG ID: '{dag_id}' and DagRun ID: '{dag_run_id}' not found")
return NoContent, HTTPStatus.NO_CONTENT
Expand Down Expand Up @@ -237,7 +237,7 @@ def get_dag_runs(
)
@provide_session
def get_dag_runs_batch(*, session: Session = NEW_SESSION) -> APIResponse:
"""Get list of DAG Runs"""
"""Get list of DAG Runs."""
body = get_json_request_dict()
try:
data = dagruns_batch_form_schema.load(body)
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/dag_source_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG_CODE)])
def get_dag_source(*, file_token: str) -> Response:
"""Get source code using file token"""
"""Get source code using file token."""
secret_key = current_app.config["SECRET_KEY"]
auth_s = URLSafeSerializer(secret_key)
try:
Expand Down
6 changes: 3 additions & 3 deletions airflow/api_connexion/endpoints/dataset_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_DATASET)])
@provide_session
def get_dataset(uri: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a Dataset"""
"""Get a Dataset."""
dataset = (
session.query(DatasetModel)
.filter(DatasetModel.uri == uri)
Expand All @@ -64,7 +64,7 @@ def get_datasets(
order_by: str = "id",
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get datasets"""
"""Get datasets."""
allowed_attrs = ["id", "uri", "created_at", "updated_at"]

total_entries = session.query(func.count(DatasetModel.id)).scalar()
Expand Down Expand Up @@ -96,7 +96,7 @@ def get_dataset_events(
source_map_index: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get dataset events"""
"""Get dataset events."""
allowed_attrs = ["source_dag_id", "source_task_id", "source_run_id", "source_map_index", "timestamp"]

query = session.query(DatasetEvent)
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/event_log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_AUDIT_LOG)])
@provide_session
def get_event_log(*, event_log_id: int, session: Session = NEW_SESSION) -> APIResponse:
"""Get a log entry"""
"""Get a log entry."""
event_log = session.query(Log).get(event_log_id)
if event_log is None:
raise NotFound("Event Log not found")
Expand All @@ -53,7 +53,7 @@ def get_event_logs(
order_by: str = "event_log_id",
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get all log entries from event log"""
"""Get all log entries from event log."""
to_replace = {"event_log_id": "id", "when": "dttm"}
allowed_filter_attrs = [
"event_log_id",
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/extra_link_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def get_extra_links(
task_id: str,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get extra links for task instance"""
"""Get extra links for task instance."""
from airflow.models.taskinstance import TaskInstance

dagbag: DagBag = get_airflow_app().dag_bag
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/health_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@


def get_health() -> APIResponse:
"""Return the health of the airflow scheduler and metadatabase"""
"""Return the health of the airflow scheduler and metadatabase."""
metadatabase_status = HEALTHY
latest_scheduler_heartbeat = None
scheduler_status = UNHEALTHY
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/import_error_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_IMPORT_ERROR)])
@provide_session
def get_import_error(*, import_error_id: int, session: Session = NEW_SESSION) -> APIResponse:
"""Get an import error"""
"""Get an import error."""
error = session.query(ImportErrorModel).get(import_error_id)

if error is None:
Expand All @@ -57,7 +57,7 @@ def get_import_errors(
order_by: str = "import_error_id",
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get all import errors"""
"""Get all import errors."""
to_replace = {"import_error_id": "id"}
allowed_filter_attrs = ["import_error_id", "timestamp", "filename"]
total_entries = session.query(func.count(ImportErrorModel.id)).scalar()
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/log_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_log(
token: str | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get logs for specific task instance"""
"""Get logs for specific task instance."""
key = get_airflow_app().config["SECRET_KEY"]
if not token:
metadata = {}
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/plugin_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PLUGIN)])
@format_parameters({"limit": check_limit})
def get_plugins(*, limit: int, offset: int = 0) -> APIResponse:
"""Get plugins endpoint"""
"""Get plugins endpoint."""
plugins_info = get_plugin_info()
collection = PluginCollection(plugins=plugins_info[offset:][:limit], total_entries=len(plugins_info))
return plugin_collection_schema.dump(collection)
10 changes: 5 additions & 5 deletions airflow/api_connexion/endpoints/pool_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_POOL)])
@provide_session
def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
"""Delete a pool"""
"""Delete a pool."""
if pool_name == "default_pool":
raise BadRequest(detail="Default Pool can't be deleted")
affected_count = session.query(Pool).filter(Pool.pool == pool_name).delete()
Expand All @@ -50,7 +50,7 @@ def delete_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIRespons
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_POOL)])
@provide_session
def get_pool(*, pool_name: str, session: Session = NEW_SESSION) -> APIResponse:
"""Get a pool"""
"""Get a pool."""
obj = session.query(Pool).filter(Pool.pool == pool_name).one_or_none()
if obj is None:
raise NotFound(detail=f"Pool with name:'{pool_name}' not found")
Expand All @@ -67,7 +67,7 @@ def get_pools(
offset: int | None = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get all pools"""
"""Get all pools."""
to_replace = {"name": "pool"}
allowed_filter_attrs = ["name", "slots", "id"]
total_entries = session.query(func.count(Pool.id)).scalar()
Expand All @@ -85,7 +85,7 @@ def patch_pool(
update_mask: UpdateMask = None,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Update a pool"""
"""Update a pool."""
request_dict = get_json_request_dict()
# Only slots can be modified in 'default_pool'
try:
Expand Down Expand Up @@ -136,7 +136,7 @@ def patch_pool(
@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_POOL)])
@provide_session
def post_pool(*, session: Session = NEW_SESSION) -> APIResponse:
"""Create a pool"""
"""Create a pool."""
required_fields = {"name", "slots"} # Pool would require both fields in the post request
fields_diff = required_fields - set(get_json_request_dict().keys())
if fields_diff:
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/provider_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _provider_mapper(provider: ProviderInfo) -> Provider:

@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_PROVIDER)])
def get_providers() -> APIResponse:
"""Get providers"""
"""Get providers."""
providers = [_provider_mapper(d) for d in ProvidersManager().providers.values()]
total_entries = len(providers)
return provider_collection_schema.dump(
Expand Down
1 change: 1 addition & 0 deletions airflow/api_connexion/endpoints/request_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@


def get_json_request_dict() -> Mapping[str, Any]:
"""Cast request dictionary to JSON."""
from flask import request

return cast(Mapping[str, Any], request.get_json())
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@

def _check_action_and_resource(sm: AirflowSecurityManager, perms: list[tuple[str, str]]) -> None:
"""
Checks if the action or resource exists and raise 400 if not
Checks if the action or resource exists and otherwise raise 400.

This function is intended for use in the REST API because it raise 400
"""
Expand All @@ -55,7 +55,7 @@ def _check_action_and_resource(sm: AirflowSecurityManager, perms: list[tuple[str

@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)])
def get_role(*, role_name: str) -> APIResponse:
"""Get role"""
"""Get role."""
ab_security_manager = get_airflow_app().appbuilder.sm
role = ab_security_manager.find_role(name=role_name)
if not role:
Expand All @@ -66,7 +66,7 @@ def get_role(*, role_name: str) -> APIResponse:
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ROLE)])
@format_parameters({"limit": check_limit})
def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None) -> APIResponse:
"""Get roles"""
"""Get roles."""
appbuilder = get_airflow_app().appbuilder
session = appbuilder.get_session
total_entries = session.query(func.count(Role.id)).scalar()
Expand All @@ -90,7 +90,7 @@ def get_roles(*, order_by: str = "name", limit: int, offset: int | None = None)
@security.requires_access([(permissions.ACTION_CAN_READ, permissions.RESOURCE_ACTION)])
@format_parameters({"limit": check_limit})
def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse:
"""Get permissions"""
"""Get permissions."""
session = get_airflow_app().appbuilder.get_session
total_entries = session.query(func.count(Action.id)).scalar()
query = session.query(Action)
Expand All @@ -100,7 +100,7 @@ def get_permissions(*, limit: int, offset: int | None = None) -> APIResponse:

@security.requires_access([(permissions.ACTION_CAN_DELETE, permissions.RESOURCE_ROLE)])
def delete_role(*, role_name: str) -> APIResponse:
"""Delete a role"""
"""Delete a role."""
ab_security_manager = get_airflow_app().appbuilder.sm
role = ab_security_manager.find_role(name=role_name)
if not role:
Expand All @@ -111,7 +111,7 @@ def delete_role(*, role_name: str) -> APIResponse:

@security.requires_access([(permissions.ACTION_CAN_EDIT, permissions.RESOURCE_ROLE)])
def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse:
"""Update a role"""
"""Update a role."""
appbuilder = get_airflow_app().appbuilder
security_manager = appbuilder.sm
body = request.json
Expand Down Expand Up @@ -145,7 +145,7 @@ def patch_role(*, role_name: str, update_mask: UpdateMask = None) -> APIResponse

@security.requires_access([(permissions.ACTION_CAN_CREATE, permissions.RESOURCE_ROLE)])
def post_role() -> APIResponse:
"""Create a new role"""
"""Create a new role."""
appbuilder = get_airflow_app().appbuilder
security_manager = appbuilder.sm
body = request.json
Expand Down
2 changes: 1 addition & 1 deletion airflow/api_connexion/endpoints/task_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def get_task(*, dag_id: str, task_id: str) -> APIResponse:
],
)
def get_tasks(*, dag_id: str, order_by: str = "task_id") -> APIResponse:
"""Get tasks for DAG"""
"""Get tasks for DAG."""
dag: DAG = get_airflow_app().dag_bag.get_dag(dag_id)
if not dag:
raise NotFound("DAG not found")
Expand Down
4 changes: 2 additions & 2 deletions airflow/api_connexion/endpoints/task_instance_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def get_task_instance(
task_id: str,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get task instance"""
"""Get task instance."""
query = (
session.query(TI)
.filter(TI.dag_id == dag_id, TI.run_id == dag_run_id, TI.task_id == task_id)
Expand Down Expand Up @@ -118,7 +118,7 @@ def get_mapped_task_instance(
map_index: int,
session: Session = NEW_SESSION,
) -> APIResponse:
"""Get task instance"""
"""Get task instance."""
query = (
session.query(TI)
.filter(
Expand Down
Loading