# Tutorial 3: Building the Environment

**Goal:** Create a multi-agent environment using HERON's `MultiAgentEnv` base class.

**Time:** ~15 minutes

---

## The Environment's Role

The environment connects your agent hierarchy to RL training frameworks:

```
RLlib / StableBaselines3
    |
    +-- MultiAgentEnv (step, reset)
            |
            +-- SystemAgent (L3) <-> ProxyAgent (state management)
                    |
                    +-- CoordinatorAgent (L2) - Microgrid
                    |       +-- FieldAgent (L1) - Battery
                    |       +-- FieldAgent (L1) - Generator
                    +-- CoordinatorAgent (L2) - Microgrid
                            +-- ...
```

**Key Pattern:** Environment receives `system_agent` in constructor. The `ProxyAgent` handles state management and visibility filtering automatically.

## Step 1: Define Simple Agents

First, let's create minimal agents for this tutorial. In real applications, you'd use the agents from Tutorial 2.

In [None]:
import numpy as np
from typing import Any, ClassVar, Dict, List, Optional, Sequence
from dataclasses import dataclass
from gymnasium.spaces import Box

from heron.core.feature import FeatureProvider
from heron.core.action import Action
from heron.agents.field_agent import FieldAgent
from heron.agents.coordinator_agent import CoordinatorAgent
from heron.agents.system_agent import SystemAgent
from heron.scheduling.tick_config import TickConfig


# === Features (using current API with ClassVar visibility) ===
@dataclass(slots=True)
class BatterySOC(FeatureProvider):
    """Battery state of charge feature."""
    visibility: ClassVar[Sequence[str]] = ['owner', 'upper_level']
    soc: float = 0.5
    capacity: float = 2.0


@dataclass(slots=True)
class GenOutput(FeatureProvider):
    """Generator output feature."""
    visibility: ClassVar[Sequence[str]] = ['owner', 'upper_level', 'system']
    p_mw: float = 0.0
    p_max: float = 5.0


@dataclass(slots=True)
class SystemFrequency(FeatureProvider):
    """System frequency feature."""
    visibility: ClassVar[Sequence[str]] = ['system']
    frequency_hz: float = 60.0
    nominal_hz: float = 60.0


# === L1: Field Agents ===
class SimpleBattery(FieldAgent):
    """Simple battery field agent."""
    
    def __init__(self, agent_id: str, capacity: float = 2.0, **kwargs):
        self.capacity = capacity
        self.max_power = 0.5
        # Create features to pass to parent
        features = [BatterySOC(soc=0.5, capacity=capacity)]
        super().__init__(agent_id=agent_id, features=features, **kwargs)

    def init_action(self, features: List[FeatureProvider] = []) -> Action:
        action = Action()
        action.set_specs(dim_c=1, range=(np.array([-1.0]), np.array([1.0])))
        return action

    def set_action(self, action: Any, *args, **kwargs) -> None:
        if hasattr(action, '__iter__'):
            self.action.set_values(action[0] if len(action) > 0 else 0.0)
        else:
            self.action.set_values(float(action))

    def set_state(self, **kwargs) -> None:
        if 'soc' in kwargs:
            self.state.features['BatterySOC'].soc = np.clip(float(kwargs['soc']), 0.1, 0.9)

    def apply_action(self) -> None:
        # Get current SOC and action
        soc_feature = self.state.features['BatterySOC']
        power = self.action.vector()[0] * self.max_power
        # Update SOC based on power (simplified physics)
        new_soc = soc_feature.soc + power / self.capacity
        self.set_state(soc=new_soc)

    def compute_local_reward(self, local_state: dict) -> float:
        return 0.0  # Simplified - no local reward

    @property
    def cost(self) -> float:
        return 0.0

    @property
    def safety(self) -> float:
        soc = self.state.features['BatterySOC'].soc
        return max(0, 0.2 - soc) + max(0, soc - 0.8)  # Penalize extreme SOC


class SimpleGen(FieldAgent):
    """Simple generator field agent."""
    
    def __init__(self, agent_id: str, p_max: float = 5.0, cost_per_mwh: float = 50.0, **kwargs):
        self.p_max = p_max
        self.cost_per_mwh = cost_per_mwh
        features = [GenOutput(p_mw=0.0, p_max=p_max)]
        super().__init__(agent_id=agent_id, features=features, **kwargs)

    def init_action(self, features: List[FeatureProvider] = []) -> Action:
        action = Action()
        action.set_specs(dim_c=1, range=(np.array([0.0]), np.array([1.0])))
        return action

    def set_action(self, action: Any, *args, **kwargs) -> None:
        if hasattr(action, '__iter__'):
            self.action.set_values(action[0] if len(action) > 0 else 0.0)
        else:
            self.action.set_values(float(action))

    def set_state(self, **kwargs) -> None:
        if 'p_mw' in kwargs:
            self.state.features['GenOutput'].p_mw = float(kwargs['p_mw'])

    def apply_action(self) -> None:
        # Set generator output based on action
        gen_feature = self.state.features['GenOutput']
        power = self.action.vector()[0] * self.p_max
        self.set_state(p_mw=power)

    def compute_local_reward(self, local_state: dict) -> float:
        return 0.0

    @property
    def cost(self) -> float:
        p_mw = self.state.features['GenOutput'].p_mw
        return p_mw * self.cost_per_mwh

    @property
    def safety(self) -> float:
        return 0.0


# === L2: Coordinator Agent ===
class SimpleMicrogrid(CoordinatorAgent):
    """Simple microgrid coordinator with pre-built subordinates."""
    
    def __init__(self, agent_id: str, load: float = 3.0, **kwargs):
        self.load = load
        
        # Build subordinates BEFORE calling super().__init__
        battery = SimpleBattery(f'{agent_id}_bat', upstream_id=agent_id)
        gen = SimpleGen(f'{agent_id}_gen', upstream_id=agent_id)
        subordinates = {battery.agent_id: battery, gen.agent_id: gen}
        
        super().__init__(
            agent_id=agent_id,
            subordinates=subordinates,
            tick_config=TickConfig.deterministic(tick_interval=60.0),
            **kwargs
        )

    @property
    def cost(self) -> float:
        return sum(s.cost for s in self.subordinates.values())

    @property
    def safety(self) -> float:
        return sum(s.safety for s in self.subordinates.values())


# === L3: System Agent ===
class SimpleGridSystem(SystemAgent):
    """System agent managing multiple microgrids."""
    
    def __init__(self, agent_id: str, subordinates: Dict[str, SimpleMicrogrid]):
        features = [SystemFrequency(frequency_hz=60.0)]
        super().__init__(
            agent_id=agent_id,
            features=features,
            subordinates=subordinates,
            tick_config=TickConfig.deterministic(tick_interval=300.0),
        )

    @property
    def cost(self) -> float:
        return sum(mg.cost for mg in self.subordinates.values())

    @property
    def safety(self) -> float:
        return sum(mg.safety for mg in self.subordinates.values())


print("Agents defined: FieldAgent -> CoordinatorAgent -> SystemAgent")

## Step 2: Build the Multi-Agent Environment

HERON's `MultiAgentEnv` provides:
- Agent registration and hierarchy management
- `ProxyAgent` for state management and visibility filtering
- Standard `reset()` and `step()` interface

You need to implement three abstract methods:
1. `global_state_to_env_state()` - Convert proxy state to your simulation state
2. `run_simulation()` - Run physics/power flow
3. `env_state_to_global_state()` - Convert simulation results back to proxy state

In [None]:
from heron.envs.base import MultiAgentEnv
from dataclasses import dataclass, field


@dataclass
class SimpleEnvState:
    """Custom environment state for our simple simulation.
    
    This is the bridge between HERON's global state and your physics engine.
    """
    # Device setpoints from agent actions
    battery_power: Dict[str, float] = field(default_factory=dict)  # mg_id -> power
    gen_power: Dict[str, float] = field(default_factory=dict)      # mg_id -> power
    
    # Simulation results (populated by run_simulation)
    battery_soc: Dict[str, float] = field(default_factory=dict)
    frequency: float = 60.0
    converged: bool = True


class SimpleMultiMicrogridEnv(MultiAgentEnv):
    """Multi-agent environment demonstrating HERON patterns.
    
    Key Patterns:
    1. Constructor receives pre-built system_agent
    2. ProxyAgent handles state management automatically
    3. Three abstract methods handle state conversion and simulation
    """
    
    def __init__(
        self,
        system_agent: SystemAgent,
        max_steps: int = 96,
        share_reward: bool = True,
        penalty: float = 10.0,
        dt: float = 1.0,
    ):
        """Initialize environment.
        
        Args:
            system_agent: Pre-initialized SystemAgent with full hierarchy
            max_steps: Maximum steps per episode
            share_reward: Whether to use shared (cooperative) rewards
            penalty: Penalty coefficient for power imbalance
            dt: Time step duration
        """
        self.max_steps = max_steps
        self.share_reward = share_reward
        self.penalty = penalty
        self.dt = dt
        self._step_count = 0
        
        # Store microgrid loads (from coordinator agents)
        self._mg_loads = {}
        for mg_id, mg in system_agent.subordinates.items():
            self._mg_loads[mg_id] = getattr(mg, 'load', 3.0)
        
        # Call parent init - this registers all agents and creates ProxyAgent
        super().__init__(
            system_agent=system_agent,
            env_id="simple_microgrids",
        )

    # ============================================
    # Required Abstract Methods
    # ============================================
    
    def global_state_to_env_state(self, global_state: Dict) -> SimpleEnvState:
        """Convert proxy's global state to custom env state for simulation.
        
        This is called BEFORE run_simulation() to extract device setpoints
        from agent states.
        
        Args:
            global_state: Dict from proxy with structure:
                {"agent_states": {agent_id: state_dict, ...}}
        
        Returns:
            Custom EnvState with device setpoints for physics simulation
        """
        env_state = SimpleEnvState()
        agent_states = global_state.get("agent_states", {})
        
        # Extract battery and generator setpoints from field agent states
        for agent_id, state_dict in agent_states.items():
            features = state_dict.get("features", {})
            
            # Battery agents have BatterySOC feature
            if "BatterySOC" in features:
                # Extract microgrid ID from agent_id (e.g., "mg_0_bat" -> "mg_0")
                mg_id = "_".join(agent_id.split("_")[:-1])
                env_state.battery_soc[mg_id] = features["BatterySOC"].get("soc", 0.5)
                # Action is already applied to agent.action - we read it during simulation
            
            # Generator agents have GenOutput feature  
            if "GenOutput" in features:
                mg_id = "_".join(agent_id.split("_")[:-1])
                env_state.gen_power[mg_id] = features["GenOutput"].get("p_mw", 0.0)
        
        return env_state

    def run_simulation(self, env_state: SimpleEnvState, *args, **kwargs) -> SimpleEnvState:
        """Run physics simulation.
        
        This is where your domain-specific simulation logic goes:
        - Power flow analysis
        - Battery dynamics
        - Grid frequency response
        
        Args:
            env_state: EnvState with current device states
        
        Returns:
            Updated EnvState with simulation results
        """
        # Simple physics: compute power balance and frequency
        total_gen = sum(env_state.gen_power.values())
        total_load = sum(self._mg_loads.values())
        
        # Frequency deviation based on power imbalance
        env_state.frequency = 60.0 + (total_gen - total_load) * 0.01
        env_state.converged = True
        
        self._step_count += 1
        return env_state

    def env_state_to_global_state(self, env_state: SimpleEnvState) -> Dict:
        """Convert simulation results back to global state format.
        
        This is called AFTER run_simulation() to update agent states
        with simulation results.
        
        Args:
            env_state: EnvState with simulation results
        
        Returns:
            Dict to merge into proxy's global state
        """
        # Get current serialized agent states from proxy
        agent_states = self.proxy_agent.get_serialized_agent_states()
        
        # Update system-level features with frequency
        for agent_id, state_dict in agent_states.items():
            features = state_dict.get("features", {})
            if "SystemFrequency" in features:
                features["SystemFrequency"]["frequency_hz"] = env_state.frequency
        
        return {"agent_states": agent_states}

    # ============================================
    # Override reset for custom initialization
    # ============================================
    
    def reset(self, *, seed: Optional[int] = None, **kwargs):
        """Reset environment."""
        self._step_count = 0
        
        # Call parent reset (resets all agents, clears proxy state cache)
        obs, info = super().reset(seed=seed, **kwargs)
        return obs, info


print("Environment defined with MultiAgentEnv pattern")

## Step 3: Create and Test the Environment

The key pattern: **Build agents first (bottom-up), then pass to environment.**

In [None]:
# Build agent hierarchy bottom-up
num_microgrids = 3
loads = [3.0, 4.0, 2.5]

# Step 1: Create coordinators with their subordinates
microgrids = {}
for i in range(num_microgrids):
    mg_id = f'mg_{i}'
    mg = SimpleMicrogrid(agent_id=mg_id, load=loads[i])
    microgrids[mg_id] = mg

# Step 2: Create system agent with coordinators
system_agent = SimpleGridSystem(
    agent_id='grid_system',
    subordinates=microgrids
)

# Step 3: Create environment with system agent
env = SimpleMultiMicrogridEnv(
    system_agent=system_agent,
    max_steps=10,
    share_reward=True
)

print("=== HERON 3-Level Hierarchy ===")
print(f"L3 SystemAgent: {env._system_agent.agent_id}")
print(f"L2 Coordinators: {list(env._system_agent.subordinates.keys())}")
for mg_id, mg in env._system_agent.subordinates.items():
    print(f"  {mg_id} -> L1 Subordinates: {list(mg.subordinates.keys())}")

print(f"\nRegistered agents: {list(env.registered_agents.keys())}")
print(f"ProxyAgent: {env.proxy_agent}")

In [None]:
# Test reset and step
obs, infos = env.reset()
print("After reset:")
print(f"  Observations: {list(obs.keys())}")
print(f"  Info keys: {list(infos.keys())}")

In [None]:
# Run a few steps with random actions
print("Running 5 steps with random actions...\n")

for step in range(5):
    # Generate random actions for each microgrid
    # Actions format depends on your agent hierarchy
    actions = {}
    for mg_id in microgrids.keys():
        # Each microgrid gets actions for its subordinates
        bat_id = f"{mg_id}_bat"
        gen_id = f"{mg_id}_gen"
        actions[bat_id] = np.random.uniform(-1, 1, size=(1,))
        actions[gen_id] = np.random.uniform(0, 1, size=(1,))
    
    obs, rewards, terminateds, truncateds, infos = env.step(actions)
    
    print(f"Step {step + 1}:")
    print(f"  Rewards: {rewards}")
    print(f"  Terminated: {terminateds.get('__all__', False)}")

## Step 4: Understanding the Data Flow

### Environment Step Sequence

When `env.step(actions)` is called:

```
1. system_agent.execute(actions, proxy_agent)
   |
   +-- Agents observe via proxy (visibility filtering)
   +-- Agents compute actions (policy or upstream)
   +-- Agents apply actions (update local state)
   |
2. global_state_to_env_state(proxy.global_state)
   +-- Extract device setpoints from agent states
   |
3. run_simulation(env_state)
   +-- Run physics (power flow, frequency response)
   +-- Return updated env_state
   |
4. env_state_to_global_state(env_state)
   +-- Convert results to agent state updates
   +-- Proxy distributes to agents
   |
5. proxy_agent.get_step_results()
   +-- Returns (obs, rewards, terminated, truncated, info)
```

### Why This Pattern?

| Component | Responsibility |
|-----------|---------------|
| **Environment** | Simulation orchestration, reset/step interface |
| **ProxyAgent** | State management, visibility filtering |
| **SystemAgent** | Agent hierarchy coordination |
| **CoordinatorAgent** | Subordinate management |
| **FieldAgent** | State/action representation, reward computation |

## Step 5: Real-World Environment Example

For production environments, see `HierarchicalMicrogridEnv` in `powergrid/envs/hierarchical_microgrid_env.py`.

Key additions for real environments:

1. **Dataset Integration**: Load time-series data for loads, renewables, prices
2. **Power Flow Simulation**: Use PandaPower for AC power flow analysis
3. **EnvState Class**: Structured state with device setpoints and results

In [None]:
# Example: The production EnvState class
from powergrid.envs.common import EnvState

# EnvState provides structured storage for power grid simulation
state = EnvState()

# Set device setpoints (from agent actions)
state.set_device_setpoint("battery_1", P=0.5, Q=0.0, in_service=True)
state.set_device_setpoint("gen_1", P=2.0, Q=0.1, in_service=True)

# After simulation, update with results
state.update_power_flow_results({
    "converged": True,
    "voltage_min": 0.95,
    "voltage_max": 1.05,
    "grid_power": 1.5,
})

print("EnvState example:")
print(f"  Device setpoints: {state.device_setpoints}")
print(f"  Power flow results: {state.power_flow_results}")
print(f"  Converged: {state.converged}")

## Key Takeaways

1. **Constructor Pattern**: Environment receives pre-built `system_agent`
   ```python
   env = MyEnv(system_agent=system_agent, ...)
   ```

2. **Three Required Methods**:
   ```python
   def global_state_to_env_state(self, global_state: Dict) -> EnvState:
       # Extract device setpoints from agent states
       
   def run_simulation(self, env_state: EnvState) -> EnvState:
       # Run physics simulation
       
   def env_state_to_global_state(self, env_state: EnvState) -> Dict:
       # Convert results back to agent state format
   ```

3. **ProxyAgent Handles State**: Visibility filtering and state distribution are automatic

4. **Build Agents Bottom-Up**: FieldAgents -> CoordinatorAgents -> SystemAgent -> Environment

---

**Next:** [04_training_with_rllib.ipynb](04_training_with_rllib.ipynb) - Training with RLlib