Skip to content

Commit

Permalink
fixed test_end_to_end()
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Marin committed Mar 19, 2018
1 parent 299701f commit 8a7b9d3
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 24 deletions.
9 changes: 3 additions & 6 deletions mrjob/dataproc.py
Expand Up @@ -25,11 +25,10 @@
import google.auth
import google.cloud.dataproc_v1
import google.cloud.dataproc_v1.types
from google.api_core.exceptions import NotFound
import google.api_core.exceptions
from google.api_core.grpc_helpers import create_channel
except:
google = None
NotFound = None
create_channel = None

import mrjob
Expand Down Expand Up @@ -121,8 +120,6 @@ def _job_state_name(state_value):
return google.cloud.dataproc_v1.types.JobStatus.State.Name(state_value)




########## BEGIN - Helper fxns for _cluster_create_kwargs ##########
def _gcp_zone_uri(project, zone):
return (
Expand Down Expand Up @@ -502,7 +499,7 @@ def _create_fs_tmp_bucket(self, bucket_name, location=None):
try:
self.fs.get_bucket(bucket_name)
return
except NotFound:
except google.api_core.exceptions.NotFound:
pass

log.info('creating FS bucket %r' % bucket_name)
Expand Down Expand Up @@ -661,7 +658,7 @@ def _launch_cluster(self):
try:
self._get_cluster(self._cluster_id)
log.info('Adding job to existing cluster - %s' % self._cluster_id)
except NotFound:
except google.api_core.exceptions.NotFound:
log.info(
'Creating Dataproc Hadoop cluster - %s' % self._cluster_id)

Expand Down
64 changes: 56 additions & 8 deletions tests/mock_google/dataproc.py
Expand Up @@ -82,6 +82,16 @@ def create_cluster(self, project_id, region, cluster):

self.mock_clusters[cluster_key] = cluster

def delete_cluster(self, project_id, region, cluster_name):
cluster_key = (project_id, region, cluster_name)

cluster = self.mock_clusters.get(cluster_key)

if not cluster:
raise NotFound('Not found: Cluster ' + _cluster_path(*cluster_key))

cluster.status.state = _cluster_state_value('DELETING')

def get_cluster(self, project_id, region, cluster_name):
cluster_key = (project_id, region, cluster_name)
if cluster_key not in self.mock_clusters:
Expand All @@ -91,12 +101,22 @@ def get_cluster(self, project_id, region, cluster_name):
cluster = self.mock_clusters[cluster_key]

result = deepcopy(cluster)
self._simulate_progress(cluster)
self._simulate_progress(project_id, region, cluster_name)
return result

def _simulate_progress(self, mock_cluster):
# just move from STARTING to RUNNING
mock_cluster.status.state = _cluster_state_value('RUNNING')
def _simulate_progress(self, project_id, region, cluster_name):
cluster_key = (project_id, region, cluster_name)
cluster = self.mock_clusters[cluster_key]

state_name = _cluster_state_name(cluster.status.state)

if state_name == 'DELETING':
del self.mock_clusters[cluster_key]
else:
# just move from STARTING to RUNNING
cluster.status.state = _cluster_state_value('RUNNING')




class MockGoogleDataprocJobClient(object):
Expand Down Expand Up @@ -134,7 +154,7 @@ def submit_job(self, project_id, region, job):
else:
job.reference.project_id = project_id

job.status.state = _job_state_value('PENDING')
job.status.state = _job_state_value('SETUP_DONE')

job_key = (project_id, region, job_id)

Expand All @@ -158,12 +178,40 @@ def get_job(self, project_id, region, job_id):
self._simulate_progress(job)
return result

def list_jobs(self, project_id, region, page_size=None,
cluster_name=None, job_state_matcher=None):
if page_size:
raise NotImplementedError('page_size is not mocked')

for job_key, job in self.mock_jobs.items():
job_project_id, job_region, job_id = job_key

if job_project_id != project_id:
continue

if job_region != region:
continue

if cluster_name and job.placement.cluster_name != cluster_name:
continue

if job_state_matcher:
if job_state_matcher != _STATE_MATCHER_ACTIVE:
raise NotImplementedError(
'only ACTIVE job state matcher is mocked')

if (_job_state_name(job.status.state) not in
('PENDING', 'RUNNING', 'CANCEL_PENDING')):
continue

yield deepcopy(job)

def _simulate_progress(self, mock_job):
state = _job_state_name(mock_job.status.state)

if state == 'PENDING':
mock_job.status.state = _job_state_value('SETUP_DONE')
elif state == 'SETUP_DONE':
if state == 'SETUP_DONE':
mock_job.status.state = _job_state_value('PENDING')
elif state == 'PENDING':
mock_job.status.state = _job_state_value('RUNNING')
elif state == 'RUNNING':
if self.mock_jobs_succeed:
Expand Down
19 changes: 9 additions & 10 deletions tests/test_dataproc.py
Expand Up @@ -135,13 +135,14 @@ def test_end_to_end(self):

# make sure our input and output formats are attached to
# the correct steps
jobs_list = runner.jobs_client.list_jobs(
projectId=runner._project_id,
region=_DATAPROC_API_REGION).execute()
jobs = jobs_list['items']
jobs = list(runner._list_jobs())
self.assertEqual(len(jobs), 2)

step_0_args = jobs[0]['hadoopJob']['args']
step_1_args = jobs[1]['hadoopJob']['args']
# put earliest job first
jobs.sort(key=lambda j: j.reference.job_id)

step_0_args = jobs[0].hadoop_job.args
step_1_args = jobs[1].hadoop_job.args

self.assertIn('-inputformat', step_0_args)
self.assertNotIn('-outputformat', step_0_args)
Expand Down Expand Up @@ -176,10 +177,8 @@ def test_end_to_end(self):
self.assertEqual(len(fake_gcs_output), len(output_dirs))

# job should get terminated
cluster = (
self._dataproc_client._cache_clusters[_TEST_PROJECT][cluster_id])
cluster_state = self._dataproc_client.get_state(cluster)
self.assertEqual(cluster_state, 'DELETING')
cluster = runner._get_cluster(cluster_id)
self.assertEqual(_cluster_state_name(cluster.status.state), 'DELETING')

def test_failed_job(self):
mr_job = MRTwoStepJob(['-r', 'dataproc', '-v'])
Expand Down

0 comments on commit 8a7b9d3

Please sign in to comment.