diff --git a/airflow/operators/cassandra_to_gcs.py b/airflow/operators/cassandra_to_gcs.py index 40b197363d2d4..8343161e89d7a 100644 --- a/airflow/operators/cassandra_to_gcs.py +++ b/airflow/operators/cassandra_to_gcs.py @@ -20,13 +20,14 @@ This module contains operator for copying data from Cassandra to Google cloud storage in JSON format. """ + import json import warnings from base64 import b64encode from datetime import datetime from decimal import Decimal from tempfile import NamedTemporaryFile -from typing import Optional +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union from uuid import UUID from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time @@ -82,7 +83,7 @@ class CassandraToGoogleCloudStorageOperator(BaseOperator): ui_color = '#a0e08c' @apply_defaults - def __init__(self, + def __init__(self, # pylint: disable=too-many-arguments cql: str, bucket: str, filename: str, @@ -113,8 +114,6 @@ def __init__(self, self.delegate_to = delegate_to self.gzip = gzip - self.hook = None - # Default Cassandra to BigQuery type mapping CQL_TYPE_MAP = { 'BytesType': 'BYTES', @@ -141,8 +140,10 @@ def __init__(self, 'VarcharType': 'STRING', } - def execute(self, context): - cursor = self._query_cassandra() + def execute(self, context: Dict[str, str]): + hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) + cursor = hook.get_conn().execute(self.cql) + files_to_upload = self._write_local_data_files(cursor) # If a schema is set, create a BQ schema JSON file. @@ -160,16 +161,7 @@ def execute(self, context): file_handle.close() # Close all sessions and connection associated with this Cassandra cluster - self.hook.shutdown_cluster() - - def _query_cassandra(self): - """ - Queries cassandra and returns a cursor to the results. - """ - self.hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id) - session = self.hook.get_conn() - cursor = session.execute(self.cql) - return cursor + hook.shutdown_cluster() def _write_local_data_files(self, cursor): """ @@ -184,8 +176,8 @@ def _write_local_data_files(self, cursor): tmp_file_handles = {self.filename.format(file_no): tmp_file_handle} for row in cursor: row_dict = self.generate_data_dict(row._fields, row) - s = json.dumps(row_dict).encode('utf-8') - tmp_file_handle.write(s) + content = json.dumps(row_dict).encode('utf-8') + tmp_file_handle.write(content) # Append newline to make dumps BigQuery compatible. tmp_file_handle.write(b'\n') @@ -209,29 +201,41 @@ def _write_local_schema_file(self, cursor): schema = [] tmp_schema_file_handle = NamedTemporaryFile(delete=True) - for name, type in zip(cursor.column_names, cursor.column_types): - schema.append(self.generate_schema_dict(name, type)) + for name, type_ in zip(cursor.column_names, cursor.column_types): + schema.append(self.generate_schema_dict(name, type_)) json_serialized_schema = json.dumps(schema).encode('utf-8') tmp_schema_file_handle.write(json_serialized_schema) return {self.schema_filename: tmp_schema_file_handle} - def _upload_to_gcs(self, files_to_upload): + def _upload_to_gcs(self, files_to_upload: Dict[str, Any]): hook = GoogleCloudStorageHook( google_cloud_storage_conn_id=self.gcp_conn_id, delegate_to=self.delegate_to) - for object, tmp_file_handle in files_to_upload.items(): - hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json', self.gzip) + for obj, tmp_file_handle in files_to_upload.items(): + hook.upload( + bucket_name=self.bucket, + object_name=obj, + filename=tmp_file_handle.name, + mime_type='application/json', + gzip=self.gzip + ) @classmethod - def generate_data_dict(cls, names, values): - row_dict = {} - for name, value in zip(names, values): - row_dict.update({name: cls.convert_value(name, value)}) - return row_dict + def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]: + """ + Generates data structure that will be stored as file in GCS. + """ + return {n: cls.convert_value(v) for n, v in zip(names, values)} @classmethod - def convert_value(cls, name, value): + def convert_value( # pylint: disable=too-many-return-statements + cls, + value: Optional[Any] + ) -> Optional[Any]: + """ + Convert value to BQ type. + """ if not value: return value elif isinstance(value, (str, int, float, bool, dict)): @@ -247,44 +251,46 @@ def convert_value(cls, name, value): elif isinstance(value, Time): return str(value).split('.')[0] elif isinstance(value, (list, SortedSet)): - return cls.convert_array_types(name, value) + return cls.convert_array_types(value) elif hasattr(value, '_fields'): - return cls.convert_user_type(name, value) + return cls.convert_user_type(value) elif isinstance(value, tuple): - return cls.convert_tuple_type(name, value) + return cls.convert_tuple_type(value) elif isinstance(value, OrderedMapSerializedKey): - return cls.convert_map_type(name, value) + return cls.convert_map_type(value) else: - raise AirflowException('unexpected value: ' + str(value)) + raise AirflowException('Unexpected value: ' + str(value)) @classmethod - def convert_array_types(cls, name, value): - return [cls.convert_value(name, nested_value) for nested_value in value] + def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]: + """ + Maps convert_value over array. + """ + return [cls.convert_value(nested_value) for nested_value in value] @classmethod - def convert_user_type(cls, name, value): + def convert_user_type(cls, value: Any) -> Dict[str, Any]: """ Converts a user type to RECORD that contains n fields, where n is the number of attributes. Each element in the user type class will be converted to its corresponding data type in BQ. """ names = value._fields - values = [cls.convert_value(name, getattr(value, name)) for name in names] + values = [cls.convert_value(getattr(value, name)) for name in names] return cls.generate_data_dict(names, values) @classmethod - def convert_tuple_type(cls, name, value): + def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]: """ Converts a tuple to RECORD that contains n fields, each will be converted to its corresponding data type in bq and will be named 'field_', where index is determined by the order of the tuple elements defined in cassandra. """ - names = ['field_' + str(i) for i in range(len(value))] - values = [cls.convert_value(name, value) for name, value in zip(names, value)] + names = ['field_' + str(i) for i in range(len(values))] return cls.generate_data_dict(names, values) @classmethod - def convert_map_type(cls, name, value): + def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]: """ Converts a map to a repeated RECORD that contains two fields: 'key' and 'value', each will be converted to its corresponding data type in BQ. @@ -292,75 +298,93 @@ def convert_map_type(cls, name, value): converted_map = [] for k, v in zip(value.keys(), value.values()): converted_map.append({ - 'key': cls.convert_value('key', k), - 'value': cls.convert_value('value', v) + 'key': cls.convert_value(k), + 'value': cls.convert_value(v) }) return converted_map @classmethod - def generate_schema_dict(cls, name, type): - field_schema = dict() + def generate_schema_dict(cls, name: str, type_: Any) -> Dict[str, Any]: + """ + Generates BQ schema. + """ + field_schema: Dict[str, Any] = dict() field_schema.update({'name': name}) - field_schema.update({'type': cls.get_bq_type(type)}) - field_schema.update({'mode': cls.get_bq_mode(type)}) - fields = cls.get_bq_fields(name, type) + field_schema.update({'type_': cls.get_bq_type(type_)}) + field_schema.update({'mode': cls.get_bq_mode(type_)}) + fields = cls.get_bq_fields(type_) if fields: field_schema.update({'fields': fields}) return field_schema @classmethod - def get_bq_fields(cls, name, type): - fields = [] - - if not cls.is_simple_type(type): - names, types = [], [] - - if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]): - names = type.subtypes[0].fieldnames - types = type.subtypes[0].subtypes - elif cls.is_record_type(type): - names = type.fieldnames - types = type.subtypes - - if types and not names and type.cassname == 'TupleType': - names = ['field_' + str(i) for i in range(len(types))] - elif types and not names and type.cassname == 'MapType': - names = ['key', 'value'] - - for name, type in zip(names, types): - field = cls.generate_schema_dict(name, type) - fields.append(field) - - return fields - - @classmethod - def is_simple_type(cls, type): - return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP + def get_bq_fields(cls, type_: Any) -> List[Dict[str, Any]]: + """ + Converts non simple type value to BQ representation. + """ + if cls.is_simple_type(type_): + return [] + + # In case of not simple type + names: List[str] = [] + types: List[Any] = [] + if cls.is_array_type(type_) and cls.is_record_type(type_.subtypes[0]): + names = type_.subtypes[0].fieldnames + types = type_.subtypes[0].subtypes + elif cls.is_record_type(type_): + names = type_.fieldnames + types = type_.subtypes + + if types and not names and type_.cassname == 'TupleType': + names = ['field_' + str(i) for i in range(len(types))] + elif types and not names and type_.cassname == 'MapType': + names = ['key', 'value'] + + return [cls.generate_schema_dict(n, t) for n, t in zip(names, types)] + + @staticmethod + def is_simple_type(type_: Any) -> bool: + """ + Check if type is a simple type. + """ + return type_.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP - @classmethod - def is_array_type(cls, type): - return type.cassname in ['ListType', 'SetType'] + @staticmethod + def is_array_type(type_: Any) -> bool: + """ + Check if type is an array type. + """ + return type_.cassname in ['ListType', 'SetType'] - @classmethod - def is_record_type(cls, type): - return type.cassname in ['UserType', 'TupleType', 'MapType'] + @staticmethod + def is_record_type(type_: Any) -> bool: + """ + Checks the record type. + """ + return type_.cassname in ['UserType', 'TupleType', 'MapType'] @classmethod - def get_bq_type(cls, type): - if cls.is_simple_type(type): - return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname] - elif cls.is_record_type(type): + def get_bq_type(cls, type_: Any) -> str: + """ + Converts type to equivalent BQ type. + """ + if cls.is_simple_type(type_): + return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type_.cassname] + elif cls.is_record_type(type_): return 'RECORD' - elif cls.is_array_type(type): - return cls.get_bq_type(type.subtypes[0]) + elif cls.is_array_type(type_): + return cls.get_bq_type(type_.subtypes[0]) else: - raise AirflowException('Not a supported type: ' + type.cassname) + raise AirflowException('Not a supported type_: ' + type_.cassname) @classmethod - def get_bq_mode(cls, type): - if cls.is_array_type(type) or type.cassname == 'MapType': + def get_bq_mode(cls, type_: Any) -> str: + """ + Converts type to equivalent BQ mode. + """ + if cls.is_array_type(type_) or type_.cassname == 'MapType': return 'REPEATED' - elif cls.is_record_type(type) or cls.is_simple_type(type): + elif cls.is_record_type(type_) or cls.is_simple_type(type_): return 'NULLABLE' else: - raise AirflowException('Not a supported type: ' + type.cassname) + raise AirflowException('Not a supported type_: ' + type_.cassname) diff --git a/scripts/ci/pylint_todo.txt b/scripts/ci/pylint_todo.txt index 4ab5fe486b258..65f113957abf2 100644 --- a/scripts/ci/pylint_todo.txt +++ b/scripts/ci/pylint_todo.txt @@ -149,7 +149,6 @@ ./airflow/models/variable.py ./airflow/models/xcom.py ./airflow/operators/bash_operator.py -./airflow/operators/cassandra_to_gcs.py ./airflow/operators/check_operator.py ./airflow/operators/dagrun_operator.py ./airflow/operators/druid_check_operator.py @@ -303,7 +302,6 @@ ./tests/models/test_pool.py ./tests/models/test_taskinstance.py ./tests/operators/test_bash_operator.py -./tests/operators/test_cassandra_to_gcs.py ./tests/operators/test_gcs_to_s3.py ./tests/operators/test_hive_operator.py ./tests/operators/test_operators.py diff --git a/tests/operators/test_cassandra_to_gcs.py b/tests/operators/test_cassandra_to_gcs.py index 1a10228f1035f..d295e6099f77c 100644 --- a/tests/operators/test_cassandra_to_gcs.py +++ b/tests/operators/test_cassandra_to_gcs.py @@ -51,57 +51,59 @@ def test_execute(self, mock_hook, mock_upload, mock_tempfile): operator.execute(None) mock_hook.return_value.get_conn.assert_called_once_with() - call_schema = call(test_bucket, schema, TMP_FILE_NAME, "application/json", gzip) - call_data = call(test_bucket, filename, TMP_FILE_NAME, "application/json", gzip) + call_schema = call(bucket_name=test_bucket, object_name=schema, + filename=TMP_FILE_NAME, mime_type="application/json", gzip=gzip) + call_data = call(bucket_name=test_bucket, object_name=filename, + filename=TMP_FILE_NAME, mime_type="application/json", gzip=gzip) mock_upload.assert_has_calls([call_schema, call_data], any_order=True) def test_convert_value(self): op = CassandraToGoogleCloudStorageOperator - self.assertEqual(op.convert_value("None", None), None) - self.assertEqual(op.convert_value("int", 1), 1) - self.assertEqual(op.convert_value("float", 1.0), 1.0) - self.assertEqual(op.convert_value("str", "text"), "text") - self.assertEqual(op.convert_value("bool", True), True) - self.assertEqual(op.convert_value("dict", {"a": "b"}), {"a": "b"}) + self.assertEqual(op.convert_value(None), None) + self.assertEqual(op.convert_value(1), 1) + self.assertEqual(op.convert_value(1.0), 1.0) + self.assertEqual(op.convert_value("text"), "text") + self.assertEqual(op.convert_value(True), True) + self.assertEqual(op.convert_value({"a": "b"}), {"a": "b"}) from datetime import datetime now = datetime.now() - self.assertEqual(op.convert_value("datetime", now), str(now)) + self.assertEqual(op.convert_value(now), str(now)) from cassandra.util import Date date_str = "2018-01-01" date = Date(date_str) - self.assertEqual(op.convert_value("date", date), str(date_str)) + self.assertEqual(op.convert_value(date), str(date_str)) import uuid from base64 import b64encode test_uuid = uuid.uuid4() encoded_uuid = b64encode(test_uuid.bytes).decode("ascii") - self.assertEqual(op.convert_value("uuid", test_uuid), encoded_uuid) + self.assertEqual(op.convert_value(test_uuid), encoded_uuid) - b = b"abc" - encoded_b = b64encode(b).decode("ascii") - self.assertEqual(op.convert_value("binary", b), encoded_b) + byte_str = b"abc" + encoded_b = b64encode(byte_str).decode("ascii") + self.assertEqual(op.convert_value(byte_str), encoded_b) from decimal import Decimal - d = Decimal(1.0) - self.assertEqual(op.convert_value("decimal", d), float(d)) + decimal = Decimal(1.0) + self.assertEqual(op.convert_value(decimal), float(decimal)) from cassandra.util import Time time = Time(0) - self.assertEqual(op.convert_value("time", time), "00:00:00") + self.assertEqual(op.convert_value(time), "00:00:00") date_str_lst = ["2018-01-01", "2018-01-02", "2018-01-03"] date_lst = [Date(d) for d in date_str_lst] - self.assertEqual(op.convert_value("list", date_lst), date_str_lst) + self.assertEqual(op.convert_value(date_lst), date_str_lst) date_tpl = tuple(date_lst) self.assertEqual( - op.convert_value("tuple", date_tpl), + op.convert_value(date_tpl), {"field_0": "2018-01-01", "field_1": "2018-01-02", "field_2": "2018-01-03"}, )