Skip to content

Commit

Permalink
Fix Google providers type errors
Browse files Browse the repository at this point in the history
  • Loading branch information
ahidalgob committed May 17, 2023
1 parent 51a6787 commit 31d7cd6
Show file tree
Hide file tree
Showing 22 changed files with 169 additions and 72 deletions.
37 changes: 20 additions & 17 deletions airflow/providers/google/cloud/hooks/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
LoadJob,
QueryJob,
SchemaField,
UnknownJob,
)
from google.cloud.bigquery.dataset import AccessEntry, Dataset, DatasetListItem, DatasetReference
from google.cloud.bigquery.table import EncryptionConfiguration, Row, RowIterator, Table, TableReference
Expand Down Expand Up @@ -319,7 +320,7 @@ def create_empty_table(
view: dict | None = None,
materialized_view: dict | None = None,
encryption_configuration: dict | None = None,
retry: Retry | None = DEFAULT_RETRY,
retry: Retry = DEFAULT_RETRY,
location: str | None = None,
exists_ok: bool = True,
) -> Table:
Expand Down Expand Up @@ -1062,7 +1063,9 @@ def get_datasets_list(
# If iterator is requested, we cannot perform a list() on it to log the number
# of datasets because we will have started iteration
if return_iterator:
return iterator
# The iterator returned by list_datasets() is a HTTPIterator but annotated
# as Iterator
return iterator # type: ignore

datasets_list = list(iterator)
self.log.info("Datasets List: %s", len(datasets_list))
Expand Down Expand Up @@ -1294,9 +1297,9 @@ def list_rows(
selected_fields = selected_fields.split(",")

if selected_fields:
selected_fields = [SchemaField(n, "") for n in selected_fields]
selected_fields_sequence = [SchemaField(n, "") for n in selected_fields]
else:
selected_fields = None
selected_fields_sequence = None

table = self._resolve_table_reference(
table_resource={},
Expand All @@ -1307,7 +1310,7 @@ def list_rows(

iterator = self.get_client(project_id=project_id, location=location).list_rows(
table=Table.from_api_repr(table),
selected_fields=selected_fields,
selected_fields=selected_fields_sequence,
max_results=max_results,
page_token=page_token,
start_index=start_index,
Expand Down Expand Up @@ -1503,17 +1506,17 @@ def cancel_job(
@GoogleBaseHook.fallback_to_default_project_id
def get_job(
self,
job_id: str | None = None,
job_id: str,
project_id: str | None = None,
location: str | None = None,
) -> CopyJob | QueryJob | LoadJob | ExtractJob:
) -> CopyJob | QueryJob | LoadJob | ExtractJob | UnknownJob:
"""
Retrieves a BigQuery job. For more information see:
https://cloud.google.com/bigquery/docs/reference/v2/jobs
:param job_id: The ID of the job. The ID must contain only letters (a-z, A-Z),
numbers (0-9), underscores (_), or dashes (-). The maximum length is 1,024
characters. If not provided then uuid will be generated.
characters.
:param project_id: Google Cloud Project where the job is running
:param location: location the job is running
"""
Expand Down Expand Up @@ -1570,30 +1573,30 @@ def insert_job(
"jobReference": {"jobId": job_id, "projectId": project_id, "location": location},
}

supported_jobs = {
supported_jobs: dict[str, type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob]] = {
LoadJob._JOB_TYPE: LoadJob,
CopyJob._JOB_TYPE: CopyJob,
ExtractJob._JOB_TYPE: ExtractJob,
QueryJob._JOB_TYPE: QueryJob,
}

job = None
job: type[CopyJob] | type[QueryJob] | type[LoadJob] | type[ExtractJob] | None = None
for job_type, job_object in supported_jobs.items():
if job_type in configuration:
job = job_object
break

if not job:
raise AirflowException(f"Unknown job type. Supported types: {supported_jobs.keys()}")
job = job.from_api_repr(job_data, client)
self.log.info("Inserting job %s", job.job_id)
job_api_repr = job.from_api_repr(job_data, client)
self.log.info("Inserting job %s", job_api_repr.job_id)
if nowait:
# Initiate the job and don't wait for it to complete.
job._begin()
job_api_repr._begin()
else:
# Start the job and wait for it to complete and get the result.
job.result(timeout=timeout, retry=retry)
return job
job_api_repr.result(timeout=timeout, retry=retry)
return job_api_repr

def run_with_configuration(self, configuration: dict) -> str:
"""
Expand Down Expand Up @@ -2527,7 +2530,7 @@ def get_datasets_list(self, *args, **kwargs) -> list | HTTPIterator:
)
return self.hook.get_datasets_list(*args, **kwargs)

def get_dataset(self, *args, **kwargs) -> dict:
def get_dataset(self, *args, **kwargs) -> Dataset:
"""
This method is deprecated.
Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.get_dataset`
Expand Down Expand Up @@ -2671,7 +2674,7 @@ def run_copy(self, *args, **kwargs) -> str:
)
return self.hook.run_copy(*args, **kwargs)

def run_extract(self, *args, **kwargs) -> str:
def run_extract(self, *args, **kwargs) -> str | BigQueryJob:
"""
This method is deprecated.
Please use `airflow.providers.google.cloud.hooks.bigquery.BigQueryHook.run_extract`
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def insert_instance_template(
timeout=timeout,
metadata=metadata,
)
self._wait_for_operation_to_complete(operation_name=operation.name, project_id=self.project_id)
self._wait_for_operation_to_complete(operation_name=operation.name, project_id=project_id)

@GoogleBaseHook.fallback_to_default_project_id
def delete_instance_template(
Expand Down Expand Up @@ -196,7 +196,7 @@ def delete_instance_template(
timeout=timeout,
metadata=metadata,
)
self._wait_for_operation_to_complete(operation_name=operation.name, project_id=self.project_id)
self._wait_for_operation_to_complete(operation_name=operation.name, project_id=project_id)

@GoogleBaseHook.fallback_to_default_project_id
def get_instance_template(
Expand Down
4 changes: 3 additions & 1 deletion airflow/providers/google/cloud/hooks/mlengine.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,9 @@ async def get_job_status(
self._check_fileds(project_id=project_id, job_id=job_id)
async with ClientSession() as session:
try:
job = await self.get_job(project_id=project_id, job_id=job_id, session=session)
job = await self.get_job(
project_id=project_id, job_id=job_id, session=session # type: ignore
)
job = await job.json(content_type=None)
self.log.info("Retrieving json_response: %s", job)

Expand Down
19 changes: 17 additions & 2 deletions airflow/providers/google/cloud/hooks/natural_language.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self._conn = None
self._conn: LanguageServiceClient | None = None

def get_conn(self) -> LanguageServiceClient:
"""
Expand Down Expand Up @@ -104,6 +104,8 @@ def analyze_entities(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
return client.analyze_entities(
document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata
)
Expand Down Expand Up @@ -132,6 +134,8 @@ def analyze_entity_sentiment(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
return client.analyze_entity_sentiment(
document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata
)
Expand Down Expand Up @@ -159,6 +163,8 @@ def analyze_sentiment(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
return client.analyze_sentiment(
document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata
)
Expand Down Expand Up @@ -187,6 +193,8 @@ def analyze_syntax(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
return client.analyze_syntax(
document=document, encoding_type=encoding_type, retry=retry, timeout=timeout, metadata=metadata
)
Expand All @@ -196,7 +204,7 @@ def annotate_text(
self,
document: dict | Document,
features: dict | AnnotateTextRequest.Features,
encoding_type: EncodingType = None,
encoding_type: EncodingType | None = None,
retry: Retry | _MethodDefault = DEFAULT,
timeout: float | None = None,
metadata: Sequence[tuple[str, str]] = (),
Expand All @@ -218,6 +226,11 @@ def annotate_text(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
if isinstance(features, dict):
features = AnnotateTextRequest.Features(features)

return client.annotate_text(
document=document,
features=features,
Expand Down Expand Up @@ -248,4 +261,6 @@ def classify_text(
"""
client = self.get_conn()

if isinstance(document, dict):
document = Document(document)
return client.classify_text(document=document, retry=retry, timeout=timeout, metadata=metadata)
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/spanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self._client = None
self._client: Client | None = None

def _get_client(self, project_id: str) -> Client:
"""
Expand All @@ -75,7 +75,7 @@ def get_instance(
self,
instance_id: str,
project_id: str,
) -> Instance:
) -> Instance | None:
"""
Gets information about a particular instance.
Expand Down
7 changes: 6 additions & 1 deletion airflow/providers/google/cloud/hooks/speech_to_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self._client = None
self._client: SpeechClient | None = None

def get_conn(self) -> SpeechClient:
"""
Expand Down Expand Up @@ -92,6 +92,11 @@ def recognize_speech(
Note that if retry is specified, the timeout applies to each individual attempt.
"""
client = self.get_conn()
if isinstance(config, dict):
config = RecognitionConfig(config)
if isinstance(audio, dict):
audio = RecognitionAudio(audio)

response = client.recognize(config=config, audio=audio, retry=retry, timeout=timeout)
self.log.info("Recognised speech: %s", response)
return response
7 changes: 7 additions & 0 deletions airflow/providers/google/cloud/hooks/text_to_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ def synthesize_speech(
https://googleapis.github.io/google-cloud-python/latest/texttospeech/gapic/v1/types.html#google.cloud.texttospeech_v1.types.SynthesizeSpeechResponse
"""
client = self.get_conn()

if isinstance(input_data, dict):
input_data = SynthesisInput(input_data)
if isinstance(voice, dict):
voice = VoiceSelectionParams(voice)
if isinstance(audio_config, dict):
audio_config = AudioConfig(audio_config)
self.log.info("Synthesizing input: %s", input_data)

return client.synthesize_speech(
Expand Down
4 changes: 2 additions & 2 deletions airflow/providers/google/cloud/hooks/video_intelligence.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def __init__(
gcp_conn_id=gcp_conn_id,
impersonation_chain=impersonation_chain,
)
self._conn = None
self._conn: VideoIntelligenceServiceClient | None = None

def get_conn(self) -> VideoIntelligenceServiceClient:
"""Returns Gcp Video Intelligence Service client"""
Expand All @@ -82,7 +82,7 @@ def annotate_video(
input_uri: str | None = None,
input_content: bytes | None = None,
features: Sequence[Feature] | None = None,
video_context: dict | VideoContext = None,
video_context: dict | VideoContext | None = None,
output_uri: str | None = None,
location: str | None = None,
retry: Retry | _MethodDefault = DEFAULT,
Expand Down

0 comments on commit 31d7cd6

Please sign in to comment.