# Tutorial 5: Training with RLlib

**Goal:** Train a multi-agent policy using MAPPO on our microgrid environment.

**Time:** ~15 minutes (including training)

---

## End-to-End Workflow

This tutorial demonstrates the complete HERON + RLlib workflow:

```
1. Define Features (FeatureProvider)
       ↓
2. Create Agents (FieldAgent, CoordinatorAgent)
       ↓
3. Build Environment (PettingZooParallelEnv)
       ↓
4. Register with RLlib (ParallelPettingZooEnv wrapper)
       ↓
5. Train MAPPO (shared policy for CTDE)
       ↓
6. Evaluate (in synchronous mode)
       ↓
7. Test robustness (event-driven mode → Tutorial 06)
```

## Step 1: Complete Environment Definition

Let's put everything from Tutorials 2-4 into a single, self-contained module.

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from typing import Any, Dict, List, Optional, Tuple
from dataclasses import dataclass
from gymnasium.spaces import Box

from heron.core.feature import FeatureProvider
from heron.core.state import FieldAgentState
from heron.agents.field_agent import FieldAgent
from heron.agents.coordinator_agent import CoordinatorAgent
from heron.protocols.vertical import SetpointProtocol
from heron.envs.adapters import PettingZooParallelEnv  # Use HERON adapter!


# ============================================
# Features
# ============================================
@dataclass
class BatterySOC(FeatureProvider):
    """Battery state of charge - visible to owner and coordinator."""
    visibility = ['owner', 'upper_level']
    soc: float = 0.5
    
    def vector(self) -> np.ndarray:
        return np.array([self.soc], dtype=np.float32)
    def names(self): return ['soc']
    def to_dict(self): return {'soc': self.soc}
    @classmethod
    def from_dict(cls, d): return cls(**d)
    def set_values(self, **kw): 
        if 'soc' in kw: self.soc = np.clip(float(kw['soc']), 0.1, 0.9)


@dataclass
class GenOutput(FeatureProvider):
    """Generator output - visible to owner, coordinator, and system."""
    visibility = ['owner', 'upper_level', 'system']
    p_mw: float = 0.0
    p_max: float = 5.0
    
    def vector(self) -> np.ndarray:
        return np.array([self.p_mw / max(self.p_max, 1e-6)], dtype=np.float32)
    def names(self): return ['p_norm']
    def to_dict(self): return {'p_mw': self.p_mw, 'p_max': self.p_max}
    @classmethod
    def from_dict(cls, d): return cls(**d)
    def set_values(self, **kw): 
        if 'p_mw' in kw: self.p_mw = float(kw['p_mw'])


# ============================================
# Field Agents (Devices)
# ============================================
class SimpleBattery(FieldAgent):
    """Simple battery agent with charge/discharge control."""
    def __init__(self, agent_id: str, capacity: float = 2.0, max_power: float = 0.5, upstream_id: str = None):
        self.capacity = capacity
        self.max_power = max_power
        super().__init__(agent_id=agent_id, upstream_id=upstream_id, config={'name': agent_id})
        self.state = FieldAgentState(owner_id=agent_id, owner_level=1)
        self.state.register_feature('soc', BatterySOC(soc=0.5))
        self.action_space = Box(-1, 1, (1,), np.float32)
        self.observation_space = Box(-np.inf, np.inf, (1,), np.float32)
    
    def observe(self, gs=None): return self.state.vector()
    
    def step(self, action, dt=1.0):
        power = float(action[0]) * self.max_power
        soc = self.state.features['soc'].soc + power * dt / self.capacity
        self.state.features['soc'].set_values(soc=soc)
        return {'power_mw': power, 'soc': self.state.features['soc'].soc}
    
    def reset(self, seed=None):
        self.state.features['soc'].soc = 0.5
        return self.observe()


class SimpleGen(FieldAgent):
    """Simple generator agent with power output control."""
    def __init__(self, agent_id: str, p_max: float = 5.0, cost: float = 50.0, upstream_id: str = None):
        self.p_max = p_max
        self.cost_per_mwh = cost
        super().__init__(agent_id=agent_id, upstream_id=upstream_id, config={'name': agent_id})
        self.state = FieldAgentState(owner_id=agent_id, owner_level=1)
        self.state.register_feature('output', GenOutput(p_mw=0, p_max=p_max))
        self.action_space = Box(0, 1, (1,), np.float32)
        self.observation_space = Box(-np.inf, np.inf, (1,), np.float32)
    
    def observe(self, gs=None): return self.state.vector()
    
    def step(self, action, dt=1.0):
        p = float(action[0]) * self.p_max
        self.state.features['output'].set_values(p_mw=p)
        return {'power_mw': p, 'cost': p * dt * self.cost_per_mwh}
    
    def reset(self, seed=None):
        self.state.features['output'].set_values(p_mw=0)
        return self.observe()


# ============================================
# Coordinator Agent (Microgrid)
# ============================================
class SimpleMicrogrid(CoordinatorAgent):
    """Microgrid coordinator managing battery and generator."""
    def __init__(self, agent_id: str, load: float = 3.0, upstream_id: str = None):
        self.load = load
        super().__init__(agent_id=agent_id, upstream_id=upstream_id, protocol=SetpointProtocol())
        
        self.battery = SimpleBattery(f'{agent_id}_bat', upstream_id=agent_id)
        self.gen = SimpleGen(f'{agent_id}_gen', upstream_id=agent_id)
        self.subordinates = {self.battery.agent_id: self.battery, self.gen.agent_id: self.gen}
        
        self.observation_space = Box(-np.inf, np.inf, (3,), np.float32)
        self.action_space = Box(np.array([-1, 0]), np.array([1, 1]), dtype=np.float32)
    
    def observe(self, gs=None) -> np.ndarray:
        return np.array([
            self.battery.state.features['soc'].soc,
            self.gen.state.features['output'].p_mw / self.gen.p_max,
            self.load / 10.0
        ], dtype=np.float32)
    
    def step(self, action: np.ndarray, dt: float = 1.0) -> Dict:
        bat_res = self.battery.step(action[0:1], dt)
        gen_res = self.gen.step(action[1:2], dt)
        net_power = gen_res['power_mw'] - bat_res['power_mw']
        imbalance = abs(self.load - net_power)
        return {'net_power': net_power, 'imbalance': imbalance, 'cost': gen_res['cost']}
    
    def reset(self, seed=None):
        self.battery.reset(seed)
        self.gen.reset(seed)
        return self.observe()


# ============================================
# Multi-Agent Environment (using HERON adapter)
# ============================================
class SimpleMicrogridEnv(PettingZooParallelEnv):
    """Multi-agent microgrid environment using HERON's PettingZoo adapter.
    
    Using PettingZooParallelEnv gives us:
    - Built-in agent registration and management
    - Event-driven execution support (for testing)
    - Message broker integration (for distributed mode)
    - Automatic space initialization helpers
    """
    
    metadata = {'render_modes': ['human'], 'name': 'simple_microgrids_v0'}
    
    def __init__(self, config: Dict = None):
        # Initialize HERON's adapter first
        super().__init__(env_id="simple_microgrids")
        
        config = config or {}
        self.num_microgrids = config.get('num_microgrids', 3)
        self.max_steps = config.get('max_episode_steps', 96)
        self.share_reward = config.get('share_reward', True)
        self.penalty = config.get('penalty', 10.0)
        
        # Create agents with varying loads
        loads = [3.0, 4.0, 2.5]
        self.agents_dict = {}
        for i in range(self.num_microgrids):
            agent_id = f'mg_{i}'
            agent = SimpleMicrogrid(agent_id, load=loads[i % len(loads)])
            self.agents_dict[agent_id] = agent
            # Register with HERON (enables event-driven, messaging, etc.)
            self.register_agent(agent)
        
        # Setup PettingZoo attributes using HERON helpers
        self._set_agent_ids(list(self.agents_dict.keys()))
        self._init_spaces(
            action_spaces={aid: a.action_space for aid, a in self.agents_dict.items()},
            observation_spaces={aid: a.observation_space for aid, a in self.agents_dict.items()},
        )
        
        self._step_count = 0
    
    @property
    def observation_space(self): return self.observation_spaces
    
    @property
    def action_space(self): return self.action_spaces
    
    def reset(self, seed=None, options=None):
        self._step_count = 0
        self._agents = self._possible_agents.copy()
        
        # Use HERON's reset helper
        self.reset_agents(seed=seed)
        
        obs = {aid: a.observe() for aid, a in self.agents_dict.items()}
        return obs, {aid: {} for aid in self.agents}
    
    def step(self, actions):
        self._step_count += 1
        self._timestep = self._step_count
        
        results = {}
        total_cost, total_imbalance = 0.0, 0.0
        
        for aid, agent in self.agents_dict.items():
            action = actions.get(aid, agent.action_space.sample())
            results[aid] = agent.step(action)
            total_cost += results[aid]['cost']
            total_imbalance += results[aid]['imbalance']
        
        # Reward: minimize cost and imbalance
        collective_reward = -(total_cost + self.penalty * total_imbalance)
        
        if self.share_reward:
            rewards = {aid: collective_reward / self.num_microgrids for aid in self.agents}
        else:
            rewards = {aid: -(results[aid]['cost'] + self.penalty * results[aid]['imbalance']) 
                      for aid in self.agents}
        
        obs = {aid: a.observe() for aid, a in self.agents_dict.items()}
        done = self._step_count >= self.max_steps
        terminateds = {aid: done for aid in self.agents}
        terminateds['__all__'] = done
        truncateds = {aid: False for aid in self.agents}
        truncateds['__all__'] = False
        infos = {aid: {'cost': results[aid]['cost'], 'imbalance': results[aid]['imbalance']} 
                for aid in self.agents}
        
        return obs, rewards, terminateds, truncateds, infos
    
    def render(self): pass


print("Environment module ready (using HERON's PettingZooParallelEnv)!")

## Step 2: Setup RLlib Training

Now let's configure RLlib to train MAPPO on our environment.

In [None]:
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.env.wrappers.pettingzoo_env import ParallelPettingZooEnv
from ray.tune.registry import register_env

# Initialize Ray
ray.init(ignore_reinit_error=True, num_cpus=2)

# Environment creator function for RLlib
def env_creator(config):
    """Create environment wrapped for RLlib."""
    env = SimpleMicrogridEnv(config)
    return ParallelPettingZooEnv(env)

# Register the environment
register_env("simple_microgrids", env_creator)

print("RLlib initialized and environment registered!")

In [None]:
# Create a test environment to get spaces
env_config = {
    'num_microgrids': 3,
    'max_episode_steps': 48,  # Shorter episodes for tutorial
    'share_reward': True,
    'penalty': 10.0,
}

test_env = env_creator(env_config)

# Get agent IDs and spaces
possible_agents = test_env.par_env.possible_agents
print(f"Agents: {possible_agents}")

# For MAPPO: All agents share one policy
# Use first agent's spaces (they're all identical)
first_agent = possible_agents[0]
policies = {
    'shared_policy': (
        None,  # Policy class (None = default)
        test_env.observation_space[first_agent],
        test_env.action_space[first_agent],
        {}  # Policy config
    )
}

# All agents use the shared policy
policy_mapping_fn = lambda agent_id, *args, **kwargs: 'shared_policy'

print(f"Observation space: {test_env.observation_space[first_agent]}")
print(f"Action space: {test_env.action_space[first_agent]}")
print(f"Policy mapping: All agents -> 'shared_policy' (MAPPO)")

In [None]:
# Configure PPO algorithm
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=False,
        enable_env_runner_and_connector_v2=False,
    )
    .environment(
        env="simple_microgrids",
        env_config=env_config,
        disable_env_checking=True,
    )
    .framework("torch")
    .training(
        lr=5e-4,  # Learning rate
        gamma=0.99,  # Discount factor
        lambda_=0.95,  # GAE lambda
        entropy_coeff=0.01,  # Exploration bonus
        clip_param=0.2,  # PPO clip parameter
    )
    .multi_agent(
        policies=policies,
        policy_mapping_fn=policy_mapping_fn,
    )
    .resources(
        num_gpus=0,  # CPU only for tutorial
    )
    .env_runners(
        num_env_runners=1,  # Single worker for tutorial
        num_envs_per_env_runner=1,
    )
)

# Additional settings
config.train_batch_size = 500
config.sgd_minibatch_size = 64
config.num_sgd_iter = 5
config.model = {
    'fcnet_hiddens': [64, 64],  # Small network for tutorial
    'fcnet_activation': 'relu',
}
config.preprocessor_pref = None
config.enable_connectors = False

print("PPO config ready!")

## Step 3: Train the Agent

Now let's train for a few iterations.

In [None]:
# Build the algorithm
algo = config.build()

print("Training MAPPO on Simple Microgrids...")
print("="*60)
print(f"{'Iter':>5} | {'Reward':>12} | {'Episodes':>10} | {'Steps':>10}")
print("-"*60)

# Training loop - 10 iterations for tutorial
num_iterations = 10
rewards = []

for i in range(num_iterations):
    result = algo.train()
    
    # Extract metrics
    env_runners = result.get('env_runners', {})
    reward_mean = env_runners.get('episode_reward_mean', 0)
    episodes = env_runners.get('episodes_this_iter', 0)
    timesteps = result.get('timesteps_total', 0)
    
    rewards.append(reward_mean)
    print(f"{i+1:5d} | {reward_mean:12.2f} | {episodes:10d} | {timesteps:10d}")

print("-"*60)
print(f"Training complete! Final reward: {rewards[-1]:.2f}")

In [None]:
# Plot training progress
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(range(1, len(rewards)+1), rewards, 'b-o')
plt.xlabel('Iteration')
plt.ylabel('Mean Episode Reward')
plt.title('MAPPO Training on Simple Microgrids')
plt.grid(True, alpha=0.3)
plt.show()

## Step 4: Evaluate the Learned Policy

Let's see how the trained agents behave.

In [None]:
# Create evaluation environment
eval_env = SimpleMicrogridEnv(env_config)

print("Evaluating trained policy...")
print("="*60)

obs, info = eval_env.reset()
done = False
total_rewards = {aid: 0.0 for aid in eval_env.agents}
total_cost = 0.0
total_imbalance = 0.0
step = 0

while not done:
    # Get actions from trained policy
    actions = {}
    for agent_id, agent_obs in obs.items():
        actions[agent_id] = algo.compute_single_action(
            agent_obs,
            policy_id='shared_policy',
            explore=False
        )
    
    # Step environment
    obs, rewards, terminateds, truncateds, infos = eval_env.step(actions)
    
    # Accumulate metrics
    for aid in eval_env.agents:
        total_rewards[aid] += rewards[aid]
        total_cost += infos[aid]['cost']
        total_imbalance += infos[aid]['imbalance']
    
    done = terminateds.get('__all__', False)
    step += 1
    
    # Print every 10 steps
    if step % 10 == 0:
        print(f"Step {step:3d}: Cost={total_cost/step:.2f}/step, Imbalance={total_imbalance/step:.2f}/step")

print("="*60)
print(f"Episode complete after {step} steps")
print(f"Total reward per agent: {sum(total_rewards.values())/len(total_rewards):.2f}")
print(f"Avg cost per step: {total_cost/step:.2f}")
print(f"Avg imbalance per step: {total_imbalance/step:.2f}")

## Step 5: Cleanup

In [None]:
# Stop the algorithm and shutdown Ray
algo.stop()
ray.shutdown()
print("Cleanup complete!")

## What We Built

In this tutorial series, we built a **complete MARL case study from scratch**:

| Component | What It Does | HERON Abstraction |
|-----------|--------------|-------------------|
| `BatterySOC`, `GenOutput` | Observable features | `FeatureProvider` |
| `SimpleBattery`, `SimpleGen` | Device agents | `FieldAgent` |
| `SimpleMicrogrid` | Coordinator | `CoordinatorAgent` |
| `SimpleMicrogridEnv` | Environment | `PettingZooParallelEnv` |
| MAPPO training | Policy learning | RLlib + PettingZoo |

### Key HERON Patterns Used

1. **HERON Environment Adapters**
   - Use `PettingZooParallelEnv` (not raw `ParallelEnv`)
   - Use `RLlibMultiAgentEnv` for direct RLlib integration
   - Enables event-driven execution, message broker, agent management

2. **Agent Registration**
   ```python
   self.register_agent(agent)  # HERON tracks agents
   self._set_agent_ids([...])  # Setup PettingZoo attributes
   self._init_spaces(...)      # Initialize spaces
   ```

3. **Visibility-based observation filtering**
   - Features declare who can see them
   - No manual filtering code

4. **Hierarchical agent structure**
   - Devices (L1) -> Microgrids (L2)
   - Clear parent-child relationships

5. **Protocol-based coordination**
   - `SetpointProtocol` for hierarchical control
   - Swappable without changing agent code

### Next Steps

To build a production case study like `examples/05_mappo_training.py`:

1. **Add more features**: Voltage, frequency, prices, weather
2. **Use real physics**: PandaPower, OpenDSS, SUMO
3. **Add event-driven mode**: See Tutorial 06 for heterogeneous tick rates, delays
4. **Compare protocols**: Setpoint vs PriceSignal vs Consensus
5. **Add visibility ablation**: Test different observability levels

## Summary

You've built a complete MARL system from scratch:

| Step | Time | What You Built |
|------|------|----------------|
| Features | 2 min | `BatterySOC`, `GenOutput` with visibility |
| Agents | 3 min | `SimpleBattery`, `SimpleGen`, `SimpleMicrogrid` |
| Environment | 2 min | `SimpleMicrogridEnv` with HERON adapter |
| Training | 5 min | MAPPO with shared policy (CTDE) |

**But wait—we trained in synchronous mode.** Will this policy work in the real world where agents have different update rates and communication delays?

---

**Next:** [06_event_driven_testing.ipynb](06_event_driven_testing.ipynb) — Validate policy robustness with realistic timing

For production examples:
- `examples/05_mappo_training.py` — Full MAPPO script
- `powergrid/` — Complete case study (14 features, 4 protocols)