Skip to content

Commit

Permalink
Use Python asyncio gRPC
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Nov 16, 2023
1 parent e235660 commit fb38344
Show file tree
Hide file tree
Showing 9 changed files with 196 additions and 221 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,5 @@ tox = "*"
black = "*"
ruff = "*"
pytest = "*"
pytest-asyncio = "*"
grpcio-tools = "*"

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,21 @@
# limitations under the License.
#

import asyncio
import logging
import sys

from langstream_grpc.grpc_service import AgentServer


async def main(target, config, context):
server = AgentServer(target)
await server.init(config, context)
await server.start()
await server.grpc_server.wait_for_termination()
await server.stop()


if __name__ == "__main__":
logging.addLevelName(logging.WARNING, "WARN")
logging.basicConfig(
Expand All @@ -34,7 +44,4 @@
)
sys.exit(1)

server = AgentServer(sys.argv[1], sys.argv[2], sys.argv[3])
server.start()
server.grpc_server.wait_for_termination()
server.stop()
asyncio.run(main(*sys.argv))
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License.
#

import concurrent
import asyncio
import importlib
import json
import os
Expand All @@ -23,7 +23,7 @@
import threading
from concurrent.futures import Future
from io import BytesIO
from typing import Iterable, Union, List, Tuple, Any, Optional, Dict
from typing import Union, List, Tuple, Any, Optional, Dict, AsyncIterable

import fastavro
import grpc
Expand Down Expand Up @@ -78,30 +78,28 @@ def __init__(self, agent: Union[Agent, Source, Sink, Processor]):
self.schemas = {}
self.client_schemas = {}

def agent_info(self, _, __):
info = call_method_if_exists(self.agent, "agent_info") or {}
async def agent_info(self, _, __):
info = await acall_method_if_exists(self.agent, "agent_info") or {}
return InfoResponse(json_info=json.dumps(info))

def get_topic_producer_records(self, request_iterator, context):
# TODO: to be implementedbla
for _ in request_iterator:
yield None
async def get_topic_producer_records(self, request_iterator, context):
# TODO: to be implemented
async for _ in request_iterator:
yield

def read(self, requests: Iterable[SourceRequest], _):
async def read(self, requests: AsyncIterable[SourceRequest], _):
read_records = {}
op_result = []
read_thread = threading.Thread(
target=self.handle_read_requests,
args=(requests, read_records, op_result),
)
last_record_id = 0
read_thread.start()
read_requests_task = asyncio.create_task(
self.handle_read_requests(requests, read_records, op_result)
)
while True:
if len(op_result) > 0:
if op_result[0] is True:
break
raise op_result[0]
records = self.agent.read()
records = await asyncio.to_thread(self.agent.read)
if len(records) > 0:
records = [wrap_in_record(record) for record in records]
grpc_records = []
Expand All @@ -115,25 +113,25 @@ def read(self, requests: Iterable[SourceRequest], _):
grpc_records[i].record_id = last_record_id
read_records[last_record_id] = record
yield SourceResponse(records=grpc_records)
read_thread.join()
read_requests_task.cancel()

def handle_read_requests(
async def handle_read_requests(
self,
requests: Iterable[SourceRequest],
requests: AsyncIterable[SourceRequest],
read_records: Dict[int, Record],
read_result,
):
try:
for request in requests:
async for request in requests:
if len(request.committed_records) > 0:
for record_id in request.committed_records:
record = read_records.pop(record_id, None)
if record is not None:
call_method_if_exists(self.agent, "commit", record)
await acall_method_if_exists(self.agent, "commit", record)
if request.HasField("permanent_failure"):
failure = request.permanent_failure
record = read_records.pop(failure.record_id, None)
call_method_if_exists(
await acall_method_if_exists(
self.agent,
"permanent_failure",
record,
Expand All @@ -159,83 +157,49 @@ def handle_requests(handler, requests):
pass
thread.join()

def process(self, requests: Iterable[ProcessorRequest], _):
return self.handle_requests(self.handle_process_requests, requests)

def process_record(
self, source_record, get_processed_fn, get_processed_args, process_results
):
grpc_result = ProcessorResult(record_id=source_record.record_id)
try:
processed_records = get_processed_fn(*get_processed_args)
if isinstance(processed_records, Future):
processed_records.add_done_callback(
lambda f: self.process_record(
source_record, f.result, (), process_results
)
)
else:
for record in processed_records:
schemas, grpc_record = self.to_grpc_record(wrap_in_record(record))
for schema in schemas:
process_results.put(ProcessorResponse(schema=schema))
grpc_result.records.append(grpc_record)
process_results.put(ProcessorResponse(results=[grpc_result]))
except Exception as e:
grpc_result.error = str(e)
process_results.put(ProcessorResponse(results=[grpc_result]))

def handle_process_requests(
self, requests: Iterable[ProcessorRequest], process_results
):
for request in requests:
async def process(self, requests: AsyncIterable[ProcessorRequest], _):
async for request in requests:
if request.HasField("schema"):
schema = fastavro.parse_schema(json.loads(request.schema.value))
self.client_schemas[request.schema.schema_id] = schema
if len(request.records) > 0:
for source_record in request.records:
self.process_record(
source_record,
lambda r: self.agent.process(self.from_grpc_record(r)),
(source_record,),
process_results,
)
process_results.put(True)

def write(self, requests: Iterable[SinkRequest], _):
return self.handle_requests(self.handle_write_requests, requests)

def write_record(
self, source_record, get_written_fn, get_written_args, write_results
):
try:
result = get_written_fn(*get_written_args)
if isinstance(result, Future):
result.add_done_callback(
lambda f: self.write_record(
source_record, f.result, (), write_results
)
)
else:
write_results.put(SinkResponse(record_id=source_record.record_id))
except Exception as e:
write_results.put(
SinkResponse(record_id=source_record.record_id, error=str(e))
)

def handle_write_requests(self, requests: Iterable[SinkRequest], write_results):
for request in requests:
grpc_result = ProcessorResult(record_id=source_record.record_id)
try:
processed_records = await asyncio.to_thread(
self.agent.process, self.from_grpc_record(source_record)
)
if isinstance(processed_records, Future):
processed_records = await asyncio.wrap_future(
processed_records
)
for record in processed_records:
schemas, grpc_record = self.to_grpc_record(
wrap_in_record(record)
)
for schema in schemas:
yield ProcessorResponse(schema=schema)
grpc_result.records.append(grpc_record)
yield ProcessorResponse(results=[grpc_result])
except Exception as e:
grpc_result.error = str(e)
yield ProcessorResponse(results=[grpc_result])

async def write(self, requests: AsyncIterable[SinkRequest], context):
async for request in requests:
if request.HasField("schema"):
schema = fastavro.parse_schema(json.loads(request.schema.value))
self.client_schemas[request.schema.schema_id] = schema
if request.HasField("record"):
self.write_record(
request.record,
lambda r: self.agent.write(self.from_grpc_record(r)),
(request.record,),
write_results,
)
write_results.put(True)
try:
result = await asyncio.to_thread(
self.agent.write, self.from_grpc_record(request.record)
)
if isinstance(result, Future):
await asyncio.wrap_future(result)
yield SinkResponse(record_id=request.record.record_id)
except Exception as e:
yield SinkResponse(record_id=request.record.record_id, error=str(e))

def from_grpc_record(self, record: GrpcRecord) -> SimpleRecord:
return RecordWithId(
Expand Down Expand Up @@ -333,6 +297,12 @@ def call_method_if_exists(klass, method, *args, **kwargs):
return None


async def acall_method_if_exists(klass, method, *args, **kwargs):
return await asyncio.to_thread(
call_method_if_exists, klass, method, *args, **kwargs
)


class MainExecutor(threading.Thread):
def __init__(self, onError, klass, method, *args, **kwargs):
threading.Thread.__init__(self)
Expand Down Expand Up @@ -364,17 +334,16 @@ def call_method_new_thread_if_exists(klass, methodName, *args, **kwargs):
def crash_process():
logging.error("Main method with an error. Exiting process.")
os.exit(1)
return


def init_agent(configuration, context) -> Agent:
async def init_agent(configuration, context) -> Agent:
full_class_name = configuration["className"]
class_name = full_class_name.split(".")[-1]
module_name = full_class_name[: -len(class_name) - 1]
module = importlib.import_module(module_name)
agent = getattr(module, class_name)()
context_impl = DefaultAgentContext(configuration, context)
call_method_if_exists(agent, "init", configuration, context_impl)
await acall_method_if_exists(agent, "init", configuration, context_impl)
return agent


Expand All @@ -388,36 +357,35 @@ def get_persistent_state_directory(self) -> Optional[str]:


class AgentServer(object):
def __init__(self, target: str, config: str, context: str):
self.thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=10)
def __init__(self, target: str):
self.target = target
self.grpc_server = grpc.server(self.thread_pool)
self.grpc_server = grpc.aio.server()
self.port = self.grpc_server.add_insecure_port(target)
self.agent = None

async def init(self, config, context):
configuration = json.loads(config)
logging.debug("Configuration: " + json.dumps(configuration))
environment = configuration.get("environment", [])
logging.debug("Environment: " + json.dumps(environment))

for env in environment:
key = env["key"]
value = env["value"]
logging.debug(f"Setting environment variable {key}={value}")
os.environ[key] = value
self.agent = await init_agent(configuration, json.loads(context))

self.agent = init_agent(configuration, json.loads(context))

def start(self):
call_method_if_exists(self.agent, "start")
async def start(self):
await acall_method_if_exists(self.agent, "start")
call_method_new_thread_if_exists(self.agent, "main", crash_process)

agent_pb2_grpc.add_AgentServiceServicer_to_server(
AgentService(self.agent), self.grpc_server
)
self.grpc_server.start()

await self.grpc_server.start()
logging.info("GRPC Server started, listening on " + self.target)

def stop(self):
self.grpc_server.stop(None)
call_method_if_exists(self.agent, "close")
self.thread_pool.shutdown(wait=True)
async def stop(self):
await self.grpc_server.stop(None)
await acall_method_if_exists(self.agent, "close")
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,17 @@ def __init__(self, class_name, agent_config={}, context={}):
self.config["className"] = class_name
self.context = context
self.server: Optional[AgentServer] = None
self.channel: Optional[grpc.Channel] = None
self.channel: Optional[grpc.aio.Channel] = None
self.stub: Optional[AgentServiceStub] = None

def __enter__(self):
self.server = AgentServer(
"[::]:0", json.dumps(self.config), json.dumps(self.context)
)
self.server.start()
self.channel = grpc.insecure_channel("localhost:%d" % self.server.port)
async def __aenter__(self):
self.server = AgentServer("[::]:0")
await self.server.init(json.dumps(self.config), json.dumps(self.context))
await self.server.start()
self.channel = grpc.aio.insecure_channel("localhost:%d" % self.server.port)
self.stub = AgentServiceStub(channel=self.channel)
return self

def __exit__(self, *args):
self.channel.close()
self.server.stop()
async def __aexit__(self, *args):
await self.channel.close()
await self.server.stop()
Loading

0 comments on commit fb38344

Please sign in to comment.