# Tutorial 4: Training with RLlib

**Goal:** Train a multi-agent policy using MAPPO on our microgrid environment.

**Time:** ~15 minutes (including training)

---

## End-to-End Workflow

This tutorial demonstrates the complete HERON + RLlib workflow:

```
1. Define Features (FeatureProvider with ClassVar visibility)
       |
2. Create Agents (FieldAgent, CoordinatorAgent, SystemAgent)
       |
3. Build Environment (MultiAgentEnv with system_agent)
       |
4. Wrap for RLlib (Gymnasium-compatible interface)
       |
5. Train MAPPO (shared policy for CTDE)
       |
6. Evaluate (in synchronous mode)
```

## Step 1: Complete Environment Definition

Let's put everything from Tutorials 1-3 into a single, self-contained module.

In [None]:
import warnings
warnings.filterwarnings('ignore')

import numpy as np
from typing import Any, ClassVar, Dict, List, Optional, Sequence
from dataclasses import dataclass, field
from gymnasium.spaces import Box

from heron.core.feature import FeatureProvider
from heron.core.action import Action
from heron.agents.field_agent import FieldAgent
from heron.agents.coordinator_agent import CoordinatorAgent
from heron.agents.system_agent import SystemAgent
from heron.envs.base import MultiAgentEnv
from heron.scheduling.tick_config import TickConfig


# ============================================
# Features (with ClassVar visibility)
# ============================================
@dataclass(slots=True)
class BatterySOC(FeatureProvider):
    """Battery state of charge - visible to owner and coordinator."""
    visibility: ClassVar[Sequence[str]] = ['owner', 'upper_level']
    soc: float = 0.5
    capacity: float = 2.0


@dataclass(slots=True)
class GenOutput(FeatureProvider):
    """Generator output - visible to owner, coordinator, and system."""
    visibility: ClassVar[Sequence[str]] = ['owner', 'upper_level', 'system']
    p_mw: float = 0.0
    p_max: float = 5.0


@dataclass(slots=True)
class SystemFrequency(FeatureProvider):
    """System frequency feature."""
    visibility: ClassVar[Sequence[str]] = ['system']
    frequency_hz: float = 60.0


# ============================================
# Field Agents (L1)
# ============================================
class SimpleBattery(FieldAgent):
    """Simple battery agent with charge/discharge control."""
    
    def __init__(self, agent_id: str, capacity: float = 2.0, max_power: float = 0.5, **kwargs):
        self.capacity = capacity
        self.max_power = max_power
        # Create features to pass to parent
        features = [BatterySOC(soc=0.5, capacity=capacity)]
        super().__init__(
            agent_id=agent_id,
            features=features,
            tick_config=TickConfig.deterministic(tick_interval=1.0),
            **kwargs
        )

    def init_action(self, features: List[FeatureProvider] = []) -> Action:
        action = Action()
        action.set_specs(dim_c=1, range=(np.array([-1.0]), np.array([1.0])))
        return action

    def set_action(self, action: Any, *args, **kwargs) -> None:
        if hasattr(action, '__iter__'):
            val = action[0] if len(action) > 0 else 0.0
        else:
            val = float(action)
        self.action.set_values(val)

    def set_state(self, **kwargs) -> None:
        if 'soc' in kwargs:
            self.state.features['BatterySOC'].soc = np.clip(float(kwargs['soc']), 0.1, 0.9)

    def apply_action(self) -> None:
        soc_feature = self.state.features['BatterySOC']
        power = self.action.vector()[0] * self.max_power
        new_soc = soc_feature.soc + power / self.capacity
        self.set_state(soc=new_soc)

    def compute_local_reward(self, local_state: dict) -> float:
        return 0.0

    @property
    def cost(self) -> float:
        return 0.0

    @property
    def safety(self) -> float:
        soc = self.state.features['BatterySOC'].soc
        return max(0, 0.2 - soc) + max(0, soc - 0.8)


class SimpleGen(FieldAgent):
    """Simple generator agent with power output control."""
    
    def __init__(self, agent_id: str, p_max: float = 5.0, cost_per_mwh: float = 50.0, **kwargs):
        self.p_max = p_max
        self.cost_per_mwh = cost_per_mwh
        features = [GenOutput(p_mw=0.0, p_max=p_max)]
        super().__init__(
            agent_id=agent_id,
            features=features,
            tick_config=TickConfig.deterministic(tick_interval=1.0),
            **kwargs
        )

    def init_action(self, features: List[FeatureProvider] = []) -> Action:
        action = Action()
        action.set_specs(dim_c=1, range=(np.array([0.0]), np.array([1.0])))
        return action

    def set_action(self, action: Any, *args, **kwargs) -> None:
        if hasattr(action, '__iter__'):
            val = action[0] if len(action) > 0 else 0.0
        else:
            val = float(action)
        self.action.set_values(val)

    def set_state(self, **kwargs) -> None:
        if 'p_mw' in kwargs:
            self.state.features['GenOutput'].p_mw = float(kwargs['p_mw'])

    def apply_action(self) -> None:
        power = self.action.vector()[0] * self.p_max
        self.set_state(p_mw=power)

    def compute_local_reward(self, local_state: dict) -> float:
        return 0.0

    @property
    def cost(self) -> float:
        return self.state.features['GenOutput'].p_mw * self.cost_per_mwh

    @property
    def safety(self) -> float:
        return 0.0


# ============================================
# Coordinator Agent (L2)
# ============================================
class SimpleMicrogrid(CoordinatorAgent):
    """Microgrid coordinator managing battery and generator."""
    
    def __init__(self, agent_id: str, load: float = 3.0, **kwargs):
        self.load = load
        
        # Build subordinates BEFORE calling super().__init__
        battery = SimpleBattery(f'{agent_id}_bat', upstream_id=agent_id)
        gen = SimpleGen(f'{agent_id}_gen', upstream_id=agent_id)
        subordinates = {battery.agent_id: battery, gen.agent_id: gen}
        
        super().__init__(
            agent_id=agent_id,
            subordinates=subordinates,
            tick_config=TickConfig.deterministic(tick_interval=60.0),
            **kwargs
        )

    @property
    def cost(self) -> float:
        return sum(s.cost for s in self.subordinates.values())

    @property
    def safety(self) -> float:
        return sum(s.safety for s in self.subordinates.values())


# ============================================
# System Agent (L3)
# ============================================
class SimpleGridSystem(SystemAgent):
    """System agent managing multiple microgrids."""
    
    def __init__(self, agent_id: str, subordinates: Dict[str, SimpleMicrogrid]):
        features = [SystemFrequency(frequency_hz=60.0)]
        super().__init__(
            agent_id=agent_id,
            features=features,
            subordinates=subordinates,
            tick_config=TickConfig.deterministic(tick_interval=300.0),
        )

    @property
    def cost(self) -> float:
        return sum(mg.cost for mg in self.subordinates.values())

    @property
    def safety(self) -> float:
        return sum(mg.safety for mg in self.subordinates.values())


print("Agents defined with current HERON API!")

In [None]:
# ============================================
# Environment State (bridge to simulation)
# ============================================
@dataclass
class SimpleEnvState:
    """Environment state for physics simulation."""
    battery_soc: Dict[str, float] = field(default_factory=dict)
    gen_power: Dict[str, float] = field(default_factory=dict)
    frequency: float = 60.0


# ============================================
# Multi-Agent Environment
# ============================================
class SimpleMicrogridEnv(MultiAgentEnv):
    """Multi-agent microgrid environment using HERON's MultiAgentEnv."""

    def __init__(
        self,
        system_agent: SystemAgent,
        max_steps: int = 96,
        share_reward: bool = True,
        penalty: float = 10.0,
    ):
        self.max_steps = max_steps
        self.share_reward = share_reward
        self.penalty = penalty
        self._step_count = 0
        
        # Store microgrid loads
        self._mg_loads = {}
        for mg_id, mg in system_agent.subordinates.items():
            self._mg_loads[mg_id] = getattr(mg, 'load', 3.0)
        
        # Call parent init
        super().__init__(
            system_agent=system_agent,
            env_id="simple_microgrids",
        )

    def global_state_to_env_state(self, global_state: Dict) -> SimpleEnvState:
        """Convert proxy state to env state."""
        env_state = SimpleEnvState()
        agent_states = global_state.get("agent_states", {})
        
        for agent_id, state_dict in agent_states.items():
            features = state_dict.get("features", {})
            if "BatterySOC" in features:
                mg_id = "_".join(agent_id.split("_")[:-1])
                env_state.battery_soc[mg_id] = features["BatterySOC"].get("soc", 0.5)
            if "GenOutput" in features:
                mg_id = "_".join(agent_id.split("_")[:-1])
                env_state.gen_power[mg_id] = features["GenOutput"].get("p_mw", 0.0)
        
        return env_state

    def run_simulation(self, env_state: SimpleEnvState, *args, **kwargs) -> SimpleEnvState:
        """Run physics simulation."""
        total_gen = sum(env_state.gen_power.values())
        total_load = sum(self._mg_loads.values())
        env_state.frequency = 60.0 + (total_gen - total_load) * 0.01
        self._step_count += 1
        return env_state

    def env_state_to_global_state(self, env_state: SimpleEnvState) -> Dict:
        """Convert simulation results to global state."""
        agent_states = self.proxy_agent.get_serialized_agent_states()
        for agent_id, state_dict in agent_states.items():
            features = state_dict.get("features", {})
            if "SystemFrequency" in features:
                features["SystemFrequency"]["frequency_hz"] = env_state.frequency
        return {"agent_states": agent_states}

    def reset(self, *, seed: Optional[int] = None, **kwargs):
        self._step_count = 0
        return super().reset(seed=seed, **kwargs)


print("Environment defined with MultiAgentEnv pattern!")

In [None]:
# ============================================
# Factory function to create environment
# ============================================
def create_env(config: Dict = None) -> SimpleMicrogridEnv:
    """Create environment from config dict."""
    config = config or {}
    num_microgrids = config.get('num_microgrids', 3)
    loads = [3.0, 4.0, 2.5]
    
    # Build agent hierarchy bottom-up
    microgrids = {}
    for i in range(num_microgrids):
        mg_id = f'mg_{i}'
        mg = SimpleMicrogrid(agent_id=mg_id, load=loads[i % len(loads)])
        microgrids[mg_id] = mg
    
    system_agent = SimpleGridSystem(
        agent_id='grid_system',
        subordinates=microgrids
    )
    
    return SimpleMicrogridEnv(
        system_agent=system_agent,
        max_steps=config.get('max_episode_steps', 96),
        share_reward=config.get('share_reward', True),
        penalty=config.get('penalty', 10.0),
    )

# Test the factory
test_env = create_env({'num_microgrids': 3, 'max_episode_steps': 48})
print(f"Created env with {len(test_env._mg_loads)} microgrids")
print(f"Registered agents: {list(test_env.registered_agents.keys())}")

## Step 2: Setup RLlib Training

Now let's configure RLlib to train MAPPO on our environment.

**Note:** RLlib requires a Gymnasium-compatible interface. We'll create a simple wrapper.

In [None]:
import gymnasium as gym
from gymnasium.spaces import Dict as DictSpace


class RLlibEnvWrapper(gym.Env):
    """Wrapper to make HERON MultiAgentEnv compatible with RLlib.
    
    RLlib expects specific observation/action space formats.
    This wrapper handles the conversion.
    """
    
    def __init__(self, config: Dict = None):
        super().__init__()
        self.env = create_env(config or {})
        
        # Get field agent IDs (the actual RL agents)
        self.agent_ids = []
        for mg_id, mg in self.env._system_agent.subordinates.items():
            for sub_id in mg.subordinates.keys():
                self.agent_ids.append(sub_id)
        
        # Define spaces for each agent
        self._obs_space = Box(-np.inf, np.inf, (2,), np.float32)  # [soc or p_norm, load_norm]
        self._act_space_bat = Box(np.array([-1.0]), np.array([1.0]), dtype=np.float32)
        self._act_space_gen = Box(np.array([0.0]), np.array([1.0]), dtype=np.float32)
        
        # Multi-agent spaces
        self.observation_space = DictSpace({
            aid: self._obs_space for aid in self.agent_ids
        })
        self.action_space = DictSpace({
            aid: self._act_space_bat if 'bat' in aid else self._act_space_gen
            for aid in self.agent_ids
        })
        
        self._agents = self.agent_ids.copy()
        
    @property
    def possible_agents(self):
        return self.agent_ids
    
    @property
    def agents(self):
        return self._agents
    
    def _get_obs(self) -> Dict[str, np.ndarray]:
        """Extract observations for each field agent."""
        obs = {}
        for mg_id, mg in self.env._system_agent.subordinates.items():
            load_norm = mg.load / 10.0
            for sub_id, sub in mg.subordinates.items():
                if 'BatterySOC' in sub.state.features:
                    soc = sub.state.features['BatterySOC'].soc
                    obs[sub_id] = np.array([soc, load_norm], dtype=np.float32)
                elif 'GenOutput' in sub.state.features:
                    p_norm = sub.state.features['GenOutput'].p_mw / sub.p_max
                    obs[sub_id] = np.array([p_norm, load_norm], dtype=np.float32)
        return obs
    
    def reset(self, *, seed=None, options=None):
        self.env.reset(seed=seed)
        self._agents = self.agent_ids.copy()
        return self._get_obs(), {aid: {} for aid in self.agents}
    
    def step(self, actions: Dict[str, np.ndarray]):
        # Step the environment
        obs, rewards, terminateds, truncateds, infos = self.env.step(actions)
        
        # Convert to field agent observations and rewards
        field_obs = self._get_obs()
        
        # Compute rewards per field agent
        total_cost = self.env._system_agent.cost
        total_safety = self.env._system_agent.safety
        collective_reward = -(total_cost + self.env.penalty * total_safety)
        
        field_rewards = {
            aid: collective_reward / len(self.agent_ids)
            for aid in self.agent_ids
        }
        
        done = self.env._step_count >= self.env.max_steps
        field_terminateds = {aid: done for aid in self.agent_ids}
        field_terminateds['__all__'] = done
        field_truncateds = {aid: False for aid in self.agent_ids}
        field_truncateds['__all__'] = False
        field_infos = {aid: {'cost': total_cost, 'safety': total_safety} for aid in self.agent_ids}
        
        return field_obs, field_rewards, field_terminateds, field_truncateds, field_infos


# Test the wrapper
wrapped_env = RLlibEnvWrapper({'num_microgrids': 3, 'max_episode_steps': 48})
print(f"Agent IDs: {wrapped_env.agent_ids}")
print(f"Observation space: {wrapped_env.observation_space}")
print(f"Action space: {wrapped_env.action_space}")

In [None]:
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env

# Initialize Ray
ray.init(ignore_reinit_error=True, num_cpus=2)

# Environment creator function for RLlib
def env_creator(config):
    return RLlibEnvWrapper(config)

# Register the environment
register_env("simple_microgrids", env_creator)

print("RLlib initialized and environment registered!")

In [None]:
# Create a test environment to get spaces
env_config = {
    'num_microgrids': 3,
    'max_episode_steps': 48,
    'share_reward': True,
    'penalty': 10.0,
}

test_env = env_creator(env_config)
possible_agents = test_env.possible_agents
print(f"Agents: {possible_agents}")

# For MAPPO: All agents share one policy
first_agent = possible_agents[0]
policies = {
    'shared_policy': (
        None,
        test_env.observation_space[first_agent],
        test_env.action_space[first_agent],
        {}
    )
}

policy_mapping_fn = lambda agent_id, *args, **kwargs: 'shared_policy'

print(f"Observation space: {test_env.observation_space[first_agent]}")
print(f"Action space: {test_env.action_space[first_agent]}")
print(f"Policy mapping: All agents -> 'shared_policy' (MAPPO)")

In [None]:
# Configure PPO algorithm
config = (
    PPOConfig()
    .api_stack(
        enable_rl_module_and_learner=False,
        enable_env_runner_and_connector_v2=False,
    )
    .environment(
        env="simple_microgrids",
        env_config=env_config,
        disable_env_checking=True,
    )
    .framework("torch")
    .training(
        lr=5e-4,
        gamma=0.99,
        lambda_=0.95,
        entropy_coeff=0.01,
        clip_param=0.2,
    )
    .multi_agent(
        policies=policies,
        policy_mapping_fn=policy_mapping_fn,
    )
    .resources(
        num_gpus=0,
    )
    .env_runners(
        num_env_runners=1,
        num_envs_per_env_runner=1,
    )
)

config.train_batch_size = 500
config.sgd_minibatch_size = 64
config.num_sgd_iter = 5
config.model = {
    'fcnet_hiddens': [64, 64],
    'fcnet_activation': 'relu',
}
config.preprocessor_pref = None
config.enable_connectors = False

print("PPO config ready!")

## Step 3: Train the Agent

Now let's train for a few iterations.

In [None]:
# Build the algorithm
algo = config.build()

print("Training MAPPO on Simple Microgrids...")
print("="*60)
print(f"{'Iter':>5} | {'Reward':>12} | {'Episodes':>10} | {'Steps':>10}")
print("-"*60)

# Training loop
num_iterations = 10
rewards = []

for i in range(num_iterations):
    result = algo.train()
    
    env_runners = result.get('env_runners', {})
    reward_mean = env_runners.get('episode_reward_mean', 0)
    episodes = env_runners.get('episodes_this_iter', 0)
    timesteps = result.get('timesteps_total', 0)
    
    rewards.append(reward_mean)
    print(f"{i+1:5d} | {reward_mean:12.2f} | {episodes:10d} | {timesteps:10d}")

print("-"*60)
print(f"Training complete! Final reward: {rewards[-1]:.2f}")

In [None]:
# Plot training progress
import matplotlib.pyplot as plt

plt.figure(figsize=(10, 4))
plt.plot(range(1, len(rewards)+1), rewards, 'b-o')
plt.xlabel('Iteration')
plt.ylabel('Mean Episode Reward')
plt.title('MAPPO Training on Simple Microgrids')
plt.grid(True, alpha=0.3)
plt.show()

## Step 4: Evaluate the Learned Policy

Let's see how the trained agents behave.

In [None]:
# Create evaluation environment
eval_env = RLlibEnvWrapper(env_config)

print("Evaluating trained policy...")
print("="*60)

obs, info = eval_env.reset()
done = False
total_rewards = {aid: 0.0 for aid in eval_env.agents}
step = 0

while not done:
    # Get actions from trained policy
    actions = {}
    for agent_id, agent_obs in obs.items():
        actions[agent_id] = algo.compute_single_action(
            agent_obs,
            policy_id='shared_policy',
            explore=False
        )
    
    obs, rewards, terminateds, truncateds, infos = eval_env.step(actions)
    
    for aid in eval_env.agents:
        total_rewards[aid] += rewards[aid]
    
    done = terminateds.get('__all__', False)
    step += 1
    
    if step % 10 == 0:
        cost = infos[eval_env.agents[0]]['cost']
        safety = infos[eval_env.agents[0]]['safety']
        print(f"Step {step:3d}: Cost={cost:.2f}, Safety={safety:.2f}")

print("="*60)
print(f"Episode complete after {step} steps")
print(f"Total reward per agent: {sum(total_rewards.values())/len(total_rewards):.2f}")

## Step 5: Cleanup

In [None]:
algo.stop()
ray.shutdown()
print("Cleanup complete!")

## Key Takeaways

### HERON Patterns Used

1. **Feature Definition with ClassVar Visibility**
   ```python
   @dataclass(slots=True)
   class MyFeature(FeatureProvider):
       visibility: ClassVar[Sequence[str]] = ['owner', 'upper_level']
       value: float = 0.0
   ```

2. **Agent Hierarchy (Bottom-Up)**
   ```python
   # L1: Field agents with features in constructor
   battery = SimpleBattery(agent_id="bat", features=[...])
   
   # L2: Coordinators with subordinates in constructor
   microgrid = SimpleMicrogrid(agent_id="mg", subordinates={...})
   
   # L3: System agent with coordinators
   system = SimpleGridSystem(agent_id="sys", subordinates={...})
   ```

3. **Environment with system_agent**
   ```python
   env = MyEnv(system_agent=system_agent)
   ```

4. **Factory Function for RLlib**
   ```python
   def create_env(config):
       # Build agents bottom-up
       # Return wrapped environment
   ```

### Training Considerations

- **Shared Policy (MAPPO)**: All agents use same neural network
- **Cooperative Rewards**: Agents optimize collective objective
- **Synchronous Training**: All agents act simultaneously

---

**Next:** [05_event_driven_testing.ipynb](05_event_driven_testing.ipynb) - Validate policy robustness with realistic timing