## Defining a Custom Environment

The example below walks through defining a custom language agent environment in Aviary. 
We define a simple environment where an agent takes actions to modify a counter.

In [22]:
from collections import namedtuple
from aviary.core import Environment, Message, ToolRequestMessage, Tool
from pydantic import BaseModel

# State in this example is simply a counter
class CounterEnvState(BaseModel):
        count: int

class CounterEnv(Environment[CounterEnvState]):
    """A simple environment that allows an agent to modify a counter."""
    
    async def reset(self):
        """Initialize the environment with a counter set to 0."""
        self.state = CounterEnvState(count=0)
        
        # Create tools allowing the agent to increment and decrement counter
        self.tools = [
            Tool.from_function(self.incr),
            Tool.from_function(self.decr),
        ]
        
        # Return an observation message with the counter and available tools
        return [Message(content=f"counter={self.state.count}")], self.tools

    async def step(self, action: ToolRequestMessage):
        """Executes the tool call requested by the agent."""
        obs = await self.exec_tool_calls(action)
        
        # The reward is the square of the current count
        reward = self.state.count ** 2
        
        # Returns observations, reward, done, truncated
        return obs, reward, reward < 0, False

    def incr(self):
        """Increment the counter."""
        self.state.count += 1
        return f"counter={self.state.count}"

    def decr(self):
        """Decrement the counter."""
        self.state.count -= 1
        return f"counter={self.state.count}"

## Evaluating an Agent on the Environment

Following the definition of our custom environment, we can now evaluate a language agent
on the environment using Aviary's sister library LDP (https://github.com/Future-House/ldp).

In [None]:
from ldp.agent import Agent
from ldp.graph import LLMCallOp
from ldp.alg.rollout import RolloutManager

class AgentState:
    """A container for maintaining agent state across interactions."""
    def __init__(self, messages, tools):
        self.messages = messages
        self.tools = tools

class SimpleAgent(Agent):
    def __init__(self, **kwargs):
        self._llm_call_op = LLMCallOp(**kwargs)

    async def init_state(self, tools):
        return AgentState([], tools)

    async def get_asv(self, agent_state, obs):
        """Take an action, observe new state, return value"""
        action = await self._llm_call_op(
            config={"name": "gpt-4o", "temperature": 0.1},
            msgs=agent_state.messages + obs,
            tools=agent_state.tools,
        )
        new_state = AgentState(
            messages=agent_state.messages + obs + [action.value], 
            tools=agent_state.tools,
        )
        # Return action, state, value
        return action, new_state, 0.0

import litellm
# Modify the MAX_CALLBACKS constant before using LiteLLM
# Note this is currently a known issue in litellm: https://github.com/BerriAI/litellm/issues/9792
litellm.utils.logging_callback_manager.MAX_CALLBACKS = 100  # Set a higher limit

# Create a simple agent and perform rollouts on the environment
agent = SimpleAgent()

runner = RolloutManager(agent=agent)

trajectories = await runner.sample_trajectories(
    environment_factory=CounterEnv, 
    batch_size=2,
)