# Tutorial 3: Building the Environment

**Goal:** Create a PettingZoo-compatible environment using HERON's adapter.

**Time:** ~15 minutes

---

## The Environment's Role

The environment **only talks to L3 SystemAgent**, which coordinates the hierarchy:

```
RLlib / StableBaselines3
    │
    └── PettingZoo API (step, reset, observe)
            │
            └── HERON Environment ←→ SystemAgent (L3)
                                          │
                                          ├── Microgrid 0 (L2 CoordinatorAgent)
                                          │   ├── Battery (L1 FieldAgent)
                                          │   └── Generator (L1 FieldAgent)
                                          ├── Microgrid 1 ...
                                          └── Microgrid 2 ...
```

**Critical pattern:** Environment ONLY interacts with L3. L3 coordinates L2. L2 coordinates L1.

## Step 1: Define Our Simple Agents

Agents from Tutorial 2 - **no `step()` method** (physics lives in environment).

In [None]:
import numpy as np
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from gymnasium.spaces import Box

from heron.core.feature import FeatureProvider
from heron.agents.field_agent import FieldAgent
from heron.agents.coordinator_agent import CoordinatorAgent
from heron.agents.system_agent import SystemAgent
from heron.protocols.vertical import SetpointProtocol
from heron.scheduling.tick_config import TickConfig


# === Features ===
@dataclass
class BatterySOC(FeatureProvider):
    visibility = ['owner', 'upper_level']
    soc: float = 0.5

    def vector(self) -> np.ndarray:
        return np.array([self.soc], dtype=np.float32)

    def names(self): return ['soc']
    def to_dict(self): return {'soc': self.soc}
    @classmethod
    def from_dict(cls, d): return cls(**d)
    def set_values(self, **kw):
        if 'soc' in kw: self.soc = np.clip(float(kw['soc']), 0.1, 0.9)


@dataclass
class GenOutput(FeatureProvider):
    visibility = ['owner', 'upper_level', 'system']
    p_mw: float = 0.0
    p_max: float = 5.0

    def vector(self) -> np.ndarray:
        return np.array([self.p_mw / self.p_max], dtype=np.float32)

    def names(self): return ['p_norm']
    def to_dict(self): return {'p_mw': self.p_mw, 'p_max': self.p_max}
    @classmethod
    def from_dict(cls, d): return cls(**d)
    def set_values(self, **kw):
        if 'p_mw' in kw: self.p_mw = float(kw['p_mw'])


@dataclass
class SystemFrequency(FeatureProvider):
    visibility = ['system']
    frequency_hz: float = 60.0

    def vector(self) -> np.ndarray:
        return np.array([self.frequency_hz - 60.0], dtype=np.float32)

    def names(self): return ['freq_deviation']
    def to_dict(self): return {'frequency_hz': self.frequency_hz}
    @classmethod
    def from_dict(cls, d): return cls(**d)
    def set_values(self, **kw):
        if 'frequency_hz' in kw: self.frequency_hz = float(kw['frequency_hz'])


# === L1: Field Agents ===
class SimpleBattery(FieldAgent):
    def __init__(self, agent_id: str, capacity: float = 2.0, **kwargs):
        self.capacity = capacity
        self.max_power = 0.5
        super().__init__(agent_id=agent_id, **kwargs)

    def set_state(self):
        self.soc_feature = BatterySOC(soc=0.5)
        self.state.features.append(self.soc_feature)

    def set_action(self):
        self.action.set_specs(dim_c=1, range=(np.array([-1.0]), np.array([1.0])))

    def update_state(self, **env_state):
        if 'soc' in env_state:
            self.soc_feature.set_values(soc=env_state['soc'])


class SimpleGen(FieldAgent):
    def __init__(self, agent_id: str, p_max: float = 5.0, cost_per_mwh: float = 50.0, **kwargs):
        self.p_max = p_max
        self.cost_per_mwh = cost_per_mwh
        super().__init__(agent_id=agent_id, **kwargs)

    def set_state(self):
        self.output_feature = GenOutput(p_mw=0.0, p_max=self.p_max)
        self.state.features.append(self.output_feature)

    def set_action(self):
        self.action.set_specs(dim_c=1, range=(np.array([0.0]), np.array([1.0])))

    def update_state(self, **env_state):
        if 'p_mw' in env_state:
            self.output_feature.set_values(p_mw=env_state['p_mw'])


# === L2: Coordinator Agent ===
class SimpleMicrogrid(CoordinatorAgent):
    def __init__(self, agent_id: str, load: float = 3.0, **kwargs):
        self.load = load
        self._my_id = agent_id
        super().__init__(
            agent_id=agent_id,
            protocol=SetpointProtocol(),
            tick_config=TickConfig.deterministic(tick_interval=60.0),
            **kwargs
        )

    def _build_subordinates(self, configs, env_id=None, upstream_id=None):
        self.battery = SimpleBattery(f'{self._my_id}_bat', upstream_id=self._my_id)
        self.gen = SimpleGen(f'{self._my_id}_gen', upstream_id=self._my_id)
        return {self.battery.agent_id: self.battery, self.gen.agent_id: self.gen}


# === L3: System Agent ===
class SimpleGridSystem(SystemAgent):
    """System agent managing multiple microgrids."""
    
    def __init__(self, agent_id: str, microgrids: List[SimpleMicrogrid] = None):
        self._init_mgs = microgrids or []
        super().__init__(
            agent_id=agent_id,
            tick_config=TickConfig.deterministic(tick_interval=300.0),
        )
        # Setup coordinators
        if microgrids:
            self.coordinators = {mg.agent_id: mg for mg in microgrids}
            for mg in microgrids:
                mg.upstream_id = agent_id

    def set_state(self):
        self.freq_feature = SystemFrequency(frequency_hz=60.0)
        self.state.features.append(self.freq_feature)


print("All 3 levels defined: FieldAgent -> CoordinatorAgent -> SystemAgent")

## Step 2: Build the Multi-Agent Environment

HERON provides `PettingZooParallelEnv` - an adapter that combines:
- PettingZoo's `ParallelEnv` interface (for RL framework compatibility)
- HERON's `HeronEnvCore` mixin (for agent management, event-driven execution)

**Critical Design Principle: Environment ↔ L3 Only**

The environment should:
1. **Maintain its own physics state** (battery SOCs, gen outputs, etc.)
2. **Only interact with SystemAgent (L3)** - never directly access L2 or L1
3. **Push state via `update_from_environment()`** - L3 propagates to L2 → L1
4. **Get observations via `observe()`** - flows L1 → L2 → L3 → Env
```

In [None]:
from heron.envs.adapters import PettingZooParallelEnv
from heron.core.observation import OBS_KEY_SUBORDINATE_OBS


class SimpleMultiMicrogridEnv(PettingZooParallelEnv):
    """Multi-agent environment demonstrating proper HERON patterns.
    
    Key Patterns:
    1. Environment ONLY talks to SystemAgent (L3)
    2. Observations flow: L1 → L2 → L3 → Env
    3. Actions flow: Env → L3 → L2 → L1
    4. State updates flow: Env → L3 → L2 → L1
    """
    
    metadata = {'render_modes': ['human'], 'name': 'simple_microgrids_v0'}
    
    def __init__(
        self,
        num_microgrids: int = 3,
        max_steps: int = 96,
        share_reward: bool = True,
        penalty: float = 10.0,
    ):
        super().__init__(env_id="simple_microgrids")
        
        self.num_microgrids = num_microgrids
        self.max_steps = max_steps
        self.share_reward = share_reward
        self.penalty = penalty
        self.dt = 1.0
        
        # === ENVIRONMENT'S OWN PHYSICS STATE ===
        self._physics = {
            'frequency': 60.0,
            'microgrids': {}
        }
        
        loads = [3.0, 4.0, 2.5]
        for i in range(num_microgrids):
            mg_id = f'mg_{i}'
            self._physics['microgrids'][mg_id] = {
                'load': loads[i % len(loads)],
                'battery': {'soc': 0.5, 'capacity': 2.0, 'max_power': 0.5},
                'generator': {'p_mw': 0.0, 'p_max': 5.0, 'cost_per_mwh': 50.0}
            }
        
        # === BUILD AGENT HIERARCHY (L3 -> L2 -> L1) ===
        microgrid_list = []
        for i in range(num_microgrids):
            mg = SimpleMicrogrid(agent_id=f'mg_{i}', load=loads[i % len(loads)])
            microgrid_list.append(mg)
        
        self._grid_system = SimpleGridSystem(
            agent_id='grid_system',
            microgrids=microgrid_list
        )
        self.set_system_agent(self._grid_system)
        
        # PettingZoo setup
        self._mg_ids = [f'mg_{i}' for i in range(num_microgrids)]
        self._set_agent_ids(self._mg_ids)
        
        obs_spaces = {aid: Box(-np.inf, np.inf, (3,), np.float32) for aid in self._mg_ids}
        act_spaces = {aid: Box(np.array([-1, 0]), np.array([1, 1]), dtype=np.float32) 
                      for aid in self._mg_ids}
        self._init_spaces(action_spaces=act_spaces, observation_spaces=obs_spaces)
        self._step_count = 0
    
    def _build_env_state(self) -> Dict[str, Any]:
        """Build env_state dict for pushing to L3."""
        env_state = {
            'system': {'SystemFrequency': {'frequency_hz': self._physics['frequency']}},
            'coordinators': {}
        }
        for mg_id, mg_physics in self._physics['microgrids'].items():
            bat_id, gen_id = f'{mg_id}_bat', f'{mg_id}_gen'
            env_state['coordinators'][mg_id] = {
                'subordinates': {
                    bat_id: {'BatterySOC': {'soc': mg_physics['battery']['soc']}},
                    gen_id: {'GenOutput': {'p_mw': mg_physics['generator']['p_mw']}}
                }
            }
        return env_state
    
    def _extract_observations(self, sys_obs) -> Dict[str, np.ndarray]:
        """Extract per-microgrid observations from SystemAgent's observation."""
        observations = {}
        for mg_id in self._mg_ids:
            mg_physics = self._physics['microgrids'][mg_id]
            observations[mg_id] = np.array([
                mg_physics['battery']['soc'],
                mg_physics['generator']['p_mw'] / mg_physics['generator']['p_max'],
                mg_physics['load'] / 10.0
            ], dtype=np.float32)
        return observations
    
    def _get_global_state(self) -> Dict[str, Any]:
        """Build global_state dict for observe().
        
        global_state is optional context passed down to all agents during observe().
        Use it for environment-wide info agents might need for observations:
        - Simulation time, market prices, weather conditions, etc.
        
        In this simple tutorial, we don't need any global context,
        so we return an empty dict. In real scenarios:
        
            return {
                'time': self._step_count,
                'market_price': self._electricity_price,
                'weather': {'solar': 800, 'wind': 5.0},
            }
        """
        return {}
    
    def reset(self, seed=None, options=None):
        self._step_count = 0
        self._agents = self._possible_agents.copy()
        
        # Reset physics
        self._physics['frequency'] = 60.0
        for mg_id in self._physics['microgrids']:
            self._physics['microgrids'][mg_id]['battery']['soc'] = 0.5
            self._physics['microgrids'][mg_id]['generator']['p_mw'] = 0.0
        
        # Push to L3 → L2 → L1
        self._grid_system.update_from_environment(self._build_env_state())
        
        # Get observations L1 → L2 → L3 → Env
        sys_obs = self._grid_system.observe(global_state=self._get_global_state())
        return self._extract_observations(sys_obs), {aid: {} for aid in self.agents}
    
    def step(self, actions: Dict[str, np.ndarray]):
        """Environment step with proper HERON action/state flow.
        
        Flow:
        1. Get observation through L3 (for action routing context)
        2. Route actions: L3.act() → L2.act() → L1.act()
        3. Run physics (environment's job)
        4. Push state: Env → L3 → L2 → L1
        5. Get observations: L1 → L2 → L3 → Env
        """
        self._step_count += 1
        
        # === 1. GET OBSERVATION THROUGH L3 ===
        # global_state: optional context for agents (time, prices, weather, etc.)
        # Empty dict here since our simple agents don't need global context
        sys_obs = self._grid_system.observe(global_state=self._get_global_state())
        
        # === 2. ROUTE ACTIONS THROUGH HIERARCHY ===
        # Actions flow: Env → L3 → L2 → L1
        self._grid_system.act(sys_obs, upstream_action=actions)
        
        # === 3. RUN PHYSICS (environment's job) ===
        total_gen, total_load, total_cost, total_imbalance = 0.0, 0.0, 0.0, 0.0
        results = {}
        
        for mg_id in self._mg_ids:
            action = actions.get(mg_id, np.zeros(2))
            mg_phys = self._physics['microgrids'][mg_id]
            bat, gen = mg_phys['battery'], mg_phys['generator']
            
            # Battery physics
            bat_power = action[0] * bat['max_power']
            bat['soc'] = np.clip(bat['soc'] + bat_power * self.dt / bat['capacity'], 0.1, 0.9)
            
            # Generator physics
            gen_power = action[1] * gen['p_max']
            gen['p_mw'] = gen_power
            gen_cost = gen_power * self.dt * gen['cost_per_mwh']
            
            # Balance
            imbalance = abs(mg_phys['load'] - (gen_power - bat_power))
            results[mg_id] = {'cost': gen_cost, 'imbalance': imbalance}
            total_gen += gen_power
            total_load += mg_phys['load']
            total_cost += gen_cost
            total_imbalance += imbalance
        
        # System frequency
        self._physics['frequency'] = 60.0 + (total_gen - total_load) * 0.01
        
        # === 4. PUSH STATE TO L3 (propagates to L2 → L1) ===
        self._grid_system.update_from_environment(self._build_env_state())
        
        # === 5. GET OBSERVATIONS THROUGH L3 ===
        sys_obs = self._grid_system.observe(global_state=self._get_global_state())
        observations = self._extract_observations(sys_obs)
        
        # Rewards
        collective_reward = -(total_cost + self.penalty * total_imbalance)
        if self.share_reward:
            rewards = {aid: collective_reward / self.num_microgrids for aid in self.agents}
        else:
            rewards = {aid: -(results[aid]['cost'] + self.penalty * results[aid]['imbalance'])
                      for aid in self.agents}
        
        done = self._step_count >= self.max_steps
        terminateds = {aid: done for aid in self.agents}
        terminateds['__all__'] = done
        truncateds = {aid: False for aid in self.agents}
        truncateds['__all__'] = False
        
        return observations, rewards, terminateds, truncateds, {aid: results[aid] for aid in self.agents}


print("Environment with proper HERON action flow: Env → L3 → L2 → L1")

In [None]:
# Test the environment with L3-only interaction
env = SimpleMultiMicrogridEnv(num_microgrids=3, max_steps=10, share_reward=True)

print("=== HERON 3-Level Hierarchy ===")
print(f"L3 SystemAgent: {env.system_agent.agent_id}")
print(f"L2 Coordinators: {list(env.system_agent.coordinators.keys())}")
for mg_id, mg in env.system_agent.coordinators.items():
    print(f"  {mg_id} -> L1 Subordinates: {list(mg.subordinates.keys())}")

print(f"\nRL Agents (PettingZoo): {env.possible_agents}")
print(f"\n=== Environment's Physics State (separate from agents) ===")
print(f"env._physics keys: {list(env._physics.keys())}")
print(f"mg_0 physics: battery_soc={env._physics['microgrids']['mg_0']['battery']['soc']}, "
      f"gen_power={env._physics['microgrids']['mg_0']['generator']['p_mw']}")

# Reset - note: we access physics through env, observations through L3
obs, infos = env.reset()
print(f"\nAfter reset (observations came through L3):")
for aid, o in obs.items():
    print(f"  {aid}: {o}")

In [None]:
# Run steps - step() now handles action routing internally
print("Running 5 steps...\n")
print("Inside step():")
print("  1. sys_obs = system_agent.observe()      # L1→L2→L3→Env")
print("  2. system_agent.act(obs, actions)        # Env→L3→L2→L1")
print("  3. Run physics                           # Environment's job")
print("  4. system_agent.update_from_environment  # Env→L3→L2→L1")
print("  5. Return observations                   # L1→L2→L3→Env\n")

for step in range(5):
    # RL policy computes actions (random for testing)
    actions = {aid: env.action_spaces[aid].sample() for aid in env.agents}
    
    # step() handles: observe → act → physics → update → return obs
    obs, rewards, terminateds, truncateds, infos = env.step(actions)
    
    print(f"Step {step + 1}:")
    for mg_id in env._mg_ids:
        mg_phys = env._physics['microgrids'][mg_id]
        print(f"  {mg_id}: SOC={mg_phys['battery']['soc']:.2f}, "
              f"Gen={mg_phys['generator']['p_mw']:.2f}MW")
    print(f"  System Frequency: {env._physics['frequency']:.3f} Hz")
    print()

## Step 3: Understanding the L3-Only Pattern

### Why Environment ↔ L3 Only?

This pattern ensures clean separation of concerns:

| Component | Responsibility |
|-----------|---------------|
| **Environment** | Physics simulation, reward calculation |
| **SystemAgent (L3)** | Hierarchy coordination, state distribution |
| **CoordinatorAgent (L2)** | Subordinate management, protocol execution |
| **FieldAgent (L1)** | State representation, action execution |

### Data Flow

```
┌─────────────────┐
│   Environment   │ ← Maintains physics state
└────────┬────────┘
         │ update_from_environment(env_state)
         ▼
┌─────────────────┐
│ SystemAgent (L3)│ ← Propagates to coordinators
└────────┬────────┘
         │ update_from_environment()
         ▼
┌─────────────────┐
│Coordinator (L2) │ ← Propagates to subordinates
└────────┬────────┘
         │ update_from_environment()
         ▼
┌─────────────────┐
│ FieldAgent (L1) │ ← Updates state features
└─────────────────┘
```

### Why Shared Rewards?

In **cooperative** settings, agents should optimize collective goals:
```python
if self.share_reward:
    # All agents get same reward -> learn to cooperate
    rewards = {aid: collective_reward / num_agents for aid in agents}
```

This is **Centralized Training with Decentralized Execution (CTDE)**.

### Why Penalty for Imbalance?

Power grids must balance supply and demand. The penalty:
```python
reward = -(cost + penalty * imbalance)
```
Encourages agents to coordinate generation with load.

## Step 4: Adding Configuration Support

For production, environments should be configurable via dicts/YAML.

In [None]:
# Example config (like what load_setup() would return)
env_config = {
    'num_microgrids': 3,
    'max_steps': 96,
    'share_reward': True,
    'penalty': 10.0,
    'train': True,
    'centralized': True,
}

# Factory function for RLlib
def create_env(config: Dict) -> SimpleMultiMicrogridEnv:
    """Create environment from config dict."""
    return SimpleMultiMicrogridEnv(
        num_microgrids=config.get('num_microgrids', 3),
        max_steps=config.get('max_steps', 96),
        share_reward=config.get('share_reward', True),
        penalty=config.get('penalty', 10.0),
    )

# Test
env2 = create_env(env_config)
print(f"Created env with {env2.num_microgrids} microgrids, {env2.max_steps} max steps")

## Key Takeaways

1. **Environment Only Talks to L3 SystemAgent**
   ```python
   # All interaction through L3 - never directly touch L2/L1
   self._grid_system.observe(...)
   self._grid_system.act(...)
   self._grid_system.update_from_environment(...)
   ```

2. **Complete Data Flow in `step()`**
   ```python
   def step(self, actions):
       # 1. Observations: L1 → L2 → L3 → Env
       sys_obs = self._grid_system.observe(global_state={})
       
       # 2. Actions: Env → L3 → L2 → L1
       self._grid_system.act(sys_obs, upstream_action=actions)
       
       # 3. Physics (environment's job)
       # ... run simulation ...
       
       # 4. State updates: Env → L3 → L2 → L1
       self._grid_system.update_from_environment(env_state)
       
       # 5. Return observations: L1 → L2 → L3 → Env
       return observations, rewards, ...
   ```

3. **Environment Owns Physics State**
   ```python
   self._physics = {
       'frequency': 60.0,
       'microgrids': {'mg_0': {'battery': {...}, 'generator': {...}}}
   }
   ```

4. **Agents Don't Have `step()`**
   - `observe()` — collect observations from hierarchy
   - `act()` — distribute actions down hierarchy
   - `update_from_environment()` — receive state updates
   - Physics — environment's responsibility

---

**Next:** [04_training_with_rllib.ipynb](04_training_with_rllib.ipynb) - Training with MAPPO