# Predictor-Guided Exploration 与 Dueling DQN 强化学习

本 Notebook 展示如何利用已训练的 health predictor 引导智能体的探索策略，并基于 Dueling DQN 进行训练与评估。

## 2. 实现 Predictor-Guided Exploration 策略
实现一个 epsilon-greedy 变体：在探索期，先用 health predictor 计算 $P_{fail}$，根据阈值动态调整可选动作池，实现安全约束的启发式采样。

In [None]:
import numpy as np
import random
import joblib
# 加载 predictor
health_predictor = joblib.load('data/health_predictor_lr.pkl')  # 或加载 MLP

# 预测器引导的 epsilon-greedy 策略

def predictor_guided_epsilon_greedy(obs, repair_cnt, q_values, epsilon=0.6, danger_threshold=0.7):
    """
    obs: 当前观测 (1, 9)
    repair_cnt: 当前修理次数
    q_values: 当前状态下所有动作的Q值 (3,)
    epsilon: 探索概率
    threshold: 安全阈值
    danger_threshold: 极端危险阈值
    返回: 选定动作
    """
    if np.random.rand() > epsilon:
        # 利用期，选Q值最大动作
        return int(np.argmax(q_values))
    else:
        # 探索期，先用预测器判断失效概率
        p_fail = health_predictor.predict_proba(obs.reshape(1, -1))[0, 1]
        print(f"Predicted failure probability: {p_fail:.4f}")
        if p_fail > danger_threshold:
            # 极端危险，强制修理
            print("Danger detected! Forcing repair action.")
            return random.choice([1, 2])
        else:
            # 正常探索，随机选动作
            return random.choice([0, 1, 2])

## 3. 环境与状态空间初始化
初始化 Gym 环境，定义状态、动作空间，设置随机种子，准备 DQN 训练所需的环境交互接口。

In [8]:
import torch
import random
from student_client import create_student_gym_env

def set_seed(seed=42):
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

env = create_student_gym_env()
obs_dim = env.observation_space.shape[0]
act_dim = env.action_space.n
print(f"状态维度: {obs_dim}, 动作数: {act_dim}")

2026-02-22 13:27:24,885 - student_client.student_gym_env - INFO - Client version 0.3 is newer than latest 0.2
2026-02-22 13:27:24,911 - student_client.student_gym_env - INFO - Created new session: 4d4a2a0c-142c-4653-945d-e13e01b8a720
2026-02-22 13:27:25,820 - student_client.student_gym_env - INFO - Created new episode: 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd
2026-02-22 13:27:25,824 - student_client.student_gym_env - INFO - StudentGymEnv initialized with episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd


状态维度: 9, 动作数: 3


## 4. Dueling DQN 网络结构定义
用 PyTorch 定义 Dueling DQN 网络，包括 value stream 和 advantage stream，输出每个动作的 Q 值。

In [9]:
import torch.nn as nn

class DuelingDQN(nn.Module):
    def __init__(self, obs_dim, act_dim):
        super().__init__()
        self.feature = nn.Sequential(
            nn.Linear(obs_dim, 128),
            nn.ReLU(),
        )
        self.value_stream = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, 1)
        )
        self.advantage_stream = nn.Sequential(
            nn.Linear(128, 64),
            nn.ReLU(),
            nn.Linear(64, act_dim)
        )
    def forward(self, x):
        feat = self.feature(x)
        value = self.value_stream(feat)
        adv = self.advantage_stream(feat)
        q = value + adv - adv.mean(dim=1, keepdim=True)
        return q

## 5. 经验回放与优化器设置
实现经验回放缓冲区（Replay Buffer），设置优化器、损失函数、目标网络同步等 DQN 训练组件。

In [12]:
from collections import deque
import torch.optim as optim

class ReplayBuffer:
    def __init__(self, capacity=10000):
        self.buffer = deque(maxlen=capacity)
    def push(self, s, a, r, s_, d):
        self.buffer.append((s, a, r, s_, d))
    def sample(self, batch_size):
        idx = np.random.choice(len(self.buffer), batch_size, replace=False)
        s, a, r, s_, d = zip(*[self.buffer[i] for i in idx])
        return np.stack(s), np.array(a), np.array(r), np.stack(s_), np.array(d)
    def __len__(self):
        return len(self.buffer)

buffer = ReplayBuffer(20000)

# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')
policy_net = DuelingDQN(obs_dim, act_dim).to(device)
target_net = DuelingDQN(obs_dim, act_dim).to(device)
target_net.load_state_dict(policy_net.state_dict())
target_net.eval()
optimizer = optim.Adam(policy_net.parameters(), lr=1e-3)
loss_fn = nn.MSELoss()

def sync_target():
    target_net.load_state_dict(policy_net.state_dict())

## 6. 训练 Dueling DQN 智能体（含引导探索）
主训练循环：每步用 predictor-guided epsilon-greedy 选择动作，收集经验，定期优化 DQN 网络，记录奖励、成功率等指标。

In [None]:
from utils import normalize_obs, process_reward
num_episodes = 300
batch_size = 64
gamma = 0.99
target_update = 20
epsilon_start = 0.2
epsilon_end = 0.05
epsilon_decay = 200

all_rewards = []
all_lengths = []

for ep in range(num_episodes):
    print(f"Episode {ep+1}/{num_episodes}")
    obs, info = env.reset()
    obs = np.array(obs, dtype=np.float32)
    total_reward = 0
    done = False
    repair_cnt = 0
    t = 0
    while not done:
        epsilon = epsilon_end + (epsilon_start - epsilon_end) * np.exp(-1. * ep / epsilon_decay)
        obs_input = obs.reshape(1, -1)
        obs_normalized = normalize_obs(obs_input)
        with torch.no_grad():
            obs_tensor = torch.tensor(obs_normalized, dtype=torch.float32, device=device)
            q_values = policy_net(obs_tensor).cpu().numpy()[0]
        action = predictor_guided_epsilon_greedy(obs_normalized, repair_cnt, q_values, epsilon=epsilon)
        next_obs, reward, terminated, truncated, info = env.step(action)
        reward = process_reward(reward)
        next_obs = normalize_obs(next_obs.reshape(1, -1))
        done_flag = terminated or truncated
        buffer.push(obs, action, reward, next_obs, done_flag)
        obs = next_obs
        total_reward += reward
        if action == 1:
            repair_cnt += 1
        done = done_flag
        t += 1
        # DQN优化
        if len(buffer) >= batch_size:
            s, a, r, s_, d = buffer.sample(batch_size)
            s = torch.tensor(s, dtype=torch.float32, device=device)
            a = torch.tensor(a, dtype=torch.long, device=device)
            r = torch.tensor(r, dtype=torch.float32, device=device)
            s_ = torch.tensor(s_, dtype=torch.float32, device=device)
            d = torch.tensor(d, dtype=torch.float32, device=device)
            q = policy_net(s).gather(1, a.unsqueeze(1)).squeeze(1)
            with torch.no_grad():
                q_next = target_net(s_).max(1)[0]
                q_target = r + gamma * q_next * (1 - d)
            loss = loss_fn(q, q_target)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
    all_rewards.append(total_reward)
    all_lengths.append(t)
    if ep % target_update == 0:
        sync_target()
    if (ep+1) % 20 == 0:
        print(f"Episode {ep+1}, Reward: {total_reward:.2f}, Epsilon: {epsilon:.3f}")

2026-02-22 13:32:19,407 - student_client.student_gym_env - INFO - Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reset successfully


Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reached termination state, reason: sold


2026-02-22 13:32:22,697 - student_client.student_gym_env - INFO - Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reset successfully


Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reached termination state, reason: sold


2026-02-22 13:32:35,168 - student_client.student_gym_env - INFO - Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reset successfully


Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reached termination state, reason: sold


2026-02-22 13:33:06,130 - student_client.student_gym_env - INFO - Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reset successfully


Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reached termination state, reason: sold


2026-02-22 13:33:10,491 - student_client.student_gym_env - INFO - Episode 91a91e63-2c02-42d8-bcea-9d4f7b8ea5cd reset successfully


KeyboardInterrupt: 

## 7. 训练过程可视化与评估
绘制训练曲线（如 episode reward、成功率），并与普通 epsilon-greedy 策略对比，评估 predictor-guided 策略的探索效率提升。

In [None]:
import matplotlib.pyplot as plt
plt.figure(figsize=(10,4))
plt.plot(all_rewards, label='Episode Reward')
plt.xlabel('Episode')
plt.ylabel('Reward')
plt.title('Dueling DQN with Predictor-Guided Exploration')
plt.legend()
plt.show()

# 可选：与普通 epsilon-greedy 对比（需另行训练baseline）
# plt.plot(baseline_rewards, label='Vanilla Epsilon-Greedy')
# plt.legend()
# plt.show()