# WSmart+ Route: Lightning-Based RL Training Tutorial

**Version:** 1.0  
**Last Updated:** January 2026

This comprehensive tutorial covers the reinforcement learning training pipeline built on **PyTorch Lightning** and **Hydra** for combinatorial optimization problems.

## Table of Contents

1. [Overview & Architecture](#1-overview--architecture)
2. [Hydra Configuration System](#2-hydra-configuration-system)
3. [PyTorch Lightning Modules](#3-pytorch-lightning-modules)
4. [RL Algorithms](#4-rl-algorithms)
5. [Baselines for Variance Reduction](#5-baselines-for-variance-reduction)
6. [Environments & Data Generation](#6-environments--data-generation)
7. [Meta-Learning](#7-meta-learning)
8. [Hyperparameter Optimization](#8-hyperparameter-optimization)
9. [Practical Examples](#9-practical-examples)
10. [Advanced Topics](#10-advanced-topics)

In [None]:
# Standard notebook setup
from notebook_setup import setup_google_colab, setup_home_directory

NOTEBOOK_NAME = "lightning_rl_training_tutorial"
home_dir = setup_home_directory(NOTEBOOK_NAME)
IN_COLAB, gdrive, gfiles = setup_google_colab(NOTEBOOK_NAME)

---

## 1. Overview & Architecture

The WSmart+ Route training pipeline is built on three foundational technologies:

### 1.1 Core Technologies

| Technology | Purpose | Key Benefits |
|------------|---------|-------------|
| **PyTorch Lightning** | Training orchestration | Automatic GPU management, logging, checkpointing |
| **Hydra** | Configuration management | Hierarchical configs, CLI overrides, experiment tracking |
| **TensorDict** | State management | Efficient batched operations, device-agnostic tensors |

### 1.2 Pipeline Architecture

```
┌─────────────────────────────────────────────────────────────────┐
│                      train_lightning.py                         │
│                    (Hydra CLI Entry Point)                      │
└─────────────────────────────────────────────────────────────────┘
                              │
              ┌───────────────┼───────────────┐
              │               │               │
              ▼               ▼               ▼
┌─────────────────┐ ┌─────────────────┐ ┌─────────────────┐
│   Config        │ │   Environment   │ │   RL Module     │
│   (Hydra)       │ │   (RL4COEnv)    │ │   (Lightning)   │
└─────────────────┘ └─────────────────┘ └─────────────────┘
                              │
                              ▼
              ┌───────────────────────────────┐
              │        WSTrainer              │
              │   (PyTorch Lightning)         │
              └───────────────────────────────┘
                              │
                              ▼
              ┌───────────────────────────────┐
              │   Training Loop               │
              │   - Data Generation           │
              │   - Forward Pass              │
              │   - Loss Computation          │
              │   - Baseline Updates          │
              └───────────────────────────────┘
```

### 1.3 Directory Structure

```
logic/src/pipeline/rl/
├── core/                    # RL algorithms
│   ├── base.py             # RL4COLitModule (base Lightning module)
│   ├── baselines.py        # Variance reduction baselines
│   ├── reinforce.py        # REINFORCE algorithm
│   ├── ppo.py              # Proximal Policy Optimization
│   ├── sapo.py             # Soft Adaptive PPO
│   ├── gspo.py             # Group Sequence PO
│   ├── dr_grpo.py          # Divergence-Regularized GRPO
│   ├── gdpo.py             # Group Distributional PO
│   ├── pomo.py             # Policy Optimization Multiple Optima
│   ├── symnco.py           # Symmetry-aware NCO
│   ├── hrl.py              # Hierarchical RL
│   ├── imitation.py        # Imitation Learning
│   └── adaptive_imitation.py
├── meta/                    # Meta-learning
│   ├── module.py           # MetaRLModule wrapper
│   ├── weight_optimizer.py
│   ├── contextual_bandits.py
│   ├── td_learning.py
│   └── multi_objective.py
├── hpo/                     # Hyperparameter optimization
│   ├── optuna_hpo.py
│   └── dehb.py
└── features/                # Training utilities
    ├── epoch.py
    ├── post_processing.py
    └── time_training.py
```

---

## 2. Hydra Configuration System

The training pipeline uses **Hydra** for configuration management, enabling:
- Hierarchical configuration with dataclasses
- Command-line overrides
- Experiment tracking and reproducibility

### 2.1 Configuration Dataclasses

All configurations are defined in `logic/src/configs/__init__.py`:

In [None]:
from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional

# Examine the configuration structure
from logic.src.configs import (
    Config,
    EnvConfig,
    ModelConfig,
    TrainConfig,
    OptimConfig,
    RLConfig,
    HPOConfig,
)

# Print available config sections
print("Configuration Sections:")
print("="*50)
for config_cls in [EnvConfig, ModelConfig, TrainConfig, OptimConfig, RLConfig, HPOConfig]:
    print(f"\n{config_cls.__name__}:")
    for field_name, field_type in config_cls.__annotations__.items():
        default = getattr(config_cls, field_name, "<no default>")
        if callable(default) and hasattr(default, "__name__"):
            default = f"<factory: {default.__name__}>"
        print(f"  {field_name}: {field_type} = {default}")

### 2.2 Configuration Sections Explained

#### EnvConfig - Environment Settings
```python
@dataclass
class EnvConfig:
    name: str = "vrpp"           # Problem type: vrpp, wcvrp, cwcvrp, sdwcvrp
    num_loc: int = 50             # Number of locations (nodes)
    min_loc: float = 0.0          # Min coordinate value
    max_loc: float = 1.0          # Max coordinate value
    capacity: Optional[float] = None  # Vehicle capacity
    overflow_penalty: float = 1.0     # Penalty for bin overflow
    collection_reward: float = 1.0    # Reward for waste collection
    cost_weight: float = 1.0          # Weight for travel cost
    prize_weight: float = 1.0         # Weight for prizes (VRPP)
```

#### ModelConfig - Neural Network Architecture
```python
@dataclass
class ModelConfig:
    name: str = "am"              # Model: am, deep_decoder, temporal, pointer, symnco
    embed_dim: int = 128          # Embedding dimension
    hidden_dim: int = 512         # Hidden layer dimension
    num_encoder_layers: int = 3   # Number of encoder layers
    num_decoder_layers: int = 3   # Number of decoder layers (deep_decoder)
    num_heads: int = 8            # Attention heads
    encoder_type: str = "gat"     # Encoder type: gat, gcn, mlp
    normalization: str = "instance"  # Normalization: instance, batch, layer
    activation: str = "gelu"         # Activation function
    dropout: float = 0.1             # Dropout rate
```

#### RLConfig - RL Algorithm Settings
```python
@dataclass
class RLConfig:
    algorithm: str = "reinforce"  # reinforce, ppo, sapo, gspo, pomo, symnco, hrl
    baseline: str = "rollout"     # none, exponential, rollout, critic, pomo
    entropy_weight: float = 0.0   # Entropy regularization
    max_grad_norm: float = 1.0    # Gradient clipping
    
    # PPO-specific
    ppo_epochs: int = 10          # Inner PPO epochs
    eps_clip: float = 0.2         # Clipping epsilon
    value_loss_weight: float = 0.5
    
    # Meta-learning
    use_meta: bool = False        # Enable meta-learning wrapper
    meta_strategy: str = "rnn"    # rnn, bandit, morl, tdl, hypernet
```

### 2.3 Command-Line Interface Usage

The main entry point is `train_lightning.py`, accessed via:

```bash
python main.py train_lightning [CONFIG_OVERRIDES]
```

#### Basic Training Examples

```bash
# Train Attention Model on VRPP with 50 nodes
python main.py train_lightning model=am env.name=vrpp env.num_loc=50

# Train with PPO algorithm
python main.py train_lightning rl.algorithm=ppo rl.ppo_epochs=10

# Use different baseline
python main.py train_lightning rl.baseline=exponential

# Custom training parameters
python main.py train_lightning \
    train.n_epochs=100 \
    train.batch_size=256 \
    optim.lr=1e-4
```

#### Advanced Configurations

```bash
# POMO with data augmentation
python main.py train_lightning \
    rl.algorithm=pomo \
    rl.num_augment=8 \
    rl.augment_fn=dihedral8

# Hierarchical RL
python main.py train_lightning \
    rl.algorithm=hrl \
    rl.meta_hidden_dim=128

# With Meta-Learning wrapper
python main.py train_lightning \
    rl.use_meta=true \
    rl.meta_strategy=rnn \
    rl.meta_lr=1e-3

# Hyperparameter optimization
python main.py train_lightning \
    hpo.n_trials=50 \
    hpo.method=tpe
```

In [None]:
# Programmatically create a configuration
from logic.src.configs import Config, EnvConfig, ModelConfig, TrainConfig, OptimConfig, RLConfig

# Create custom configuration
cfg = Config(
    env=EnvConfig(
        name="vrpp",
        num_loc=50,
        capacity=100.0,
    ),
    model=ModelConfig(
        name="am",
        embed_dim=128,
        num_encoder_layers=3,
        num_heads=8,
    ),
    train=TrainConfig(
        n_epochs=10,
        batch_size=256,
        train_data_size=10000,
    ),
    optim=OptimConfig(
        optimizer="adam",
        lr=1e-4,
    ),
    rl=RLConfig(
        algorithm="reinforce",
        baseline="rollout",
    ),
    seed=42,
    device="cuda" if torch.cuda.is_available() else "cpu",
)

print(f"Configuration created:")
print(f"  Environment: {cfg.env.name} with {cfg.env.num_loc} locations")
print(f"  Model: {cfg.model.name} with {cfg.model.embed_dim}d embeddings")
print(f"  Algorithm: {cfg.rl.algorithm} with {cfg.rl.baseline} baseline")
print(f"  Device: {cfg.device}")

---

## 3. PyTorch Lightning Modules

All RL algorithms inherit from `RL4COLitModule`, which provides:
- Training/validation/test loops
- Automatic optimizer configuration
- Data loading with generators
- Baseline integration
- Metric logging

### 3.1 Base Module Architecture

In [None]:
from logic.src.pipeline.rl.core.base import RL4COLitModule
import inspect

# Show the base class signature
print("RL4COLitModule.__init__ signature:")
print(inspect.signature(RL4COLitModule.__init__))

print("\nKey Methods:")
for name, method in inspect.getmembers(RL4COLitModule, predicate=inspect.isfunction):
    if not name.startswith('_') or name in ['__init__']:
        doc = method.__doc__
        first_line = doc.split('\n')[0].strip() if doc else "No docstring"
        print(f"  {name}: {first_line}")

### 3.2 Training Lifecycle

```python
class RL4COLitModule(pl.LightningModule, ABC):
    """
    Training Lifecycle:
    
    1. setup(stage='fit'):
       - Create train_dataset (GeneratorDataset)
       - Create val_dataset
       
    2. for epoch in range(n_epochs):
       
       3. on_train_epoch_start():
          - Wrap dataset with baseline (if RolloutBaseline)
          
       4. for batch in train_dataloader():
          
          5. training_step(batch):
             - Unwrap batch (baseline values)
             - shared_step():
               a. env.reset(batch)
               b. policy(td, env, decode_type="sampling")
               c. calculate_loss()  # Algorithm-specific
             - return loss
             
          6. on_before_optimizer_step():
             - Gradient clipping
             
       7. validation_step(batch):
          - shared_step() with decode_type="greedy"
          
       8. on_train_epoch_end():
          - baseline.epoch_callback()
          - Optionally regenerate dataset
    """
```

### 3.3 Key Methods Explained

#### `shared_step()` - Common Training Logic
```python
def shared_step(self, batch, batch_idx, phase):
    # 1. Unwrap baseline values if present
    batch, baseline_val = self.baseline.unwrap_batch(batch)
    
    # 2. Move to device
    batch = batch.to(self.device)
    
    # 3. Reset environment with batch
    td = self.env.reset(batch)
    
    # 4. Run policy
    out = self.policy(
        td, self.env,
        decode_type="sampling" if phase == "train" else "greedy"
    )
    
    # 5. Compute loss (training only)
    if phase == "train":
        out["loss"] = self.calculate_loss(td, out, batch_idx)
    
    # 6. Log metrics
    self.log(f"{phase}/reward", out["reward"].mean())
    
    return out
```

#### `calculate_loss()` - Algorithm-Specific
This is the **abstract method** that each RL algorithm must implement.

In [None]:
# Example: REINFORCE loss implementation
from logic.src.pipeline.rl.core.reinforce import REINFORCE

# Show the calculate_loss implementation
print("REINFORCE.calculate_loss source:")
print(inspect.getsource(REINFORCE.calculate_loss))

---

## 4. RL Algorithms

The pipeline supports multiple RL algorithms, each suited for different scenarios:

| Algorithm | Best For | Key Features |
|-----------|----------|-------------|
| **REINFORCE** | Simple baseline | Standard policy gradient |
| **PPO** | Stable training | Clipped surrogate, multiple epochs |
| **SAPO** | Adaptive clipping | Soft gating instead of hard clip |
| **GSPO** | Sequence-level | Group-based importance ratios |
| **POMO** | Multiple solutions | Data augmentation, multi-start |
| **SymNCO** | Symmetry exploitation | Invariance-aware training |
| **HRL** | Hierarchical decisions | Manager-Worker architecture |
| **Imitation** | Expert guidance | Learn from HGS/ALNS experts |

### 4.1 REINFORCE with Baseline

In [None]:
from logic.src.pipeline.rl import REINFORCE
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy

# Create environment
env = get_env("vrpp", num_loc=20, device="cpu")

# Create policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# Create REINFORCE module
reinforce_module = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",  # Using exponential baseline for demo
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
    entropy_weight=0.01,
    max_grad_norm=1.0,
)

print(f"REINFORCE module created:")
print(f"  Environment: {env.__class__.__name__}")
print(f"  Policy parameters: {sum(p.numel() for p in policy.parameters()):,}")
print(f"  Baseline: {reinforce_module.baseline_type}")

### 4.2 PPO (Proximal Policy Optimization)

PPO performs multiple optimization epochs per batch with clipped surrogate objective:

$$\mathcal{L}^{CLIP}(\theta) = \mathbb{E}\left[\min\left(r_t(\theta)\hat{A}_t, \text{clip}(r_t(\theta), 1-\epsilon, 1+\epsilon)\hat{A}_t\right)\right]$$

where $r_t(\theta) = \frac{\pi_\theta(a_t|s_t)}{\pi_{\theta_{old}}(a_t|s_t)}$

In [None]:
from logic.src.pipeline.rl import PPO
from logic.src.models.policies.critic import create_critic_from_actor

# Create critic from actor architecture
critic = create_critic_from_actor(
    policy,
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_layers=2,
    n_heads=4,
)

# Create PPO module
ppo_module = PPO(
    env=env,
    policy=policy,
    critic=critic,
    ppo_epochs=10,           # Inner optimization epochs
    eps_clip=0.2,            # Clipping epsilon
    value_loss_weight=0.5,   # Critic loss weight
    mini_batch_size=0.25,    # 25% of batch per mini-batch
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

print(f"PPO module created:")
print(f"  Actor parameters: {sum(p.numel() for p in policy.parameters()):,}")
print(f"  Critic parameters: {sum(p.numel() for p in critic.parameters()):,}")
print(f"  PPO epochs: {ppo_module.ppo_epochs}")
print(f"  Clip epsilon: {ppo_module.eps_clip}")

### 4.3 POMO (Policy Optimization with Multiple Optima)

POMO exploits problem symmetries through:
1. **Data Augmentation**: Dihedral transformations (rotations, reflections)
2. **Multi-start Decoding**: Try multiple starting nodes
3. **Shared Baseline**: Mean reward across all augmentations/starts

In [None]:
from logic.src.pipeline.rl import POMO

# Create POMO module
pomo_module = POMO(
    env=env,
    policy=policy,
    num_augment=8,           # Dihedral group D8
    augment_fn="dihedral8",  # Augmentation function
    num_starts=None,         # Defaults to num_loc
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

print(f"POMO module created:")
print(f"  Number of augmentations: {pomo_module.num_augment}")
print(f"  Augmentation function: {pomo_module.augment_fn}")
print(f"  Number of starts: {pomo_module.num_starts or 'auto (num_loc)'}")

### 4.4 Algorithm Selection Guide

```
┌─────────────────────────────────────────────────────────────────┐
│                    Algorithm Selection Tree                      │
└─────────────────────────────────────────────────────────────────┘
                              │
          ┌───────────────────┼───────────────────┐
          │                   │                   │
    Need stability?    Need multiple      Learning from
          │              solutions?          expert?
          │                   │                   │
          ▼                   ▼                   ▼
    ┌─────────┐        ┌─────────┐        ┌─────────────┐
    │   PPO   │        │  POMO   │        │  Imitation  │
    │  SAPO   │        │ SymNCO  │        │  Adaptive   │
    │  GSPO   │        └─────────┘        └─────────────┘
    └─────────┘
          │
          │ Simple case?
          ▼
    ┌─────────┐
    │REINFORCE│
    └─────────┘
```

**Recommendations:**
- **Start with REINFORCE + Rollout baseline** for initial experiments
- **Use PPO** when training is unstable or gradients are noisy
- **Use POMO** for problems with symmetric solutions (TSP, VRP)
- **Use Imitation + Adaptive** to bootstrap from classical solvers

---

## 5. Baselines for Variance Reduction

Baselines reduce variance in policy gradient estimates:

$$\nabla_\theta J(\theta) = \mathbb{E}\left[(R - b(s)) \nabla_\theta \log \pi_\theta(a|s)\right]$$

### 5.1 Available Baselines

In [None]:
from logic.src.pipeline.rl.core.baselines import (
    BASELINE_REGISTRY,
    NoBaseline,
    ExponentialBaseline,
    RolloutBaseline,
    CriticBaseline,
    WarmupBaseline,
    POMOBaseline,
    get_baseline,
)

print("Available Baselines:")
print("="*60)
for name, cls in BASELINE_REGISTRY.items():
    doc = cls.__doc__ or "No description"
    first_line = doc.split('\n')[0].strip()
    print(f"  {name:15} - {first_line}")

### 5.2 Baseline Comparison

| Baseline | Formula | Pros | Cons |
|----------|---------|------|------|
| **none** | $b = 0$ | Simple | High variance |
| **exponential** | $b = \beta \cdot b + (1-\beta) \cdot \bar{R}$ | Low compute | Biased |
| **rollout** | $b = R^{greedy}_{\pi_{old}}$ | Unbiased | Expensive |
| **critic** | $b = V_\phi(s)$ | Learned | Requires critic network |
| **pomo** | $b = \text{mean}(R_{starts})$ | Multi-solution | POMO-specific |

### 5.3 Rollout Baseline Deep Dive

The **RolloutBaseline** is the most commonly used baseline:

1. **Pre-compute**: At epoch start, run greedy rollout on training data
2. **Store**: Keep baseline values alongside training samples
3. **Use**: During training, advantage = reward - stored_baseline
4. **Update**: Periodically update baseline policy if improved

In [None]:
# Demonstrate RolloutBaseline usage
from logic.src.pipeline.rl.core.baselines import RolloutBaseline

# Create baseline with policy
rollout_bl = RolloutBaseline(
    policy=policy,
    update_every=1,   # Check for update every epoch
    bl_alpha=0.05,    # Significance level for T-test
)

print("RolloutBaseline:")
print(f"  Update frequency: every {rollout_bl.update_every} epoch(s)")
print(f"  T-test alpha: {rollout_bl.bl_alpha}")
print(f"  Has baseline policy: {rollout_bl.baseline_policy is not None}")

# Key methods
print("\nKey Methods:")
print("  wrap_dataset(dataset, policy, env) - Pre-compute baseline values")
print("  unwrap_batch(batch) -> (data, baseline_val) - Extract baseline")
print("  eval(td, reward, env) - Compute baseline on-the-fly")
print("  epoch_callback(policy, epoch, val_dataset, env) - Update check")

### 5.4 Warmup Baseline

The **WarmupBaseline** provides gradual transition from exponential to target baseline:

$$b_{\text{warmup}} = \alpha \cdot b_{\text{target}} + (1 - \alpha) \cdot b_{\text{exponential}}$$

where $\alpha$ increases from 0 to 1 over warmup epochs.

In [None]:
from logic.src.pipeline.rl.core.baselines import WarmupBaseline, RolloutBaseline

# Create warmup baseline wrapping rollout
target_baseline = RolloutBaseline(policy=policy)
warmup_bl = WarmupBaseline(
    baseline=target_baseline,
    warmup_epochs=5,
    beta=0.8,  # Exponential baseline decay factor
)

print(f"WarmupBaseline Configuration:")
print(f"  Target: {target_baseline.__class__.__name__}")
print(f"  Warmup epochs: {warmup_bl.warmup_epochs}")
print(f"  Current alpha: {warmup_bl.alpha}")
print(f"\nBlending schedule:")
for epoch in range(6):
    alpha = min(1.0, (epoch + 1) / warmup_bl.warmup_epochs)
    print(f"  Epoch {epoch}: α = {alpha:.2f} ({int(alpha*100)}% target, {int((1-alpha)*100)}% exponential)")

---

## 6. Environments & Data Generation

The environment system provides:
- Problem-specific state transitions
- Reward computation
- Action masking for valid moves
- On-the-fly data generation

### 6.1 Environment Registry

In [None]:
from logic.src.envs import ENV_REGISTRY, get_env

print("Available Environments:")
print("="*60)
for name, cls in ENV_REGISTRY.items():
    doc = cls.__doc__ or "No description"
    first_line = doc.split('\n')[0].strip() if doc else "No description"
    print(f"  {name:10} - {cls.__name__}")

### 6.2 Environment Types

| Environment | Description | Key Features |
|-------------|-------------|-------------|
| **vrpp** | Vehicle Routing with Profits | Prizes at nodes, maximize profit-cost |
| **cvrpp** | Capacitated VRPP | + Vehicle capacity constraints |
| **wcvrp** | Waste Collection VRP | Bin fill levels, accumulation |
| **cwcvrp** | Capacitated WCVRP | + Vehicle capacity |
| **sdwcvrp** | Stochastic Demand WCVRP | + Uncertain waste generation |

### 6.3 TensorDict State Structure

All environments use TensorDict for state management:

In [None]:
from logic.src.envs import get_env

# Create VRPP environment
env = get_env(
    "vrpp",
    num_loc=20,
    min_loc=0.0,
    max_loc=1.0,
    device="cpu",
)

# Generate a batch of instances
batch_size = 4
td = env.generator(batch_size)

print("Generated TensorDict:")
print(f"  Batch size: {batch_size}")
print(f"  Keys: {list(td.keys())}")
print(f"\nShapes:")
for key in td.keys():
    print(f"  {key}: {td[key].shape}")

In [None]:
# Demonstrate environment step
td_reset = env.reset(td)

print("After reset:")
print(f"  Keys: {list(td_reset.keys())}")
print(f"  Done: {td_reset['done'] if 'done' in td_reset.keys() else 'Not set'}")

# Show action mask
if 'action_mask' in td_reset.keys():
    mask = td_reset['action_mask']
    print(f"\nAction mask shape: {mask.shape}")
    print(f"  Valid actions (first instance): {mask[0].sum().item()}/{mask[0].numel()}")

### 6.4 Data Generators

Generators create problem instances on-the-fly:

In [None]:
from logic.src.envs import GENERATOR_REGISTRY, get_generator

print("Available Generators:")
for name in GENERATOR_REGISTRY.keys():
    print(f"  {name}")

# Create VRPP generator
generator = get_generator(
    "vrpp",
    num_loc=50,
    min_loc=0.0,
    max_loc=1.0,
)

# Generate instances
instances = generator(batch_size=8)
print(f"\nGenerated {instances.batch_size[0]} instances with {instances['locs'].shape[-2]} locations")

### 6.5 Dataset Classes

The pipeline uses custom dataset classes:

```python
from logic.src.data.datasets import (
    GeneratorDataset,      # On-the-fly generation
    TensorDictDataset,     # Persistent storage
    BaselineDataset,       # Wraps dataset with baseline values
    tensordict_collate_fn, # Custom collation for TensorDict
)
```

In [None]:
from logic.src.data.datasets import GeneratorDataset, tensordict_collate_fn
from torch.utils.data import DataLoader

# Create dataset
dataset = GeneratorDataset(generator, size=1000)

# Create dataloader
dataloader = DataLoader(
    dataset,
    batch_size=64,
    collate_fn=tensordict_collate_fn,
    num_workers=0,
)

# Get a batch
batch = next(iter(dataloader))
print(f"Batch keys: {list(batch.keys())}")
print(f"Batch size: {batch.batch_size}")

---

## 7. Meta-Learning

Meta-learning enables automatic adaptation of training parameters (e.g., reward weights) through bi-level optimization.

### 7.1 Meta-Learning Wrapper

In [None]:
from logic.src.pipeline.rl import MetaRLModule

# Create base RL module
base_module = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=1000,
    val_data_size=100,
    batch_size=64,
)

# Wrap with meta-learning
meta_module = MetaRLModule(
    agent=base_module,
    meta_lr=1e-3,
    history_length=10,
    hidden_size=64,
)

print("MetaRLModule created:")
print(f"  Inner agent: {base_module.__class__.__name__}")
print(f"  Meta learning rate: {meta_module.meta_lr}")
print(f"  History length: {meta_module.history_length}")

### 7.2 Meta-Learning Strategies

| Strategy | Description | Best For |
|----------|-------------|----------|
| **rnn** | Recurrent network processes reward history | General adaptation |
| **bandit** | UCB/Thompson sampling for weight selection | Discrete weight choices |
| **morl** | Multi-objective Pareto optimization | Multiple objectives |
| **tdl** | Temporal difference learning | Online adaptation |
| **hypernet** | Hypernetwork generates weights | Problem-conditioned |

### 7.3 Bi-Level Optimization Flow

```
┌─────────────────────────────────────────────────────────────────┐
│                    Meta-Learning Training Loop                   │
└─────────────────────────────────────────────────────────────────┘

for epoch in range(n_epochs):
    
    ┌─────────────────────────────────────────────────────────────┐
    │ INNER LOOP: RL Training                                      │
    │                                                              │
    │   for batch in dataloader:                                   │
    │       loss = agent.training_step(batch)                      │
    │       optimizer.step()                                       │
    │                                                              │
    │   reward_signal = epoch_reward                               │
    └─────────────────────────────────────────────────────────────┘
                              │
                              ▼
    ┌─────────────────────────────────────────────────────────────┐
    │ OUTER LOOP: Meta-Strategy Update                             │
    │                                                              │
    │   meta_strategy.update(reward_signal)                        │
    │   new_weights = meta_strategy.propose_weights()              │
    │   env.update_weights(new_weights)                            │
    └─────────────────────────────────────────────────────────────┘
```

---

## 8. Hyperparameter Optimization

The pipeline supports two HPO methods:

### 8.1 Optuna-Based HPO

Multiple sampling strategies:
- **TPE** (Tree-structured Parzen Estimator) - Default
- **Grid Search**
- **Random Search**
- **Hyperband** - Multi-fidelity

In [None]:
from logic.src.configs import HPOConfig

# HPO configuration
hpo_cfg = HPOConfig(
    method="tpe",           # tpe, grid, random, hyperband
    n_trials=50,            # Number of trials
    n_epochs_per_trial=10,  # Epochs per trial
    search_space={
        "optim.lr": [1e-5, 1e-3],           # Log-uniform
        "train.batch_size": [64, 512],       # Integer
        "model.embed_dim": [64, 128, 256],   # Categorical
        "rl.entropy_weight": [0.0, 0.1],     # Uniform
    },
)

print("HPO Configuration:")
print(f"  Method: {hpo_cfg.method}")
print(f"  Trials: {hpo_cfg.n_trials}")
print(f"  Epochs per trial: {hpo_cfg.n_epochs_per_trial}")
print(f"\nSearch Space:")
for param, range_val in hpo_cfg.search_space.items():
    print(f"  {param}: {range_val}")

### 8.2 HPO via CLI

```bash
# TPE optimization with 50 trials
python main.py train_lightning \
    hpo.n_trials=50 \
    hpo.method=tpe \
    hpo.n_epochs_per_trial=10 \
    'hpo.search_space={"optim.lr": [1e-5, 1e-3], "train.batch_size": [64, 512]}'

# Grid search
python main.py train_lightning \
    hpo.n_trials=100 \
    hpo.method=grid \
    'hpo.search_space={"model.embed_dim": [64, 128, 256], "model.num_heads": [4, 8]}'

# DEHB (multi-fidelity)
python main.py train_lightning \
    hpo.method=dehb \
    hpo.min_fidelity=1 \
    hpo.max_fidelity=50 \
    hpo.fevals=100
```

### 8.3 DEHB (Differential Evolution Hyperband)

DEHB combines:
- **Differential Evolution** for global search
- **Hyperband** for multi-fidelity scheduling

Benefits:
- Early stopping of poor configurations
- Efficient use of compute budget

---

## 9. Practical Examples

### 9.1 Complete Training Script

In [None]:
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.loggers import CSVLogger

from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE
from logic.src.pipeline.trainer import WSTrainer
from logic.src.callbacks import SpeedMonitor

# Set seed for reproducibility
seed_everything(42)

# 1. Create Environment
env = get_env(
    "vrpp",
    num_loc=20,
    device="cpu",  # Use "cuda" if available
)

# 2. Create Policy
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=128,
    n_encode_layers=2,
    n_heads=4,
)

# 3. Create RL Module
model = REINFORCE(
    env=env,
    policy=policy,
    baseline="exponential",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=5000,
    val_data_size=500,
    batch_size=64,
    num_workers=0,  # Use 0 for notebooks
)

# 4. Create Trainer
trainer = WSTrainer(
    max_epochs=3,  # Short training for demo
    accelerator="auto",
    devices=1,
    logger=CSVLogger("logs", name="tutorial_demo"),
    callbacks=[SpeedMonitor(epoch_time=True)],
    enable_progress_bar=True,
)

print("Training configuration ready!")
print(f"  Environment: {env.__class__.__name__}")
print(f"  Policy: {policy.__class__.__name__}")
print(f"  Algorithm: REINFORCE with {model.baseline_type} baseline")
print(f"  Epochs: {trainer.max_epochs}")

In [None]:
# Run training (uncomment to execute)
# trainer.fit(model)

# Get final metrics
# print(f"\nFinal validation reward: {trainer.callback_metrics.get('val/reward', 'N/A')}")

### 9.2 CLI Examples Reference

#### Basic Training
```bash
# VRPP with Attention Model
python main.py train_lightning \
    model=am \
    env.name=vrpp \
    env.num_loc=50 \
    train.n_epochs=100

# WCVRP (Waste Collection)
python main.py train_lightning \
    model=am \
    env.name=wcvrp \
    env.num_loc=50 \
    env.capacity=100
```

#### Algorithm Variants
```bash
# PPO with custom parameters
python main.py train_lightning \
    rl.algorithm=ppo \
    rl.ppo_epochs=10 \
    rl.eps_clip=0.2 \
    rl.value_loss_weight=0.5

# POMO with augmentation
python main.py train_lightning \
    rl.algorithm=pomo \
    rl.num_augment=8 \
    rl.num_starts=50

# Imitation learning from HGS
python main.py train_lightning \
    rl.algorithm=imitation \
    rl.expert=hgs
```

#### Advanced Features
```bash
# Meta-learning with RNN strategy
python main.py train_lightning \
    rl.use_meta=true \
    rl.meta_strategy=rnn \
    rl.meta_lr=1e-3 \
    rl.meta_history_length=10

# Learning rate scheduling
python main.py train_lightning \
    optim.lr_scheduler=cosine \
    'optim.lr_scheduler_kwargs={"T_max": 100}'

# Mixed precision training
python main.py train_lightning \
    train.precision=16-mixed
```

---

## 10. Advanced Topics

### 10.1 Custom RL Algorithm

To implement a custom RL algorithm, inherit from `RL4COLitModule`:

In [None]:
from logic.src.pipeline.rl.core.base import RL4COLitModule
from typing import Optional
import torch
from tensordict import TensorDict

class CustomRL(RL4COLitModule):
    """
    Example custom RL algorithm.
    
    Implements a simple variant with custom advantage normalization.
    """
    
    def __init__(
        self,
        temperature: float = 1.0,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.temperature = temperature
    
    def calculate_loss(
        self,
        td: TensorDict,
        out: dict,
        batch_idx: int,
        env: Optional["RL4COEnvBase"] = None,
    ) -> torch.Tensor:
        """
        Custom loss computation.
        
        Uses temperature-scaled advantage.
        """
        reward = out["reward"]
        log_likelihood = out["log_likelihood"]
        
        # Get baseline
        if hasattr(self, "_current_baseline_val") and self._current_baseline_val is not None:
            baseline_val = self._current_baseline_val
        else:
            baseline_val = self.baseline.eval(td, reward, env=env)
        
        # Temperature-scaled advantage
        advantage = (reward - baseline_val) / self.temperature
        advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-8)
        
        # Policy gradient loss
        loss = -(advantage.detach() * log_likelihood).mean()
        
        return loss

print("CustomRL algorithm defined successfully!")
print(f"  Key parameter: temperature={1.0}")
print(f"  Inherits from: RL4COLitModule")

### 10.2 Custom Baseline

To implement a custom baseline:

In [None]:
from logic.src.pipeline.rl.core.baselines import Baseline
import torch
import torch.nn as nn
from tensordict import TensorDict

class PercentileBaseline(Baseline):
    """
    Baseline using batch percentile.
    
    Returns the q-th percentile of rewards as baseline.
    """
    
    def __init__(self, percentile: float = 50.0):
        super().__init__()
        self.percentile = percentile
    
    def eval(
        self,
        td: TensorDict,
        reward: torch.Tensor,
        env = None,
    ) -> torch.Tensor:
        """
        Compute percentile baseline.
        
        Args:
            td: TensorDict (unused)
            reward: Batch rewards
            env: Environment (unused)
            
        Returns:
            Percentile value expanded to reward shape
        """
        # Compute percentile
        baseline_val = torch.quantile(
            reward.float(), 
            self.percentile / 100.0
        )
        return baseline_val.expand_as(reward)

# Test the custom baseline
baseline = PercentileBaseline(percentile=75.0)
test_rewards = torch.randn(64)
baseline_val = baseline.eval(None, test_rewards)

print(f"PercentileBaseline (75th):")
print(f"  Input rewards shape: {test_rewards.shape}")
print(f"  Baseline value: {baseline_val[0].item():.4f}")
print(f"  Actual 75th percentile: {torch.quantile(test_rewards, 0.75).item():.4f}")

### 10.3 Multi-GPU Training

PyTorch Lightning automatically handles distributed training:

```bash
# Single GPU
python main.py train_lightning device=cuda

# Multiple GPUs (DDP)
python main.py train_lightning \
    device=cuda \
    --trainer.devices=4 \
    --trainer.strategy=ddp

# DeepSpeed (for large models)
python main.py train_lightning \
    --trainer.strategy=deepspeed_stage_2
```

### 10.4 Checkpoint Management

```bash
# Resume from checkpoint
python main.py train_lightning \
    --ckpt_path=/path/to/checkpoint.ckpt

# Custom checkpoint directory
python main.py train_lightning \
    output_dir=assets/model_weights/experiment1
```

### 10.5 Debugging Tips

```bash
# Fast dev run (1 batch per epoch)
python main.py train_lightning \
    --trainer.fast_dev_run=true

# Limit batches for debugging
python main.py train_lightning \
    --trainer.limit_train_batches=10 \
    --trainer.limit_val_batches=5

# Profiling
python main.py train_lightning \
    --trainer.profiler=simple
```

---

## Summary

This tutorial covered the Lightning-based RL training pipeline:

### Key Components

1. **Hydra Configuration**: Flexible, hierarchical configs with CLI overrides
2. **PyTorch Lightning**: Automatic training loop, logging, checkpointing
3. **RL Algorithms**: REINFORCE, PPO, POMO, SymNCO, HRL, Imitation
4. **Baselines**: None, Exponential, Rollout, Critic, POMO, Warmup
5. **Environments**: VRPP, WCVRP variants with TensorDict state
6. **Meta-Learning**: Bi-level optimization for adaptive training
7. **HPO**: Optuna (TPE, Grid, Random) and DEHB

### Quick Reference

```bash
# Basic training
python main.py train_lightning model=am env.name=vrpp

# PPO with rollout baseline
python main.py train_lightning rl.algorithm=ppo rl.baseline=rollout

# POMO with augmentation
python main.py train_lightning rl.algorithm=pomo rl.num_augment=8

# Meta-learning
python main.py train_lightning rl.use_meta=true rl.meta_strategy=rnn

# Hyperparameter optimization
python main.py train_lightning hpo.n_trials=50 hpo.method=tpe
```

### Further Reading

- `CLAUDE.md` - Project overview and coding standards
- `logic/src/pipeline/rl/` - Source code for all RL components
- `logic/src/configs/` - Configuration dataclasses
- `logic/src/envs/` - Environment implementations