Skip to content

Commit

Permalink
mocked out get_job(), only a few tests left
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Marin committed Mar 19, 2018
1 parent 59bf5dd commit 6dd9d4c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 3 deletions.
13 changes: 13 additions & 0 deletions mrjob/dataproc.py
Expand Up @@ -24,6 +24,7 @@
try:
import google.auth
import google.cloud.dataproc_v1
import google.cloud.dataproc_v1.types
from google.api_core.exceptions import NotFound
from google.api_core.grpc_helpers import create_channel
except:
Expand Down Expand Up @@ -110,6 +111,18 @@
_GCP_CLUSTER_NAME_REGEX = '(?:[a-z](?:[-a-z0-9]{0,53}[a-z0-9])?).'


# convert enum values to strings (e.g. 'RUNNING')

def _cluster_state_name(state_value):
return google.cloud.dataproc_v1.ClusterStatus.State.Name(state_value)


def _job_state_name(state_value):
return google.cloud.dataproc_v1.types.JobStatus.State.Name(state_value)




########## BEGIN - Helper fxns for _cluster_create_kwargs ##########
def _gcp_zone_uri(project, zone):
return (
Expand Down
42 changes: 39 additions & 3 deletions tests/mock_google/dataproc.py
Expand Up @@ -25,6 +25,20 @@
from google.cloud.dataproc_v1.types import Job
from google.cloud.dataproc_v1.types import JobStatus

from mrjob.dataproc import _cluster_state_name
from mrjob.dataproc import _job_state_name


# convert strings (e.g. 'RUNNING') to enum values

def _cluster_state_value(state_name):
return ClusterStatus.State.Value(state_name)


def _job_state_value(state_name):
return JobStatus.State.Value(state_name)



class MockGoogleDataprocClusterClient(object):

Expand Down Expand Up @@ -58,7 +72,7 @@ def create_cluster(self, project_id, region, cluster):
raise InvalidArgument('Cluster name is required')

# initialize cluster status
cluster.status.state = ClusterStatus.State.Value('CREATING')
cluster.status.state = _cluster_state_value('CREATING')

cluster_key = (project_id, region, cluster.cluster_name)

Expand All @@ -82,7 +96,7 @@ def get_cluster(self, project_id, region, cluster_name):

def _simulate_progress(self, mock_cluster):
# just move from STARTING to RUNNING
mock_cluster.status.state = ClusterStatus.State.Value('RUNNING')
mock_cluster.status.state = _cluster_state_value('RUNNING')


class MockGoogleDataprocJobClient(object):
Expand Down Expand Up @@ -118,7 +132,7 @@ def submit_job(self, project_id, region, job):
else:
job.reference.project_id = project_id

job.status.state = JobStatus.State.Value('PENDING')
job.status.state = _job_state_value('PENDING')

job_key = (project_id, region, job_id)

Expand All @@ -130,6 +144,28 @@ def submit_job(self, project_id, region, job):

return deepcopy(job)

def get_job(self, project_id, region, job_id):
job_key = (project_id, region, job_id)

job = self.mock_jobs.get(job_key)

if not job:
raise NotFound('Not found: Job ' + _job_path(*job_key))

result = deepcopy(job)
self._simulate_progress(job)
return result

def _simulate_progress(self, mock_job):
state = _job_state_name(mock_job.status.state)

if state == 'PENDING':
mock_job.status.state = _job_state_value('SETUP_DONE')
elif state == 'SETUP_DONE':
mock_job.status.state = _job_state_value('RUNNING')
elif state == 'RUNNING':
mock_job.status.state = _job_state_value('DONE')


def _cluster_path(project_id, region, cluster_name):
return 'projects/%s/regions/%s/clusters/%s' % (
Expand Down

0 comments on commit 6dd9d4c

Please sign in to comment.