Skip to content

Commit

Permalink
Ensure @contextmanager decorates generator func (#23103)
Browse files Browse the repository at this point in the history
(cherry picked from commit e589855)
  • Loading branch information
uranusjr authored and potiuk committed May 31, 2022
1 parent b74b58b commit 1db9c17
Show file tree
Hide file tree
Showing 11 changed files with 48 additions and 28 deletions.
4 changes: 2 additions & 2 deletions airflow/cli/commands/task_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import os
import textwrap
from contextlib import contextmanager, redirect_stderr, redirect_stdout
from typing import Dict, List, Optional, Tuple, Union
from typing import Dict, Generator, List, Optional, Tuple, Union

from pendulum.parsing.exceptions import ParserError
from sqlalchemy.orm.exc import NoResultFound
Expand Down Expand Up @@ -269,7 +269,7 @@ def _extract_external_executor_id(args) -> Optional[str]:


@contextmanager
def _capture_task_logs(ti):
def _capture_task_logs(ti: TaskInstance) -> Generator[None, None, None]:
"""Manage logging context for a task run
- Replace the root logger configuration with the airflow.task configuration
Expand Down
3 changes: 1 addition & 2 deletions airflow/models/taskinstance.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
Dict,
Generator,
Iterable,
Iterator,
List,
NamedTuple,
Optional,
Expand Down Expand Up @@ -142,7 +141,7 @@


@contextlib.contextmanager
def set_current_context(context: Context) -> Iterator[Context]:
def set_current_context(context: Context) -> Generator[Context, None, None]:
"""
Sets the current execution context to the provided context object.
This method should be called once per Task execution, before calling operator.execute.
Expand Down
19 changes: 16 additions & 3 deletions airflow/providers/google/cloud/hooks/gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,20 @@
from io import BytesIO
from os import path
from tempfile import NamedTemporaryFile
from typing import Callable, List, Optional, Sequence, Set, Tuple, TypeVar, Union, cast, overload
from typing import (
IO,
Callable,
Generator,
List,
Optional,
Sequence,
Set,
Tuple,
TypeVar,
Union,
cast,
overload,
)
from urllib.parse import urlparse

from google.api_core.exceptions import NotFound
Expand Down Expand Up @@ -385,7 +398,7 @@ def provide_file(
object_name: Optional[str] = None,
object_url: Optional[str] = None,
dir: Optional[str] = None,
):
) -> Generator[IO[bytes], None, None]:
"""
Downloads the file to a temporary directory and returns a file handle
Expand Down Expand Up @@ -413,7 +426,7 @@ def provide_file_and_upload(
bucket_name: str = PROVIDE_BUCKET,
object_name: Optional[str] = None,
object_url: Optional[str] = None,
):
) -> Generator[IO[bytes], None, None]:
"""
Creates temporary file, returns a file handle and uploads the files content
on close.
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/google/cloud/utils/credentials_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,10 @@ def build_gcp_conn(


@contextmanager
def provide_gcp_credentials(key_file_path: Optional[str] = None, key_file_dict: Optional[Dict] = None):
def provide_gcp_credentials(
key_file_path: Optional[str] = None,
key_file_dict: Optional[Dict] = None,
) -> Generator[None, None, None]:
"""
Context manager that provides a Google Cloud credentials for application supporting
`Application Default Credentials (ADC) strategy`__.
Expand Down Expand Up @@ -111,7 +114,7 @@ def provide_gcp_connection(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
) -> Generator:
) -> Generator[None, None, None]:
"""
Context manager that provides a temporary value of :envvar:`AIRFLOW_CONN_GOOGLE_CLOUD_DEFAULT`
connection. It build a new connection that includes path to provided service json,
Expand All @@ -135,7 +138,7 @@ def provide_gcp_conn_and_credentials(
key_file_path: Optional[str] = None,
scopes: Optional[Sequence] = None,
project_id: Optional[str] = None,
) -> Generator:
) -> Generator[None, None, None]:
"""
Context manager that provides both:
Expand Down
10 changes: 5 additions & 5 deletions airflow/providers/google/common/hooks/base_google.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import warnings
from contextlib import ExitStack, contextmanager
from subprocess import check_output
from typing import Any, Callable, Dict, Optional, Sequence, Tuple, TypeVar, Union, cast
from typing import Any, Callable, Dict, Generator, Optional, Sequence, Tuple, TypeVar, Union, cast

import google.auth
import google.auth.credentials
Expand Down Expand Up @@ -459,16 +459,16 @@ def wrapper(self: GoogleBaseHook, *args, **kwargs):
return cast(T, wrapper)

@contextmanager
def provide_gcp_credential_file_as_context(self):
def provide_gcp_credential_file_as_context(self) -> Generator[Optional[str], None, None]:
"""
Context manager that provides a Google Cloud credentials for application supporting `Application
Default Credentials (ADC) strategy <https://cloud.google.com/docs/authentication/production>`__.
It can be used to provide credentials for external programs (e.g. gcloud) that expect authorization
file in ``GOOGLE_APPLICATION_CREDENTIALS`` environment variable.
"""
key_path = self._get_field('key_path', None) # type: Optional[str] #
keyfile_dict = self._get_field('keyfile_dict', None) # type: Optional[Dict]
key_path: Optional[str] = self._get_field('key_path', None)
keyfile_dict: Optional[str] = self._get_field('keyfile_dict', None)
if key_path and keyfile_dict:
raise AirflowException(
"The `keyfile_dict` and `key_path` fields are mutually exclusive. "
Expand All @@ -490,7 +490,7 @@ def provide_gcp_credential_file_as_context(self):
yield None

@contextmanager
def provide_authorized_gcloud(self):
def provide_authorized_gcloud(self) -> Generator[None, None, None]:
"""
Provides a separate gcloud configuration with current credentials.
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/microsoft/psrp/hooks/psrp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from contextlib import contextmanager
from copy import copy
from logging import DEBUG, ERROR, INFO, WARNING
from typing import Any, Callable, Dict, Iterator, Optional
from typing import Any, Callable, Dict, Generator, Optional
from weakref import WeakKeyDictionary

from pypsrp.host import PSHost
Expand Down Expand Up @@ -155,7 +155,7 @@ def apply_extra(d, keys):
return pool

@contextmanager
def invoke(self) -> Iterator[PowerShell]:
def invoke(self) -> Generator[PowerShell, None, None]:
"""
Context manager that yields a PowerShell object to which commands can be
added. Upon exit, the commands will be invoked.
Expand Down
11 changes: 8 additions & 3 deletions airflow/utils/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import warnings
from dataclasses import dataclass
from tempfile import gettempdir
from typing import TYPE_CHECKING, Callable, Iterable, List, Optional, Tuple, Union
from typing import TYPE_CHECKING, Callable, Generator, Iterable, List, Optional, Tuple, Union

from sqlalchemy import Table, and_, column, exc, func, inspect, or_, select, table, text, tuple_
from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -68,6 +68,7 @@
from airflow.version import version

if TYPE_CHECKING:
from alembic.runtime.environment import EnvironmentContext
from alembic.script import ScriptDirectory
from sqlalchemy.orm import Query

Expand Down Expand Up @@ -709,7 +710,7 @@ def check_migrations(timeout):


@contextlib.contextmanager
def _configured_alembic_environment():
def _configured_alembic_environment() -> Generator["EnvironmentContext", None, None]:
from alembic.runtime.environment import EnvironmentContext

config = _get_alembic_config()
Expand Down Expand Up @@ -1606,7 +1607,11 @@ def __str__(self):


@contextlib.contextmanager
def create_global_lock(session: Session, lock: DBLocks, lock_timeout=1800):
def create_global_lock(
session: Session,
lock: DBLocks,
lock_timeout: int = 1800,
) -> Generator[None, None, None]:
"""Contextmanager that will create and teardown a global db lock."""
conn = session.get_bind().connect()
dialect = conn.dialect
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/process_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
import pty

from contextlib import contextmanager
from typing import Dict, List, Optional
from typing import Dict, Generator, List, Optional

import psutil
from lockfile.pidlockfile import PIDLockFile
Expand Down Expand Up @@ -258,7 +258,7 @@ def kill_child_processes_by_pids(pids_to_kill: List[int], timeout: int = 5) -> N


@contextmanager
def patch_environ(new_env_variables: Dict[str, str]):
def patch_environ(new_env_variables: Dict[str, str]) -> Generator[None, None, None]:
"""
Sets environment variables in context. After leaving the context, it restores its original state.
Expand Down
4 changes: 2 additions & 2 deletions airflow/utils/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@
import contextlib
from functools import wraps
from inspect import signature
from typing import Callable, Iterator, TypeVar, cast
from typing import Callable, Generator, TypeVar, cast

from airflow import settings


@contextlib.contextmanager
def create_session() -> Iterator[settings.SASession]:
def create_session() -> Generator[settings.SASession, None, None]:
"""Contextmanager that will create and teardown a session."""
if not settings.Session:
raise RuntimeError("Session must be set before!")
Expand Down
4 changes: 2 additions & 2 deletions dev/breeze/src/airflow_breeze/utils/run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from functools import lru_cache
from pathlib import Path
from re import match
from typing import Dict, List, Mapping, Optional, Union
from typing import Dict, Generator, List, Mapping, Optional, Union

from airflow_breeze.branch_defaults import AIRFLOW_BRANCH
from airflow_breeze.params._common_build_params import _CommonBuildParams
Expand Down Expand Up @@ -213,7 +213,7 @@ def instruct_build_image(python: str):


@contextlib.contextmanager
def working_directory(source_path: Path):
def working_directory(source_path: Path) -> Generator[None, None, None]:
"""
# Equivalent of pushd and popd in bash script.
# https://stackoverflow.com/a/42441759/3101838
Expand Down
4 changes: 2 additions & 2 deletions dev/provider_packages/prepare_provider_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from os.path import dirname, relpath
from pathlib import Path
from shutil import copyfile
from typing import Any, Dict, Iterable, List, NamedTuple, Optional, Set, Tuple, Union
from typing import Any, Dict, Generator, Iterable, List, NamedTuple, Optional, Set, Tuple, Union

import jsonschema
import rich_click as click
Expand Down Expand Up @@ -195,7 +195,7 @@ def cli():


@contextmanager
def with_group(title):
def with_group(title: str) -> Generator[None, None, None]:
"""
If used in GitHub Action, creates an expandable group in the GitHub Action log.
Otherwise, display simple text groups.
Expand Down

0 comments on commit 1db9c17

Please sign in to comment.