In [None]:
# 贝叶斯FINAL

In [None]:
import numpy as np 
import math
import logging
from scipy.stats import norm
from scipy.optimize import brentq, newton

# 配置日志（INFO 级别）
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# -------------------------------
# 全局可调参数
# -------------------------------
TITRATED_VOLUME = 11.0    # 被滴定物体积（mL）
ANALYTE_CONC = 0.1        # 被滴定试剂中酸的浓度（mol/L）

# 以下为可由用户设定的两种不同浓度的盐酸与两种不同浓度的氢氧化钠
HCL_CONC1 = 0.1           # 盐酸1浓度（mol/L）
HCL_CONC2 = 0.01          # 盐酸2浓度（mol/L）
NAOH_CONC1 = 0.1          # 氢氧化钠1浓度（mol/L）
NAOH_CONC2 = 0.01         # 氢氧化钠2浓度（mol/L）
TARGET_PH = 11           # 目标 pH
MAX_STEPS = 50            # 最大步骤数

# -------------------------------
# 全局试剂浓度字典（带有“1”或“2”标识）
# -------------------------------
REAGENTS = {
    'dilute_acid_1': HCL_CONC1,
    'dilute_acid_2': HCL_CONC2,
    'dilute_base_1': NAOH_CONC1,
    'dilute_base_2': NAOH_CONC2,
}

# -------------------------------
# pH计算函数（基于多个缓冲对）
# -------------------------------

def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge
def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    acid_anion_charge = self.calculate_acid_anion_charge(c_A, H, pKa_list)
    return H + c_Na - OH - acid_anion_charge - c_HCl
def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
        if abs(f_mid) < 1e-10:
            return mid
        if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

# -------------------------------
# pH 调整环境类
# -------------------------------
class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME  
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        
        # 最小滴加量（mL），并构建滴加体积列表
        self.min_addition_volume = 0.01  
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() 
                             for volume in self.addition_volumes]
                
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4

        # 设置缓冲体系的不确定参数：初始随机 pKa 和总摩尔数
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        # 用初始采样的 pKa 值作为参考
        self.ref_pKa = np.copy(self.pKa_list)
        # 用于记录每个缓冲对更新后标准差，初始设为 0.5
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None

        # 初始化先验分布（假设 pKa ~ N(mean, 0.5) 和 total_moles ~ N(mean, 0.005)）
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5

        self.last_measured_ph = None
        self.prev_measured_ph = None

        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None

        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))
    
    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        """
        根据当前的 pKa_list、ref_pKa 和 pKa_std 计算动态权重，构造有效 pKa 数组，
        数组长度等于缓冲对的个数。
        """
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        """
        计算从当前 pH 到目标 pH 所需的理论滴加体积，利用 brentq 数值求解.
        根据当前情况选择加入酸或碱，并使用更新后的 pKa 均值计算 pH。
        """
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()

        if self.current_ph < self.target_ph:
            # 当前体系酸性过强，需要加入碱
            if self.use_secondary_reagents:
                reagent = 'dilute_base_2'
            else:
                reagent = 'dilute_base_1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            # 当前体系碱性过强，需要加入酸
            if self.use_secondary_reagents:
                reagent = 'dilute_acid_2'
            else:
                reagent = 'dilute_acid_1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def step(self, action: tuple, mode: str = 'simulate') -> tuple:
        """
        mode参数：
          - 'simulate'：自动计算当前 pH（调用 recalc_ph 计算），
          - 'manual'：提示用户输入 pH 值（交互方式）。
        """
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume

            if 'acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles

            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'acid' in reagent.lower():
                return self.current_ph, -100, True, {}

            # 根据模式选择更新测得 pH 值
            if mode == 'simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            elif mode == 'manual':
                while True:
                    user_input = input("请输入当前测得的 pH 值: ")
                    try:
                        manual_ph = float(user_input)
                        break
                    except ValueError:
                        print("输入格式不正确，请输入一个数字（例如 7.0）。")
                self.update_exp_ph(manual_ph)

            # 检测在最小滴加量下的 pH 振荡情况
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("检测到在最小滴加量下的pH振荡，累计次数：%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("达到连续震荡阈值，切换到次级试剂滴定。")

            self.steps_taken += 1

            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}

            # --- 修改后的理想体积计算部分 ---
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['pKa'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            # --------------------------------------------

            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty

            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty

            if self.steps_taken > 0:
                if 'acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0

                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                     self.target_ph, reagent,
                                                                     last_added, reagent_conc,
                                                                     self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh

            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True

            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("执行 step 时出现异常：%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'PHAdjustmentEnv':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)


    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph

        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)

        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['pKa'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)

        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['pKa'].rvs())
            sampled_total_moles.append(prior['total_moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        """
        利用采样得到的参数更新环境拷贝后计算 pH，
        从而反映参数变化对 pH 预测的影响。
        """
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        """
        贝叶斯更新流程（基于粒子滤波）：
          ① 采样：从当前先验中采样 num_particles 个粒子；
          ② 预测：对于每个粒子，根据 action 预测操作后的 pH；
          ③ 评价：计算每个粒子的似然；
          ④ 重采样：根据似然重采样获得新粒子集合；
          ⑤ 统计更新：利用重采样后的粒子更新缓冲对的先验分布。
        """
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['pKa'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['total_moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='manual')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# -------------------------------
# 主程序
# -------------------------------
def main():
    initial_volume = TITRATED_VOLUME

    REAGENTS['dilute_acid_1'] = HCL_CONC1
    REAGENTS['dilute_acid_2'] = HCL_CONC2
    REAGENTS['dilute_base_1'] = NAOH_CONC1
    REAGENTS['dilute_base_2'] = NAOH_CONC2

    # 手动设定初始 pH
    initial_ph = 6.9
    logging.info("初始 pH = %.2f", initial_ph)
    
    env = PHAdjustmentEnv()
    env.initialize(init_pH=initial_ph, target_pH=TARGET_PH, max_steps=MAX_STEPS, initial_volume=initial_volume)
    
    measured_ph = env.current_ph
    action, done = env.select_best_action()
    
    while not done:
        if abs(measured_ph - env.target_ph) < 0.1:
            break
        overshoot_msg = ""
        if env.overshoot_threshold is not None:
            overshoot_msg = "（过冲限制：最大滴加体积为 {:.2f} mL）".format(env.overshoot_threshold)
        print("当前 pH = {:.2f}，推荐操作: 加 {} {}".format(measured_ph, action, overshoot_msg))
        action, done = env.suggest_next_action(action, measured_ph)
        measured_ph = env.current_ph

    print("达到目标 pH 或超过最大步数，实验结束。")
    print("总加酸量：{:.2f} mL, 总加碱量：{:.2f} mL, 总步骤数：{}，最终 pH = {:.2f}"
          .format(env.acid_volume, env.base_volume, env.steps_taken, measured_ph))

if __name__ == '__main__':
    main()


In [None]:
# 生成数据

In [None]:
import numpy as np
import math
import logging
import json
import random
from scipy.stats import norm
from scipy.optimize import brentq

# 配置日志（INFO 级别）
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 固定随机种子，确保实验重复性
np.random.seed(42)
random.seed(42)

# -------------------------------
# 全局可调参数（与第一段代码一致）
# -------------------------------
TITRATED_VOLUME = 10.0    # 被滴定物体积（mL）
ANALYTE_CONC = 0.1        # 被滴定试剂中酸的浓度（mol/L）
HCL_CONC1 = 0.1           # 盐酸1浓度（mol/L）
HCL_CONC2 = 0.01          # 盐酸2浓度（mol/L）
NAOH_CONC1 = 0.1          # 氢氧化钠1浓度（mol/L）
NAOH_CONC2 = 0.01         # 氢氧化钠2浓度（mol/L）
MAX_STEPS = 50            # 最大步骤数

# 试剂浓度字典
REAGENTS = {
    'dilute_acid_1': HCL_CONC1,
    'dilute_acid_2': HCL_CONC2,
    'dilute_base_1': NAOH_CONC1,
    'dilute_base_2': NAOH_CONC2,
}

# 定义试剂名称到离散动作索引的映射
reagent_mapping = {
    'dilute_acid_1': 0,
    'dilute_acid_2': 1,
    'dilute_base_1': 2,
    'dilute_base_2': 3,
}

# -------------------------------
# pH计算函数（与第一段代码一致）
# -------------------------------
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

def f(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_eff_array: np.ndarray) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_eff_array.tolist())
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH(c_A: float, c_Na: float, c_HCl: float, pKa_eff_array: np.ndarray) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f(mid, c_A, c_Na, c_HCl, pKa_eff_array)
        if abs(f_mid) < 1e-10:
            return mid
        if f(lo, c_A, c_Na, c_HCl, pKa_eff_array) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_custom(acid_total_moles: float, base_total_moles: float,
                        acid_volume: float, base_volume: float, pKa_list=None) -> float:
    V_total = (TITRATED_VOLUME + acid_volume + base_volume) / 1000.0
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    c_A = n_analyte / V_total
    c_Na = base_total_moles / V_total
    c_HCl = acid_total_moles / V_total
    if pKa_list is None:
        pKa_list = [4.21]
    pKa_eff_array = np.array(pKa_list)
    return round(solve_pH(c_A, c_Na, c_HCl, pKa_eff_array), 2)

# -------------------------------
# pH 调整环境类（滴定方案与第一段代码一致）
# -------------------------------
class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.last_action = None

    def get_state(self):
        ph_diff = round(self.current_ph - self.target_ph, 2) if self.target_ph is not None else None
        last_added_volume = self.last_action[1] if self.last_action is not None else 0
        return {
            'pH': round(self.current_ph, 2),
            'target_ph': self.target_ph,
            'ph_diff': ph_diff,
            'acid_volume': self.acid_volume,
            'base_volume': self.base_volume,
            'total_volume': self.total_volume,
            'steps_taken': self.steps_taken,
            'error': round(self.current_ph - self.target_ph, 2) if self.target_ph is not None else None,
            'last_action': self.last_action,
            'ph_delta': round(self.last_measured_ph - self.prev_measured_ph, 2) if self.prev_measured_ph is not None else None,
            'last_added_volume': last_added_volume
        }

    def initialize(self, init_pH: float, target_pH: float, max_steps: int,
                   acid_type: str = 'mono', pKa_list=None, initial_volume: float = TITRATED_VOLUME) -> None:
        if pKa_list is None:
            if acid_type == 'mono':
                pKa_list = [np.random.uniform(1, 5)]
            elif acid_type == 'di':
                pKa_list = [np.random.uniform(1, 3), np.random.uniform(4, 6)]
            elif acid_type == 'tri':
                pKa_list = [np.random.uniform(1, 3), np.random.uniform(4, 5), np.random.uniform(6, 7)]
            else:
                pKa_list = [4.0]
        pKa_list = [round(val, 2) for val in pKa_list]
        self.num_buffers = len(pKa_list)
        self.pKa_list = np.array(pKa_list)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        n_analyte = (initial_volume / 1000.0) * ANALYTE_CONC
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.acid_type = acid_type
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        pH = round(pH, 2)
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()

        if self.current_ph < self.target_ph:
            reagent = 'dilute_base_2' if self.use_secondary_reagents else 'dilute_base_1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'dilute_acid_2' if self.use_secondary_reagents else 'dilute_acid_1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def step(self, action: tuple, _: float = None) -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            penalty = 0
            if current_for_direction > self.target_ph and 'base' in reagent.lower():
                penalty = -100
                logging.info("使用错误试剂（base），给予惩罚，但继续实验。")
            if current_for_direction < self.target_ph and 'acid' in reagent.lower():
                penalty = -100
                logging.info("使用错误试剂（acid），给予惩罚，但继续实验。")
            simulated_ph = calculate_pH_custom(self.acid_added_moles, self.base_added_moles,
                                               self.acid_volume, self.base_volume,
                                               pKa_list=self.get_effective_pka_array().tolist())
            self.update_exp_ph(simulated_ph)
            self.last_action = action
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("检测到在最小滴加量下的pH震荡，累计次数：%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("达到连续震荡阈值，切换到次级试剂滴定。")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            current_error = abs(self.current_ph - self.target_ph)
            previous_error = abs(self.previous_ph - self.target_ph)
            ph_change = abs(self.current_ph - self.prev_measured_ph) if self.prev_measured_ph is not None else 0.0
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['pKa'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = current_error + 0.1 * required_vol
            max_vol = max(self.addition_volumes)
            ideal_volume = self.min_addition_volume + (max_vol - self.min_addition_volume) * np.tanh(alpha * combined_value)
            error_reward = -current_error
            improvement = previous_error - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty + penalty
            reward = round(reward, 2)
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'acid' in reagent.lower():
                pen = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= pen
            if self.target_ph < current_for_direction and 'base' in reagent.lower():
                pen = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= pen
            reward = round(reward, 2)
            if self.steps_taken > 0:
                if 'acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("执行 step 时出现异常：%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'PHAdjustmentEnv':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.overshoot_occurred = self.overshoot_occurred
        env_copied.overshoot_reagent = self.overshoot_reagent
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.acid_type = self.acid_type
        env_copied.last_action = self.last_action
        return env_copied

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph

        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_1' in r.lower()]

        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)

        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['pKa'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)

        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['pKa'].rvs())
            sampled_total_moles.append(prior['total_moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        reagent, volume = action
        volume = float(volume)
        added_moles = env_copy.reagents[reagent] * (volume / 1000.0)
        env_copy.total_volume += volume
        if 'acid' in reagent.lower():
            env_copy.acid_added_moles += added_moles
            env_copy.acid_volume += volume
        elif 'base' in reagent.lower():
            env_copy.base_added_moles += added_moles
            env_copy.base_volume += volume
        V_total = (TITRATED_VOLUME + env_copy.acid_volume + env_copy.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = env_copy.base_added_moles / V_total
        c_HCl = env_copy.acid_added_moles / V_total
        pKa_eff_array = np.array(sampled_pKa)
        new_ph = solve_pH(c_A, c_Na, c_HCl, pKa_eff_array)
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['pKa'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['total_moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action)
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# -------------------------------
# 单次实验生成函数
# -------------------------------
def generate_single_experiment(acid_type: str) -> dict:
    if acid_type == 'mono':
        pKa_list = [np.random.uniform(1, 5)]
    elif acid_type == 'di':
        pKa_list = [np.random.uniform(1, 4), np.random.uniform(4, 7)]
    elif acid_type == 'tri':
        pKa_list = [np.random.uniform(1, 3), np.random.uniform(3, 5), np.random.uniform(5, 7)]
    else:
        pKa_list = [4.0]
    pKa_list = [round(val, 2) for val in pKa_list]
    target_ph = round(np.random.uniform(2, 11), 2)
    init_ph = calculate_pH_custom(0, 0, 0, 0, pKa_list=pKa_list)
    env = PHAdjustmentEnv()
    env.initialize(init_pH=init_ph, target_pH=target_ph, max_steps=MAX_STEPS,
                   acid_type=acid_type, pKa_list=pKa_list, initial_volume=TITRATED_VOLUME)
    transitions = []
    state = env.get_state()
    action, done = env.select_best_action()
    while not env.done:
        current_ph, reward, done, _ = env.step(action)
        next_state = env.get_state()
        transition = {
            'state': state,
            'action': action,
            'reward': reward,
            'next_state': next_state,
            'done': done
        }
        transitions.append(transition)
        state = next_state
        if done:
            break
        action, done = env.select_best_action()
    experiment_data = {
        'acid_type': acid_type,
        'pKa_list': pKa_list,
        'target_ph': target_ph,
        'initial_ph': init_ph,
        'steps_taken': env.steps_taken,
        'success': (env.steps_taken <= MAX_STEPS and abs(env.current_ph - target_ph) < 0.1),
        'transitions': transitions
    }
    return experiment_data

# -------------------------------
# 辅助函数：将状态和动作转换为数值向量
# -------------------------------
def convert_state(state: dict) -> list:
    pH = state.get('pH', 0)
    target_ph = state.get('target_ph', 0)
    acid_vol = state.get('acid_volume', 0)
    base_vol = state.get('base_volume', 0)
    tot_vol = state.get('total_volume', 0)
    steps = state.get('steps_taken', 0)
    error = state.get('error', 0) if state.get('error') is not None else 0
    ph_delta = state.get('ph_delta', 0) if state.get('ph_delta') is not None else 0
    last_added = state.get('last_added_volume', 0)
    return [pH, target_ph, acid_vol, base_vol, tot_vol, steps, error, ph_delta, last_added]

def convert_action(action: tuple) -> list:
    reagent, volume = action
    reagent_idx = reagent_mapping.get(reagent, -1)
    return [reagent_idx, volume]

# -------------------------------
# 主函数：生成成功实验，并将转换后的 transition 数据聚合保存
# -------------------------------
def main():
    desired_success = 8
    successful_experiments = []
    acid_types = ['mono', 'di', 'tri']
    total_generated = 0
    while len(successful_experiments) < desired_success:
        acid_type = random.choice(acid_types)
        experiment = generate_single_experiment(acid_type)
        total_generated += 1
        if experiment['success']:
            successful_experiments.append(experiment)
        if total_generated % 100 == 0:
            logging.info("生成实验 %d 次，成功实验数量：%d", total_generated, len(successful_experiments))
    logging.info("成功实验生成完毕，总共生成实验 %d 次", total_generated)

    avg_steps = sum(exp['steps_taken'] for exp in successful_experiments) / len(successful_experiments)
    logging.info("成功实验的平均步数：%.2f", avg_steps)

    all_transitions = [trans for exp in successful_experiments for trans in exp['transitions']]
    total_samples = len(all_transitions)
    indices = list(range(total_samples))
    random.shuffle(indices)
    observations = [convert_state(all_transitions[i]['state']) for i in indices]
    actions = [convert_action(all_transitions[i]['action']) for i in indices]
    rewards = [all_transitions[i]['reward'] for i in indices]
    train_end = int(0.7 * total_samples)
    valid_end = train_end + int(0.15 * total_samples)
    train_set = {
        'observations': observations[:train_end],
        'actions': actions[:train_end],
        'rewards': rewards[:train_end],
    }
    valid_set = {
        'observations': observations[train_end:valid_end],
        'actions': actions[train_end:valid_end],
        'rewards': rewards[train_end:valid_end],
    }
    test_set = {
        'observations': observations[valid_end:],
        'actions': actions[valid_end:],
        'rewards': rewards[valid_end:],
    }

    with open('train_set_big_new_test.json', 'w') as f:
        json.dump(train_set, f, indent=2)
    with open('validation_set_big_new_test.json', 'w') as f:
        json.dump(valid_set, f, indent=2)
    with open('test_set_big_new_test.json', 'w') as f:
        json.dump(test_set, f, indent=2)
    logging.info("数据集划分完毕：训练集 %d 条，验证集 %d 条，测试集 %d 条",
                 len(train_set['observations']), len(valid_set['observations']), len(test_set['observations']))

if __name__ == '__main__':
    main()

In [None]:
# 训练离散回归模型

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np

# 固定随机种子，确保结果可复现
np.random.seed(42)
torch.manual_seed(42)

# -------------------------------
# 全局参数
# -------------------------------
INPUT_DIM = 5    # 特征：当前 pH、目标 pH、pH变化、误差（当前 pH-目标 pH）和上一步滴加的体积
HIDDEN_DIM1 = 256
HIDDEN_DIM2 = 256
BATCH_SIZE = 64
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3

# 离散动作空间参数：体积范围 [0.01, 10.00] mL，步长 0.01 mL
MIN_VOLUME = 0.01
MAX_VOLUME = 10.0
STEP = 0.01
NUM_ACTIONS = int((MAX_VOLUME - MIN_VOLUME) / STEP) + 1  # 1000个离散动作

# -------------------------------
# 数据集：转换连续标签为离散类别
# -------------------------------
class VolumePredictionDataset(Dataset):
    def __init__(self, dataset):
        # 将 observations 和 actions 转换为 numpy 数组
        obs = np.array(dataset['observations'])
        acts = np.array(dataset['actions'])
        
        # 只保留动作类别为 0 或 2 的样本
        mask = np.isin(acts[:, 0], [0, 2])
        obs = obs[mask]
        acts = acts[mask]
        
        # 提取输入特征：当前 pH（索引 0）、目标 pH（索引 1）、pH变化（索引 7）、误差（当前 pH-目标 pH）和上一步滴加的体积（索引 8）
        current_ph = obs[:, 0]
        target_ph = obs[:, 1]
        ph_change = obs[:, 7]
        error = current_ph - target_ph
        last_added_volume = obs[:, 8]
        inputs = np.stack([current_ph, target_ph, ph_change, error, last_added_volume], axis=1)
        self.inputs = torch.tensor(inputs, dtype=torch.float32)
        
        # 提取动作中的体积（第二列）作为回归目标，
        # 并转换为离散类别：类别索引 = round((volume - MIN_VOLUME)/STEP)
        continuous_volumes = acts[:, 1]
        indices = np.rint((continuous_volumes - MIN_VOLUME) / STEP).astype(np.int64)
        self.labels = torch.tensor(indices, dtype=torch.long)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

# -------------------------------
# 离散动作策略模型：输出 NUM_ACTIONS 个 logits，对应 0.01~10.00 mL 的离散体积
# -------------------------------
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, num_actions=NUM_ACTIONS):
        super(DiscreteVolumeRegressor, self).__init__()
        self.num_actions = num_actions
        self.net = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM1),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM1, HIDDEN_DIM2),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM2, num_actions)
        )
        # 生成离散体积列表，用于将类别索引映射为实际体积
        self.discrete_volumes = [round(MIN_VOLUME + i * STEP, 2) for i in range(num_actions)]
    
    def forward(self, x):
        return self.net(x)
    
    # 推理时，使用 argmax 选取概率最大的类别，并映射回体积
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        # 将类别索引转换为体积
        predicted_volumes = [self.discrete_volumes[idx] for idx in predicted_indices.tolist()]
        return torch.tensor(predicted_volumes, dtype=torch.float32).unsqueeze(1)
    
    # 如果需要采样动作，也可以采用 Categorical 分布
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob

# -------------------------------
# 工具函数：加载 JSON 数据
# -------------------------------
def load_json_file(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

# -------------------------------
# 工具函数：在数据加载器上评估模型（计算交叉熵损失和准确率）
# -------------------------------
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    count = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            logits = model(inputs)
            loss = criterion(logits, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(logits, dim=1)
            total_correct += (predicted == labels).sum().item()
            count += inputs.size(0)
    avg_loss = total_loss / count
    accuracy = total_correct / count
    return avg_loss, accuracy

# -------------------------------
# 主训练流程：加载 train_set_big.json、validation_set_big.json、test_set_big.json 进行训练、验证和测试
# -------------------------------
def main():
    # 加载数据集
    train_data = load_json_file('train_set_big_new1.json')
    val_data = load_json_file('validation_set_big_new1.json')
    test_data = load_json_file('test_set_big_new1.json')
    
    train_dataset = VolumePredictionDataset(train_data)
    val_dataset = VolumePredictionDataset(val_data)
    test_dataset = VolumePredictionDataset(test_data)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # 初始化模型、优化器与损失函数（交叉熵损失）
    model = DiscreteVolumeRegressor(INPUT_DIM, NUM_ACTIONS)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    
    best_val_loss = float('inf')
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
        for inputs, labels in train_loader:
            logits = model(inputs)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * inputs.size(0)
        avg_train_loss = total_loss / len(train_dataset)
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "volume_regressor_best_big_discrete_new1-test.pth")
            print("Saved best model with Val Loss: {:.4f}".format(val_loss))
    
    # 测试阶段
    model.load_state_dict(torch.load("volume_regressor_best_big_discrete_new1-test.pth"))
    test_loss, test_acc = evaluate(model, test_loader, criterion)
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    # 推理示例：示例输入：[当前 pH=9.0, 目标 pH=2.0, pH变化=-0.5, error=7.0, 上一步滴加体积=0.05]
    example_input = [9.0, 2.0, -0.5, 7.0, 0.05]
    input_tensor = torch.tensor(example_input, dtype=torch.float32).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        predicted_volume = model.predict_volume(input_tensor).item()
    print("对于输入", example_input, "预测体积为:", predicted_volume)

if __name__ == '__main__':
    main()


In [None]:
# 在测试集上测试离散回归模型

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np

# -------------------------------
# 数据集类，与训练时一致
# -------------------------------
class VolumePredictionDataset(Dataset):
    def __init__(self, dataset):
        # 转换为 numpy 数组
        obs = np.array(dataset['observations'])
        acts = np.array(dataset['actions'])
        
        # 只保留动作类别为 0 和 2 的样本
        mask = np.isin(acts[:, 0], [0, 2])
        obs = obs[mask]
        acts = acts[mask]
        
        # 提取输入特征：
        # 当前 pH（索引 0）、目标 pH（索引 1）、pH变化（索引 7）、误差（当前 pH - 目标 pH）
        # 以及上一步滴加的体积（索引 8）
        current_ph = obs[:, 0]
        target_ph = obs[:, 1]
        ph_change = obs[:, 7]
        error = current_ph - target_ph
        last_added_volume = obs[:, 8]
        inputs = np.stack([current_ph, target_ph, ph_change, error, last_added_volume], axis=1)
        self.inputs = torch.tensor(inputs, dtype=torch.float32)
        
        # 标签：动作中的体积（第二个数值），作为回归目标
        self.labels = torch.tensor(acts[:, 1], dtype=torch.float32).unsqueeze(1)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

# -------------------------------
# 离散动作策略模型
# -------------------------------
INPUT_DIM = 5
HIDDEN_DIM1 = 256
HIDDEN_DIM2 = 256

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        # 生成离散动作列表：0.01 ~ 10.00 mL
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM1),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM1, HIDDEN_DIM2),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM2, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    # 推理时，使用 argmax 选择概率最大的类别，再映射到实际体积
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

# -------------------------------
# 工具函数：加载 JSON 数据
# -------------------------------
def load_json_file(filename):
    with open(filename, 'r') as f:
        data = json.load(f)
    return data

# -------------------------------
# 测试集评价函数：计算 MSE, MAE 和 R²
# -------------------------------
def evaluate_model(model, dataloader):
    mse_criterion = nn.MSELoss()
    mae_criterion = nn.L1Loss()
    
    total_mse_loss = 0.0
    total_samples = 0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            # 对于每个批次中的每个样本，使用 predict_volume 得到预测体积
            batch_preds = []
            for i in range(inputs.size(0)):
                x = inputs[i].unsqueeze(0)
                pred = model.predict_volume(x)
                batch_preds.append(pred)
            batch_preds = torch.cat(batch_preds, dim=0)
            
            mse_loss = mse_criterion(batch_preds, labels)
            total_mse_loss += mse_loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
            all_preds.append(batch_preds)
            all_labels.append(labels)
    
    avg_mse_loss = total_mse_loss / total_samples
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    mae_loss = mae_criterion(all_preds, all_labels).item()
    
    ss_res = torch.sum((all_labels - all_preds) ** 2)
    mean_labels = torch.mean(all_labels)
    ss_tot = torch.sum((all_labels - mean_labels) ** 2)
    r2_score = 1 - ss_res / ss_tot
    
    print("Test MSE Loss: {:.4f}".format(avg_mse_loss))
    print("Test MAE Loss: {:.4f}".format(mae_loss))
    print("Test R² Score: {:.4f}".format(r2_score.item()))

# -------------------------------
# 主测试流程
# -------------------------------
if __name__ == '__main__':
    # 加载测试集数据
    test_data = load_json_file('test_set_big_new1.json')
    test_dataset = VolumePredictionDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # 初始化离散模型，并加载预训练模型状态
    model = DiscreteVolumeRegressor(INPUT_DIM, min_volume=0.01, max_volume=10.0, step=0.01)
    model.load_state_dict(torch.load("volume_regressor_best_big_discrete_new1-test.pth", map_location=torch.device('cpu')))
    model.eval()
    
    # 在测试集上评价模型
    evaluate_model(model, test_loader)


In [None]:
# 现在使用的强化

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json

##############################################
# 固定随机种子，确保实验可重复
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# 全局常量
##############################################
TITRANT_CONC = 0.1          # 滴定剂浓度（0.1 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH 误差阈值

##############################################
# pH 计算函数：单元酸
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH 计算函数：双元酸
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH 计算函数：三元酸
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# 奖励计算函数（修改版）
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("step_penalty", -0.005)
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    # 如果发生过冲，则计算过冲惩罚
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    # 如果上一“步”发生过冲，则对当前动作体积施加负向惩罚，
    # 若当前动作体积比上次过冲时的体积更小，则给予一定正向奖励
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    # 非终端状态下对奖励做剪切
    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 4.0)
        reward_clip_min = reward_config.get("reward_clip_min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH 模拟环境：PHSimEnv（修改版）
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol  # mL
        self.analyte_conc = analyte_conc          # 0.1 M
        self.titrant_conc = titrant_conc          # 0.1 M
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "dense_lambda": -0.03,
            "step_penalty": 0,
            "terminal_bonus": 3.9, 
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 1,
            "reward_clip_max": 4.1,
            "reward_clip_min": -4.1,
            "overshoot_volume_penalty": 0.1,
            "overshoot_volume_bonus": 0.1
        }
        # 随机生成 30 组酸性参数供不同酸类型使用
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        # 随机选择酸类型及参数
        self.acid_type = random.choice(['monoprotic', 'diprotic', 'triprotic'])
        if self.acid_type == 'monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        else:  # triprotic
            self.acid_params = random.choice(self.triprotic_pKa_list)
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        # 初始化过冲相关标志
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        # 根据酸类型初始化当前 pH
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:  # triprotic
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        # 初始时将上一状态 pH 与当前 pH 设为相同
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        # 状态向量：当前 pH、目标 pH、pH变化、误差、最后动作体积
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        # 根据当前 pH 与目标 pH 的关系选择加入碱或酸
        if self.current_ph < self.target_ph:
            reagent = "strong_base"
            self.base_added_mL += volume
        else:
            reagent = "strong_acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        # 保存当前 pH 为上一状态
        self.previous_ph = self.current_ph

        # 更新当前 pH（根据酸类型调用相应函数）
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        # 使用新的奖励函数计算奖励
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume
        )
        
        # 判断本步是否发生过冲（即：pH 从一侧越过目标 pH）
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'reagent': reagent}

##############################################
# 离散动作策略模型：DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# 在线训练：使用 REINFORCE 算法更新策略模型
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=0.99):
    best_error = float('inf')  # 跟踪最佳误差
    best_model_state = None    # 存储最佳模型状态

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()  # 用于传入环境
            next_state, reward, done, _ = env.step(action_scalar)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        # 计算折扣回报
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # 计算当前 episode 的最终误差
        current_error = abs(env.current_ph - env.target_ph)
        
        # 如果当前误差优于最佳误差，保存模型
        if current_error < best_error:
            best_error = current_error
            best_model_state = policy_model.state_dict().copy()
            torch.save(best_model_state, "volume_regressor_best_big_discrete_new1_trained-1-test.pth")
            print(f"Episode {episode}, Loss: {loss.item():.4f}, Updated Best Model with Error: {best_error:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")
        elif episode % 50 == 0:  # 仅在非最佳时每 50 步打印一次
            total_reward = sum(rewards)
            print(f"Episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

##############################################
# 测试函数：运行实验并打印每一步详情
##############################################
def test_model(policy_model, env, num_experiments=10):
    for i in range(num_experiments):
        print(f"\n==== 实验 {i+1} 开始 ====")
        state = env.reset()
        print(f"初始状态: {state}")
        print(f"酸类型: {env.acid_type}, 参数: {env.acid_params}, 目标 pH: {env.target_ph}")
        done = False
        steps = 0
        experiment_trace = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar)
            experiment_trace.append((state, action_scalar, info.get('reagent', '')))
            steps += 1
        for j, (s, a, reagent) in enumerate(experiment_trace):
            print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
        print(f"实验结束，共用步数: {steps}")

##############################################
# 主程序：加载预训练模型（如果可用）并训练、测试
##############################################
if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    pretrained_path = "volume_regressor_best_big_discrete_new1-test.pth"
    try:
        state_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
        policy_model.load_state_dict(state_dict)
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载预训练模型，使用随机初始化模型。", e)
    
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=gamma)
    # 训练结束时保存最后一次模型（可选）
    torch.save(policy_model.state_dict(), "volume_regressor_best_big_discrete_new1_trained-1-test.pth")
    print("模型已保存。")
    
    test_model(policy_model, env, num_experiments=10)

In [None]:
# 边训边保存

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json

##############################################
# 固定随机种子，确保实验可重复
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# 全局常量
##############################################
TITRANT_CONC = 0.1          # 滴定剂浓度（0.1 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH 误差阈值

##############################################
# pH 计算函数：单元酸
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH 计算函数：双元酸
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH 计算函数：三元酸
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# 奖励计算函数（修改版）
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("step_penalty", -0.005)
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    # 如果发生过冲，则计算过冲惩罚
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    # 如果上一“步”发生过冲，则对当前动作体积施加负向惩罚，
    # 若当前动作体积比上次过冲时的体积更小，则给予一定正向奖励
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    # 非终端状态下对奖励做剪切
    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 4.0)
        reward_clip_min = reward_config.get("reward_clip_min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH 模拟环境：PHSimEnv（修改版）
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol  # mL
        self.analyte_conc = analyte_conc          # 0.1 M
        self.titrant_conc = titrant_conc          # 0.1 M
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "dense_lambda": -0.03,
            "step_penalty": 0,
            "terminal_bonus": 3.9, 
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 1,
            "reward_clip_max": 4.1,
            "reward_clip_min": -4.1,
            "overshoot_volume_penalty": 0.1,
            "overshoot_volume_bonus": 0.1
        }
        # 随机生成 30 组酸性参数供不同酸类型使用
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        # 随机选择酸类型及参数
        self.acid_type = random.choice(['monoprotic', 'diprotic', 'triprotic'])
        if self.acid_type == 'monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        else:  # triprotic
            self.acid_params = random.choice(self.triprotic_pKa_list)
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        # 初始化过冲相关标志
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        # 根据酸类型初始化当前 pH
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:  # triprotic
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        # 初始时将上一状态 pH 与当前 pH 设为相同
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        # 状态向量：当前 pH、目标 pH、pH变化、误差、最后动作体积
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        # 根据当前 pH 与目标 pH 的关系选择加入碱或酸
        if self.current_ph < self.target_ph:
            reagent = "strong_base"
            self.base_added_mL += volume
        else:
            reagent = "strong_acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        # 保存当前 pH 为上一状态
        self.previous_ph = self.current_ph

        # 更新当前 pH（根据酸类型调用相应函数）
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        # 使用新的奖励函数计算奖励
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume
        )
        
        # 判断本步是否发生过冲（即：pH 从一侧越过目标 pH）
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'reagent': reagent}

##############################################
# 离散动作策略模型：DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# 在线训练：使用 REINFORCE 算法更新策略模型
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=0.99):
    best_success_count = 0  # 跟踪最佳成功次数
    best_avg_steps = float('inf')  # 跟踪平均步数，初始为无穷大
    best_model_state = None  # 存储最佳模型状态
    total_episodes = 0  # 总 episode 数，用于计算成功率
    successful_steps = 0  # 成功 episode 的总步数

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()  # 用于传入环境
            next_state, reward, done, _ = env.step(action_scalar)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        # 计算折扣回报
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # 判断当前 episode 是否成功
        is_success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD and env.steps < MAX_STEPS
        total_episodes += 1
        if is_success:
            successful_steps += env.steps
        
        # 更新成功率和平均步数
        current_success_count = successful_steps  # 累计成功次数
        current_avg_steps = successful_steps / max(1, current_success_count) if current_success_count > 0 else float('inf')
        
        # 保存条件：优先成功率，成功率相同时比较平均步数
        if current_success_count > best_success_count or \
           (current_success_count == best_success_count and current_avg_steps < best_avg_steps):
            best_success_count = current_success_count
            best_avg_steps = current_avg_steps
            best_model_state = policy_model.state_dict().copy()
            torch.save(best_model_state, "volume_regressor_best_big_discrete_new1_trained-1-test-保存最优.pth")
            print(f"Episode {episode}, Loss: {loss.item():.4f}, Updated Best Model with Success Count: {best_success_count}, Avg Steps: {best_avg_steps:.2f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")
        elif episode % 50 == 0:  # 仅在非最佳时每 50 步打印一次
            total_reward = sum(rewards)
            print(f"Episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

##############################################
# 测试函数：运行实验并打印每一步详情
##############################################
def test_model(policy_model, env, num_experiments=10):
    for i in range(num_experiments):
        print(f"\n==== 实验 {i+1} 开始 ====")
        state = env.reset()
        print(f"初始状态: {state}")
        print(f"酸类型: {env.acid_type}, 参数: {env.acid_params}, 目标 pH: {env.target_ph}")
        done = False
        steps = 0
        experiment_trace = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar)
            experiment_trace.append((state, action_scalar, info.get('reagent', '')))
            steps += 1
        for j, (s, a, reagent) in enumerate(experiment_trace):
            print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
        print(f"实验结束，共用步数: {steps}")

##############################################
# 主程序：加载预训练模型（如果可用）并训练、测试
##############################################
if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    pretrained_path = "volume_regressor_best_big_discrete_new1-test.pth"
    try:
        state_dict = torch.load(pretrained_path, map_location=torch.device('cpu'))
        policy_model.load_state_dict(state_dict)
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载预训练模型，使用随机初始化模型。", e)
    
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=gamma)
    print("训练完成，最优模型已保存。")
    
    test_model(policy_model, env, num_experiments=10)

In [None]:
# 强化学习做实验

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
from scipy.optimize import fsolve

# 固定随机种子
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

##############################################
# 全局常量
##############################################
TITRANT_CONC = 0.1          # 滴定剂浓度（0.1 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH误差阈值

##############################################
# pH 计算函数（单、双、三元酸） —— 保持与训练时一致
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# 环境与奖励函数
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD,
                     prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("step_penalty", -0.005)
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 4.0)
        reward_clip_min = reward_config.get("reward_clip_min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.titrant_conc = titrant_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "dense_lambda": 1.0,
            "step_penalty": -0.005,
            "terminal_bonus": 80,
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 1.0,
            "reward_clip_max": 2.0,
            "reward_clip_min": -2.0
        }
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        self.acid_type = random.choice(['monoprotic', 'diprotic', 'triprotic'])
        if self.acid_type == 'monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        elif self.acid_type == 'triprotic':
            self.acid_params = random.choice(self.triprotic_pKa_list)
        # 这里原始代码随机生成目标 pH，后续会由用户输入覆盖
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        # 这里的step方法在自动模拟时会更新状态，
        # 但在手动实验中我们采用用户输入的 pH 来更新状态，
        # 因此不直接调用该方法。
        pass

##############################################
# 离散动作策略模型：动作空间 [0.01, 10] mL，步长 0.01 mL，共 1000 个离散动作
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

##############################################
# 交互式手动滴定实验
# 说明：先输入初始 pH 与目标 pH；随后模型给出建议动作，
#        你在实验室进行操作后输入测得的 pH，状态更新后继续给出建议。
##############################################
def interactive_titration_manual(env, policy_model):
    # 输入初始 pH 与目标 pH
    try:
        init_ph = float(input("请输入初始 pH 值: "))
        target_ph = float(input("请输入目标 pH 值: "))
    except ValueError:
        print("输入格式错误，使用环境默认值。")
        init_ph = env.current_ph
        target_ph = env.target_ph

    # 重置环境并覆盖初始 pH 与目标 pH
    state = env.reset()
    env.current_ph = init_ph
    env.previous_ph = init_ph
    env.target_ph = target_ph
    print(f"\n初始 pH: {env.current_ph:.2f}，目标 pH: {env.target_ph:.2f}\n")

    done = False
    while not done:
        # 打印当前状态
        print(f"当前 pH: {env.current_ph:.2f}")
        # 更新状态向量（利用最新的 pH 值）
        state = env._get_state()
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        
        # 根据当前状态，模型给出建议的加入体积
        with torch.no_grad():
            recommended_action, _ = policy_model.sample_action(state_tensor)
            recommended_volume = recommended_action.item()
        
        # 根据当前 pH 与目标 pH 判断建议使用的试剂（与模拟环境逻辑一致）
        if env.current_ph < env.target_ph:
            recommended_reagent = "strong_base"
        else:
            recommended_reagent = "strong_acid"
        print(f"建议: 加入 {recommended_volume:.2f} mL {recommended_reagent}")
        
        # 允许用户选择是否直接采用建议或输入自定义体积
        user_choice = input("是否采用建议体积？(直接回车采用，n 输入自定义体积): ")
        if user_choice.strip().lower() == "n":
            try:
                action = float(input("请输入实际加入体积 (mL): "))
            except ValueError:
                print("输入错误，采用建议值。")
                action = recommended_volume
        else:
            action = recommended_volume
        
        # 提示用户在实验室进行操作后，输入测得的 pH 值
        measured_ph = None
        while measured_ph is None:
            try:
                measured_ph = float(input("请输入操作后测得的 pH 值: "))
            except ValueError:
                print("输入格式错误，请输入数字。")
        
        # 更新状态：记录上一步 pH，并将当前 pH 更新为用户输入的测得值
        env.previous_ph = env.current_ph
        env.current_ph = measured_ph
        env.last_action_volume = action
        env.steps += 1
        
        # 更新试剂累计加入量（按照模拟中固定逻辑：若之前 pH 小于目标则为加碱，否则加酸）
        if env.previous_ph < env.target_ph:
            env.base_added_mL += action
            reagent_used = "strong_base"
        else:
            env.acid_added_mL += action
            reagent_used = "strong_acid"
        env.total_volume = env.initial_acid_vol + env.base_added_mL + env.acid_added_mL
        
        print(f"操作: 加入 {action:.2f} mL {reagent_used}，测得 pH: {env.current_ph:.2f}\n")
        
        # 检查是否满足终止条件
        if abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD:
            print("成功达到目标 pH！")
            done = True
        elif env.steps >= MAX_STEPS:
            print("达到最大步数，实验结束。")
            done = True
                                
    print(f"实验结束，共用了 {env.steps} 步, 最终 pH: {env.current_ph:.2f}")

##############################################
# 主程序：加载预训练模型（如果存在）并进入交互模式
##############################################
if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    # 初始化环境
    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=0.1)
    
    # 初始化离散动作模型
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    # 尝试加载预训练模型状态（文件名请保持一致）
    try:
        policy_model.load_state_dict(torch.load("volume_regressor_best_big_discrete_new1_trained-1-test.pth", map_location=torch.device('cpu')))
        print("加载离散预训练模型成功。\n")
    except Exception as e:
        print("未能加载离散预训练模型，使用随机初始化模型。\n", e)
    
    policy_model.eval()
    
    # 进入交互式手动滴定实验模式
    interactive_titration_manual(env, policy_model)


In [None]:
# 网络做同样的实验

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
import logging
from scipy.optimize import fsolve
import csv
import ast

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 固定随机种子
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# 全局常量
TITRANT_CONC1 = 0.1         # 主滴定剂浓度（0.1 M）
TITRANT_CONC2 = 0.01        # 次级滴定剂浓度（0.01 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH误差阈值
MIN_ADDITION_VOLUME = 0.01  # 最小滴加量 (mL)

REAGENTS = {
    'strong_base_1': TITRANT_CONC1,
    'strong_base_2': TITRANT_CONC2,
    'strong_acid_1': TITRANT_CONC1,
    'strong_acid_2': TITRANT_CONC2,
}

# pH 计算函数
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

# 环境与奖励函数
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("step_penalty", -0.005)
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 80.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 2.0)
        reward_clip_min = reward_config.get("reward_clip_min", -2.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = MIN_ADDITION_VOLUME
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1001)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() for volume in self.addition_volumes]
        self.reward_config = {
            "dense_lambda": 1.0,
            "step_penalty": -0.005,
            "terminal_bonus": 80,
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 60.0,
            "reward_clip_max": 2.0,
            "reward_clip_min": -2.0,
            "overshoot_volume_penalty": 0.1,
            "overshoot_volume_bonus": 0.1
        }
        # 初始化时不调用 reset，等待 test_model 提供 CSV 参数
        self.acid_type = None
        self.acid_params = None
        self.target_ph = None
        self.current_ph = None

    def reset(self, acid_type=None, acid_params=None, target_ph=None, initial_ph=None):
        if acid_type is None or acid_params is None:
            raise ValueError("acid_type and acid_params must be provided from CSV data")
        
        self.acid_type = acid_type
        self.acid_params = acid_params
        self.target_ph = float(target_ph) if target_ph is not None else 7.0
        self.base1_added_mL = 0.0
        self.base2_added_mL = 0.0
        self.acid1_added_mL = 0.0
        self.acid2_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.last_added_moles = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        
        if self.acid_type == 'monoprotic':
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_monoprotic(0.0, 0.0, 0.0, 0.0, float(self.acid_params))
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_diprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_triprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2, pKa3)
        else:
            raise ValueError(f"Unknown acid_type: {self.acid_type}")
        
        self.previous_ph = self.current_ph
        self.last_measured_ph = self.current_ph
        self.prev_measured_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2) if self.current_ph is not None and self.previous_ph is not None else 0.0
        error = round(self.current_ph - self.target_ph, 2) if self.current_ph is not None and self.target_ph is not None else 0.0
        return np.array([self.current_ph or 0.0, self.target_ph or 7.0, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def step(self, action):
        reagent, volume = action
        volume = float(volume)
        self.last_action_volume = volume
        self.last_added_moles = self.reagents[reagent] * (volume / 1000.0)
        self.steps += 1
        self.previous_ph = self.current_ph
        self.prev_measured_ph = self.last_measured_ph
        if reagent == 'strong_base_1':
            self.base1_added_mL += volume
        elif reagent == 'strong_base_2':
            self.base2_added_mL += volume
        elif reagent == 'strong_acid_1':
            self.acid1_added_mL += volume
        elif reagent == 'strong_acid_2':
            self.acid2_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base1_added_mL + self.base2_added_mL + self.acid1_added_mL + self.acid2_added_mL
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, float(self.acid_params)
            )
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2
            )
        elif self.acid_type == 'triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2, pKa3
            )
        self.last_measured_ph = self.current_ph

        if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
            if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                self.oscillation_count += 1
                logging.info(f"检测到在最小滴加量下的pH振荡，累计次数：{self.oscillation_count}")
                if self.oscillation_count >= 3:
                    self.use_secondary_reagents = True
                    logging.info("达到连续震荡阈值，切换到次级试剂滴定。")

        overshoot_flag, new_threshold = self.detect_overshoot(
            self.previous_ph, self.current_ph, self.target_ph, reagent,
            self.last_added_moles, self.reagents[reagent], self.min_addition_volume
        )
        if overshoot_flag:
            self.overshoot_occurred = True
            self.overshoot_reagent = reagent
            if new_threshold is not None:
                if self.overshoot_threshold is None or new_threshold < self.overshoot_threshold:
                    self.overshoot_threshold = new_threshold

        state = self._get_state()
        reward, done = calculate_reward(
            self.previous_ph, self.current_ph, self.target_ph, self.steps, MAX_STEPS, reagent,
            self.reward_config, SUCCESS_THRESHOLD, self.overshoot_occurred, self.overshoot_threshold, volume
        )
        return state, reward, done, {'reagent': reagent}

    def select_best_action(self, state_tensor, policy_model):
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        
        logging.info(f"候选动作: {candidate_actions[:5]}... (共{len(candidate_actions)}个)")
        
        with torch.no_grad():
            logits = policy_model(state_tensor)
            candidate_indices = [self.addition_volumes.index(a[1]) for a in candidate_actions]
            candidate_logits = logits[0, candidate_indices]
            best_index = candidate_indices[candidate_logits.argmax().item()]
            best_action = candidate_actions[candidate_logits.argmax().item()]
        
        logging.info(f"选择动作: {best_action}")
        return best_action

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

def load_experiment_conditions(csv_file):
    experiments = []
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            acid_type = row['Acid_Type']
            acid_params = ast.literal_eval(row['Acid_Params'])
            initial_ph = float(row['Initial_pH'])
            target_ph = float(row['Target_pH'])
            experiments.append({
                'acid_type': acid_type,
                'acid_params': acid_params,
                'initial_ph': initial_ph,
                'target_ph': target_ph
            })
    return experiments

def test_model(policy_model, csv_file="experiment_summary.csv", output_file="test_output2_modified.txt", summary_file="experiment_summary_rl.csv"):
    experiments = load_experiment_conditions(csv_file)
    num_experiments = len(experiments)
    success_count = 0
    total_steps_success = []
    
    with open(output_file, 'w', encoding='utf-8') as f, open(summary_file, 'w', newline='', encoding='utf-8') as summary_f:
        def log_and_print(message):
            print(message)
            f.write(message + '\n')
        
        csv_writer = csv.writer(summary_f)
        csv_writer.writerow(['Experiment', 'Acid_Type', 'Acid_Params', 'Initial_pH', 'Target_pH', 'Final_pH', 'Steps_Taken', 'Success'])
        
        for i, exp in enumerate(experiments, 1):
            log_and_print(f"\n==== 实验 {i} 开始 ====")
            acid_type = exp['acid_type']
            acid_params = exp['acid_params']
            initial_ph = exp['initial_ph']
            target_ph = exp['target_ph']
            state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph, initial_ph=initial_ph)
            log_and_print(f"初始状态: {state}")
            log_and_print(f"酸类型: {acid_type}, 参数: {acid_params}, 目标 pH: {target_ph}")
            done = False
            steps = 0
            experiment_trace = []
            while not done:
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
                log_and_print(f"当前状态向量: {state}")
                action = env.select_best_action(state_tensor, policy_model)
                state, reward, done, info = env.step(action)
                experiment_trace.append((state, action[1], info.get('reagent', '')))
                steps += 1
            log_and_print("状态-动作-试剂对:")
            for j, (s, a, reagent) in enumerate(experiment_trace):
                log_and_print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
            log_and_print(f"实验结束，共用步数: {steps}, 最终 pH: {env.current_ph:.2f}")
            success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD
            if success:
                success_count += 1
                total_steps_success.append(steps)
            
            acid_params_str = f"{acid_params}" if isinstance(acid_params, (list, tuple)) else f"{acid_params:.2f}"
            csv_writer.writerow([i, acid_type, acid_params_str, f"{initial_ph:.2f}", f"{target_ph:.2f}",
                                f"{env.current_ph:.2f}", steps, 'Yes' if success else 'No'])
        
        success_rate = success_count / num_experiments * 100
        avg_steps_success = np.mean(total_steps_success) if total_steps_success else 0
        summary_stats = f"\n测试完成：成功率 = {success_rate:.2f}%, 成功实验平均步数 = {avg_steps_success:.2f}"
        log_and_print(summary_stats)

if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1)
    
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    try:
        policy_model.load_state_dict(torch.load("volume_regressor_best_big_discrete_new1_trained-1-test-保存最优.pth", map_location=torch.device('cpu')))
        print("加载离散预训练模型成功。")
    except Exception as e:
        print("未能加载离散预训练模型，使用随机初始化模型。", e)
    
    policy_model.eval()
    
    test_model(policy_model, csv_file="experiment_summary.csv", output_file="保存最优神经网络同样的实验只用浓酸碱.txt", summary_file="保存最优神经网络同样的实验只用浓酸碱.csv")

In [None]:
# 评估强化前的模型

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
import logging
from scipy.optimize import fsolve
import csv
import ast

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 固定随机种子
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# 全局常量
TITRANT_CONC1 = 0.1         # 主滴定剂浓度（0.1 M）
TITRANT_CONC2 = 0.1        # 次级滴定剂浓度（0.01 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH误差阈值
MIN_ADDITION_VOLUME = 0.01  # 最小滴加量 (mL)

REAGENTS = {
    'strong_base_1': TITRANT_CONC1,
    'strong_base_2': TITRANT_CONC2,
    'strong_acid_1': TITRANT_CONC1,
    'strong_acid_2': TITRANT_CONC2,
}

# pH 计算函数
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

# 环境与奖励函数
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("step_penalty", -0.005)
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 80.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 2.0)
        reward_clip_min = reward_config.get("reward_clip_min", -2.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = MIN_ADDITION_VOLUME
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1001)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() for volume in self.addition_volumes]
        self.reward_config = {
            "dense_lambda": 1.0,
            "step_penalty": -0.005,
            "terminal_bonus": 80,
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 60.0,
            "reward_clip_max": 2.0,
            "reward_clip_min": -2.0,
            "overshoot_volume_penalty": 0.1,
            "overshoot_volume_bonus": 0.1
        }
        # 初始化时不调用 reset，等待 test_model 提供 CSV 参数
        self.acid_type = None
        self.acid_params = None
        self.target_ph = None
        self.current_ph = None

    def reset(self, acid_type=None, acid_params=None, target_ph=None, initial_ph=None):
        if acid_type is None or acid_params is None:
            raise ValueError("acid_type and acid_params must be provided from CSV data")
        
        self.acid_type = acid_type
        self.acid_params = acid_params
        self.target_ph = float(target_ph) if target_ph is not None else 7.0
        self.base1_added_mL = 0.0
        self.base2_added_mL = 0.0
        self.acid1_added_mL = 0.0
        self.acid2_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.last_added_moles = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        
        if self.acid_type == 'monoprotic':
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_monoprotic(0.0, 0.0, 0.0, 0.0, float(self.acid_params))
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_diprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_triprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2, pKa3)
        else:
            raise ValueError(f"Unknown acid_type: {self.acid_type}")
        
        self.previous_ph = self.current_ph
        self.last_measured_ph = self.current_ph
        self.prev_measured_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2) if self.current_ph is not None and self.previous_ph is not None else 0.0
        error = round(self.current_ph - self.target_ph, 2) if self.current_ph is not None and self.target_ph is not None else 0.0
        return np.array([self.current_ph or 0.0, self.target_ph or 7.0, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def step(self, action):
        reagent, volume = action
        volume = float(volume)
        self.last_action_volume = volume
        self.last_added_moles = self.reagents[reagent] * (volume / 1000.0)
        self.steps += 1
        self.previous_ph = self.current_ph
        self.prev_measured_ph = self.last_measured_ph
        if reagent == 'strong_base_1':
            self.base1_added_mL += volume
        elif reagent == 'strong_base_2':
            self.base2_added_mL += volume
        elif reagent == 'strong_acid_1':
            self.acid1_added_mL += volume
        elif reagent == 'strong_acid_2':
            self.acid2_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base1_added_mL + self.base2_added_mL + self.acid1_added_mL + self.acid2_added_mL
        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, float(self.acid_params)
            )
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2
            )
        elif self.acid_type == 'triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2, pKa3
            )
        self.last_measured_ph = self.current_ph

        if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
            if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                self.oscillation_count += 1
                logging.info(f"检测到在最小滴加量下的pH振荡，累计次数：{self.oscillation_count}")
                if self.oscillation_count >= 3:
                    self.use_secondary_reagents = True
                    logging.info("达到连续震荡阈值，切换到次级试剂滴定。")

        overshoot_flag, new_threshold = self.detect_overshoot(
            self.previous_ph, self.current_ph, self.target_ph, reagent,
            self.last_added_moles, self.reagents[reagent], self.min_addition_volume
        )
        if overshoot_flag:
            self.overshoot_occurred = True
            self.overshoot_reagent = reagent
            if new_threshold is not None:
                if self.overshoot_threshold is None or new_threshold < self.overshoot_threshold:
                    self.overshoot_threshold = new_threshold

        state = self._get_state()
        reward, done = calculate_reward(
            self.previous_ph, self.current_ph, self.target_ph, self.steps, MAX_STEPS, reagent,
            self.reward_config, SUCCESS_THRESHOLD, self.overshoot_occurred, self.overshoot_threshold, volume
        )
        return state, reward, done, {'reagent': reagent}

    def select_best_action(self, state_tensor, policy_model):
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        
        logging.info(f"候选动作: {candidate_actions[:5]}... (共{len(candidate_actions)}个)")
        
        with torch.no_grad():
            logits = policy_model(state_tensor)
            candidate_indices = [self.addition_volumes.index(a[1]) for a in candidate_actions]
            candidate_logits = logits[0, candidate_indices]
            best_index = candidate_indices[candidate_logits.argmax().item()]
            best_action = candidate_actions[candidate_logits.argmax().item()]
        
        logging.info(f"选择动作: {best_action}")
        return best_action

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

def load_experiment_conditions(csv_file):
    experiments = []
    with open(csv_file, 'r', encoding='utf-8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            acid_type = row['Acid_Type']
            acid_params = ast.literal_eval(row['Acid_Params'])
            initial_ph = float(row['Initial_pH'])
            target_ph = float(row['Target_pH'])
            experiments.append({
                'acid_type': acid_type,
                'acid_params': acid_params,
                'initial_ph': initial_ph,
                'target_ph': target_ph
            })
    return experiments

def test_model(policy_model, csv_file="experiment_summary.csv", output_file="test_output2_modified.txt", summary_file="experiment_summary_rl.csv"):
    experiments = load_experiment_conditions(csv_file)
    num_experiments = len(experiments)
    success_count = 0
    total_steps_success = []
    
    with open(output_file, 'w', encoding='utf-8') as f, open(summary_file, 'w', newline='', encoding='utf-8') as summary_f:
        def log_and_print(message):
            print(message)
            f.write(message + '\n')
        
        csv_writer = csv.writer(summary_f)
        csv_writer.writerow(['Experiment', 'Acid_Type', 'Acid_Params', 'Initial_pH', 'Target_pH', 'Final_pH', 'Steps_Taken', 'Success'])
        
        for i, exp in enumerate(experiments, 1):
            log_and_print(f"\n==== 实验 {i} 开始 ====")
            acid_type = exp['acid_type']
            acid_params = exp['acid_params']
            initial_ph = exp['initial_ph']
            target_ph = exp['target_ph']
            state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph, initial_ph=initial_ph)
            log_and_print(f"初始状态: {state}")
            log_and_print(f"酸类型: {acid_type}, 参数: {acid_params}, 目标 pH: {target_ph}")
            done = False
            steps = 0
            experiment_trace = []
            while not done:
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
                log_and_print(f"当前状态向量: {state}")
                action = env.select_best_action(state_tensor, policy_model)
                state, reward, done, info = env.step(action)
                experiment_trace.append((state, action[1], info.get('reagent', '')))
                steps += 1
            log_and_print("状态-动作-试剂对:")
            for j, (s, a, reagent) in enumerate(experiment_trace):
                log_and_print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
            log_and_print(f"实验结束，共用步数: {steps}, 最终 pH: {env.current_ph:.2f}")
            success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD
            if success:
                success_count += 1
                total_steps_success.append(steps)
            
            acid_params_str = f"{acid_params}" if isinstance(acid_params, (list, tuple)) else f"{acid_params:.2f}"
            csv_writer.writerow([i, acid_type, acid_params_str, f"{initial_ph:.2f}", f"{target_ph:.2f}",
                                f"{env.current_ph:.2f}", steps, 'Yes' if success else 'No'])
        
        success_rate = success_count / num_experiments * 100
        avg_steps_success = np.mean(total_steps_success) if total_steps_success else 0
        summary_stats = f"\n测试完成：成功率 = {success_rate:.2f}%, 成功实验平均步数 = {avg_steps_success:.2f}"
        log_and_print(summary_stats)

if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1)
    
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    try:
        policy_model.load_state_dict(torch.load("volume_regressor_best_big_discrete_new1-test.pth", map_location=torch.device('cpu')))
        print("加载离散预训练模型成功。")
    except Exception as e:
        print("未能加载离散预训练模型，使用随机初始化模型。", e)
    
    policy_model.eval()
    
    test_model(policy_model, csv_file="experiment_summary.csv", output_file="强化前神经网络同样的实验只用浓酸碱.txt", summary_file="强化前神经网络同样的实验只用浓酸碱.csv")

In [None]:
# 评估贝叶斯final

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 全局参数（与代码2一致）
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.01
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.01
MAX_STEPS = 50

REAGENTS = {
    'dilute_acid_1': HCL_CONC1,
    'dilute_acid_2': HCL_CONC2,
    'dilute_base_1': NAOH_CONC1,
    'dilute_base_2': NAOH_CONC2,
}

# pH 计算函数（直接复用代码2）
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'last_action_volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
        if self.acid_type == "monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'dilute_base_2' if self.use_secondary_reagents else 'dilute_base_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'dilute_acid_2' if self.use_secondary_reagents else 'dilute_acid_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("检测到在最小滴加量下的pH振荡，累计次数：%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("达到连续震荡阈值，切换到次级试剂滴定。")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['pKa'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("执行 step 时出现异常：%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'PHAdjustmentEnv':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['pKa'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['pKa'].rvs())
            sampled_total_moles.append(prior['total_moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['pKa'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['total_moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# 辅助函数：计算初始 pH（模拟代码1的行为）
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "monoprotic":
        pKa = [acid_params]
    elif acid_type == "diprotic":
        pKa = acid_params
    elif acid_type == "triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# 修改后的主程序
def main():
    # 固定随机种子以确保可重复性（与代码1一致）
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    for exp in range(num_experiments):
        # 随机生成目标 pH 和初始 pH
        target_ph = round(random.uniform(2, 11), 2)
        env = PHAdjustmentEnv()

        # 随机选择酸类型并生成初始 pH
        env.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
        if env.acid_type == "monoprotic":
            env.acid_params = random.uniform(2, 6)
        elif env.acid_type == "diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            env.acid_params = [pKa1, pKa2]
        elif env.acid_type == "triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            env.acid_params = [pKa1, pKa2, pKa3]
        initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

        # 初始化环境
        env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

        print(f"==== 实验 {exp+1} 开始 ====")
        initial_state = env.get_state()
        print(f"初始状态: {np.round(initial_state, 2)}")
        print(f"酸类型: {env.acid_type}, 参数: {env.acid_params}, 目标 pH: {env.target_ph}")
        print("状态-动作-试剂对:")

        trace = []
        action, _ = env.select_best_action()
        while not env.done:
            state_before = env.get_state()
            current_ph, reward, done, info = env.step(action, mode='simulate')
            state_after = env.get_state()
            trace.append((state_after, action, action[0]))  # 记录试剂名称
            action, _ = env.select_best_action()

        for i, (s, a, reagent) in enumerate(trace, start=1):
            s_formatted = np.round(s, 2)
            print(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}")
        print(f"实验结束，共用步数: {env.steps_taken}, 最终 pH: {env.current_ph:.2f}\n")

        if abs(env.current_ph - env.target_ph) < 0.1:
            success_count += 1
            steps_success.append(env.steps_taken)

    success_rate = success_count / num_experiments * 100
    avg_steps = np.mean(steps_success) if steps_success else 0
    print("总实验数: {}, 成功实验数: {}, 成功率: {:.2f}%, 成功实验平均步数: {:.2f}".format(
        num_experiments, success_count, success_rate, avg_steps))

if __name__ == '__main__':
    main()

In [None]:
# 贝叶斯多日志

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq
import csv

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 全局参数（与代码2一致）
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.1
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.1
MAX_STEPS = 50

REAGENTS = {
    'dilute_acid_1': HCL_CONC1,
    'dilute_acid_2': HCL_CONC2,
    'dilute_base_1': NAOH_CONC1,
    'dilute_base_2': NAOH_CONC2,
}

# pH 计算函数（直接复用代码2）
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'last_action_volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
        if self.acid_type == "monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'dilute_base_2' if self.use_secondary_reagents else 'dilute_base_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'dilute_acid_2' if self.use_secondary_reagents else 'dilute_acid_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("检测到在最小滴加量下的pH振荡，累计次数：%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("达到连续震荡阈值，切换到次级试剂滴定。")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['pKa'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("执行 step 时出现异常：%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'PHAdjustmentEnv':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['pKa'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['pKa'].rvs())
            sampled_total_moles.append(prior['total_moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['pKa'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['total_moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# 辅助函数：计算初始 pH（模拟代码1的行为）
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "monoprotic":
        pKa = [acid_params]
    elif acid_type == "diprotic":
        pKa = acid_params
    elif acid_type == "triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# 修改后的主程序
def main():
    # 固定随机种子以确保可重复性（与代码1一致）
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    # 打开文件用于保存实验记录和摘要
    with open('贝叶斯只用浓酸碱.txt', 'w', encoding='utf-8') as log_file, \
         open('experiment_summary.csv', 'w', newline='', encoding='utf-8') as summary_file:
        
        # 初始化 CSV 写入器
        csv_writer = csv.writer(summary_file)
        csv_writer.writerow(['Experiment', 'Acid_Type', 'Acid_Params', 'Initial_pH', 'Target_pH', 'Final_pH', 'Steps_Taken', 'Success'])

        for exp in range(num_experiments):
            # 随机生成目标 pH 和初始 pH
            target_ph = round(random.uniform(2, 11), 2)
            env = PHAdjustmentEnv()

            # 随机选择酸类型并生成初始 pH
            env.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
            if env.acid_type == "monoprotic":
                env.acid_params = random.uniform(2, 6)
            elif env.acid_type == "diprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 7)
                env.acid_params = [pKa1, pKa2]
            elif env.acid_type == "triprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 6)
                pKa3 = random.uniform(6, 8)
                env.acid_params = [pKa1, pKa2, pKa3]
            initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

            # 初始化环境
            env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

            # 写入日志文件
            log_file.write(f"==== 实验 {exp+1} 开始 ====\n")
            initial_state = env.get_state()
            log_file.write(f"初始状态: {np.round(initial_state, 2)}\n")
            log_file.write(f"酸类型: {env.acid_type}, 参数: {env.acid_params}, 目标 pH: {env.target_ph}\n")
            log_file.write("状态-动作-试剂对:\n")

            trace = []
            action, _ = env.select_best_action()
            while not env.done:
                state_before = env.get_state()
                current_ph, reward, done, info = env.step(action, mode='simulate')
                state_after = env.get_state()
                trace.append((state_after, action, action[0]))  # 记录试剂名称
                action, _ = env.select_best_action()

            for i, (s, a, reagent) in enumerate(trace, start=1):
                s_formatted = np.round(s, 2)
                log_file.write(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}\n")
            log_file.write(f"实验结束，共用步数: {env.steps_taken}, 最终 pH: {env.current_ph:.2f}\n\n")

            # 判断是否成功
            success = abs(env.current_ph - env.target_ph) < 0.1
            if success:
                success_count += 1
                steps_success.append(env.steps_taken)

            # 写入 CSV 摘要
            acid_params_str = f"{env.acid_params}" if isinstance(env.acid_params, list) else f"{env.acid_params:.2f}"
            csv_writer.writerow([exp+1, env.acid_type, acid_params_str, f"{initial_ph:.2f}", f"{env.target_ph:.2f}",
                                f"{env.current_ph:.2f}", env.steps_taken, 'Yes' if success else 'No'])

        success_rate = success_count / num_experiments * 100
        avg_steps = np.mean(steps_success) if steps_success else 0
        summary_stats = f"总实验数: {num_experiments}, 成功实验数: {success_count}, 成功率: {success_rate:.2f}%, 成功实验平均步数: {avg_steps:.2f}\n"
        log_file.write(summary_stats)
        print(summary_stats)

if __name__ == '__main__':
    main()

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq
import sys
from io import StringIO

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 全局参数（与代码2一致）
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.1
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.1
MAX_STEPS = 50

REAGENTS = {
    'dilute_acid_1': HCL_CONC1,
    'dilute_acid_2': HCL_CONC2,
    'dilute_base_1': NAOH_CONC1,
    'dilute_base_2': NAOH_CONC2,
}

# pH 计算函数（直接复用代码2）
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'pKa': norm(loc=self.pKa_list[i], scale=0.5),
                'total_moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'last_action_volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
        if self.acid_type == "monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'dilute_base_2' if self.use_secondary_reagents else 'dilute_base_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'dilute_acid_2' if self.use_secondary_reagents else 'dilute_acid_1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("检测到在最小滴加量下的pH振荡，累计次数：%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("达到连续震荡阈值，切换到次级试剂滴定。")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['pKa'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("执行 step 时出现异常：%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'PHAdjustmentEnv':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'acid_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'base_1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_base_1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'dilute_acid_1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['pKa'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['pKa'].rvs())
            sampled_total_moles.append(prior['total_moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['pKa'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['total_moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# 辅助函数：计算初始 pH（模拟代码1的行为）
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "monoprotic":
        pKa = [acid_params]
    elif acid_type == "diprotic":
        pKa = acid_params
    elif acid_type == "triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# 自定义类用于同时输出到控制台和文件
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush()
    def flush(self):
        for f in self.files:
            f.flush()

# 修改后的主程序
def main():
    # 固定随机种子以确保可重复性（与代码1一致）
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    # 设置输出文件
    output_file = open('experiment_output.txt', 'w', encoding='utf-8')
    original_stdout = sys.stdout
    sys.stdout = Tee(sys.stdout, output_file)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    try:
        for exp in range(num_experiments):
            # 随机生成目标 pH 和初始 pH
            target_ph = round(random.uniform(2, 11), 2)
            env = PHAdjustmentEnv()

            # 随机选择酸类型并生成初始 pH
            env.acid_type = random.choice(["monoprotic", "diprotic", "triprotic"])
            if env.acid_type == "monoprotic":
                env.acid_params = random.uniform(2, 6)
            elif env.acid_type == "diprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 7)
                env.acid_params = [pKa1, pKa2]
            elif env.acid_type == "triprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 6)
                pKa3 = random.uniform(6, 8)
                env.acid_params = [pKa1, pKa2, pKa3]
            initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

            # 初始化环境
            env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

            print(f"==== 实验 {exp+1} 开始 ====")
            initial_state = env.get_state()
            print(f"初始状态: {np.round(initial_state, 2)}")
            print(f"酸类型: {env.acid_type}, 参数: {env.acid_params}, 目标 pH: {env.target_ph}")
            print("状态-动作-试剂对:")

            trace = []
            action, _ = env.select_best_action()
            while not env.done:
                state_before = env.get_state()
                current_ph, reward, done, info = env.step(action, mode='simulate')
                state_after = env.get_state()
                trace.append((state_after, action, action[0]))  # 记录试剂名称
                action, _ = env.select_best_action()

            for i, (s, a, reagent) in enumerate(trace, start=1):
                s_formatted = np.round(s, 2)
                print(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}")
            print(f"实验结束，共用步数: {env.steps_taken}, 最终 pH: {env.current_ph:.2f}\n")

            if abs(env.current_ph - env.target_ph) < 0.1:
                success_count += 1
                steps_success.append(env.steps_taken)

        success_rate = success_count / num_experiments * 100
        avg_steps = np.mean(steps_success) if steps_success else 0
        print("总实验数: {}, 成功实验数: {}, 成功率: {:.2f}%, 成功实验平均步数: {:.2f}".format(
            num_experiments, success_count, success_rate, avg_steps))
    finally:
        # 恢复原始 stdout 并关闭文件
        sys.stdout = original_stdout
        output_file.close()

if __name__ == '__main__':
    main()

In [None]:
# shap分析

In [None]:
import torch
import torch.nn as nn
import numpy as np
import shap
import matplotlib.pyplot as plt
import re

# 固定随机种子
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# 离散动作策略模型
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# 从 txt 文件提取状态向量，最多 100 个
def extract_state_vectors(file_path, max_vectors=10):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # 查找所有状态向量
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# SHAP 分析函数
def analyze_shap_importance(state_vectors, model_path="volume_regressor_best_big_discrete_new1_trained-1-test.pth", nsamples=500):
    # 初始化模型
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载模型:", e)
        return None, None
    
    model.eval()
    
    # 包装模型为 SHAP 可用的函数
    def model_predict(inputs):
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
        with torch.no_grad():
            outputs = []
            for i in range(inputs.shape[0]):
                volume = model.sample_action(inputs_tensor[i].unsqueeze(0))
                outputs.append(volume)
        return np.array(outputs)
    
    # 转换为 NumPy 数组
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    
    # 使用 KernelExplainer
    explainer = shap.KernelExplainer(model_predict, state_vectors_np)
    shap_values = explainer.shap_values(state_vectors_np, nsamples=nsamples)
    
    # 特征名称
    feature_names = ['current_ph', 'target_ph', 'pH_delta', 'error', 'last_action_volume']
    
    # 计算平均 SHAP 值
    avg_shap = np.abs(shap_values).mean(axis=0)
    total_shap = avg_shap.sum()
    normalized_shap = avg_shap / total_shap if total_shap > 0 else avg_shap
    
    # 打印结果
    print("\n平均SHAP值（绝对贡献，mL）：")
    for name, score in zip(feature_names, avg_shap):
        print(f"{name}: {score:.4f}")
    print("\n归一化SHAP值（比例）：")
    for name, score in zip(feature_names, normalized_shap):
        print(f"{name}: {score:.4f}")
    
    # 可视化：条形图
    plt.figure(figsize=(8, 6))
    plt.bar(feature_names, normalized_shap)
    plt.xlabel('特征')
    plt.ylabel('归一化 SHAP 值')
    plt.title('SHAP 特征重要性分析（前100个状态向量）')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("shap_importance.png")
    plt.show()
    
    # 可视化：SHAP 摘要图
    shap.summary_plot(shap_values, state_vectors_np, feature_names=feature_names, show=False)
    plt.savefig("shap_summary.png")
    plt.show()
    
    return dict(zip(feature_names, avg_shap)), dict(zip(feature_names, normalized_shap))

# 主程序
if __name__ == "__main__":
    # 读取 txt 文件，最多提取 100 个状态向量
    file_path = "test_output2_modified.txt"
    state_vectors = extract_state_vectors(file_path, max_vectors=500)
    
    print(f"\n提取到 {len(state_vectors)} 个状态向量。")
    
    # 运行 SHAP 分析
    if state_vectors:
        avg_shap, normalized_shap = analyze_shap_importance(state_vectors)
    else:
        print("未提取到状态向量，无法进行分析。")

In [None]:
# shap分析，error取绝对值

In [None]:
import torch
import torch.nn as nn
import numpy as np
import shap
import matplotlib.pyplot as plt
import re

# 固定随机种子
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# 离散动作策略模型
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# 从 txt 文件提取状态向量，最多 100 个
def extract_state_vectors(file_path, max_vectors=500):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    # 查找所有状态向量
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# SHAP 分析函数
def analyze_shap_importance(state_vectors, model_path="volume_regressor_best_big_discrete_new1_trained-1-test.pth", nsamples=500):
    # 初始化模型
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载模型:", e)
        return None, None
    
    model.eval()
    
    # 包装模型为 SHAP 可用的函数
    def model_predict(inputs):
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
        with torch.no_grad():
            outputs = []
            for i in range(inputs.shape[0]):
                volume = model.sample_action(inputs_tensor[i].unsqueeze(0))
                outputs.append(volume)
        return np.array(outputs)
    
    # 转换为 NumPy 数组并对 error 取绝对值
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    state_vectors_np[:, 3] = np.abs(state_vectors_np[:, 3])  # 对 error 列取绝对值
    
    # 使用 KernelExplainer
    explainer = shap.KernelExplainer(model_predict, state_vectors_np)
    shap_values = explainer.shap_values(state_vectors_np, nsamples=nsamples)
    
    # 特征名称
    feature_names = ['current_ph', 'target_ph', 'pH_delta', 'error', 'last_action_volume']
    
    # 计算平均 SHAP 值
    avg_shap = np.abs(shap_values).mean(axis=0)
    total_shap = avg_shap.sum()
    normalized_shap = avg_shap / total_shap if total_shap > 0 else avg_shap
    
    # 打印结果
    print("\n平均SHAP值（绝对贡献，mL）：")
    for name, score in zip(feature_names, avg_shap):
        print(f"{name}: {score:.4f}")
    print("\n归一化SHAP值（比例）：")
    for name, score in zip(feature_names, normalized_shap):
        print(f"{name}: {score:.4f}")
    
    # 可视化：条形图
    plt.figure(figsize=(8, 6))
    plt.bar(feature_names, normalized_shap)
    plt.xlabel('特征')
    plt.ylabel('归一化 SHAP 值')
    plt.title('SHAP 特征重要性分析（前500个状态向量）')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("shap_importance.png")
    plt.show()
    
    # 可视化：SHAP 摘要图
    shap.summary_plot(shap_values, state_vectors_np, feature_names=feature_names, show=False)
    plt.savefig("shap_summary.png")
    plt.show()
    
    return dict(zip(feature_names, avg_shap)), dict(zip(feature_names, normalized_shap))

# 主程序
if __name__ == "__main__":
    # 读取 txt 文件，最多提取 500 个状态向量
    file_path = "test_output2_modified.txt"
    state_vectors = extract_state_vectors(file_path, max_vectors=500)
    
    print(f"\n提取到 {len(state_vectors)} 个状态向量。")
    
    # 运行 SHAP 分析
    if state_vectors:
        avg_shap, normalized_shap = analyze_shap_importance(state_vectors)
    else:
        print("未提取到状态向量，无法进行分析。")

In [None]:
# 相关性分析

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import re

# 固定随机种子
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# 离散动作策略模型
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# 从 txt 文件提取状态向量，最多 100 个
def extract_state_vectors(file_path, max_vectors=100):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# 相关性分析函数
def correlation_analysis(file_path, model_path="volume_regressor_best_big_discrete_new1_trained-1-test.pth", max_vectors=10000):
    # 提取状态向量
    state_vectors = extract_state_vectors(file_path, max_vectors)
    if not state_vectors:
        print("未提取到状态向量，无法进行分析。")
        return None, None
    
    print(f"\n提取到 {len(state_vectors)} 个状态向量。")
    
    # 初始化模型
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载模型:", e)
        return None, None
    
    model.eval()
    
    # 获取预测体积
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    volumes = []
    with torch.no_grad():
        for state in state_vectors_np:
            volume = model.sample_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
            volumes.append(volume)
    volumes = np.array(volumes)
    
    # 特征名称
    feature_names = ['current_ph', 'target_ph', 'pH_delta', 'error', 'last_action_volume']
    
    # 计算相关系数
    pearson_corrs = {}
    spearman_corrs = {}
    for i, name in enumerate(feature_names):
        pearson_corr, pearson_p = pearsonr(state_vectors_np[:, i], volumes)
        spearman_corr, spearman_p = spearmanr(state_vectors_np[:, i], volumes)
        pearson_corrs[name] = (pearson_corr, pearson_p)
        spearman_corrs[name] = (spearman_corr, spearman_p)
    
    # 打印结果
    print("\nPearson 相关系数（相关系数, p 值）：")
    for name, (corr, p) in pearson_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    print("\nSpearman 相关系数（相关系数, p 值）：")
    for name, (corr, p) in spearman_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    
    # 可视化：条形图
    plt.figure(figsize=(10, 6))
    corr_df = pd.DataFrame({
        'Pearson': [corr for corr, _ in pearson_corrs.values()],
        'Spearman': [corr for corr, _ in spearman_corrs.values()]
    }, index=feature_names)
    corr_df.plot(kind='bar', ax=plt.gca())
    plt.xlabel('特征')
    plt.ylabel('相关系数')
    plt.title('状态向量与预测体积的相关性分析（前100个状态向量）')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("correlation_bar.png")
    plt.show()
    
    # 可视化：热图
    plt.figure(figsize=(8, 6))
    corr_matrix = np.zeros((len(feature_names), 2))
    for i, name in enumerate(feature_names):
        corr_matrix[i, 0] = pearson_corrs[name][0]
        corr_matrix[i, 1] = spearman_corrs[name][0]
    sns.heatmap(corr_matrix, annot=True, xticklabels=['Pearson', 'Spearman'], yticklabels=feature_names, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title('相关性热图')
    plt.tight_layout()
    plt.savefig("correlation_heatmap.png")
    plt.show()
    
    return pearson_corrs, spearman_corrs

# 主程序
if __name__ == "__main__":
    file_path = "test_output2_modified.txt"
    pearson_corrs, spearman_corrs = correlation_analysis(file_path)

In [None]:
# 相关性分析，error取绝对值

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import re

# 固定随机种子
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# 离散动作策略模型
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# 从 txt 文件提取状态向量，最多 100 个
def extract_state_vectors(file_path, max_vectors=10000):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'r', encoding='utf-8') as f:
        content = f.read()
    
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# 相关性分析函数
def correlation_analysis(file_path, model_path="volume_regressor_best_big_discrete_new1_trained-1-test.pth", max_vectors=20000):
    # 提取状态向量
    state_vectors = extract_state_vectors(file_path, max_vectors)
    if not state_vectors:
        print("未提取到状态向量，无法进行分析。")
        return None, None
    
    print(f"\n提取到 {len(state_vectors)} 个状态向量。")
    
    # 初始化模型
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        print("加载预训练模型成功。")
    except Exception as e:
        print("未能加载模型:", e)
        return None, None
    
    model.eval()
    
    # 获取预测体积
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    volumes = []
    with torch.no_grad():
        for state in state_vectors_np:
            volume = model.sample_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
            volumes.append(volume)
    volumes = np.array(volumes)
    
    # 特征名称
    feature_names = ['current_ph', 'target_ph', 'pH_delta', 'error', 'last_action_volume']
    
    # 计算相关系数
    pearson_corrs = {}
    spearman_corrs = {}
    for i, name in enumerate(feature_names):
        # 对 error 取绝对值
        if name == 'error':
            feature_values = np.abs(state_vectors_np[:, i])
        else:
            feature_values = state_vectors_np[:, i]
        pearson_corr, pearson_p = pearsonr(feature_values, volumes)
        spearman_corr, spearman_p = spearmanr(feature_values, volumes)
        pearson_corrs[name] = (pearson_corr, pearson_p)
        spearman_corrs[name] = (spearman_corr, spearman_p)
    
    # 打印结果
    print("\nPearson 相关系数（相关系数, p 值）：")
    for name, (corr, p) in pearson_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    print("\nSpearman 相关系数（相关系数, p 值）：")
    for name, (corr, p) in spearman_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    
    # 可视化：条形图
    plt.figure(figsize=(10, 6))
    corr_df = pd.DataFrame({
        'Pearson': [corr for corr, _ in pearson_corrs.values()],
        'Spearman': [corr for corr, _ in spearman_corrs.values()]
    }, index=feature_names)
    corr_df.plot(kind='bar', ax=plt.gca())
    plt.xlabel('特征')
    plt.ylabel('相关系数')
    plt.title('状态向量与预测体积的相关性分析（前10000个状态向量）')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("correlation_bar.png")
    plt.show()
    
    # 可视化：热图
    plt.figure(figsize=(8, 6))
    corr_matrix = np.zeros((len(feature_names), 2))
    for i, name in enumerate(feature_names):
        corr_matrix[i, 0] = pearson_corrs[name][0]
        corr_matrix[i, 1] = spearman_corrs[name][0]
    sns.heatmap(corr_matrix, annot=True, xticklabels=['Pearson', 'Spearman'], yticklabels=feature_names, cmap='coolwarm', vmin=-1, vmax=1)
    plt.title('相关性热图')
    plt.tight_layout()
    plt.savefig("correlation_heatmap.png")
    plt.show()
    
    return pearson_corrs, spearman_corrs

# 主程序
if __name__ == "__main__":
    file_path = "test_output2_modified.txt"
    pearson_corrs, spearman_corrs = correlation_analysis(file_path)

In [None]:
# 消融实验

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json
import os

##############################################
# 固定随机种子，确保实验可重复
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# 全局常量
##############################################
TITRANT_CONC = 0.1          # 滴定剂浓度（0.1 M）
MAX_STEPS = 50              # 最大步数
INITIAL_ACID_VOL = 11.0     # 初始被滴定弱酸体积 (mL)
SUCCESS_THRESHOLD = 0.1     # pH 误差阈值

##############################################
# pH 计算函数：单元酸
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH 计算函数：双元酸
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH 计算函数：三元酸
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# 奖励计算函数（支持消融实验）
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume, ablate_component=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("dense_lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio) if ablate_component != "dense_reward" else 0
    step_penalty = reward_config.get("step_penalty", -0.005) if ablate_component != "step_penalty" else 0
    overshoot_weight = reward_config.get("overshoot_weight", 0.2)
    overshoot_threshold = reward_config.get("overshoot_threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold)))) if ablate_component != "overshoot_penalty" else 0
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("wrong_dir_factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'base' in reagent.lower()) or (current_ph < target_ph and 'acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph) if ablate_component != "wrong_dir_penalty" else 0
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("overshoot_volume_penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume if ablate_component != "volume_penalty" else 0
        overshoot_volume_bonus = reward_config.get("overshoot_volume_bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume) if ablate_component != "volume_bonus" else 0
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("terminal_bonus", 3.0) * bonus_factor if ablate_component != "terminal_bonus" else 0
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("reward_clip_max", 4.0)
        reward_clip_min = reward_config.get("reward_clip_min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH 模拟环境：PHSimEnv
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.titrant_conc = titrant_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "dense_lambda": -0.03,
            "step_penalty": 0,
            "terminal_bonus": 3.9, 
            "overshoot_weight": 0.2,
            "overshoot_threshold": 0.1,
            "wrong_dir_factor": 1,
            "reward_clip_max": 4.1,
            "reward_clip_min": -4.1,
            "overshoot_volume_penalty": 0.1,
            "overshoot_volume_bonus": 0.1
        }
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 7)) for _ in range(30)]
        self.triprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 6), random.uniform(6, 8)) for _ in range(30)]
        self.reset()

    def reset(self, acid_type=None, acid_params=None, target_ph=None):
        if acid_type is None:
            self.acid_type = random.choice(['monoprotic', 'diprotic', 'triprotic'])
        else:
            self.acid_type = acid_type
        
        if acid_params is None:
            if self.acid_type == 'monoprotic':
                self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
            elif self.acid_type == 'diprotic':
                self.acid_params = random.choice(self.diprotic_pKa_list)
            else:
                self.acid_params = random.choice(self.triprotic_pKa_list)
        else:
            self.acid_params = acid_params

        self.target_ph = round(random.uniform(2, 11), 2) if target_ph is None else target_ph
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None

        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)

        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action, ablate_component=None):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        if self.current_ph < self.target_ph:
            reagent = "strong_base"
            self.base_added_mL += volume
        else:
            reagent = "strong_acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        self.previous_ph = self.current_ph

        if self.acid_type == 'monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume,
            ablate_component=ablate_component
        )
        
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'reagent': reagent}

##############################################
# 离散动作策略模型：DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# 在线训练：使用 REINFORCE 算法更新策略模型
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=200, gamma=0.99, ablate_component=None):
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            next_state, reward, done, _ = env.step(action_scalar, ablate_component=ablate_component)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        if episode % 50 == 0:
            total_reward = sum(rewards)
            print(f"Episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

    model_path = f"volume_regressor_ablation_no_{ablate_component}_200ep.pth" if ablate_component else "volume_regressor_full_200ep.pth"
    torch.save(policy_model.state_dict(), model_path)
    print(f"模型已保存至 {model_path}")

##############################################
# 测试函数：运行固定200个实验并统计成功率和平均步数
##############################################
def test_model(policy_model, env, test_configs, ablate_component=None):
    success_count = 0
    success_steps = []
    
    for i, config in enumerate(test_configs):
        acid_type, acid_params, target_ph = config
        state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph)
        done = False
        steps = 0
        while not done and steps < MAX_STEPS:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar, ablate_component=ablate_component)
            steps += 1
        
        is_success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD and steps <= MAX_STEPS
        if is_success:
            success_count += 1
            success_steps.append(steps)
    
    success_rate = success_count / len(test_configs) * 100
    avg_steps = np.mean(success_steps) if success_steps else 0.0
    
    print(f"\n测试结果 ({'No ' + ablate_component if ablate_component else 'Full Reward'}):")
    print(f"成功率: {success_rate:.2f}% ({success_count}/{len(test_configs)})")
    print(f"成功实验平均步数: {avg_steps:.2f}")
    
    return {"success_rate": success_rate, "avg_steps": avg_steps, "success_count": success_count, "total_experiments": len(test_configs)}

##############################################
# 生成固定200个测试配置
##############################################
def generate_test_configs(num_configs=200, seed=123):
    np.random.seed(seed)
    random.seed(seed)
    monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
    diprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 7)) for _ in range(30)]
    triprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 6), random.uniform(6, 8)) for _ in range(30)]
    
    configs = []
    for _ in range(num_configs):
        acid_type = random.choice(['monoprotic', 'diprotic', 'triprotic'])
        if acid_type == 'monoprotic':
            acid_params = float(np.random.choice(monoprotic_pKa_list))
        elif acid_type == 'diprotic':
            acid_params = random.choice(diprotic_pKa_list)
        else:
            acid_params = random.choice(triprotic_pKa_list)
        target_ph = round(random.uniform(2, 11), 2)
        configs.append((acid_type, acid_params, target_ph))
    
    return configs

##############################################
# 主程序：运行消融实验并测试
##############################################
if __name__ == "__main__":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99
    num_episodes = 500
    num_test_experiments = 500

    reward_components = [
        "dense_reward",
        "step_penalty",
        "overshoot_penalty",
        "wrong_dir_penalty",
        "volume_penalty",
        "volume_bonus",
        "terminal_bonus"
    ]

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    test_configs = generate_test_configs(num_configs=num_test_experiments, seed=123)
    
    results = {}
    
    print("\n=== 运行完整奖励的训练 ===")
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=num_episodes, gamma=gamma, ablate_component=None)
    results["full_reward"] = test_model(policy_model, env, test_configs, ablate_component=None)

    for component in reward_components:
        print(f"\n=== 消融实验：移除 {component} ===")
        policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
        optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
        train_reinforce(env, policy_model, optimizer, num_episodes=num_episodes, gamma=gamma, ablate_component=component)
        results[f"no_{component}"] = test_model(policy_model, env, test_configs, ablate_component=component)
    
    with open("ablation_test_results.json", "w") as f:
        json.dump(results, f, indent=4)
    print("\n测试结果已保存至 ablation_test_results.json")

In [None]:
# 测试pid

In [None]:
import numpy as np
import csv
import ast
import math
import statistics

# 全局参数配置
TITRANT_CONC = 0.1  # 滴定剂浓度 (mol/L)
INITIAL_ACID_VOL = 11.0  # 初始酸液体积 (mL)
MAX_STEPS = 50  # 每个实验最大允许步数
SUCCESS_THRESHOLD = 0.1  # pH 误差允许范围
MIN_VOLUME = 0.01  # 最小滴加量 (mL)

# --- 物理化学引擎模块 ---

def get_acid_charge_factor(pH, pKas):
    """计算特定pH下，1mol多元弱酸电离出的总负电荷数 (分步电离理论)"""
    H = 10 ** (-pH)
    Kas = [10 ** (-pk) for pk in sorted(pKas)]
    n = len(Kas)
    coeffs = [1.0]
    curr = 1.0
    for K in Kas:
        curr *= K
        coeffs.append(curr)
    terms = [coeffs[i] * (H ** (n - i)) for i in range(n + 1)]
    D = sum(terms)
    avg_charge = sum(i * terms[i] for i in range(n + 1)) / D
    return avg_charge

def charge_balance_equation(pH, c_A, c_Na, c_HCl, pKas):
    """电荷平衡方程: [H+] + [Na+] - [OH-] - [Cl-] - [酸根负电荷] = 0"""
    H = 10 ** (-pH)
    OH = 1e-14 / H
    acid_neg_charge = c_A * get_acid_charge_factor(pH, pKas)
    return H + c_Na - OH - c_HCl - acid_neg_charge

def solve_pH(base_vol, acid_vol, pKas):
    """求解混合体系的最终 pH"""
    total_vol_L = (INITIAL_ACID_VOL + base_vol + acid_vol) / 1000
    c_A = (INITIAL_ACID_VOL * 0.1 / 1000) / total_vol_L
    c_Na = (base_vol * TITRANT_CONC / 1000) / total_vol_L
    c_HCl = (acid_vol * TITRANT_CONC / 1000) / total_vol_L
    
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2
        if charge_balance_equation(mid, c_A, c_Na, c_HCl, pKas) > 0:
            lo = mid
        else:
            hi = mid
    return round((lo + hi) / 2, 2)

# --- 环境与控制器模块 ---

class TitrationEnv:
    def __init__(self):
        self.pKas = []
        self.target_ph = 0
        self.base_added = 0.0
        self.acid_added = 0.0
        self.current_ph = 0.0
        self.steps = 0

    def reset_state(self, pKas, target_ph):
        self.pKas = pKas if isinstance(pKas, list) else [pKas]
        self.target_ph = target_ph
        self.base_added = 0.0
        self.acid_added = 0.0
        self.current_ph = solve_pH(0, 0, self.pKas)
        self.steps = 0
        return self.current_ph

    def step(self, volume):
        prev_ph = self.current_ph
        # 决策：根据当前 pH 和目标的差距决定加酸还是加碱
        reagent = "Base" if self.current_ph < self.target_ph else "Acid"
        
        if reagent == "Base":
            self.base_added += volume
        else:
            self.acid_added += volume
            
        self.current_ph = solve_pH(self.base_added, self.acid_added, self.pKas)
        self.steps += 1
        
        # 过冲判定：操作后的 pH 是否跨过了目标值
        is_overshoot = False
        if (prev_ph < self.target_ph and self.current_ph > self.target_ph) or \
           (prev_ph > self.target_ph and self.current_ph < self.target_ph):
            is_overshoot = True
            
        return self.current_ph, reagent, volume, is_overshoot

class PIDController:
    def __init__(self, Kp=0.5, Ki=0.05, Kd=0.02):
        self.Kp, self.Ki, self.Kd = Kp, Ki, Kd
        self.reset()

    def reset(self):
        self.integral = 0
        self.prev_error = None

    def get_volume(self, current_ph, target_ph):
        error = abs(target_ph - current_ph)
        self.integral += error
        derivative = (error - self.prev_error) if self.prev_error is not None else 0
        self.prev_error = error
        
        # 非线性输出：距离远时大剂量，距离近时指数级减小剂量
        output = self.Kp * error + self.Ki * self.integral + self.Kd * derivative
        if error > 2.0:
            vol = 1.2 * math.log1p(output)
        else:
            vol = 0.05 * output
        return max(round(vol, 3), MIN_VOLUME)

# --- 主逻辑执行 ---

def run_all_experiments(csv_path, report_path):
    env = TitrationEnv()
    pid = PIDController()
    
    all_results = []
    total_steps_global = 0
    total_overshoots_global = 0
    
    # 读取 CSV 数据
    with open(csv_path, 'r', encoding='utf-8') as f:
        reader = list(csv.DictReader(f))

    # 打开 TXT 文件准备写入结果
    with open(report_path, 'w', encoding='utf-8') as out_f:
        out_f.write("滴定实验操作详细记录报告\n")
        out_f.write("="*60 + "\n\n")

        for row in reader:
            exp_id = row['Experiment']
            acid_type = row['Acid_Type']
            pKas = ast.literal_eval(row['Acid_Params'])
            target_ph = float(row['Target_pH'])
            
            curr_ph = env.reset_state(pKas, target_ph)
            pid.reset()
            
            out_f.write(f"实验 ID: {exp_id} | 酸类型: {acid_type} | 目标 pH: {target_ph}\n")
            out_f.write(f"起始状态: pH = {curr_ph:.2f}\n")
            
            exp_overshoot_count = 0
            done = False
            
            while not done:
                vol = pid.get_volume(curr_ph, target_ph)
                curr_ph, reagent, added_vol, overshot = env.step(vol)
                
                if overshot: 
                    total_overshoots_global += 1
                
                status_msg = f"  步次 {env.steps:02d}: 添加 {reagent} {added_vol:.3f}mL -> 当前 pH: {curr_ph:.2f}"
                if overshot:
                    status_msg += " [过冲!]"
                out_f.write(status_msg + "\n")
                
                # 判定完成或失败
                if abs(curr_ph - target_ph) <= SUCCESS_THRESHOLD:
                    success = True
                    done = True
                elif env.steps >= MAX_STEPS:
                    success = False
                    done = True
            
            total_steps_global += env.steps
            all_results.append({'success': success, 'steps': env.steps})
            
            out_f.write(f"实验结论: {'成功' if success else '失败'} | 最终 pH: {curr_ph:.2f} | 总步数: {env.steps}\n")
            out_f.write("-" * 40 + "\n\n")

        # --- 计算最终统计指标 ---
        total_exps = len(all_results)
        success_count = sum(1 for r in all_results if r['success'])
        success_rate = (success_count / total_exps) * 100
        
        step_counts = [r['steps'] for r in all_results]
        avg_steps = statistics.mean(step_counts)
        # 方差计算
        var_steps = statistics.variance(step_counts) if total_exps > 1 else 0
        # 过冲率 = 过冲次数 / 总滴加步数
        overshoot_rate = (total_overshoots_global / total_steps_global) * 100

        # 写入汇总统计
        out_f.write("\n" + "="*60 + "\n")
        out_f.write("汇总统计报告\n")
        out_f.write("-" * 60 + "\n")
        out_f.write(f"总实验次数      : {total_exps}\n")
        out_f.write(f"操作成功率      : {success_rate:.2f}%\n")
        out_f.write(f"平均步数 (Mean) : {avg_steps:.2f}\n")
        out_f.write(f"步数方差 (Var)  : {var_steps:.2f}\n")
        out_f.write(f"平均步数±方差   : {avg_steps:.2f} ± {var_steps:.2f}\n")
        out_f.write(f"总操作步数      : {total_steps_global}\n")
        out_f.write(f"总过冲次数      : {total_overshoots_global}\n")
        out_f.write(f"整体过冲率      : {overshoot_rate:.2f}% (过冲次数/总步数)\n")
        out_f.write("="*60 + "\n")

    print(f"所有实验完成！详细日志已写入至: {report_path}")

if __name__ == "__main__":
    # 配置输入和输出路径
    input_csv = r'C:\Users\Admin\Documents\GitHub\ph_adjust_algorithm\upload\experiment_summary.csv'
    output_txt = r'C:\Users\Admin\Documents\GitHub\ph_adjust_algorithm\upload\experiment_report.txt'
    
    run_all_experiments(input_csv, output_txt)