In [1]:
"""
根据患者状态预测最优策略（药物剂量）

使用方法：
    python predict_action.py
    或在代码中导入使用
"""

import numpy as np
import torch
import torch.nn as nn
from sklearn.preprocessing import StandardScaler

# 导入模型定义（与训练脚本保持一致）
def create_mlp(input_dim, output_dim, hidden_sizes=(256, 256), activation=nn.ReLU):
    layers = []
    last_dim = input_dim
    for h in hidden_sizes:
        layers.extend([nn.Linear(last_dim, h), activation()])
        last_dim = h
    layers.append(nn.Linear(last_dim, output_dim))
    return nn.Sequential(*layers)

class GaussianPolicy(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = create_mlp(state_dim, 2 * action_dim, hidden_sizes)
        self.log_std_min = -5
        self.log_std_max = 2
    def forward(self, state):
        mean_logstd = self.net(state)
        mean, log_std = torch.chunk(mean_logstd, 2, dim=-1)
        log_std = torch.tanh(log_std)
        log_std = self.log_std_min + 0.5 * (log_std + 1) * (self.log_std_max - self.log_std_min)
        std = torch.exp(log_std)
        return mean, std
    def sample(self, state):
        mean, std = self(state)
        normal = torch.distributions.Normal(mean, std)
        x_t = normal.rsample()
        action = torch.tanh(x_t)
        log_prob = normal.log_prob(x_t) - torch.log(1 - action.pow(2) + 1e-6)
        log_prob = log_prob.sum(dim=-1, keepdim=True)
        return action, log_prob, torch.tanh(mean)
    def get_deterministic_action(self, state):
        mean, _ = self(state)
        return torch.tanh(mean)

class QNetwork(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.net = create_mlp(state_dim + action_dim, 1, hidden_sizes)
    def forward(self, state, action):
        sa = torch.cat([state, action], dim=-1)
        return self.net(sa)

class CQLAgent(nn.Module):
    def __init__(self, state_dim, action_dim, hidden_sizes=(256, 256)):
        super().__init__()
        self.q1 = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q2 = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q1_target = QNetwork(state_dim, action_dim, hidden_sizes)
        self.q2_target = QNetwork(state_dim, action_dim, hidden_sizes)
        self.policy = GaussianPolicy(state_dim, action_dim, hidden_sizes)
        self.q1_target.load_state_dict(self.q1.state_dict())
        self.q2_target.load_state_dict(self.q2.state_dict())

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 状态列定义（必须与训练时一致）
STATE_COLS = [
    "vanco_level(ug/mL)",
    "creatinine(mg/dL)",
    "wbc(K/uL)",
    "bun(mg/dL)",
    "temperature",
    "sbp",
    "heart_rate"
]


def load_model(model_path='cql_final_model.pt'):
    """
    加载训练好的模型
    
    返回:
        agent: 训练好的CQL agent
        state_scaler: 状态标准化器
        action_scaler: 动作标准化器
        config: 训练配置
    """
    print(f"加载模型: {model_path}")
    checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
    
    state_scaler = checkpoint['state_scaler']
    action_scaler = checkpoint['action_scaler']
    config = checkpoint['config']
    state_dim = checkpoint['state_dim']
    action_dim = checkpoint['action_dim']
    
    # 初始化模型
    agent = CQLAgent(state_dim, action_dim).to(DEVICE)
    agent.load_state_dict(checkpoint['agent'])
    agent.eval()
    
    print(f"✅ 模型加载成功")
    print(f"   状态维度: {state_dim}")
    print(f"   动作维度: {action_dim}")
    print(f"   设备: {DEVICE}")
    
    return agent, state_scaler, action_scaler, config


def predict_action(agent, patient_state_raw, state_scaler, action_scaler, device=DEVICE):
    """
    根据患者状态预测最优动作（药物剂量）
    
    参数:
        agent: 训练好的CQL agent
        patient_state_raw: 患者状态（原始值，未标准化），7维数组
                          [vanco_level, creatinine, wbc, bun, temperature, sbp, heart_rate]
        state_scaler: 状态标准化器
        action_scaler: 动作标准化器
        device: 计算设备
    
    返回:
        dict: 包含预测结果的字典
            - action: 推荐的药物剂量（mg，原始值）
            - action_mean_norm: 策略均值（标准化，[-1,1]）
            - action_std_norm: 策略标准差（标准化）
            - q_value: Q值估计
            - confidence: 策略置信度（基于标准差）
    """
    agent.eval()
    
    # 1. 标准化状态
    patient_state = state_scaler.transform([patient_state_raw])
    patient_state_tensor = torch.FloatTensor(patient_state).to(device)
    
    with torch.no_grad():
        # 2. 获取策略动作（确定性动作：使用均值）
        action_mean, action_std = agent.policy(patient_state_tensor)
        action_mean_norm = torch.tanh(action_mean)  # 限制到[-1, 1]
        
        # 3. 计算Q值（评估动作质量）
        q1_val = agent.q1(patient_state_tensor, action_mean_norm)
        q2_val = agent.q2(patient_state_tensor, action_mean_norm)
        q_value = (q1_val + q2_val) / 2
        
        # 4. 反标准化动作（从[-1, 1]恢复到原始范围）
        action_mean_np = action_mean_norm.cpu().numpy().reshape(1, -1)
        action_raw = action_scaler.inverse_transform(action_mean_np)[0][0]
        
        # 5. 确保动作非负（药物剂量不能为负）
        action_raw = max(0, action_raw)
        
        # 6. 计算置信度（标准差越小，置信度越高）
        action_std_norm = action_std.cpu().numpy()[0, 0]
        confidence = 1.0 / (1.0 + action_std_norm)  # 简单的置信度计算
    
    return {
        'action': action_raw,  # 推荐的药物剂量（mg，原始值）
        'action_mean_norm': action_mean_norm.cpu().numpy()[0, 0],  # 标准化均值
        'action_std_norm': action_std_norm,  # 标准化标准差
        'q_value': q_value.cpu().numpy()[0, 0],  # Q值估计
        'confidence': confidence,  # 策略置信度
    }


def predict_action_with_uncertainty(agent, patient_state_raw, state_scaler, action_scaler, 
                                    n_samples=100, device=DEVICE):
    """
    根据患者状态预测动作，并考虑不确定性（多次采样）
    
    参数:
        agent: 训练好的CQL agent
        patient_state_raw: 患者状态（原始值）
        state_scaler: 状态标准化器
        action_scaler: 动作标准化器
        n_samples: 采样次数
        device: 计算设备
    
    返回:
        dict: 包含预测结果和不确定性的字典
    """
    agent.eval()
    
    # 标准化状态
    patient_state = state_scaler.transform([patient_state_raw])
    patient_state_tensor = torch.FloatTensor(patient_state).to(device)
    
    with torch.no_grad():
        # 多次采样
        actions_list = []
        q_values_list = []
        
        for _ in range(n_samples):
            policy_actions, _, _ = agent.policy.sample(patient_state_tensor)
            q1_val = agent.q1(patient_state_tensor, policy_actions)
            q2_val = agent.q2(patient_state_tensor, policy_actions)
            q_value = (q1_val + q2_val) / 2
            
            # 反标准化
            action_norm = policy_actions.cpu().numpy()[0, 0]
            action_raw = action_scaler.inverse_transform([[action_norm]])[0][0]
            action_raw = max(0, action_raw)
            
            actions_list.append(action_raw)
            q_values_list.append(q_value.cpu().numpy()[0, 0])
        
        actions_array = np.array(actions_list)
        q_values_array = np.array(q_values_list)
        
        # 统计信息
        mean_action = np.mean(actions_array)
        std_action = np.std(actions_array)
        median_action = np.median(actions_array)
        q_mean = np.mean(q_values_array)
        
        # 置信区间（95%）
        action_lower = np.percentile(actions_array, 2.5)
        action_upper = np.percentile(actions_array, 97.5)
    
    return {
        'action_mean': mean_action,  # 平均推荐剂量
        'action_median': median_action,  # 中位数推荐剂量
        'action_std': std_action,  # 标准差
        'action_range': (action_lower, action_upper),  # 95%置信区间
        'q_value_mean': q_mean,  # 平均Q值
        'uncertainty': std_action / (mean_action + 1e-6),  # 相对不确定性
    }


# ========== 使用示例 ==========
if __name__ == "__main__":
    print("=" * 80)
    print("CQL模型推理 - 根据患者状态预测最优策略")
    print("=" * 80)
    
    # 1. 加载模型
    agent, state_scaler, action_scaler, config = load_model('cql_final_model.pt')
    
    # 2. 定义患者状态（原始值，未标准化）
    # 注意：必须按照STATE_COLS的顺序
    patient_state = np.array([
        12.0,    # vanco_level(ug/mL) - 万古霉素浓度
        1.2,     # creatinine(mg/dL) - 肌酐
        8.0,     # wbc(K/uL) - 白细胞计数
        20.0,    # bun(mg/dL) - 血尿素氮
        37.5,    # temperature - 体温
        120,     # sbp - 收缩压
        85,      # heart_rate - 心率
    ])
    
    print("\n" + "=" * 80)
    print("患者状态（原始值）:")
    print("=" * 80)
    for i, col in enumerate(STATE_COLS):
        print(f"  {col}: {patient_state[i]}")
    
    # 3. 预测动作（确定性）
    print("\n" + "=" * 80)
    print("策略预测结果（确定性）:")
    print("=" * 80)
    result = predict_action(agent, patient_state, state_scaler, action_scaler, DEVICE)
    
    print(f"推荐药物剂量: {result['action']:.1f} mg")
    print(f"Q值估计: {result['q_value']:.4f}")
    print(f"策略置信度: {result['confidence']:.4f}")
    print(f"策略标准差: {result['action_std_norm']:.4f}")
    
    # 4. 预测动作（考虑不确定性）
    print("\n" + "=" * 80)
    print("策略预测结果（考虑不确定性，采样100次）:")
    print("=" * 80)
    result_uncertainty = predict_action_with_uncertainty(
        agent, patient_state, state_scaler, action_scaler, 
        n_samples=100, device=DEVICE
    )
    
    print(f"平均推荐剂量: {result_uncertainty['action_mean']:.1f} mg")
    print(f"中位数推荐剂量: {result_uncertainty['action_median']:.1f} mg")
    print(f"标准差: {result_uncertainty['action_std']:.1f} mg")
    print(f"95%置信区间: [{result_uncertainty['action_range'][0]:.1f}, {result_uncertainty['action_range'][1]:.1f}] mg")
    print(f"相对不确定性: {result_uncertainty['uncertainty']*100:.2f}%")
    print(f"平均Q值: {result_uncertainty['q_value_mean']:.4f}")
    
    print("\n" + "=" * 80)
    print("使用说明:")
    print("=" * 80)
    print("1. 确定性预测：使用策略均值，给出单一推荐剂量")
    print("2. 不确定性预测：多次采样，考虑策略的不确定性")
    print("3. Q值：评估动作质量，值越大表示动作越好")
    print("4. 置信区间：95%的采样动作落在此范围内")
    print("=" * 80)



CQL模型推理 - 根据患者状态预测最优策略
加载模型: cql_final_model.pt
✅ 模型加载成功
   状态维度: 7
   动作维度: 1
   设备: cpu

患者状态（原始值）:
  vanco_level(ug/mL): 12.0
  creatinine(mg/dL): 1.2
  wbc(K/uL): 8.0
  bun(mg/dL): 20.0
  temperature: 37.5
  sbp: 120.0
  heart_rate: 85.0

策略预测结果（确定性）:
推荐药物剂量: 69.6 mg
Q值估计: 62.0689
策略置信度: 0.5144
策略标准差: 0.9440

策略预测结果（考虑不确定性，采样100次）:




平均推荐剂量: 88.1 mg
中位数推荐剂量: 44.8 mg
标准差: 94.2 mg
95%置信区间: [0.0, 264.9] mg
相对不确定性: 106.96%
平均Q值: 62.1276

使用说明:
1. 确定性预测：使用策略均值，给出单一推荐剂量
2. 不确定性预测：多次采样，考虑策略的不确定性
3. Q值：评估动作质量，值越大表示动作越好
4. 置信区间：95%的采样动作落在此范围内
