# Creating Custom Environments for RLlib

**Prerequisites**: Complete [02_rllib_basics](../02_rllib_basics/01_ray_setup.ipynb)

You've trained on CartPole. Now let's train on YOUR problems!

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                    THE GYMNASIUM ENVIRONMENT INTERFACE                      │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Every RL environment follows the same pattern:                             │
│                                                                             │
│     ┌─────────┐                         ┌─────────────────┐                 │
│     │  Agent  │                         │   Environment   │                 │
│     │         │                         │                 │                 │
│     │ (RLlib) │ ────── action ────────> │  (YOUR CODE!)   │                 │
│     │         │                         │                 │                 │
│     │         │ <── state, reward ───── │                 │                 │
│     └─────────┘                         └─────────────────┘                 │
│                                                                             │
│  Your job: Define how the environment responds to actions                   │
│  RLlib's job: Figure out what actions to take                               │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

## The 5 Things Every Environment Needs

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                      GYMNASIUM ENV REQUIREMENTS                             │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  class MyEnv(gym.Env):                                                      │
│                                                                             │
│      1. __init__(self)                                                      │
│         └─> Define observation_space and action_space                       │
│                                                                             │
│      2. observation_space = ...                                             │
│         └─> What does the agent SEE? (state shape and bounds)               │
│                                                                             │
│      3. action_space = ...                                                  │
│         └─> What can the agent DO? (action shape and bounds)                │
│                                                                             │
│      4. reset(self) -> (observation, info)                                  │
│         └─> Start a new episode, return initial state                       │
│                                                                             │
│      5. step(self, action) -> (observation, reward, terminated,             │
│                                truncated, info)                             │
│         └─> Take action, return next state and reward                       │
│                                                                             │
│                                                                             │
│  That's it! Just these 5 things and you have an RL environment.             │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

In [None]:
# Suppress warnings
import warnings
import logging
warnings.filterwarnings("ignore")
logging.getLogger("ray").setLevel(logging.ERROR)

import gymnasium as gym
from gymnasium import spaces
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Dict, Any

---

## Understanding Spaces

Spaces define what observations and actions look like.

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                          GYMNASIUM SPACES                                   │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  spaces.Discrete(n)                                                         │
│  ──────────────────                                                         │
│  Integer from 0 to n-1                                                      │
│                                                                             │
│  Example: Discrete(4) = {0, 1, 2, 3}                                        │
│           (up, right, down, left)                                           │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  spaces.Box(low, high, shape)                                               │
│  ────────────────────────────                                               │
│  Continuous values in a bounded box                                         │
│                                                                             │
│  Example: Box(low=-1.0, high=1.0, shape=(3,))                               │
│           = array of 3 floats, each between -1 and 1                        │
│           like [0.5, -0.3, 0.8]                                             │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  spaces.MultiDiscrete([n1, n2, ...])                                        │
│  ───────────────────────────────────                                        │
│  Multiple discrete values                                                   │
│                                                                             │
│  Example: MultiDiscrete([3, 4]) = two integers, first 0-2, second 0-3       │
│           like [2, 1]                                                       │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  spaces.Dict({...})                                                         │
│  ──────────────────                                                         │
│  Dictionary of spaces (for complex observations)                            │
│                                                                             │
│  Example: Dict({"position": Box(...), "velocity": Box(...)})                │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

In [None]:
# Let's see what each space looks like

print("SPACE EXAMPLES")
print("=" * 50)

# Discrete - single integer action
discrete = spaces.Discrete(4)  # 0, 1, 2, or 3
print(f"\nDiscrete(4): {[discrete.sample() for _ in range(5)]}")
print("  Use for: up/down/left/right, buy/sell/hold")

# Box - continuous values
box = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
print(f"\nBox(low=-1, high=1, shape=(3,)): {box.sample()}")
print("  Use for: robot joint angles, continuous control")

# MultiDiscrete - multiple discrete values
multi_discrete = spaces.MultiDiscrete([3, 4, 2])
print(f"\nMultiDiscrete([3,4,2]): {multi_discrete.sample()}")
print("  Use for: multiple independent choices")

# Dict - nested spaces
dict_space = spaces.Dict({
    "position": spaces.Box(low=-10, high=10, shape=(2,)),
    "velocity": spaces.Box(low=-1, high=1, shape=(2,)),
    "inventory": spaces.Discrete(5)
})
print(f"\nDict space: {dict_space.sample()}")
print("  Use for: complex observations with multiple types")

---

## Example 1: Simple Trading Environment

Let's build a trading environment from scratch!

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                        TRADING ENVIRONMENT                                  │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  SCENARIO:                                                                  │
│  You're a trader. You can buy, sell, or hold a stock.                       │
│  Goal: Maximize profit over 200 time steps.                                 │
│                                                                             │
│  STATE (what the agent sees):                                               │
│  ┌──────────────────────────────────────────┐                               │
│  │ [position, cash, price, price_change]    │                               │
│  │                                          │                               │
│  │  position: how many shares you own       │                               │
│  │  cash: how much money you have           │                               │
│  │  price: current stock price              │                               │
│  │  price_change: recent price movement     │                               │
│  └──────────────────────────────────────────┘                               │
│                                                                             │
│  ACTIONS (what the agent can do):                                           │
│  ┌──────────────────────────────────────────┐                               │
│  │  0 = HOLD   (do nothing)                 │                               │
│  │  1 = BUY    (buy as many shares as       │                               │
│  │             possible with current cash)  │                               │
│  │  2 = SELL   (sell all shares)            │                               │
│  └──────────────────────────────────────────┘                               │
│                                                                             │
│  REWARD:                                                                    │
│  Change in portfolio value (cash + shares × price)                          │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

In [None]:
class SimpleTradingEnv(gym.Env):
    """
    Simple trading environment.
    
    The agent learns to trade a single stock to maximize profit.
    """
    
    metadata = {"render_modes": ["human"]}
    
    def __init__(self, config: Dict = None):
        super().__init__()
        config = config or {}
        
        # Environment parameters
        self.initial_cash = config.get("initial_cash", 10000)
        self.max_steps = config.get("max_steps", 200)
        self.volatility = config.get("volatility", 0.02)
        
        # ============================================================
        # REQUIREMENT 1 & 2: Define action and observation spaces
        # ============================================================
        
        # Action space: 3 discrete choices
        self.action_space = spaces.Discrete(3)  # 0=hold, 1=buy, 2=sell
        
        # Observation space: 4 continuous values
        self.observation_space = spaces.Box(
            low=np.array([-100, 0, 0, -1]),   # [position, cash, price, change]
            high=np.array([100, 2, 2, 1]),    # normalized values
            dtype=np.float32
        )
        
        self.reset()
    
    # ================================================================
    # REQUIREMENT 3: reset() method
    # ================================================================
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
        """Reset the environment for a new episode."""
        super().reset(seed=seed)
        
        # Reset state
        self.cash = self.initial_cash
        self.position = 0  # Number of shares owned
        self.price = 100.0
        self.initial_price = self.price
        self.step_count = 0
        self.prev_portfolio_value = self.cash
        
        return self._get_obs(), {}  # Return (observation, info)
    
    def _get_obs(self) -> np.ndarray:
        """Convert internal state to observation."""
        price_change = (self.price - self.initial_price) / self.initial_price
        return np.array([
            self.position,
            self.cash / self.initial_cash,      # Normalized cash
            self.price / self.initial_price,    # Normalized price
            price_change
        ], dtype=np.float32)
    
    def _get_portfolio_value(self) -> float:
        """Total value = cash + shares × price."""
        return self.cash + self.position * self.price
    
    # ================================================================
    # REQUIREMENT 4: step() method
    # ================================================================
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        """Execute one step in the environment."""
        
        # Execute action
        if action == 1:  # BUY
            if self.cash >= self.price:
                shares_to_buy = int(self.cash // self.price)
                self.position += shares_to_buy
                self.cash -= shares_to_buy * self.price
        elif action == 2:  # SELL
            if self.position > 0:
                self.cash += self.position * self.price
                self.position = 0
        # action == 0 means HOLD (do nothing)
        
        # Simulate price movement (random walk)
        price_change = np.random.normal(0.0001, self.volatility)
        self.price *= (1 + price_change)
        self.price = max(self.price, 1.0)  # Price floor
        
        # Calculate reward (change in portfolio value)
        current_value = self._get_portfolio_value()
        reward = (current_value - self.prev_portfolio_value) / self.initial_cash
        self.prev_portfolio_value = current_value
        
        # Check termination
        self.step_count += 1
        terminated = False  # No early termination
        truncated = self.step_count >= self.max_steps  # Episode ends after max_steps
        
        info = {
            "portfolio_value": current_value,
            "position": self.position,
            "price": self.price
        }
        
        return self._get_obs(), reward, terminated, truncated, info
    
    def render(self):
        """Print current state."""
        print(f"Step {self.step_count}: Price=${self.price:.2f}, "
              f"Position={self.position}, Cash=${self.cash:.2f}, "
              f"Value=${self._get_portfolio_value():.2f}")

In [None]:
# Test the environment
print("TESTING SimpleTradingEnv")
print("=" * 50)

env = SimpleTradingEnv()
obs, _ = env.reset()

print(f"\nObservation space: {env.observation_space}")
print(f"Action space: {env.action_space}")
print(f"\nInitial observation: {obs}")
print("  [position, cash_norm, price_norm, price_change]")

# Run a few steps
print("\nRunning 10 random steps:")
total_reward = 0
for i in range(10):
    action = env.action_space.sample()
    action_names = ["HOLD", "BUY", "SELL"]
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    print(f"  Action: {action_names[action]:>4}, Reward: {reward:>+.4f}, Value: ${info['portfolio_value']:.2f}")

print(f"\nTotal reward: {total_reward:.4f}")

---

## Registering with RLlib

To use your environment with RLlib, you need to register it.

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                    ENVIRONMENT REGISTRATION                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  Step 1: Create a "creator" function                                        │
│  ───────────────────────────────────                                        │
│                                                                             │
│      def env_creator(env_config):                                           │
│          return MyEnv(env_config)                                           │
│                                                                             │
│  Step 2: Register with a name                                               │
│  ──────────────────────────────                                             │
│                                                                             │
│      register_env("MyEnv-v0", env_creator)                                  │
│                                                                             │
│  Step 3: Use in config                                                      │
│  ─────────────────────                                                      │
│                                                                             │
│      config = PPOConfig().environment("MyEnv-v0", env_config={...})         │
│                                                                             │
│                                                                             │
│  WHY A CREATOR FUNCTION?                                                    │
│  RLlib creates multiple copies of your environment (one per worker).        │
│  The creator function is called for each copy.                              │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

In [None]:
# Initialize Ray
ray.init(
    num_cpus=4,
    object_store_memory=1 * 1024 * 1024 * 1024,
    ignore_reinit_error=True,
)

# Register the environment
def trading_env_creator(env_config):
    return SimpleTradingEnv(env_config)

register_env("SimpleTradingEnv", trading_env_creator)
print("Environment registered as 'SimpleTradingEnv'")

In [None]:
# Train on the custom environment
config = (
    PPOConfig()
    .environment(
        env="SimpleTradingEnv",  # Use registered name
        env_config={             # Passed to env_creator
            "initial_cash": 10000,
            "max_steps": 200,
            "volatility": 0.02,
        }
    )
    .framework("torch")
    .env_runners(num_env_runners=2)
    .training(
        lr=3e-4,
        train_batch_size=2000,
    )
)

algo = config.build_algo()

print("Training on SimpleTradingEnv...")
print("=" * 50)

for i in range(10):
    result = algo.train()
    mean_reward = result["env_runners"]["episode_return_mean"]
    print(f"Iter {i+1:>2}: Mean reward = {mean_reward:.3f}")

algo.stop()
print("\nTraining complete!")

---

## Example 2: Resource Management Environment

A more complex example with **continuous actions**.

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                    RESOURCE MANAGEMENT ENVIRONMENT                          │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  SCENARIO:                                                                  │
│  You manage servers for a website. Traffic varies throughout the day.       │
│  Goal: Minimize costs while keeping users happy.                            │
│                                                                             │
│      Traffic Pattern (24 hours):                                            │
│                                                                             │
│      Load │         _____                                                   │
│       1.0 │        /     \        Peak at noon                              │
│           │       /       \                                                 │
│       0.5 │      /         \                                                │
│           │_____/           \_____                                          │
│       0.0 └───────────────────────                                          │
│           0:00  6:00  12:00  18:00  24:00                                   │
│                                                                             │
│  STATE: [current_load, num_servers, queue_length, time_of_day]              │
│                                                                             │
│  ACTION: Continuous value from -1 to +1                                     │
│          -1 = remove 2 servers                                              │
│          +1 = add 2 servers                                                 │
│                                                                             │
│  REWARD:                                                                    │
│          - Server cost (more servers = higher cost)                         │
│          - Queue penalty (long queue = unhappy users)                       │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

In [None]:
class ResourceManagementEnv(gym.Env):
    """Environment for learning to manage compute resources."""
    
    def __init__(self, config: Dict = None):
        super().__init__()
        config = config or {}
        
        self.max_servers = config.get("max_servers", 10)
        self.server_cost = config.get("server_cost", 1.0)
        self.queue_penalty = config.get("queue_penalty", 0.5)
        self.max_steps = config.get("max_steps", 288)  # 24 hours × 12 (5-min intervals)
        
        # CONTINUOUS action space!
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(1,), dtype=np.float32
        )
        
        # Observation space
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, 0]),
            high=np.array([1, 1, 1, 1]),
            dtype=np.float32
        )
        
        self.reset()
    
    def _generate_load_pattern(self) -> np.ndarray:
        """Generate realistic daily load pattern."""
        t = np.linspace(0, 2 * np.pi, self.max_steps)
        base_load = 0.3 + 0.4 * np.sin(t - np.pi/2)  # Peak at noon
        noise = np.random.normal(0, 0.05, self.max_steps)
        return np.clip(base_load + noise, 0.1, 1.0)
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
        super().reset(seed=seed)
        self.load_pattern = self._generate_load_pattern()
        self.num_servers = 3
        self.queue_length = 0
        self.step_count = 0
        return self._get_obs(), {}
    
    def _get_obs(self) -> np.ndarray:
        current_load = self.load_pattern[self.step_count]
        time_of_day = self.step_count / self.max_steps
        return np.array([
            current_load,
            self.num_servers / self.max_servers,
            min(self.queue_length / 100, 1.0),
            time_of_day
        ], dtype=np.float32)
    
    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        # Continuous action: scale to server change (-2 to +2)
        server_change = int(np.round(action[0] * 2))
        self.num_servers = np.clip(
            self.num_servers + server_change, 1, self.max_servers
        )
        
        # Calculate capacity vs load
        current_load = self.load_pattern[self.step_count]
        capacity = self.num_servers / self.max_servers
        
        # Update queue
        if current_load > capacity:
            self.queue_length += (current_load - capacity) * 50
        else:
            self.queue_length = max(0, self.queue_length - 10)
        
        # Calculate reward (negative = cost)
        server_cost = self.num_servers * self.server_cost
        queue_cost = self.queue_length * self.queue_penalty
        reward = -(server_cost + queue_cost) / 100  # Normalize
        
        self.step_count += 1
        terminated = False
        truncated = self.step_count >= self.max_steps
        
        info = {"servers": self.num_servers, "queue": self.queue_length, "load": current_load}
        return self._get_obs(), reward, terminated, truncated, info

In [None]:
# Test and visualize
env = ResourceManagementEnv()
obs, _ = env.reset()

print("ResourceManagementEnv")
print("=" * 50)
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}  <- CONTINUOUS!")

# Visualize load pattern
plt.figure(figsize=(10, 4))
plt.plot(env.load_pattern)
plt.xlabel('Time Step (5-min intervals)')
plt.ylabel('Load')
plt.title('Daily Traffic Pattern')
plt.grid(True, alpha=0.3)
plt.show()

---

## Tips for Building Good Environments

```
┌─────────────────────────────────────────────────────────────────────────────┐
│                     ENVIRONMENT DESIGN TIPS                                 │
├─────────────────────────────────────────────────────────────────────────────┤
│                                                                             │
│  1. NORMALIZE OBSERVATIONS                                                  │
│     ─────────────────────────                                               │
│     Keep values roughly between -1 and 1 (or 0 and 1)                       │
│     Neural networks train faster with normalized inputs                     │
│                                                                             │
│     Bad:  obs = [100000, 0.001, 5000]  (wildly different scales)            │
│     Good: obs = [1.0, 0.1, 0.5]        (all similar scale)                  │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  2. SHAPE REWARDS CAREFULLY                                                 │
│     ───────────────────────────                                             │
│     Make rewards dense (give feedback often)                                │
│     Normalize to prevent huge values                                        │
│                                                                             │
│     Bad:  reward = 0 for 999 steps, then +1000000 at end                    │
│     Good: reward = small bonus every step you're doing well                 │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  3. INCLUDE ALL RELEVANT STATE                                              │
│     ──────────────────────────────                                          │
│     The agent only knows what you tell it!                                  │
│     If something affects the reward, include it in the observation          │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  4. TEST MANUALLY FIRST                                                     │
│     ──────────────────────                                                  │
│     Play your environment yourself!                                         │
│     If you can't figure out a good strategy, neither can the agent          │
│                                                                             │
│  ───────────────────────────────────────────────────────────────────────    │
│                                                                             │
│  5. USE THE RIGHT SPACE TYPE                                                │
│     ────────────────────────────                                            │
│     Discrete: finite choices (up/down, buy/sell)                            │
│     Box: continuous values (torque, speed)                                  │
│     MultiDiscrete: multiple independent choices                             │
│                                                                             │
└─────────────────────────────────────────────────────────────────────────────┘
```

## Key Takeaways

1. **5 requirements**: `__init__`, `observation_space`, `action_space`, `reset()`, `step()`

2. **Choose the right space**: Discrete for choices, Box for continuous

3. **Register with RLlib** using `register_env()`

4. **Normalize** observations and rewards for better training

## What's Next?

```
┌──────────────────┐          ┌──────────────────┐          ┌──────────────────┐
│ 03 Custom Envs   │   ───>   │ 05 Distributed   │   ───>   │ 06 Tune          │
│ (you are here)   │          │                  │          │                  │
│                  │          │ Scale training   │          │ Find best        │
│ - Gymnasium API  │          │ to many workers  │          │ hyperparameters  │
│ - Spaces         │          │                  │          │                  │
└──────────────────┘          └──────────────────┘          └──────────────────┘
```

In [None]:
ray.shutdown()