# WSmart+ Route: Lightning-Based RL Training Tutorial

**Version:** 2.0  
**Last Updated:** January 2026

This comprehensive tutorial covers the reinforcement learning training pipeline built on **PyTorch Lightning** and **Hydra** for combinatorial optimization problems.

## Table of Contents

1. [Overview & Architecture](#1-overview--architecture)
2. [Hydra Configuration System](#2-hydra-configuration-system)
3. [PyTorch Lightning Modules](#3-pytorch-lightning-modules)
4. [Complete RL Algorithms Reference](#4-complete-rl-algorithms-reference)
5. [Baselines for Variance Reduction](#5-baselines-for-variance-reduction)
6. [Complete Neural Models Reference](#6-complete-neural-models-reference)
7. [Environments & Data Generation](#7-environments--data-generation)
8. [Meta-Learning](#8-meta-learning)
9. [Hyperparameter Optimization](#9-hyperparameter-optimization)
10. [Practical Examples](#10-practical-examples)
11. [Adding New Components Guide](#11-adding-new-components-guide)
12. [Hydra + Lightning + TorchRL Best Practices](#12-hydra--lightning--torchrl-best-practices)
13. [Troubleshooting & Common Patterns](#13-troubleshooting--common-patterns)

In [1]:
# Standard notebook setup
import warnings

warnings.filterwarnings('ignore')

from notebook_setup import setup_google_colab, setup_home_directory

NOTEBOOK_NAME = "lightning_rl_training_tutorial"
home_dir = setup_home_directory(NOTEBOOK_NAME)
IN_COLAB, gdrive, gfiles = setup_google_colab(NOTEBOOK_NAME)

# Core imports used throughout the notebook
from typing import Any, Dict, List, Optional

import pytorch_lightning as pl
import torch
import torch.nn as nn

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

Setup completed - added home_dir to system path: /home/pkhunter/Repositories/WSmart-Route
PyTorch version: 2.2.2+cu121
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3090 Ti


---

## 1. Overview & Architecture

The WSmart+ Route training pipeline is built on three foundational technologies:

### 1.1 Core Technologies

| Technology | Purpose | Key Benefits |
|------------|---------|-------------|
| **PyTorch Lightning** | Training orchestration | Automatic GPU management, logging, checkpointing |
| **Hydra** | Configuration management | Hierarchical configs, CLI overrides, experiment tracking |
| **TensorDict** | State management | Efficient batched operations, device-agnostic tensors |

### 1.2 Pipeline Architecture

```
┌─────────────────────────────────────────────────────────────────┐
│                      train_lightning.py                         │
│                    (Hydra CLI Entry Point)                      │
└─────────────────────────────────────────────────────────────────┘
                              │
              ┌───────────────┼───────────────┐
              │               │               │
              ▼               ▼               ▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│   Config        │ │   Environment   │ │   RL Module     │
│   (Hydra)       │ │   (RL4COEnv)    │ │   (Lightning)   │
└─────────────────┘ └─────────────────┘ └─────────────────┘
                              │
                              ▼
              ┌───────────────────────────────┐
              │        WSTrainer              │
              │   (PyTorch Lightning)         │
              └───────────────────────────────┘
                              │
                              ▼
              ┌───────────────────────────────┐
              │   Training Loop               │
              │   - Data Generation           │
              │   - Forward Pass              │
              │   - Loss Computation          │
              │   - Baseline Updates          │
              └───────────────────────────────┘
```

### 1.3 Complete Directory Structure

```
logic/src/
├── pipeline/rl/              # RL TRAINING PIPELINE
│   ├── core/                 # Core RL algorithms
│   │   ├── base.py          # RL4COLitModule (base Lightning module)
│   │   ├── baselines.py     # All baseline implementations
│   │   ├── reinforce.py     # REINFORCE algorithm
│   │   ├── ppo.py           # Proximal Policy Optimization
│   │   ├── sapo.py          # Self-Adaptive Policy Optimization
│   │   ├── gspo.py          # Gradient-Scaled Proxy Optimization
│   │   ├── gdpo.py          # Gradient-Divergence Policy Optimization
│   │   ├── dr_grpo.py       # Divergence-Regularized GRPO
│   │   ├── pomo.py          # Policy Optimization Multiple Optima
│   │   ├── symnco.py        # Symmetry-aware NCO
│   │   ├── imitation.py     # Imitation Learning
│   │   └── adaptive_imitation.py  # IL-to-RL transition
│   ├── meta/                 # Meta-learning strategies
│   │   ├── module.py        # MetaRLModule wrapper
│   │   ├── hrl.py           # Hierarchical RL
│   │   ├── weight_optimizer.py
│   │   ├── contextual_bandits.py
│   │   ├── td_learning.py
│   │   └── multi_objective.py
│   ├── hpo/                  # Hyperparameter optimization
│   │   ├── optuna_hpo.py    # Optuna-based HPO
│   │   └── dehb.py          # DEHB multi-fidelity
│   └── features/             # Training utilities
│       ├── epoch.py         # Epoch preparation
│       ├── post_processing.py
│       └── time_training.py
│
├── models/                   # NEURAL NETWORK MODELS
│   ├── attention_model.py   # Attention Model (AM)
│   ├── deep_decoder_am.py   # Deep Decoder AM (DDAM)
│   ├── temporal_am.py       # Temporal AM (TAM)
│   ├── gat_lstm_manager.py  # HRL Manager network
│   ├── pointer_network.py   # Pointer Network
│   ├── critic_network.py    # Value network for PPO/baselines
│   ├── hypernet.py          # Hypernetwork for meta-learning
│   ├── meta_rnn.py          # Meta-learning RNN
│   ├── moe_model.py         # Mixture of Experts
│   ├── model_factory.py     # Factory for model instantiation
│   ├── modules/             # Atomic building blocks
│   │   ├── multi_head_attention.py
│   │   ├── graph_convolution.py
│   │   ├── feed_forward.py
│   │   ├── normalization.py
│   │   ├── activation_function.py
│   │   ├── skip_connection.py
│   │   └── ...
│   └── subnets/             # Encoders & Decoders
│       ├── gat_encoder.py   # Graph Attention Encoder
│       ├── gac_encoder.py   # Graph Attention Conv Encoder
│       ├── tgc_encoder.py   # Transformer Graph Conv Encoder
│       ├── gcn_encoder.py   # GCN Encoder
│       ├── attention_decoder.py
│       └── ...
│
├── envs/                     # ENVIRONMENTS
│   ├── base.py              # RL4COEnvBase
│   ├── vrpp.py              # VRPP, CVRPP environments
│   ├── wcvrp.py             # WCVRP, CWCVRP, SDWCVRP
│   ├── swcvrp.py            # SCWCVRP
│   └── generators.py        # Data generators
│
└── configs/                  # CONFIGURATION DATACLASSES
    ├── __init__.py          # Root Config
    ├── env.py               # EnvConfig
    ├── model.py             # ModelConfig
    ├── train.py             # TrainConfig
    ├── optim.py             # OptimConfig
    ├── rl.py                # RLConfig
    ├── meta_rl.py           # MetaRLConfig
    └── hpo.py               # HPOConfig
```

---

## 2. Hydra Configuration System

The training pipeline uses **Hydra** for configuration management, enabling:
- Hierarchical configuration with dataclasses
- Command-line overrides
- Experiment tracking and reproducibility

### 2.1 Configuration Dataclasses

All configurations are defined in `logic/src/configs/__init__.py`:

In [3]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

# Examine the configuration structure
from logic.src.configs import (
    Config,
    EnvConfig,
    HPOConfig,
    ModelConfig,
    OptimConfig,
    RLConfig,
    TrainConfig,
)

# Print available config sections
print("Configuration Sections:")
print("=" * 50)
for config_cls in [EnvConfig, ModelConfig, TrainConfig, OptimConfig, RLConfig, HPOConfig]:
    print(f"\n{config_cls.__name__}:")
    for field_name, field_type in config_cls.__annotations__.items():
        default = getattr(config_cls, field_name, "<no default>")
        if callable(default) and hasattr(default, "__name__"):
            default = f"<factory: {default.__name__}>"
        print(f"  {field_name}: {field_type} = {default}")

Configuration Sections:

EnvConfig:
  name: <class 'str'> = vrpp
  num_loc: <class 'int'> = 50
  min_loc: <class 'float'> = 0.0
  max_loc: <class 'float'> = 1.0
  capacity: typing.Optional[float] = None
  overflow_penalty: <class 'float'> = 1.0
  collection_reward: <class 'float'> = 1.0
  cost_weight: <class 'float'> = 1.0
  prize_weight: <class 'float'> = 1.0
  area: <class 'str'> = riomaior
  waste_type: <class 'str'> = plastic
  focus_graph: typing.Optional[str] = None
  focus_size: <class 'int'> = 0
  eval_focus_size: <class 'int'> = 0
  distance_method: <class 'str'> = ogd
  dm_filepath: typing.Optional[str] = None
  waste_filepath: typing.Optional[str] = None
  vertex_method: <class 'str'> = mmn
  edge_threshold: <class 'float'> = 0.0
  edge_method: typing.Optional[str] = None
  data_distribution: typing.Optional[str] = None
  min_fill: <class 'float'> = 0.0
  max_fill: <class 'float'> = 1.0
  fill_distribution: <class 'str'> = uniform

ModelConfig:
  name: <class 'str'> = am
  e

### 2.2 Configuration Sections Explained

#### EnvConfig - Environment Settings
```python
@dataclass
class EnvConfig:
    name: str = "vrpp"           # Problem type: vrpp, wcvrp, cwcvrp, sdwcvrp
    num_loc: int = 50             # Number of locations (nodes)
    min_loc: float = 0.0          # Min coordinate value
    max_loc: float = 1.0          # Max coordinate value
    capacity: Optional[float] = None  # Vehicle capacity
    overflow_penalty: float = 1.0     # Penalty for bin overflow
    collection_reward: float = 1.0    # Reward for waste collection
    cost_weight: float = 1.0          # Weight for travel cost
    prize_weight: float = 1.0         # Weight for prizes (VRPP)
```

#### ModelConfig - Neural Network Architecture
```python
@dataclass
class ModelConfig:
    name: str = "am"              # Model: am, deep_decoder, temporal, pointer, symnco
    embed_dim: int = 128          # Embedding dimension
    hidden_dim: int = 512         # Hidden layer dimension
    num_encoder_layers: int = 3   # Number of encoder layers
    num_decoder_layers: int = 3   # Number of decoder layers (deep_decoder)
    n_heads: int = 8            # Attention heads
    encoder_type: str = "gat"     # Encoder type: gat, gcn, mlp
    normalization: str = "instance"  # Normalization: instance, batch, layer
    activation: str = "gelu"         # Activation function
    dropout: float = 0.1             # Dropout rate
```

#### RLConfig - RL Algorithm Settings
```python
@dataclass
class RLConfig:
    algorithm: str = "reinforce"  # reinforce, ppo, sapo, gspo, pomo, symnco, hrl
    baseline: str = "rollout"     # none, exponential, rollout, critic, pomo
    entropy_weight: float = 0.0   # Entropy regularization
    max_grad_norm: float = 1.0    # Gradient clipping
    
    # PPO-specific
    ppo_epochs: int = 10          # Inner PPO epochs
    eps_clip: float = 0.2         # Clipping epsilon
    value_loss_weight: float = 0.5
    
    # Meta-learning
    use_meta: bool = False        # Enable meta-learning wrapper
    meta_strategy: str = "rnn"    # rnn, bandit, morl, tdl, hypernet
```

### 2.3 Command-Line Interface Usage

The main entry point is `train_lightning.py`, accessed via:

```bash
python main.py train_lightning [CONFIG_OVERRIDES]
```

#### Basic Training Examples

```bash
# Train Attention Model on VRPP with 50 nodes
python main.py train_lightning model=am env.name=vrpp env.num_loc=50

# Train with PPO algorithm
python main.py train_lightning rl.algorithm=ppo rl.ppo_epochs=10

# Use different baseline
python main.py train_lightning rl.baseline=exponential

# Custom training parameters
python main.py train_lightning \
    train.n_epochs=100 \
    train.batch_size=256 \
    optim.lr=1e-4
```

#### Advanced Configurations

```bash
# POMO with data augmentation
python main.py train_lightning \
    rl.algorithm=pomo \
    rl.num_augment=8 \
    rl.augment_fn=dihedral8

# Hierarchical RL
python main.py train_lightning \
    rl.algorithm=hrl \
    rl.meta_hidden_dim=128

# With Meta-Learning wrapper
python main.py train_lightning \
    rl.use_meta=true \
    rl.meta_strategy=rnn \
    rl.meta_lr=1e-3

# Hyperparameter optimization
python main.py train_lightning \
    hpo.n_trials=50 \
    hpo.method=tpe
```

In [4]:
# Programmatically create a configuration
from logic.src.configs import Config, EnvConfig, ModelConfig, OptimConfig, RLConfig, TrainConfig

# Create custom configuration
cfg = Config(
    env=EnvConfig(
        name="vrpp",
        num_loc=50,
        capacity=100.0,
    ),
    model=ModelConfig(
        name="am",
        embed_dim=128,
        num_encoder_layers=3,
        n_heads=8,
    ),
    train=TrainConfig(
        n_epochs=10,
        batch_size=256,
        train_data_size=10000,
    ),
    optim=OptimConfig(
        optimizer="adam",
        lr=1e-4,
    ),
    rl=RLConfig(
        algorithm="reinforce",
        baseline="rollout",
    ),
    seed=42,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

print("Configuration created:")
print(f"  Environment: {cfg.env.name} with {cfg.env.num_loc} locations")
print(f"  Model: {cfg.model.name} with {cfg.model.embed_dim}d embeddings")
print(f"  Algorithm: {cfg.rl.algorithm} with {cfg.rl.baseline} baseline")
print(f"  Device: {cfg.device}")

Configuration created:
  Environment: vrpp with 50 locations
  Model: am with 128d embeddings
  Algorithm: reinforce with rollout baseline
  Device: cuda


---

## 3. PyTorch Lightning Modules

All RL algorithms inherit from `RL4COLitModule`, which provides:
- Training/validation/test loops
- Automatic optimizer configuration
- Data loading with generators
- Baseline integration
- Metric logging

### 3.1 Base Module Architecture

In [4]:
import inspect

from logic.src.pipeline.rl.core.base import RL4COLitModule

# Show the base class signature
print("RL4COLitModule.__init__ signature:")
print(inspect.signature(RL4COLitModule.__init__))

print("\nKey Methods:")
for name, method in inspect.getmembers(RL4COLitModule, predicate=inspect.isfunction):
    if not name.startswith("_") or name in ["__init__"]:
        doc = method.__doc__
        first_line = doc.split("\n")[0].strip() if doc else "No docstring"
        print(f"  {name}: {first_line}")

RL4COLitModule.__init__ signature:
(self, env: 'RL4COEnvBase', policy: 'ConstructivePolicy', baseline: 'Optional[str]' = 'rollout', optimizer: 'str' = 'adam', optimizer_kwargs: 'Optional[dict]' = None, lr_scheduler: 'Optional[str]' = None, lr_scheduler_kwargs: 'Optional[dict]' = None, train_data_size: 'int' = 100000, val_data_size: 'int' = 10000, val_dataset_path: 'Optional[str]' = None, batch_size: 'int' = 256, num_workers: 'int' = 4, persistent_workers: 'bool' = True, pin_memory: 'bool' = False, **kwargs)

Key Methods:
  __init__: 
  add_module: Add a child module to the current module.
  all_gather: Gather tensors or collections of tensors from multiple processes.
  apply: Apply ``fn`` recursively to every submodule (as returned by ``.children()``) as well as self.
  backward: Called to perform backward on the loss returned in :meth:`training_step`. Override this hook with your own
  bfloat16: Casts all floating point parameters and buffers to ``bfloat16`` datatype.
  buffers: Retur

### 3.2 Training Lifecycle

```python
class RL4COLitModule(pl.LightningModule, ABC):
    """
    Training Lifecycle:
    
    1. setup(stage='fit'):
       - Create train_dataset (GeneratorDataset)
       - Create val_dataset
       
    2. for epoch in range(n_epochs):
       
       3. on_train_epoch_start():
          - Wrap dataset with baseline (if RolloutBaseline)
          
       4. for batch in train_dataloader():
          
          5. training_step(batch):
             - Unwrap batch (baseline values)
             - shared_step():
               a. env.reset(batch)
               b. policy(td, env, strategy="sampling")
               c. calculate_loss()  # Algorithm-specific
             - return loss
             
          6. on_before_optimizer_step():
             - Gradient clipping
             
       7. validation_step(batch):
          - shared_step() with strategy="greedy"
          
       8. on_train_epoch_end():
          - baseline.epoch_callback()
          - Optionally regenerate dataset
    """
```

### 3.3 Key Methods Explained

#### `shared_step()` - Common Training Logic
```python
def shared_step(self, batch, batch_idx, phase):
    # 1. Unwrap baseline values if present
    batch, baseline_val = self.baseline.unwrap_batch(batch)
    
    # 2. Move to device
    batch = batch.to(self.device)
    
    # 3. Reset environment with batch
    td = self.env.reset(batch)
    
    # 4. Run policy
    out = self.policy(
        td, self.env,
        strategy="sampling" if phase == "train" else "greedy"
    )
    
    # 5. Compute loss (training only)
    if phase == "train":
        out["loss"] = self.calculate_loss(td, out, batch_idx)
    
    # 6. Log metrics
    self.log(f"{phase}/reward", out["reward"].mean())
    
    return out
```

#### `calculate_loss()` - Algorithm-Specific
This is the **abstract method** that each RL algorithm must implement.

In [5]:
# Example: REINFORCE loss implementation
from logic.src.pipeline.rl.core.reinforce import REINFORCE

# Show the calculate_loss implementation
print("REINFORCE.calculate_loss source:")
print(inspect.getsource(REINFORCE.calculate_loss))

REINFORCE.calculate_loss source:
    def calculate_loss(
        self,
        td: TensorDict,
        out: dict,
        batch_idx: int,
        env: Optional["RL4COEnvBase"] = None,
    ) -> torch.Tensor:
        """
        Compute REINFORCE loss.

        Loss = -E[(R - b) * log π(a|s)]
        """
        reward = out["reward"]
        log_likelihood = out["log_likelihood"]

        # Get baseline
        if hasattr(self, "_current_baseline_val") and self._current_baseline_val is not None:
            baseline_val = self._current_baseline_val
        else:
            baseline_val = self.baseline.eval(td, reward, env=env)

        # Advantage
        advantage = reward - baseline_val

        # Normalize advantage (optional but helps stability)
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # Policy gradient loss
        loss = -(advantage.detach() * log_likelihood).mean()

        # Entropy bonus (if applicable)
        if self.entropy_weig

---

## 4. Complete RL Algorithms Reference

The pipeline supports **11 RL algorithms**, each suited for different scenarios.

### 4.1 Algorithm Overview

| Algorithm | File | Loss Function | Best For |
|-----------|------|---------------|----------|
| **REINFORCE** | `reinforce.py` | $-\mathbb{E}[(R-b) \log \pi]$ | Simple baseline, debugging |
| **PPO** | `ppo.py` | Clipped surrogate + Value | Stable training |
| **SAPO** | `sapo.py` | Soft adaptive clipping | When PPO is too aggressive |
| **GSPO** | `gspo.py` | Gradient-scaled proxy | Sequence-level optimization |
| **GDPO** | `gdpo.py` | Group distributional | Multi-objective optimization |
| **DRGRPO** | `dr_grpo.py` | Divergence-regularized | Preventing policy collapse |
| **POMO** | `pomo.py` | Multi-start baseline | Symmetric problems (TSP, VRP) |
| **SymNCO** | `symnco.py` | Symmetry-aware loss | Exploiting problem symmetries |
| **ImitationLearning** | `imitation.py` | Cross-entropy with expert | Bootstrapping from classical solvers |
| **AdaptiveImitation** | `adaptive_imitation.py` | IL → RL transition | Curriculum learning |
| **HRLModule** | `hrl.py` | Hierarchical PPO | Temporal decisions (when + where) |

### 4.2 Algorithm Selection Guide

```
                        ┌─────────────────────────┐
                        │   Starting a project?   │
                        └───────────┬─────────────┘
                                    │
                    ┌───────────────┼───────────────┐
                    │               │               │
              Have expert      Need stable     Problem has
               solutions?       training?       symmetry?
                    │               │               │
                    ▼               ▼               ▼
            ┌───────────┐   ┌───────────┐   ┌───────────┐
            │ Imitation │   │    PPO    │   │   POMO    │
            │ Adaptive  │   │   SAPO    │   │  SymNCO   │
            └───────────┘   │   GSPO    │   └───────────┘
                            └───────────┘
                                    │
                            Simple case?
                                    │
                                    ▼
                            ┌───────────┐
                            │ REINFORCE │
                            └───────────┘
```

### 4.3 Detailed Algorithm Descriptions

#### REINFORCE (Vanilla Policy Gradient)
```python
# Loss: -E[(R - baseline) * log π(a|s)]
from logic.src.pipeline.rl import REINFORCE

module = REINFORCE(
    env=env, policy=policy,
    baseline="rollout",        # none, exponential, rollout, critic
    entropy_weight=0.01,       # Entropy regularization
    max_grad_norm=1.0,         # Gradient clipping
)
```

#### PPO (Proximal Policy Optimization)
```python
# Loss: min(r*A, clip(r, 1-ε, 1+ε)*A) + c*L_value - β*H(π)
from logic.src.pipeline.rl import PPO

module = PPO(
    env=env, policy=policy, critic=critic,
    ppo_epochs=10,             # Inner optimization epochs
    eps_clip=0.2,              # Clipping epsilon
    value_loss_weight=0.5,     # Critic loss weight
    mini_batch_size=0.25,      # Fraction of batch per mini-batch
)
```

#### SAPO (Self-Adaptive Policy Optimization)
```python
# Soft gating instead of hard clipping
from logic.src.pipeline.rl import SAPO

module = SAPO(
    env=env, policy=policy, critic=critic,
    sapo_tau_pos=0.1,          # Positive advantage temperature
    sapo_tau_neg=1.0,          # Negative advantage temperature
)
```

#### POMO (Policy Optimization with Multiple Optima)
```python
# Multi-start + data augmentation for symmetric problems
from logic.src.pipeline.rl import POMO

module = POMO(
    env=env, policy=policy,
    num_augment=8,             # Dihedral group D8
    augment_fn="dihedral8",    # Augmentation function
    num_starts=None,           # Defaults to num_loc
)
```

#### SymNCO (Symmetry-aware NCO)
```python
# Exploits invariances in CO problems
from logic.src.pipeline.rl import SymNCO

module = SymNCO(
    env=env, policy=policy,
    num_augment=8,
    symnco_alpha=0.2,          # Symmetry loss weight
    symnco_beta=1.0,           # Consistency loss weight
)
```

#### ImitationLearning (Learn from Expert)
```python
# Supervised learning from classical solvers
from logic.src.pipeline.rl import ImitationLearning

module = ImitationLearning(
    env=env, policy=policy,
    imitation_mode="hgs",      # Expert: hgs, alns, gurobi
    imitation_weight=1.0,      # IL loss weight
)
```

#### AdaptiveImitation (IL → RL Transition)
```python
# Curriculum: start with IL, transition to RL
from logic.src.pipeline.rl import AdaptiveImitation

module = AdaptiveImitation(
    env=env, policy=policy,
    imitation_weight=1.0,      # Initial IL weight
    imitation_decay=0.95,      # Decay per epoch
    imitation_threshold=0.05,  # Minimum IL weight
    reannealing_threshold=0.05,# Re-anneal if gap exceeds this
    reannealing_patience=5,    # Epochs before re-annealing
)
```

#### HRLModule (Hierarchical RL)
```python
# Manager (when to collect) + Worker (route planning)
from logic.src.pipeline.rl import HRLModule

module = HRLModule(
    env=env,
    manager=manager_network,   # GATLSTManager
    worker=worker_policy,      # AttentionModelPolicy
    gamma=0.99,                # Discount factor
)
```

In [6]:
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE

# Create environment
env = get_env("vrpp", num_loc=20, device="cpu")

# Create policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create REINFORCE module
reinforce_module = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",  # Using exponential baseline for demo
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
    entropy_weight=0.01,
    max_grad_norm=1.0,
)

print("REINFORCE module created:")
print(f"  Environment: {env.__class__.__name__}")
print(f"  Policy parameters: {sum(p.numel() for p in policy.parameters()):,}")
print(f"  Baseline: {reinforce_module.baseline_type}")

REINFORCE module created:
  Environment: VRPPEnv
  Policy parameters: 91,520
  Baseline: exponential


### 4.2 PPO (Proximal Policy Optimization)

PPO performs multiple optimization epochs per batch with clipped surrogate objective:

$$\mathcal{L}^{CLIP}(\theta) = \mathbb{E}\left[\min\left(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t\right)\right]$$

where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$

In [None]:
# PPO Example - requires environment and policy from previous cell
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.models.policies.critic import create_critic_from_actor
from logic.src.pipeline.rl import PPO

# Create environment
env = get_env("vrpp", num_loc=20, device="cpu")

# Create policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create critic from actor architecture
critic = create_critic_from_actor(
    policy,
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_layers=2,
    n_heads=4,
)

# Create PPO module
ppo_module = PPO(
    env=env,
    policy=policy,
    critic=critic,
    ppo_epochs=10,  # Inner optimization epochs
    eps_clip=0.2,  # Clipping epsilon
    value_loss_weight=0.5,  # Critic loss weight
    mini_batch_size=0.25,  # 25% of batch per mini-batch
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

print("PPO module created:")
print(f"  Actor parameters: {sum(p.numel() for p in policy.parameters()):,}")
print(f"  Critic parameters: {sum(p.numel() for p in critic.parameters()):,}")
print(f"  PPO epochs: {ppo_module.ppo_epochs}")
print(f"  Clip epsilon: {ppo_module.eps_clip}")

### 4.3 POMO (Policy Optimization with Multiple Optima)

POMO exploits problem symmetries through:
1. **Data Augmentation**: Dihedral transformations (rotations, reflections)
2. **Multi-start Decoding**: Try multiple starting nodes
3. **Shared Baseline**: Mean reward across all augmentations/starts

In [None]:
# POMO Example
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import POMO

# Create environment
env = get_env("vrpp", num_loc=20, device="cpu")

# Create policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create POMO module
pomo_module = POMO(
    env=env,
    policy=policy,
    num_augment=8,  # Dihedral group D8
    augment_fn="dihedral8",  # Augmentation function
    num_starts=None,  # Defaults to num_loc
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

print("POMO module created:")
print(f"  Number of augmentations: {pomo_module.num_augment}")
print(f"  Augmentation function: {pomo_module.augment_fn}")
print(f"  Number of starts: {pomo_module.num_starts or 'auto (num_loc)'}")

### 4.4 CLI Commands for All Algorithms

```bash
# REINFORCE (default)
python main.py train_lightning rl.algorithm=reinforce rl.baseline=rollout

# PPO
python main.py train_lightning rl.algorithm=ppo rl.ppo_epochs=10 rl.eps_clip=0.2

# SAPO
python main.py train_lightning rl.algorithm=sapo rl.sapo_tau_pos=0.1 rl.sapo_tau_neg=1.0

# GSPO
python main.py train_lightning rl.algorithm=gspo rl.gspo_epsilon=0.2 rl.gspo_epochs=3

# GDPO (multi-objective)
python main.py train_lightning rl.algorithm=gdpo \
    'rl.gdpo_objective_keys=["cost", "overflow"]' \
    'rl.gdpo_objective_weights=[0.8, 0.2]'

# DR-GRPO
python main.py train_lightning rl.algorithm=dr_grpo rl.dr_grpo_group_size=8 rl.dr_grpo_epsilon=0.2

# POMO
python main.py train_lightning rl.algorithm=pomo rl.num_augment=8 rl.augment_fn=dihedral8

# SymNCO
python main.py train_lightning rl.algorithm=symnco rl.num_augment=8 rl.symnco_alpha=0.2

# Imitation Learning
python main.py train_lightning rl.algorithm=imitation rl.imitation_mode=hgs

# Adaptive Imitation
python main.py train_lightning rl.algorithm=adaptive_imitation \
    rl.imitation_weight=1.0 rl.imitation_decay=0.95

# HRL (Hierarchical)
python main.py train_lightning rl.algorithm=hrl rl.gamma=0.99
```

---

## 5. Baselines for Variance Reduction

Baselines reduce variance in policy gradient estimates:

$$\nabla_\theta J(\theta) = \mathbb{E}\left[(R - b(s)) \nabla_\theta \log \pi_\theta(a|s)\right]$$

### 5.1 Available Baselines

In [None]:
# Baselines Overview
from logic.src.pipeline.rl.core.baselines import (
    BASELINE_REGISTRY,
    CriticBaseline,
    ExponentialBaseline,
    NoBaseline,
    POMOBaseline,
    RolloutBaseline,
    WarmupBaseline,
    get_baseline,
)

print("Available Baselines:")
print("=" * 60)
for name, cls in BASELINE_REGISTRY.items():
    doc = cls.__doc__ or "No description"
    first_line = doc.split("\n")[0].strip()
    print(f"  {name:15} - {first_line}")

### 5.2 Baseline Comparison

| Baseline | Formula | Pros | Cons |
|----------|---------|------|------|
| **none** | $b = 0$ | Simple | High variance |
| **exponential** | $b = \beta \cdot b + (1-\beta) \cdot \bar{R}$ | Low compute | Biased |
| **rollout** | $b = R^{greedy}_{\pi_{old}}$ | Unbiased | Expensive |
| **critic** | $b = V_\phi(s)$ | Learned | Requires critic network |
| **pomo** | $b = \text{mean}(R_{starts})$ | Multi-solution | POMO-specific |

### 5.3 Rollout Baseline Deep Dive

The **RolloutBaseline** is the most commonly used baseline:

1. **Pre-compute**: At epoch start, run greedy rollout on training data
2. **Store**: Keep baseline values alongside training samples
3. **Use**: During training, advantage = reward - stored_baseline
4. **Update**: Periodically update baseline policy if improved

In [None]:
# Demonstrate RolloutBaseline usage
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl.core.baselines import RolloutBaseline

# Create policy for demonstration
env = get_env("vrpp", num_loc=20, device="cpu")
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create baseline with policy
rollout_bl = RolloutBaseline(
    policy=policy,
    update_every=1,  # Check for update every epoch
    bl_alpha=0.05,  # Significance level for T-test
)

print("RolloutBaseline:")
print(f"  Update frequency: every {rollout_bl.update_every} epoch(s)")
print(f"  T-test alpha: {rollout_bl.bl_alpha}")
print(f"  Has baseline policy: {rollout_bl.baseline_policy is not None}")

# Key methods
print("\nKey Methods:")
print("  wrap_dataset(dataset, policy, env) - Pre-compute baseline values")
print("  unwrap_batch(batch) -> (data, baseline_val) - Extract baseline")
print("  eval(td, reward, env) - Compute baseline on-the-fly")
print("  epoch_callback(policy, epoch, val_dataset, env) - Update check")

### 5.4 Warmup Baseline

The **WarmupBaseline** provides gradual transition from exponential to target baseline:

$$b_{\text{warmup}} = \alpha \cdot b_{\text{target}} + (1 - \alpha) \cdot b_{\text{exponential}}$$

where $\alpha$ increases from 0 to 1 over warmup epochs.

In [None]:
# WarmupBaseline Example
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl.core.baselines import RolloutBaseline, WarmupBaseline

# Create policy
env = get_env("vrpp", num_loc=20, device="cpu")
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create warmup baseline wrapping rollout
target_baseline = RolloutBaseline(policy=policy)
warmup_bl = WarmupBaseline(
    baseline=target_baseline,
    warmup_epochs=5,
    beta=0.8,  # Exponential baseline decay factor
)

print("WarmupBaseline Configuration:")
print(f"  Target: {target_baseline.__class__.__name__}")
print(f"  Warmup epochs: {warmup_bl.warmup_epochs}")
print(f"  Current alpha: {warmup_bl.alpha}")
print("\nBlending schedule:")
for epoch in range(6):
    alpha = min(1.0, (epoch + 1) / warmup_bl.warmup_epochs)
    print(f"  Epoch {epoch}: α = {alpha:.2f} ({int(alpha * 100)}% target, {int((1 - alpha) * 100)}% exponential)")

---

## 6. Complete Neural Models Reference

### 6.1 Model Architecture Overview

| Model | File | Architecture | Use Case |
|-------|------|--------------|----------|
| **AttentionModel (AM)** | `attention_model.py` | Transformer Encoder-Decoder | Standard routing, VRPP, WCVRP |
| **DeepDecoderAM (DDAM)** | `deep_decoder_am.py` | Deep Transformer Decoder | Complex routing decisions |
| **TemporalAM (TAM)** | `temporal_am.py` | Time-aware Transformer | Multi-day waste collection |
| **GATLSTManager** | `gat_lstm_manager.py` | GAT + LSTM | HRL manager (dispatch decisions) |
| **PointerNetwork** | `pointer_network.py` | RNN + Attention | Classic seq2seq routing |
| **CriticNetwork** | `critic_network.py` | MLP Value Network | PPO baseline, actor-critic |
| **Hypernetwork** | `hypernet.py` | Meta-learning | Weight generation |
| **MetaRNN** | `meta_rnn.py` | LSTM | Adaptive weight adjustment |
| **MOEModel** | `moe_model.py` | Mixture of Experts | Multi-task routing |

### 6.2 Encoder Types

All models support pluggable encoders via `encoder_type` config:

| Encoder | Key | Architecture | Best For |
|---------|-----|--------------|----------|
| **GraphAttentionEncoder** | `gat` | Multi-head GAT | Default, general purpose |
| **GraphAttConvEncoder** | `gac` | GAT + Edge features | When edge attributes matter |
| **TransGraphConvEncoder** | `tgc` | Transformer-style GC | Large graphs |
| **GatedGraphAttConvEncoder** | `ggac` | Gated GAT | Complex node-edge interactions |
| **GCNEncoder** | `gcn` | Standard GCN | Simple baselines |
| **MLPEncoder** | `mlp` | MLP only | No graph structure |
| **PointerEncoder** | `ptr` | RNN-based | Pointer networks |
| **MOEEncoder** | `moe` | Mixture of Experts | Multi-task learning |

### 6.3 Module Building Blocks

| Module | File | Description |
|--------|------|-------------|
| **MultiHeadAttention** | `multi_head_attention.py` | Standard scaled dot-product attention |
| **GraphConvolution** | `graph_convolution.py` | Basic GCN message passing |
| **DistanceGraphConvolution** | `distance_graph_convolution.py` | Distance-weighted convolution |
| **GatedGraphConvolution** | `gated_graph_convolution.py` | GRU-style gating on graphs |
| **EfficientGraphConvolution** | `efficient_graph_convolution.py` | Lightweight multi-head with aggregators |
| **FeedForward** | `feed_forward.py` | 2-layer MLP block |
| **Normalization** | `normalization.py` | Batch/Layer/Instance/Group norm |
| **ActivationFunction** | `activation_function.py` | 21+ activations (ReLU, GELU, Mish, SwiGLU...) |
| **SkipConnection** | `skip_connection.py` | Residual connections |
| **HyperConnection** | `hyper_connection.py` | Dynamic depth mixing |
| **MOE** | `moe.py` | Expert routing mechanism |

### 6.4 Model Configuration Examples

```python
# Standard Attention Model
ModelConfig(
    name="am",
    embed_dim=128,
    hidden_dim=512,
    num_encoder_layers=3,
    num_decoder_layers=1,
    n_heads=8,
    encoder_type="gat",
    normalization="instance",
    activation="gelu",
    dropout=0.1,
)

# Deep Decoder for complex problems
ModelConfig(
    name="deep_decoder",
    embed_dim=128,
    hidden_dim=512,
    num_encoder_layers=3,
    num_decoder_layers=6,  # Deep decoder
    n_heads=8,
)

# Temporal model for multi-day scenarios
ModelConfig(
    name="temporal",
    embed_dim=128,
    temporal_horizon=7,  # 7-day lookahead
)
```

### 6.5 CLI Model Selection

```bash
# Attention Model (default)
python main.py train_lightning model=am model.embed_dim=128

# Deep Decoder
python main.py train_lightning model=deep_decoder model.num_decoder_layers=6

# Temporal Model
python main.py train_lightning model=temporal model.temporal_horizon=7

# Change encoder type
python main.py train_lightning model.encoder_type=gac  # Graph Attention Conv
python main.py train_lightning model.encoder_type=tgc  # Transformer Graph Conv
python main.py train_lightning model.encoder_type=gcn  # Standard GCN

# Change normalization
python main.py train_lightning model.normalization=layer  # or batch, instance, group

# Change activation
python main.py train_lightning model.activation=relu  # or gelu, mish, swish, swiglu
```

---

## 6. Environments & Data Generation

The environment system provides:
- Problem-specific state transitions
- Reward computation
- Action masking for valid moves
- On-the-fly data generation

### 6.1 Environment Registry

In [None]:
from logic.src.envs import ENV_REGISTRY, get_env

print("Available Environments:")
print("=" * 60)
for name, cls in ENV_REGISTRY.items():
    doc = cls.__doc__ or "No description"
    first_line = doc.split("\n")[0].strip() if doc else "No description"
    print(f"  {name:10} - {cls.__name__}")

### 6.2 Environment Types

| Environment | Description | Key Features |
|-------------|-------------|-------------|
| **vrpp** | Vehicle Routing with Profits | Prizes at nodes, maximize profit-cost |
| **cvrpp** | Capacitated VRPP | + Vehicle capacity constraints |
| **wcvrp** | Waste Collection VRP | Bin fill levels, accumulation |
| **cwcvrp** | Capacitated WCVRP | + Vehicle capacity |
| **sdwcvrp** | Stochastic Demand WCVRP | + Uncertain waste generation |

### 6.3 TensorDict State Structure

All environments use TensorDict for state management:

In [None]:
from logic.src.envs import get_env

# Create VRPP environment
env = get_env(
    "vrpp",
    num_loc=20,
    min_loc=0.0,
    max_loc=1.0,
    device="cpu",
)

# Generate a batch of instances
batch_size = 4
td = env.generator(batch_size)

print("Generated TensorDict:")
print(f"  Batch size: {batch_size}")
print(f"  Keys: {list(td.keys())}")
print("\nShapes:")
for key in td.keys():
    print(f"  {key}: {td[key].shape}")

In [None]:
# Demonstrate environment step
td_reset = env.reset(td)

print("After reset:")
print(f"  Keys: {list(td_reset.keys())}")
print(f"  Done: {td_reset['done'] if 'done' in td_reset.keys() else 'Not set'}")

# Show action mask
if "action_mask" in td_reset.keys():
    mask = td_reset["action_mask"]
    print(f"\nAction mask shape: {mask.shape}")
    print(f"  Valid actions (first instance): {mask[0].sum().item()}/{mask[0].numel()}")

### 6.4 Data Generators

Generators create problem instances on-the-fly:

In [None]:
from logic.src.envs import GENERATOR_REGISTRY, get_generator

print("Available Generators:")
for name in GENERATOR_REGISTRY.keys():
    print(f"  {name}")

# Create VRPP generator
generator = get_generator(
    "vrpp",
    num_loc=50,
    min_loc=0.0,
    max_loc=1.0,
)

# Generate instances
instances = generator(batch_size=8)
print(f"\nGenerated {instances.batch_size[0]} instances with {instances['locs'].shape[-2]} locations")

### 6.5 Dataset Classes

The pipeline uses custom dataset classes:

```python
from logic.src.data.datasets import (
    GeneratorDataset,      # On-the-fly generation
    TensorDictDataset,     # Persistent storage
    BaselineDataset,       # Wraps dataset with baseline values
    tensordict_collate_fn, # Custom collation for TensorDict
)
```

In [None]:
from logic.src.data.datasets import GeneratorDataset, tensordict_collate_fn
from torch.utils.data import DataLoader

# Create dataset
dataset = GeneratorDataset(generator, size=1000)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=64,
    collate_fn=tensordict_collate_fn,
    num_workers=0,
)

# Get a batch
batch = next(iter(dataloader))
print(f"Batch keys: {list(batch.keys())}")
print(f"Batch size: {batch.batch_size}")

---

## 7. Meta-Learning

Meta-learning enables automatic adaptation of training parameters (e.g., reward weights) through bi-level optimization.

### 7.1 Meta-Learning Wrapper

In [None]:
# Meta-Learning Example
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE, MetaRLModule

# Create environment
env = get_env("vrpp", num_loc=20, device="cpu")

# Create policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create base RL module
base_module = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

# Wrap with meta-learning
meta_module = MetaRLModule(
    agent=base_module,
    meta_lr=1e-3,
    history_length=10,
    hidden_size=64,
)

print("MetaRLModule created:")
print(f"  Inner agent: {base_module.__class__.__name__}")
print(f"  Meta learning rate: {meta_module.meta_lr}")
print(f"  History length: {meta_module.history_length}")

### 7.2 Meta-Learning Strategies

| Strategy | Description | Best For |
|----------|-------------|----------|
| **rnn** | Recurrent network processes reward history | General adaptation |
| **bandit** | UCB/Thompson sampling for weight selection | Discrete weight choices |
| **morl** | Multi-objective Pareto optimization | Multiple objectives |
| **tdl** | Temporal difference learning | Online adaptation |
| **hypernet** | Hypernetwork generates weights | Problem-conditioned |

### 7.3 Bi-Level Optimization Flow

```
┌─────────────────────────────────────────────────────────────────┐
│                    Meta-Learning Training Loop                   │
└─────────────────────────────────────────────────────────────────┘

for epoch in range(n_epochs):
    
    ┌─────────────────────────────────────────────────────────────┐
    │ INNER LOOP: RL Training                                      │
    │                                                              │
    │   for batch in dataloader:                                   │
    │       loss = agent.training_step(batch)                      │
    │       optimizer.step()                                       │
    │                                                              │
    │   reward_signal = epoch_reward                               │
    └─────────────────────────────────────────────────────────────┘
                              │
                              ▼
    ┌─────────────────────────────────────────────────────────────┐
    │ OUTER LOOP: Meta-Strategy Update                             │
    │                                                              │
    │   meta_strategy.update(reward_signal)                        │
    │   new_weights = meta_strategy.propose_weights()              │
    │   env.update_weights(new_weights)                            │
    └─────────────────────────────────────────────────────────────┘
```

---

## 8. Hyperparameter Optimization

The pipeline supports two HPO methods:

### 8.1 Optuna-Based HPO

Multiple sampling strategies:
- **TPE** (Tree-structured Parzen Estimator) - Default
- **Grid Search**
- **Random Search**
- **Hyperband** - Multi-fidelity

In [None]:
from logic.src.configs import HPOConfig

# HPO configuration
hpo_cfg = HPOConfig(
    method="tpe",  # tpe, grid, random, hyperband
    n_trials=50,  # Number of trials
    n_epochs_per_trial=10,  # Epochs per trial
    search_space={
        "optim.lr": [1e-5, 1e-3],  # Log-uniform
        "train.batch_size": [64, 512],  # Integer
        "model.embed_dim": [64, 128, 256],  # Categorical
        "rl.entropy_weight": [0.0, 0.1],  # Uniform
    },
)

print("HPO Configuration:")
print(f"  Method: {hpo_cfg.method}")
print(f"  Trials: {hpo_cfg.n_trials}")
print(f"  Epochs per trial: {hpo_cfg.n_epochs_per_trial}")
print("\nSearch Space:")
for param, range_val in hpo_cfg.search_space.items():
    print(f"  {param}: {range_val}")

### 8.2 HPO via CLI

```bash
# TPE optimization with 50 trials
python main.py train_lightning \
    hpo.n_trials=50 \
    hpo.method=tpe \
    hpo.n_epochs_per_trial=10 \
    'hpo.search_space={"optim.lr": [1e-5, 1e-3], "train.batch_size": [64, 512]}'

# Grid search
python main.py train_lightning \
    hpo.n_trials=100 \
    hpo.method=grid \
    'hpo.search_space={"model.embed_dim": [64, 128, 256], "model.n_heads": [4, 8]}'

# DEHB (multi-fidelity)
python main.py train_lightning \
    hpo.method=dehb \
    hpo.min_fidelity=1 \
    hpo.max_fidelity=50 \
    hpo.fevals=100
```

### 8.3 DEHB (Differential Evolution Hyperband)

DEHB combines:
- **Differential Evolution** for global search
- **Hyperband** for multi-fidelity scheduling

Benefits:
- Early stopping of poor configurations
- Efficient use of compute budget

---

## 9. Practical Examples

### 9.1 Complete Training Script

In [None]:
import pytorch_lightning as pl
from logic.src.callbacks import SpeedMonitor
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE
from logic.src.pipeline.trainer import WSTrainer
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger

# Set seed for reproducibility
seed_everything(42)

# 1. Create Environment
env = get_env(
    "vrpp",
    num_loc=20,
    device="cpu",  # Use "cuda" if available
)

# 2. Create Policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# 3. Create RL Module
model = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=5000,
    val_data_size=500,
    batch_size=64,
    num_workers=0,  # Use 0 for notebooks
)

# 4. Create Trainer
trainer = WSTrainer(
    max_epochs=3,  # Short training for demo
    accelerator="auto",
    devices=1,
    logger=CSVLogger("logs", name="tutorial_demo"),
    callbacks=[SpeedMonitor(epoch_time=True)],
    enable_progress_bar=True,
)

print("Training configuration ready!")
print(f"  Environment: {env.__class__.__name__}")
print(f"  Policy: {policy.__class__.__name__}")
print(f"  Algorithm: REINFORCE with {model.baseline_type} baseline")
print(f"  Epochs: {trainer.max_epochs}")

In [None]:
# Run training (uncomment to execute)
# trainer.fit(model)

# Get final metrics
# print(f"\nFinal validation reward: {trainer.callback_metrics.get('val/reward', 'N/A')}")

### 9.2 CLI Examples Reference

#### Basic Training
```bash
# VRPP with Attention Model
python main.py train_lightning \
    model=am \
    env.name=vrpp \
    env.num_loc=50 \
    train.n_epochs=100

# WCVRP (Waste Collection)
python main.py train_lightning \
    model=am \
    env.name=wcvrp \
    env.num_loc=50 \
    env.capacity=100
```

#### Algorithm Variants
```bash
# PPO with custom parameters
python main.py train_lightning \
    rl.algorithm=ppo \
    rl.ppo_epochs=10 \
    rl.eps_clip=0.2 \
    rl.value_loss_weight=0.5

# POMO with augmentation
python main.py train_lightning \
    rl.algorithm=pomo \
    rl.num_augment=8 \
    rl.num_starts=50

# Imitation learning from HGS
python main.py train_lightning \
    rl.algorithm=imitation \
    rl.expert=hgs
```

#### Advanced Features
```bash
# Meta-learning with RNN strategy
python main.py train_lightning \
    rl.use_meta=true \
    rl.meta_strategy=rnn \
    rl.meta_lr=1e-3 \
    rl.meta_history_length=10

# Learning rate scheduling
python main.py train_lightning \
    optim.lr_scheduler=cosine \
    'optim.lr_scheduler_kwargs={"T_max": 100}'

# Mixed precision training
python main.py train_lightning \
    train.precision=16-mixed
```

---

## 10. Advanced Topics

### 10.1 Custom RL Algorithm

To implement a custom RL algorithm, inherit from `RL4COLitModule`:

In [None]:
from typing import Optional

import torch
from logic.src.pipeline.rl.core.base import RL4COLitModule
from tensordict import TensorDict


class CustomRL(RL4COLitModule):
    """
    Example custom RL algorithm.

    Implements a simple variant with custom advantage normalization.
    """

    def __init__(
        self,
        temperature: float = 1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.temperature = temperature

    def calculate_loss(
        self,
        td: TensorDict,
        out: dict,
        batch_idx: int,
        env: Optional["RL4COEnvBase"] = None,
    ) -> torch.Tensor:
        """
        Custom loss computation.

        Uses temperature-scaled advantage.
        """
        reward = out["reward"]
        log_likelihood = out["log_likelihood"]

        # Get baseline
        if hasattr(self, "_current_baseline_val") and self._current_baseline_val is not None:
            baseline_val = self._current_baseline_val
        else:
            baseline_val = self.baseline.eval(td, reward, env=env)

        # Temperature-scaled advantage
        advantage = (reward - baseline_val) / self.temperature
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)

        # Policy gradient loss
        loss = -(advantage.detach() * log_likelihood).mean()

        return loss


print("CustomRL algorithm defined successfully!")
print(f"  Key parameter: temperature={1.0}")
print("  Inherits from: RL4COLitModule")

### 10.2 Custom Baseline

To implement a custom baseline:

In [None]:
import torch
import torch.nn as nn
from logic.src.pipeline.rl.core.baselines import Baseline
from tensordict import TensorDict


class PercentileBaseline(Baseline):
    """
    Baseline using batch percentile.

    Returns the q-th percentile of rewards as baseline.
    """

    def __init__(self, percentile: float = 50.0):
        super().__init__()
        self.percentile = percentile

    def eval(
        self,
        td: TensorDict,
        reward: torch.Tensor,
        env=None,
    ) -> torch.Tensor:
        """
        Compute percentile baseline.

        Args:
            td: TensorDict (unused)
            reward: Batch rewards
            env: Environment (unused)

        Returns:
            Percentile value expanded to reward shape
        """
        # Compute percentile
        baseline_val = torch.quantile(reward.float(), self.percentile / 100.0)
        return baseline_val.expand_as(reward)


# Test the custom baseline
baseline = PercentileBaseline(percentile=75.0)
test_rewards = torch.randn(64)
baseline_val = baseline.eval(None, test_rewards)

print("PercentileBaseline (75th):")
print(f"  Input rewards shape: {test_rewards.shape}")
print(f"  Baseline value: {baseline_val[0].item():.4f}")
print(f"  Actual 75th percentile: {torch.quantile(test_rewards, 0.75).item():.4f}")

### 10.3 Multi-GPU Training

PyTorch Lightning automatically handles distributed training:

```bash
# Single GPU
python main.py train_lightning device=cuda

# Multiple GPUs (DDP)
python main.py train_lightning \
    device=cuda \
    --trainer.devices=4 \
    --trainer.strategy=ddp

# DeepSpeed (for large models)
python main.py train_lightning \
    --trainer.strategy=deepspeed_stage_2
```

### 10.4 Checkpoint Management

```bash
# Resume from checkpoint
python main.py train_lightning \
    --ckpt_path=/path/to/checkpoint.ckpt

# Custom checkpoint directory
python main.py train_lightning \
    output_dir=assets/model_weights/experiment1
```

### 10.5 Debugging Tips

```bash
# Fast dev run (1 batch per epoch)
python main.py train_lightning \
    --trainer.fast_dev_run=true

# Limit batches for debugging
python main.py train_lightning \
    --trainer.limit_train_batches=10 \
    --trainer.limit_val_batches=5

# Profiling
python main.py train_lightning \
    --trainer.profiler=simple
```

---

## Summary

This comprehensive tutorial covered the Lightning-based RL training pipeline for WSmart+ Route.

### Key Components Covered

| Category | Count | Description |
|----------|-------|-------------|
| **RL Algorithms** | 11 | REINFORCE, PPO, SAPO, GSPO, GDPO, DRGRPO, POMO, SymNCO, Imitation, AdaptiveImitation, HRL |
| **Baselines** | 7 | None, Exponential, Rollout, Critic, Warmup, POMO |
| **Environments** | 6 | VRPP, CVRPP, WCVRP, CWCVRP, SDWCVRP, SCWCVRP |
| **Models** | 9 | AM, DDAM, TAM, Pointer, Critic, HyperNet, MetaRNN, MOE, GATLSTManager |
| **Encoders** | 8 | GAT, GAC, TGC, GGAC, GCN, MLP, Pointer, MOE |
| **Modules** | 11 | Attention, GraphConv, FeedForward, Normalization, Activation, Skip, Hyper, MOE |

### Quick Reference Commands

```bash
# Basic training
python main.py train_lightning model=am env.name=vrpp

# PPO with rollout baseline
python main.py train_lightning rl.algorithm=ppo rl.baseline=rollout

# POMO with augmentation
python main.py train_lightning rl.algorithm=pomo rl.num_augment=8

# Imitation learning from HGS expert
python main.py train_lightning rl.algorithm=imitation rl.imitation_mode=hgs

# Adaptive IL → RL transition
python main.py train_lightning rl.algorithm=adaptive_imitation rl.imitation_decay=0.95

# Meta-learning
python main.py train_lightning meta_rl.enabled=true meta_rl.strategy=rnn

# Hyperparameter optimization
python main.py train_lightning hpo.n_trials=50 hpo.method=tpe

# Multi-GPU training
python main.py train_lightning --trainer.devices=4 --trainer.strategy=ddp
```

### Extension Points

This tutorial covered how to add:
1. **New Environment/Task** - Inherit from `RL4COEnvBase`
2. **New Encoder** - Create module in `subnets/`, register in factory
3. **New RL Algorithm** - Inherit from `RL4COLitModule`, implement `calculate_loss()`
4. **New Baseline** - Inherit from `Baseline`, implement `eval()`
5. **New Configuration** - Create dataclass, add to root `Config`

### Further Reading

| Resource | Location |
|----------|----------|
| Project Overview | `CLAUDE.md` |
| Architecture | `ARCHITECTURE.md` |
| CLI Reference | `python main.py --help` |
| RL Pipeline Source | `logic/src/pipeline/rl/` |
| Configuration | `logic/src/configs/` |
| Environments | `logic/src/envs/` |
| Models | `logic/src/models/` |

### Version History

| Version | Date | Changes |
|---------|------|---------|
| 2.0 | Jan 2026 | Added all algorithms, models, extension guides, best practices |
| 1.0 | Jan 2026 | Initial tutorial |

---

## 11. Adding New Components Guide

This section provides step-by-step guides for extending the framework with new components.

### 11.1 Adding a New Environment/Task

**Location:** `logic/src/envs/`

**Step 1: Create the environment file**

```python
# logic/src/envs/my_problem.py

import torch
from tensordict import TensorDict
from logic.src.envs.base import RL4COEnvBase
from logic.src.envs.generators import Generator


class MyProblemGenerator(Generator):
    """Generator for MyProblem instances."""
    
    def __init__(
        self,
        num_loc: int = 50,
        min_loc: float = 0.0,
        max_loc: float = 1.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.num_loc = num_loc
        self.min_loc = min_loc
        self.max_loc = max_loc
    
    def _generate(self, batch_size: int) -> TensorDict:
        """Generate random problem instances."""
        # Locations (batch, num_loc, 2)
        locs = torch.rand(batch_size, self.num_loc, 2)
        locs = locs * (self.max_loc - self.min_loc) + self.min_loc
        
        # Depot (batch, 2)
        depot = torch.rand(batch_size, 2)
        
        # Problem-specific features
        demands = torch.rand(batch_size, self.num_loc)
        
        return TensorDict({
            "locs": locs,
            "depot": depot,
            "demand": demands,
        }, batch_size=[batch_size])


class MyProblemEnv(RL4COEnvBase):
    """Environment for MyProblem."""
    
    name = "myproblem"
    
    def __init__(
        self,
        num_loc: int = 50,
        capacity: float = 1.0,
        **kwargs
    ):
        super().__init__(**kwargs)
        self.num_loc = num_loc
        self.capacity = capacity
        self.generator = MyProblemGenerator(num_loc=num_loc, **kwargs)
    
    def _reset(self, td: TensorDict) -> TensorDict:
        """Reset environment to initial state."""
        batch_size = td.batch_size[0]
        device = td.device
        
        # Initialize state
        td["current_node"] = torch.zeros(batch_size, dtype=torch.long, device=device)
        td["visited"] = torch.zeros(batch_size, self.num_loc, dtype=torch.bool, device=device)
        td["current_load"] = torch.zeros(batch_size, device=device)
        td["done"] = torch.zeros(batch_size, dtype=torch.bool, device=device)
        
        # Compute action mask
        td["action_mask"] = self._get_action_mask(td)
        
        return td
    
    def _step(self, td: TensorDict) -> TensorDict:
        """Execute one step in the environment."""
        action = td["action"]
        batch_size = td.batch_size[0]
        
        # Update visited
        td["visited"].scatter_(1, action.unsqueeze(-1), True)
        
        # Update current node
        td["current_node"] = action
        
        # Update load
        demand = td["demand"].gather(1, action.unsqueeze(-1)).squeeze(-1)
        td["current_load"] = td["current_load"] + demand
        
        # Check if done
        td["done"] = td["visited"].all(dim=-1)
        
        # Update action mask
        td["action_mask"] = self._get_action_mask(td)
        
        return td
    
    def _get_action_mask(self, td: TensorDict) -> torch.Tensor:
        """Compute valid actions mask."""
        mask = ~td["visited"]
        # Add capacity constraints, etc.
        return mask
    
    def _get_reward(self, td: TensorDict, actions: torch.Tensor) -> torch.Tensor:
        """Compute reward (negative cost)."""
        # Compute tour length
        locs = td["locs"]
        depot = td["depot"]
        
        # Simplified: compute total distance
        # ... actual implementation
        
        return -cost  # Negative because we minimize
```

**Step 2: Register in `__init__.py`**

```python
# logic/src/envs/__init__.py

from logic.src.envs.my_problem import MyProblemEnv, MyProblemGenerator

# Add to registry
ENV_REGISTRY = {
    # ... existing envs
    "myproblem": MyProblemEnv,
}

GENERATOR_REGISTRY = {
    # ... existing generators
    "myproblem": MyProblemGenerator,
}
```

**Step 3: Add configuration**

```python
# logic/src/configs/env.py

@dataclass
class EnvConfig:
    name: str = "vrpp"  # Add "myproblem" as valid option
    # ... add any new config fields
```

### 11.2 Adding a New Encoder

**Location:** `logic/src/models/subnets/`

**Step 1: Create the encoder file**

```python
# logic/src/models/subnets/my_encoder.py

import torch
import torch.nn as nn
from logic.src.models.modules import Normalization, ActivationFunction


class MyCustomEncoder(nn.Module):
    """
    Custom encoder for node embedding.
    
    Args:
        embed_dim: Embedding dimension
        hidden_dim: Hidden layer dimension
        n_layers: Number of encoder layers
        n_heads: Number of attention heads
        normalization: Type of normalization
        activation: Activation function
    """
    
    def __init__(
        self,
        embed_dim: int = 128,
        hidden_dim: int = 512,
        n_layers: int = 3,
        n_heads: int = 8,
        normalization: str = "instance",
        activation: str = "gelu",
        dropout: float = 0.1,
    ):
        super().__init__()
        
        self.embed_dim = embed_dim
        self.n_layers = n_layers
        
        # Initial embedding
        self.init_embed = nn.Linear(2, embed_dim)  # 2D coordinates
        
        # Encoder layers
        self.layers = nn.ModuleList([
            self._make_layer(embed_dim, hidden_dim, n_heads, normalization, activation, dropout)
            for _ in range(n_layers)
        ])
        
    def _make_layer(self, embed_dim, hidden_dim, n_heads, normalization, activation, dropout):
        """Create a single encoder layer."""
        return nn.ModuleDict({
            "attention": nn.MultiheadAttention(embed_dim, n_heads, dropout=dropout, batch_first=True),
            "norm1": Normalization(embed_dim, normalization),
            "ff": nn.Sequential(
                nn.Linear(embed_dim, hidden_dim),
                ActivationFunction(activation),
                nn.Dropout(dropout),
                nn.Linear(hidden_dim, embed_dim),
            ),
            "norm2": Normalization(embed_dim, normalization),
        })
    
    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Node features (batch, n_nodes, input_dim)
            mask: Attention mask (batch, n_nodes)
            
        Returns:
            Node embeddings (batch, n_nodes, embed_dim)
        """
        # Initial embedding
        h = self.init_embed(x)
        
        # Encoder layers
        for layer in self.layers:
            # Self-attention
            attn_out, _ = layer["attention"](h, h, h, key_padding_mask=mask)
            h = layer["norm1"](h + attn_out)
            
            # Feed-forward
            ff_out = layer["ff"](h)
            h = layer["norm2"](h + ff_out)
        
        return h
```

**Step 2: Register in `__init__.py`**

```python
# logic/src/models/subnets/__init__.py

from .my_encoder import MyCustomEncoder as MyCustomEncoder
```

**Step 3: Add to model factory**

```python
# logic/src/models/model_factory.py

ENCODER_REGISTRY = {
    "gat": GraphAttentionEncoder,
    "gcn": GraphConvolutionEncoder,
    "my_encoder": MyCustomEncoder,  # Add here
}
```

### 11.3 Adding a New RL Algorithm

**Location:** `logic/src/pipeline/rl/core/`

**Step 1: Create the algorithm file**

```python
# logic/src/pipeline/rl/core/my_algorithm.py

from typing import Optional
import torch
from tensordict import TensorDict
from logic.src.pipeline.rl.core.base import RL4COLitModule


class MyAlgorithm(RL4COLitModule):
    """
    Custom RL algorithm.
    
    Implements a novel policy gradient variant with custom loss.
    """
    
    def __init__(
        self,
        # Algorithm-specific parameters
        my_param: float = 1.0,
        temperature: float = 1.0,
        **kwargs  # Pass to parent
    ):
        super().__init__(**kwargs)
        self.my_param = my_param
        self.temperature = temperature
        
        # Save hyperparameters for checkpointing
        self.save_hyperparameters(ignore=["env", "policy"])
    
    def calculate_loss(
        self,
        td: TensorDict,
        out: dict,
        batch_idx: int,
        env: Optional["RL4COEnvBase"] = None,
    ) -> torch.Tensor:
        """
        Compute the custom loss.
        
        Args:
            td: TensorDict with environment state
            out: Dict with policy outputs (reward, log_likelihood, etc.)
            batch_idx: Current batch index
            env: Environment instance
            
        Returns:
            Scalar loss tensor
        """
        reward = out["reward"]
        log_likelihood = out["log_likelihood"]
        
        # Get baseline value
        if hasattr(self, "_current_baseline_val") and self._current_baseline_val is not None:
            baseline_val = self._current_baseline_val
        else:
            baseline_val = self.baseline.eval(td, reward, env=env)
        
        # Compute advantage
        advantage = reward - baseline_val
        
        # Custom loss computation
        # Example: temperature-scaled policy gradient
        scaled_advantage = advantage / self.temperature
        scaled_advantage = (scaled_advantage - scaled_advantage.mean()) / (scaled_advantage.std() + 1e-8)
        
        # Policy gradient loss
        loss = -(scaled_advantage.detach() * log_likelihood).mean()
        
        # Optional: add regularization
        loss = loss * self.my_param
        
        # Log metrics
        self.log("train/advantage", advantage.mean(), prog_bar=False)
        self.log("train/baseline", baseline_val.mean(), prog_bar=False)
        
        return loss
```

**Step 2: Register in `__init__.py`**

```python
# logic/src/pipeline/rl/core/__init__.py

from logic.src.pipeline.rl.core.my_algorithm import MyAlgorithm

__all__ = [
    # ... existing
    "MyAlgorithm",
]
```

**Step 3: Add to config**

```python
# logic/src/configs/rl.py

@dataclass
class RLConfig:
    algorithm: str = "reinforce"  # Add "my_algorithm" as valid option
    # Add algorithm-specific parameters
    my_param: float = 1.0
    temperature: float = 1.0
```

### 11.4 Adding a New Baseline

**Location:** `logic/src/pipeline/rl/core/baselines.py`

```python
# Add to logic/src/pipeline/rl/core/baselines.py

class MyCustomBaseline(Baseline):
    """
    Custom baseline implementation.
    
    Example: Percentile-based baseline.
    """
    
    def __init__(self, percentile: float = 50.0):
        super().__init__()
        self.percentile = percentile
    
    def setup(
        self,
        policy: nn.Module,
        env: "RL4COEnvBase",
        **kwargs
    ) -> "Baseline":
        """Initialize baseline (called once at start)."""
        self.policy = policy
        self.env = env
        return self
    
    def eval(
        self,
        td: TensorDict,
        reward: torch.Tensor,
        env: Optional["RL4COEnvBase"] = None,
    ) -> torch.Tensor:
        """
        Compute baseline value.
        
        Args:
            td: TensorDict with state
            reward: Batch rewards
            env: Environment (optional)
            
        Returns:
            Baseline values (same shape as reward)
        """
        # Compute percentile
        baseline_val = torch.quantile(reward.float(), self.percentile / 100.0)
        return baseline_val.expand_as(reward)
    
    def epoch_callback(
        self,
        policy: nn.Module,
        epoch: int,
        val_dataset: "TensorDictDataset",
        env: "RL4COEnvBase",
        **kwargs
    ) -> dict:
        """Called at end of each epoch (optional)."""
        return {}  # Return any metrics to log


# Register in BASELINE_REGISTRY
BASELINE_REGISTRY = {
    # ... existing
    "percentile": MyCustomBaseline,
}
```

### 11.5 Adding a New Configuration Section

**Location:** `logic/src/configs/`

```python
# logic/src/configs/my_config.py

from dataclasses import dataclass, field
from typing import Optional, List, Dict, Any


@dataclass
class MyFeatureConfig:
    """Configuration for my new feature.
    
    Attributes:
        enabled: Whether feature is enabled
        param1: First parameter
        param2: Second parameter
        options: List of options
    """
    
    enabled: bool = False
    param1: float = 1.0
    param2: int = 10
    options: List[str] = field(default_factory=lambda: ["option1", "option2"])
    advanced_settings: Dict[str, Any] = field(default_factory=dict)
```

**Add to root config:**

```python
# logic/src/configs/__init__.py

from .my_config import MyFeatureConfig

@dataclass
class Config:
    # ... existing fields
    my_feature: MyFeatureConfig = field(default_factory=MyFeatureConfig)
```

---

## 12. Hydra + Lightning + TorchRL Best Practices

### 12.1 Hydra Configuration Best Practices

#### Structured Configs with Dataclasses

```python
from dataclasses import dataclass, field
from omegaconf import MISSING  # For required fields

@dataclass
class MyConfig:
    # Required field (must be set)
    name: str = MISSING
    
    # Optional with default
    learning_rate: float = 1e-4
    
    # List with factory
    layers: List[int] = field(default_factory=lambda: [128, 256, 128])
    
    # Nested config
    optimizer: OptimizerConfig = field(default_factory=OptimizerConfig)
```

#### CLI Override Patterns

```bash
# Simple override
python main.py train_lightning model.embed_dim=256

# Nested override
python main.py train_lightning optim.lr_scheduler_kwargs.T_max=100

# List override (use quotes)
python main.py train_lightning 'model.layers=[64,128,64]'

# Dict override
python main.py train_lightning 'rl.gdpo_objective_keys=["cost","overflow"]'

# Boolean flags
python main.py train_lightning train.eval_only=true

# Override with None
python main.py train_lightning train.val_dataset=null

# Multiple overrides
python main.py train_lightning model=am env.name=vrpp train.n_epochs=50
```

#### Multi-run and Sweeps

```bash
# Grid search over multiple values
python main.py train_lightning -m model.embed_dim=64,128,256 optim.lr=1e-3,1e-4

# Range sweep
python main.py train_lightning -m 'train.batch_size=range(64,512,64)'

# Glob pattern for files
python main.py train_lightning -m 'train.val_dataset=glob(data/*.pkl)'
```

### 12.2 PyTorch Lightning Best Practices

#### Logging Metrics

```python
class MyModule(RL4COLitModule):
    def calculate_loss(self, td, out, batch_idx, env=None):
        # Log scalar metrics
        self.log("train/loss", loss, prog_bar=True)
        self.log("train/reward", reward.mean(), sync_dist=True)
        
        # Log multiple metrics at once
        self.log_dict({
            "train/advantage": advantage.mean(),
            "train/baseline": baseline_val.mean(),
            "train/entropy": entropy.mean(),
        })
        
        # Log with specific settings
        self.log(
            "train/custom",
            value,
            on_step=True,      # Log at each step
            on_epoch=True,     # Also log epoch average
            prog_bar=True,     # Show in progress bar
            logger=True,       # Send to logger
            sync_dist=True,    # Sync across GPUs
        )
        
        return loss
```

#### Checkpointing

```python
from pytorch_lightning.callbacks import ModelCheckpoint

# Save best model by validation reward
checkpoint_callback = ModelCheckpoint(
    dirpath="checkpoints/",
    filename="best-{epoch:02d}-{val_reward:.4f}",
    monitor="val/reward",
    mode="max",
    save_top_k=3,
    save_last=True,
)

# Resume from checkpoint
trainer.fit(model, ckpt_path="checkpoints/last.ckpt")
```

#### Learning Rate Scheduling

```python
class MyModule(RL4COLitModule):
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        
        # Cosine annealing
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.trainer.max_epochs
        )
        
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",  # or "step"
                "frequency": 1,
                "monitor": "val/reward",
            }
        }
```

### 12.3 TensorDict Best Practices

#### Efficient Batched Operations

```python
from tensordict import TensorDict

# Create TensorDict
td = TensorDict({
    "locs": torch.rand(batch_size, n_nodes, 2),
    "demand": torch.rand(batch_size, n_nodes),
}, batch_size=[batch_size])

# Move to device (moves all tensors)
td = td.to(device)

# Clone for modification
td_clone = td.clone()

# Update in-place
td["visited"] = torch.zeros(batch_size, n_nodes, dtype=torch.bool)

# Batch indexing
subset = td[0:10]  # First 10 samples

# Apply function to all tensors
td = td.apply(lambda x: x.float())
```

#### Environment State Management

```python
# Reset returns TensorDict with initial state
td = env.reset(batch)

# Step modifies state in-place
td["action"] = action
td = env.step(td)

# Access state components
current_node = td["current_node"]
action_mask = td["action_mask"]
done = td["done"]
```

### 12.4 Common Pitfalls and Solutions

| Pitfall | Solution |
|---------|----------|
| **OOM on large batches** | Use `accumulate_grad_batches` in trainer |
| **Slow data loading** | Increase `num_workers`, use `pin_memory=True` |
| **NaN in loss** | Check gradient clipping, reduce learning rate |
| **Baseline not updating** | Verify `epoch_callback` is called |
| **Config not merging** | Use `OmegaConf.merge()` for nested configs |
| **GPU memory leak** | Call `.detach()` on baseline values |

### 12.5 Debugging Tips

```python
# Fast dev run (1 batch)
trainer = Trainer(fast_dev_run=True)

# Limit batches for debugging
trainer = Trainer(
    limit_train_batches=10,
    limit_val_batches=5,
)

# Enable gradient anomaly detection
torch.autograd.set_detect_anomaly(True)

# Profile training
trainer = Trainer(profiler="simple")  # or "advanced"

# Check model device
print(f"Model device: {next(model.parameters()).device}")

# Check TensorDict device
print(f"TD device: {td.device}")
```

### 12.6 Performance Optimization

```python
# Mixed precision training (faster on modern GPUs)
trainer = Trainer(precision="16-mixed")

# Gradient accumulation for effective larger batches
trainer = Trainer(accumulate_grad_batches=4)

# Compile model (PyTorch 2.0+)
model = torch.compile(model)

# Use persistent workers
DataLoader(..., persistent_workers=True, num_workers=4)

# Pin memory for faster GPU transfer
DataLoader(..., pin_memory=True)
```

---

## 13. Troubleshooting & Common Patterns

### 13.1 Common Errors and Solutions

#### "CUDA out of memory"
```python
# Solution 1: Reduce batch size
train.batch_size=128  # Instead of 256

# Solution 2: Use gradient accumulation
trainer = Trainer(accumulate_grad_batches=2)

# Solution 3: Use mixed precision
trainer = Trainer(precision="16-mixed")

# Solution 4: Clear cache between batches
torch.cuda.empty_cache()
```

#### "NaN in loss"
```python
# Check 1: Gradient clipping
rl.max_grad_norm=0.5  # Reduce from 1.0

# Check 2: Learning rate
optim.lr=1e-5  # Reduce from 1e-4

# Check 3: Baseline values
baseline_val = baseline_val.detach()  # Must detach!

# Check 4: Log likelihood explosion
log_likelihood = log_likelihood.clamp(min=-100, max=0)
```

#### "Baseline not improving"
```python
# Solution: Check T-test threshold
rl.bl_alpha=0.1  # More lenient (default 0.05)

# Or use exponential baseline instead
rl.baseline="exponential"
rl.exp_beta=0.9
```

#### "Reward not improving"
```python
# Check 1: Increase exploration
rl.entropy_weight=0.01

# Check 2: Use warmup baseline
rl.bl_warmup_epochs=5

# Check 3: Try different algorithm
rl.algorithm=pomo  # If problem has symmetry
```

### 13.2 Recommended Configurations by Problem

#### VRPP (Vehicle Routing with Profits)
```bash
python main.py train_lightning \
    env.name=vrpp \
    env.num_loc=50 \
    model=am \
    model.embed_dim=128 \
    rl.algorithm=reinforce \
    rl.baseline=rollout \
    train.n_epochs=100 \
    train.batch_size=256
```

#### WCVRP (Waste Collection VRP)
```bash
python main.py train_lightning \
    env.name=cwcvrp \
    env.num_loc=50 \
    env.capacity=100 \
    model=am \
    model.encoder_type=gat \
    rl.algorithm=pomo \
    rl.num_augment=8 \
    train.n_epochs=100
```

#### Large-scale (100+ nodes)
```bash
python main.py train_lightning \
    env.num_loc=100 \
    model.embed_dim=256 \
    model.num_encoder_layers=6 \
    train.batch_size=128 \
    train.precision=16-mixed \
    optim.lr=5e-5
```

#### Multi-day temporal
```bash
python main.py train_lightning \
    env.name=sdwcvrp \
    model=temporal \
    model.temporal_horizon=7 \
    train.train_time=true \
    train.eval_time_days=7
```

### 13.3 Quick Reference Cheatsheet

#### All Available Algorithms
| Key | Class | Best For |
|-----|-------|----------|
| `reinforce` | REINFORCE | Simple baseline |
| `ppo` | PPO | Stable training |
| `sapo` | SAPO | Adaptive clipping |
| `gspo` | GSPO | Sequence-level |
| `gdpo` | GDPO | Multi-objective |
| `dr_grpo` | DRGRPO | Divergence-regularized |
| `pomo` | POMO | Symmetric problems |
| `symnco` | SymNCO | Symmetry exploitation |
| `imitation` | ImitationLearning | Expert guidance |
| `adaptive_imitation` | AdaptiveImitation | IL → RL |
| `hrl` | HRLModule | Hierarchical |

#### All Available Baselines
| Key | Class | Description |
|-----|-------|-------------|
| `none` | NoBaseline | No baseline (high variance) |
| `exponential` | ExponentialBaseline | Moving average |
| `rollout` | RolloutBaseline | Greedy rollout |
| `critic` | CriticBaseline | Learned value network |
| `warmup` | WarmupBaseline | Gradual transition |
| `pomo` | POMOBaseline | Multi-start mean |

#### All Available Environments
| Key | Class | Description |
|-----|-------|-------------|
| `vrpp` | VRPPEnv | Vehicle Routing with Profits |
| `cvrpp` | CVRPPEnv | Capacitated VRPP |
| `wcvrp` | WCVRPEnv | Waste Collection VRP |
| `cwcvrp` | CWCVRPEnv | Capacitated WCVRP |
| `sdwcvrp` | SDWCVRPEnv | Stochastic Demand WCVRP |
| `scwcvrp` | SCWCVRPEnv | Selective Capacitated WCVRP |

#### All Available Encoders
| Key | Class | Architecture |
|-----|-------|--------------|
| `gat` | GraphAttentionEncoder | Multi-head GAT |
| `gac` | GraphAttConvEncoder | GAT + edge features |
| `tgc` | TransGraphConvEncoder | Transformer-style |
| `ggac` | GatedGraphAttConvEncoder | Gated GAT |
| `gcn` | GCNEncoder | Standard GCN |
| `mlp` | MLPEncoder | No graph structure |
| `ptr` | PointerEncoder | RNN-based |
| `moe` | MOEEncoder | Mixture of Experts |

#### All Available Models
| Key | Class | Use Case |
|-----|-------|----------|
| `am` | AttentionModel | Standard routing |
| `deep_decoder` | DeepDecoderAM | Complex problems |
| `temporal` | TemporalAM | Multi-day scenarios |
| `pointer` | PointerNetwork | Classic seq2seq |