In [3]:
"""
CQL (Conservative Q-Learning) - 连续动作空间训练脚本
ICU药物剂量优化 - 离线强化学习

特点：
- 连续动作空间（药物剂量 mg）
- 以stay_id为轨迹，每4小时为一个状态
- 使用高斯策略 + Twin Q网络
- 包含策略评估（FQE）和推理接口

运行方式：
1. 直接运行: python cql_continuous_train.py
2. 或在Jupyter Notebook中按cell执行
"""

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Tuple, Dict, List
import warnings
import json
warnings.filterwarnings('ignore')

# ========== 配置 ==========
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)

# 数据列定义
STATE_COLS = [
    "vanco_level(ug/mL)",
    "creatinine(mg/dL)",
    "wbc(K/uL)",
    "bun(mg/dL)",
    "temperature",
    "sbp",
    "heart_rate"
]
ACTION_COL = "totalamount_mg"
REWARD_COL = "step_reward"
TIME_COLS = ["stay_id", "step_4hr"]

print(f"使用设备: {DEVICE}")
print("=" * 80)

# ========== 1. 数据加载 ==========
print("1. 加载数据...")
df = pd.read_csv("ready_data.csv")
print(f"数据形状: {df.shape}")
print(f"唯一stay_id数: {df['stay_id'].nunique()}")
print(f"\n动作统计:")
print(df[ACTION_COL].describe())
print(f"\n奖励统计:")
print(df[REWARD_COL].describe())

# ========== 2. 数据预处理 ==========
print("\n2. 数据预处理...")
df[STATE_COLS] = df[STATE_COLS].fillna(df[STATE_COLS].median())

# 按stay_id划分训练集和验证集
# 注意：验证集的作用是防止过拟合，不是评估策略好坏
# - 验证集的动作不一定是最优的
# - 验证损失只是训练过程的监控指标
# - 真正的策略评估应该用FQE（策略价值），见 evaluate_cql_continuous.py
stay_ids = df["stay_id"].unique()
train_stay_ids, val_stay_ids = train_test_split(stay_ids, test_size=0.2, random_state=42)
train_df = df[df["stay_id"].isin(train_stay_ids)].reset_index(drop=True)
val_df = df[df["stay_id"].isin(val_stay_ids)].reset_index(drop=True)

print(f"训练集: {len(train_df)} 条记录, {len(train_stay_ids)} 个stay_id")
print(f"验证集: {len(val_df)} 条记录, {len(val_stay_ids)} 个stay_id")

# 标准化
state_scaler = StandardScaler()
train_states = state_scaler.fit_transform(train_df[STATE_COLS])
val_states = state_scaler.transform(val_df[STATE_COLS])

action_scaler = StandardScaler()
train_actions = action_scaler.fit_transform(train_df[[ACTION_COL]]).flatten()
val_actions = action_scaler.transform(val_df[[ACTION_COL]]).flatten()

# 奖励归一化
r_min = train_df[REWARD_COL].min()
r_max = train_df[REWARD_COL].max()
r_range = r_max - r_min if r_max != r_min else 1.0
train_rewards = (train_df[REWARD_COL].values - r_min) / r_range
val_rewards = (val_df[REWARD_COL].values - r_min) / r_range

print(f"状态维度: {train_states.shape[1]}")
print(f"动作范围（原始）: [{train_df[ACTION_COL].min():.1f}, {train_df[ACTION_COL].max():.1f}] mg")

# 构建转移
def build_transitions(df_src, scaled_states):
    next_states = np.zeros_like(scaled_states)
    dones = np.zeros(len(df_src), dtype=np.float32)
    for stay_id in df_src['stay_id'].unique():
        stay_mask = df_src['stay_id'] == stay_id
        stay_data = df_src[stay_mask].sort_values('step_4hr').reset_index(drop=True)
        stay_indices = np.where(stay_mask)[0]
        for i, idx in enumerate(stay_indices):
            if i < len(stay_indices) - 1:
                next_states[idx] = scaled_states[stay_indices[i + 1]]
                dones[idx] = 0.0
            else:
                next_states[idx] = scaled_states[idx]
                dones[idx] = 1.0
    return next_states.astype(np.float32), dones.astype(np.float32)

train_next_states, train_dones = build_transitions(train_df, train_states)
val_next_states, val_dones = build_transitions(val_df, val_states)

# ========== 3. 数据集 ==========
print("\n3. 构建数据集...")
class RLDataset(Dataset):
    def __init__(self, states, actions, rewards, next_states, dones):
        self.states = torch.FloatTensor(states)
        self.actions = torch.FloatTensor(actions)
        self.rewards = torch.FloatTensor(rewards)
        self.next_states = torch.FloatTensor(next_states)
        self.dones = torch.FloatTensor(dones)
    def __len__(self):
        return len(self.states)
    def __getitem__(self, idx):
        return (self.states[idx], self.actions[idx], self.rewards[idx], 
                self.next_states[idx], self.dones[idx])

train_dataset = RLDataset(train_states, train_actions, train_rewards, train_next_states, train_dones)
val_dataset = RLDataset(val_states, val_actions, val_rewards, val_next_states, val_dones)
train_loader = DataLoader(train_dataset, batch_size=256, shuffle=True, drop_last=True)
val_loader = DataLoader(val_dataset, batch_size=256, shuffle=False, drop_last=False)

# ========== 4. 模型定义 ==========
print("\n4. 定义模型...")
def create_mlp(input_dim, output_dim, hidden_sizes=(256, 256), activation=nn.ReLU):
    layers = []
    last_dim = input_dim
    for h in hidden_sizes:
        layers.extend([nn.Linear(last_dim, h), activation()])
        last_dim = h
    layers.append(nn.Linear(last_dim, output_dim))
    return nn.Sequential(*layers)

class GaussianPolicy(nn.Module):
    """高斯策略（连续动作空间）"""
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = create_mlp(state_dim, 2 * action_dim, hidden_sizes)
        self.log_std_min = -5
        self.log_std_max = 2
    def forward(self, state):
        mean_logstd = self.net(state)
        mean, log_std = torch.chunk(mean_logstd, 2, dim=-1)
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (log_std + 1) * (self.log_std_max - self.log_std_min)
        std = torch.exp(log_std)
        return mean, std
    def sample(self, state):
        mean, std = self(state)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        return action, log_prob, torch.tanh(mean)
    def get_deterministic_action(self, state):
        mean, _ = self(state)
        return torch.tanh(mean)

class QNetwork(nn.Module):
    """Q网络"""
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = create_mlp(state_dim + action_dim, 1, hidden_sizes)
    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.net(sa)

class CQLAgent(nn.Module):
    """CQL Agent: Twin Q + 高斯策略"""
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.q1 = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q2 = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q1_target = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q2_target = QNetwork(state_dim, action_dim, hidden_sizes)
        self.policy = GaussianPolicy(state_dim, action_dim, hidden_sizes)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())
    @torch.no_grad()
    def soft_update(self, tau=0.005):
        for target_param, param in zip(self.q1_target.parameters(), self.q1.parameters()):
            target_param.data.mul_(1 - tau).add_(tau * param.data)
        for target_param, param in zip(self.q2_target.parameters(), self.q2.parameters()):
            target_param.data.mul_(1 - tau).add_(tau * param.data)

state_dim = train_states.shape[1]
action_dim = 1
agent = CQLAgent(state_dim, action_dim).to(DEVICE)

# ========== 修复：重新初始化Q网络为小值 ==========
# 默认初始化可能导致Q值过大或过小，导致训练不稳定
def init_q_network_small(m):
    """将Q网络初始化为接近0的小值，提高训练稳定性"""
    if isinstance(m, nn.Linear):
        if m.out_features == 1:  # Q网络输出层
            # 输出层初始化为接近0的小值
            nn.init.uniform_(m.weight, -0.01, 0.01)
            nn.init.constant_(m.bias, 0.0)
        else:
            # 隐藏层使用较小的Xavier初始化
            nn.init.xavier_uniform_(m.weight, gain=0.5)
            nn.init.constant_(m.bias, 0.0)

agent.q1.apply(init_q_network_small)
agent.q2.apply(init_q_network_small)
agent.q1_target.load_state_dict(agent.q1.state_dict())
agent.q2_target.load_state_dict(agent.q2.state_dict())

# 验证初始化
agent.eval()
with torch.no_grad():
    sample_state = torch.FloatTensor(train_states[:5]).to(DEVICE)
    sample_action = torch.FloatTensor(train_actions[:5]).unsqueeze(1).to(DEVICE)
    q1_init = agent.q1(sample_state, sample_action)
    q2_init = agent.q2(sample_state, sample_action)
    print(f"初始化后Q1值范围: [{q1_init.min().item():.4f}, {q1_init.max().item():.4f}]")
    print(f"初始化后Q2值范围: [{q2_init.min().item():.4f}, {q2_init.max().item():.4f}]")
agent.train()

print(f"状态维度: {state_dim}, 动作维度: {action_dim}")
print(f"模型参数数量: {sum(p.numel() for p in agent.parameters()):,}")

# ========== 5. CQL损失函数 ==========
print("\n5. 定义CQL损失函数...")
def compute_cql_loss(agent, batch, config):
    states, actions, rewards, next_states, dones = batch
    states = states.to(DEVICE)
    actions = actions.unsqueeze(1).to(DEVICE)
    rewards = rewards.unsqueeze(1).to(DEVICE)
    next_states = next_states.to(DEVICE)
    dones = dones.unsqueeze(1).to(DEVICE)
    
    # Bellman误差
    q1_data = agent.q1(states, actions)
    q2_data = agent.q2(states, actions)
    
    with torch.no_grad():
        next_actions, next_logp, _ = agent.policy.sample(next_states)
        next_q1_target = agent.q1_target(next_states, next_actions)
        next_q2_target = agent.q2_target(next_states, next_actions)
        next_q_target = torch.min(next_q1_target, next_q2_target) - next_logp
        backup = rewards + config['gamma'] * (1 - dones) * next_q_target
    
    q1_loss = F.huber_loss(q1_data, backup, delta=1.0)
    q2_loss = F.huber_loss(q2_data, backup, delta=1.0)
    
    # CQL正则项
    batch_size = states.shape[0]
    cql_samples = config.get('cql_samples', 10)
    random_actions = torch.empty(batch_size, cql_samples, action_dim, device=DEVICE).uniform_(-1, 1)
    policy_actions, _, _ = agent.policy.sample(states)
    states_rep = states.unsqueeze(1).expand(-1, cql_samples, -1).reshape(-1, states.shape[-1])
    random_actions_flat = random_actions.reshape(-1, action_dim)
    q1_rand = agent.q1(states_rep, random_actions_flat).reshape(batch_size, cql_samples, 1)
    q2_rand = agent.q2(states_rep, random_actions_flat).reshape(batch_size, cql_samples, 1)
    q1_policy = agent.q1(states, policy_actions)
    q2_policy = agent.q2(states, policy_actions)
    q1_cat = torch.cat([q1_rand, q1_policy.unsqueeze(1)], dim=1)
    q2_cat = torch.cat([q2_rand, q2_policy.unsqueeze(1)], dim=1)
    cql1 = torch.logsumexp(q1_cat, dim=1).mean() - q1_data.mean()
    cql2 = torch.logsumexp(q2_cat, dim=1).mean() - q2_data.mean()
    
    # 策略损失
    policy_dist = torch.distributions.Normal(*agent.policy(states))
    policy_log_prob = policy_dist.log_prob(actions.squeeze(1)).sum(dim=-1, keepdim=True)
    policy_loss = -policy_log_prob.mean()
    
    total_loss = q1_loss + q2_loss + config['alpha'] * (cql1 + cql2) + 0.1 * policy_loss
    
    info = {
        'total_loss': total_loss.item(),
        'q_loss': (q1_loss + q2_loss).item(),
        'cql1': cql1.item(),
        'cql2': cql2.item(),
        'policy_loss': policy_loss.item(),
        'q1_mean': q1_data.mean().item(),
    }
    return total_loss, info

# ========== 6. 训练配置 ==========
print("\n6. 训练配置...")
config = {
    'batch_size': 256,
    'lr': 1e-4,
    'gamma': 0.99,
    'alpha': 0.01,  # 降低到0.01，避免Q值被CQL正则项拉低
    'tau': 0.005,
    'cql_samples': 10,
    'epochs': 30,
    'steps_per_epoch': 1000,
    'val_interval': 2,
}

optimizer = optim.Adam(
    list(agent.q1.parameters()) + 
    list(agent.q2.parameters()) + 
    list(agent.policy.parameters()),
    lr=config['lr']
)

# ========== 7. 训练循环 ==========
print("\n7. 开始训练...")
print("=" * 80)

def evaluate(agent, val_loader, config):
    """
    评估验证集损失（用于防止过拟合，不是策略评估）
    
    注意：
    - 验证集的动作不一定是最优的
    - 验证损失只是训练过程的监控指标
    - 真正的策略评估应该用FQE（见 evaluate_cql_continuous.py）
    """
    agent.eval()
    losses = []
    with torch.no_grad():
        for batch in val_loader:
            loss, _ = compute_cql_loss(agent, batch, config)
            losses.append(loss.item())
    agent.train()
    return np.mean(losses)

train_losses = []
val_losses = []
best_val_loss = float('inf')
patience = 10
patience_counter = 0

for epoch in range(1, config['epochs'] + 1):
    agent.train()
    epoch_losses = []
    for step in range(config['steps_per_epoch']):
        batch = next(iter(train_loader))
        loss, info = compute_cql_loss(agent, batch, config)
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(agent.parameters(), 5.0)
        optimizer.step()
        agent.soft_update(config['tau'])
        epoch_losses.append(info['total_loss'])
    
    avg_loss = np.mean(epoch_losses)
    train_losses.append(avg_loss)
    
    if epoch % config['val_interval'] == 0:
        # 验证损失用于防止过拟合，选择最佳模型
        # 注意：这不是策略评估！真正的策略评估用FQE（见 evaluate_cql_continuous.py）
        val_loss = evaluate(agent, val_loader, config)
        val_losses.append(val_loss)
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            patience_counter = 0
            torch.save({
                'agent': agent.state_dict(),
                'state_scaler': state_scaler,
                'action_scaler': action_scaler,
                'config': config,
                'r_min': r_min,
                'r_range': r_range,
            }, 'cql_best_model.pt')
            print(f"Epoch {epoch:3d} | 训练损失: {avg_loss:.4f} | 验证损失: {val_loss:.4f} ✓ (最佳模型已保存)")
        else:
            patience_counter += 1
            print(f"Epoch {epoch:3d} | 训练损失: {avg_loss:.4f} | 验证损失: {val_loss:.4f} (耐心: {patience_counter}/{patience})")
            if patience_counter >= patience:
                print(f"\n早停触发！最佳验证损失: {best_val_loss:.4f}")
                print("提示：验证损失只是训练指标，策略评估请运行 evaluate_cql_continuous.py")
                break
    else:
        print(f"Epoch {epoch:3d} | 训练损失: {avg_loss:.4f}")

print(f"\n训练完成！最佳验证损失: {best_val_loss:.4f}")

# ========== 8. 保存最终模型 ==========
print("\n8. 保存最终模型...")
torch.save({
    'agent': agent.state_dict(),
    'state_scaler': state_scaler,
    'action_scaler': action_scaler,
    'config': config,
    'r_min': r_min,
    'r_range': r_range,
    'state_dim': state_dim,
    'action_dim': action_dim,
    'STATE_COLS': STATE_COLS,
    'ACTION_COL': ACTION_COL,
    'REWARD_COL': REWARD_COL,
}, 'cql_final_model.pt')

print("✅ 模型已保存: cql_final_model.pt")
print("\n" + "=" * 80)
print("重要提示：")
print("=" * 80)
print("1. 验证损失只是训练过程的监控指标（防止过拟合）")
print("2. 验证集的动作不一定是最优的，所以验证损失不能评估策略好坏")
print("3. 真正的策略评估应该使用FQE方法（策略价值）")
print("4. 运行评估脚本: python evaluate_cql_continuous.py")
print("=" * 80)



使用设备: cpu
1. 加载数据...
数据形状: (2113, 17)
唯一stay_id数: 58

动作统计:
count    2113.000000
mean       72.404638
std       221.459863
min         0.000000
25%         0.000000
50%         0.000000
75%         0.000000
max      1500.000000
Name: totalamount_mg, dtype: float64

奖励统计:
count    2113.000000
mean       -0.734737
std         1.173248
min        -5.700000
25%        -1.700000
50%        -0.700000
75%         0.200000
max         2.500000
Name: step_reward, dtype: float64

2. 数据预处理...
训练集: 1805 条记录, 46 个stay_id
验证集: 308 条记录, 12 个stay_id
状态维度: 7
动作范围（原始）: [0.0, 1500.0] mg

3. 构建数据集...

4. 定义模型...
初始化后Q1值范围: [-0.0028, 0.0006]
初始化后Q2值范围: [-0.0004, 0.0023]
状态维度: 7, 动作维度: 1
模型参数数量: 341,766

5. 定义CQL损失函数...

6. 训练配置...

7. 开始训练...
Epoch   1 | 训练损失: 39.0913
Epoch   2 | 训练损失: 37.2220 | 验证损失: 41.1803 ✓ (最佳模型已保存)
Epoch   3 | 训练损失: 37.7510
Epoch   4 | 训练损失: 38.0299 | 验证损失: 40.5099 ✓ (最佳模型已保存)
Epoch   5 | 训练损失: 38.2090
Epoch   6 | 训练损失: 38.6193 | 验证损失: 44.1039 (耐心: 1/10)
Epoch   7 | 训练损失: 38.8604
Epo