# Stage 3 — Thermal Control: DQN vs PID

This notebook demonstrates the custom **ThermalControlEnv** — a Gymnasium
environment simulating avionics thermal management.

**Scenario:** An electronic box in an aircraft generates heat from a CPU.
The agent controls a fan (5 discrete speed levels) to keep the temperature
in a safe band (50–60°C) while minimising energy consumption.

We compare:
1. **Classical PID controller** (hand-tuned baseline)
2. **Trained DQN agent** (learned optimal policy)

In [None]:
import sys
from pathlib import Path

PROJECT_ROOT = Path.cwd().parent if Path.cwd().name == "notebooks" else Path.cwd()
if str(PROJECT_ROOT) not in sys.path:
    sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import torch
import matplotlib.pyplot as plt

from src.environments.thermal_control_env import ThermalControlEnv
from src.control.pid_controller import PIDController
from src.control.control_utils import (
    compute_settling_time, compute_overshoot,
    compute_steady_state_error, compute_integral_absolute_error,
    compute_energy_cost,
)
from src.agents.ddqn_agent import DDQNAgent
from src.utils.config_loader import load_config, get_device
from src.utils.plotting import plot_thermal_trajectory, plot_controller_comparison

print(f"PyTorch {torch.__version__}")

## 1. Explore the Environment

In [None]:
config = load_config(PROJECT_ROOT / "config" / "thermal_control.yaml")
device = get_device(config)

env = ThermalControlEnv(config=config)
print(f"Observation space: {env.observation_space}")
print(f"Action space:      {env.action_space} (fan levels)")
print(f"\nThermal parameters:")
print(f"  Target temp:    {env.target_temp}°C ± {env.temp_tolerance}°C")
print(f"  Critical temp:  {env.critical_temp}°C")
print(f"  Thermal mass:   {env.thermal_mass} J/°C")
print(f"  Thermal R:      {env.thermal_resistance} °C/W")
print(f"  Fan cooling:    {env.fan_cooling} W")
print(f"  Fan cost:       {env.fan_cost}")

## 2. Random Agent Baseline

In [None]:
state, info = env.reset(seed=42)
random_temps, random_fans, random_rewards = [], [], []

for _ in range(500):
    action = env.action_space.sample()
    state, reward, terminated, truncated, info = env.step(action)
    random_temps.append(info["temperature"])
    random_fans.append(info["fan_level"])
    random_rewards.append(reward)
    if terminated or truncated:
        break

print(f"Random agent: {len(random_temps)} steps, "
      f"total reward = {sum(random_rewards):.1f}, "
      f"temp range = [{min(random_temps):.1f}, {max(random_temps):.1f}]°C")

fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 6), sharex=True)
ax1.plot(random_temps, color="tomato")
ax1.axhline(55, color="green", linestyle="--", alpha=0.6)
ax1.axhspan(50, 60, alpha=0.1, color="green")
ax1.set_ylabel("Temperature (°C)")
ax1.set_title("Random Agent")
ax2.step(range(len(random_fans)), random_fans, color="steelblue", where="post")
ax2.set_ylabel("Fan Level")
ax2.set_xlabel("Step")
fig.tight_layout()
plt.show()

## 3. PID Controller Baseline

In [None]:
pid = PIDController.from_config(config)
state, info = env.reset(seed=42)
pid.reset()

pid_traj = {"states": [state.copy()], "actions": [], "rewards": [], "infos": [info]}

for _ in range(500):
    temperature = state[0]
    action = pid.compute_action(temperature)
    state, reward, terminated, truncated, info = env.step(action)
    pid_traj["actions"].append(action)
    pid_traj["rewards"].append(reward)
    pid_traj["states"].append(state.copy())
    pid_traj["infos"].append(info)
    if terminated or truncated:
        break

pid_temps = np.array([i["temperature"] for i in pid_traj["infos"]])
pid_total = sum(pid_traj["rewards"])
print(f"PID: {len(pid_traj['actions'])} steps, total reward = {pid_total:.1f}")
print(f"  Temp range: [{pid_temps.min():.1f}, {pid_temps.max():.1f}]°C")

fig = plot_thermal_trajectory(pid_traj, config, title="PID Controller")
plt.show()

## 4. Load Trained DQN Agent

In [None]:
agent = DDQNAgent(config, device)

ckpt_path = PROJECT_ROOT / "outputs" / "models" / "thermal_control" / "checkpoint_best.pt"
if ckpt_path.exists():
    ckpt = agent.load(ckpt_path)
    agent.epsilon = 0.0
    print(f"Loaded checkpoint from episode {ckpt.get('episode', '?')}")
    print(f"Best eval reward: {ckpt.get('best_eval_reward', 'N/A')}")
else:
    print("No checkpoint found. Train first:")
    print("  python train.py --config config/thermal_control.yaml")

print(f"\nNetwork:\n{agent.online_net}")

## 5. DQN Agent Rollout

In [None]:
state, info = env.reset(seed=42)
dqn_traj = {"states": [state.copy()], "actions": [], "rewards": [], "infos": [info]}

for _ in range(500):
    action = agent.select_action(state, eval_mode=True)
    state, reward, terminated, truncated, info = env.step(action)
    dqn_traj["actions"].append(action)
    dqn_traj["rewards"].append(reward)
    dqn_traj["states"].append(state.copy())
    dqn_traj["infos"].append(info)
    if terminated or truncated:
        break

dqn_temps = np.array([i["temperature"] for i in dqn_traj["infos"]])
dqn_total = sum(dqn_traj["rewards"])
print(f"DQN: {len(dqn_traj['actions'])} steps, total reward = {dqn_total:.1f}")
print(f"  Temp range: [{dqn_temps.min():.1f}, {dqn_temps.max():.1f}]°C")

fig = plot_thermal_trajectory(dqn_traj, config, title="DQN Agent")
plt.show()

## 6. Head-to-Head Comparison

In [None]:
fig = plot_controller_comparison(dqn_traj, pid_traj, config)
plt.show()

## 7. Quantitative Control Metrics

In [None]:
thermal_cfg = config.get("environment", {}).get("thermal", {})
target = thermal_cfg.get("target_temp", 55.0)
dt = thermal_cfg.get("dt", 1.0)
fan_cost_table = thermal_cfg.get("fan_energy_cost", [0, 0.5, 1.5, 3.5, 7.0])

print(f"{'Metric':<30} {'DQN':>12} {'PID':>12}")
print("─" * 55)

for label, traj in [("DQN", dqn_traj), ("PID", pid_traj)]:
    temps = np.array([i["temperature"] for i in traj["infos"]])
    actions = np.array(traj["actions"])
    pass  # just for structure

# Compute and display
metrics = {}
for label, traj in [("DQN", dqn_traj), ("PID", pid_traj)]:
    temps = np.array([i["temperature"] for i in traj["infos"]])
    actions = np.array(traj["actions"])
    metrics[label] = {
        "Total Reward": sum(traj["rewards"]),
        "Settling Time (s)": compute_settling_time(temps, target, 0.05, dt) or float("nan"),
        "Overshoot (%)": compute_overshoot(temps, target),
        "Steady-State Error (°C)": compute_steady_state_error(temps, target),
        "IAE": compute_integral_absolute_error(temps, target, dt),
        "Energy Cost": compute_energy_cost(actions, fan_cost_table),
        "In-Band %": 100 * np.mean(np.abs(temps - target) <= 5.0),
    }

for metric_name in metrics["DQN"]:
    dqn_val = metrics["DQN"][metric_name]
    pid_val = metrics["PID"][metric_name]
    print(f"{metric_name:<30} {dqn_val:>12.2f} {pid_val:>12.2f}")

## 8. Environment Dynamics Exploration

Visualise how the thermal system responds to fixed fan levels.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

for fan_level in range(5):
    state, _ = env.reset(seed=0)
    temps = [state[0]]
    for _ in range(200):
        state, _, terminated, truncated, _ = env.step(fan_level)
        temps.append(state[0])
        if terminated or truncated:
            break
    axes[0].plot(temps, label=f"Fan {fan_level}")

axes[0].axhline(55, color="green", linestyle="--", alpha=0.5)
axes[0].axhspan(50, 60, alpha=0.08, color="green")
axes[0].set_xlabel("Step")
axes[0].set_ylabel("Temperature (°C)")
axes[0].set_title("Step Response per Fan Level")
axes[0].legend(fontsize=9)

# Workload profile
steps = np.arange(500)
freq = thermal_cfg.get("workload_frequency", 0.02)
base = thermal_cfg.get("heat_generation_base", 30.0)
var = thermal_cfg.get("heat_generation_var", 15.0)
workload = base + var * np.sin(2 * np.pi * freq * steps)
axes[1].plot(steps, workload, color="orange")
axes[1].set_xlabel("Step")
axes[1].set_ylabel("Heat Generation (W)")
axes[1].set_title("Workload Profile (sinusoidal)")

fig.tight_layout()
plt.show()
env.close()