Skip to content

Commit

Permalink
Add job submission api to ray
Browse files Browse the repository at this point in the history
  • Loading branch information
kartik4949 committed Jun 13, 2024
1 parent fc9184c commit 00ac16a
Show file tree
Hide file tree
Showing 12 changed files with 217 additions and 44 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Handle ProgrammingError of SnowFlake for non-existing objects
- Updated the use cases.
- Update references to components and artifacts.
- Fix Ray compute async with job submission api.

## [0.1.1](https://github.com/SuperDuperDB/superduperdb/compare/0.0.20...0.1.0]) (2023-Feb-09)

Expand Down
4 changes: 2 additions & 2 deletions deploy/testenv/env/smoke/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ cluster:
strategy: null
uri: http://cdc:8001
compute:
uri: ray://ray-head:10001
uri: http://ray-head:8265
vector_search:
uri: http://vector-search:8000
backfill_batch_size: 100
Expand All @@ -28,4 +28,4 @@ retries:
stop_after_attempt: 2
wait_max: 10.0
wait_min: 4.0
wait_multiplier: 1.0
wait_multiplier: 1.0
8 changes: 8 additions & 0 deletions superduperdb/backends/base/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ def type(self) -> str:
"""Return the type of compute engine."""
pass

@abstractproperty
def remote(self) -> bool:
"""Return if remote compute engine."""
pass

@abstractproperty
def name(self) -> str:
"""Return the name of current compute engine."""
Expand Down Expand Up @@ -63,3 +68,6 @@ def disconnect(self) -> None:
def shutdown(self) -> None:
"""Shuts down the compute cluster."""
pass

def execute_task(self, job_id, dependencies, compute_kwargs={}):
"""Execute task function for distributed backends."""
5 changes: 5 additions & 0 deletions superduperdb/backends/local/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def __init__(
):
self.__outputs: t.Dict = {}

@property
def remote(self) -> bool:
"""Return if remote compute engine."""
return False

@property
def type(self) -> str:
"""The type of the backend."""
Expand Down
76 changes: 61 additions & 15 deletions superduperdb/backends/ray/compute.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import json
import os
import typing as t
import uuid

import ray
from ray.job_submission import JobSubmissionClient

from superduperdb import logging
from superduperdb.backends.base.compute import ComputeBackend
Expand All @@ -22,10 +26,13 @@ def __init__(
):
self._futures_collection: t.Dict[str, ray.ObjectRef] = {}
self.address = address
if local:
ray.init(ignore_reinit_error=True)
else:
ray.init(address=address, **kwargs, ignore_reinit_error=True)

self.client = JobSubmissionClient(self.address)

@property
def remote(self) -> bool:
"""Return if remote compute engine."""
return True

@property
def type(self) -> str:
Expand All @@ -37,24 +44,58 @@ def name(self) -> str:
"""The name of the compute backend."""
return f"ray://{self.address}"

def submit(
self, function: t.Callable, *args, compute_kwargs: t.Dict = {}, **kwargs
def submit(self, identifier, dependencies=(), compute_kwargs={}):
"""
Submit job to remote cluster.
:param identifier: Job identifier.
:param dependencies: List of dependencies on the job.
:param compute_kwargs: Compute kwargs for the job.
"""
try:
uuid.UUID(str(identifier))
except ValueError:
raise ValueError(f'Identifier {identifier} is not valid')
dependencies = list([d for d in dependencies if d is not None])

if dependencies:
dependencies = f"dependencies={json.dumps(dependencies)}"
job_string = f"remote_job(\"{identifier}\", {dependencies}"
else:
job_string = f"remote_job(\"{identifier}\""

if compute_kwargs:
job_string += f", compute_kwargs={json.dumps(compute_kwargs)})"
else:
job_string += ")"

entrypoint = (
f"python -c 'from superduperdb.jobs.job import remote_job; {job_string}'"
)

runtime_env = {}
env_vars = {
k: os.environ[k] for k in os.environ if k.startswith('SUPERDUPERDB_')
}
if env_vars:
runtime_env = {'env_vars': env_vars}
job_id = self.client.submit_job(entrypoint=entrypoint, runtime_env=runtime_env)
return job_id

def execute_task(
self, job_id, dependencies, compute_kwargs={}
) -> t.Tuple[ray.ObjectRef, str]:
"""
Submits a function to the ray server for execution.
:param function: The function to be executed.
:param args: Positional arguments to be passed to the function.
:param compute_kwargs: Additional keyword arguments to be passed to ray API.
:param kwargs: Keyword arguments to be passed to the function.
"""

def _dependable_remote_job(function, *args, **kwargs):
if (
function.__name__ in ['method_job', 'callable_job']
and 'dependencies' in kwargs
):
dependencies = kwargs['dependencies']
if 'dependencies' in kwargs:
dependencies = kwargs.pop('dependencies', None)
if dependencies:
ray.wait(dependencies)
return function(*args, **kwargs)
Expand All @@ -63,12 +104,17 @@ def _dependable_remote_job(function, *args, **kwargs):
remote_function = ray.remote(**compute_kwargs)(_dependable_remote_job)
else:
remote_function = ray.remote(_dependable_remote_job)
future = remote_function.remote(function, *args, **kwargs)

from superduperdb.jobs.job import remote_task

future = remote_function.remote(remote_task, job_id, dependencies=dependencies)

ray.get(future)
task_id = str(future.task_id().hex())
self._futures_collection[task_id] = future

logging.success(
f"Job submitted on {self}. function: {function}; "
f"Job submitted on {self}. function: remote_job; "
f"task: {task_id}; job_id: {str(future.job_id())}"
)
return future, task_id
Expand Down Expand Up @@ -108,4 +154,4 @@ def disconnect(self) -> None:

def shutdown(self) -> None:
"""Shuts down the ray cluster."""
ray.shutdown()
raise NotImplementedError
2 changes: 1 addition & 1 deletion superduperdb/backends/sqlalchemy/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def _init_tables(self):
Column('method_name', type_string),
Column('stdout', type_json_as_string),
Column('stderr', type_json_as_string),
Column('cls', type_string),
Column('_path', type_string),
Column('job_id', type_string),
*job_table_args,
)
Expand Down
21 changes: 14 additions & 7 deletions superduperdb/base/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,20 @@
from superduperdb.misc.anonymize import anonymize_url


def _get_metadata_store(cfg):
# try to connect to the metadata store specified in the configuration.
logging.info("Connecting to Metadata Client:", cfg.metadata_store)
return _build_databackend_impl(cfg.metadata_store, metadata_stores, type='metadata')


def _build_metadata(cfg, databackend: t.Optional['BaseDataBackend'] = None):
# Connect to metadata store.
# ------------------------------
# 1. try to connect to the metadata store specified in the configuration.
# 2. if that fails, try to connect to the data backend engine.
# 3. if that fails, try to connect to the data backend uri.
if cfg.metadata_store is not None:
# try to connect to the metadata store specified in the configuration.
logging.info("Connecting to Metadata Client:", cfg.metadata_store)
return _build_databackend_impl(
cfg.metadata_store, metadata_stores, type='metadata'
)
return _get_metadata_store(cfg)
else:
try:
# try to connect to the data backend engine.
Expand Down Expand Up @@ -151,7 +153,12 @@ def _build_databackend_impl(uri, mapping, type: str = 'data_backend'):
return mapping['sqlalchemy'](sql_conn, name)


def _build_compute(compute):
def build_compute(compute):
"""
Helper function to build compute backend.
:param compute: Compute uri.
"""
logging.info("Connecting to compute client:", compute)

if compute is None:
Expand Down Expand Up @@ -180,7 +187,7 @@ def build_datalayer(cfg=None, databackend=None, **kwargs) -> Datalayer:
assert metadata

artifact_store = _build_artifact_store(cfg.artifact_store, databackend)
compute = _build_compute(cfg.cluster.compute.uri)
compute = build_compute(cfg.cluster.compute.uri)

datalayer = Datalayer(
databackend=databackend,
Expand Down
2 changes: 1 addition & 1 deletion superduperdb/base/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ class Config(BaseConfig):

envs: dc.InitVar[t.Optional[t.Dict[str, str]]] = None

data_backend: str = 'mongodb://localhost:27017/test_db'
data_backend: str = 'mongodb://mongodb:27017/test_db'
lance_home: str = os.path.join('.superduperdb', 'vector_indices')

artifact_store: t.Optional[str] = None
Expand Down
4 changes: 3 additions & 1 deletion superduperdb/ext/sentence_transformers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,15 +88,17 @@ def predict(self, X, *args, **kwargs):
result = self.postprocess(result)
return result

@t.no_type_check
@ensure_initialized
def predict_batches(self, dataset: t.Union[t.List, QueryDataset]) -> t.List:
"""Predict on a dataset.
:param dataset: The dataset to predict on.
"""
if self.preprocess is not None:
dataset = list(map(self.preprocess, dataset)) # type: ignore[arg-type]
dataset = list(map(self.preprocess, dataset))
assert self.object is not None

results = self.object.encode(dataset, **self.predict_kwargs)
if self.postprocess is not None:
results = self.postprocess(results)
Expand Down
Loading

0 comments on commit 00ac16a

Please sign in to comment.