Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 117 additions & 93 deletions airflow/operators/cassandra_to_gcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
This module contains operator for copying
data from Cassandra to Google cloud storage in JSON format.
"""

import json
import warnings
from base64 import b64encode
from datetime import datetime
from decimal import Decimal
from tempfile import NamedTemporaryFile
from typing import Optional
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from uuid import UUID

from cassandra.util import Date, OrderedMapSerializedKey, SortedSet, Time
Expand Down Expand Up @@ -82,7 +83,7 @@ class CassandraToGoogleCloudStorageOperator(BaseOperator):
ui_color = '#a0e08c'

@apply_defaults
def __init__(self,
def __init__(self, # pylint: disable=too-many-arguments
cql: str,
bucket: str,
filename: str,
Expand Down Expand Up @@ -113,8 +114,6 @@ def __init__(self,
self.delegate_to = delegate_to
self.gzip = gzip

self.hook = None

# Default Cassandra to BigQuery type mapping
CQL_TYPE_MAP = {
'BytesType': 'BYTES',
Expand All @@ -141,8 +140,10 @@ def __init__(self,
'VarcharType': 'STRING',
}

def execute(self, context):
cursor = self._query_cassandra()
def execute(self, context: Dict[str, str]):
hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
cursor = hook.get_conn().execute(self.cql)

files_to_upload = self._write_local_data_files(cursor)

# If a schema is set, create a BQ schema JSON file.
Expand All @@ -160,16 +161,7 @@ def execute(self, context):
file_handle.close()

# Close all sessions and connection associated with this Cassandra cluster
self.hook.shutdown_cluster()

def _query_cassandra(self):
"""
Queries cassandra and returns a cursor to the results.
"""
self.hook = CassandraHook(cassandra_conn_id=self.cassandra_conn_id)
session = self.hook.get_conn()
cursor = session.execute(self.cql)
return cursor
hook.shutdown_cluster()

def _write_local_data_files(self, cursor):
"""
Expand All @@ -184,8 +176,8 @@ def _write_local_data_files(self, cursor):
tmp_file_handles = {self.filename.format(file_no): tmp_file_handle}
for row in cursor:
row_dict = self.generate_data_dict(row._fields, row)
s = json.dumps(row_dict).encode('utf-8')
tmp_file_handle.write(s)
content = json.dumps(row_dict).encode('utf-8')
tmp_file_handle.write(content)

# Append newline to make dumps BigQuery compatible.
tmp_file_handle.write(b'\n')
Expand All @@ -209,29 +201,41 @@ def _write_local_schema_file(self, cursor):
schema = []
tmp_schema_file_handle = NamedTemporaryFile(delete=True)

for name, type in zip(cursor.column_names, cursor.column_types):
schema.append(self.generate_schema_dict(name, type))
for name, type_ in zip(cursor.column_names, cursor.column_types):
schema.append(self.generate_schema_dict(name, type_))
json_serialized_schema = json.dumps(schema).encode('utf-8')

tmp_schema_file_handle.write(json_serialized_schema)
return {self.schema_filename: tmp_schema_file_handle}

def _upload_to_gcs(self, files_to_upload):
def _upload_to_gcs(self, files_to_upload: Dict[str, Any]):
hook = GoogleCloudStorageHook(
google_cloud_storage_conn_id=self.gcp_conn_id,
delegate_to=self.delegate_to)
for object, tmp_file_handle in files_to_upload.items():
hook.upload(self.bucket, object, tmp_file_handle.name, 'application/json', self.gzip)
for obj, tmp_file_handle in files_to_upload.items():
hook.upload(
bucket_name=self.bucket,
object_name=obj,
filename=tmp_file_handle.name,
mime_type='application/json',
gzip=self.gzip
)

@classmethod
def generate_data_dict(cls, names, values):
row_dict = {}
for name, value in zip(names, values):
row_dict.update({name: cls.convert_value(name, value)})
return row_dict
def generate_data_dict(cls, names: Iterable[str], values: Any) -> Dict[str, Any]:
"""
Generates data structure that will be stored as file in GCS.
"""
return {n: cls.convert_value(v) for n, v in zip(names, values)}

@classmethod
def convert_value(cls, name, value):
def convert_value( # pylint: disable=too-many-return-statements
cls,
value: Optional[Any]
) -> Optional[Any]:
"""
Convert value to BQ type.
"""
if not value:
return value
elif isinstance(value, (str, int, float, bool, dict)):
Expand All @@ -247,120 +251,140 @@ def convert_value(cls, name, value):
elif isinstance(value, Time):
return str(value).split('.')[0]
elif isinstance(value, (list, SortedSet)):
return cls.convert_array_types(name, value)
return cls.convert_array_types(value)
elif hasattr(value, '_fields'):
return cls.convert_user_type(name, value)
return cls.convert_user_type(value)
elif isinstance(value, tuple):
return cls.convert_tuple_type(name, value)
return cls.convert_tuple_type(value)
elif isinstance(value, OrderedMapSerializedKey):
return cls.convert_map_type(name, value)
return cls.convert_map_type(value)
else:
raise AirflowException('unexpected value: ' + str(value))
raise AirflowException('Unexpected value: ' + str(value))

@classmethod
def convert_array_types(cls, name, value):
return [cls.convert_value(name, nested_value) for nested_value in value]
def convert_array_types(cls, value: Union[List[Any], SortedSet]) -> List[Any]:
"""
Maps convert_value over array.
"""
return [cls.convert_value(nested_value) for nested_value in value]

@classmethod
def convert_user_type(cls, name, value):
def convert_user_type(cls, value: Any) -> Dict[str, Any]:
"""
Converts a user type to RECORD that contains n fields, where n is the
number of attributes. Each element in the user type class will be converted to its
corresponding data type in BQ.
"""
names = value._fields
values = [cls.convert_value(name, getattr(value, name)) for name in names]
values = [cls.convert_value(getattr(value, name)) for name in names]
return cls.generate_data_dict(names, values)

@classmethod
def convert_tuple_type(cls, name, value):
def convert_tuple_type(cls, values: Tuple[Any]) -> Dict[str, Any]:
"""
Converts a tuple to RECORD that contains n fields, each will be converted
to its corresponding data type in bq and will be named 'field_<index>', where
index is determined by the order of the tuple elements defined in cassandra.
"""
names = ['field_' + str(i) for i in range(len(value))]
values = [cls.convert_value(name, value) for name, value in zip(names, value)]
names = ['field_' + str(i) for i in range(len(values))]
return cls.generate_data_dict(names, values)

@classmethod
def convert_map_type(cls, name, value):
def convert_map_type(cls, value: OrderedMapSerializedKey) -> List[Dict[str, Any]]:
"""
Converts a map to a repeated RECORD that contains two fields: 'key' and 'value',
each will be converted to its corresponding data type in BQ.
"""
converted_map = []
for k, v in zip(value.keys(), value.values()):
converted_map.append({
'key': cls.convert_value('key', k),
'value': cls.convert_value('value', v)
'key': cls.convert_value(k),
'value': cls.convert_value(v)
})
return converted_map

@classmethod
def generate_schema_dict(cls, name, type):
field_schema = dict()
def generate_schema_dict(cls, name: str, type_: Any) -> Dict[str, Any]:
"""
Generates BQ schema.
"""
field_schema: Dict[str, Any] = dict()
field_schema.update({'name': name})
field_schema.update({'type': cls.get_bq_type(type)})
field_schema.update({'mode': cls.get_bq_mode(type)})
fields = cls.get_bq_fields(name, type)
field_schema.update({'type_': cls.get_bq_type(type_)})
field_schema.update({'mode': cls.get_bq_mode(type_)})
fields = cls.get_bq_fields(type_)
if fields:
field_schema.update({'fields': fields})
return field_schema

@classmethod
def get_bq_fields(cls, name, type):
fields = []

if not cls.is_simple_type(type):
names, types = [], []

if cls.is_array_type(type) and cls.is_record_type(type.subtypes[0]):
names = type.subtypes[0].fieldnames
types = type.subtypes[0].subtypes
elif cls.is_record_type(type):
names = type.fieldnames
types = type.subtypes

if types and not names and type.cassname == 'TupleType':
names = ['field_' + str(i) for i in range(len(types))]
elif types and not names and type.cassname == 'MapType':
names = ['key', 'value']

for name, type in zip(names, types):
field = cls.generate_schema_dict(name, type)
fields.append(field)

return fields

@classmethod
def is_simple_type(cls, type):
return type.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP
def get_bq_fields(cls, type_: Any) -> List[Dict[str, Any]]:
"""
Converts non simple type value to BQ representation.
"""
if cls.is_simple_type(type_):
return []

# In case of not simple type
names: List[str] = []
types: List[Any] = []
if cls.is_array_type(type_) and cls.is_record_type(type_.subtypes[0]):
names = type_.subtypes[0].fieldnames
types = type_.subtypes[0].subtypes
elif cls.is_record_type(type_):
names = type_.fieldnames
types = type_.subtypes

if types and not names and type_.cassname == 'TupleType':
names = ['field_' + str(i) for i in range(len(types))]
elif types and not names and type_.cassname == 'MapType':
names = ['key', 'value']

return [cls.generate_schema_dict(n, t) for n, t in zip(names, types)]

@staticmethod
def is_simple_type(type_: Any) -> bool:
"""
Check if type is a simple type.
"""
return type_.cassname in CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP

@classmethod
def is_array_type(cls, type):
return type.cassname in ['ListType', 'SetType']
@staticmethod
def is_array_type(type_: Any) -> bool:
"""
Check if type is an array type.
"""
return type_.cassname in ['ListType', 'SetType']

@classmethod
def is_record_type(cls, type):
return type.cassname in ['UserType', 'TupleType', 'MapType']
@staticmethod
def is_record_type(type_: Any) -> bool:
"""
Checks the record type.
"""
return type_.cassname in ['UserType', 'TupleType', 'MapType']

@classmethod
def get_bq_type(cls, type):
if cls.is_simple_type(type):
return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type.cassname]
elif cls.is_record_type(type):
def get_bq_type(cls, type_: Any) -> str:
"""
Converts type to equivalent BQ type.
"""
if cls.is_simple_type(type_):
return CassandraToGoogleCloudStorageOperator.CQL_TYPE_MAP[type_.cassname]
elif cls.is_record_type(type_):
return 'RECORD'
elif cls.is_array_type(type):
return cls.get_bq_type(type.subtypes[0])
elif cls.is_array_type(type_):
return cls.get_bq_type(type_.subtypes[0])
else:
raise AirflowException('Not a supported type: ' + type.cassname)
raise AirflowException('Not a supported type_: ' + type_.cassname)

@classmethod
def get_bq_mode(cls, type):
if cls.is_array_type(type) or type.cassname == 'MapType':
def get_bq_mode(cls, type_: Any) -> str:
"""
Converts type to equivalent BQ mode.
"""
if cls.is_array_type(type_) or type_.cassname == 'MapType':
return 'REPEATED'
elif cls.is_record_type(type) or cls.is_simple_type(type):
elif cls.is_record_type(type_) or cls.is_simple_type(type_):
return 'NULLABLE'
else:
raise AirflowException('Not a supported type: ' + type.cassname)
raise AirflowException('Not a supported type_: ' + type_.cassname)
2 changes: 0 additions & 2 deletions scripts/ci/pylint_todo.txt
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,6 @@
./airflow/models/variable.py
./airflow/models/xcom.py
./airflow/operators/bash_operator.py
./airflow/operators/cassandra_to_gcs.py
./airflow/operators/check_operator.py
./airflow/operators/dagrun_operator.py
./airflow/operators/druid_check_operator.py
Expand Down Expand Up @@ -303,7 +302,6 @@
./tests/models/test_pool.py
./tests/models/test_taskinstance.py
./tests/operators/test_bash_operator.py
./tests/operators/test_cassandra_to_gcs.py
./tests/operators/test_gcs_to_s3.py
./tests/operators/test_hive_operator.py
./tests/operators/test_operators.py
Expand Down
Loading