# Tutorial 7: Extending the Codebase

**WSmart+ Route Tutorial Series**

One of the core design goals of WSmart+ Route is extensibility. This tutorial demonstrates how to add new components without modifying the core library code. You'll learn how to:

1. **Create a custom classical policy** (e.g., a geometric heuristic)
2. **Implement a custom neural module** (e.g., a specialized encoder)
3. **Integrate custom components** into the training pipeline

**Previous**: [06_simulation_testing.ipynb](06_simulation_testing.ipynb)

In [None]:
import os
import sys
import warnings

warnings.filterwarnings("ignore")

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import math
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---
## 1. Adding a Custom Classical Policy

Policies in WSmart+ Route follow the **Adapter Pattern**. To add a new policy, you must:

1. Inherit from the `IPolicy` interface (in `logic.src.policies.adapters`).
2. Implement the `execute(self, **kwargs)` method.
3. Register the class using the `@PolicyRegistry.register` decorator.

Let's implement a **Spiral Policy** properly wrapped as an `IPolicy`.

In [None]:
from typing import Any, List, Tuple
from logic.src.policies.adapters import IPolicy, PolicyRegistry

@PolicyRegistry.register("spiral")
class SpiralPolicy(IPolicy):
    """
    Collects bins above threshold, ordering them by angle (spiral sweep).
    """
    
    def execute(self, **kwargs: Any) -> Tuple[List[int], float, Any]:
        """
        Execute the spiral policy.
        
        Expected kwargs:
            bins: Bins object with .fill_levels
            dist_matrix: Distance matrix
            coords: Coordinates array
            config: Configuration dictionary (optional)
        """
        # 1. Extract data from kwargs
        bins = kwargs.get("bins")
        coords = kwargs.get("coords")
        dist_matrix = kwargs.get("distance_matrix")
        
        # Parse config for threshold (default 50%)
        config = kwargs.get("config", {})
        spiral_cfg = config.get("spiral", {})
        threshold = spiral_cfg.get("threshold", 50.0)
        
        # 2. Policy Logic
        fill_levels = bins.fill_levels
        
        # Identify candidates
        candidates = np.where(fill_levels >= threshold)[0]
        if len(candidates) == 0:
            return [], 0.0, None
            
        # Adjust indices for coordinates (depot is 0, bin i is i+1)
        # Note: bins 0..N-1 correspond to coords 1..N
        candidate_coords = coords[candidates + 1]
        depot_coord = coords[0]
        
        # Calculate angles relative to depot
        diffs = candidate_coords - depot_coord
        angles = np.arctan2(diffs[:, 1], diffs[:, 0])
        
        # Sort by angle
        sort_indices = np.argsort(angles)
        tour = list(candidates[sort_indices])
        
        # 3. Calculate cost (helper function or manual)
        # Simple manual calculation for tutorial:
        def calc_tour_cost(tour, dist_matrix):
            if not tour: return 0.0
            route = [0] + [t + 1 for t in tour] + [0]
            cost = 0.0
            for i in range(len(route) - 1):
                cost += dist_matrix[route[i], route[i+1]]
            return cost
            
        cost = calc_tour_cost(tour, dist_matrix)
        
        return tour, cost, None

print("SpiralPolicy registered successfully!")
print(f"Current Registry: {PolicyRegistry.list_policies()}")

### Step 2: Test the Custom Policy

Now we can instantiate and run it, simulating how the `Simulator` would call it.

In [None]:
# Mock objects to simulate the environment state
class MockBins:
    def __init__(self, n=20):
        self.n = n
        self.fill_levels = np.random.uniform(0, 100, n)

# Prepare simulation data
n_bins = 20
mock_bins = MockBins(n_bins)
coords = np.random.rand(n_bins + 1, 2)
dist_matrix = np.sqrt(((coords[:, None] - coords[None, :]) ** 2).sum(axis=-1))

# Instantiate policy via Registry (or directly)
policy_cls = PolicyRegistry.get("spiral")
policy_instance = policy_cls()

# Execute
tour, cost, _ = policy_instance.execute(
    bins=mock_bins,
    coords=coords,
    distance_matrix=dist_matrix,
    config={"spiral": {"threshold": 30.0}}
)

print(f"Fill levels > 30%: {(mock_bins.fill_levels > 30).sum()}")
print(f"Selected by Spiral Policy: {len(tour)}")
print(f"Tour order: {tour}")
print(f"Tour cost: {cost:.2f}")

In [None]:
# Visualizing our custom policy
def plot_tour(coords, tour, title="Policy Tour"):
    fig, ax = plt.subplots(figsize=(6, 6))
    
    # Plot all nodes
    ax.scatter(coords[1:, 0], coords[1:, 1], c='lightgray', s=30, label='Bin')
    ax.scatter(coords[0, 0], coords[0, 1], c='red', s=100, marker='*', label='Depot')
    
    if len(tour) > 0:
        # Construct path: Depot -> Tour -> Depot
        path_indices = [0] + [t + 1 for t in tour] + [0]
        path_coords = coords[path_indices]
        
        ax.plot(path_coords[:, 0], path_coords[:, 1], c='steelblue', linewidth=1.5, alpha=0.8)
        ax.scatter(coords[[t+1 for t in tour], 0], coords[[t+1 for t in tour], 1], 
                   c='steelblue', s=50, zorder=3, label='Collected')
        
        # Annotate order
        for i, idx in enumerate(path_indices[1:-1]):
            ax.text(coords[idx, 0], coords[idx, 1], str(i+1), fontsize=8, color='white', ha='center', va='center')

    ax.legend()
    ax.set_title(title)
    plt.tight_layout()
    plt.show()

plot_tour(coords, tour, title="Custom Spiral Policy")

---
## 2. Adding a Custom Neural Encoder

You can customize neural architectures by subclassing `nn.Module` and swapping components in existing models.

Let's create a **Residual MLP Encoder**. Unlike the standard Graph Attention Encoder (GAT), this encoder treats nodes independently but uses a deep residual MLP to project features, which is faster but ignores graph topology.

In [None]:
from logic.src.models.policies import AttentionModelPolicy

class ResidualBlock(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.linear1 = nn.Linear(dim, dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(dim, dim)
    
    def forward(self, x):
        residual = x
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x + residual  # Skip connection

class ResidualMLPEncoder(nn.Module):
    def __init__(self, input_dim=2, embed_dim=128, num_layers=3):
        super().__init__()
        self.init_embed = nn.Linear(input_dim, embed_dim)
        self.layers = nn.ModuleList([ResidualBlock(embed_dim) for _ in range(num_layers)])
    
    def forward(self, x, mask=None):
        """
        Args:
            x: Input features (batch_size, num_nodes, input_dim)
            mask: Unused here, kept for API compatibility
        Returns:
            (embeddings, mean_embedding) tuple
        """
        h = self.init_embed(x)
        for layer in self.layers:
            h = layer(h)
        
        # Return (node_embeddings, graph_embedding)
        # Graph embedding is just mean over nodes
        graph_emb = h.mean(dim=1)
        return h, graph_emb

# Instantiate our custom encoder
custom_encoder = ResidualMLPEncoder(input_dim=2, embed_dim=64, num_layers=4)
print("Custom Encoder Created:")
print(custom_encoder)

### Integrating the Custom Encoder

We can create a standard `AttentionModelPolicy` and simply replace its encoder.

In [None]:
# 1. Create standard policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,  # Must match our custom encoder
    n_encode_layers=2, 
    n_decode_layers=2,
    n_heads=4
)

# 2. Swap the encoder
print(f"Original Encoder: {type(policy.encoder).__name__}")
policy.encoder = custom_encoder
print(f"New Encoder:      {type(policy.encoder).__name__}")

# 3. Verify it runs
from logic.src.envs import get_env
env = get_env("vrpp", num_loc=20)
td = env.reset(env.generator(batch_size=2))

with torch.no_grad():
    # Standard forward pass triggers our new encoder
    out = policy(td, env, strategy="greedy")
    print(f"Forward pass successful! Reward: {out['reward'].mean():.4f}")

---
## 3. Creating a Subclass Model

For cleaner integration, you can subclass `AttentionModelPolicy` to permanently use your architecture.

In [None]:
class CustomMLPPolicy(AttentionModelPolicy):
    def __init__(self, mlp_layers=3, **kwargs):
        # Initialize parent standard components
        super().__init__(**kwargs)
        
        # Overwrite encoder with our custom one
        self.encoder = ResidualMLPEncoder(
            input_dim=self.encoder.init_embed.in_features,
            embed_dim=kwargs.get('embed_dim', 128),
            num_layers=mlp_layers
        )

custom_policy = CustomMLPPolicy(
    env_name="vrpp", embed_dim=64, mlp_layers=5
)

print(f"Custom Policy: {custom_policy}")

---
## 4. Customizing the Loss Function (RL)

WSmart+ Route uses PyTorch Lightning `LightningModule`s for training logic. You can subclass `RL4COLitModule` or `REINFORCE` to modify the loss calculation.

Let's implement **Entropy-Regularized REINFORCE**. Standard REINFORCE maximizes reward. We will add an entropy term to encourage exploration.

In [None]:
from logic.src.pipeline.rl import REINFORCE

class EntropyREINFORCE(REINFORCE):
    def __init__(self, entropy_coeff=0.01, **kwargs):
        super().__init__(**kwargs)
        self.entropy_coeff = entropy_coeff
    
    def training_step(self, batch, batch_idx):
        # 1. Run policy
        out = self.policy(batch, self.env, strategy="sampling", return_actions=True)
        
        # 2. Compute advantage (Reward - Baseline)
        # (Using internal baseline logic from parent)
        bl_val, bl_loss = self.baseline.eval(batch, out['reward'])
        advantage = out['reward'] - bl_val
        
        # 3. Policy Loss: -(Advantage * log_prob)
        log_likelihood = out['log_likelihood']
        reinforce_loss = -(advantage * log_likelihood).mean()
        
        # 4. Entropy Bonus: - (beta * entropy)
        # Note: We subtract because we minimize loss. Entropy maximization -> negative loss.
        entropy = out.get('entropy', torch.zeros(1, device=self.device))
        entropy_loss = -self.entropy_coeff * entropy.mean()
        
        # 5. Total loss
        loss = reinforce_loss + entropy_loss + bl_loss
        
        # Logging
        self.log("train/loss", loss)
        self.log("train/entropy", entropy.mean())
        self.log("train/reward", out['reward'].mean())
        
        return loss

print("Custom Entropy-Regularized RL Module Defined.")

In [None]:
# Brief training loop verification
from logic.src.pipeline.rl.common.trainer import WSTrainer

# Create module
model_ent = EntropyREINFORCE(
    entropy_coeff=0.1,  # High entropy for demonstration
    env=env,
    policy=custom_policy,
    baseline="rollout",
    train_data_size=128,
    val_data_size=32,
    batch_size=32
)

# Train for 1 epoch
trainer = WSTrainer(
    max_epochs=1, 
    accelerator="cpu", 
    devices=1, 
    logger=False, 
    enable_progress_bar=False
)

print("Running training step with custom loss...")
trainer.fit(model_ent)
print("Training step successful!")

---
## Summary

In this tutorial, we demonstrated how to extend WSmart+ Route:

1.  **Custom Policies**: Implemented a `SpiralPolicy` using the `IPolicy` interface and `PolicyRegistry`.
2.  **Custom Neural Modules**: Created a `ResidualMLPEncoder` and swapped it into a standard `AttentionModelPolicy`.
3.  **Inheritance**: Created a proper `CustomMLPPolicy` class.
4.  **Custom Loss**: Subclassed `REINFORCE` to add an entropy regularization term in the training step.

These mechanisms allow you to test novel ideas without needing to fork the entire library.