# RL Policy Training

Train DQN to learn service policies and compare against baselines.

In [None]:
import sys
sys.path.insert(0, '..')

import numpy as np
import matplotlib.pyplot as plt

from src.scenarios.basic_bathtub import BasicBathtubScenario
from src.policy import LinearIntervalPolicy, FixedIntervalPolicy, NoOpPolicy
from src.runner import run_scenario, compare_policies
from src.rl import ServiceEnv, train_dqn, evaluate_model, DQNPolicy

## 1. Setup Scenario

In [None]:
# Use same scenario as policy_optimisation notebook
scenario = BasicBathtubScenario(
    scale1=100.0,
    scale2=200.0,
    service_cost=0.5,
    failure_cost=150.0,
    revenue_per_time=1.50,
)

# Print key parameters for reference
print("Scenario parameters:")
print(f"  Bathtub shape: shape1={scenario.failure_model.shape1}, shape2={scenario.failure_model.shape2}")
print(f"  Scales: scale1={scenario.failure_model.scale1}, scale2={scenario.failure_model.scale2}")
print(f"  Delta_t (age reduction per service): {scenario.failure_model.delta_t}")
print(f"  Costs: service={scenario.costs.service_cost}, failure={scenario.costs.failure_cost}")
print(f"  Revenue per time: {scenario.costs.revenue_per_time}")

MAX_TIME = 150.0

## 2. Baseline Comparison

First, establish baseline performance with known policies.

In [None]:
# Define baseline policies including optimised from policy_optimisation notebook
# Best from multi-start: c=48.4, r=0.50 -> a=24.2, b=24.2
policies = {
    'no_service': NoOpPolicy(),
    'fixed_25': FixedIntervalPolicy(interval=25.0),
    'fixed_50': FixedIntervalPolicy(interval=50.0),
    'linear_15_10': LinearIntervalPolicy(a=15.0, b=10.0),  # baseline from policy_opt
    'optimised_linear': LinearIntervalPolicy(a=24.2, b=24.2),  # best from grid+NM
}

# Compare policies
results = compare_policies(
    scenario,
    policies,
    n_subjects=2000,
    max_time=MAX_TIME,
    n_repeats=5,
    seed=42
)

# Display results
print("Baseline Policy Comparison")
print("=" * 50)
for name, stats in sorted(results.items(), key=lambda x: -x[1]['mean']):
    print(f"{name:20s}: mean={stats['mean']:8.2f} Â± {stats['std']:6.2f}")

## 3. Train DQN

In [None]:
# Train DQN
# Action space now: [0.1, 5, 10, 15, ..., 100, inf] (22 actions)
print(f"Action space: {len(ServiceEnv.ACTION_DELAYS)} actions")
print(f"Delays: {ServiceEnv.ACTION_DELAYS[:5]} ... {ServiceEnv.ACTION_DELAYS[-3:]}")

import torch.nn as nn
from stable_baselines3 import DQN
from stable_baselines3.common.monitor import Monitor

# Network architecture - toggle activation:
USE_TANH = False  # Set True for Tanh, False for ReLU

policy_kwargs = dict(
    #net_arch=[128, 128],
    net_arch=[256, 128, 64],
    activation_fn=nn.Tanh if USE_TANH else nn.ReLU,
)

# Create environment
env = ServiceEnv(scenario, max_time=MAX_TIME, seed=42)
env = Monitor(env)

# TensorBoard: run `tensorboard --logdir ./logs` to watch
print(f"\nNetwork: {policy_kwargs}")
print("Training DQN...")

model = DQN(
    'MlpPolicy',
    env,
    learning_rate=1e-3,
    buffer_size=50_000,
    batch_size=128,
    gamma=0.99,
    exploration_fraction=0.5,
    exploration_final_eps=0.05,
    target_update_interval=500,
    policy_kwargs=policy_kwargs,
    verbose=1,
    seed=42,
    tensorboard_log="./logs",
)

# Callbacks for logging
from src.rl.training import ActionStatsCallback, RewardLoggerCallback
reward_logger = RewardLoggerCallback()
callbacks = [reward_logger, ActionStatsCallback(log_freq=1000)]

model.learn(total_timesteps=200_000, callback=callbacks)

# Store for later cells
training_rewards = reward_logger.episode_rewards
trained_model = model

print(f"\nTraining complete!")
print(f"Episodes: {len(training_rewards)}")

In [None]:
# Plot training progress
fig, ax = plt.subplots(figsize=(10, 4))

# Smooth rewards with rolling average
rewards = training_rewards  # From previous cell
window = 50
if len(rewards) > window:
    smoothed = np.convolve(rewards, np.ones(window)/window, mode='valid')
    ax.plot(smoothed, label=f'Rolling avg (window={window})')
    ax.scatter(range(len(rewards)), rewards, alpha=0.1, s=1, label='Raw')
else:
    ax.plot(rewards)

ax.set_xlabel('Episode')
ax.set_ylabel('Cumulative Reward')
ax.set_title('DQN Training Progress')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

## 4. Evaluate Trained Model

In [None]:
# Evaluate trained model
dqn_rewards = evaluate_model(
    trained_model,
    scenario,
    n_episodes=500,
    max_time=MAX_TIME,
    seed=42,
    deterministic=True
)

print(f"DQN Performance (500 episodes):")
print(f"  Mean reward: {np.mean(dqn_rewards):.2f}")
print(f"  Std reward:  {np.std(dqn_rewards):.2f}")
print(f"  Min reward:  {np.min(dqn_rewards):.2f}")
print(f"  Max reward:  {np.max(dqn_rewards):.2f}")

In [None]:
# Compare DQN vs baselines
fig, ax = plt.subplots(figsize=(10, 5))

# Collect all results
all_results = {name: stats['mean'] for name, stats in results.items()}
all_results['DQN'] = np.mean(dqn_rewards)

# Sort by performance
sorted_results = sorted(all_results.items(), key=lambda x: x[1])
names = [r[0] for r in sorted_results]
means = [r[1] for r in sorted_results]

colors = ['green' if n == 'DQN' else 'steelblue' for n in names]
bars = ax.barh(names, means, color=colors)

ax.set_xlabel('Mean Reward')
ax.set_title('DQN vs Baseline Policies')
ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)

# Add value labels
for bar, mean in zip(bars, means):
    ax.text(mean + 5, bar.get_y() + bar.get_height()/2, 
            f'{mean:.1f}', va='center', fontsize=9)

plt.tight_layout()
plt.show()

# Diagnostic: Run a few episodes manually
print("\n--- Diagnostic: Episode traces ---")
print("Obs (normalised): [time, last_int, svc_cnt, avg_int, dur] - shown denormalised below")
env = ServiceEnv(scenario, max_time=MAX_TIME, seed=123)
for ep in range(3):
    obs, info = env.reset(seed=123 + ep)
    print(f"\nEpisode {ep+1}: durability={obs[4]*10:.2f}")
    total_reward = 0
    step = 0
    while True:
        action, _ = trained_model.predict(obs, deterministic=True)
        delay = ServiceEnv.ACTION_DELAYS[int(action)]
        obs, reward, term, trunc, info = env.step(action)
        total_reward += reward
        step += 1
        if step <= 5:
            t = obs[0] * MAX_TIME
            last_int = obs[1] * MAX_TIME
            print(f"  step {step}: t={t:5.1f}, last_int={last_int:5.1f}, action={int(action)} (delay={delay:4.0f}), reward={reward:7.1f}")
        if term or trunc:
            status = "FAILED" if term else "TRUNCATED"
            print(f"  ... {status} at t={info['time']:.1f}, services={info['service_count']}, total={total_reward:.1f}")
            break

## 5. Analyse DQN Behaviour

In [None]:
# Analyse action distribution across different states
from src.rl.environment import ServiceEnv

# Sample states and get actions
env = ServiceEnv(scenario, max_time=MAX_TIME, seed=42)

# Collect state -> action pairs
# Obs (normalised): [current_time, last_interval, service_count, avg_interval, durability]
state_actions = []
for ep in range(100):
    obs, _ = env.reset(seed=42 + ep)
    done = False
    while not done:
        action, _ = trained_model.predict(obs, deterministic=True)
        state_actions.append({
            'current_time': obs[0] * MAX_TIME,      # denormalise
            'last_interval': obs[1] * MAX_TIME,
            'service_count': obs[2] * 100,
            'avg_interval': obs[3] * MAX_TIME,
            'durability': obs[4] * 10,
            'action': int(action),
            'delay': ServiceEnv.ACTION_DELAYS[int(action)]
        })
        obs, _, terminated, truncated, _ = env.step(action)
        done = terminated or truncated

import pandas as pd
df = pd.DataFrame(state_actions)

print("Action distribution:")
print(df['delay'].value_counts().sort_index())
print(f"\nTotal decisions: {len(df)}")

In [None]:
# Plot action choices
fig, axes = plt.subplots(1, 3, figsize=(14, 4))

# Action vs last_interval (colored by durability)
ax = axes[0]
scatter = ax.scatter(df['last_interval'], df['delay'], c=df['durability'], 
                     alpha=0.5, s=20, cmap='viridis')
ax.set_xlabel('Last Interval')
ax.set_ylabel('Chosen Delay')
ax.set_title('Action vs Last Interval')
plt.colorbar(scatter, ax=ax, label='Durability')

# Action vs durability (colored by last_interval)
ax = axes[1]
for action_idx in range(len(ServiceEnv.ACTION_DELAYS)):
    mask = df['action'] == action_idx
    if mask.any():
        delay = ServiceEnv.ACTION_DELAYS[action_idx]
        label = f'delay={delay}' if delay < float('inf') else 'no service'
        ax.scatter(df.loc[mask, 'durability'], df.loc[mask, 'last_interval'],
                   alpha=0.3, s=10, label=label)
ax.set_xlabel('Durability')
ax.set_ylabel('Last Interval')
ax.set_title('Action by Durability & Last Interval')
ax.legend(markerscale=3, fontsize=8)

# Average delay vs durability bins
ax = axes[2]
df['durability_bin'] = pd.cut(df['durability'], bins=8)
finite_delays = df[df['delay'] < float('inf')]
if len(finite_delays) > 0:
    avg_delay = finite_delays.groupby('durability_bin', observed=True)['delay'].mean()
    avg_delay.plot(kind='bar', ax=ax, color='steelblue')
    ax.set_xlabel('Durability Bin')
    ax.set_ylabel('Average Delay')
    ax.set_title('Avg Delay by Durability')
    plt.xticks(rotation=45)

plt.tight_layout()
plt.show()

# What interval does optimal linear use?
print("\nOptimal linear policy intervals for reference:")
print("  interval = 24.2 + 24.2 * durability")
for d in [0.7, 1.0, 1.3, 1.6]:
    print(f"  durability={d:.1f} -> interval={24.2 + 24.2*d:.1f}")

## 6. Save Model

In [None]:
# Save trained model
model_path = '../models/dqn_service_policy'
trained_model.save(model_path)
print(f"Model saved to {model_path}")