# Tutorial 5: Evaluation and Decoding Strategies

**WSmart+ Route Tutorial Series**

This tutorial covers how to evaluate trained models and explore different decoding strategies. You'll learn:

1. **Training a small model** for evaluation demos
2. **Decoding strategies**: greedy, sampling, multi-sample
3. **Temperature scaling** for sampling
4. **Evaluation metrics** and cost decomposition
5. **Visualizing tours** and comparing solution quality

**Previous**: [04_training_with_lightning.ipynb](04_training_with_lightning.ipynb) | **Next**: [06_simulation_testing.ipynb](06_simulation_testing.ipynb)

In [None]:
import os
import sys
import warnings

warnings.filterwarnings("ignore")

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import matplotlib.pyplot as plt
import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---
## 1. Quick Training Setup

First, let's train a small model so we have something to evaluate. We'll use a short training run with REINFORCE.

In [None]:
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE
from logic.src.pipeline.rl.common.trainer import WSTrainer

# Create components
env = get_env("vrpp", num_loc=20)
policy = AttentionModelPolicy(
    env_name="vrpp", embed_dim=64, n_encode_layers=2, n_decode_layers=2, n_heads=4,
)

model = REINFORCE(
    env=env, policy=policy, baseline="exponential",
    optimizer="adam", optimizer_kwargs={"lr": 1e-4},
    train_data_size=1280, val_data_size=256, batch_size=64, num_workers=0,
)

# Quick training
trainer = WSTrainer(
    max_epochs=5, accelerator="cpu", devices=1, precision="32-true",
    log_every_n_steps=5, enable_progress_bar=True, logger=False,
    reload_dataloaders_every_n_epochs=1,
)

print("Training model for evaluation demos (5 epochs)...")
trainer.fit(model)
print("Training complete!")

In [None]:
# Generate a fixed test dataset
torch.manual_seed(999)  # Different seed from training
test_data = env.generator(batch_size=128)
print(f"Test dataset: {test_data.batch_size[0]} instances, {test_data['locs'].shape[1]} nodes")

---
## 2. Decoding Strategies

The decoder converts the neural network's probability distribution over nodes into actual route selections. Different strategies trade off speed vs solution quality.

In [None]:
# Strategy 1: Greedy - always pick the highest probability node
td_test = env.reset(test_data.clone())

with torch.no_grad():
    out_greedy = policy(td_test, env, strategy="greedy", return_actions=True)

print("Greedy Decoding:")
print(f"  Mean reward:  {out_greedy['reward'].mean():.4f}")
print(f"  Std reward:   {out_greedy['reward'].std():.4f}")
print(f"  Best reward:  {out_greedy['reward'].max():.4f}")
print(f"  Worst reward: {out_greedy['reward'].min():.4f}")

In [None]:
# Strategy 2: Sampling - sample from the probability distribution
td_test = env.reset(test_data.clone())

with torch.no_grad():
    out_sample = policy(td_test, env, strategy="sampling", return_actions=True)

print("Sampling Decoding (single sample):")
print(f"  Mean reward:  {out_sample['reward'].mean():.4f}")
print(f"  Std reward:   {out_sample['reward'].std():.4f}")

In [None]:
# Strategy 3: Multi-sample - run N samples, keep the best for each instance
def multi_sample_eval(policy, env, test_data, n_samples=8):
    """Run multiple sampling rollouts and keep the best solution per instance."""
    all_rewards = []
    all_actions = []

    for _ in range(n_samples):
        td = env.reset(test_data.clone())
        with torch.no_grad():
            out = policy(td, env, strategy="sampling", return_actions=True)
        all_rewards.append(out["reward"])
        all_actions.append(out["actions"])

    # Stack and find best per instance
    stacked_rewards = torch.stack(all_rewards, dim=0)  # (n_samples, batch)
    best_idx = stacked_rewards.argmax(dim=0)  # (batch,)
    best_rewards = stacked_rewards.max(dim=0).values

    return best_rewards, stacked_rewards


# Compare different sample counts
n_samples_list = [1, 2, 4, 8, 16, 32, 64]
multi_results = {}

for n in n_samples_list:
    best_rewards, _ = multi_sample_eval(policy, env, test_data, n_samples=n)
    multi_results[n] = best_rewards.mean().item()
    print(f"  {n:3d} samples -> Mean best reward: {multi_results[n]:.4f}")

In [None]:
fig, ax = plt.subplots(figsize=(9, 5))

ax.plot(n_samples_list, list(multi_results.values()), "o-", linewidth=2,
        markersize=8, color="steelblue", label="Multi-sample best")
ax.axhline(y=out_greedy["reward"].mean().item(), color="red", linestyle="--",
           linewidth=1.5, label=f"Greedy: {out_greedy['reward'].mean():.4f}")
ax.set_xlabel("Number of Samples", fontsize=12)
ax.set_ylabel("Mean Best Reward", fontsize=12)
ax.set_title("Reward vs Number of Samples")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xscale("log", base=2)
ax.set_xticks(n_samples_list)
ax.set_xticklabels(n_samples_list)

plt.tight_layout()
plt.show()

print("Multi-sampling trades computation time for solution quality.")
print("Greedy gives a strong deterministic baseline; sampling with enough")
print("trials can exceed it by exploring diverse solutions.")

---
## 3. Temperature Scaling

Temperature controls the randomness of sampling. Lower temperatures make the distribution sharper (more greedy-like), while higher temperatures increase exploration.

In [None]:
temperatures = [0.1, 0.5, 1.0, 2.0, 5.0]
temp_results = {}

for temp in temperatures:
    td = env.reset(test_data.clone())
    with torch.no_grad():
        out = policy(td, env, strategy="sampling", softmax_temp=temp, return_actions=True)
    temp_results[temp] = out["reward"].mean().item()

print("Temperature Scaling (single sample):")
print(f"{'Temperature':<15} {'Mean Reward':<15}")
print("-" * 30)
for temp, reward in temp_results.items():
    print(f"{temp:<15.1f} {reward:<15.4f}")

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))

temps = list(temp_results.keys())
rewards = list(temp_results.values())

ax.plot(temps, rewards, "s-", linewidth=2, markersize=8, color="coral")
ax.axhline(y=out_greedy["reward"].mean().item(), color="steelblue", linestyle="--",
           linewidth=1.5, label=f"Greedy: {out_greedy['reward'].mean():.4f}")
ax.set_xlabel("Temperature", fontsize=12)
ax.set_ylabel("Mean Reward (single sample)", fontsize=12)
ax.set_title("Effect of Softmax Temperature on Sampling Quality")
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

---
## 4. Evaluation Metrics

Beyond raw reward, we can decompose performance into meaningful metrics.

In [None]:
# Evaluate and decompose metrics
td_metrics = env.reset(test_data.clone())

with torch.no_grad():
    out = policy(td_metrics, env, strategy="greedy", return_actions=True)

# Access the final state from the environment
td_final = out.get("td", td_metrics)

print("Evaluation Metrics (128 test instances, greedy decoding):")
print("=" * 50)
print(f"  Reward:          {out['reward'].mean():.4f} +/- {out['reward'].std():.4f}")
print(f"  Log-likelihood:  {out['log_likelihood'].mean():.4f}")

if "entropy" in out:
    print(f"  Entropy:         {out['entropy'].mean():.4f}")

# Tour statistics from actions
actions = out["actions"]
tour_lengths = (actions != 0).sum(dim=-1).float()  # Non-depot visits
print(f"  Avg tour length: {tour_lengths.mean():.1f} nodes")
print(f"  Min tour length: {tour_lengths.min():.0f} nodes")
print(f"  Max tour length: {tour_lengths.max():.0f} nodes")

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Reward distribution
ax = axes[0]
rewards = out["reward"].numpy()
ax.hist(rewards, bins=25, alpha=0.7, color="steelblue", edgecolor="white")
ax.axvline(x=rewards.mean(), color="red", linestyle="--", label=f"Mean: {rewards.mean():.4f}")
ax.set_xlabel("Reward")
ax.set_ylabel("Count")
ax.set_title("Reward Distribution (Greedy)")
ax.legend()

# Right: Tour length distribution
ax = axes[1]
lengths = tour_lengths.numpy()
ax.hist(lengths, bins=15, alpha=0.7, color="coral", edgecolor="white")
ax.axvline(x=lengths.mean(), color="red", linestyle="--", label=f"Mean: {lengths.mean():.1f}")
ax.set_xlabel("Tour Length (nodes visited)")
ax.set_ylabel("Count")
ax.set_title("Tour Length Distribution")
ax.legend()

plt.tight_layout()
plt.show()

---
## 5. Tour Visualization

In [None]:
def plot_solution(test_data, actions, idx=0, title="Solution"):
    """Plot a solution tour."""
    locs = test_data["locs"][idx].numpy()
    depot = test_data["depot"][idx].numpy()
    prizes = test_data["prize"][idx].numpy()
    tour = actions[idx].numpy()

    # Remove trailing zeros (padding)
    tour = tour[tour != 0] if (tour != 0).any() else tour[:1]

    fig, ax = plt.subplots(figsize=(8, 7))

    # All locations
    all_locs = np.vstack([depot.reshape(1, 2), locs])

    # Determine visited nodes
    visited_set = set(tour.tolist())
    unvisited = [i for i in range(1, len(locs) + 1) if i not in visited_set]
    visited = [i for i in range(1, len(locs) + 1) if i in visited_set]

    # Plot unvisited
    if unvisited:
        uv_locs = all_locs[unvisited]
        ax.scatter(uv_locs[:, 0], uv_locs[:, 1], c="lightgray", s=50,
                   edgecolors="gray", linewidth=0.5, zorder=2, label="Unvisited")

    # Plot visited colored by prize
    if visited:
        v_locs = all_locs[visited]
        v_prizes = prizes[[v - 1 for v in visited]]
        sc = ax.scatter(v_locs[:, 0], v_locs[:, 1], c=v_prizes, cmap="YlOrRd",
                        s=80, edgecolors="black", linewidth=0.5, zorder=3, label="Visited")
        plt.colorbar(sc, ax=ax, label="Prize", shrink=0.8)

    # Depot
    ax.scatter(depot[0], depot[1], c="blue", s=200, marker="s",
               edgecolors="black", linewidth=1.5, zorder=4, label="Depot")

    # Draw path
    path = [0] + tour.tolist() + [0]
    for i in range(len(path) - 1):
        start = all_locs[path[i]]
        end = all_locs[path[i + 1]]
        ax.annotate("", xy=end, xytext=start,
                     arrowprops=dict(arrowstyle="->", color="steelblue", lw=1.5))

    # Add node labels
    for i, (x, y) in enumerate(all_locs):
        if i == 0:
            continue
        ax.annotate(str(i), (x, y), textcoords="offset points",
                    xytext=(5, 5), fontsize=7, alpha=0.6)

    ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")
    ax.legend(loc="upper right", fontsize=9)
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    plt.tight_layout()
    plt.show()


# Visualize best and worst solutions
best_idx = out["reward"].argmax().item()
worst_idx = out["reward"].argmin().item()

plot_solution(test_data, out["actions"], idx=best_idx,
              title=f"Best Solution (reward: {out['reward'][best_idx]:.4f})")
plot_solution(test_data, out["actions"], idx=worst_idx,
              title=f"Worst Solution (reward: {out['reward'][worst_idx]:.4f})")

---
## 6. Greedy vs Sampling Side-by-Side

In [None]:
# Get greedy and sampling solutions for same instance
idx = 0

td_g = env.reset(test_data.clone())
td_s = env.reset(test_data.clone())

with torch.no_grad():
    out_g = policy(td_g, env, strategy="greedy", return_actions=True)
    out_s = policy(td_s, env, strategy="sampling", return_actions=True)

fig, axes = plt.subplots(1, 2, figsize=(16, 7))

for ax, (out_d, decode_name) in zip(axes, [(out_g, "Greedy"), (out_s, "Sampling")]):
    locs = test_data["locs"][idx].numpy()
    depot = test_data["depot"][idx].numpy()
    tour = out_d["actions"][idx].numpy()
    tour = tour[tour != 0] if (tour != 0).any() else tour[:1]
    all_locs = np.vstack([depot.reshape(1, 2), locs])

    ax.scatter(locs[:, 0], locs[:, 1], c="lightblue", s=50, edgecolors="gray", linewidth=0.5, zorder=2)
    ax.scatter(depot[0], depot[1], c="blue", s=200, marker="s", edgecolors="black", linewidth=1.5, zorder=4)

    path = [0] + tour.tolist() + [0]
    for i in range(len(path) - 1):
        start = all_locs[path[i]]
        end = all_locs[path[i + 1]]
        ax.annotate("", xy=end, xytext=start,
                     arrowprops=dict(arrowstyle="->", color="steelblue", lw=1.5))

    reward = out_d["reward"][idx].item()
    ax.set_title(f"{decode_name} (reward: {reward:.4f})")
    ax.set_xlim(-0.05, 1.05)
    ax.set_ylim(-0.05, 1.05)
    ax.set_aspect("equal")

plt.suptitle("Greedy vs Sampling Decoding", fontsize=14)
plt.tight_layout()
plt.show()

---
## 7. Strategy Selection Guide

| Strategy | Speed | Quality | Use Case |
|----------|-------|---------|----------|
| Greedy | Fast | Good baseline | Real-time deployment, quick evaluation |
| Sampling (1x) | Fast | Variable | Training (exploration needed) |
| Sampling (Nx) | N x slower | Better than greedy | When quality matters more than speed |
| Low temperature | Fast | Near-greedy | Fine-tuning exploration-exploitation |

---
## Summary

In this tutorial, you learned:

- **Greedy decoding** always picks the most likely node - fast and deterministic
- **Sampling decoding** introduces randomness for exploration during training
- **Multi-sample** (best-of-N) improves quality at the cost of N times more computation
- **Temperature** controls the sharpness of sampling distributions
- **Evaluation metrics** include reward, tour length, and per-instance statistics
- **Tour visualization** helps understand model behavior and solution quality

### Next Steps

Continue to **[Tutorial 6: Multi-Day Simulation Testing](06_simulation_testing.ipynb)** to learn how to test routing policies (both neural and classical) in realistic multi-day waste collection scenarios.