In [37]:
import time
import asyncio
import requests
import starlette

import ray
from ray import serve
from ray.experimental.dag.input_node import InputNode
from ray.serve.drivers import DAGDriver
from ray.serve.http_adapters import json_request

In [38]:
@serve.deployment
async def avg_preprocessor(input_data):
    """Simple feature processing that returns average of input list as float."""
    await asyncio.sleep(0.15) # Manual delay for blocking computation
    return sum(input_data) / len(input_data)

In [39]:
@serve.deployment
class Model:
    def __init__(self, weight: int):
        self.weight = weight

    async def forward(self, input: int):
        await asyncio.sleep(0.3) # Manual delay for blocking computation
        return f"({self.weight} * {input})"

In [40]:
@serve.deployment
class Combiner:
    def __init__(self, m: Model):
        self.m = m

    async def run(self, req_part, operation):
        # Merge model input from two preprocessors
        req = f"({req_part}"

        # Submit to model for inference
        r1_ref = self.m.forward.remote(req)

        # Async gathering of model forward results for same request data
        rst = await asyncio.gather(r1_ref)

        # Control flow that determines runtime behavior based on user input
        if operation == "sum":
            return f"sum({rst})"
        else:
            return f"max({rst})"

In [41]:
# DAG building
with InputNode() as dag_input:
    # Partial access of user input by index
    preprocessed_2 = avg_preprocessor.bind(dag_input[0])
    
    # Create a model Node 
    m1 = Model.bind(1)
    
    # Use other DeploymentNode in bind()
    combiner = Combiner.bind(m1)
    
    # Use output of function DeploymentNode in bind()
    dag = combiner.run.bind(preprocessed_2, dag_input[1])
    
    # Each serve dag has a driver deployment as ingress that can be user provided.
    serve_dag = DAGDriver.options(route_prefix="/my-dag", num_replicas=2).bind(
        dag, http_adapter=json_request
    )

In [42]:
dag_handle = serve.run(serve_dag)

[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:41,859 controller 57869 checkpoint_path.py:17 - Using RayInternalKVStore for controller checkpoint and recovery.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:41,964 controller 57869 http_state.py:112 - Starting HTTP proxy with name 'SERVE_CONTROLLER_ACTOR:SERVE_PROXY_ACTOR-node:127.0.0.1-0' on node 'node:127.0.0.1-0' listening on '127.0.0.1:8000'
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:42,488 controller 57869 deployment_state.py:1216 - Adding 1 replicas to deployment 'avg_preprocessor'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:42,494 controller 57869 deployment_state.py:1216 - Adding 1 replicas to deployment 'Model'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:42,500 controller 57869 deployment_state.py:1216 - Adding 1 replicas to deployment 'Combiner'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:18:42,505 controller 57869 depl

In [43]:
# Warm up
ray.get(dag_handle.predict.remote([[1, 2], "sum"]))

[2m[36m(DAGDriver pid=57877)[0m You are retrieving a sync handle inside an asyncio loop. Try getting client.get_handle(.., sync=False) to get better performance. Learn more at https://docs.ray.io/en/master/serve/http-servehandle.html#sync-and-async-handles
[2m[36m(DAGDriver pid=57877)[0m You are retrieving a sync handle inside an asyncio loop. Try getting client.get_handle(.., sync=False) to get better performance. Learn more at https://docs.ray.io/en/master/serve/http-servehandle.html#sync-and-async-handles
[2m[36m(avg_preprocessor pid=57873)[0m INFO 2022-07-05 17:18:45,170 avg_preprocessor avg_preprocessor#LigqIN replica.py:478 - HANDLE __call__ OK 151.3ms
[2m[36m(Combiner pid=57875)[0m You are retrieving a sync handle inside an asyncio loop. Try getting client.get_handle(.., sync=False) to get better performance. Learn more at https://docs.ray.io/en/master/serve/http-servehandle.html#sync-and-async-handles


"sum(['(1 * (1.5)'])"

[2m[36m(Model pid=57874)[0m INFO 2022-07-05 17:18:45,488 Model Model#uQCfEK replica.py:478 - HANDLE forward OK 302.0ms
[2m[36m(Combiner pid=57875)[0m INFO 2022-07-05 17:18:45,493 Combiner Combiner#EGMgfR replica.py:478 - HANDLE run OK 318.6ms
[2m[36m(DAGDriver pid=57877)[0m INFO 2022-07-05 17:18:45,496 DAGDriver DAGDriver#xewkyl replica.py:478 - HANDLE predict OK 492.4ms


In [44]:
serve.shutdown()

[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:19:12,885 controller 57869 deployment_state.py:1240 - Removing 1 replicas from deployment 'avg_preprocessor'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:19:12,893 controller 57869 deployment_state.py:1240 - Removing 1 replicas from deployment 'Model'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:19:12,897 controller 57869 deployment_state.py:1240 - Removing 1 replicas from deployment 'Combiner'.
[2m[36m(ServeController pid=57869)[0m INFO 2022-07-05 17:19:12,899 controller 57869 deployment_state.py:1240 - Removing 2 replicas from deployment 'DAGDriver'.
