Skip to content

Commit

Permalink
[AIRFLOW-1816] Add region param to Dataproc operators
Browse files Browse the repository at this point in the history
Closes #2788 from DanSedov/master
  • Loading branch information
DanSedov authored and criccomini committed Nov 15, 2017
1 parent 5157b5a commit d04519e
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions airflow/contrib/operators/dataproc_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,7 @@ def __init__(
dataproc_pig_jars=None,
gcp_conn_id='google_cloud_default',
delegate_to=None,
region='global',
*args,
**kwargs):
"""
Expand Down Expand Up @@ -474,6 +475,8 @@ def __init__(
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: string
:param region: The specified region where the dataproc cluster is created.
:type region: string
"""
super(DataProcPigOperator, self).__init__(*args, **kwargs)
self.gcp_conn_id = gcp_conn_id
Expand All @@ -485,6 +488,7 @@ def __init__(
self.cluster_name = cluster_name
self.dataproc_properties = dataproc_pig_properties
self.dataproc_jars = dataproc_pig_jars
self.region = region

def execute(self, context):
hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
Expand All @@ -500,7 +504,7 @@ def execute(self, context):
job.add_jar_file_uris(self.dataproc_jars)
job.set_job_name(self.job_name)

hook.submit(hook.project_id, job.build())
hook.submit(hook.project_id, job.build(), self.region)


class DataProcHiveOperator(BaseOperator):
Expand Down Expand Up @@ -606,6 +610,7 @@ def __init__(
dataproc_spark_jars=None,
gcp_conn_id='google_cloud_default',
delegate_to=None,
region='global',
*args,
**kwargs):
"""
Expand Down Expand Up @@ -635,6 +640,8 @@ def __init__(
For this to work, the service account making the request must have domain-wide
delegation enabled.
:type delegate_to: string
:param region: The specified region where the dataproc cluster is created.
:type region: string
"""
super(DataProcSparkSqlOperator, self).__init__(*args, **kwargs)
self.gcp_conn_id = gcp_conn_id
Expand All @@ -646,6 +653,7 @@ def __init__(
self.cluster_name = cluster_name
self.dataproc_properties = dataproc_spark_properties
self.dataproc_jars = dataproc_spark_jars
self.region = region

def execute(self, context):
hook = DataProcHook(gcp_conn_id=self.gcp_conn_id,
Expand All @@ -662,7 +670,7 @@ def execute(self, context):
job.add_jar_file_uris(self.dataproc_jars)
job.set_job_name(self.job_name)

hook.submit(hook.project_id, job.build())
hook.submit(hook.project_id, job.build(), self.region)


class DataProcSparkOperator(BaseOperator):
Expand Down

0 comments on commit d04519e

Please sign in to comment.