Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove agent context from Python API #441

Merged
merged 1 commit into from
Sep 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,12 @@
Source,
Processor,
CommitCallback,
TopicConsumer,
TopicProducer,
AgentContext,
)
from .util import SimpleRecord, SingleRecordProcessor

__all__ = [
"Record",
"RecordType",
"TopicConsumer",
"TopicProducer",
"AgentContext",
"Agent",
"Source",
"Sink",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,16 @@
#

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict, Union, Optional
from typing import Any, List, Tuple, Dict, Union

__all__ = [
"Record",
"RecordType",
"AgentContext",
"Agent",
"Source",
"Sink",
"Processor",
"CommitCallback",
"TopicConsumer",
"TopicProducer",
]


Expand Down Expand Up @@ -65,67 +61,6 @@ def headers(self) -> List[Tuple[str, Any]]:
RecordType = Union[Record, list, tuple]


class TopicConsumer(ABC):
"""The topic consumer interface"""

def start(self):
"""Start the consumer."""
pass

def close(self):
"""Close the consumer"""
pass

def read(self) -> List[Record]:
"""Read records from the topic."""
return []

def commit(self, records: List[Record]):
"""Commit records."""
pass

def get_native_consumer(self) -> Any:
"""Return the native wrapped consumer"""
pass

def get_info(self) -> Dict[str, Any]:
"""Return the consumer info"""
return {}


class TopicProducer(ABC):
"""The topic producer interface"""

def start(self):
"""Start the producer."""
pass

def close(self):
"""Close the producer."""
pass

def write(self, records: List[Record]):
"""Write records to the topic."""
pass

def get_native_producer(self) -> Any:
"""Return the native wrapped producer"""
pass

def get_info(self) -> Dict[str, Any]:
"""Return the producer info"""
return {}


@dataclass
class AgentContext(object):
"""The Agent context"""

topic_consumer: Optional[TopicConsumer] = None
topic_producer: Optional[TopicProducer] = None
global_agent_id: Optional[str] = None


class Agent(ABC):
"""The Agent interface"""

Expand All @@ -145,10 +80,6 @@ def agent_info(self) -> Dict[str, Any]:
"""Return the agent information."""
return {}

def set_context(self, context: AgentContext):
"""Set the agent context."""
pass


class Source(Agent):
"""The Source agent interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@
#

from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, List, Tuple, Dict, Union, Optional
from typing import Any, List, Tuple, Dict, Union

__all__ = [
"Record",
"RecordType",
"AgentContext",
"Agent",
"Source",
"Sink",
Expand Down Expand Up @@ -117,15 +115,6 @@ def get_info(self) -> Dict[str, Any]:
return {}


@dataclass
class AgentContext(object):
"""The Agent context"""

topic_consumer: Optional[TopicConsumer] = None
topic_producer: Optional[TopicProducer] = None
global_agent_id: Optional[str] = None


class Agent(ABC):
"""The Agent interface"""

Expand All @@ -145,10 +134,6 @@ def agent_info(self) -> Dict[str, Any]:
"""Return the agent information."""
return {}

def set_context(self, context: AgentContext):
"""Set the agent context."""
pass


class Source(Agent):
"""The Source agent interface
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
Processor,
Record,
CommitCallback,
AgentContext,
RecordType,
)
from .util import SingleRecordProcessor, SimpleRecord
Expand Down Expand Up @@ -92,7 +91,7 @@ class ComponentType(str, Enum):
class RuntimeAgent(Agent):
def __init__(
self,
agent: Agent,
agent: Union[Source, Sink, Processor],
component_type: ComponentType,
agent_id,
agent_type,
Expand Down Expand Up @@ -136,9 +135,6 @@ def get_agent_status(self) -> List[Dict[str, Any]]:
}
]

def set_context(self, context: AgentContext):
call_method_if_exists(self.agent, "set_context", context)


class RuntimeSource(RuntimeAgent, Source):
def __init__(self, source: Source, agent_id, agent_type, started_at=None):
Expand Down Expand Up @@ -329,15 +325,6 @@ def run(configuration, agent=None, agent_info: AgentInfo = AgentInfo(), max_loop

agent_info.processor = processor

agent_context = AgentContext(
topic_consumer=consumer,
topic_producer=producer,
global_agent_id=application_agent_id,
)

for component in {a.agent: a for a in {source, sink, processor}}.values():
component.set_context(agent_context)

run_main_loop(
source,
sink,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
Record,
SimpleRecord,
SingleRecordProcessor,
AgentContext,
)
from langstream.api import RecordType
from langstream_runtime import runtime
Expand Down Expand Up @@ -67,12 +66,6 @@ def test_simple_agent(_):
"start": 1,
"close": 1,
"set_commit_callback": 1,
"set_context": 1,
"agent_context": {
"consumer": "<class 'langstream_runtime.runtime.NoopTopicConsumer'>",
"producer": "<class 'langstream_runtime.runtime.NoopTopicProducer'>",
"global-agent-id": "testApplicationId-testAgentId"
},
"records": [
"some record 0 processed",
"some record 1 processed"
Expand All @@ -98,12 +91,6 @@ def test_simple_agent(_):
"start": 1,
"close": 1,
"set_commit_callback": 1,
"set_context": 1,
"agent_context": {
"consumer": "<class 'langstream_runtime.runtime.NoopTopicConsumer'>",
"producer": "<class 'langstream_runtime.runtime.NoopTopicProducer'>",
"global-agent-id": "testApplicationId-testAgentId"
},
"records": [
"some record 0 processed",
"some record 1 processed"
Expand All @@ -129,12 +116,6 @@ def test_simple_agent(_):
"start": 1,
"close": 1,
"set_commit_callback": 1,
"set_context": 1,
"agent_context": {
"consumer": "<class 'langstream_runtime.runtime.NoopTopicConsumer'>",
"producer": "<class 'langstream_runtime.runtime.NoopTopicProducer'>",
"global-agent-id": "testApplicationId-testAgentId"
},
"records": [
"some record 0 processed",
"some record 1 processed"
Expand All @@ -158,8 +139,6 @@ def __init__(self):
"start": 0,
"close": 0,
"set_commit_callback": 0,
"set_context": 0,
"agent_context": {},
"records": [],
}
self.commit_callback = None
Expand All @@ -186,14 +165,6 @@ def write(self, records):
def process_record(self, record: Record) -> List[RecordType]:
return [(record.value() + " processed",)]

def set_context(self, agent_context: AgentContext):
self.context["set_context"] += 1
self.context["agent_context"] = {
"consumer": str(type(agent_context.topic_consumer)),
"producer": str(type(agent_context.topic_producer)),
"global-agent-id": agent_context.global_agent_id,
}

def agent_info(self) -> Dict[str, Any]:
return self.context

Expand Down
Loading