# Tutorial 2: RL Environments and Problem Formulation

**WSmart+ Route Tutorial Series**

This tutorial covers the RL4CO-style environment abstraction used in WSmart+ Route. You'll learn:

1. The **environment registry** and factory pattern
2. **VRPPEnv** deep dive: reset, step, action masks, rewards
3. **Manual rollout**: stepping through episodes
4. **WCVRPEnv** and **CVRPPEnv** variants
5. **Visualizing** routes and state evolution

**Previous**: [01_data_generation.ipynb](01_data_generation.ipynb) | **Next**: [03_models_and_policies.ipynb](03_models_and_policies.ipynb)

In [None]:
import os
import sys
import warnings

warnings.filterwarnings("ignore")

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import matplotlib.pyplot as plt
import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---
## 1. Environment Registry

WSmart+ Route provides several environment types, all accessible through a unified registry.

In [None]:
from logic.src.envs import ENV_REGISTRY, get_env

print("Available Environments:")
print("-" * 50)
for name, env_cls in ENV_REGISTRY.items():
    print(f"  {name:10s} -> {env_cls.__name__}")

### Environment Types

| Environment | Description | Key Features |
|------------|-------------|--------------|
| `vrpp` | Vehicle Routing with Profits | Maximize prize - travel cost |
| `cvrpp` | Capacitated VRPP | VRPP + vehicle capacity limits |
| `wcvrp` | Waste Collection VRP | Collect waste with fill levels |
| `cwcvrp` | Capacitated WCVRP | WCVRP + capacity constraints |
| `sdwcvrp` | Stochastic Demand WCVRP | Noisy demand observations |
| `scwcvrp` | Selective Capacitated WCVRP | Selective collection under noise |

---
## 2. VRPPEnv Deep Dive

The `VRPPEnv` is the foundational environment. The agent selects which nodes to visit to maximize **total prize** minus **travel cost**.

In [None]:
# Create VRPP environment with 20 nodes
env = get_env("vrpp", num_loc=20)

# Generate instances using the environment's built-in generator
td = env.generator(batch_size=4)
print("Generated instance keys:", list(td.keys()))
print(f"  locs:     {td['locs'].shape}  (batch, nodes, 2)")
print(f"  depot:    {td['depot'].shape}     (batch, 2)")
print(f"  waste:    {td['waste'].shape}    (batch, nodes)")
print(f"  prize:    {td['prize'].shape}    (batch, nodes)")

# Reset environment - initializes episode state
td = env.reset(td)
print("\nAfter reset, new keys added:")
new_keys = [k for k in td.keys() if k not in ["locs", "depot", "waste", "prize", "capacity", "max_length", "max_waste"]]
for key in sorted(new_keys):
    shape = td[key].shape
    print(f"  {key:20s} shape={shape}  value={td[key][0].tolist() if td[key][0].numel() <= 5 else '...'}")

In [None]:
# Action mask: which nodes can we visit?
mask = td["action_mask"]
print(f"Action mask shape: {mask.shape}")
print(f"Action mask for instance 0: {mask[0].int().tolist()}")
print(f"  Number of valid actions: {mask[0].sum().item()}")
print(f"  Depot (index 0) is always reachable: {mask[0, 0].item()}")

In [None]:
# Take a step: visit node 5
td["action"] = torch.tensor([5, 3, 7, 1])  # Different action per batch element
result = env.step(td)

# The step returns a dict with "next" containing the updated state
td_next = result["next"]
print("After stepping to nodes [5, 3, 7, 1]:")
print(f"  Current node: {td_next['current_node'].squeeze(-1).tolist()}")
print(f"  Tour length:  {td_next['tour_length'].tolist()}")
print(f"  Prize collected: {td_next['collected_prize'].tolist()}")
print(f"  Done: {td_next['done'].squeeze(-1).tolist()}")
print(f"  Visited nodes: {td_next['visited'][0].int().tolist()}")

In [None]:
# Take another step: visit node 10
td_next["action"] = torch.tensor([10, 8, 2, 15])
result2 = env.step(td_next)
td_next2 = result2["next"]

print("After second step:")
print(f"  Current node: {td_next2['current_node'].squeeze(-1).tolist()}")
print(f"  Tour length:  {[f'{x:.4f}' for x in td_next2['tour_length'].tolist()]}")
print(f"  Prize collected: {[f'{x:.4f}' for x in td_next2['collected_prize'].tolist()]}")
print(f"  Visited count: {td_next2['visited'].sum(dim=-1).tolist()}")

# Return to depot to end episode
td_next2["action"] = torch.zeros(4, dtype=torch.long)  # Action 0 = depot
result3 = env.step(td_next2)
td_final = result3["next"]
print(f"\nAfter returning to depot:")
print(f"  Done: {td_final['done'].squeeze(-1).tolist()}")
print(f"  Reward: {td_final['reward'].squeeze(-1).tolist()}")

---
## 3. Manual Rollout

Let's step through complete episodes with different action selection strategies.

In [None]:
def random_rollout(env, td, max_steps=100):
    """Execute a complete episode with random valid actions."""
    td = env.reset(td.clone())
    actions = []
    step = 0

    while not td["done"].all() and step < max_steps:
        # Select random valid actions
        mask = td["action_mask"].float()
        action = torch.multinomial(mask, 1).squeeze(-1)
        actions.append(action.clone())
        td["action"] = action
        td = env.step(td)["next"]
        step += 1

    reward = td["reward"].squeeze(-1) if "reward" in td.keys() else env._get_reward(td)
    return td, actions, reward


def greedy_nearest_rollout(env, td, max_steps=100):
    """Greedy nearest-neighbor heuristic."""
    td = env.reset(td.clone())
    actions = []
    step = 0

    while not td["done"].all() and step < max_steps:
        mask = td["action_mask"]  # (batch, nodes)
        current = td["current_node"].squeeze(-1)  # (batch,)
        locs = td["locs"]  # (batch, nodes, 2)

        # Compute distances from current node to all nodes
        current_loc = locs.gather(1, current[:, None, None].expand(-1, -1, 2)).squeeze(1)
        distances = torch.norm(locs - current_loc.unsqueeze(1), dim=-1)  # (batch, nodes)

        # Mask invalid nodes with large distance
        distances = distances.masked_fill(~mask, float("inf"))

        # Don't go to depot unless it's the only option
        non_depot_mask = mask.clone()
        non_depot_mask[:, 0] = False
        has_non_depot = non_depot_mask.any(dim=-1)
        distances[:, 0] = torch.where(has_non_depot, torch.tensor(float("inf")), distances[:, 0])

        action = distances.argmin(dim=-1)
        actions.append(action.clone())
        td["action"] = action
        td = env.step(td)["next"]
        step += 1

    reward = td["reward"].squeeze(-1) if "reward" in td.keys() else env._get_reward(td)
    return td, actions, reward

In [None]:
# Compare random vs greedy nearest-neighbor
env = get_env("vrpp", num_loc=20)
td_test = env.generator(batch_size=32)

td_rand, actions_rand, reward_rand = random_rollout(env, td_test)
td_greedy, actions_greedy, reward_greedy = greedy_nearest_rollout(env, td_test)

print("Rollout Comparison (32 instances, 20 nodes):")
print(f"  Random  - Mean reward: {reward_rand.mean():.4f} (+/- {reward_rand.std():.4f})")
print(f"  Greedy  - Mean reward: {reward_greedy.mean():.4f} (+/- {reward_greedy.std():.4f})")
print(f"  Greedy improvement: {((reward_greedy.mean() - reward_rand.mean()) / reward_rand.mean().abs() * 100):.1f}%")

In [None]:
def plot_tour(td_initial, td_final, actions, idx=0, title="Tour"):
    """Visualize a completed tour."""
    locs = td_initial["locs"][idx].numpy()
    depot = td_initial["depot"][idx].numpy()
    prizes = td_initial["prize"][idx].numpy()
    visited = td_final["visited"][idx].numpy()

    fig, ax = plt.subplots(1, 1, figsize=(8, 7))

    # Plot unvisited nodes
    unvisited_mask = ~visited[1:]  # Exclude depot
    if unvisited_mask.any():
        ax.scatter(locs[unvisited_mask, 0], locs[unvisited_mask, 1],
                   c="lightgray", s=60, edgecolors="gray", linewidth=0.5, zorder=2, label="Unvisited")

    # Plot visited nodes colored by prize
    visited_mask = visited[1:]
    if visited_mask.any():
        sc = ax.scatter(locs[visited_mask, 0], locs[visited_mask, 1],
                        c=prizes[visited_mask], cmap="YlOrRd", s=80,
                        edgecolors="black", linewidth=0.5, zorder=3, label="Visited")
        plt.colorbar(sc, ax=ax, label="Prize", shrink=0.8)

    # Plot depot
    ax.scatter(depot[0], depot[1], c="blue", s=200, marker="s",
               edgecolors="black", linewidth=1.5, zorder=4, label="Depot")

    # Draw tour path
    tour_nodes = [a[idx].item() for a in actions]
    all_locs = np.vstack([depot.reshape(1, 2), locs])  # depot at index 0
    path = [0] + tour_nodes  # Start from depot

    for i in range(len(path) - 1):
        start = all_locs[path[i]]
        end = all_locs[path[i + 1]]
        ax.annotate("", xy=end, xytext=start,
                     arrowprops=dict(arrowstyle="->", color="steelblue", lw=1.5))

    reward = td_final["reward"][idx].item() if "reward" in td_final.keys() else 0
    ax.set_title(f"{title}\nReward: {reward:.4f} | Nodes visited: {visited.sum() - 1}")
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.legend(loc="upper right")
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)

    plt.tight_layout()
    plt.show()


# Visualize random vs greedy tour for same instance
plot_tour(td_test, td_rand, actions_rand, idx=0, title="Random Policy Tour")
plot_tour(td_test, td_greedy, actions_greedy, idx=0, title="Greedy Nearest-Neighbor Tour")

---
## 4. Waste Collection VRP Environment

The `WCVRPEnv` models waste collection where bins fill over time. The objective balances **waste collected** against **travel cost** and **overflow penalties**.

In [None]:
# Create waste collection environment
wc_env = get_env("wcvrp", num_loc=20)
td_wc = wc_env.generator(batch_size=8)
td_wc = wc_env.reset(td_wc)

print("WCVRP State after reset:")
for key in sorted(td_wc.keys()):
    shape = td_wc[key].shape
    print(f"  {key:20s} {str(shape):20s}")

In [None]:
# Run greedy rollout on WCVRP
td_wc_test = wc_env.generator(batch_size=32)
td_wc_final, actions_wc, reward_wc = greedy_nearest_rollout(wc_env, td_wc_test)

print("WCVRP Greedy Rollout Results:")
print(f"  Mean reward: {reward_wc.mean():.4f}")
print(f"  Std reward:  {reward_wc.std():.4f}")

---
## 5. Capacitated VRPP Environment

The `CVRPPEnv` adds **vehicle capacity constraints** to VRPP. The vehicle can only carry a limited amount of waste/goods before returning to the depot.

In [None]:
# Create capacitated environment
c_env = get_env("cvrpp", num_loc=20)
td_c = c_env.generator(batch_size=4)
td_c = c_env.reset(td_c)

print("CVRPP State after reset:")
print(f"  remaining_capacity: {td_c['remaining_capacity'].tolist()}")
print(f"  collected: {td_c['collected'].tolist()}")

# Step to a node - capacity decreases
td_c["action"] = torch.tensor([3, 5, 7, 2])
td_c_next = c_env.step(td_c)["next"]
print(f"\nAfter visiting a node:")
print(f"  remaining_capacity: {[f'{x:.4f}' for x in td_c_next['remaining_capacity'].tolist()]}")
print(f"  collected: {[f'{x:.4f}' for x in td_c_next['collected'].tolist()]}")

---
## 6. Environment Configuration

Environments can be configured programmatically using `EnvConfig` dataclass.

In [None]:
from logic.src.configs import EnvConfig

# Default configuration
default_cfg = EnvConfig()
print("Default EnvConfig:")
for field_name in ["name", "num_loc", "capacity", "cost_weight", "prize_weight",
                   "overflow_penalty", "collection_reward"]:
    print(f"  {field_name}: {getattr(default_cfg, field_name)}")

# Custom configuration
custom_cfg = EnvConfig(
    name="vrpp",
    num_loc=50,
    cost_weight=2.0,     # Penalize distance more
    prize_weight=1.0,
)
print(f"\nCustom config: {custom_cfg.num_loc} nodes, cost_weight={custom_cfg.cost_weight}")

In [None]:
# Show how reward weights affect behavior
env_balanced = get_env("vrpp", num_loc=20, prize_weight=1.0, cost_weight=1.0)
env_prize_focus = get_env("vrpp", num_loc=20, prize_weight=2.0, cost_weight=0.5)
env_cost_focus = get_env("vrpp", num_loc=20, prize_weight=0.5, cost_weight=2.0)

td_compare = env_balanced.generator(batch_size=64)

configs = [
    ("Balanced (1:1)", env_balanced),
    ("Prize Focus (2:0.5)", env_prize_focus),
    ("Cost Focus (0.5:2)", env_cost_focus),
]

print("Reward Weight Comparison (greedy nearest-neighbor on 64 instances):")
print("-" * 55)
for name, env_cfg in configs:
    _, _, rewards = greedy_nearest_rollout(env_cfg, td_compare)
    print(f"  {name:25s} Mean reward: {rewards.mean():.4f}")

---
## Summary

In this tutorial, you learned:

- **Environment registry** provides 6 problem types accessible via `get_env(name)`
- **VRPPEnv** manages state (visited nodes, tour length, prize collected) through `reset()` and `step()` calls
- **Action masks** prevent invalid actions (already visited nodes)
- **Manual rollouts** with random and greedy nearest-neighbor strategies
- **WCVRPEnv** and **CVRPPEnv** add waste collection and capacity constraints
- **Reward weights** (`prize_weight`, `cost_weight`) control the optimization objective

### Next Steps

Continue to **[Tutorial 3: Neural Models and Classical Policies](03_models_and_policies.ipynb)** to learn how trained neural networks and classical optimization algorithms solve these routing problems.