In [None]:
#bayesianfinal

In [None]:
import numpy as np 
import math
import logging
from scipy.stats import norm
from scipy.optimize import brentq, newton

# Configuration log (INFO level)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# -------------------------------
# Global tunable parameters
# -------------------------------
TITRATED_VOLUME = 11.0    # Volume of titrant (m l)
ANALYTE_CONC = 0.1        # Concentration of acid in titrated reagent (mol/l)

# The following are two different concentrations of hydrochloric acid and two different concentrations of sodium hydroxide that can be set by the user.
HCL_CONC1 = 0.1           # Hydrochloric acid 1 concentration (mol/l)
HCL_CONC2 = 0.01          # Hydrochloric acid 2 concentration (mol/l)
NAOH_CONC1 = 0.1          # Sodium hydroxide 1 concentration (mol/l)
NAOH_CONC2 = 0.01         # Sodium hydroxide 2 concentration (mol/l)
TARGET_PH = 11           # target pH
MAX_STEPS = 50            # Maximum number of steps

# -------------------------------
# Global reagent concentration dictionary (identified with "1" or "2")
# -------------------------------
REAGENTS = {
    'Dilute acid 1': HCL_CONC1,
    'Dilute acid 2': HCL_CONC2,
    'Dilute base 1': NAOH_CONC1,
    'Dilute base 2': NAOH_CONC2,
}

# -------------------------------
# pH calculation function (based on multiple buffer pairs)
# -------------------------------

def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge
def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    acid_anion_charge = self.calculate_acid_anion_charge(c_A, H, pKa_list)
    return H + c_Na - OH - acid_anion_charge - c_HCl
def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
        if abs(f_mid) < 1e-10:
            return mid
        if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

# -------------------------------
# pH adjustment environment
# -------------------------------
class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME  
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        
        # Minimum dripping volume (m l), and construct a dripping volume list
        self.min_addition_volume = 0.01  
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() 
                             for volume in self.addition_volumes]
                
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4

        # Set the uncertain parameters of the buffer system: initial random pKa and total number of moles
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        # Use the initial sampled pKa value as a reference
        self.ref_pKa = np.copy(self.pKa_list)
        # Used to record the updated standard deviation of each buffer pair, initially set to 0.5
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None

        # Initialize the prior distribution (assuming pKa ~ N(mean, 0.5) and total_moles ~ N(mean, 0.005))
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5

        self.last_measured_ph = None
        self.prev_measured_ph = None

        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None

        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))
    
    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        """
        Calculate dynamic weights based on the current pKa_list, ref_pKa and pKa_std, construct a valid pKa array,
        The length of the array is equal to the number of buffer pairs.
        """
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        """
        Calculate the theoretical drip volume required from the current pH to the target pH and solve numerically using brentq.
        Choose to add an acid or a base based on the current situation and use the updated pKa mean to calculate the pH.
        """
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()

        if self.current_ph < self.target_ph:
            # The current system is too acidic and alkali needs to be added
            if self.use_secondary_reagents:
                reagent = 'Dilute base 2'
            else:
                reagent = 'Dilute base 1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            # The current system is too alkaline and acid needs to be added.
            if self.use_secondary_reagents:
                reagent = 'Dilute acid 2'
            else:
                reagent = 'Dilute acid 1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def step(self, action: tuple, mode: str = 'Simulate') -> tuple:
        """
        mode parameter:
          -'simulate': automatically calculate the current pH (call recalc_ph calculation),
          -'manual': Prompts the user for pH value (interactive).
        """
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume

            if 'Acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'Base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles

            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'Base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'Acid' in reagent.lower():
                return self.current_ph, -100, True, {}

            # Update measured pH based on mode selection
            if mode == 'Simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            elif mode == 'Manual':
                while True:
                    user_input = input("Please enter the currently measured pH value: ")
                    try:
                        manual_ph = float(user_input)
                        break
                    except ValueError:
                        print("Incorrect input format, please enter a number (e.g. 7.0).")
                self.update_exp_ph(manual_ph)

            # Detection of pH oscillations at minimum dosage
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("The pH oscillation at the minimum dripping amount was detected, the cumulative number of times:%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")

            self.steps_taken += 1

            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}

            # ---Modified ideal volume calculation part ---
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['P the'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            # --------------------------------------------

            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty

            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'Acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'Base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty

            if self.steps_taken > 0:
                if 'Acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'Base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0

                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                     self.target_ph, reagent,
                                                                     last_added, reagent_conc,
                                                                     self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh

            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True

            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("An exception occurs when executing step:%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'Ph adjustment env':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)


    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph

        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)

        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['P the'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)

        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['P the'].rvs())
            sampled_total_moles.append(prior['Total moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        """
        Use the parameters obtained from sampling to update the environment copy and calculate the pH.
        This reflects the impact of parameter changes on pH prediction.
        """
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        """
        Bayesian update process (based on particle filtering):
          ① Sampling: Sample num_particles particles from the current prior;
          ② Prediction: For each particle, predict the pH after operation based on action;
          ③ Evaluation: Calculate the likelihood of each particle;
          ④ Resampling: obtain a new particle set based on likelihood resampling;
          ⑤ Statistical update: Use the resampled particles to update the prior distribution of the buffer pair.
        """
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['P the'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['Total moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='Manual')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# -------------------------------
# main program
# -------------------------------
def main():
    initial_volume = TITRATED_VOLUME

    REAGENTS['Dilute acid 1'] = HCL_CONC1
    REAGENTS['Dilute acid 2'] = HCL_CONC2
    REAGENTS['Dilute base 1'] = NAOH_CONC1
    REAGENTS['Dilute base 2'] = NAOH_CONC2

    # Manually set initial pH
    initial_ph = 6.9
    logging.info("Initial pH = %.2f", initial_ph)
    
    env = PHAdjustmentEnv()
    env.initialize(init_pH=initial_ph, target_pH=TARGET_PH, max_steps=MAX_STEPS, initial_volume=initial_volume)
    
    measured_ph = env.current_ph
    action, done = env.select_best_action()
    
    while not done:
        if abs(measured_ph - env.target_ph) < 0.1:
            break
        overshoot_msg = ""
        if env.overshoot_threshold is not None:
            overshoot_msg = "(Overshoot limit: the maximum dripping volume is {:.2f} M l）".format(env.overshoot_threshold)
        print("Current pH = {:.2f}, recommended operation: add {} {}".format(measured_ph, action, overshoot_msg))
        action, done = env.suggest_next_action(action, measured_ph)
        measured_ph = env.current_ph

    print("The experiment ends when the target pH is reached or the maximum number of steps is exceeded.")
    print("Total amount of added acid:{:.2f} mL, total alkali added amount:{:.2f} mL, total number of steps:{}, final pH = {:.2f}"
          .format(env.acid_volume, env.base_volume, env.steps_taken, measured_ph))

if __name__ == 'Main':
    main()


In [None]:
# Generate data

In [None]:
import numpy as np
import math
import logging
import json
import random
from scipy.stats import norm
from scipy.optimize import brentq

# Configuration log (INFO level)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Fixed random seeds to ensure experimental repeatability
np.random.seed(42)
random.seed(42)

# -------------------------------
# Global tunable parameters (consistent with the first code)
# -------------------------------
TITRATED_VOLUME = 10.0    # Volume of titrant (m l)
ANALYTE_CONC = 0.1        # Concentration of acid in titrated reagent (mol/l)
HCL_CONC1 = 0.1           # Hydrochloric acid 1 concentration (mol/l)
HCL_CONC2 = 0.01          # Hydrochloric acid 2 concentration (mol/l)
NAOH_CONC1 = 0.1          # Sodium hydroxide 1 concentration (mol/l)
NAOH_CONC2 = 0.01         # Sodium hydroxide 2 concentration (mol/l)
MAX_STEPS = 50            # Maximum number of steps

# Reagent concentration dictionary
REAGENTS = {
    'Dilute acid 1': HCL_CONC1,
    'Dilute acid 2': HCL_CONC2,
    'Dilute base 1': NAOH_CONC1,
    'Dilute base 2': NAOH_CONC2,
}

# Define the mapping of reagent names to discrete action indices
reagent_mapping = {
    'Dilute acid 1': 0,
    'Dilute acid 2': 1,
    'Dilute base 1': 2,
    'Dilute base 2': 3,
}

# -------------------------------
# pH calculation function (consistent with the first code)
# -------------------------------
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

def f(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_eff_array: np.ndarray) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_eff_array.tolist())
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH(c_A: float, c_Na: float, c_HCl: float, pKa_eff_array: np.ndarray) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f(mid, c_A, c_Na, c_HCl, pKa_eff_array)
        if abs(f_mid) < 1e-10:
            return mid
        if f(lo, c_A, c_Na, c_HCl, pKa_eff_array) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_custom(acid_total_moles: float, base_total_moles: float,
                        acid_volume: float, base_volume: float, pKa_list=None) -> float:
    V_total = (TITRATED_VOLUME + acid_volume + base_volume) / 1000.0
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    c_A = n_analyte / V_total
    c_Na = base_total_moles / V_total
    c_HCl = acid_total_moles / V_total
    if pKa_list is None:
        pKa_list = [4.21]
    pKa_eff_array = np.array(pKa_list)
    return round(solve_pH(c_A, c_Na, c_HCl, pKa_eff_array), 2)

# -------------------------------
# pH adjustment environment class (titration scheme is consistent with the first code)
# -------------------------------
class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.last_action = None

    def get_state(self):
        ph_diff = round(self.current_ph - self.target_ph, 2) if self.target_ph is not None else None
        last_added_volume = self.last_action[1] if self.last_action is not None else 0
        return {
            'P h': round(self.current_ph, 2),
            'Target ph': self.target_ph,
            'Ph diff': ph_diff,
            'Acid volume': self.acid_volume,
            'Base volume': self.base_volume,
            'Total volume': self.total_volume,
            'Steps taken': self.steps_taken,
            'Error': round(self.current_ph - self.target_ph, 2) if self.target_ph is not None else None,
            'Last action': self.last_action,
            'Ph delta': round(self.last_measured_ph - self.prev_measured_ph, 2) if self.prev_measured_ph is not None else None,
            'Last added volume': last_added_volume
        }

    def initialize(self, init_pH: float, target_pH: float, max_steps: int,
                   acid_type: str = 'Mono', pKa_list=None, initial_volume: float = TITRATED_VOLUME) -> None:
        if pKa_list is None:
            if acid_type == 'Mono':
                pKa_list = [np.random.uniform(1, 5)]
            elif acid_type == 'From':
                pKa_list = [np.random.uniform(1, 3), np.random.uniform(4, 6)]
            elif acid_type == 'Tri':
                pKa_list = [np.random.uniform(1, 3), np.random.uniform(4, 5), np.random.uniform(6, 7)]
            else:
                pKa_list = [4.0]
        pKa_list = [round(val, 2) for val in pKa_list]
        self.num_buffers = len(pKa_list)
        self.pKa_list = np.array(pKa_list)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        n_analyte = (initial_volume / 1000.0) * ANALYTE_CONC
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.acid_type = acid_type
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        pH = round(pH, 2)
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()

        if self.current_ph < self.target_ph:
            reagent = 'Dilute base 2' if self.use_secondary_reagents else 'Dilute base 1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'Dilute acid 2' if self.use_secondary_reagents else 'Dilute acid 1'
            conc = self.reagents[reagent]

            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph

            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def step(self, action: tuple, _: float = None) -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'Acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'Base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            penalty = 0
            if current_for_direction > self.target_ph and 'Base' in reagent.lower():
                penalty = -100
                logging.info("Use the wrong reagent (base), give a penalty, but continue the experiment.")
            if current_for_direction < self.target_ph and 'Acid' in reagent.lower():
                penalty = -100
                logging.info("Use the wrong reagent (acid), give a penalty, but continue the experiment.")
            simulated_ph = calculate_pH_custom(self.acid_added_moles, self.base_added_moles,
                                               self.acid_volume, self.base_volume,
                                               pKa_list=self.get_effective_pka_array().tolist())
            self.update_exp_ph(simulated_ph)
            self.last_action = action
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("The pH oscillation at the minimum dripping amount was detected, the cumulative number of times:%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            current_error = abs(self.current_ph - self.target_ph)
            previous_error = abs(self.previous_ph - self.target_ph)
            ph_change = abs(self.current_ph - self.prev_measured_ph) if self.prev_measured_ph is not None else 0.0
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['P the'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = current_error + 0.1 * required_vol
            max_vol = max(self.addition_volumes)
            ideal_volume = self.min_addition_volume + (max_vol - self.min_addition_volume) * np.tanh(alpha * combined_value)
            error_reward = -current_error
            improvement = previous_error - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty + penalty
            reward = round(reward, 2)
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'Acid' in reagent.lower():
                pen = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= pen
            if self.target_ph < current_for_direction and 'Base' in reagent.lower():
                pen = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= pen
            reward = round(reward, 2)
            if self.steps_taken > 0:
                if 'Acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'Base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("An exception occurs when executing step:%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'Ph adjustment env':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.overshoot_occurred = self.overshoot_occurred
        env_copied.overshoot_reagent = self.overshoot_reagent
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.acid_type = self.acid_type
        env_copied.last_action = self.last_action
        return env_copied

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph

        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 1' in r.lower()]

        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)

        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['P the'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)

        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['P the'].rvs())
            sampled_total_moles.append(prior['Total moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        reagent, volume = action
        volume = float(volume)
        added_moles = env_copy.reagents[reagent] * (volume / 1000.0)
        env_copy.total_volume += volume
        if 'Acid' in reagent.lower():
            env_copy.acid_added_moles += added_moles
            env_copy.acid_volume += volume
        elif 'Base' in reagent.lower():
            env_copy.base_added_moles += added_moles
            env_copy.base_volume += volume
        V_total = (TITRATED_VOLUME + env_copy.acid_volume + env_copy.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = env_copy.base_added_moles / V_total
        c_HCl = env_copy.acid_added_moles / V_total
        pKa_eff_array = np.array(sampled_pKa)
        new_ph = solve_pH(c_A, c_Na, c_HCl, pKa_eff_array)
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['P the'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['Total moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action)
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# -------------------------------
# Single experiment generation function
# -------------------------------
def generate_single_experiment(acid_type: str) -> dict:
    if acid_type == 'Mono':
        pKa_list = [np.random.uniform(1, 5)]
    elif acid_type == 'From':
        pKa_list = [np.random.uniform(1, 4), np.random.uniform(4, 7)]
    elif acid_type == 'Tri':
        pKa_list = [np.random.uniform(1, 3), np.random.uniform(3, 5), np.random.uniform(5, 7)]
    else:
        pKa_list = [4.0]
    pKa_list = [round(val, 2) for val in pKa_list]
    target_ph = round(np.random.uniform(2, 11), 2)
    init_ph = calculate_pH_custom(0, 0, 0, 0, pKa_list=pKa_list)
    env = PHAdjustmentEnv()
    env.initialize(init_pH=init_ph, target_pH=target_ph, max_steps=MAX_STEPS,
                   acid_type=acid_type, pKa_list=pKa_list, initial_volume=TITRATED_VOLUME)
    transitions = []
    state = env.get_state()
    action, done = env.select_best_action()
    while not env.done:
        current_ph, reward, done, _ = env.step(action)
        next_state = env.get_state()
        transition = {
            'State': state,
            'Action': action,
            'Reward': reward,
            'Next state': next_state,
            'Done': done
        }
        transitions.append(transition)
        state = next_state
        if done:
            break
        action, done = env.select_best_action()
    experiment_data = {
        'Acid type': acid_type,
        'P is a list': pKa_list,
        'Target ph': target_ph,
        'Initial ph': init_ph,
        'Steps taken': env.steps_taken,
        'Success': (env.steps_taken <= MAX_STEPS and abs(env.current_ph - target_ph) < 0.1),
        'Transitions': transitions
    }
    return experiment_data

# -------------------------------
# Helper functions: convert states and actions into numeric vectors
# -------------------------------
def convert_state(state: dict) -> list:
    pH = state.get('P h', 0)
    target_ph = state.get('Target ph', 0)
    acid_vol = state.get('Acid volume', 0)
    base_vol = state.get('Base volume', 0)
    tot_vol = state.get('Total volume', 0)
    steps = state.get('Steps taken', 0)
    error = state.get('Error', 0) if state.get('Error') is not None else 0
    ph_delta = state.get('Ph delta', 0) if state.get('Ph delta') is not None else 0
    last_added = state.get('Last added volume', 0)
    return [pH, target_ph, acid_vol, base_vol, tot_vol, steps, error, ph_delta, last_added]

def convert_action(action: tuple) -> list:
    reagent, volume = action
    reagent_idx = reagent_mapping.get(reagent, -1)
    return [reagent_idx, volume]

# -------------------------------
# Main function: generate a successful experiment and aggregate and save the converted transition data
# -------------------------------
def main():
    desired_success = 8
    successful_experiments = []
    acid_types = ['Mono', 'From', 'Tri']
    total_generated = 0
    while len(successful_experiments) < desired_success:
        acid_type = random.choice(acid_types)
        experiment = generate_single_experiment(acid_type)
        total_generated += 1
        if experiment['Success']:
            successful_experiments.append(experiment)
        if total_generated % 100 == 0:
            logging.info("Generate experiments %d times, number of successful experiments:%d", total_generated, len(successful_experiments))
    logging.info("The successful experiment is generated, and a total of experiments are generated. %d Second-rate", total_generated)

    avg_steps = sum(exp['Steps taken'] for exp in successful_experiments) / len(successful_experiments)
    logging.info("Average number of steps for a successful experiment:%.2f", avg_steps)

    all_transitions = [trans for exp in successful_experiments for trans in exp['Transitions']]
    total_samples = len(all_transitions)
    indices = list(range(total_samples))
    random.shuffle(indices)
    observations = [convert_state(all_transitions[i]['State']) for i in indices]
    actions = [convert_action(all_transitions[i]['Action']) for i in indices]
    rewards = [all_transitions[i]['Reward'] for i in indices]
    train_end = int(0.7 * total_samples)
    valid_end = train_end + int(0.15 * total_samples)
    train_set = {
        'Observations': observations[:train_end],
        'Actions': actions[:train_end],
        'Rewards': rewards[:train_end],
    }
    valid_set = {
        'Observations': observations[train_end:valid_end],
        'Actions': actions[train_end:valid_end],
        'Rewards': rewards[train_end:valid_end],
    }
    test_set = {
        'Observations': observations[valid_end:],
        'Actions': actions[valid_end:],
        'Rewards': rewards[valid_end:],
    }

    with open('Train set big new test.json', 'W') as f:
        json.dump(train_set, f, indent=2)
    with open('Validation set big new test.json', 'W') as f:
        json.dump(valid_set, f, indent=2)
    with open('Test set big new test.json', 'W') as f:
        json.dump(test_set, f, indent=2)
    logging.info("The data set is divided: training set %d strips, validation set %d strips, test set %d strip",
                 len(train_set['Observations']), len(valid_set['Observations']), len(test_set['Observations']))

if __name__ == 'Main':
    main()

In [None]:
# Train a discrete regression model

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import json
from torch.utils.data import Dataset, DataLoader
import numpy as np

# Fixed random seeds to ensure reproducible results
np.random.seed(42)
torch.manual_seed(42)

# -------------------------------
# global parameters
# -------------------------------
INPUT_DIM = 5    # Features: current pH, target pH, pH change, error (current pH -target pH) and volume dropped in the previous step
HIDDEN_DIM1 = 256
HIDDEN_DIM2 = 256
BATCH_SIZE = 64
NUM_EPOCHS = 80
LEARNING_RATE = 1e-3

# Discrete action space parameters: Volume range [0.01, 10.00] mL, step size 0.01 mL
MIN_VOLUME = 0.01
MAX_VOLUME = 10.0
STEP = 0.01
NUM_ACTIONS = int((MAX_VOLUME - MIN_VOLUME) / STEP) + 1  # 1000 discrete actions

# -------------------------------
# Dataset: Convert continuous labels to discrete categories
# -------------------------------
class VolumePredictionDataset(Dataset):
    def __init__(self, dataset):
        # Convert observations and actions to numpy arrays
        obs = np.array(dataset['Observations'])
        acts = np.array(dataset['Actions'])
        
        # Only keep samples with action category 0 or 2
        mask = np.isin(acts[:, 0], [0, 2])
        obs = obs[mask]
        acts = acts[mask]
        
        # Extract input features: current pH (index 0), target pH (index 1), pH change (index 7), error (current pH -target pH) and the volume dropped in the previous step (index 8)
        current_ph = obs[:, 0]
        target_ph = obs[:, 1]
        ph_change = obs[:, 7]
        error = current_ph - target_ph
        last_added_volume = obs[:, 8]
        inputs = np.stack([current_ph, target_ph, ph_change, error, last_added_volume], axis=1)
        self.inputs = torch.tensor(inputs, dtype=torch.float32)
        
        # Extract the volume in action (second column) as the regression target,
        # and converted to discrete categories: category index = round((volume -MIN_VOLUME)/STEP)
        continuous_volumes = acts[:, 1]
        indices = np.rint((continuous_volumes - MIN_VOLUME) / STEP).astype(np.int64)
        self.labels = torch.tensor(indices, dtype=torch.long)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

# -------------------------------
# Discrete action strategy model: output NUM_ACTIONS logits, corresponding to discrete volumes of 0.01~10.00 mL
# -------------------------------
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, num_actions=NUM_ACTIONS):
        super(DiscreteVolumeRegressor, self).__init__()
        self.num_actions = num_actions
        self.net = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM1),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM1, HIDDEN_DIM2),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM2, num_actions)
        )
        # Generate a list of discrete volumes used to map category indices to actual volumes
        self.discrete_volumes = [round(MIN_VOLUME + i * STEP, 2) for i in range(num_actions)]
    
    def forward(self, x):
        return self.net(x)
    
    # During inference, use argmax to select the category with the highest probability and map it back to the volume
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        # Convert category index to volume
        predicted_volumes = [self.discrete_volumes[idx] for idx in predicted_indices.tolist()]
        return torch.tensor(predicted_volumes, dtype=torch.float32).unsqueeze(1)
    
    # If sampling action is required, Categorical distribution can also be used
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob

# -------------------------------
# Tool function: load JSON data
# -------------------------------
def load_json_file(filename):
    with open(filename, 'R') as f:
        data = json.load(f)
    return data

# -------------------------------
# Utility function: Evaluate the model on the data loader (compute cross-entropy loss and accuracy)
# -------------------------------
def evaluate(model, dataloader, criterion):
    model.eval()
    total_loss = 0.0
    total_correct = 0
    count = 0
    with torch.no_grad():
        for inputs, labels in dataloader:
            logits = model(inputs)
            loss = criterion(logits, labels)
            total_loss += loss.item() * inputs.size(0)
            _, predicted = torch.max(logits, dim=1)
            total_correct += (predicted == labels).sum().item()
            count += inputs.size(0)
    avg_loss = total_loss / count
    accuracy = total_correct / count
    return avg_loss, accuracy

# -------------------------------
# Main training process: load train_set_big.json, validation_set_big.json, test_set_big.json for training, verification and testing
# -------------------------------
def main():
    # Load dataset
    train_data = load_json_file('Train set big new1.json')
    val_data = load_json_file('Validation set big new1.json')
    test_data = load_json_file('Test set big new1.json')
    
    train_dataset = VolumePredictionDataset(train_data)
    val_dataset = VolumePredictionDataset(val_data)
    test_dataset = VolumePredictionDataset(test_data)
    
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    
    # Initializing the model, optimizer and loss function (cross-entropy loss)
    model = DiscreteVolumeRegressor(INPUT_DIM, NUM_ACTIONS)
    optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.CrossEntropyLoss()
    
    best_val_loss = float('Inf')
    for epoch in range(NUM_EPOCHS):
        model.train()
        total_loss = 0.0
        for inputs, labels in train_loader:
            logits = model(inputs)
            loss = criterion(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item() * inputs.size(0)
        avg_train_loss = total_loss / len(train_dataset)
        val_loss, val_acc = evaluate(model, val_loader, criterion)
        print(f"epoch {epoch+1}/{NUM_EPOCHS}, Train Loss: {avg_train_loss:.4f}, Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            torch.save(model.state_dict(), "Volume regressor best big discrete new1 test.pth")
            print("Saved best model with Val Loss: {:.4f}".format(val_loss))
    
    # testing phase
    model.load_state_dict(torch.load("Volume regressor best big discrete new1 test.pth"))
    test_loss, test_acc = evaluate(model, test_loader, criterion)
    print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}")
    
    # Reasoning example: Example input: [current pH=9.0, target pH=2.0, pH change=-0.5, error=7.0, previous drop volume=0.05]
    example_input = [9.0, 2.0, -0.5, 7.0, 0.05]
    input_tensor = torch.tensor(example_input, dtype=torch.float32).unsqueeze(0)
    model.eval()
    with torch.no_grad():
        predicted_volume = model.predict_volume(input_tensor).item()
    print("for input", example_input, "The predicted volume is:", predicted_volume)

if __name__ == 'Main':
    main()


In [None]:
# Test a discrete regression model on the test set

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import json
import numpy as np

# -------------------------------
# Data set class, consistent with training
# -------------------------------
class VolumePredictionDataset(Dataset):
    def __init__(self, dataset):
        # Convert to numpy array
        obs = np.array(dataset['Observations'])
        acts = np.array(dataset['Actions'])
        
        # Only keep samples with action categories 0 and 2
        mask = np.isin(acts[:, 0], [0, 2])
        obs = obs[mask]
        acts = acts[mask]
        
        # Extract input features:
        # Current pH (index 0), target pH (index 1), pH change (index 7), error (current pH -target pH)
        # and the volume added in the previous step (index 8)
        current_ph = obs[:, 0]
        target_ph = obs[:, 1]
        ph_change = obs[:, 7]
        error = current_ph - target_ph
        last_added_volume = obs[:, 8]
        inputs = np.stack([current_ph, target_ph, ph_change, error, last_added_volume], axis=1)
        self.inputs = torch.tensor(inputs, dtype=torch.float32)
        
        # Label: Volume in action (second value), as regression target
        self.labels = torch.tensor(acts[:, 1], dtype=torch.float32).unsqueeze(1)
    
    def __len__(self):
        return len(self.inputs)
    
    def __getitem__(self, idx):
        return self.inputs[idx], self.labels[idx]

# -------------------------------
# discrete action strategy model
# -------------------------------
INPUT_DIM = 5
HIDDEN_DIM1 = 256
HIDDEN_DIM2 = 256

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=INPUT_DIM, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        # Generate discrete action list: 0.01 ~ 10.00 mL
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, HIDDEN_DIM1),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM1, HIDDEN_DIM2),
            nn.ReLU(),
            nn.Linear(HIDDEN_DIM2, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    # During inference, use argmax to select the category with the highest probability, and then map it to the actual volume
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

# -------------------------------
# Tool function: load JSON data
# -------------------------------
def load_json_file(filename):
    with open(filename, 'R') as f:
        data = json.load(f)
    return data

# -------------------------------
# Test set evaluation function: Calculate MSE, MAE and R²
# -------------------------------
def evaluate_model(model, dataloader):
    mse_criterion = nn.MSELoss()
    mae_criterion = nn.L1Loss()
    
    total_mse_loss = 0.0
    total_samples = 0
    all_preds = []
    all_labels = []
    
    model.eval()
    with torch.no_grad():
        for inputs, labels in dataloader:
            # For each sample in each batch, use predict_volume to get the predicted volume
            batch_preds = []
            for i in range(inputs.size(0)):
                x = inputs[i].unsqueeze(0)
                pred = model.predict_volume(x)
                batch_preds.append(pred)
            batch_preds = torch.cat(batch_preds, dim=0)
            
            mse_loss = mse_criterion(batch_preds, labels)
            total_mse_loss += mse_loss.item() * inputs.size(0)
            total_samples += inputs.size(0)
            all_preds.append(batch_preds)
            all_labels.append(labels)
    
    avg_mse_loss = total_mse_loss / total_samples
    
    all_preds = torch.cat(all_preds, dim=0)
    all_labels = torch.cat(all_labels, dim=0)
    mae_loss = mae_criterion(all_preds, all_labels).item()
    
    ss_res = torch.sum((all_labels - all_preds) ** 2)
    mean_labels = torch.mean(all_labels)
    ss_tot = torch.sum((all_labels - mean_labels) ** 2)
    r2_score = 1 - ss_res / ss_tot
    
    print("Test MSE Loss: {:.4f}".format(avg_mse_loss))
    print("Test MAE Loss: {:.4f}".format(mae_loss))
    print("Test R² Score: {:.4f}".format(r2_score.item()))

# -------------------------------
# Main test process
# -------------------------------
if __name__ == 'Main':
    # Load test set data
    test_data = load_json_file('Test set big new1.json')
    test_dataset = VolumePredictionDataset(test_data)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    # Initialize the discrete model and load the pre-trained model state
    model = DiscreteVolumeRegressor(INPUT_DIM, min_volume=0.01, max_volume=10.0, step=0.01)
    model.load_state_dict(torch.load("Volume regressor best big discrete new1 test.pth", map_location=torch.device('Cpu')))
    model.eval()
    
    # Evaluate the model on the test set
    evaluate_model(model, test_loader)


In [None]:
# Reinforcement currently in use

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json

##############################################
# Fixed random seeds to ensure repeatable experiments
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# global constants
##############################################
TITRANT_CONC = 0.1          # Titrant concentration (0.1 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # pH error threshold

##############################################
# pH calculation function: unit acid
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH Calculation Function: Dibasic Acid
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH Calculation Function: Tribasic Acid
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# Reward calculation function (modified version)
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("Step penalty", -0.005)
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    # If overshoot occurs, the overshoot penalty is calculated
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    # If an overshoot occurs in the previous "step", a negative penalty will be applied to the current action volume,
    # If the current action volume is smaller than the last overshoot, a certain positive reward will be given
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    # Cut rewards in non-terminal state
    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 4.0)
        reward_clip_min = reward_config.get("Reward clip min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH simulation environment: PHSimEnv (modified version)
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol  # M l
        self.analyte_conc = analyte_conc          # 0.1 M
        self.titrant_conc = titrant_conc          # 0.1 M
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "Dense lambda": -0.03,
            "Step penalty": 0,
            "Terminal bonus": 3.9, 
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 1,
            "Reward clip max": 4.1,
            "Reward clip min": -4.1,
            "Overshoot volume penalty": 0.1,
            "Overshoot volume bonus": 0.1
        }
        # Randomly generate 30 sets of acid parameters for different acid types
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        # Randomly select acid type and parameters
        self.acid_type = random.choice(['Monoprotic', 'Diprotic', 'Triprotic'])
        if self.acid_type == 'Monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'Diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        else:  # Triprotic
            self.acid_params = random.choice(self.triprotic_pKa_list)
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        # Initialization overshoot related flags
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        # Initialize current pH based on acid type
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:  # Triprotic
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        # Initially set the previous state pH and current pH to be the same
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        # State vector: current pH, target pH, pH change, error, last action volume
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        # Choose to add a base or acid based on the relationship between current pH and target pH
        if self.current_ph < self.target_ph:
            reagent = "Strong base"
            self.base_added_mL += volume
        else:
            reagent = "Strong acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        # Save current pH as previous state
        self.previous_ph = self.current_ph

        # Update the current pH (call the corresponding function based on the acid type)
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        # Calculate rewards using new reward function
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume
        )
        
        # Determine whether overshoot occurs in this step (ie: pH crosses the target pH from one side)
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'Reagent': reagent}

##############################################
# Discrete action strategy model: DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# Online training: updating policy model using REINFORCE algorithm
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=0.99):
    best_error = float('Inf')  # tracking best error
    best_model_state = None    # Store the best model state

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()  # for passing in environment
            next_state, reward, done, _ = env.step(action_scalar)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        # Calculate discounted returns
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Calculate the final error of the current episode
        current_error = abs(env.current_ph - env.target_ph)
        
        # If the current error is better than the optimal error, save the model
        if current_error < best_error:
            best_error = current_error
            best_model_state = policy_model.state_dict().copy()
            torch.save(best_model_state, "Volume regressor best big discrete new1 trained 1 test.pth")
            print(f"episode {episode}, Loss: {loss.item():.4f}, Updated Best Model with Error: {best_error:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")
        elif episode % 50 == 0:  # Only print every 50 steps if not optimal
            total_reward = sum(rewards)
            print(f"episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

##############################################
# Test function: run the experiment and print the details of each step
##############################################
def test_model(policy_model, env, num_experiments=10):
    for i in range(num_experiments):
        print(f"\n==== Experiment {i+1} Start ====")
        state = env.reset()
        print(f"Initial state: {state}")
        print(f"Acid type: {env.acid_type}, parameters: {env.acid_params}, target pH: {env.target_ph}")
        done = False
        steps = 0
        experiment_trace = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar)
            experiment_trace.append((state, action_scalar, info.get('Reagent', '')))
            steps += 1
        for j, (s, a, reagent) in enumerate(experiment_trace):
            print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
        print(f"The experiment is over, the number of shared steps is: {steps}")

##############################################
# Main program: load the pre-trained model (if available) and train and test
##############################################
if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    pretrained_path = "Volume regressor best big discrete new1 test.pth"
    try:
        state_dict = torch.load(pretrained_path, map_location=torch.device('Cpu'))
        policy_model.load_state_dict(state_dict)
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load pretrained model, using randomly initialized model.", e)
    
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=gamma)
    # Save the last model at the end of training (optional)
    torch.save(policy_model.state_dict(), "Volume regressor best big discrete new1 trained 1 test.pth")
    print("Model saved.")
    
    test_model(policy_model, env, num_experiments=10)

In [None]:
# Save while training

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json

##############################################
# Fixed random seeds to ensure repeatable experiments
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# global constants
##############################################
TITRANT_CONC = 0.1          # Titrant concentration (0.1 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # pH error threshold

##############################################
# pH calculation function: unit acid
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH Calculation Function: Dibasic Acid
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH Calculation Function: Tribasic Acid
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# Reward calculation function (modified version)
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("Step penalty", -0.005)
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    # If overshoot occurs, the overshoot penalty is calculated
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    # If an overshoot occurs in the previous "step", a negative penalty will be applied to the current action volume,
    # If the current action volume is smaller than the last overshoot, a certain positive reward will be given
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    # Cut rewards in non-terminal state
    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 4.0)
        reward_clip_min = reward_config.get("Reward clip min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH simulation environment: PHSimEnv (modified version)
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol  # M l
        self.analyte_conc = analyte_conc          # 0.1 M
        self.titrant_conc = titrant_conc          # 0.1 M
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "Dense lambda": -0.03,
            "Step penalty": 0,
            "Terminal bonus": 3.9, 
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 1,
            "Reward clip max": 4.1,
            "Reward clip min": -4.1,
            "Overshoot volume penalty": 0.1,
            "Overshoot volume bonus": 0.1
        }
        # Randomly generate 30 sets of acid parameters for different acid types
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        # Randomly select acid type and parameters
        self.acid_type = random.choice(['Monoprotic', 'Diprotic', 'Triprotic'])
        if self.acid_type == 'Monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'Diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        else:  # Triprotic
            self.acid_params = random.choice(self.triprotic_pKa_list)
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        # Initialization overshoot related flags
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        # Initialize current pH based on acid type
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:  # Triprotic
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        # Initially set the previous state pH and current pH to be the same
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        # State vector: current pH, target pH, pH change, error, last action volume
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        # Choose to add a base or acid based on the relationship between current pH and target pH
        if self.current_ph < self.target_ph:
            reagent = "Strong base"
            self.base_added_mL += volume
        else:
            reagent = "Strong acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        # Save current pH as previous state
        self.previous_ph = self.current_ph

        # Update the current pH (call the corresponding function based on the acid type)
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        # Calculate rewards using new reward function
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume
        )
        
        # Determine whether overshoot occurs in this step (ie: pH crosses the target pH from one side)
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'Reagent': reagent}

##############################################
# Discrete action strategy model: DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# Online training: updating policy model using REINFORCE algorithm
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=0.99):
    best_success_count = 0  # Track best hits
    best_avg_steps = float('Inf')  # Track the average number of steps, initially set to infinity
    best_model_state = None  # Store the best model state
    total_episodes = 0  # The total number of episodes, used to calculate the success rate
    successful_steps = 0  # The total number of steps for a successful episode

    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()  # for passing in environment
            next_state, reward, done, _ = env.step(action_scalar)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        # Calculate discounted returns
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        # Determine whether the current episode is successful
        is_success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD and env.steps < MAX_STEPS
        total_episodes += 1
        if is_success:
            successful_steps += env.steps
        
        # Update success rate and average number of steps
        current_success_count = successful_steps  # Cumulative number of successes
        current_avg_steps = successful_steps / max(1, current_success_count) if current_success_count > 0 else float('Inf')
        
        # Saving conditions: priority is given to success rate, and if the success rate is the same, compare the average number of steps
        if current_success_count > best_success_count or \
           (current_success_count == best_success_count and current_avg_steps < best_avg_steps):
            best_success_count = current_success_count
            best_avg_steps = current_avg_steps
            best_model_state = policy_model.state_dict().copy()
            torch.save(best_model_state, "Volume regressor best big discrete new1 trained 1 test save the best.pth")
            print(f"episode {episode}, Loss: {loss.item():.4f}, Updated Best Model with Success Count: {best_success_count}, Avg Steps: {best_avg_steps:.2f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")
        elif episode % 50 == 0:  # Only print every 50 steps if not optimal
            total_reward = sum(rewards)
            print(f"episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

##############################################
# Test function: run the experiment and print the details of each step
##############################################
def test_model(policy_model, env, num_experiments=10):
    for i in range(num_experiments):
        print(f"\n==== Experiment {i+1} Start ====")
        state = env.reset()
        print(f"Initial state: {state}")
        print(f"Acid type: {env.acid_type}, parameters: {env.acid_params}, target pH: {env.target_ph}")
        done = False
        steps = 0
        experiment_trace = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar)
            experiment_trace.append((state, action_scalar, info.get('Reagent', '')))
            steps += 1
        for j, (s, a, reagent) in enumerate(experiment_trace):
            print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
        print(f"The experiment is over, the number of shared steps is: {steps}")

##############################################
# Main program: load the pre-trained model (if available) and train and test
##############################################
if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    pretrained_path = "Volume regressor best big discrete new1 test.pth"
    try:
        state_dict = torch.load(pretrained_path, map_location=torch.device('Cpu'))
        policy_model.load_state_dict(state_dict)
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load pretrained model, using randomly initialized model.", e)
    
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=1000, gamma=gamma)
    print("Training is completed and the optimal model has been saved.")
    
    test_model(policy_model, env, num_experiments=10)

In [None]:
# Reinforcement learning experiments

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
from scipy.optimize import fsolve

# fixed random seed
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

##############################################
# global constants
##############################################
TITRANT_CONC = 0.1          # Titrant concentration (0.1 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # Ph error threshold

##############################################
# pH calculation function (mono, di, tribasic acids) -remains consistent with training
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# Environment and reward function
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD,
                     prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("Step penalty", -0.005)
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 3.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 4.0)
        reward_clip_min = reward_config.get("Reward clip min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.titrant_conc = titrant_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "Dense lambda": 1.0,
            "Step penalty": -0.005,
            "Terminal bonus": 80,
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 1.0,
            "Reward clip max": 2.0,
            "Reward clip min": -2.0
        }
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.diprotic_pKa_list.append((pKa1, pKa2))
        self.triprotic_pKa_list = []
        for _ in range(30):
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.triprotic_pKa_list.append((pKa1, pKa2, pKa3))
        self.reset()

    def reset(self):
        self.acid_type = random.choice(['Monoprotic', 'Diprotic', 'Triprotic'])
        if self.acid_type == 'Monoprotic':
            self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
        elif self.acid_type == 'Diprotic':
            self.acid_params = random.choice(self.diprotic_pKa_list)
        elif self.acid_type == 'Triprotic':
            self.acid_params = random.choice(self.triprotic_pKa_list)
        # The original code here randomly generates the target pH, which will later be overwritten by user input.
        self.target_ph = round(random.uniform(2, 11), 2)
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'Triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)
        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action):
        # The step method here will update the status during automatic simulation.
        # But in manual experiments we use the user-entered pH to update the status,
        # Therefore this method is not called directly.
        pass

##############################################
# Discrete action policy model: action space [0.01, 10] mL, step size 0.01 mL, 1000 discrete actions in total
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

##############################################
# Interactive manual titration experiments
# Description: First enter the initial pH and target pH; then the model will give suggested actions.
#        You enter the measured pH after working in the lab, and the status updates to continue giving recommendations.
##############################################
def interactive_titration_manual(env, policy_model):
    # Enter initial pH and target pH
    try:
        init_ph = float(input("Please enter initial pH value: "))
        target_ph = float(input("Please enter target pH value: "))
    except ValueError:
        print("Input format error, use environment default value.")
        init_ph = env.current_ph
        target_ph = env.target_ph

    # Reset environment and override initial pH and target pH
    state = env.reset()
    env.current_ph = init_ph
    env.previous_ph = init_ph
    env.target_ph = target_ph
    print(f"\nInitial pH: {env.current_ph:.2f}, target pH: {env.target_ph:.2f}\n")

    done = False
    while not done:
        # Print current status
        print(f"Current pH: {env.current_ph:.2f}")
        # Update state vector (with latest pH value)
        state = env._get_state()
        state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
        
        # Based on the current status, the model gives a recommended adding volume
        with torch.no_grad():
            recommended_action, _ = policy_model.sample_action(state_tensor)
            recommended_volume = recommended_action.item()
        
        # Determine recommended reagents based on current pH and target pH (consistent with simulation environment logic)
        if env.current_ph < env.target_ph:
            recommended_reagent = "Strong base"
        else:
            recommended_reagent = "Strong acid"
        print(f"Suggestion: Join {recommended_volume:.2f} M l {recommended_reagent}")
        
        # Allow the user to choose whether to follow the recommendations directly or enter a custom volume
        user_choice = input("Is the recommended volume used? (Press enter directly to use, n enter a custom volume): ")
        if user_choice.strip().lower() == "N":
            try:
                action = float(input("Please enter the actual added volume (mL): "))
            except ValueError:
                print("Input error, use recommended value.")
                action = recommended_volume
        else:
            action = recommended_volume
        
        # Prompts the user to enter the measured pH value after working in the laboratory
        measured_ph = None
        while measured_ph is None:
            try:
                measured_ph = float(input("Please enter the pH value measured after the operation: "))
            except ValueError:
                print("Input format error, please enter a number.")
        
        # Update status: record the previous pH and update the current pH to the measured value entered by the user
        env.previous_ph = env.current_ph
        env.current_ph = measured_ph
        env.last_action_volume = action
        env.steps += 1
        
        # Update the cumulative amount of reagent added (according to the fixed logic in the simulation: if the previous pH is less than the target, add alkali, otherwise add acid)
        if env.previous_ph < env.target_ph:
            env.base_added_mL += action
            reagent_used = "Strong base"
        else:
            env.acid_added_mL += action
            reagent_used = "Strong acid"
        env.total_volume = env.initial_acid_vol + env.base_added_mL + env.acid_added_mL
        
        print(f"Action: Join {action:.2f} M l {reagent_used}, measured pH: {env.current_ph:.2f}\n")
        
        # Check whether the termination condition is met
        if abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD:
            print("Target pH successfully achieved!")
            done = True
        elif env.steps >= MAX_STEPS:
            print("When the maximum number of steps is reached, the experiment ends.")
            done = True
                                
    print(f"The experiment is over and shared {env.steps} Step, final pH: {env.current_ph:.2f}")

##############################################
# Main program: Load the pre-trained model (if it exists) and enter interactive mode
##############################################
if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    # Initialize environment
    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=0.1)
    
    # Initialize the discrete action model
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    # Try to load the pre-trained model status (please keep the file name consistent)
    try:
        policy_model.load_state_dict(torch.load("Volume regressor best big discrete new1 trained 1 test.pth", map_location=torch.device('Cpu')))
        print("Loading the discrete pre-trained model was successful. \n")
    except Exception as e:
        print("Failed to load discrete pretrained model, using randomly initialized model. \n", e)
    
    policy_model.eval()
    
    # Enter interactive manual titration experiment mode
    interactive_titration_manual(env, policy_model)


In [None]:
# Do the same experiment online

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
import logging
from scipy.optimize import fsolve
import csv
import ast

# Configuration log
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# fixed random seed
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# global constants
TITRANT_CONC1 = 0.1         # Main titrant concentration (0.1 M)
TITRANT_CONC2 = 0.01        # Secondary titrant concentration (0.01 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # Ph error threshold
MIN_ADDITION_VOLUME = 0.01  # Minimum drop volume (mL)

REAGENTS = {
    'Strong base 1': TITRANT_CONC1,
    'Strong base 2': TITRANT_CONC2,
    'Strong acid 1': TITRANT_CONC1,
    'Strong acid 2': TITRANT_CONC2,
}

# pH calculation function
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

# Environment and reward function
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("Step penalty", -0.005)
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 80.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 2.0)
        reward_clip_min = reward_config.get("Reward clip min", -2.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = MIN_ADDITION_VOLUME
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1001)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() for volume in self.addition_volumes]
        self.reward_config = {
            "Dense lambda": 1.0,
            "Step penalty": -0.005,
            "Terminal bonus": 80,
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 60.0,
            "Reward clip max": 2.0,
            "Reward clip min": -2.0,
            "Overshoot volume penalty": 0.1,
            "Overshoot volume bonus": 0.1
        }
        # Do not call reset during initialization, wait for test_model to provide CSV parameters
        self.acid_type = None
        self.acid_params = None
        self.target_ph = None
        self.current_ph = None

    def reset(self, acid_type=None, acid_params=None, target_ph=None, initial_ph=None):
        if acid_type is None or acid_params is None:
            raise ValueError("acid_type and acid_params must be provided from CSV data")
        
        self.acid_type = acid_type
        self.acid_params = acid_params
        self.target_ph = float(target_ph) if target_ph is not None else 7.0
        self.base1_added_mL = 0.0
        self.base2_added_mL = 0.0
        self.acid1_added_mL = 0.0
        self.acid2_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.last_added_moles = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        
        if self.acid_type == 'Monoprotic':
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_monoprotic(0.0, 0.0, 0.0, 0.0, float(self.acid_params))
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_diprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'Triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_triprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2, pKa3)
        else:
            raise ValueError(f"Unknown acid_type: {self.acid_type}")
        
        self.previous_ph = self.current_ph
        self.last_measured_ph = self.current_ph
        self.prev_measured_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2) if self.current_ph is not None and self.previous_ph is not None else 0.0
        error = round(self.current_ph - self.target_ph, 2) if self.current_ph is not None and self.target_ph is not None else 0.0
        return np.array([self.current_ph or 0.0, self.target_ph or 7.0, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def step(self, action):
        reagent, volume = action
        volume = float(volume)
        self.last_action_volume = volume
        self.last_added_moles = self.reagents[reagent] * (volume / 1000.0)
        self.steps += 1
        self.previous_ph = self.current_ph
        self.prev_measured_ph = self.last_measured_ph
        if reagent == 'Strong base 1':
            self.base1_added_mL += volume
        elif reagent == 'Strong base 2':
            self.base2_added_mL += volume
        elif reagent == 'Strong acid 1':
            self.acid1_added_mL += volume
        elif reagent == 'Strong acid 2':
            self.acid2_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base1_added_mL + self.base2_added_mL + self.acid1_added_mL + self.acid2_added_mL
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, float(self.acid_params)
            )
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2
            )
        elif self.acid_type == 'Triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2, pKa3
            )
        self.last_measured_ph = self.current_ph

        if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
            if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                self.oscillation_count += 1
                logging.info(f"PH oscillation at minimum dripping volume was detected, cumulative number of times:{self.oscillation_count}")
                if self.oscillation_count >= 3:
                    self.use_secondary_reagents = True
                    logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")

        overshoot_flag, new_threshold = self.detect_overshoot(
            self.previous_ph, self.current_ph, self.target_ph, reagent,
            self.last_added_moles, self.reagents[reagent], self.min_addition_volume
        )
        if overshoot_flag:
            self.overshoot_occurred = True
            self.overshoot_reagent = reagent
            if new_threshold is not None:
                if self.overshoot_threshold is None or new_threshold < self.overshoot_threshold:
                    self.overshoot_threshold = new_threshold

        state = self._get_state()
        reward, done = calculate_reward(
            self.previous_ph, self.current_ph, self.target_ph, self.steps, MAX_STEPS, reagent,
            self.reward_config, SUCCESS_THRESHOLD, self.overshoot_occurred, self.overshoot_threshold, volume
        )
        return state, reward, done, {'Reagent': reagent}

    def select_best_action(self, state_tensor, policy_model):
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        
        logging.info(f"Candidate actions: {candidate_actions[:5]}... (common{len(candidate_actions)}indivual)")
        
        with torch.no_grad():
            logits = policy_model(state_tensor)
            candidate_indices = [self.addition_volumes.index(a[1]) for a in candidate_actions]
            candidate_logits = logits[0, candidate_indices]
            best_index = candidate_indices[candidate_logits.argmax().item()]
            best_action = candidate_actions[candidate_logits.argmax().item()]
        
        logging.info(f"Select action: {best_action}")
        return best_action

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

def load_experiment_conditions(csv_file):
    experiments = []
    with open(csv_file, 'R', encoding='Utf 8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            acid_type = row['Acid type']
            acid_params = ast.literal_eval(row['Acid params'])
            initial_ph = float(row['Initial p h'])
            target_ph = float(row['Target p h'])
            experiments.append({
                'Acid type': acid_type,
                'Acid params': acid_params,
                'Initial ph': initial_ph,
                'Target ph': target_ph
            })
    return experiments

def test_model(policy_model, csv_file="Experiment summary.csv", output_file="Test output2 modified.txt", summary_file="Experiment summary rl.csv"):
    experiments = load_experiment_conditions(csv_file)
    num_experiments = len(experiments)
    success_count = 0
    total_steps_success = []
    
    with open(output_file, 'W', encoding='Utf 8') as f, open(summary_file, 'W', newline='', encoding='Utf 8') as summary_f:
        def log_and_print(message):
            print(message)
            f.write(message + '\n')
        
        csv_writer = csv.writer(summary_f)
        csv_writer.writerow(['Experiment', 'Acid type', 'Acid params', 'Initial p h', 'Target p h', 'Final p h', 'Steps taken', 'Success'])
        
        for i, exp in enumerate(experiments, 1):
            log_and_print(f"\n==== Experiment {i} Start ====")
            acid_type = exp['Acid type']
            acid_params = exp['Acid params']
            initial_ph = exp['Initial ph']
            target_ph = exp['Target ph']
            state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph, initial_ph=initial_ph)
            log_and_print(f"Initial state: {state}")
            log_and_print(f"Acid type: {acid_type}, parameters: {acid_params}, target pH: {target_ph}")
            done = False
            steps = 0
            experiment_trace = []
            while not done:
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
                log_and_print(f"Current state vector: {state}")
                action = env.select_best_action(state_tensor, policy_model)
                state, reward, done, info = env.step(action)
                experiment_trace.append((state, action[1], info.get('Reagent', '')))
                steps += 1
            log_and_print("Status Action Reagent pair:")
            for j, (s, a, reagent) in enumerate(experiment_trace):
                log_and_print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
            log_and_print(f"The experiment is over, the number of shared steps is: {steps}, final pH: {env.current_ph:.2f}")
            success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD
            if success:
                success_count += 1
                total_steps_success.append(steps)
            
            acid_params_str = f"{acid_params}" if isinstance(acid_params, (list, tuple)) else f"{acid_params:.2f}"
            csv_writer.writerow([i, acid_type, acid_params_str, f"{initial_ph:.2f}", f"{target_ph:.2f}",
                                f"{env.current_ph:.2f}", steps, 'Yes' if success else 'No'])
        
        success_rate = success_count / num_experiments * 100
        avg_steps_success = np.mean(total_steps_success) if total_steps_success else 0
        summary_stats = f"\nTest completed: success rate = {success_rate:.2f}%, average number of steps for successful experiments = {avg_steps_success:.2f}"
        log_and_print(summary_stats)

if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1)
    
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    try:
        policy_model.load_state_dict(torch.load("Volume regressor best big discrete new1 trained 1 test save the best.pth", map_location=torch.device('Cpu')))
        print("Loading the discrete pre-trained model was successful.")
    except Exception as e:
        print("Failed to load discrete pretrained model, using randomly initialized model.", e)
    
    policy_model.eval()
    
    test_model(policy_model, csv_file="Experiment summary.csv", output_file="Save the optimal neural network for the same experiment using only concentrated acids and bases.txt", summary_file="Save the optimal neural network for the same experiment using only concentrated acids and bases.csv")

In [None]:
# Evaluate the model before reinforcement

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import math
import random
import logging
from scipy.optimize import fsolve
import csv
import ast

# Configuration log
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# fixed random seed
seed = 555
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

# global constants
TITRANT_CONC1 = 0.1         # Main titrant concentration (0.1 M)
TITRANT_CONC2 = 0.1        # Secondary titrant concentration (0.01 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # Ph error threshold
MIN_ADDITION_VOLUME = 0.01  # Minimum drop volume (mL)

REAGENTS = {
    'Strong base 1': TITRANT_CONC1,
    'Strong base 2': TITRANT_CONC2,
    'Strong acid 1': TITRANT_CONC1,
    'Strong acid 2': TITRANT_CONC2,
}

# pH calculation function
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base1_mL: float, base2_mL: float, acid1_mL: float, acid2_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    V_total = (acid_vol_mL + base1_mL + base2_mL + acid1_mL + acid2_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = (base1_mL * TITRANT_CONC1 + base2_mL * TITRANT_CONC2) / 1000.0 / V_total
    c_HCl = (acid1_mL * TITRANT_CONC1 + acid2_mL * TITRANT_CONC2) / 1000.0 / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

# Environment and reward function
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag=None, prev_overshoot_volume=None, last_action_volume=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio)
    step_penalty = reward_config.get("Step penalty", -0.005)
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold))))
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph)
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None and last_action_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume)
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 80.0) * bonus_factor
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 2.0)
        reward_clip_min = reward_config.get("Reward clip min", -2.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = MIN_ADDITION_VOLUME
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1001)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys() for volume in self.addition_volumes]
        self.reward_config = {
            "Dense lambda": 1.0,
            "Step penalty": -0.005,
            "Terminal bonus": 80,
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 60.0,
            "Reward clip max": 2.0,
            "Reward clip min": -2.0,
            "Overshoot volume penalty": 0.1,
            "Overshoot volume bonus": 0.1
        }
        # Do not call reset during initialization, wait for test_model to provide CSV parameters
        self.acid_type = None
        self.acid_params = None
        self.target_ph = None
        self.current_ph = None

    def reset(self, acid_type=None, acid_params=None, target_ph=None, initial_ph=None):
        if acid_type is None or acid_params is None:
            raise ValueError("acid_type and acid_params must be provided from CSV data")
        
        self.acid_type = acid_type
        self.acid_params = acid_params
        self.target_ph = float(target_ph) if target_ph is not None else 7.0
        self.base1_added_mL = 0.0
        self.base2_added_mL = 0.0
        self.acid1_added_mL = 0.0
        self.acid2_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.last_added_moles = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        
        if self.acid_type == 'Monoprotic':
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_monoprotic(0.0, 0.0, 0.0, 0.0, float(self.acid_params))
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_diprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2)
        elif self.acid_type == 'Triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = initial_ph if initial_ph is not None else calculate_pH_triprotic(0.0, 0.0, 0.0, 0.0, pKa1, pKa2, pKa3)
        else:
            raise ValueError(f"Unknown acid_type: {self.acid_type}")
        
        self.previous_ph = self.current_ph
        self.last_measured_ph = self.current_ph
        self.prev_measured_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2) if self.current_ph is not None and self.previous_ph is not None else 0.0
        error = round(self.current_ph - self.target_ph, 2) if self.current_ph is not None and self.target_ph is not None else 0.0
        return np.array([self.current_ph or 0.0, self.target_ph or 7.0, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def step(self, action):
        reagent, volume = action
        volume = float(volume)
        self.last_action_volume = volume
        self.last_added_moles = self.reagents[reagent] * (volume / 1000.0)
        self.steps += 1
        self.previous_ph = self.current_ph
        self.prev_measured_ph = self.last_measured_ph
        if reagent == 'Strong base 1':
            self.base1_added_mL += volume
        elif reagent == 'Strong base 2':
            self.base2_added_mL += volume
        elif reagent == 'Strong acid 1':
            self.acid1_added_mL += volume
        elif reagent == 'Strong acid 2':
            self.acid2_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base1_added_mL + self.base2_added_mL + self.acid1_added_mL + self.acid2_added_mL
        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, float(self.acid_params)
            )
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2
            )
        elif self.acid_type == 'Triprotic':
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(
                self.base1_added_mL, self.base2_added_mL, self.acid1_added_mL, self.acid2_added_mL, pKa1, pKa2, pKa3
            )
        self.last_measured_ph = self.current_ph

        if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
            if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                self.oscillation_count += 1
                logging.info(f"PH oscillation at minimum dripping volume was detected, cumulative number of times:{self.oscillation_count}")
                if self.oscillation_count >= 3:
                    self.use_secondary_reagents = True
                    logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")

        overshoot_flag, new_threshold = self.detect_overshoot(
            self.previous_ph, self.current_ph, self.target_ph, reagent,
            self.last_added_moles, self.reagents[reagent], self.min_addition_volume
        )
        if overshoot_flag:
            self.overshoot_occurred = True
            self.overshoot_reagent = reagent
            if new_threshold is not None:
                if self.overshoot_threshold is None or new_threshold < self.overshoot_threshold:
                    self.overshoot_threshold = new_threshold

        state = self._get_state()
        reward, done = calculate_reward(
            self.previous_ph, self.current_ph, self.target_ph, self.steps, MAX_STEPS, reagent,
            self.reward_config, SUCCESS_THRESHOLD, self.overshoot_occurred, self.overshoot_threshold, volume
        )
        return state, reward, done, {'Reagent': reagent}

    def select_best_action(self, state_tensor, policy_model):
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates

        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
        
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        
        logging.info(f"Candidate actions: {candidate_actions[:5]}... (common{len(candidate_actions)}indivual)")
        
        with torch.no_grad():
            logits = policy_model(state_tensor)
            candidate_indices = [self.addition_volumes.index(a[1]) for a in candidate_actions]
            candidate_logits = logits[0, candidate_indices]
            best_index = candidate_indices[candidate_logits.argmax().item()]
            best_action = candidate_actions[candidate_logits.argmax().item()]
        
        logging.info(f"Select action: {best_action}")
        return best_action

class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_indices = torch.max(logits, dim=1)
        predicted_volume = self.discrete_volumes[predicted_indices.item()]
        return torch.tensor([[predicted_volume]], dtype=torch.float32)

def load_experiment_conditions(csv_file):
    experiments = []
    with open(csv_file, 'R', encoding='Utf 8') as f:
        reader = csv.DictReader(f)
        for row in reader:
            acid_type = row['Acid type']
            acid_params = ast.literal_eval(row['Acid params'])
            initial_ph = float(row['Initial p h'])
            target_ph = float(row['Target p h'])
            experiments.append({
                'Acid type': acid_type,
                'Acid params': acid_params,
                'Initial ph': initial_ph,
                'Target ph': target_ph
            })
    return experiments

def test_model(policy_model, csv_file="Experiment summary.csv", output_file="Test output2 modified.txt", summary_file="Experiment summary rl.csv"):
    experiments = load_experiment_conditions(csv_file)
    num_experiments = len(experiments)
    success_count = 0
    total_steps_success = []
    
    with open(output_file, 'W', encoding='Utf 8') as f, open(summary_file, 'W', newline='', encoding='Utf 8') as summary_f:
        def log_and_print(message):
            print(message)
            f.write(message + '\n')
        
        csv_writer = csv.writer(summary_f)
        csv_writer.writerow(['Experiment', 'Acid type', 'Acid params', 'Initial p h', 'Target p h', 'Final p h', 'Steps taken', 'Success'])
        
        for i, exp in enumerate(experiments, 1):
            log_and_print(f"\n==== Experiment {i} Start ====")
            acid_type = exp['Acid type']
            acid_params = exp['Acid params']
            initial_ph = exp['Initial ph']
            target_ph = exp['Target ph']
            state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph, initial_ph=initial_ph)
            log_and_print(f"Initial state: {state}")
            log_and_print(f"Acid type: {acid_type}, parameters: {acid_params}, target pH: {target_ph}")
            done = False
            steps = 0
            experiment_trace = []
            while not done:
                state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
                log_and_print(f"Current state vector: {state}")
                action = env.select_best_action(state_tensor, policy_model)
                state, reward, done, info = env.step(action)
                experiment_trace.append((state, action[1], info.get('Reagent', '')))
                steps += 1
            log_and_print("Status Action Reagent pair:")
            for j, (s, a, reagent) in enumerate(experiment_trace):
                log_and_print(f"  Step {j+1}: State = {s}, Action = {a:.4f}, Reagent = {reagent}")
            log_and_print(f"The experiment is over, the number of shared steps is: {steps}, final pH: {env.current_ph:.2f}")
            success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD
            if success:
                success_count += 1
                total_steps_success.append(steps)
            
            acid_params_str = f"{acid_params}" if isinstance(acid_params, (list, tuple)) else f"{acid_params:.2f}"
            csv_writer.writerow([i, acid_type, acid_params_str, f"{initial_ph:.2f}", f"{target_ph:.2f}",
                                f"{env.current_ph:.2f}", steps, 'Yes' if success else 'No'])
        
        success_rate = success_count / num_experiments * 100
        avg_steps_success = np.mean(total_steps_success) if total_steps_success else 0
        summary_stats = f"\nTest completed: success rate = {success_rate:.2f}%, average number of steps for successful experiments = {avg_steps_success:.2f}"
        log_and_print(summary_stats)

if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-3
    gamma = 0.99

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1)
    
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    
    try:
        policy_model.load_state_dict(torch.load("Volume regressor best big discrete new1 test.pth", map_location=torch.device('Cpu')))
        print("Loading the discrete pre-trained model was successful.")
    except Exception as e:
        print("Failed to load discrete pretrained model, using randomly initialized model.", e)
    
    policy_model.eval()
    
    test_model(policy_model, csv_file="Experiment summary.csv", output_file="The same experiment of neural network before strengthening only uses concentrated acid and alkali.txt", summary_file="The same experiment of neural network before strengthening only uses concentrated acid and alkali.csv")

In [None]:
# Evaluate Bayesian final

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq

# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# 全局参数（与代码2一致）
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.01
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.01
MAX_STEPS = 50

REAGENTS = {
    'Dilute acid 1': HCL_CONC1,
    'Dilute acid 2': HCL_CONC2,
    'Dilute base 1': NAOH_CONC1,
    'Dilute base 2': NAOH_CONC2,
}

# pH 计算函数（直接复用代码2）
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'Last action volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
        if self.acid_type == "Monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "Diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "Triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'Dilute base 2' if self.use_secondary_reagents else 'Dilute base 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'Dilute acid 2' if self.use_secondary_reagents else 'Dilute acid 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'Simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'Acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'Base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'Base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'Acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'Simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("The pH oscillation at the minimum dripping amount was detected, the cumulative number of times:%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['P the'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'Acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'Base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'Acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'Base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("An exception occurs when executing step:%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'Ph adjustment env':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['P the'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['P the'].rvs())
            sampled_total_moles.append(prior['Total moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['P the'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['Total moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='Simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# 辅助函数：计算初始 pH（模拟代码1的行为）
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "Monoprotic":
        pKa = [acid_params]
    elif acid_type == "Diprotic":
        pKa = acid_params
    elif acid_type == "Triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# 修改后的主程序
def main():
    # 固定随机种子以确保可重复性（与代码1一致）
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    for exp in range(num_experiments):
        # 随机生成目标 pH 和初始 pH
        target_ph = round(random.uniform(2, 11), 2)
        env = PHAdjustmentEnv()

        # 随机选择酸类型并生成初始 pH
        env.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
        if env.acid_type == "Monoprotic":
            env.acid_params = random.uniform(2, 6)
        elif env.acid_type == "Diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            env.acid_params = [pKa1, pKa2]
        elif env.acid_type == "Triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            env.acid_params = [pKa1, pKa2, pKa3]
        initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

        # 初始化环境
        env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

        print(f"==== Experiment {exp+1} Start ====")
        initial_state = env.get_state()
        print(f"Initial state: {np.round(initial_state, 2)}")
        print(f"Acid type: {env.acid_type}, parameters: {env.acid_params}, target pH: {env.target_ph}")
        print("Status Action Reagent pair:")

        trace = []
        action, _ = env.select_best_action()
        while not env.done:
            state_before = env.get_state()
            current_ph, reward, done, info = env.step(action, mode='Simulate')
            state_after = env.get_state()
            trace.append((state_after, action, action[0]))  # 记录试剂名称
            action, _ = env.select_best_action()

        for i, (s, a, reagent) in enumerate(trace, start=1):
            s_formatted = np.round(s, 2)
            print(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}")
        print(f"The experiment is over, the number of shared steps is: {env.steps_taken}, final pH: {env.current_ph:.2f}\n")

        if abs(env.current_ph - env.target_ph) < 0.1:
            success_count += 1
            steps_success.append(env.steps_taken)

    success_rate = success_count / num_experiments * 100
    avg_steps = np.mean(steps_success) if steps_success else 0
    print("Total number of experiments: {}, number of successful experiments: {}, success rate: {:.2f}%, average number of steps for successful experiments: {:.2f}".format(
        num_experiments, success_count, success_rate, avg_steps))

if __name__ == 'Main':
    main()

In [None]:
# Bayesian multi-log

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq
import csv

# Configuration log
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Global parameters (consistent with code 2)
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.1
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.1
MAX_STEPS = 50

REAGENTS = {
    'Dilute acid 1': HCL_CONC1,
    'Dilute acid 2': HCL_CONC2,
    'Dilute base 1': NAOH_CONC1,
    'Dilute base 2': NAOH_CONC2,
}

# pH calculation function (directly reuse code 2)
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'Last action volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
        if self.acid_type == "Monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "Diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "Triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'Dilute base 2' if self.use_secondary_reagents else 'Dilute base 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'Dilute acid 2' if self.use_secondary_reagents else 'Dilute acid 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'Simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'Acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'Base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'Base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'Acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'Simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("The pH oscillation at the minimum dripping amount was detected, the cumulative number of times:%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['P the'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'Acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'Base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'Acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'Base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("An exception occurs when executing step:%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'Ph adjustment env':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['P the'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['P the'].rvs())
            sampled_total_moles.append(prior['Total moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['P the'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['Total moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='Simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# Helper function: Calculate initial pH (emulates the behavior of code 1)
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "Monoprotic":
        pKa = [acid_params]
    elif acid_type == "Diprotic":
        pKa = acid_params
    elif acid_type == "Triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# Modified main program
def main():
    # Fixed random seed to ensure reproducibility (consistent with Code 1)
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    # Open file to save experiment records and summaries
    with open('Bayesian only uses concentrated acids and bases.txt', 'W', encoding='Utf 8') as log_file, \
         open('Experiment summary.csv', 'W', newline='', encoding='Utf 8') as summary_file:
        
        # Initialize the CSV writer
        csv_writer = csv.writer(summary_file)
        csv_writer.writerow(['Experiment', 'Acid type', 'Acid params', 'Initial p h', 'Target p h', 'Final p h', 'Steps taken', 'Success'])

        for exp in range(num_experiments):
            # Randomly generate target pH and initial pH
            target_ph = round(random.uniform(2, 11), 2)
            env = PHAdjustmentEnv()

            # Randomly select acid type and generate initial pH
            env.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
            if env.acid_type == "Monoprotic":
                env.acid_params = random.uniform(2, 6)
            elif env.acid_type == "Diprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 7)
                env.acid_params = [pKa1, pKa2]
            elif env.acid_type == "Triprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 6)
                pKa3 = random.uniform(6, 8)
                env.acid_params = [pKa1, pKa2, pKa3]
            initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

            # Initialize environment
            env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

            # Write to log file
            log_file.write(f"==== Experiment {exp+1} Start ====\n")
            initial_state = env.get_state()
            log_file.write(f"Initial state: {np.round(initial_state, 2)}\n")
            log_file.write(f"Acid type: {env.acid_type}, parameters: {env.acid_params}, target pH: {env.target_ph}\n")
            log_file.write("Status Action Reagent Pair:\n")

            trace = []
            action, _ = env.select_best_action()
            while not env.done:
                state_before = env.get_state()
                current_ph, reward, done, info = env.step(action, mode='Simulate')
                state_after = env.get_state()
                trace.append((state_after, action, action[0]))  # Record reagent name
                action, _ = env.select_best_action()

            for i, (s, a, reagent) in enumerate(trace, start=1):
                s_formatted = np.round(s, 2)
                log_file.write(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}\n")
            log_file.write(f"The experiment is over, the number of shared steps is: {env.steps_taken}, final pH: {env.current_ph:.2f}\n\n")

            # Determine whether it is successful
            success = abs(env.current_ph - env.target_ph) < 0.1
            if success:
                success_count += 1
                steps_success.append(env.steps_taken)

            # Write CSV summary
            acid_params_str = f"{env.acid_params}" if isinstance(env.acid_params, list) else f"{env.acid_params:.2f}"
            csv_writer.writerow([exp+1, env.acid_type, acid_params_str, f"{initial_ph:.2f}", f"{env.target_ph:.2f}",
                                f"{env.current_ph:.2f}", env.steps_taken, 'Yes' if success else 'No'])

        success_rate = success_count / num_experiments * 100
        avg_steps = np.mean(steps_success) if steps_success else 0
        summary_stats = f"Total number of experiments: {num_experiments}, number of successful experiments: {success_count}, success rate: {success_rate:.2f}%, average number of steps for successful experiments: {avg_steps:.2f}\n"
        log_file.write(summary_stats)
        print(summary_stats)

if __name__ == 'Main':
    main()

In [None]:
import numpy as np
import math
import random
import logging
from scipy.stats import norm
from scipy.optimize import brentq
import sys
from io import StringIO

# Configuration log
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Global parameters (consistent with code 2)
TITRATED_VOLUME = 11.0
ANALYTE_CONC = 0.1
HCL_CONC1 = 0.1
HCL_CONC2 = 0.1
NAOH_CONC1 = 0.1
NAOH_CONC2 = 0.1
MAX_STEPS = 50

REAGENTS = {
    'Dilute acid 1': HCL_CONC1,
    'Dilute acid 2': HCL_CONC2,
    'Dilute base 1': NAOH_CONC1,
    'Dilute base 2': NAOH_CONC2,
}

# pH calculation function (directly reuse code 2)
def calculate_acid_anion_charge(c_A: float, H: float, pKa_list: list) -> float:
    n = len(pKa_list)
    K = [np.power(10, np.clip(-pKa, -100, 100)) for pKa in pKa_list]
    denominator = 1.0
    cumulative_K = 1.0
    for i in range(n):
        cumulative_K *= K[i]
        denominator += cumulative_K / np.power(H, i + 1, where=H != 0, out=np.array(np.inf))
    H_nA = c_A / denominator if denominator != 0 else 0.0
    anion_charge = 0.0
    cumulative_K = 1.0
    for k in range(1, n + 1):
        cumulative_K *= K[k - 1]
        anion_conc = H_nA * (cumulative_K / np.power(H, k, where=H != 0, out=np.array(np.inf)))
        anion_charge += k * anion_conc
    return anion_charge

class PHAdjustmentEnv:
    def __init__(self):
        self.steps_taken = 0
        self.done = False
        self.total_volume = TITRATED_VOLUME
        self.previous_total_volume = TITRATED_VOLUME
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_acid_added = 0.0
        self.last_base_added = 0.0
        self.reagents = REAGENTS.copy()
        self.min_addition_volume = 0.01
        self.addition_volumes = [self.min_addition_volume * i for i in range(1, 1000)]
        self.action_space = [(reagent, volume) for reagent in self.reagents.keys()
                             for volume in self.addition_volumes]
        self.epsilon = 0
        self.direction_penalty_factor = 60.0
        self.tol = 1e-4
        self.num_buffers = 3
        self.pKa_list = np.random.uniform(2, 6, size=self.num_buffers)
        self.ref_pKa = np.copy(self.pKa_list)
        self.pKa_std = np.full(self.num_buffers, 0.2)
        self.buffer_total_moles = np.random.uniform(1e-6, 0.5, size=self.num_buffers)
        self.initial_ph = None
        self.current_ph = None
        self.previous_ph = None
        self.target_ph = None
        self.max_steps = None
        self.priors = []
        for i in range(self.num_buffers):
            prior = {
                'P the': norm(loc=self.pKa_list[i], scale=0.5),
                'Total moles': norm(loc=self.buffer_total_moles[i], scale=0.005)
            }
            self.priors.append(prior)
        self.vol_ideal_factor = 0.2
        self.ph_rate_threshold = 1.0
        self.ph_rate_bonus_factor = 0.5
        self.last_measured_ph = None
        self.prev_measured_ph = None
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False
        self.acid_type = None
        self.acid_params = None

    def get_state(self):
        pH_delta = self.current_ph - self.previous_ph if self.previous_ph is not None else 0.0
        error = self.current_ph - self.target_ph
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume if hasattr(self, 'Last action volume') else 0.0], dtype=np.float32)

    def initialize(self, init_pH: float, target_pH: float, max_steps: int, initial_volume: float = TITRATED_VOLUME) -> None:
        self.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
        if self.acid_type == "Monoprotic":
            self.acid_params = random.uniform(2, 6)
        elif self.acid_type == "Diprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 7)
            self.acid_params = [pKa1, pKa2]
        elif self.acid_type == "Triprotic":
            pKa1 = random.uniform(2, 4)
            pKa2 = random.uniform(4, 6)
            pKa3 = random.uniform(6, 8)
            self.acid_params = [pKa1, pKa2, pKa3]
        self.initial_ph = init_pH
        self.current_ph = init_pH
        self.previous_ph = init_pH
        self.target_ph = target_pH
        self.max_steps = max_steps
        self.steps_taken = 0
        self.done = False
        self.total_volume = initial_volume
        self.previous_total_volume = initial_volume
        self.acid_added_moles = 0.0
        self.base_added_moles = 0.0
        self.acid_volume = 0.0
        self.base_volume = 0.0
        self.last_measured_ph = init_pH
        self.prev_measured_ph = init_pH
        self.overshoot_threshold = None
        self.overshoot_occurred = False
        self.overshoot_reagent = None
        self.oscillation_count = 0
        self.use_secondary_reagents = False

    def safe_pow10(self, x: float) -> float:
        return np.power(10, np.clip(x, -100, 100))

    def update_exp_ph(self, pH: float) -> None:
        if self.last_measured_ph is not None:
            self.prev_measured_ph = self.last_measured_ph
        else:
            self.prev_measured_ph = pH
        self.current_ph = pH
        self.last_measured_ph = pH

    def get_effective_pka_array(self) -> np.ndarray:
        weight_max = 0.2
        k = 1.0
        pKa_eff_array = np.zeros(self.num_buffers)
        for i in range(self.num_buffers):
            weight_i = weight_max * (1 - np.tanh(k * self.pKa_std[i]))
            pKa_eff_array[i] = self.ref_pKa[i] + weight_i * (self.pKa_list[i] - self.ref_pKa[i])
        return pKa_eff_array

    def compute_required_volume(self) -> float:
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        effective_pKa = self.get_effective_pka_array()
        if self.current_ph < self.target_ph:
            reagent = 'Dilute base 2' if self.use_secondary_reagents else 'Dilute base 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_base = self.base_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = new_base / new_total_volume
                c_HCl_new = self.acid_added_moles / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req
        else:
            reagent = 'Dilute acid 2' if self.use_secondary_reagents else 'Dilute acid 1'
            conc = self.reagents[reagent]
            def f_vol(x):
                add_moles = conc * (x / 1000.0)
                new_acid = self.acid_added_moles + add_moles
                new_total_volume = (TITRATED_VOLUME + self.acid_volume + self.base_volume + x) / 1000.0
                c_A_new = n_analyte / new_total_volume
                c_Na_new = self.base_added_moles / new_total_volume
                c_HCl_new = new_acid / new_total_volume
                pH_new = self.solve_pH(c_A_new, c_Na_new, c_HCl_new, effective_pKa)
                return pH_new - self.target_ph
            try:
                x_req = brentq(f_vol, 0, 10)
            except Exception:
                x_req = 0.0
            return x_req

    def f(self, pH: float, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        H = 10 ** (-pH)
        Kw = 1e-14
        OH = Kw / H
        acid_anion_charge = calculate_acid_anion_charge(c_A, H, pKa_list)
        return H + c_Na - OH - acid_anion_charge - c_HCl

    def solve_pH(self, c_A: float, c_Na: float, c_HCl: float, pKa_list: list) -> float:
        lo, hi = 0.0, 14.0
        for _ in range(100):
            mid = (lo + hi) / 2.0
            f_mid = self.f(mid, c_A, c_Na, c_HCl, pKa_list)
            if abs(f_mid) < 1e-10:
                return mid
            if self.f(lo, c_A, c_Na, c_HCl, pKa_list) * f_mid < 0:
                hi = mid
            else:
                lo = mid
        return (lo + hi) / 2.0

    def step(self, action: tuple, mode: str = 'Simulate') -> tuple:
        if self.done:
            return self.current_ph, 0, self.done, {}
        try:
            reagent, volume = action
            volume = float(volume)
            self.last_action_volume = volume
            added_moles = self.reagents[reagent] * (volume / 1000.0)
            self.previous_ph = self.current_ph
            self.previous_total_volume = self.total_volume
            self.total_volume += volume
            if 'Acid' in reagent.lower():
                self.acid_added_moles += added_moles
                self.acid_volume += volume
                self.last_acid_added = added_moles
            elif 'Base' in reagent.lower():
                self.base_added_moles += added_moles
                self.base_volume += volume
                self.last_base_added = added_moles
            current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
            if current_for_direction > self.target_ph and 'Base' in reagent.lower():
                return self.current_ph, -100, True, {}
            if current_for_direction < self.target_ph and 'Acid' in reagent.lower():
                return self.current_ph, -100, True, {}
            if mode == 'Simulate':
                new_pH = self.recalc_ph()
                self.update_exp_ph(new_pH)
            if self.previous_ph is not None and abs(volume - self.min_addition_volume) < 1e-6:
                if (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0 and abs(self.current_ph - self.previous_ph) > 0.1:
                    self.oscillation_count += 1
                    logging.info("The pH oscillation at the minimum dripping amount was detected, the cumulative number of times:%d", self.oscillation_count)
                    if self.oscillation_count >= 3:
                        self.use_secondary_reagents = True
                        logging.info("When the continuous shaking threshold is reached, switch to secondary reagent titration.")
            self.steps_taken += 1
            if np.isnan(self.current_ph) or self.current_ph < 0 or self.current_ph > 14:
                self.done = True
                return self.current_ph, -100, self.done, {}
            error = abs(self.current_ph - self.target_ph)
            ph_change = abs(self.current_ph - (self.prev_measured_ph if self.prev_measured_ph is not None else self.current_ph))
            bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
            uncertainties = [prior['P the'].std() for prior in self.priors]
            avg_uncertainty = np.mean(uncertainties)
            max_uncertainty = 1.0
            uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
            buffer_mean = np.mean(self.buffer_total_moles)
            ref_buffer = 0.5
            buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
            buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
            alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
            required_vol = self.compute_required_volume()
            combined_value = error + 0.1 * required_vol
            min_vol = self.min_addition_volume
            max_vol = max(self.addition_volumes)
            ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
            current_error = abs(self.current_ph - self.target_ph)
            error_reward = -current_error
            improvement = abs(self.previous_ph - self.target_ph) - current_error
            lambda_cost = 0.05
            action_cost = lambda_cost * ((volume - ideal_volume) ** 2)
            time_penalty = self.steps_taken * 0.1
            reward = improvement + error_reward - action_cost - time_penalty
            dynamic_direction_penalty = self.direction_penalty_factor * (0.5 if current_error > 2.0 else 1.0)
            if self.last_measured_ph is not None:
                current_for_direction = self.last_measured_ph
            if self.target_ph > current_for_direction and 'Acid' in reagent.lower():
                penalty = dynamic_direction_penalty * (self.target_ph - current_for_direction) / max(self.target_ph, 1)
                reward -= penalty
            if self.target_ph < current_for_direction and 'Base' in reagent.lower():
                penalty = dynamic_direction_penalty * (current_for_direction - self.target_ph) / max((14 - self.target_ph), 1)
                reward -= penalty
            if self.steps_taken > 0:
                if 'Acid' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_acid_added
                elif 'Base' in reagent.lower():
                    reagent_conc = self.reagents[reagent]
                    last_added = self.last_base_added
                else:
                    reagent_conc = 1.0
                    last_added = 0.0
                overshoot_flag, new_thresh = self.detect_overshoot(self.previous_ph, self.current_ph,
                                                                   self.target_ph, reagent,
                                                                   last_added, reagent_conc,
                                                                   self.min_addition_volume)
                if overshoot_flag:
                    self.overshoot_occurred = True
                    self.overshoot_reagent = reagent
                    if new_thresh is not None:
                        if self.overshoot_threshold is None or new_thresh < self.overshoot_threshold:
                            self.overshoot_threshold = new_thresh
            if current_error < 0.1 or self.steps_taken >= self.max_steps:
                self.done = True
            return self.current_ph, reward, self.done, {}
        except Exception as e:
            logging.error("An exception occurs when executing step:%s", e)
            self.done = True
            return self.current_ph, -100, self.done, {}

    def detect_overshoot(self, prev_ph, current_ph, target_ph, reagent, last_added_moles, reagent_conc, min_addition):
        overshoot = False
        new_threshold = None
        sign_change = (prev_ph - target_ph) * (current_ph - target_ph) < 0
        error_increased = abs(current_ph - target_ph) > abs(prev_ph - target_ph)
        if sign_change or error_increased:
            overshoot = True
            overshoot_volume = last_added_moles * 1000.0 / reagent_conc
            new_threshold = max(overshoot_volume / 2, min_addition)
        return overshoot, new_threshold

    def env_copy(self) -> 'Ph adjustment env':
        env_copied = PHAdjustmentEnv()
        env_copied.total_volume = self.total_volume
        env_copied.previous_total_volume = self.previous_total_volume
        env_copied.acid_added_moles = self.acid_added_moles
        env_copied.base_added_moles = self.base_added_moles
        env_copied.acid_volume = self.acid_volume
        env_copied.base_volume = self.base_volume
        env_copied.current_ph = self.current_ph
        env_copied.previous_ph = self.previous_ph
        env_copied.target_ph = self.target_ph
        env_copied.steps_taken = self.steps_taken
        env_copied.done = self.done
        env_copied.num_buffers = self.num_buffers
        env_copied.pKa_list = np.copy(self.pKa_list)
        env_copied.buffer_total_moles = np.copy(self.buffer_total_moles)
        env_copied.priors = self.priors.copy()
        env_copied.epsilon = self.epsilon
        env_copied.direction_penalty_factor = self.direction_penalty_factor
        env_copied.tol = self.tol
        env_copied.reagents = self.reagents.copy()
        env_copied.addition_volumes = self.addition_volumes.copy()
        env_copied.action_space = self.action_space.copy()
        env_copied.max_steps = self.max_steps
        env_copied.vol_ideal_factor = self.vol_ideal_factor
        env_copied.ph_rate_threshold = self.ph_rate_threshold
        env_copied.ph_rate_bonus_factor = self.ph_rate_bonus_factor
        env_copied.last_measured_ph = self.last_measured_ph
        env_copied.prev_measured_ph = self.prev_measured_ph
        env_copied.overshoot_threshold = self.overshoot_threshold
        env_copied.oscillation_count = self.oscillation_count
        env_copied.use_secondary_reagents = self.use_secondary_reagents
        env_copied.ref_pKa = np.copy(self.ref_pKa)
        env_copied.pKa_std = np.copy(self.pKa_std)
        env_copied.acid_type = self.acid_type
        env_copied.acid_params = self.acid_params
        return env_copied

    def recalc_ph(self) -> float:
        V_total = (TITRATED_VOLUME + self.acid_volume + self.base_volume) / 1000.0
        n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
        c_A = n_analyte / V_total
        c_Na = self.base_added_moles / V_total
        c_HCl = self.acid_added_moles / V_total
        pKa_list = self.get_effective_pka_array().tolist()
        return self.solve_pH(c_A, c_Na, c_HCl, pKa_list)

    def select_best_action(self) -> tuple:
        def filter_by_global_threshold(candidates):
            if self.overshoot_threshold is not None:
                filtered = [a for a in candidates if a[1] <= self.overshoot_threshold]
                if filtered:
                    return filtered
            return candidates
        current_for_direction = self.last_measured_ph if self.last_measured_ph is not None else self.current_ph
        if self.use_secondary_reagents:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 2' in r.lower()]
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 2' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 2' in r.lower()]
        else:
            if self.overshoot_occurred:
                if 'Base' in self.overshoot_reagent.lower():
                    allowed_reagent = [r for r in self.reagents.keys() if 'Acid 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Base 1' in r.lower()]
                self.overshoot_occurred = False
                self.overshoot_reagent = None
            else:
                if current_for_direction < self.target_ph:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute base 1' in r.lower()]
                else:
                    allowed_reagent = [r for r in self.reagents.keys() if 'Dilute acid 1' in r.lower()]
        candidate_actions = [a for a in self.action_space if a[0] in allowed_reagent]
        candidate_actions = filter_by_global_threshold(candidate_actions)
        error = abs(current_for_direction - self.target_ph)
        ph_change = abs(current_for_direction - (self.prev_measured_ph if self.prev_measured_ph is not None else current_for_direction))
        bonus_factor = 1 + self.ph_rate_bonus_factor * (1 - min(ph_change, self.ph_rate_threshold) / self.ph_rate_threshold)
        uncertainties = [prior['P the'].std() for prior in self.priors]
        avg_uncertainty = np.mean(uncertainties)
        max_uncertainty = 1.0
        uncertainty_factor = 1 - 0.1 * min(avg_uncertainty / max_uncertainty, 1)
        buffer_mean = np.mean(self.buffer_total_moles)
        ref_buffer = 0.5
        buffering_factor = 1.0 + 0.1 * (buffer_mean - ref_buffer)
        buffering_factor = np.clip(buffering_factor, 0.95, 1.05)
        alpha = self.vol_ideal_factor * bonus_factor * uncertainty_factor * buffering_factor
        required_vol = self.compute_required_volume()
        combined_value = error + 0.1 * required_vol
        min_vol = self.min_addition_volume
        max_vol = max(self.addition_volumes)
        ideal_volume = min_vol + (max_vol - min_vol) * np.tanh(alpha * combined_value)
        best_action = min(candidate_actions, key=lambda a: abs(a[1] - ideal_volume))
        return best_action, self.done

    def sample_parameters(self) -> tuple:
        sampled_pKa = []
        sampled_total_moles = []
        for prior in self.priors:
            sampled_pKa.append(prior['P the'].rvs())
            sampled_total_moles.append(prior['Total moles'].rvs())
        return sampled_pKa, sampled_total_moles

    def predict_ph(self, action: tuple, sampled_pKa, sampled_total_moles) -> float:
        env_copy = self.env_copy()
        env_copy.pKa_list = np.array(sampled_pKa)
        env_copy.buffer_total_moles = np.array(sampled_total_moles)
        new_ph = env_copy.recalc_ph()
        return new_ph

    def update_posteriors(self, action: tuple, observed_ph: float) -> None:
        num_particles = 1000
        particles = []
        weights = []
        for _ in range(num_particles):
            sampled_pKa, sampled_total_moles = self.sample_parameters()
            predicted_ph = self.predict_ph(action, sampled_pKa, sampled_total_moles)
            likelihood = norm.pdf(observed_ph, loc=predicted_ph, scale=0.01)
            particles.append((sampled_pKa, sampled_total_moles))
            weights.append(likelihood)
        weights = np.array(weights) + 1e-10
        weights /= np.sum(weights)
        indices = np.random.choice(range(num_particles), size=num_particles, p=weights)
        new_pKa = []
        new_total_moles = []
        new_pKa_std = []
        for i in range(self.num_buffers):
            pKa_samples = np.array([particles[idx][0][i] for idx in indices])
            total_moles_samples = np.array([particles[idx][1][i] for idx in indices])
            mean_pKa = np.mean(pKa_samples)
            std_pKa = np.std(pKa_samples) + 1e-3
            mean_total_moles = np.mean(total_moles_samples)
            std_total_moles = np.std(total_moles_samples) + 1e-3
            new_pKa.append((mean_pKa, std_pKa))
            new_total_moles.append((mean_total_moles, std_total_moles))
            new_pKa_std.append(std_pKa)
        for i in range(self.num_buffers):
            self.priors[i]['P the'] = norm(loc=new_pKa[i][0], scale=new_pKa[i][1])
            self.priors[i]['Total moles'] = norm(loc=new_total_moles[i][0], scale=new_total_moles[i][1])
            self.pKa_list[i] = new_pKa[i][0]
            self.buffer_total_moles[i] = new_total_moles[i][0]
            self.pKa_std[i] = new_pKa_std[i]

    def suggest_next_action(self, action: tuple, observed_ph: float) -> tuple:
        if abs(observed_ph - self.target_ph) < 0.1:
            self.done = True
            return None, True
        new_ph, reward, done, _ = self.step(action, mode='Simulate')
        self.update_posteriors(action, new_ph)
        next_action, _ = self.select_best_action()
        return next_action, done

# Helper function: Calculate initial pH (emulates the behavior of code 1)
def calculate_initial_ph(acid_type: str, acid_params, env: PHAdjustmentEnv):
    n_analyte = (TITRATED_VOLUME / 1000.0) * ANALYTE_CONC
    V_total = TITRATED_VOLUME / 1000.0
    c_A = n_analyte / V_total
    c_Na = 0.0
    c_HCl = 0.0
    if acid_type == "Monoprotic":
        pKa = [acid_params]
    elif acid_type == "Diprotic":
        pKa = acid_params
    elif acid_type == "Triprotic":
        pKa = acid_params
    else:
        return 7.0
    return env.solve_pH(c_A, c_Na, c_HCl, pKa)

# Custom class for outputting to console and file simultaneously
class Tee:
    def __init__(self, *files):
        self.files = files
    def write(self, obj):
        for f in self.files:
            f.write(obj)
            f.flush()
    def flush(self):
        for f in self.files:
            f.flush()

# Modified main program
def main():
    # Fixed random seed to ensure reproducibility (consistent with Code 1)
    seed = 555
    random.seed(seed)
    np.random.seed(seed)

    # Set output file
    output_file = open('Experiment output.txt', 'W', encoding='Utf 8')
    original_stdout = sys.stdout
    sys.stdout = Tee(sys.stdout, output_file)

    num_experiments = 3000
    success_count = 0
    steps_success = []

    try:
        for exp in range(num_experiments):
            # Randomly generate target pH and initial pH
            target_ph = round(random.uniform(2, 11), 2)
            env = PHAdjustmentEnv()

            # Randomly select acid type and generate initial pH
            env.acid_type = random.choice(["Monoprotic", "Diprotic", "Triprotic"])
            if env.acid_type == "Monoprotic":
                env.acid_params = random.uniform(2, 6)
            elif env.acid_type == "Diprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 7)
                env.acid_params = [pKa1, pKa2]
            elif env.acid_type == "Triprotic":
                pKa1 = random.uniform(2, 4)
                pKa2 = random.uniform(4, 6)
                pKa3 = random.uniform(6, 8)
                env.acid_params = [pKa1, pKa2, pKa3]
            initial_ph = calculate_initial_ph(env.acid_type, env.acid_params, env)

            # Initialize environment
            env.initialize(init_pH=initial_ph, target_pH=target_ph, max_steps=MAX_STEPS, initial_volume=TITRATED_VOLUME)

            print(f"==== Experiment {exp+1} Start ====")
            initial_state = env.get_state()
            print(f"Initial state: {np.round(initial_state, 2)}")
            print(f"Acid type: {env.acid_type}, parameters: {env.acid_params}, target pH: {env.target_ph}")
            print("Status Action Reagent pair:")

            trace = []
            action, _ = env.select_best_action()
            while not env.done:
                state_before = env.get_state()
                current_ph, reward, done, info = env.step(action, mode='Simulate')
                state_after = env.get_state()
                trace.append((state_after, action, action[0]))  # Record reagent name
                action, _ = env.select_best_action()

            for i, (s, a, reagent) in enumerate(trace, start=1):
                s_formatted = np.round(s, 2)
                print(f"  Step {i}: State = {s_formatted}, Action = {a[1]:.4f}, Reagent = {reagent}")
            print(f"The experiment is over, the number of shared steps is: {env.steps_taken}, final pH: {env.current_ph:.2f}\n")

            if abs(env.current_ph - env.target_ph) < 0.1:
                success_count += 1
                steps_success.append(env.steps_taken)

        success_rate = success_count / num_experiments * 100
        avg_steps = np.mean(steps_success) if steps_success else 0
        print("Total number of experiments: {}, number of successful experiments: {}, success rate: {:.2f}%, average number of steps for successful experiments: {:.2f}".format(
            num_experiments, success_count, success_rate, avg_steps))
    finally:
        # Restore original stdout and close file
        sys.stdout = original_stdout
        output_file.close()

if __name__ == 'Main':
    main()

In [None]:
# Shap analysis

In [None]:
import torch
import torch.nn as nn
import numpy as np
import shap
import matplotlib.pyplot as plt
import re

# fixed random seed
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# discrete action strategy model
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# Extract state vectors from txt file, up to 100
def extract_state_vectors(file_path, max_vectors=10):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'R', encoding='Utf 8') as f:
        content = f.read()
    
    # Find all state vectors
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# SHAP analysis function
def analyze_shap_importance(state_vectors, model_path="Volume regressor best big discrete new1 trained 1 test.pth", nsamples=500):
    # Initialize model
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('Cpu')))
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load model:", e)
        return None, None
    
    model.eval()
    
    # Wrapping the model into SHAP-available functions
    def model_predict(inputs):
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
        with torch.no_grad():
            outputs = []
            for i in range(inputs.shape[0]):
                volume = model.sample_action(inputs_tensor[i].unsqueeze(0))
                outputs.append(volume)
        return np.array(outputs)
    
    # Convert to NumPy array
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    
    # Using KernelExplainer
    explainer = shap.KernelExplainer(model_predict, state_vectors_np)
    shap_values = explainer.shap_values(state_vectors_np, nsamples=nsamples)
    
    # Feature name
    feature_names = ['Current ph', 'Target ph', 'P h delta', 'Error', 'Last action volume']
    
    # Calculate average SHAP value
    avg_shap = np.abs(shap_values).mean(axis=0)
    total_shap = avg_shap.sum()
    normalized_shap = avg_shap / total_shap if total_shap > 0 else avg_shap
    
    # Print results
    print("\nAverage shap value (absolute contribution, m l):")
    for name, score in zip(feature_names, avg_shap):
        print(f"{name}: {score:.4f}")
    print("\nNormalized shap value (proportion):")
    for name, score in zip(feature_names, normalized_shap):
        print(f"{name}: {score:.4f}")
    
    # Visualization: Bar Chart
    plt.figure(figsize=(8, 6))
    plt.bar(feature_names, normalized_shap)
    plt.xlabel('feature')
    plt.ylabel('Normalized SHAP value')
    plt.title('SHAP feature importance analysis (first 100 state vectors)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("Shap importance.png")
    plt.show()
    
    # Visualization: SHAP Summary Plot
    shap.summary_plot(shap_values, state_vectors_np, feature_names=feature_names, show=False)
    plt.savefig("Shap summary.png")
    plt.show()
    
    return dict(zip(feature_names, avg_shap)), dict(zip(feature_names, normalized_shap))

# main program
if __name__ == "Main":
    # Read txt file and extract up to 100 state vectors
    file_path = "Test output2 modified.txt"
    state_vectors = extract_state_vectors(file_path, max_vectors=500)
    
    print(f"\nExtracted to {len(state_vectors)} a state vector. ")
    
    # Run SHAP analysis
    if state_vectors:
        avg_shap, normalized_shap = analyze_shap_importance(state_vectors)
    else:
        print("The state vector has not been extracted and cannot be analyzed.")

In [None]:
# Shap analysis, error takes the absolute value

In [None]:
import torch
import torch.nn as nn
import numpy as np
import shap
import matplotlib.pyplot as plt
import re

# fixed random seed
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# discrete action strategy model
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# Extract state vectors from txt file, up to 100
def extract_state_vectors(file_path, max_vectors=500):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'R', encoding='Utf 8') as f:
        content = f.read()
    
    # Find all state vectors
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# SHAP analysis function
def analyze_shap_importance(state_vectors, model_path="Volume regressor best big discrete new1 trained 1 test.pth", nsamples=500):
    # Initialize model
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('Cpu')))
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load model:", e)
        return None, None
    
    model.eval()
    
    # Wrapping the model into SHAP-available functions
    def model_predict(inputs):
        inputs_tensor = torch.tensor(inputs, dtype=torch.float32)
        with torch.no_grad():
            outputs = []
            for i in range(inputs.shape[0]):
                volume = model.sample_action(inputs_tensor[i].unsqueeze(0))
                outputs.append(volume)
        return np.array(outputs)
    
    # Convert to NumPy array and take absolute value of error
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    state_vectors_np[:, 3] = np.abs(state_vectors_np[:, 3])  # Get absolute value of error column
    
    # Using KernelExplainer
    explainer = shap.KernelExplainer(model_predict, state_vectors_np)
    shap_values = explainer.shap_values(state_vectors_np, nsamples=nsamples)
    
    # Feature name
    feature_names = ['Current ph', 'Target ph', 'P h delta', 'Error', 'Last action volume']
    
    # Calculate average SHAP value
    avg_shap = np.abs(shap_values).mean(axis=0)
    total_shap = avg_shap.sum()
    normalized_shap = avg_shap / total_shap if total_shap > 0 else avg_shap
    
    # Print results
    print("\nAverage shap value (absolute contribution, m l):")
    for name, score in zip(feature_names, avg_shap):
        print(f"{name}: {score:.4f}")
    print("\nNormalized shap value (proportion):")
    for name, score in zip(feature_names, normalized_shap):
        print(f"{name}: {score:.4f}")
    
    # Visualization: Bar Chart
    plt.figure(figsize=(8, 6))
    plt.bar(feature_names, normalized_shap)
    plt.xlabel('feature')
    plt.ylabel('Normalized SHAP value')
    plt.title('SHAP feature importance analysis (first 500 state vectors)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("Shap importance.png")
    plt.show()
    
    # Visualization: SHAP Summary Plot
    shap.summary_plot(shap_values, state_vectors_np, feature_names=feature_names, show=False)
    plt.savefig("Shap summary.png")
    plt.show()
    
    return dict(zip(feature_names, avg_shap)), dict(zip(feature_names, normalized_shap))

# main program
if __name__ == "Main":
    # Read txt file and extract up to 500 state vectors
    file_path = "Test output2 modified.txt"
    state_vectors = extract_state_vectors(file_path, max_vectors=500)
    
    print(f"\nExtracted to {len(state_vectors)} a state vector. ")
    
    # Run SHAP analysis
    if state_vectors:
        avg_shap, normalized_shap = analyze_shap_importance(state_vectors)
    else:
        print("The state vector has not been extracted and cannot be analyzed.")

In [None]:
# correlation analysis

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import re

# fixed random seed
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# discrete action strategy model
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# Extract state vectors from txt file, up to 100
def extract_state_vectors(file_path, max_vectors=100):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'R', encoding='Utf 8') as f:
        content = f.read()
    
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# Correlation analysis function
def correlation_analysis(file_path, model_path="Volume regressor best big discrete new1 trained 1 test.pth", max_vectors=10000):
    # Extract state vector
    state_vectors = extract_state_vectors(file_path, max_vectors)
    if not state_vectors:
        print("The state vector has not been extracted and cannot be analyzed.")
        return None, None
    
    print(f"\nExtracted to {len(state_vectors)} a state vector. ")
    
    # Initialize model
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('Cpu')))
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load model:", e)
        return None, None
    
    model.eval()
    
    # Get predicted volume
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    volumes = []
    with torch.no_grad():
        for state in state_vectors_np:
            volume = model.sample_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
            volumes.append(volume)
    volumes = np.array(volumes)
    
    # Feature name
    feature_names = ['Current ph', 'Target ph', 'P h delta', 'Error', 'Last action volume']
    
    # Calculate correlation coefficient
    pearson_corrs = {}
    spearman_corrs = {}
    for i, name in enumerate(feature_names):
        pearson_corr, pearson_p = pearsonr(state_vectors_np[:, i], volumes)
        spearman_corr, spearman_p = spearmanr(state_vectors_np[:, i], volumes)
        pearson_corrs[name] = (pearson_corr, pearson_p)
        spearman_corrs[name] = (spearman_corr, spearman_p)
    
    # Print results
    print("\nPearson correlation coefficient (correlation coefficient, p-value):")
    for name, (corr, p) in pearson_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    print("\nSpearman correlation coefficient (correlation coefficient, p-value):")
    for name, (corr, p) in spearman_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    
    # Visualization: Bar Chart
    plt.figure(figsize=(10, 6))
    corr_df = pd.DataFrame({
        'Pearson': [corr for corr, _ in pearson_corrs.values()],
        'Spearman': [corr for corr, _ in spearman_corrs.values()]
    }, index=feature_names)
    corr_df.plot(kind='Bar', ax=plt.gca())
    plt.xlabel('feature')
    plt.ylabel('Correlation coefficient')
    plt.title('Correlation analysis between state vectors and predicted volumes (first 100 state vectors)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("Correlation bar.png")
    plt.show()
    
    # Visualization: Heatmap
    plt.figure(figsize=(8, 6))
    corr_matrix = np.zeros((len(feature_names), 2))
    for i, name in enumerate(feature_names):
        corr_matrix[i, 0] = pearson_corrs[name][0]
        corr_matrix[i, 1] = spearman_corrs[name][0]
    sns.heatmap(corr_matrix, annot=True, xticklabels=['Pearson', 'Spearman'], yticklabels=feature_names, cmap='Coolwarm', vmin=-1, vmax=1)
    plt.title('Correlation heat map')
    plt.tight_layout()
    plt.savefig("Correlation heatmap.png")
    plt.show()
    
    return pearson_corrs, spearman_corrs

# main program
if __name__ == "Main":
    file_path = "Test output2 modified.txt"
    pearson_corrs, spearman_corrs = correlation_analysis(file_path)

In [None]:
# Correlation analysis, error takes the absolute value

In [None]:
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
import re

# fixed random seed
seed = 555
torch.manual_seed(seed)
np.random.seed(seed)

# discrete action strategy model
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2) for i in range(int((max_volume - min_volume) / step) + 1)]
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, len(self.discrete_volumes))
        )
    
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        volume = self.discrete_volumes[action_index.item()]
        return volume

# Extract state vectors from txt file, up to 100
def extract_state_vectors(file_path, max_vectors=10000):
    state_vectors = []
    state_pattern = r"State = \[\s*([\d.-]+)\s+([\d.-]+)\s+([\d.-]+)\s+([\d.+-]+)\s+([\d.+-]+)\]"
    
    with open(file_path, 'R', encoding='Utf 8') as f:
        content = f.read()
    
    matches = re.finditer(state_pattern, content)
    for match in matches:
        state = [float(match.group(i)) for i in range(1, 6)]
        state_vectors.append(np.array(state, dtype=np.float32))
        if len(state_vectors) >= max_vectors:
            break
    
    return state_vectors[:max_vectors]

# Correlation analysis function
def correlation_analysis(file_path, model_path="Volume regressor best big discrete new1 trained 1 test.pth", max_vectors=20000):
    # Extract state vector
    state_vectors = extract_state_vectors(file_path, max_vectors)
    if not state_vectors:
        print("The state vector has not been extracted and cannot be analyzed.")
        return None, None
    
    print(f"\nExtracted to {len(state_vectors)} a state vector. ")
    
    # Initialize model
    model = DiscreteVolumeRegressor()
    try:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('Cpu')))
        print("Loading the pre-trained model successfully.")
    except Exception as e:
        print("Failed to load model:", e)
        return None, None
    
    model.eval()
    
    # Get predicted volume
    state_vectors_np = np.array(state_vectors, dtype=np.float32)
    volumes = []
    with torch.no_grad():
        for state in state_vectors_np:
            volume = model.sample_action(torch.tensor(state, dtype=torch.float32).unsqueeze(0))
            volumes.append(volume)
    volumes = np.array(volumes)
    
    # Feature name
    feature_names = ['Current ph', 'Target ph', 'P h delta', 'Error', 'Last action volume']
    
    # Calculate correlation coefficient
    pearson_corrs = {}
    spearman_corrs = {}
    for i, name in enumerate(feature_names):
        # Take the absolute value of error
        if name == 'Error':
            feature_values = np.abs(state_vectors_np[:, i])
        else:
            feature_values = state_vectors_np[:, i]
        pearson_corr, pearson_p = pearsonr(feature_values, volumes)
        spearman_corr, spearman_p = spearmanr(feature_values, volumes)
        pearson_corrs[name] = (pearson_corr, pearson_p)
        spearman_corrs[name] = (spearman_corr, spearman_p)
    
    # Print results
    print("\nPearson correlation coefficient (correlation coefficient, p-value):")
    for name, (corr, p) in pearson_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    print("\nSpearman correlation coefficient (correlation coefficient, p-value):")
    for name, (corr, p) in spearman_corrs.items():
        print(f"{name}: {corr:.4f} (p={p:.4f})")
    
    # Visualization: Bar Chart
    plt.figure(figsize=(10, 6))
    corr_df = pd.DataFrame({
        'Pearson': [corr for corr, _ in pearson_corrs.values()],
        'Spearman': [corr for corr, _ in spearman_corrs.values()]
    }, index=feature_names)
    corr_df.plot(kind='Bar', ax=plt.gca())
    plt.xlabel('feature')
    plt.ylabel('Correlation coefficient')
    plt.title('Correlation analysis between state vectors and predicted volumes (first 10,000 state vectors)')
    plt.xticks(rotation=45)
    plt.tight_layout()
    plt.savefig("Correlation bar.png")
    plt.show()
    
    # Visualization: Heatmap
    plt.figure(figsize=(8, 6))
    corr_matrix = np.zeros((len(feature_names), 2))
    for i, name in enumerate(feature_names):
        corr_matrix[i, 0] = pearson_corrs[name][0]
        corr_matrix[i, 1] = spearman_corrs[name][0]
    sns.heatmap(corr_matrix, annot=True, xticklabels=['Pearson', 'Spearman'], yticklabels=feature_names, cmap='Coolwarm', vmin=-1, vmax=1)
    plt.title('Correlation heat map')
    plt.tight_layout()
    plt.savefig("Correlation heatmap.png")
    plt.show()
    
    return pearson_corrs, spearman_corrs

# main program
if __name__ == "Main":
    file_path = "Test output2 modified.txt"
    pearson_corrs, spearman_corrs = correlation_analysis(file_path)

In [None]:
# ablation experiment

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import math
from scipy.optimize import fsolve
import json
import os

##############################################
# Fixed random seeds to ensure repeatable experiments
##############################################
seed = 255
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

##############################################
# global constants
##############################################
TITRANT_CONC = 0.1          # Titrant concentration (0.1 M)
MAX_STEPS = 50              # Maximum number of steps
INITIAL_ACID_VOL = 11.0     # Initial volume of weak acid to be titrated (mL)
SUCCESS_THRESHOLD = 0.1     # pH error threshold

##############################################
# pH calculation function: unit acid
##############################################
def f_monoprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term = 10 ** (pH - pKa)
    alpha = term / (1 + term)
    return H + c_Na - OH - c_A * alpha - c_HCl

def solve_pH_monoprotic_balance(c_A: float, c_Na: float, c_HCl: float, pKa: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_monoprotic(mid, c_A, c_Na, c_HCl, pKa)
        if abs(f_mid) < 1e-10:
            return mid
        if f_monoprotic(lo, c_A, c_Na, c_HCl, pKa) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_monoprotic(base_added_mL: float, acid_added_mL: float, pKa: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1  
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC  
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC  
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_monoprotic_balance(c_A, c_Na, c_HCl, pKa), 2)

##############################################
# pH Calculation Function: Dibasic Acid
##############################################
def f_diprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    D = 1 + term1 + term2
    alpha1 = term1 / D
    alpha2 = term2 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_diprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_diprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2)
        if abs(f_mid) < 1e-10:
            return mid
        if f_diprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_diprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_diprotic(c_A, c_Na, c_HCl, pKa1, pKa2), 2)

##############################################
# pH Calculation Function: Tribasic Acid
##############################################
def f_triprotic(pH: float, c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    H = 10 ** (-pH)
    Kw = 1e-14
    OH = Kw / H
    term1 = np.power(10, np.clip(pH - pKa1, -100, 100))
    term2 = np.power(10, np.clip(2 * pH - pKa1 - pKa2, -100, 100))
    term3 = np.power(10, np.clip(3 * pH - pKa1 - pKa2 - pKa3, -100, 100))
    D = 1 + term1 + term2 + term3
    alpha1 = term1 / D
    alpha2 = term2 / D
    alpha3 = term3 / D
    acid_anion_charge = c_A * (alpha1 + 2 * alpha2 + 3 * alpha3)
    return H + c_Na - OH - acid_anion_charge - c_HCl

def solve_pH_triprotic(c_A: float, c_Na: float, c_HCl: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2.0
        f_mid = f_triprotic(mid, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3)
        if abs(f_mid) < 1e-10:
            return mid
        if f_triprotic(lo, c_A, c_Na, c_HCl, pKa1, pKa2, pKa3) * f_mid < 0:
            hi = mid
        else:
            lo = mid
    return (lo + hi) / 2.0

def calculate_pH_triprotic(base_added_mL: float, acid_added_mL: float, pKa1: float, pKa2: float, pKa3: float) -> float:
    acid_vol_mL = INITIAL_ACID_VOL
    acid_conc = 0.1
    n_acid = acid_vol_mL / 1000.0 * acid_conc
    base_conc = TITRANT_CONC
    n_Na = base_added_mL / 1000.0 * base_conc
    acid_added_conc = TITRANT_CONC
    n_HCl = acid_added_mL / 1000.0 * acid_added_conc
    V_total = (acid_vol_mL + base_added_mL + acid_added_mL) / 1000.0
    c_A = n_acid / V_total
    c_Na = n_Na / V_total
    c_HCl = n_HCl / V_total
    return round(solve_pH_triprotic(c_A, c_Na, c_HCl, pKa1, pKa2, pKa3), 2)

##############################################
# Reward calculation function (supports ablation experiments)
##############################################
def calculate_reward(previous_ph, current_ph, target_ph, steps_taken, max_steps, reagent, reward_config, SUCCESS_THRESHOLD, prev_overshoot_flag, prev_overshoot_volume, last_action_volume, ablate_component=None):
    previous_error = abs(previous_ph - target_ph)
    current_error = abs(current_ph - target_ph)
    remaining_ratio = (max_steps - steps_taken) / max_steps
    dense_lambda = reward_config.get("Dense lambda", 1.0)
    dense_reward = dense_lambda * (previous_error - current_error) * (1 + remaining_ratio) if ablate_component != "Dense reward" else 0
    step_penalty = reward_config.get("Step penalty", -0.005) if ablate_component != "Step penalty" else 0
    overshoot_weight = reward_config.get("Overshoot weight", 0.2)
    overshoot_threshold = reward_config.get("Overshoot threshold", 0.1)
    
    if (previous_ph - target_ph) * (current_ph - target_ph) < 0 and max(previous_error, current_error) > overshoot_threshold:
        overshoot_magnitude = abs(current_ph - target_ph)
        overshoot_penalty = -overshoot_weight * (1 / (1 + math.exp(- (overshoot_magnitude - overshoot_threshold)))) if ablate_component != "Overshoot penalty" else 0
    else:
        overshoot_penalty = 0
        
    wrong_dir_factor = reward_config.get("Wrong dir factor", 1.0)
    wrong_dir_penalty = 0
    if (current_ph > target_ph and 'Base' in reagent.lower()) or (current_ph < target_ph and 'Acid' in reagent.lower()):
        wrong_dir_penalty = -wrong_dir_factor * abs(current_ph - target_ph) if ablate_component != "Wrong dir penalty" else 0
    
    volume_penalty = 0
    volume_bonus = 0
    if prev_overshoot_flag and prev_overshoot_volume is not None:
        overshoot_volume_penalty = reward_config.get("Overshoot volume penalty", 0.1)
        volume_penalty = -overshoot_volume_penalty * last_action_volume if ablate_component != "Volume penalty" else 0
        overshoot_volume_bonus = reward_config.get("Overshoot volume bonus", 0.1)
        if last_action_volume < prev_overshoot_volume:
            volume_bonus = overshoot_volume_bonus * (prev_overshoot_volume - last_action_volume) if ablate_component != "Volume bonus" else 0
    
    raw_reward = dense_reward + step_penalty + overshoot_penalty + wrong_dir_penalty + volume_penalty + volume_bonus

    is_terminal = False
    if abs(current_ph - target_ph) < SUCCESS_THRESHOLD or steps_taken >= max_steps:
        is_terminal = True
        bonus_factor = 2.0 if steps_taken < max_steps * 0.5 else 1.0
        terminal_bonus = reward_config.get("Terminal bonus", 3.0) * bonus_factor if ablate_component != "Terminal bonus" else 0
        raw_reward += terminal_bonus

    if not is_terminal:
        reward_clip_max = reward_config.get("Reward clip max", 4.0)
        reward_clip_min = reward_config.get("Reward clip min", -4.0)
        reward = max(min(raw_reward, reward_clip_max), reward_clip_min)
    else:
        reward = raw_reward

    return reward, is_terminal

##############################################
# pH simulation environment: PHSimEnv
##############################################
class PHSimEnv:
    def __init__(self, initial_acid_vol=11.0, analyte_conc=0.1, titrant_conc=0.1):
        self.initial_acid_vol = initial_acid_vol
        self.analyte_conc = analyte_conc
        self.titrant_conc = titrant_conc
        self.n_acid = self.initial_acid_vol / 1000.0 * self.analyte_conc
        self.reward_config = {
            "Dense lambda": -0.03,
            "Step penalty": 0,
            "Terminal bonus": 3.9, 
            "Overshoot weight": 0.2,
            "Overshoot threshold": 0.1,
            "Wrong dir factor": 1,
            "Reward clip max": 4.1,
            "Reward clip min": -4.1,
            "Overshoot volume penalty": 0.1,
            "Overshoot volume bonus": 0.1
        }
        self.monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
        self.diprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 7)) for _ in range(30)]
        self.triprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 6), random.uniform(6, 8)) for _ in range(30)]
        self.reset()

    def reset(self, acid_type=None, acid_params=None, target_ph=None):
        if acid_type is None:
            self.acid_type = random.choice(['Monoprotic', 'Diprotic', 'Triprotic'])
        else:
            self.acid_type = acid_type
        
        if acid_params is None:
            if self.acid_type == 'Monoprotic':
                self.acid_params = float(np.random.choice(self.monoprotic_pKa_list))
            elif self.acid_type == 'Diprotic':
                self.acid_params = random.choice(self.diprotic_pKa_list)
            else:
                self.acid_params = random.choice(self.triprotic_pKa_list)
        else:
            self.acid_params = acid_params

        self.target_ph = round(random.uniform(2, 11), 2) if target_ph is None else target_ph
        self.acid_added_mL = 0.0
        self.base_added_mL = 0.0
        self.total_volume = self.initial_acid_vol
        self.last_action_volume = 0.0
        self.steps = 0
        self.prev_overshoot_flag = False
        self.prev_overshoot_volume = None

        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(0.0, 0.0, pKa=self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(0.0, 0.0, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(0.0, 0.0, pKa1, pKa2, pKa3)

        self.previous_ph = self.current_ph
        return self._get_state()

    def _get_state(self):
        pH_delta = round(self.current_ph - self.previous_ph, 2)
        error = round(self.current_ph - self.target_ph, 2)
        return np.array([self.current_ph, self.target_ph, pH_delta, error, self.last_action_volume], dtype=np.float32)

    def step(self, action, ablate_component=None):
        volume = float(action)
        self.last_action_volume = volume
        self.steps += 1
        if self.current_ph < self.target_ph:
            reagent = "Strong base"
            self.base_added_mL += volume
        else:
            reagent = "Strong acid"
            self.acid_added_mL += volume
        self.total_volume = self.initial_acid_vol + self.base_added_mL + self.acid_added_mL

        self.previous_ph = self.current_ph

        if self.acid_type == 'Monoprotic':
            self.current_ph = calculate_pH_monoprotic(self.base_added_mL, self.acid_added_mL, self.acid_params)
        elif self.acid_type == 'Diprotic':
            pKa1, pKa2 = self.acid_params
            self.current_ph = calculate_pH_diprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2)
        else:
            pKa1, pKa2, pKa3 = self.acid_params
            self.current_ph = calculate_pH_triprotic(self.base_added_mL, self.acid_added_mL, pKa1, pKa2, pKa3)

        state = self._get_state()
        reward, done = calculate_reward(
            previous_ph=self.previous_ph,
            current_ph=self.current_ph,
            target_ph=self.target_ph,
            steps_taken=self.steps,
            max_steps=MAX_STEPS,
            reagent=reagent,
            reward_config=self.reward_config,
            SUCCESS_THRESHOLD=SUCCESS_THRESHOLD,
            prev_overshoot_flag=self.prev_overshoot_flag,
            prev_overshoot_volume=self.prev_overshoot_volume,
            last_action_volume=self.last_action_volume,
            ablate_component=ablate_component
        )
        
        current_overshoot = (self.previous_ph - self.target_ph) * (self.current_ph - self.target_ph) < 0
        if current_overshoot:
            self.prev_overshoot_flag = True
            self.prev_overshoot_volume = self.last_action_volume
        else:
            self.prev_overshoot_flag = False
            self.prev_overshoot_volume = None
        
        return state, reward, done, {'Reagent': reagent}

##############################################
# Discrete action strategy model: DiscreteVolumeRegressor
##############################################
class DiscreteVolumeRegressor(nn.Module):
    def __init__(self, input_dim=5, min_volume=0.01, max_volume=10.0, step=0.01):
        super(DiscreteVolumeRegressor, self).__init__()
        self.discrete_volumes = [round(min_volume + i * step, 2)
                                 for i in range(int((max_volume - min_volume) / step) + 1)]
        self.num_actions = len(self.discrete_volumes)
        self.net = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.ReLU(),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Linear(256, self.num_actions)
        )
        
    def forward(self, x):
        return self.net(x)
    
    def sample_action(self, x):
        logits = self.forward(x)
        if torch.isnan(logits).any():
            print("Logits contain NaN:", logits)
        dist = torch.distributions.Categorical(logits=logits)
        action_index = dist.sample()
        log_prob = dist.log_prob(action_index)
        volume = self.discrete_volumes[action_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32), log_prob
    
    def predict_volume(self, x):
        logits = self.forward(x)
        _, predicted_index = torch.max(logits, dim=1)
        volume = self.discrete_volumes[predicted_index.item()]
        return torch.tensor([[volume]], dtype=torch.float32)

##############################################
# Online training: updating policy model using REINFORCE algorithm
##############################################
def train_reinforce(env, policy_model, optimizer, num_episodes=200, gamma=0.99, ablate_component=None):
    for episode in range(num_episodes):
        state = env.reset()
        done = False
        log_probs = []
        rewards = []
        while not done:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            action, log_prob = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            next_state, reward, done, _ = env.step(action_scalar, ablate_component=ablate_component)
            log_probs.append(log_prob)
            rewards.append(reward)
            state = next_state
        
        returns = []
        R = 0
        for r in reversed(rewards):
            R = r + gamma * R
            returns.insert(0, R)
        returns = torch.tensor(returns, dtype=torch.float32)
        if returns.numel() > 1:
            returns = (returns - returns.mean()) / (returns.std() + 1e-9)
        else:
            returns = returns - returns.mean()
        
        loss = 0
        for log_prob, G in zip(log_probs, returns):
            loss += -log_prob * G

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(policy_model.parameters(), max_norm=1.0)
        optimizer.step()
        
        if episode % 50 == 0:
            total_reward = sum(rewards)
            print(f"episode {episode}, Loss: {loss.item():.4f}, Total Reward: {total_reward:.4f}, Target pH: {env.target_ph:.2f}, Final pH: {env.current_ph:.2f}")

    model_path = f"volume regressor ablation no{ablate_component}200ep.pth" if ablate_component else "Volume regressor full 200ep.pth"
    torch.save(policy_model.state_dict(), model_path)
    print(f"The model has been saved to {model_path}")

##############################################
# Test function: run a fixed 200 experiments and count the success rate and average number of steps
##############################################
def test_model(policy_model, env, test_configs, ablate_component=None):
    success_count = 0
    success_steps = []
    
    for i, config in enumerate(test_configs):
        acid_type, acid_params, target_ph = config
        state = env.reset(acid_type=acid_type, acid_params=acid_params, target_ph=target_ph)
        done = False
        steps = 0
        while not done and steps < MAX_STEPS:
            state_tensor = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
            with torch.no_grad():
                action, _ = policy_model.sample_action(state_tensor)
            action_scalar = action.item()
            state, reward, done, info = env.step(action_scalar, ablate_component=ablate_component)
            steps += 1
        
        is_success = abs(env.current_ph - env.target_ph) < SUCCESS_THRESHOLD and steps <= MAX_STEPS
        if is_success:
            success_count += 1
            success_steps.append(steps)
    
    success_rate = success_count / len(test_configs) * 100
    avg_steps = np.mean(success_steps) if success_steps else 0.0
    
    print(f"\nTest results ({'No ' + ablate_component if ablate_component else 'Full Reward'}):")
    print(f"Success rate: {success_rate:.2f}% ({success_count}/{len(test_configs)})")
    print(f"Average number of steps for successful experiments: {avg_steps:.2f}")
    
    return {"Success rate": success_rate, "Avg steps": avg_steps, "Success count": success_count, "Total experiments": len(test_configs)}

##############################################
# Generate fixed 200 test configurations
##############################################
def generate_test_configs(num_configs=200, seed=123):
    np.random.seed(seed)
    random.seed(seed)
    monoprotic_pKa_list = np.random.uniform(2, 6, size=30)
    diprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 7)) for _ in range(30)]
    triprotic_pKa_list = [(random.uniform(2, 4), random.uniform(4, 6), random.uniform(6, 8)) for _ in range(30)]
    
    configs = []
    for _ in range(num_configs):
        acid_type = random.choice(['Monoprotic', 'Diprotic', 'Triprotic'])
        if acid_type == 'Monoprotic':
            acid_params = float(np.random.choice(monoprotic_pKa_list))
        elif acid_type == 'Diprotic':
            acid_params = random.choice(diprotic_pKa_list)
        else:
            acid_params = random.choice(triprotic_pKa_list)
        target_ph = round(random.uniform(2, 11), 2)
        configs.append((acid_type, acid_params, target_ph))
    
    return configs

##############################################
# Main program: run ablation experiment and test
##############################################
if __name__ == "Main":
    input_dim = 5
    learning_rate = 1e-4
    gamma = 0.99
    num_episodes = 500
    num_test_experiments = 500

    reward_components = [
        "Dense reward",
        "Step penalty",
        "Overshoot penalty",
        "Wrong dir penalty",
        "Volume penalty",
        "Volume bonus",
        "Terminal bonus"
    ]

    env = PHSimEnv(initial_acid_vol=INITIAL_ACID_VOL, analyte_conc=0.1, titrant_conc=TITRANT_CONC)
    test_configs = generate_test_configs(num_configs=num_test_experiments, seed=123)
    
    results = {}
    
    print("\n=== Run training with full reward ===")
    policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
    optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
    train_reinforce(env, policy_model, optimizer, num_episodes=num_episodes, gamma=gamma, ablate_component=None)
    results["Full reward"] = test_model(policy_model, env, test_configs, ablate_component=None)

    for component in reward_components:
        print(f"\n=== Ablation Experiment: Removed {component} ===")
        policy_model = DiscreteVolumeRegressor(input_dim=input_dim, min_volume=0.01, max_volume=10.0, step=0.01)
        optimizer = optim.Adam(policy_model.parameters(), lr=learning_rate)
        train_reinforce(env, policy_model, optimizer, num_episodes=num_episodes, gamma=gamma, ablate_component=component)
        results[f"no{component}"] = test_model(policy_model, env, test_configs, ablate_component=component)
    
    with open("Ablation test results.json", "W") as f:
        json.dump(results, f, indent=4)
    print("\nTest results have been saved to ablation_test_results.json")

In [None]:
# Test pid

In [None]:
import numpy as np
import csv
import ast
import math
import statistics

# Global parameter configuration
TITRANT_CONC = 0.1  # Titrant concentration (mol/L)
INITIAL_ACID_VOL = 11.0  # Initial acid volume (mL)
MAX_STEPS = 50  # Maximum number of steps allowed per experiment
SUCCESS_THRESHOLD = 0.1  # pH error tolerance range
MIN_VOLUME = 0.01  # Minimum drop volume (mL)

# ---Physics and Chemistry Engine Module ---

def get_acid_charge_factor(pH, pKas):
    """Calculate the total number of negative charges ionized by 1 mol of polybasic weak acid at a specific pH (step ionization theory)"""
    H = 10 ** (-pH)
    Kas = [10 ** (-pk) for pk in sorted(pKas)]
    n = len(Kas)
    coeffs = [1.0]
    curr = 1.0
    for K in Kas:
        curr *= K
        coeffs.append(curr)
    terms = [coeffs[i] * (H ** (n - i)) for i in range(n + 1)]
    D = sum(terms)
    avg_charge = sum(i * terms[i] for i in range(n + 1)) / D
    return avg_charge

def charge_balance_equation(pH, c_A, c_Na, c_HCl, pKas):
    """Charge balance equation: [H+] + [Na+] -[OH-] -[Cl-] -[negative charge of acid radical] = 0"""
    H = 10 ** (-pH)
    OH = 1e-14 / H
    acid_neg_charge = c_A * get_acid_charge_factor(pH, pKas)
    return H + c_Na - OH - c_HCl - acid_neg_charge

def solve_pH(base_vol, acid_vol, pKas):
    """Solve for the final pH of the mixed system"""
    total_vol_L = (INITIAL_ACID_VOL + base_vol + acid_vol) / 1000
    c_A = (INITIAL_ACID_VOL * 0.1 / 1000) / total_vol_L
    c_Na = (base_vol * TITRANT_CONC / 1000) / total_vol_L
    c_HCl = (acid_vol * TITRANT_CONC / 1000) / total_vol_L
    
    lo, hi = 0.0, 14.0
    for _ in range(100):
        mid = (lo + hi) / 2
        if charge_balance_equation(mid, c_A, c_Na, c_HCl, pKas) > 0:
            lo = mid
        else:
            hi = mid
    return round((lo + hi) / 2, 2)

# ---Environment and Controller Module ---

class TitrationEnv:
    def __init__(self):
        self.pKas = []
        self.target_ph = 0
        self.base_added = 0.0
        self.acid_added = 0.0
        self.current_ph = 0.0
        self.steps = 0

    def reset_state(self, pKas, target_ph):
        self.pKas = pKas if isinstance(pKas, list) else [pKas]
        self.target_ph = target_ph
        self.base_added = 0.0
        self.acid_added = 0.0
        self.current_ph = solve_pH(0, 0, self.pKas)
        self.steps = 0
        return self.current_ph

    def step(self, volume):
        prev_ph = self.current_ph
        # Decision: Decide whether to add acid or base based on the difference between current pH and target
        reagent = "Base" if self.current_ph < self.target_ph else "Acid"
        
        if reagent == "Base":
            self.base_added += volume
        else:
            self.acid_added += volume
            
        self.current_ph = solve_pH(self.base_added, self.acid_added, self.pKas)
        self.steps += 1
        
        # Overshoot judgment: whether the pH after operation crosses the target value
        is_overshoot = False
        if (prev_ph < self.target_ph and self.current_ph > self.target_ph) or \
           (prev_ph > self.target_ph and self.current_ph < self.target_ph):
            is_overshoot = True
            
        return self.current_ph, reagent, volume, is_overshoot

class PIDController:
    def __init__(self, Kp=0.5, Ki=0.05, Kd=0.02):
        self.Kp, self.Ki, self.Kd = Kp, Ki, Kd
        self.reset()

    def reset(self):
        self.integral = 0
        self.prev_error = None

    def get_volume(self, current_ph, target_ph):
        error = abs(target_ph - current_ph)
        self.integral += error
        derivative = (error - self.prev_error) if self.prev_error is not None else 0
        self.prev_error = error
        
        # Non-linear output: large dose when the distance is far, exponentially reduced dose when the distance is close
        output = self.Kp * error + self.Ki * self.integral + self.Kd * derivative
        if error > 2.0:
            vol = 1.2 * math.log1p(output)
        else:
            vol = 0.05 * output
        return max(round(vol, 3), MIN_VOLUME)

# ---Main logic execution ---

def run_all_experiments(csv_path, report_path):
    env = TitrationEnv()
    pid = PIDController()
    
    all_results = []
    total_steps_global = 0
    total_overshoots_global = 0
    
    # Read CSV data
    with open(csv_path, 'R', encoding='Utf 8') as f:
        reader = list(csv.DictReader(f))

    # Open the TXT file and prepare to write the results
    with open(report_path, 'W', encoding='Utf 8') as out_f:
        out_f.write("Titration experiment operation detailed record report\n")
        out_f.write("="*60 + "\n\n")

        for row in reader:
            exp_id = row['Experiment']
            acid_type = row['Acid type']
            pKas = ast.literal_eval(row['Acid params'])
            target_ph = float(row['Target p h'])
            
            curr_ph = env.reset_state(pKas, target_ph)
            pid.reset()
            
            out_f.write(f"Experiment ID: {exp_id} | Acid type: {acid_type} | Target pH: {target_ph}\n")
            out_f.write(f"Initial state: pH = {curr_ph:.2f}\n")
            
            exp_overshoot_count = 0
            done = False
            
            while not done:
                vol = pid.get_volume(curr_ph, target_ph)
                curr_ph, reagent, added_vol, overshot = env.step(vol)
                
                if overshot: 
                    total_overshoots_global += 1
                
                status_msg = f" steps {env.steps:02d}: Add to {reagent} {added_vol:.3f}mL -> Current pH: {curr_ph:.2f}"
                if overshot:
                    status_msg += " [Overshoot!]"
                out_f.write(status_msg + "\n")
                
                # Determine completion or failure
                if abs(curr_ph - target_ph) <= SUCCESS_THRESHOLD:
                    success = True
                    done = True
                elif env.steps >= MAX_STEPS:
                    success = False
                    done = True
            
            total_steps_global += env.steps
            all_results.append({'Success': success, 'Steps': env.steps})
            
            out_f.write(f"Experimental conclusion: {'success' if success else 'fail'} | Final pH: {curr_ph:.2f} | Total steps: {env.steps}\n")
            out_f.write("-" * 40 + "\n\n")

        # ---Calculate final statistical indicators ---
        total_exps = len(all_results)
        success_count = sum(1 for r in all_results if r['Success'])
        success_rate = (success_count / total_exps) * 100
        
        step_counts = [r['Steps'] for r in all_results]
        avg_steps = statistics.mean(step_counts)
        # Variance calculation
        var_steps = statistics.variance(step_counts) if total_exps > 1 else 0
        # Overshoot rate = number of overshoots /total number of dripping steps
        overshoot_rate = (total_overshoots_global / total_steps_global) * 100

        # Write summary statistics
        out_f.write("\n" + "="*60 + "\n")
        out_f.write("Summary statistics report\n")
        out_f.write("-" * 60 + "\n")
        out_f.write(f"Total number of experiments: {total_exps}\n")
        out_f.write(f"Operation success rate: {success_rate:.2f}%\n")
        out_f.write(f"Average steps (Mean): {avg_steps:.2f}\n")
        out_f.write(f"Step variance (Var): {var_steps:.2f}\n")
        out_f.write(f"Average number of steps±variance: {avg_steps:.2f} ± {var_steps:.2f}\n")
        out_f.write(f"Total number of steps: {total_steps_global}\n")
        out_f.write(f"Total number of overshoots: {total_overshoots_global}\n")
        out_f.write(f"Overall overshoot rate: {overshoot_rate:.2f}% (number of overshoots/total number of steps)\n")
        out_f.write("="*60 + "\n")

    print(f"All experiments completed! Detailed logs have been written to: {report_path}")

if __name__ == "Main":
    # Configure input and output paths
    input_csv = r'C:\users\Admin\Documents\gitHub\ph_adjust_algorithm\upload\experiment_summary.csv'
    output_txt = r'C:\users\Admin\Documents\gitHub\ph_adjust_algorithm\upload\experiment_report.txt'
    
    run_all_experiments(input_csv, output_txt)