# Tutorial 4: Training with PyTorch Lightning

**WSmart+ Route Tutorial Series**

This tutorial covers end-to-end training of neural routing models using reinforcement learning. You'll learn:

1. The **configuration system** (`Config`, `EnvConfig`, `ModelConfig`, etc.)
2. **REINFORCE** training with different baselines
3. **PPO** training with a critic network
4. The `create_model()` high-level helper
5. **Monitoring** training progress
6. **Saving and loading** trained models

**Previous**: [03_models_and_policies.ipynb](03_models_and_policies.ipynb) | **Next**: [05_evaluation_and_decoding.ipynb](05_evaluation_and_decoding.ipynb)

> **Note**: This tutorial uses small problem sizes and few epochs for fast execution on CPU. For real training, use GPU with larger settings.

In [None]:
import os
import sys
import warnings

warnings.filterwarnings("ignore")

PROJECT_ROOT = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

import matplotlib.pyplot as plt
import numpy as np
import torch

torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

---
## 1. Configuration System

WSmart+ Route uses a hierarchy of dataclasses to configure all aspects of training.

In [None]:
from logic.src.configs import Config, EnvConfig, ModelConfig, OptimConfig, RLConfig, TrainConfig

# Create a configuration for a small training run
cfg = Config(
    env=EnvConfig(
        name="vrpp",
        num_loc=20,
    ),
    model=ModelConfig(
        name="am",
        embed_dim=64,
        num_encoder_layers=2,
        num_decoder_layers=2,
        n_heads=4,
    ),
    train=TrainConfig(
        n_epochs=3,
        batch_size=64,
        train_data_size=640,   # Small for tutorial
        val_data_size=128,
        num_workers=0,
        precision="32-true",   # Use full precision for CPU
    ),
    optim=OptimConfig(
        optimizer="adam",
        lr=1e-4,
    ),
    rl=RLConfig(
        algorithm="reinforce",
        baseline="rollout",
    ),
    device="cpu",
    seed=42,
)

print("Configuration Summary:")
print(f"  Environment:  {cfg.env.name} with {cfg.env.num_loc} nodes")
print(f"  Model:        {cfg.model.name} (embed={cfg.model.embed_dim}, layers={cfg.model.num_encoder_layers})")
print(f"  Training:     {cfg.train.n_epochs} epochs, batch_size={cfg.train.batch_size}")
print(f"  RL Algorithm: {cfg.rl.algorithm} with {cfg.rl.baseline} baseline")
print(f"  Optimizer:    {cfg.optim.optimizer}, lr={cfg.optim.lr}")

---
## 2. Building Components Manually

Before using the high-level `create_model()`, let's understand each component.

In [None]:
from logic.src.envs import get_env
from logic.src.models.policies import AttentionModelPolicy
from logic.src.pipeline.rl import REINFORCE

# Step 1: Create environment
env = get_env("vrpp", num_loc=20)
print(f"Environment: {env.name}")

# Step 2: Create policy (neural network)
policy = AttentionModelPolicy(
    env_name="vrpp",
    embed_dim=64,
    n_encode_layers=2,
    n_decode_layers=2,
    n_heads=4,
)
print(f"Policy: {type(policy).__name__} ({sum(p.numel() for p in policy.parameters()):,} params)")

# Step 3: Create RL module (REINFORCE)
model = REINFORCE(
    env=env,
    policy=policy,
    baseline="rollout",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=640,
    val_data_size=128,
    batch_size=64,
    num_workers=0,
)
print(f"RL Module: {type(model).__name__}")
print(f"  Baseline: {model.baseline_type}")

In [None]:
# Evaluate before training
td_eval = env.generator(batch_size=64)
td_eval = env.reset(td_eval)

with torch.no_grad():
    out_before = policy(td_eval.clone(), env, strategy="greedy", return_actions=True)

print(f"Before training (untrained model):")
print(f"  Mean reward (greedy): {out_before['reward'].mean():.4f}")
print(f"  Std reward:           {out_before['reward'].std():.4f}")

---
## 3. REINFORCE Training

The REINFORCE algorithm optimizes the policy using the gradient:

$$\nabla_\theta J(\theta) = \mathbb{E}\left[(R - b) \nabla_\theta \log \pi_\theta(a|s)\right]$$

where $R$ is the reward and $b$ is the baseline for variance reduction.

In [None]:
from logic.src.pipeline.rl.common.trainer import WSTrainer

# Create trainer
trainer = WSTrainer(
    max_epochs=3,
    accelerator="cpu",
    devices=1,
    precision="32-true",
    log_every_n_steps=5,
    enable_progress_bar=True,
    reload_dataloaders_every_n_epochs=1,
    logger=False,  # Disable logging for tutorial
)

print("Starting REINFORCE training (3 epochs)...")
trainer.fit(model)
print("Training complete!")

In [None]:
# Evaluate after training
with torch.no_grad():
    out_after = policy(td_eval.clone(), env, strategy="greedy", return_actions=True)

print(f"Training Results:")
print(f"  Before training - Mean reward: {out_before['reward'].mean():.4f}")
print(f"  After training  - Mean reward: {out_after['reward'].mean():.4f}")
improvement = out_after['reward'].mean() - out_before['reward'].mean()
print(f"  Improvement: {improvement:.4f}")

---
## 4. Baselines for Variance Reduction

Baselines reduce the variance of policy gradients. WSmart+ Route supports several baseline types:

| Baseline | Description |
|----------|-------------|
| `"rollout"` | Uses greedy rollout of a periodically updated policy |
| `"exponential"` | Exponential moving average of past rewards |
| `"critic"` | Learned value network |
| `"warmup"` | Transitions from one baseline to another |
| `"no"` | No baseline (vanilla REINFORCE) |

In [None]:
# Train with exponential baseline for comparison
env2 = get_env("vrpp", num_loc=20)
policy2 = AttentionModelPolicy(
    env_name="vrpp", embed_dim=64, n_encode_layers=2, n_decode_layers=2, n_heads=4,
)

model_exp = REINFORCE(
    env=env2,
    policy=policy2,
    baseline="exponential",
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=640,
    val_data_size=128,
    batch_size=64,
    num_workers=0,
)

trainer2 = WSTrainer(
    max_epochs=3, accelerator="cpu", devices=1, precision="32-true",
    log_every_n_steps=5, enable_progress_bar=True, logger=False,
    reload_dataloaders_every_n_epochs=1,
)

print("Training with exponential baseline...")
trainer2.fit(model_exp)

with torch.no_grad():
    out_exp = policy2(td_eval.clone(), env2, strategy="greedy", return_actions=True)

print(f"\nBaseline Comparison:")
print(f"  Rollout baseline     - Mean reward: {out_after['reward'].mean():.4f}")
print(f"  Exponential baseline - Mean reward: {out_exp['reward'].mean():.4f}")

---
## 5. PPO Training

PPO performs multiple optimization epochs per batch with a clipped surrogate objective, often leading to more stable training.

In [None]:
from logic.src.models.policies.critic import create_critic_from_actor
from logic.src.pipeline.rl import PPO

# Create fresh components
env_ppo = get_env("vrpp", num_loc=20)
policy_ppo = AttentionModelPolicy(
    env_name="vrpp", embed_dim=64, n_encode_layers=2, n_decode_layers=2, n_heads=4,
)

# Create critic network (shares encoder architecture with policy)
critic = create_critic_from_actor(
    policy_ppo,
    env_name="vrpp",
    embed_dim=64,
    hidden_dim=64,
    n_layers=2,
    n_heads=4,
)
print(f"Critic parameters: {sum(p.numel() for p in critic.parameters()):,}")

# Create PPO module
model_ppo = PPO(
    env=env_ppo,
    policy=policy_ppo,
    critic=critic,
    ppo_epochs=3,
    eps_clip=0.2,
    baseline="no",  # PPO uses critic instead of traditional baselines
    optimizer="adam",
    optimizer_kwargs={"lr": 1e-4},
    train_data_size=640,
    val_data_size=128,
    batch_size=64,
    num_workers=0,
)

trainer_ppo = WSTrainer(
    max_epochs=3, accelerator="cpu", devices=1, precision="32-true",
    log_every_n_steps=5, enable_progress_bar=True, logger=False,
    reload_dataloaders_every_n_epochs=1,
)

print("Training with PPO (3 epochs)...")
trainer_ppo.fit(model_ppo)

with torch.no_grad():
    out_ppo = policy_ppo(td_eval.clone(), env_ppo, strategy="greedy", return_actions=True)

print(f"\nPPO Results:")
print(f"  Mean reward (greedy): {out_ppo['reward'].mean():.4f}")

---
## 6. Using `create_model()` Helper

For convenience, the `create_model()` function wires all components from a `Config` object.

In [None]:
from logic.src.pipeline.features.train import create_model

# Use the config we defined earlier
cfg_auto = Config(
    env=EnvConfig(name="vrpp", num_loc=20),
    model=ModelConfig(name="am", embed_dim=64, num_encoder_layers=2, num_decoder_layers=2, n_heads=4),
    train=TrainConfig(n_epochs=2, batch_size=64, train_data_size=640, val_data_size=128, num_workers=0, precision="32-true"),
    optim=OptimConfig(optimizer="adam", lr=1e-4),
    rl=RLConfig(algorithm="reinforce", baseline="exponential"),
    device="cpu",
)

model_auto = create_model(cfg_auto)
print(f"Auto-created model: {type(model_auto).__name__}")
print(f"  Environment: {model_auto.env.name}")
print(f"  Policy: {type(model_auto.policy).__name__}")
print(f"  Baseline: {model_auto.baseline_type}")

---
## 7. Saving and Loading Models

In [None]:
import tempfile

# Save trained model weights
save_dir = tempfile.mkdtemp()
save_path = os.path.join(save_dir, "model.pt")

model.save_weights(save_path)
print(f"Saved model to: {save_path}")
print(f"File size: {os.path.getsize(save_path) / 1024:.1f} KB")

# List saved files
for f in os.listdir(save_dir):
    print(f"  {f}: {os.path.getsize(os.path.join(save_dir, f)) / 1024:.1f} KB")

---
## 8. Algorithm Comparison

Let's visualize the performance of different algorithms trained in this tutorial.

In [None]:
# Collect all results
results = {
    "Untrained": out_before["reward"].mean().item(),
    "REINFORCE\n(rollout)": out_after["reward"].mean().item(),
    "REINFORCE\n(exponential)": out_exp["reward"].mean().item(),
    "PPO": out_ppo["reward"].mean().item(),
}

fig, ax = plt.subplots(figsize=(9, 5))
names = list(results.keys())
values = list(results.values())
colors = ["lightgray", "steelblue", "coral", "seagreen"]

bars = ax.bar(range(len(names)), values, color=colors, edgecolor="black", alpha=0.85)
ax.set_xticks(range(len(names)))
ax.set_xticklabels(names, fontsize=10)
ax.set_ylabel("Mean Reward (greedy)")
ax.set_title("Training Algorithm Comparison\n(3 epochs, 20-node VRPP, CPU)")
ax.grid(True, alpha=0.3, axis="y")

for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.005,
            f"{val:.3f}", ha="center", va="bottom", fontsize=10)

plt.tight_layout()
plt.show()

print("\nNote: With more epochs and larger data, differences become more pronounced.")
print("For production training, use GPU with:")
print("  - num_loc=50-100")
print("  - embed_dim=128")
print("  - n_epochs=100+")
print("  - train_data_size=100000")

---
## Summary

In this tutorial, you learned:

- **Config dataclasses** (`Config`, `EnvConfig`, `ModelConfig`, `TrainConfig`, `RLConfig`, `OptimConfig`) control all training parameters
- **REINFORCE** is the standard policy gradient algorithm with configurable baselines (`rollout`, `exponential`, `critic`)
- **PPO** uses a critic network and clipped surrogate objective for stable training
- **WSTrainer** wraps PyTorch Lightning Trainer with RL-specific optimizations
- **`create_model(cfg)`** is the high-level helper that wires env + policy + RL module from config
- Models can be **saved/loaded** with `model.save_weights()`

### CLI Equivalents

The training we did programmatically can also be run via CLI:

```bash
# REINFORCE with rollout baseline
python main.py train_lightning model=am env.name=vrpp env.num_loc=50 train.n_epochs=100

# PPO
python main.py train_lightning model=am env.name=vrpp rl.algorithm=ppo train.n_epochs=100

# With custom settings
python main.py train_lightning model=deep_decoder env.name=wcvrp env.num_loc=100 \
    rl.algorithm=reinforce rl.baseline=exponential optim.lr=5e-5 train.batch_size=512
```

### Next Steps

Continue to **[Tutorial 5: Evaluation and Decoding Strategies](05_evaluation_and_decoding.ipynb)** to learn how to evaluate trained models and compare decoding strategies.