Skip to content

Commit

Permalink
mocked out submit_job()
Browse files Browse the repository at this point in the history
  • Loading branch information
Dave Marin committed Mar 19, 2018
1 parent 0c3f753 commit 524939d
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
6 changes: 6 additions & 0 deletions tests/mock_google/case.py
Expand Up @@ -38,6 +38,10 @@ def setUp(self):
# google.cloud.dataproc_v1.types.Cluster
self.mock_clusters = {}

# maps (project_id, region, cluster_name) to job_id to a
# google.cloud.dataproc_v1.types.Job
self.mock_jobs = {}

self.mock_credentials = Credentials('mock_token')

# Maps bucket name to a dictionary with the key
Expand Down Expand Up @@ -67,11 +71,13 @@ def auth_default(self):
def cluster_client(self, channel=None, credentials=None):
return MockGoogleDataprocClusterClient(
mock_clusters=self.mock_clusters,
mock_jobs=self.mock_jobs,
mock_gcs_fs=self.mock_gcs_fs)

def job_client(self, channel=None, credentials=None):
return MockGoogleDataprocJobClient(
mock_clusters=self.mock_clusters,
mock_jobs=self.mock_jobs,
mock_gcs_fs=self.mock_gcs_fs)

def storage_client(self, project=None, credentials=None):
Expand Down
57 changes: 53 additions & 4 deletions tests/mock_google/dataproc.py
Expand Up @@ -21,14 +21,22 @@
from google.api_core.exceptions import NotFound
from google.cloud.dataproc_v1.types import Cluster
from google.cloud.dataproc_v1.types import ClusterStatus
from google.cloud.dataproc_v1.types import Job
from google.cloud.dataproc_v1.types import JobStatus


class MockGoogleDataprocClusterClient(object):

"""Mock out google.cloud.dataproc_v1.ClusterControllerClient"""
def __init__(self, mock_clusters, mock_gcs_fs):
def __init__(self, mock_clusters, mock_jobs, mock_gcs_fs):
# maps (project_id, region, cluster_name) to a
# google.cloud.dataproc_v1.types.Cluster
self.mock_clusters = mock_clusters

# maps (project_id, region, cluster_name, job_name) to a
# google.cloud.dataproc_v1.types.Job
self.mock_jobs = mock_jobs

# see case.py
self.mock_gcs_fs = mock_gcs_fs

Expand All @@ -42,8 +50,8 @@ def create_cluster(self, project_id, region, cluster):
raise InvalidArgument(
'If provided, CreateClusterRequest.cluster.project_id must'
' match CreateClusterRequest.project_id')
else:
cluster.project_id = project_id
else:
cluster.project_id = project_id

if not cluster.cluster_name:
raise InvalidArgument('Cluster name is required')
Expand All @@ -52,6 +60,9 @@ def create_cluster(self, project_id, region, cluster):
cluster.status.state = ClusterStatus.State.Value('CREATING')

cluster_key = (project_id, region, cluster.cluster_name)

# TODO: check conflict with existing cluster

self.mock_clusters[cluster_key] = cluster

def get_cluster(self, project_id, region, cluster_name):
Expand All @@ -76,6 +87,44 @@ def _simulate_progress(self, mock_cluster):
class MockGoogleDataprocJobClient(object):

"""Mock out google.cloud.dataproc_v1.JobControllerClient"""
def __init__(self, mock_clusters, mock_gcs_fs):
def __init__(self, mock_clusters, mock_jobs, mock_gcs_fs):
self.mock_clusters = mock_clusters
self.mock_jobs = mock_jobs
self.mock_gcs_fs = mock_gcs_fs

def submit_job(self, project_id, region, job):
# convert dict to object
if not isinstance(job, Job):
job = Job(**job)

if not (job.reference.project_id and job.reference.job_id):
raise NotImplementedError('generation of job IDs not implemented')

if not job.placement.cluster_name:
raise InvalidArgument('Cluster name is required')

if not job.hadoop_job:
raise NotImplementedError('only hadoop jobs are supported')

if job.reference.project_id:
if job.reference.project_id != project_id:
raise InvalidArgument(
'If provided, SubmitJobRequest.job.job_reference'
'.project_id must match SubmitJobRequest.project_id')
else:
job.reference.project_id = project_id

# copy the job as submitted
# TODO: look at what submit_job really returns
result = deepcopy(job)

job.status.state = JobStatus.State.Value('PENDING')

cluster_key = (project_id, region, job.placement.cluster_name)
cluster_jobs = self.mock_jobs.setdefault(cluster_key, {})

# TODO: check for conflict

cluster_jobs[job.reference.job_id] = job

return result

0 comments on commit 524939d

Please sign in to comment.