# 3.1 Creating Custom Gymnasium Environments

## Learning Objectives
- Understand the Gymnasium environment interface
- Create custom environments from scratch
- Register environments with RLlib
- Handle different observation and action spaces

In [None]:
import gymnasium as gym
from gymnasium import spaces
import numpy as np
import ray
from ray.rllib.algorithms.ppo import PPOConfig
from ray.tune.registry import register_env
import matplotlib.pyplot as plt
from typing import Optional, Tuple, Dict, Any

## The Gymnasium Interface

Every Gymnasium environment must implement:

```python
class MyEnv(gym.Env):
    def __init__(self, config=None):
        self.observation_space = ...  # Define observation space
        self.action_space = ...       # Define action space
    
    def reset(self, *, seed=None, options=None):
        # Return (observation, info)
        return obs, info
    
    def step(self, action):
        # Return (observation, reward, terminated, truncated, info)
        return obs, reward, terminated, truncated, info
```

## Example 1: Simple Trading Environment

In [None]:
class SimpleTradingEnv(gym.Env):
    """
    Simple trading environment.
    
    State: [position, cash, price, price_change]
    Actions: 0=hold, 1=buy, 2=sell
    """
    
    metadata = {"render_modes": ["human"]}
    
    def __init__(self, config: Dict = None):
        super().__init__()
        config = config or {}
        
        self.initial_cash = config.get("initial_cash", 10000)
        self.max_steps = config.get("max_steps", 200)
        self.volatility = config.get("volatility", 0.02)
        
        # Action space: hold, buy, sell
        self.action_space = spaces.Discrete(3)
        
        # Observation space: [position, cash_normalized, price_normalized, price_change]
        self.observation_space = spaces.Box(
            low=np.array([-100, 0, 0, -1]),
            high=np.array([100, 2, 2, 1]),
            dtype=np.float32
        )
        
        self.reset()
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
        super().reset(seed=seed)
        
        self.cash = self.initial_cash
        self.position = 0  # Number of shares
        self.price = 100.0
        self.initial_price = self.price
        self.step_count = 0
        self.prev_portfolio_value = self.cash
        
        return self._get_obs(), {}
    
    def _get_obs(self) -> np.ndarray:
        """Get current observation."""
        price_change = (self.price - self.initial_price) / self.initial_price
        return np.array([
            self.position,
            self.cash / self.initial_cash,
            self.price / self.initial_price,
            price_change
        ], dtype=np.float32)
    
    def _get_portfolio_value(self) -> float:
        """Calculate total portfolio value."""
        return self.cash + self.position * self.price
    
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        # Execute action
        if action == 1:  # Buy
            if self.cash >= self.price:
                shares_to_buy = int(self.cash // self.price)
                self.position += shares_to_buy
                self.cash -= shares_to_buy * self.price
        elif action == 2:  # Sell
            if self.position > 0:
                self.cash += self.position * self.price
                self.position = 0
        
        # Simulate price movement (random walk with drift)
        price_change = np.random.normal(0.0001, self.volatility)
        self.price *= (1 + price_change)
        self.price = max(self.price, 1.0)  # Price floor
        
        # Calculate reward (change in portfolio value)
        current_value = self._get_portfolio_value()
        reward = (current_value - self.prev_portfolio_value) / self.initial_cash
        self.prev_portfolio_value = current_value
        
        self.step_count += 1
        terminated = False
        truncated = self.step_count >= self.max_steps
        
        info = {
            "portfolio_value": current_value,
            "position": self.position,
            "price": self.price
        }
        
        return self._get_obs(), reward, terminated, truncated, info
    
    def render(self):
        print(f"Step {self.step_count}: Price=${self.price:.2f}, "
              f"Position={self.position}, Cash=${self.cash:.2f}, "
              f"Value=${self._get_portfolio_value():.2f}")

In [None]:
# Test the environment
env = SimpleTradingEnv()
obs, _ = env.reset()
print(f"Initial observation: {obs}")
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")

# Run a few steps
total_reward = 0
for i in range(10):
    action = env.action_space.sample()
    obs, reward, terminated, truncated, info = env.step(action)
    total_reward += reward
    env.render()

print(f"\nTotal reward: {total_reward:.4f}")

## Example 2: Resource Management Environment

In [None]:
class ResourceManagementEnv(gym.Env):
    """
    Environment for learning to manage compute resources.
    
    State: [current_load, num_servers, queue_length, time_of_day]
    Actions: -1=remove server, 0=do nothing, 1=add server
    """
    
    def __init__(self, config: Dict = None):
        super().__init__()
        config = config or {}
        
        self.max_servers = config.get("max_servers", 10)
        self.server_cost = config.get("server_cost", 1.0)
        self.queue_penalty = config.get("queue_penalty", 0.5)
        self.max_steps = config.get("max_steps", 288)  # 24 hours * 12 (5-min intervals)
        
        # Continuous action space for scaling
        self.action_space = spaces.Box(
            low=-1.0, high=1.0, shape=(1,), dtype=np.float32
        )
        
        # Observation space
        self.observation_space = spaces.Box(
            low=np.array([0, 0, 0, 0]),
            high=np.array([1, 1, 1, 1]),
            dtype=np.float32
        )
        
        self.reset()
    
    def _generate_load_pattern(self) -> np.ndarray:
        """Generate realistic daily load pattern."""
        # Sinusoidal pattern with peak at noon
        t = np.linspace(0, 2 * np.pi, self.max_steps)
        base_load = 0.3 + 0.4 * np.sin(t - np.pi/2)  # Peak at t=Ï€
        noise = np.random.normal(0, 0.05, self.max_steps)
        return np.clip(base_load + noise, 0.1, 1.0)
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
        super().reset(seed=seed)
        
        self.load_pattern = self._generate_load_pattern()
        self.num_servers = 3
        self.queue_length = 0
        self.step_count = 0
        
        return self._get_obs(), {}
    
    def _get_obs(self) -> np.ndarray:
        current_load = self.load_pattern[self.step_count]
        time_of_day = self.step_count / self.max_steps
        
        return np.array([
            current_load,
            self.num_servers / self.max_servers,
            min(self.queue_length / 100, 1.0),
            time_of_day
        ], dtype=np.float32)
    
    def step(self, action: np.ndarray) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        # Scale action to server change
        server_change = int(np.round(action[0] * 2))  # -2 to +2
        self.num_servers = np.clip(
            self.num_servers + server_change, 1, self.max_servers
        )
        
        # Calculate capacity and queue
        current_load = self.load_pattern[self.step_count]
        capacity = self.num_servers / self.max_servers
        
        # Update queue based on load vs capacity
        if current_load > capacity:
            self.queue_length += (current_load - capacity) * 50
        else:
            self.queue_length = max(0, self.queue_length - 10)
        
        # Calculate reward
        server_cost = self.num_servers * self.server_cost
        queue_cost = self.queue_length * self.queue_penalty
        reward = -server_cost - queue_cost
        
        # Normalize reward
        reward = reward / 100
        
        self.step_count += 1
        terminated = False
        truncated = self.step_count >= self.max_steps
        
        info = {
            "servers": self.num_servers,
            "queue": self.queue_length,
            "load": current_load
        }
        
        return self._get_obs(), reward, terminated, truncated, info

In [None]:
# Test resource management env
env = ResourceManagementEnv()
obs, _ = env.reset()
print(f"Observation space: {env.observation_space}")
print(f"Action space: {env.action_space}")

# Visualize load pattern
plt.figure(figsize=(10, 4))
plt.plot(env.load_pattern)
plt.xlabel('Time Step (5-min intervals)')
plt.ylabel('Load')
plt.title('Daily Load Pattern')
plt.grid(True, alpha=0.3)
plt.show()

## Registering with RLlib

In [None]:
ray.init(ignore_reinit_error=True)

# Register custom environments
def trading_env_creator(env_config):
    return SimpleTradingEnv(env_config)

def resource_env_creator(env_config):
    return ResourceManagementEnv(env_config)

register_env("SimpleTradingEnv", trading_env_creator)
register_env("ResourceManagementEnv", resource_env_creator)

print("Environments registered!")

In [None]:
# Train on custom trading environment
trading_config = (
    PPOConfig()
    .environment(
        env="SimpleTradingEnv",
        env_config={
            "initial_cash": 10000,
            "max_steps": 200,
            "volatility": 0.02,
        }
    )
    .framework("torch")
    .env_runners(num_env_runners=2)
    .training(
        lr=3e-4,
        train_batch_size=2000,
    )
)

algo = trading_config.build()

print("Training on SimpleTradingEnv...")
for i in range(10):
    result = algo.train()
    print(f"Iter {i+1}: Reward = {result['env_runners']['episode_reward_mean']:.2f}")

algo.stop()

## Different Space Types

In [None]:
# Common Gymnasium space types

# Discrete - single integer action
discrete_space = spaces.Discrete(5)  # Actions: 0, 1, 2, 3, 4
print(f"Discrete: {discrete_space.sample()}")

# Box - continuous bounded values
box_space = spaces.Box(low=-1.0, high=1.0, shape=(3,), dtype=np.float32)
print(f"Box: {box_space.sample()}")

# MultiDiscrete - multiple discrete actions
multi_discrete = spaces.MultiDiscrete([3, 4, 2])  # 3 actions with different ranges
print(f"MultiDiscrete: {multi_discrete.sample()}")

# MultiBinary - binary flags
multi_binary = spaces.MultiBinary(4)
print(f"MultiBinary: {multi_binary.sample()}")

# Dict - nested spaces
dict_space = spaces.Dict({
    "position": spaces.Box(low=-10, high=10, shape=(2,)),
    "velocity": spaces.Box(low=-1, high=1, shape=(2,)),
    "target": spaces.Discrete(4)
})
print(f"Dict: {dict_space.sample()}")

# Tuple - ordered collection
tuple_space = spaces.Tuple((
    spaces.Discrete(3),
    spaces.Box(low=0, high=1, shape=(2,))
))
print(f"Tuple: {tuple_space.sample()}")

## Example 3: Image-Based Environment

In [None]:
class SimpleVisualEnv(gym.Env):
    """
    Simple environment with image observations.
    Agent must find target in a grid (represented as image).
    """
    
    def __init__(self, config: Dict = None):
        super().__init__()
        config = config or {}
        
        self.grid_size = config.get("grid_size", 8)
        self.image_size = config.get("image_size", 64)
        
        # Image observation (RGB)
        self.observation_space = spaces.Box(
            low=0, high=255,
            shape=(self.image_size, self.image_size, 3),
            dtype=np.uint8
        )
        
        # Movement actions: up, right, down, left
        self.action_space = spaces.Discrete(4)
        
        self.reset()
    
    def reset(self, *, seed: Optional[int] = None, options: Optional[Dict] = None):
        super().reset(seed=seed)
        
        # Random agent and target positions
        self.agent_pos = np.array([0, 0])
        self.target_pos = np.array([
            np.random.randint(0, self.grid_size),
            np.random.randint(0, self.grid_size)
        ])
        
        # Ensure they're not the same
        while np.array_equal(self.agent_pos, self.target_pos):
            self.target_pos = np.array([
                np.random.randint(0, self.grid_size),
                np.random.randint(0, self.grid_size)
            ])
        
        self.steps = 0
        return self._render_image(), {}
    
    def _render_image(self) -> np.ndarray:
        """Render the grid as an RGB image."""
        cell_size = self.image_size // self.grid_size
        img = np.ones((self.image_size, self.image_size, 3), dtype=np.uint8) * 255
        
        # Draw grid
        for i in range(self.grid_size + 1):
            pos = i * cell_size
            img[pos:pos+1, :] = 200
            img[:, pos:pos+1] = 200
        
        # Draw target (green)
        ty, tx = self.target_pos * cell_size
        img[ty+2:ty+cell_size-2, tx+2:tx+cell_size-2] = [0, 255, 0]
        
        # Draw agent (blue)
        ay, ax = self.agent_pos * cell_size
        img[ay+2:ay+cell_size-2, ax+2:ax+cell_size-2] = [0, 0, 255]
        
        return img
    
    def step(self, action: int) -> Tuple[np.ndarray, float, bool, bool, Dict]:
        # Move agent
        moves = [(-1, 0), (0, 1), (1, 0), (0, -1)]  # up, right, down, left
        dy, dx = moves[action]
        
        new_pos = self.agent_pos + np.array([dy, dx])
        new_pos = np.clip(new_pos, 0, self.grid_size - 1)
        self.agent_pos = new_pos
        
        self.steps += 1
        
        # Check if reached target
        if np.array_equal(self.agent_pos, self.target_pos):
            return self._render_image(), 10.0, True, False, {}
        
        # Small negative reward for each step
        truncated = self.steps >= 100
        return self._render_image(), -0.1, False, truncated, {}

In [None]:
# Test visual environment
env = SimpleVisualEnv()
obs, _ = env.reset()

print(f"Observation shape: {obs.shape}")

fig, axes = plt.subplots(1, 4, figsize=(16, 4))
axes[0].imshow(obs)
axes[0].set_title("Initial")

for i, action in enumerate([1, 2, 2]):
    obs, reward, done, _, _ = env.step(action)
    axes[i+1].imshow(obs)
    axes[i+1].set_title(f"After action {action}")

plt.tight_layout()
plt.show()

## Using CNN with RLlib for Image Environments

In [None]:
# Register visual environment
register_env("SimpleVisualEnv", lambda c: SimpleVisualEnv(c))

# Configure PPO with CNN model
visual_config = (
    PPOConfig()
    .environment(
        env="SimpleVisualEnv",
        env_config={"grid_size": 8, "image_size": 64}
    )
    .framework("torch")
    .env_runners(num_env_runners=2)
    .training(
        lr=1e-4,
        train_batch_size=2000,
        model={
            # Use CNN for image input
            "conv_filters": [
                [16, [4, 4], 2],  # [num_filters, kernel_size, stride]
                [32, [4, 4], 2],
                [64, [4, 4], 2],
            ],
            "fcnet_hiddens": [256],
            "fcnet_activation": "relu",
        },
    )
)

print("Visual environment config created!")
# Training would be similar to before

## Key Takeaways

1. **Gymnasium interface** requires `reset()`, `step()`, and space definitions

2. **Register environments** with RLlib using `register_env()`

3. **Choose appropriate spaces** for your problem (Discrete, Box, Dict, etc.)

4. **Normalize observations** for better training stability

## Next Steps

In the next notebook, we'll explore multi-agent environments.

In [None]:
ray.shutdown()