# Building a Custom Environment in Aviary

## Defining a Custom Environment

The example below walks through defining a custom language agent environment in Aviary. 
We define a simple environment where an agent takes actions to modify a counter.

In [None]:
from pydantic import BaseModel

from aviary.core import Environment, Message, Tool, ToolRequestMessage


# State in this example is simply a counter
class CounterEnvState(BaseModel):
    count: int


class CounterEnv(Environment[CounterEnvState]):
    """A simple environment that allows an agent to modify a counter."""

    async def reset(self) -> tuple[list[Message], list[Tool]]:
        """Initialize the environment with a counter set to 0."""
        self.state = CounterEnvState(count=0)

        # Target count
        self.target = 10

        # Create tools allowing the agent to increment and decrement counter
        self.tools = [
            Tool.from_function(self.incr),
            Tool.from_function(self.decr),
        ]

        # Return an observation message with the counter and available tools
        return [Message(content=f"Count to 10. counter={self.state.count}")], self.tools

    async def step(
        self, action: ToolRequestMessage
    ) -> tuple[list[Message], float, bool, bool]:
        """Executes the tool call requested by the agent."""
        obs = await self.exec_tool_calls(action)

        reward = int(self.state.count == self.target)

        # Returns observations, reward, done, truncated
        return obs, reward, reward == 1, False

    def incr(self) -> str:
        """Increment the counter."""
        self.state.count += 1
        return f"counter={self.state.count}"

    def decr(self) -> str:
        """Decrement the counter."""
        self.state.count -= 1
        return f"counter={self.state.count}"

## Evaluating an Agent on the Environment

Following the definition of our custom environment, we can now evaluate a language agent
on the environment using Aviary's sister library LDP (https://github.com/Future-House/ldp).

In [None]:
from typing import NamedTuple

from ldp.agent import Agent
from ldp.alg import RolloutManager
from ldp.graph import LLMCallOp


class AgentState(
    NamedTuple("AgentState", [("messages", list[Message]), ("tools", list[Tool])])
):
    """A container for maintaining agent state across interactions."""

    __slots__ = ()


class SimpleAgent(Agent):
    def __init__(self, **kwargs: dict) -> None:
        self._llm_call_op = LLMCallOp(**kwargs)

    async def init_state(self, tools: list[Tool]) -> AgentState:
        # With namedtuple, you instantiate it like a regular class/function call
        return AgentState(messages=[], tools=tools)

    async def get_asv(
        self, agent_state: AgentState, obs: list[Message]
    ) -> tuple[ToolRequestMessage, AgentState, float]:
        """Take an action, observe new state, return value."""
        action: ToolRequestMessage = await self._llm_call_op(
            config={"name": "gpt-4o", "temperature": 0.1},
            msgs=agent_state.messages + obs,
            tools=agent_state.tools,
        )
        new_state: AgentState = AgentState(
            messages=agent_state.messages + obs + [action.value],
            tools=agent_state.tools,
        )
        # Return action, state, value
        return action, new_state, 0.0

In [None]:
# Create a simple agent and perform rollouts on the environment
agent: SimpleAgent = SimpleAgent()

runner: RolloutManager = RolloutManager(agent=agent)

trajectories: list[tuple] = await runner.sample_trajectories(
    environment_factory=CounterEnv,
    batch_size=2,
)

In [None]:
trajectories

[(Trajectory(traj_id='7c038fd5a0c8429d9b6f45a42f3120f1', steps=[Transition(timestep=0, agent_state=AgentState(messages=[], tools=[Tool(type='function', info=FunctionInfo(name='incr', description='Increment the counter.', parameters=Parameters(type='object', properties={}, required=[]))), Tool(type='function', info=FunctionInfo(name='decr', description='Decrement the counter.', parameters=Parameters(type='object', properties={}, required=[])))]), next_agent_state=AgentState(messages=[Message(role='user', content='Count to 10. counter=0'), ToolRequestMessage(role='assistant', content=None, function_call=None, tool_calls=[ToolCall(id='call_szAYyJep11zrOcWIVgI03oyx', type='function', function=ToolCallFunction(arguments={}, name='incr')), ToolCall(id='call_IHk77OWwlwMnqQ8dzYqxb8nC', type='function', function=ToolCallFunction(arguments={}, name='incr')), ToolCall(id='call_d1Ttcj52VSJVOiApRpW0QfBM', type='function', function=ToolCallFunction(arguments={}, name='incr')), ToolCall(id='call_eQ1B

# End