Skip to content

Commit

Permalink
Add more type annotations to AWS hooks (#10671)
Browse files Browse the repository at this point in the history
Co-authored-by: Kamil Breguła <kamil.bregula@polidea.com>
  • Loading branch information
coopergillan and Kamil Breguła committed Sep 14, 2020
1 parent 9e42a97 commit 383a118
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 46 deletions.
13 changes: 9 additions & 4 deletions airflow/providers/amazon/aws/hooks/aws_dynamodb.py
Expand Up @@ -20,6 +20,8 @@
"""
This module contains the AWS DynamoDB hook
"""
from typing import Iterable, List, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

Expand All @@ -40,12 +42,15 @@ class AwsDynamoDBHook(AwsBaseHook):
:type table_name: str
"""

def __init__(self, *args, table_keys=None, table_name=None, **kwargs):
def __init__(
self, *args, table_keys: Optional[List] = None, table_name: Optional[str] = None, **kwargs
) -> None:
self.table_keys = table_keys
self.table_name = table_name
super().__init__(resource_type='dynamodb', *args, **kwargs)
kwargs["resource_type"] = "dynamodb"
super().__init__(*args, **kwargs)

def write_batch_data(self, items):
def write_batch_data(self, items: Iterable):
"""
Write batch items to DynamoDB table with provisioned throughout capacity.
"""
Expand All @@ -58,5 +63,5 @@ def write_batch_data(self, items):
return True
except Exception as general_error:
raise AirflowException(
'Failed to insert items in dynamodb, error: {error}'.format(error=str(general_error))
"Failed to insert items in dynamodb, error: {error}".format(error=str(general_error))
)
30 changes: 18 additions & 12 deletions airflow/providers/amazon/aws/hooks/base_aws.py
Expand Up @@ -39,7 +39,7 @@


class _SessionFactory(LoggingMixin):
def __init__(self, conn: Connection, region_name: str, config: Config):
def __init__(self, conn: Connection, region_name: Optional[str], config: Config) -> None:
super().__init__()
self.conn = conn
self.region_name = region_name
Expand Down Expand Up @@ -191,7 +191,7 @@ def _assume_role_with_saml(
RoleArn=role_arn, PrincipalArn=principal_arn, SAMLAssertion=saml_assertion, **assume_role_kwargs
)

def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]):
def _fetch_saml_assertion_using_http_spegno_auth(self, saml_config: Dict[str, Any]) -> str:
import requests

# requests_gssapi will need paramiko > 2.6 since you'll need
Expand Down Expand Up @@ -285,9 +285,9 @@ def __init__(
self.config = config

if not (self.client_type or self.resource_type):
raise AirflowException('Either client_type or resource_type' ' must be provided.')
raise AirflowException('Either client_type or resource_type must be provided.')

def _get_credentials(self, region_name):
def _get_credentials(self, region_name: Optional[str]) -> Tuple[boto3.session.Session, Optional[str]]:

if not self.aws_conn_id:
session = boto3.session.Session(region_name=region_name)
Expand Down Expand Up @@ -327,7 +327,9 @@ def _get_credentials(self, region_name):
session = boto3.session.Session(region_name=region_name)
return session, None

def get_client_type(self, client_type, region_name=None, config=None):
def get_client_type(
self, client_type: str, region_name: Optional[str] = None, config: Optional[Config] = None,
) -> boto3.client:
"""Get the underlying boto3 client using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

Expand All @@ -338,7 +340,9 @@ def get_client_type(self, client_type, region_name=None, config=None):

return session.client(client_type, endpoint_url=endpoint_url, config=config, verify=self.verify)

def get_resource_type(self, resource_type, region_name=None, config=None):
def get_resource_type(
self, resource_type: str, region_name: Optional[str] = None, config: Optional[Config] = None,
) -> boto3.resource:
"""Get the underlying boto3 resource using boto3 session"""
session, endpoint_url = self._get_credentials(region_name)

Expand All @@ -350,7 +354,7 @@ def get_resource_type(self, resource_type, region_name=None, config=None):
return session.resource(resource_type, endpoint_url=endpoint_url, config=config, verify=self.verify)

@cached_property
def conn(self):
def conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
Expand All @@ -365,7 +369,7 @@ def conn(self):
# Rare possibility - subclasses have not specified a client_type or resource_type
raise NotImplementedError('Could not get boto3 connection!')

def get_conn(self):
def get_conn(self) -> Union[boto3.client, boto3.resource]:
"""
Get the underlying boto3 client/resource (cached)
Expand All @@ -378,12 +382,12 @@ def get_conn(self):
# Compat shim
return self.conn

def get_session(self, region_name=None):
def get_session(self, region_name: Optional[str] = None) -> boto3.session.Session:
"""Get the underlying boto3.session."""
session, _ = self._get_credentials(region_name)
return session

def get_credentials(self, region_name=None):
def get_credentials(self, region_name: Optional[str] = None) -> Tuple[Optional[str], Optional[str]]:
"""
Get the underlying `botocore.Credentials` object.
Expand All @@ -395,7 +399,7 @@ def get_credentials(self, region_name=None):
# See https://stackoverflow.com/a/36291428/8283373
return session.get_credentials().get_frozen_credentials()

def expand_role(self, role):
def expand_role(self, role: str) -> str:
"""
If the IAM role is a role name, get the Amazon Resource Name (ARN) for the role.
If IAM role is already an IAM role ARN, no change is made.
Expand All @@ -409,7 +413,9 @@ def expand_role(self, role):
return self.get_client_type("iam").get_role(RoleName=role)["Role"]["Arn"]


def _parse_s3_config(config_file_name, config_format="boto", profile=None):
def _parse_s3_config(
config_file_name: str, config_format: Optional[str] = "boto", profile: Optional[str] = None
) -> Tuple[Optional[str], Optional[str]]:
"""
Parses a config file for s3 credentials. Can currently
parse boto, s3cmd.conf and AWS SDK config formats
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/hooks/ec2.py
Expand Up @@ -33,8 +33,9 @@ class EC2Hook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, *args, **kwargs):
super().__init__(resource_type="ec2", *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
kwargs["resource_type"] = "ec2"
super().__init__(*args, **kwargs)

def get_instance(self, instance_id: str):
"""
Expand Down
13 changes: 8 additions & 5 deletions airflow/providers/amazon/aws/hooks/emr.py
Expand Up @@ -15,6 +15,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Dict, List, Optional

from airflow.exceptions import AirflowException
from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
Expand All @@ -32,13 +33,15 @@ class EmrHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, emr_conn_id=None, *args, **kwargs):
def __init__(self, emr_conn_id: Optional[str] = None, *args, **kwargs) -> None:
self.emr_conn_id = emr_conn_id
super().__init__(client_type='emr', *args, **kwargs)
kwargs["client_type"] = "emr"
super().__init__(*args, **kwargs)

def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
def get_cluster_id_by_name(self, emr_cluster_name: str, cluster_states: List[str]) -> Optional[str]:
"""
Fetch id of EMR cluster with given name and (optional) states. Will return only if single id is found.
Fetch id of EMR cluster with given name and (optional) states.
Will return only if single id is found.
:param emr_cluster_name: Name of a cluster to find
:type emr_cluster_name: str
Expand All @@ -63,7 +66,7 @@ def get_cluster_id_by_name(self, emr_cluster_name, cluster_states):
self.log.info('No cluster found for name %s', emr_cluster_name)
return None

def create_job_flow(self, job_flow_overrides):
def create_job_flow(self, job_flow_overrides: Dict):
"""
Creates a job flow using the config from the EMR connection.
Keys of the json extra hash may have the arguments of the boto3
Expand Down
9 changes: 6 additions & 3 deletions airflow/providers/amazon/aws/hooks/kinesis.py
Expand Up @@ -19,6 +19,8 @@
"""
This module contains AWS Firehose hook
"""
from typing import Iterable

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand All @@ -36,11 +38,12 @@ class AwsFirehoseHook(AwsBaseHook):
:type delivery_stream: str
"""

def __init__(self, delivery_stream, *args, **kwargs):
def __init__(self, delivery_stream, *args, **kwargs) -> None:
self.delivery_stream = delivery_stream
super().__init__(client_type='firehose', *args, **kwargs)
kwargs["client_type"] = "firehose"
super().__init__(*args, **kwargs)

def put_records(self, records):
def put_records(self, records: Iterable):
"""
Write batch records to Kinesis Firehose
"""
Expand Down
15 changes: 8 additions & 7 deletions airflow/providers/amazon/aws/hooks/lambda_function.py
Expand Up @@ -44,20 +44,21 @@ class AwsLambdaHook(AwsBaseHook):

def __init__(
self,
function_name,
log_type='None',
qualifier='$LATEST',
invocation_type='RequestResponse',
function_name: str,
log_type: str = 'None',
qualifier: str = '$LATEST',
invocation_type: str = 'RequestResponse',
*args,
**kwargs,
):
) -> None:
self.function_name = function_name
self.log_type = log_type
self.invocation_type = invocation_type
self.qualifier = qualifier
super().__init__(client_type='lambda', *args, **kwargs)
kwargs["client_type"] = "lambda"
super().__init__(*args, **kwargs)

def invoke_lambda(self, payload):
def invoke_lambda(self, payload: str) -> str:
"""
Invoke Lambda Function
"""
Expand Down
19 changes: 14 additions & 5 deletions airflow/providers/amazon/aws/hooks/logs.py
Expand Up @@ -20,6 +20,7 @@
This module contains a hook (AwsLogsHook) with some very basic
functionality for interacting with AWS CloudWatch.
"""
from typing import Dict, Generator, Optional

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook

Expand All @@ -35,10 +36,18 @@ class AwsLogsHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, *args, **kwargs):
super().__init__(client_type='logs', *args, **kwargs)

def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start_from_head=True):
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "logs"
super().__init__(*args, **kwargs)

def get_log_events(
self,
log_group: str,
log_stream_name: str,
start_time: int = 0,
skip: int = 0,
start_from_head: bool = True,
) -> Generator:
"""
A generator for log items in a single stream. This will yield all the
items that are available at the current moment.
Expand Down Expand Up @@ -67,7 +76,7 @@ def get_log_events(self, log_group, log_stream_name, start_time=0, skip=0, start
event_count = 1
while event_count > 0:
if next_token is not None:
token_arg = {'nextToken': next_token}
token_arg: Optional[Dict[str, str]] = {'nextToken': next_token}
else:
token_arg = {}

Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/hooks/redshift.py
Expand Up @@ -35,8 +35,9 @@ class RedshiftHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, *args, **kwargs):
super().__init__(client_type='redshift', *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "redshift"
super().__init__(*args, **kwargs)

# TODO: Wrap create_cluster_snapshot
def cluster_status(self, cluster_identifier: str) -> str:
Expand Down
17 changes: 13 additions & 4 deletions airflow/providers/amazon/aws/hooks/sqs.py
Expand Up @@ -19,6 +19,8 @@
"""
This module contains AWS SQS hook
"""
from typing import Dict, Optional

from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook


Expand All @@ -33,10 +35,11 @@ class SQSHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, *args, **kwargs):
super().__init__(client_type='sqs', *args, **kwargs)
def __init__(self, *args, **kwargs) -> None:
kwargs["client_type"] = "sqs"
super().__init__(*args, **kwargs)

def create_queue(self, queue_name, attributes=None):
def create_queue(self, queue_name: str, attributes: Optional[Dict] = None) -> Dict:
"""
Create queue using connection object
Expand All @@ -52,7 +55,13 @@ def create_queue(self, queue_name, attributes=None):
"""
return self.get_conn().create_queue(QueueName=queue_name, Attributes=attributes or {})

def send_message(self, queue_url, message_body, delay_seconds=0, message_attributes=None):
def send_message(
self,
queue_url: str,
message_body: str,
delay_seconds: int = 0,
message_attributes: Optional[Dict] = None,
) -> Dict:
"""
Send message to the queue
Expand Down
5 changes: 3 additions & 2 deletions airflow/providers/amazon/aws/hooks/step_function.py
Expand Up @@ -32,8 +32,9 @@ class StepFunctionHook(AwsBaseHook):
:class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
"""

def __init__(self, region_name=None, *args, **kwargs):
super().__init__(client_type='stepfunctions', *args, **kwargs)
def __init__(self, region_name: Optional[str] = None, *args, **kwargs) -> None:
kwargs["client_type"] = "stepfunctions"
super().__init__(*args, **kwargs)

def start_execution(
self,
Expand Down

0 comments on commit 383a118

Please sign in to comment.