In [1]:
import rk_solver_cpp
from scipy.integrate import ode
import SundialsPy as SP
import numpy as np
import cantera as ct
import matplotlib.pyplot as plt
import pandas as pd
from typing import Tuple, List, Dict, Any, Optional
import os
from datetime import datetime
import time
from tqdm import tqdm

In [2]:
fuel = 'nc12h26:1.0'
oxidizer = 'N2:3.76, O2:1.0'
fuel_species = 'nc12h26'
mechanism_file = '/Users/elotech/Downloads/research_code/large_mechanism/n-dodecane.yaml'
rtol = 1e-6
atol = 1e-8

In [3]:
# ============================================================================
# ODE SYSTEM DEFINITION
# ============================================================================

def combustion_rhs(t: float, y: np.ndarray, gas: ct.Solution, pressure: float) -> np.ndarray:
    """Right-hand side of the combustion ODE system.
    
    Args:
        t: Current time
        y: Current state vector [T, Y1, Y2, ...]
        gas: Cantera gas object
        pressure: Constant pressure
    
    Returns:
        dydt: Time derivatives [dT/dt, dY1/dt, dY2/dt, ...]
    """
    # Extract temperature and mass fractions
    T = y[0]
    Y = y[1:]
    
    # Update the gas state
    gas.TPY = T, pressure, Y
    
    # Get thermodynamic properties
    rho = gas.density_mass
    wdot = gas.net_production_rates
    cp = gas.cp_mass
    h = gas.partial_molar_enthalpies
    
    # Calculate temperature derivative (energy equation)
    dTdt = -(np.dot(h, wdot) / (rho * cp))
    
    # Calculate species derivatives (mass conservation)
    dYdt = wdot * gas.molecular_weights / rho
    
    # Combine into full derivative vector
    return np.hstack([dTdt, dYdt])


# ============================================================================
# SOLVER CREATION FUNCTIONS
# ============================================================================

def create_sundials_solver(method: str, y: np.ndarray, t: float, system_size: int, rtol: float, atol: np.ndarray, 
                          gas: ct.Solution, pressure: float, table_id: Optional[SP.arkode.ButcherTable] = None) -> Any:
    """Create a SUNDIALS solver.
    
    Args:
        method: Solver method ('cvode_bdf', 'cvode_adams', 'arkode_erk')
        system_size: Size of the ODE system
        rtol: Relative tolerance
        atol: Absolute tolerance array
        gas: Cantera gas object
        pressure: Constant pressure
    
    Returns:
        solver: Initialized SUNDIALS solver
    """
    if method == 'cvode_bdf':
        solver = SP.cvode.CVodeSolver(
            system_size=system_size,
            rhs_fn=lambda t, y: combustion_rhs(t, y, gas, pressure),
            iter_type=SP.cvode.IterationType.NEWTON
        )
    elif method == 'cvode_adams':
        solver = SP.cvode.CVodeSolver(
            system_size=system_size,
            rhs_fn=lambda t, y: combustion_rhs(t, y, gas, pressure),
            iter_type=SP.cvode.IterationType.FUNCTIONAL
        )
    elif method == 'arkode_erk':
        solver = SP.arkode.ARKodeSolver(
            system_size=system_size,
            explicit_fn=lambda t, y: combustion_rhs(t, y, gas, pressure),
            implicit_fn=None,
            butcher_table=SP.arkode.ButcherTable.ARK548L2SA_ERK_8_4_5 if table_id is None else table_id
        )
    elif method == 'arkode_dirk':
        solver = SP.arkode.ARKodeSolver(
            system_size=system_size,
            explicit_fn=lambda t, y: combustion_rhs(t, y, gas, pressure),
            implicit_fn=lambda t, y: combustion_rhs(t, y, gas, pressure),
            butcher_table=SP.arkode.ButcherTable.SDIRK_2_1_2 if table_id is None else table_id,
            linsol_type=SP.cvode.LinearSolverType.DENSE
        )
        solver._py_explicit = lambda t, y: combustion_rhs(t, y, gas, pressure)
        solver._py_implicit = lambda t, y: combustion_rhs(t, y, gas, pressure)
    else:
        raise ValueError(f"Unknown SUNDIALS method: {method}")
    
    solver.initialize(y, t, rtol, atol)
    return solver

def create_cpp_solver(method: str, t: float, y: np.ndarray, t_end: float, 
                     rtol: float, atol: float, gas: ct.Solution, pressure: float) -> Any:
    """Create a C++ RK solver.
    
    Args:
        method: Solver method ('cpp_rk23', 'cpp_rk45', etc.)
        t: Current time
        y: Current state
        t_end: End time
        rtol: Relative tolerance
        atol: Absolute tolerance
        gas: Cantera gas object
        pressure: Constant pressure
    
    Returns:
        solver: Initialized C++ solver
    """
    if method == 'cpp_rk23':
        return rk_solver_cpp.RK23(
            lambda t, y: combustion_rhs(t, y, gas, pressure), 
            float(t), np.array(y), float(t_end), rtol=rtol, atol=atol
        )
    else:
        raise ValueError(f"Unknown C++ method: {method}")

def create_scipy_solver(method: str, t: float, y: np.ndarray, rtol: float, atol: float,
                       gas: ct.Solution, pressure: float) -> Any:
    """Create a SciPy solver.
    
    Args:
        method: Solver method ('scipy_rk23', 'scipy_bdf', etc.)
        t: Current time
        y: Current state
        rtol: Relative tolerance
        atol: Absolute tolerance
        gas: Cantera gas object
        pressure: Constant pressure
    
    Returns:
        solver: Initialized SciPy solver
    """
    method_parts = method.split('_')
    if len(method_parts) != 3:
        raise ValueError(f"Invalid SciPy method format: {method}")
    
    solver_type, method_name = method_parts[1], method_parts[2]
    solver = ode(lambda t, y: combustion_rhs(t, y, gas, pressure)).set_integrator(
        solver_type, method=method_name, rtol=rtol, atol=atol, nsteps=10000
    )
    solver.set_initial_value(y, t)
    return solver

def create_solver(method: str, gas: ct.Solution, y: np.ndarray, t: float, 
                 rtol: float, atol: float, t_end: Optional[float] = None, pressure: float = ct.one_atm, table_id: Optional[SP.arkode.ButcherTable] = None) -> Any:
    """Create the appropriate solver based on method.
    
    Args:
        method: Solver method string
        gas: Cantera gas object
        y: Current state
        t: Current time
        rtol: Relative tolerance
        atol: Absolute tolerance
        t_end: End time (for some solvers)
    
    Returns:
        solver: Initialized solver
    """
    system_size = 1 + gas.n_species
    
    # Create absolute tolerance array
    if np.isscalar(atol):
        abs_tol = np.ones(system_size) * atol
    else:
        abs_tol = np.asarray(atol)
        if len(abs_tol) == 1:
            abs_tol = np.ones(system_size) * abs_tol[0]
    
    if method.startswith('cvode_') or method.startswith('arkode_'):
        return create_sundials_solver(method, y, t, system_size, rtol, abs_tol, gas, pressure, table_id)
    elif method.startswith('cpp_'):
        return create_cpp_solver(method, t, y, t_end, rtol, atol, gas, pressure)
    elif method.startswith('scipy_'):
        return create_scipy_solver(method, t, y, rtol, atol, gas, pressure)
    else:
        raise ValueError(f"Unknown solver method: {method}")

In [4]:

# ============================================================================
# INTEGRATION FUNCTIONS
# ============================================================================

def integrate_single_step(method: str, gas: ct.Solution, y: np.ndarray, t: float, 
                         timestep: float, rtol: float, atol: float, fuel: str, pressure: float=ct.one_atm, table_id: Optional[SP.arkode.ButcherTable] = None) -> Dict[str, Any]:
    """Integrate one step with the specified method.
    
    Args:
        method: Solver method
        gas: Cantera gas object
        y: Current state
        t: Current time
        timestep: Time step size
        rtol: Relative tolerance
        atol: Absolute tolerance
    
    Returns:
        result: Dictionary with integration results
    """
    t_end = t + timestep
    previous_state = y.copy()
    
    try:
        # Create solver
        solver = create_solver(method, gas, y, t, rtol, atol, t_end, pressure=pressure, table_id=table_id)
        
        # Integrate
        start_time = time.time()
        
        if method.startswith('cpp_'):
            result = rk_solver_cpp.solve_ivp(solver, np.array(t_end))
            new_y = result['y'][-1]
            # ensure that new_y is not empty
            if len(new_y) == 0:
                print("new_y is empty")
                print(result)
        elif method.startswith('scipy_'):
            solver.integrate(t_end)
            new_y = solver.y
        else:  # SUNDIALS
            new_y = solver.solve_single(t_end)
        
        cpu_time = time.time() - start_time
        
        return {
            'success': True,
            't': t_end,
            'y': new_y,
            'cpu_time': cpu_time,
            'fuel_mass_fraction': gas.mass_fraction_dict()[fuel] if fuel in gas.mass_fraction_dict().keys() else 0.0,
            'error': 0.0,
            'message': 'Success',
            'timed_out': False,
            'previous_state': previous_state
        }
        
    except Exception as e:
        print(f"Step {t} failed: {e}")
        return {
            'success': False,
            't': t,
            'y': previous_state,
            'fuel_mass_fraction': gas.mass_fraction_dict()[fuel] if fuel in gas.mass_fraction_dict().keys() else 0.0,
            'cpu_time': 0.0,
            'error': float('inf'),
            'message': str(e),
            'timed_out': False,
            'previous_state': previous_state
        }

def run_integration_experiment(method: str, gas: ct.Solution, y0: np.ndarray, 
                             t0: float, end_time: float, timestep: float,
                             rtol: float, atol: float, species_to_track: List[str],
                             fuel: str, pressure: float=ct.one_atm,
                             time_limit: float = 300.0, table_id: Optional[SP.arkode.ButcherTable] = None) -> Dict[str, Any]:
    """Run a complete integration experiment with the specified method.
    
    Args:
        method: Solver method to test
        gas: Cantera gas object
        y0: Initial state
        t0: Start time
        end_time: End time
        timestep: Time step size
        rtol: Relative tolerance
        atol: Absolute tolerance
        species_to_track: List of species to monitor
        fuel: Fuel name
        time_limit: Maximum allowed wall clock time in seconds (default 300s)
    
    Returns:
        results: Dictionary with complete integration results
    """
    # Initialize tracking arrays
    times = [t0]
    temperatures = [y0[0]]
    species_profiles = {spec: [y0[gas.species_index(spec) + 1]] for spec in species_to_track}
    cpu_times = []
    fuel_mass_fractions = []
    
    # Integration loop
    t = t0
    y = y0.copy()
    step_count = 0
    start_time = time.time()
    
    bar = tqdm(total=end_time, desc=f"Running {method}-{str(table_id)} with rtol={rtol} and atol={atol}")
    while t < end_time:
        bar.update(timestep)
        # Check if time limit exceeded
        if time.time() - start_time > time_limit:
            print(f"Time limit of {time_limit}s exceeded after {step_count} steps")
            break
            
        result = integrate_single_step(method, gas, y, t, timestep, rtol, atol, fuel, pressure=pressure, table_id=table_id)
        
        if not result['success']:
            print(f"Step {step_count} failed: {result['message']}")
            break
        
        # Update state
        y = result['y']
        t = result['t']
        cpu_times.append(result['cpu_time'])
        step_count += 1
        fuel_mass_fractions.append(result['fuel_mass_fraction'])

        # ensure that y is not empty
        if len(y) == 0:
            print(f"Step {step_count} failed: y is empty")
            print(result)
            print(y)
            break
        
        # Record data
        times.append(t)
        temperatures.append(y[0])
        for spec in species_to_track:
            species_profiles[spec].append(y[gas.species_index(spec) + 1])
        
        #print(f"Step {step_count} at time {t:.2e} - temperature {y[0]:.1f}K - CPU time {cpu_times[-1]:.2e}s - time taken {time.time() - start_time:.2f}s | {np.sum(cpu_times):.2e}s")
        bar.set_postfix({
            'step': f"{step_count}",
            'temperature': f"{y[0]:.1f}K",
            'cpu_time': f"{cpu_times[-1]:.2e}s",
            'total_cpu_time': f"{np.sum(cpu_times):.2e}s"
        })
    bar.close() 
    total_wall_time = time.time() - start_time
    return {
        'method': method,
        'phi': gas.equivalence_ratio,
        'rtol': rtol,
        'atol': atol,   
        'times': np.array(times),
        'fuel_mass_fractions': np.array(fuel_mass_fractions),
        'temperatures': np.array(temperatures),
        'species_profiles': species_profiles,
        'cpu_times': np.array(cpu_times),
        'total_cpu_time': np.sum(cpu_times),
        'total_wall_time': total_wall_time,
        'steps': step_count,
        'success': step_count > 0,
        'timed_out': total_wall_time > time_limit
    }


In [5]:
def setup_combustion_chemistry_with_data(mechanism: str,temperature: float, pressure: float, data: np.ndarray) -> ct.Solution:
    """Set up the combustion chemistry with Cantera.
    
    Args:
        mechanism: Path to mechanism file
        fuel: Fuel species name
        oxidizer: Oxidizer mixture string
        phi: Equivalence ratio
        temperature: Initial temperature (K)
        pressure: Initial pressure (Pa)
    
    Returns:
        gas: Initialized Cantera gas object
    """
    gas = ct.Solution(mechanism)
    gas.TPX = temperature, pressure, data
    return gas

def get_initial_state(gas: ct.Solution) -> np.ndarray:
    """Get initial state vector [T, Y1, Y2, ...].
    
    Args:
        gas: Cantera gas object
    
    Returns:
        y: Initial state vector
    """
    return np.hstack([gas.T, gas.Y])

In [6]:
import pickle
from scipy.ndimage import gaussian_filter1d

def detect_ignition_regions(temperature_profile, time_array=None, 
                           gradient_threshold=None, smooth_sigma=1.0,
                           min_ignition_length=5):
    """
    Detect pre-ignition, ignition, and post-ignition regions in a temperature profile.
    
    Parameters:
    -----------
    temperature_profile : array-like
        Temperature values over time
    time_array : array-like, optional
        Time values corresponding to temperature measurements.
        If None, assumes uniform spacing with indices.
    gradient_threshold : float, optional
        Threshold for detecting ignition based on temperature gradient.
        If None, automatically determined from data.
    smooth_sigma : float, default=1.0
        Gaussian smoothing parameter for gradient calculation
    min_ignition_length : int, default=5
        Minimum number of points for ignition region
    
    Returns:
    --------
    tuple : (pre_ignition_end_idx, ignition_start_idx, ignition_end_idx)
        - pre_ignition_end_idx: Last index of pre-ignition region
        - ignition_start_idx: First index of ignition region  
        - ignition_end_idx: Last index of ignition region
        Post-ignition starts at ignition_end_idx + 1
    """
    
    temp = np.array(temperature_profile)
    n_points = len(temp)
    
    if time_array is None:
        time_array = np.arange(n_points)
    else:
        time_array = np.array(time_array)
    
    # Calculate smoothed gradient
    temp_smooth = gaussian_filter1d(temp, sigma=smooth_sigma)
    dt = np.diff(time_array)
    dt = np.append(dt, dt[-1])  # Extend to same length
    gradient = np.gradient(temp_smooth) / dt
    
    # Auto-determine threshold if not provided
    if gradient_threshold is None:
        # Use a multiple of the standard deviation of the gradient
        gradient_std = np.std(gradient)
        gradient_mean = np.mean(gradient)
        gradient_threshold = gradient_mean + 3 * gradient_std
    
    # Find regions where gradient exceeds threshold
    high_gradient_mask = gradient > gradient_threshold
    
    # Find the start and end of the main ignition event
    # Look for the longest continuous region above threshold
    high_gradient_indices = np.where(high_gradient_mask)[0]
    
    if len(high_gradient_indices) == 0:
        # No ignition detected, return boundaries assuming late ignition
        return n_points//3, 2*n_points//3, n_points-1
    
    # Find continuous regions
    diff_indices = np.diff(high_gradient_indices)
    breaks = np.where(diff_indices > 1)[0]
    
    if len(breaks) == 0:
        # Single continuous region
        ignition_start_idx = high_gradient_indices[0]
        ignition_end_idx = high_gradient_indices[-1]
    else:
        # Multiple regions - find the longest one
        region_starts = [high_gradient_indices[0]] + [high_gradient_indices[b+1] for b in breaks]
        region_ends = [high_gradient_indices[b] for b in breaks] + [high_gradient_indices[-1]]
        region_lengths = [end - start for start, end in zip(region_starts, region_ends)]
        
        longest_region_idx = np.argmax(region_lengths)
        ignition_start_idx = region_starts[longest_region_idx]
        ignition_end_idx = region_ends[longest_region_idx]
    
    # Extend ignition region if too short
    if ignition_end_idx - ignition_start_idx < min_ignition_length:
        center = (ignition_start_idx + ignition_end_idx) // 2
        half_length = min_ignition_length // 2
        ignition_start_idx = max(0, center - half_length)
        ignition_end_idx = min(n_points - 1, center + half_length)
    
    # Pre-ignition ends just before ignition starts
    pre_ignition_end_idx = max(0, ignition_start_idx - 1)
    
    # Ensure ignition_end_idx doesn't exceed array bounds
    ignition_end_idx = min(ignition_end_idx, n_points - 1)
    
    return pre_ignition_end_idx, ignition_start_idx, ignition_end_idx


def plot_regions(temperature_profile, time_array=None, region_indices=None):
    """
    Plot temperature profile with detected regions highlighted.
    
    Parameters:
    -----------
    temperature_profile : array-like
        Temperature values
    time_array : array-like, optional
        Time values
    region_indices : tuple, optional
        (pre_ignition_end_idx, ignition_start_idx, ignition_end_idx)
        If None, will detect automatically
    """
    
    if time_array is None:
        time_array = np.arange(len(temperature_profile))
    
    if region_indices is None:
        region_indices = detect_ignition_regions(temperature_profile, time_array)
    
    pre_end, ign_start, ign_end = region_indices
    
    plt.figure(figsize=(10, 6), dpi=200)
    plt.plot(time_array, temperature_profile, 'b-', linewidth=2, label='Temperature')
    
    # Highlight regions
    plt.axvspan(time_array[0], time_array[pre_end], alpha=0.3, color='green', 
                label='Pre-ignition')
    plt.axvspan(time_array[ign_start], time_array[ign_end], alpha=0.3, color='red', 
                label='Ignition')
    plt.axvspan(time_array[ign_end], time_array[-1], alpha=0.3, color='blue', 
                label='Post-ignition')
    
    # Add vertical lines at boundaries
    plt.axvline(time_array[pre_end], color='green', linestyle='--', alpha=0.7)
    plt.axvline(time_array[ign_start], color='red', linestyle='--', alpha=0.7)
    plt.axvline(time_array[ign_end], color='red', linestyle='--', alpha=0.7)
    
    plt.xlabel('Time')
    plt.ylabel('Temperature (K)')
    plt.title('Temperature Profile with Detected Regions')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


In [7]:

def running_average_forward(array, window_size):
    """
    Calculate running average where each point uses the window starting from that point.
    Returns array of same length as input.
    
    Parameters:
    -----------
    array : array-like
        Input array
    window_size : int
        Size of the averaging window
        
    Returns:
    --------
    numpy.ndarray
        Running average array of same length as input
    """
    array = np.array(array)
    n = len(array)
    result = np.zeros(n)
    
    # Calculate number of complete windows
    num_windows = (n + window_size - 1) // window_size
    
    # Process each window
    for w in range(num_windows):
        start_idx = w * window_size
        end_idx = min((w + 1) * window_size, n)
        window_mean = np.mean(array[start_idx:end_idx])
        
        # Fill result array for all points that use this window
        result[start_idx:end_idx] = window_mean
        
    return result

# TOLERANCE ANALYSIS

In [10]:
def calculate_rmse(ref_data, test_data, species_name, use_log=False):
    rmse_dict = {}
    for specie_name in species_name:
        if specie_name == 'temperature':
            ref_profile = ref_data['temperatures']
            test_profile = test_data['temperatures']
        else:
            ref_profile = ref_data['species_profiles'][specie_name]
            test_profile = test_data['species_profiles'][specie_name]
        if use_log:
            ref_profile = np.log10(np.maximum(ref_profile, 1e-20))
            test_profile = np.log10(np.maximum(test_profile, 1e-20))
        else:
            ref_profile = np.array(ref_profile)
            test_profile = np.array(test_profile)
            size = min(ref_profile.shape[0], test_profile.shape[0])
            ref_profile = ref_profile[:size]
            test_profile = test_profile[:size]
        rmse = np.sqrt((ref_profile- test_profile) ** 2)
        rmse_dict[specie_name] = rmse
    return rmse_dict



In [None]:
def reset_gas(temperature=600, pressure=101325, phi=1):
    gas = ct.Solution(mechanism_file)
    gas.set_equivalence_ratio(phi, fuel, 'O2:1, N2:3.76')
    gas.TPX = temperature, pressure, gas.X
    y0 = get_initial_state(gas)
    return gas, y0

In [17]:
implicit_solvers = [SP.arkode.ButcherTable.ARK2_DIRK_3_1_2, SP.arkode.ButcherTable.ESDIRK325L2SA_5_2_3, 
                    SP.arkode.ButcherTable.TRBDF2_3_3_2, SP.arkode.ButcherTable.ESDIRK436L2SA_6_3_4, SP.arkode.ButcherTable.ESDIRK43I6L2SA_6_3_4, 
                    SP.arkode.ButcherTable.QESDIRK436L2SA_6_3_4, 
                    SP.arkode.ButcherTable.CASH_5_2_4, SP.arkode.ButcherTable.CASH_5_3_4, SP.arkode.ButcherTable.SDIRK_5_3_4, 
                    SP.arkode.ButcherTable.ARK436L2SA_DIRK_6_3_4, SP.arkode.ButcherTable.ESDIRK437L2SA_7_3_4, SP.arkode.ButcherTable.ARK437L2SA_DIRK_7_3_4]

print(f"Number of implicit solvers = {len(implicit_solvers)}")

explicit_solvers = [SP.arkode.ButcherTable.HEUN_EULER_2_1_2 , SP.arkode.ButcherTable.BOGACKI_SHAMPINE_4_2_3,
        SP.arkode.ButcherTable.ARK324L2SA_ERK_4_2_3, SP.arkode.ButcherTable.ZONNEVELD_5_3_4,
        SP.arkode.ButcherTable.ARK436L2SA_ERK_6_3_4, SP.arkode.ButcherTable.ARK437L2SA_ERK_7_3_4,
        SP.arkode.ButcherTable.ARK548L2SA_ERK_8_4_5,
        SP.arkode.ButcherTable.VERNER_8_5_6,
        SP.arkode.ButcherTable.FEHLBERG_13_7_8]

print(f"Number of explicit solvers = {len(explicit_solvers)}")

Number of implicit solvers = 12
Number of explicit solvers = 9


In [None]:
solver_to_plot = [SP.arkode.ButcherTable.ARK2_DIRK_3_1_2, SP.arkode.ButcherTable.TRBDF2_3_3_2, SP.arkode.ButcherTable.HEUN_EULER_2_1_2 , SP.arkode.ButcherTable.BOGACKI_SHAMPINE_4_2_3]

print(f"Number of solvers to plot = {len(solver_to_plot)}")

In [None]:
temperature = 500
pressure = 101325
phi = 1
gas, y0 = reset_gas(temperature, pressure, phi)
t0 = 0.0
end_time = 2e-2
timestep = 1e-6
species_to_track = gas.species_names
fuel = 'nc12h26'
time_limit = 120.0
table_id = SP.arkode.ButcherTable.HEUN_EULER_2_1_2


In [None]:
# Define parameters for different runs
run_params = [
    ('reference', 1e-12, 1e-10),
    ('bdf_results_high', 1e-8, 1e-10), 
    ('bdf_results_low', 1e-6, 1e-8)
]

bdf_results = {}

# Loop through parameters and run experiments
for result_name, rtol, atol in run_params:
    gas, y0 = reset_gas(temperature, pressure, phi)
    method = 'cvode_bdf'
    
    results = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=None
            )
    
    bdf_results[result_name] = results
    

In [None]:
solvers_results_high = {}
for table_id in solver_to_plot:
    if table_id in implicit_solvers:
        method = 'arkode_dirk'
    else:
        method = 'arkode_erk'
    gas, y0 = reset_gas(temperature, pressure, phi)
    rtol, atol = 1e-8, 1e-10

    bdf_results = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=table_id
            )
    solvers_results_high[table_id] = bdf_results


In [None]:
solvers_results_low = {}
for table_id in solver_to_plot:
    if table_id in implicit_solvers:
        method = 'arkode_dirk'
    else:
        method = 'arkode_erk'
    gas, y0 = reset_gas(temperature, pressure, phi)
    rtol, atol = 1e-6, 1e-8

    bdf_results = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=table_id
            )
    solvers_results_low[table_id] = bdf_results

In [None]:
# PLOT THE CPU TIME OF THE IMPLICIT SOLVERS IN A 2x2 GRID
line_styles = ['-', '--', '-.', ':'] * 10  # Repeat basic line styles
colors = [ 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan', 'magenta', 'lime', 'teal', 'navy', 'maroon', 'gold', 'silver', 'indigo', 'turquoise']

# Create figure with 2x2 subplots sharing x and y axes
fig, ax = plt.subplots(figsize=(15, 10), dpi=200)

# Plot each group in its own subplot
for i, table_id in enumerate(solver_to_plot):
    # if table_id in solver_to_exclude:
    #     continue
    data = solvers_results_high[table_id]
    ax.plot(
        running_average_forward(data['cpu_times'], 1000),
        label=f"{str(table_id)} - {np.sum(data['cpu_times']):.2e}",
        linestyle=line_styles[i],
        linewidth=2,
        color=colors[i]
    )
ax.plot(
    running_average_forward(bdf_results['reference']['cpu_times'], 1000    ),
    label=f'Reference(1e-12,1e-10) - {np.sum(bdf_results['reference']['cpu_times']):.2e}',
    linestyle='-',
    linewidth=3,
    color='red'
)

ax.plot(
    running_average_forward(bdf_results['bdf_results_low']['cpu_times'], 1000),
    label=f'BDF(1e-6,1e-8) - {np.sum(bdf_results['bdf_results_low']['cpu_times']):.2e}',
    linestyle='-',
    linewidth=2,
    color='blue'
)

ax.plot(
    running_average_forward(bdf_results['bdf_results_high']['cpu_times'], 1000),
    label=f'BDF(1e-8,1e-10) - {np.sum(bdf_results['bdf_results_high']['cpu_times']):.2e}',
    linestyle='-',
    linewidth=2,
    color='green'
)
# ax.set_title(f'Solvers {subplot_idx*6 + 1}-{min((subplot_idx+1)*6, 25)}')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# # Set common labels
# fig.text(0.5, 0.04, 'Step Number', ha='center', va='center')
# fig.text(0.06, 0.5, 'CPU Time (s)', ha='center', va='center', rotation='vertical')
fig.suptitle(f'CPU Time per Step for Different Implicit Solvers - {temperature}K - {pressure}Pa', fontsize=16)

plt.tight_layout()
plt.show()


In [None]:
# PLOT THE CPU TIME OF THE IMPLICIT SOLVERS IN A 2x2 GRID
line_styles = ['-', '--', '-.', ':'] * 10  # Repeat basic line styles
colors = [ 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan', 'magenta', 'lime', 'teal', 'navy', 'maroon', 'gold', 'silver', 'indigo', 'turquoise']

# Create figure with 2x2 subplots sharing x and y axes
fig, ax = plt.subplots(figsize=(15, 10), dpi=200)

# Plot each group in its own subplot
for i, table_id in enumerate(solver_to_plot):
    # if table_id in solver_to_exclude:
    #     continue
    data = solvers_results_high[table_id]
    ax.plot(
        running_average_forward(data['cpu_times'], 1000),
        label=f"{str(table_id)} - {np.sum(data['cpu_times']):.2e}",
        linestyle=line_styles[i],
        linewidth=2,
        color=colors[i]
    )
ax.plot(
    running_average_forward(reference_results['cpu_times'], 1000    ),
    label=f'Reference(1e-12,1e-10) - {np.sum(reference_results['cpu_times']):.2e}',
    linestyle='-',
    linewidth=3,
    color='red'
)

ax.plot(
    running_average_forward(bdf_results['cpu_times'], 1000),
    label=f'BDF(1e-6,1e-8) - {np.sum(bdf_results['cpu_times']):.2e}',
    linestyle='-',
    linewidth=2,
    color='blue'
)

ax.plot(
    running_average_forward(bdf_results_high['cpu_times'], 1000),
    label=f'BDF(1e-8,1e-10) - {np.sum(bdf_results_high['cpu_times']):.2e}',
    linestyle='-',
    linewidth=2,
    color='green'
)
# ax.set_title(f'Solvers {subplot_idx*6 + 1}-{min((subplot_idx+1)*6, 25)}')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# # Set common labels
# fig.text(0.5, 0.04, 'Step Number', ha='center', va='center')
# fig.text(0.06, 0.5, 'CPU Time (s)', ha='center', va='center', rotation='vertical')
fig.suptitle(f'CPU Time per Step for Different Implicit Solvers - {temperature}K - {pressure}Pa', fontsize=16)

plt.tight_layout()
plt.show()


In [None]:
solver_errors = {}
solver_errors_high = {}
explicit_solver_errors = {}
species_to_track = ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh']
for table_id in solver_to_plot:
    solver_errors[table_id] = calculate_rmse(reference_results, solvers_results_low[table_id], species_to_track, use_log=False)
    solver_errors_high[table_id] = calculate_rmse(reference_results, solvers_results_high[table_id], species_to_track, use_log=False)


bdf_results_errors = calculate_rmse(reference_results, bdf_results, species_to_track, use_log=False)

bdf_results_errors_high = calculate_rmse(reference_results, bdf_results_high, species_to_track, use_log=False)



In [None]:
solver_to_exclude = [SP.arkode.ButcherTable.HEUN_EULER_2_1_2 , SP.arkode.ButcherTable.ARK324L2SA_ERK_4_2_3, SP.arkode.ButcherTable.ARK437L2SA_DIRK_7_3_4, SP.arkode.ButcherTable.BOGACKI_SHAMPINE_4_2_3, SP.arkode.ButcherTable.CASH_5_3_4, SP.arkode.ButcherTable.HEUN_EULER_2_1_2]
solver_to_plot = [SP.arkode.ButcherTable.ARK2_DIRK_3_1_2, SP.arkode.ButcherTable.TRBDF2_3_3_2, SP.arkode.ButcherTable.HEUN_EULER_2_1_2 , SP.arkode.ButcherTable.BOGACKI_SHAMPINE_4_2_3,
                  SP.arkode.ButcherTable.CASH_5_3_4, SP.arkode.ButcherTable.ARK437L2SA_DIRK_7_3_4,  SP.arkode.ButcherTable.ARK324L2SA_ERK_4_2_3]

In [None]:

# Create one large figure with subplots for all species
fig, axes = plt.subplots(3, 3, figsize=(30, 30), dpi=300)

for i, specie_name in enumerate(species_to_track):
    ax = axes[i // 3, i % 3]
    for j, table_id in enumerate(solver_to_plot):
        # if table_id in solver_to_exclude:
        #     continue
        data = solver_errors_high[table_id]
        ax.plot(np.maximum(data[specie_name], 1e-20),
            label=f"{str(table_id)}", 
            linestyle=line_styles[j % len(line_styles)],
            linewidth=2,
                color=colors[j % len(colors)])
        
    ax.plot(np.maximum(bdf_results_errors[specie_name], 1e-20),
            label=f"BDF(1e-6,1e-8)", 
            linestyle='-',
            linewidth=2,
                color='red')
    ax.plot(np.maximum(bdf_results_errors_high[specie_name], 1e-20),
            label=f"BDF(1e-8,1e-10)", 
            linestyle='-',
            linewidth=2,
                color='blue')
    ax.set_title(f'{specie_name}')
    ax.set_yscale('log')
    ax.grid(True)
    if i == len(species_to_track) - 1:  # Only show legend on last plot
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

fig.suptitle('RMSE for Different Species Across All Implicit Solvers', fontsize=16)
plt.tight_layout()
plt.show()


In [None]:
implicit_solver_errors = {}
species_to_track = ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh']
for table_id in implicit_solvers:
    implicit_solver_errors[table_id] = calculate_rmse(reference_results, implicit_results[table_id], species_to_track, use_log=False)

# Create one large figure with subplots for all species
fig, axes = plt.subplots(3, 3, figsize=(30, 30), dpi=300)

for i, specie_name in enumerate(species_to_track):
    ax = axes[i // 3, i % 3]
    for j, table_id in enumerate(implicit_solvers):
        ax.plot(np.maximum(implicit_solver_errors[table_id][specie_name], 1e-20),
            label=f"{str(table_id)}", 
            linestyle=line_styles[j % len(line_styles)],
            linewidth=2,
                color=colors[j % len(colors)])
    ax.set_title(f'{specie_name}')
    ax.set_yscale('log')
    ax.grid(True)
    if i == len(species_to_track) - 1:  # Only show legend on last plot
        ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

fig.suptitle('RMSE for Different Species Across All Implicit Solvers', fontsize=16)
plt.tight_layout()
plt.show()


In [None]:
# PLOT THE CPU TIME OF THE IMPLICIT SOLVERS IN A 2x2 GRID
line_styles = ['-', '--', '-.', ':'] * 10  # Repeat basic line styles
colors = ['blue', 'green', 'purple', 'orange', 'brown', 'pink', 'gray', 'olive', 'cyan', 'magenta', 'lime', 'teal', 'navy', 'maroon', 'gold', 'silver', 'indigo', 'turquoise']

# Create figure with 2x2 subplots sharing x and y axes
fig, axs = plt.subplots(2, 2, figsize=(25, 20), dpi=300, sharex=True, sharey=True)
axs = axs.ravel()  # Flatten axes array for easier indexing

# Split solvers into 4 groups (first 3 groups of 6, last group of 7)
solver_groups = [implicit_solvers[i:i + 6] for i in range(0, 18, 6)]
solver_groups.append(implicit_solvers[18:])  # Add remaining 7 solvers

# Plot each group in its own subplot
for subplot_idx, solver_group in enumerate(solver_groups):
    for i, table_id in enumerate(solver_group):
        axs[subplot_idx].plot(
            running_average_forward(implicit_results[table_id]['cpu_times'], 200),
            label=f"{str(table_id)} - {np.sum(implicit_results[table_id]['cpu_times']):.2e}",
            linestyle=line_styles[i],
            linewidth=2,
            color=colors[i]
        )
    axs[subplot_idx].plot(
        running_average_forward(reference_results['cpu_times'], 100),
        label=f'Reference - {np.sum(reference_results['cpu_times']):.2e}',
        linestyle='-',
        linewidth=3,
        color='red'
    )
    axs[subplot_idx].set_title(f'Solvers {subplot_idx*6 + 1}-{min((subplot_idx+1)*6, 25)}')
    axs[subplot_idx].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
# # Set common labels
# fig.text(0.5, 0.04, 'Step Number', ha='center', va='center')
# fig.text(0.06, 0.5, 'CPU Time (s)', ha='center', va='center', rotation='vertical')
fig.suptitle(f'CPU Time per Step for Different Implicit Solvers - {temperature}K - {pressure}Pa', fontsize=16)

plt.tight_layout()
plt.show()

In [None]:
gas, y0 = reset_gas(temperature, pressure, mixture_fractions)
method = 'cvode_bdf'
rtol, atol = 1e-8, 1e-8

bdf_results = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=table_id
            )

In [None]:
37.8/1.95

In [None]:
gas, y0 = reset_gas(temperature, pressure, mixture_fractions)
method = 'cpp_rk23'
rtol, atol = 1e-12, 1e-10

rk_results_high = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=table_id
            )

In [None]:
gas, y0 = reset_gas(temperature, pressure, mixture_fractions)
method = 'cpp_rk23'
rtol, atol = 1e-8, 1e-8

rk_results_low = run_integration_experiment(
                method, gas, y0, t0, end_time, timestep,
                rtol, atol, species_to_track,
                fuel, pressure=gas.P,
                time_limit=time_limit,
                table_id=table_id
            )

In [None]:
error_comparison_rk_high = calculate_rmse(reference_results, rk_results_high, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
error_comparison_rk_low = calculate_rmse(reference_results, rk_results_low, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
error_comparison_bdf = calculate_rmse(reference_results, bdf_results, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
fig, axes = plt.subplots(3, 3, figsize=(20, 20), dpi=200)

for i, specie_name in enumerate(error_comparison_rk_high.keys()):
    ax = axes[i // 3, i % 3]
    ax.plot(reference_results['times'], np.maximum(error_comparison_rk_high[specie_name], 1e-20), label='RK23 (rtol=1e-8, atol=1e-8)', color='red', linestyle='-.')
    ax.plot(reference_results['times'], np.maximum(error_comparison_rk_low[specie_name], 1e-20), label='RK23 (rtol=1e-12, atol=1e-10)', linestyle='--', color='blue')
    ax.plot(reference_results['times'], np.maximum(error_comparison_bdf[specie_name], 1e-20), label='BDF (rtol=1e-8, atol=1e-8)', linestyle='--', color='green')
    ax.set_title(f"{specie_name} RMSE")
    ax.legend()

# i want to add a text box to the figure
fig.text(0.5, 0.01, "NOTE: The the error is maxed to 1e-20 for better visualization", ha='center', va='center', fontsize=12)
fig.suptitle(f"Temperature {temperature}K - Pressure {pressure}Pa")
plt.show()

In [None]:
cpu_times = []
fuel_mass_fractions = []
times = []
temperatures = []
species_profiles = {spec: [] for spec in species_to_track}
method = 'arkode_dirk'
table_id = SP.arkode.ButcherTable.TRBDF2_3_3_2
rtol, atol = 1e-6, 1e-8

end_time = 2e-2
t0 = 0.0
temperature = 400
index = 37
sample = df_filtered_array[index]
pressure = sample[1]
mixture_fractions = sample[4:]
gas, y0 = reset_gas(temperature, pressure, mixture_fractions)
t = 0.0

timestep = 1e-5
t = t0
y = y0.copy()
step_count = 0
start_time = time.time()

In [None]:
result = integrate_single_step(method, gas, y, t, timestep, rtol, atol, fuel, pressure=pressure, table_id=table_id)
    

In [None]:
result['cpu_time']

In [None]:

cpu_times = []
fuel_mass_fractions = []
times = []
temperatures = []
species_profiles = {spec: [] for spec in species_to_track}
method = 'arkode_dirk'
table_id = SP.arkode.ButcherTable.BILLINGTON_3_3_2
rtol, atol = 1e-6, 1e-8

end_time = 2e-2
t0 = 0.0
temperature = 400
index = 37
sample = df_filtered_array[index]
pressure = sample[1]
mixture_fractions = sample[4:]
gas, y0 = reset_gas(temperature, pressure, mixture_fractions)
t = 0.0
bar = tqdm(total=end_time)

timestep = 1e-5
t = t0
y = y0.copy()
step_count = 0
start_time = time.time()
while t < end_time:

    result = integrate_single_step(method, gas, y, t, timestep, rtol, atol, fuel, pressure=pressure, table_id=table_id)
    
    if not result['success']:
        print(f"Step {step_count} failed: {result['message']}")
        break   
    # Update state
    y = result['y']
    t = result['t']
    cpu_times.append(result['cpu_time'])
    step_count += 1
    fuel_mass_fractions.append(result['fuel_mass_fraction'])

    # ensure that y is not empt
    if len(y) == 0:
        print(f"Step {step_count} failed: y is empty")
        print(result)
        print(y)
        break
    
    # Record data
    times.append(t)
    temperatures.append(y[0])
    for spec in species_to_track:
        species_profiles[spec].append(y[gas.species_index(spec) + 1])
    
    #print(f"Step {step_count} at time {t:.2e} - temperature {y[0]:.1f}K - CPU time {cpu_times[-1]:.2e}s - time taken {time.time() - start_time:.2f}s | {np.sum(cpu_times):.2e}s")
    bar.set_postfix({
        'step': f"{step_count}",
        'temperature': f"{y[0]:.1f}K",
        'cpu_time': f"{cpu_times[-1]:.2e}s",
        'total_cpu_time': f"{np.sum(cpu_times):.2e}s"
    })
bar.close() 
total_wall_time = time.time() - start_time

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(10, 10), dpi=200)
ax1.plot(times, temperatures)
ax2.plot(times, cpu_times)
# plot a vertical line at t = 1e-2
ax1.axvline(1e-2, color='red', linestyle='--')
ax2.axvline(1e-2, color='red', linestyle='--')

# plot the bar chart of the sum of cpu_times, before and after 1e-2
ax3.bar(['Before 1e-2', 'After 1e-2'], [np.sum(cpu_times[:10000]), np.sum(cpu_times[10000:])])
# write the sum of cpu_times in the bar chart
ax3.text(0, np.sum(cpu_times[:10000]), f"{np.sum(cpu_times[:10000]):.2e}", ha='center', va='bottom')
ax3.text(1, np.sum(cpu_times[10000:]), f"{np.sum(cpu_times[10000:]):.2e}", ha='center', va='bottom')
plt.show()

In [None]:
rk_result = {
        'method': method,
        'phi': gas.equivalence_ratio,
        'rtol': rtol,
        'atol': atol,   
        'times': np.array(times),
        'fuel_mass_fractions': np.array(fuel_mass_fractions),
        'temperatures': np.array(temperatures),
        'species_profiles': species_profiles,
        'cpu_times': np.array(cpu_times),
        'total_cpu_time': np.sum(cpu_times),
        'total_wall_time': total_wall_time,
        'steps': step_count,
        'success': step_count > 0,
        'timed_out': total_wall_time > time_limit
    }


error_comparison_rk = calculate_rmse(reference_results, rk_result, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
fig, axes = plt.subplots(3, 3, figsize=(20, 20), dpi=200)

for i, specie_name in enumerate(error_comparison_rk.keys()):
    ax = axes[i // 3, i % 3]
    ax.plot(reference_results['times'], np.maximum(error_comparison_rk[specie_name], 1e-10), label='RK23')
    # ax.plot(rk_result['times'], np.maximum(error_comparison_bdf[specie_name], 1e-10), label='BDF')
    # ax.plot(results['times'], np.maximum(error_comparison_erk[specie_name], 1e-10), label='ERK')
    ax.set_title(f"{specie_name} RMSE")
    ax.legend()
    
fig.suptitle('Rk23 (1e-6, 1e-8) - BDF (1e-12, 1e-10)')
plt.show()

In [None]:
def running_average_forward(array, window_size):
    """
    Calculate running average where each point uses the window starting from that point.
    Returns array of same length as input.
    
    Parameters:
    -----------
    array : array-like
        Input array
    window_size : int
        Size of the averaging window
        
    Returns:
    --------
    numpy.ndarray
        Running average array of same length as input
    """
    array = np.array(array)
    n = len(array)
    result = np.zeros(n)
    
    # Calculate number of complete windows
    num_windows = (n + window_size - 1) // window_size
    
    # Process each window
    for w in range(num_windows):
        start_idx = w * window_size
        end_idx = min((w + 1) * window_size, n)
        window_mean = np.mean(array[start_idx:end_idx])
        
        # Fill result array for all points that use this window
        result[start_idx:end_idx] = window_mean
        
    return result


In [None]:

rtol_tolerances = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
atol_tolerances = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
all_tol_data = {} 
for rtol in rtol_tolerances:
    for atol in atol_tolerances:
        try:
            with open(f'tolerance_analysis_results_rk/tolerance_results_rtol_{rtol}_atol_{atol}.pkl', 'rb') as f:
                tolerance_results = pickle.load(f)
            all_tol_data[(rtol, atol)] = tolerance_results
        except FileNotFoundError:
            print(f'File not found for tolerance {rtol} {atol}')
            pass


In [None]:
cpu_times = rk_bdf_result['cpu_times']
temperatures = rk_bdf_result['temperatures']
print(f"Length of times: {len(cpu_times)}")
print(f"Length of temperatures: {len(temperatures)}")

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 10), dpi=200)
times = reference_results['times']
ax1.plot(times, temperatures, label='rk23-bdf (rtol=1e-6, atol=1e-8)')
ax1.plot(times, reference_results['temperatures'], label='cvode_bdf (rtol=1e-12, atol=1e-10)', linestyle='--')
ax2.plot(times[1:], running_average_forward(cpu_times, 1000), label='rk23-bdf (rtol=1e-8, atol=1e-6)')
ax2.plot(times[1:], running_average_forward(reference_results['cpu_times'], 1000), label='cvode_bdf (rtol=1e-12, atol=1e-10)', linestyle='--')

time_array = reference_results['times']
ax1.axvspan(time_array[0], time_array[pre_end], alpha=0.3, color='green',  
            label='Pre-ignition')
ax1.axvspan(time_array[ign_start], time_array[ign_end], alpha=0.3, color='red',
            label='Ignition')
ax1.axvspan(time_array[ign_end], time_array[-1], alpha=0.3, color='blue', 
            label='Post-ignition')

time_array = reference_results['times']
ax2.axvspan(time_array[0], time_array[pre_end], alpha=0.3, color='green', 
            label='Pre-ignition')
ax2.axvspan(time_array[ign_start], time_array[ign_end], alpha=0.3, color='red', 
            label='Ignition')
ax2.axvspan(time_array[ign_end], time_array[-1], alpha=0.3, color='blue', 
            label='Post-ignition')


ax1.legend()
ax2.legend()
plt.show()

In [None]:
rk_result = all_tol_data[(1e-10, 1e-10)]

In [None]:
error_comparison_rk_bdf = calculate_rmse(reference_results, rk_bdf_result, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
# error_comparison_bdf = calculate_rmse(reference_results, rk_result, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
# error_comparison_erk = calculate_rmse(reference_results, results, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)

In [None]:
error_comparison_bdf['h']

In [None]:

error_comparison_rk_bdf = calculate_rmse(reference_results, rk_bdf_result, ['temperature', 'h', 'h2', 'o', 'o2', 'h2o', 'ho2', 'h2o2', 'oh'], use_log=False)
fig, axes = plt.subplots(3, 3, figsize=(20, 20), dpi=200)

for i, specie_name in enumerate(error_comparison_rk_bdf.keys()):
    ax = axes[i // 3, i % 3]
    ax.plot(reference_results['times'], np.maximum(error_comparison_rk_bdf[specie_name], 1e-10), label='RK23-BDF')
    # ax.plot(rk_result['times'], np.maximum(error_comparison_bdf[specie_name], 1e-10), label='BDF')
    # ax.plot(results['times'], np.maximum(error_comparison_erk[specie_name], 1e-10), label='ERK')
    ax.set_title(f"{specie_name} RMSE")
    ax.legend()
    
fig.suptitle('Rk23 (1e-6, 1e-8) - BDF (1e-12, 1e-10)')
plt.show()

In [None]:


fig, axes = plt.subplots(3, 3, figsize=(20, 20), dpi=200)

for i, specie_name in enumerate(error_comparison_rk_bdf.keys()):
    ax = axes[i // 3, i % 3]
    ax.plot(reference_results['times'], np.log10(np.maximum(error_comparison_rk_bdf[specie_name], 1e-10)), label='RK23-BDF')
    # ax.plot(rk_result['times'], np.log10(np.maximum(error_comparison_bdf[specie_name], 1e-10)), label='BDF')
    # ax.plot(results['times'], np.log10(np.maximum(error_comparison_erk[specie_name], 1e-10)), label='ERK')
    ax.set_title(f"{specie_name} Log RMSE")
    ax.legend()
# add a figure title
fig.suptitle('Rk23 (1e-6, 1e-8) - BDF (1e-12, 1e-10)')
plt.show()

In [None]:
tolerances = [1e-14, 1e-13, 1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
tolerance_results = {}

for tolerance in tolerances:
    gas = ct.Solution(mechanism_file)
    
    gas.TPX = 1000, sample[1], sample[4:]
    y0 = get_initial_state(gas)
    tolerance_results[tolerance] = run_integration_experiment(
        method, gas, y0, t0, end_time, timestep,
        tolerance, tolerance, species_to_track,
        fuel, pressure=gas.P,
        time_limit=time_limit
    )


In [None]:
tolerances = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5, 1e-4]

In [None]:

tolerance_results = {}

for tolerance in tolerances:
    gas = ct.Solution(mechanism_file)
    gas.TPX = 1000, sample[1], sample[4:]
    tolerance_results[tolerance] = run_integration_experiment(
        method, gas, y0, t0, end_time, timestep,
        tolerance, tolerance, species_to_track,
        fuel, pressure=gas.P,
        time_limit=time_limit
    )


In [None]:
import pickle

In [None]:
tol_data = pd.read_pickle('tolerance_results2.pkl')

In [None]:
tol_2_analyze = "rtol"

In [None]:
tolerances = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
all_tol_data = {}
for tolerance in tolerances:
    try:
        with open(f'tolerance_study_results/tolerance_results_{tolerance}.pkl', 'rb') as f:
            tolerance_results = pickle.load(f)
        all_tol_data[tolerance] = tolerance_results
    except FileNotFoundError:
        print(f'File not found for tolerance {tolerance}')
        all_tol_data[tolerance] = tol_data[tolerance]

all_tol_data.keys()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
line_types = ['-', '--', '-.', ':', '-.', '--', '-', ':', '-.', '--', '-', ':']
line_styles = ['solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted']

for i, tolerance in enumerate(tolerances):
    if all_tol_data[tolerance] is not None:
        ax.plot(all_tol_data[tolerance]['times'], all_tol_data[tolerance]['temperatures'], label=f'{tol_2_analyze} {tolerance}', linestyle=line_styles[i], linewidth=2)

ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
line_types = ['-', '--', '-.', ':', '-.', '--', '-', ':', '-.', '--', '-', ':']
line_styles = ['solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted']

for i, tolerance in enumerate(tolerances):
    if tolerance in [0.0001, 0.001, 0.01, 0.1]:
        pass
    else:
        if all_tol_data[tolerance] is not None:
            ax.plot(all_tol_data[tolerance]['times'][1:], np.log10(all_tol_data[tolerance]['cpu_times']), label=f'{tol_2_analyze} {tolerance}', linestyle=line_styles[i], linewidth=2)

ax.legend()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(10, 5))
line_types = ['-', '--', '-.', ':', '-.', '--', '-', ':', '-.', '--', '-', ':']
line_styles = ['solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted', 'solid', 'dashed', 'dashdot', 'dotted']

for i, tolerance in enumerate(tolerances):
    if tolerance in [0.0001, 0.001, 0.01, 0.1]:
        pass
    else:
        print(tolerance)
        if all_tol_data[tolerance] is not None:
            # Calculate running average of CPU times
            running_avg = np.cumsum(all_tol_data[tolerance]['cpu_times']) / np.arange(1, len(all_tol_data[tolerance]['cpu_times']) + 1)
            
            # Plot running average on log scale
            ax.plot(all_tol_data[tolerance]['times'][1:], 
                   running_avg,
                   label=f'{tol_2_analyze} {tolerance}',
                   linestyle=line_styles[i], 
                   linewidth=2)

ax.set_xlabel('Time (s)')
ax.set_ylabel('log10(Running Average CPU Time)')
ax.legend()
plt.show()

In [None]:
from scipy.interpolate import interp1d

def compare_tolerance_effectiveness(all_tol_data, reference_tolerance=1e-14, metrics=['temperature', 'species', 'cpu_time']):
    """
    Compare different tolerance values against a reference tolerance.
    
    Args:
        all_tol_data: Dictionary containing tolerance results for different tolerance values
        reference_tolerance: The reference tolerance to compare against (default: 1e-14)
        metrics: List of metrics to compare ('temperature', 'species', 'cpu_time')
    
    Returns:
        Dictionary containing comparison results for each metric
    """
    
    if reference_tolerance not in all_tol_data:
        print(f"Reference tolerance {reference_tolerance} not found in data")
        return None
    
    ref_data = all_tol_data[reference_tolerance]
    comparison_results = {}
    
    # Get available tolerances (excluding reference)
    test_tolerances = [tol for tol in all_tol_data.keys() if tol not in [reference_tolerance, 0.0001, 0.001, 0.01, 0.1]]
    test_tolerances.sort()
    
    print(f"Comparing {len(test_tolerances)} tolerance values against reference {reference_tolerance}")
    print(f"Available test tolerances: {test_tolerances}")
    
    for metric in metrics:
        comparison_results[metric] = {}
        
        if metric == 'temperature':
            comparison_results[metric] = compare_temperature_profiles(all_tol_data, reference_tolerance, test_tolerances)
        
        elif metric == 'species':
            comparison_results[metric] = compare_species_profiles(all_tol_data, reference_tolerance, test_tolerances)
        
        elif metric == 'cpu_time':
            comparison_results[metric] = compare_cpu_times(all_tol_data, reference_tolerance, test_tolerances)
    
    return comparison_results

def compare_temperature_profiles(all_tol_data, reference_tolerance, test_tolerances):
    """Compare temperature profiles between reference and test tolerances."""
    ref_data = all_tol_data[reference_tolerance]
    ref_times = ref_data['times']
    ref_temps = ref_data['temperatures']
    
    temp_comparison = {}
    
    for tol in test_tolerances:
        if all_tol_data[tol] is None:
            temp_comparison[tol] = {'error': np.nan, 'max_error': np.nan, 'rmse': np.nan, 'relative_error': np.nan}
            continue
            
        test_data = all_tol_data[tol]
        test_times = test_data['times']
        test_temps = test_data['temperatures']
        
        # Interpolate reference data to test time points for comparison
        if len(ref_times) > 1 and len(test_times) > 1:
            ref_interp = interp1d(ref_times, ref_temps, bounds_error=False, fill_value='extrapolate')
            ref_temps_interp = ref_interp(test_times)
            
            # Calculate errors
            errors = test_temps - ref_temps_interp
            max_error = np.max(np.abs(errors))
            rmse = np.sqrt(np.mean(errors**2))
            
            # Calculate relative error (avoid division by zero)
            ref_temps_safe = np.where(np.abs(ref_temps_interp) > 1e-10, ref_temps_interp, 1e-10)
            relative_errors = np.abs(errors) / np.abs(ref_temps_safe)
            max_relative_error = np.max(relative_errors)
            
            temp_comparison[tol] = {
                'error': errors,
                'max_error': max_error,
                'rmse': rmse,
                'relative_error': relative_errors,
                'max_relative_error': max_relative_error
            }
        else:
            temp_comparison[tol] = {'error': np.nan, 'max_error': np.nan, 'rmse': np.nan, 'relative_error': np.nan}
    
    return temp_comparison

def compare_species_profiles(all_tol_data, reference_tolerance, test_tolerances):
    """Compare species profiles between reference and test tolerances."""
    ref_data = all_tol_data[reference_tolerance]
    ref_times = ref_data['times']
    ref_species = ref_data['species_profiles']
    
    species_comparison = {}
    
    for tol in test_tolerances:
        if all_tol_data[tol] is None:
            species_comparison[tol] = {}
            continue
            
        test_data = all_tol_data[tol]
        test_times = test_data['times']
        test_species = test_data['species_profiles']
        
        species_comparison[tol] = {}
        
        # Compare each species
        for species_name in ref_species.keys():
            if species_name in test_species:
                ref_profile = ref_species[species_name]
                test_profile = test_species[species_name]
                
                # Interpolate reference data to test time points
                if len(ref_times) > 1 and len(test_times) > 1:
                    ref_interp = interp1d(ref_times, ref_profile, bounds_error=False, fill_value='extrapolate')
                    ref_profile_interp = ref_interp(test_times)
                    
                    # Calculate errors
                    errors = test_profile - ref_profile_interp
                    max_error = np.max(np.abs(errors))
                    rmse = np.sqrt(np.mean(errors**2))
                    
                    # Calculate relative error
                    ref_profile_safe = np.where(np.abs(ref_profile_interp) > 1e-10, ref_profile_interp, 1e-10)
                    relative_errors = np.abs(errors) / np.abs(ref_profile_safe)
                    max_relative_error = np.max(relative_errors)
                    
                    species_comparison[tol][species_name] = {
                        'error': errors,
                        'max_error': max_error,
                        'rmse': rmse,
                        'relative_error': relative_errors,
                        'max_relative_error': max_relative_error
                    }
                else:
                    species_comparison[tol][species_name] = {
                        'error': np.nan, 'max_error': np.nan, 'rmse': np.nan, 
                        'relative_error': np.nan, 'max_relative_error': np.nan
                    }
    
    return species_comparison

def compare_cpu_times(all_tol_data, reference_tolerance, test_tolerances):
    """Compare CPU times between reference and test tolerances."""
    ref_data = all_tol_data[reference_tolerance]
    ref_cpu_time = ref_data['total_cpu_time']
    
    cpu_comparison = {}
    
    for tol in test_tolerances:
        if all_tol_data[tol] is None:
            cpu_comparison[tol] = {'cpu_time': np.nan, 'speedup': np.nan, 'efficiency': np.nan}
            continue
            
        test_data = all_tol_data[tol]
        test_cpu_time = test_data['total_cpu_time']
        
        # Calculate speedup and efficiency
        speedup = ref_cpu_time / test_cpu_time if test_cpu_time > 0 else np.nan
        efficiency = speedup if speedup <= 1 else 1/speedup
        
        cpu_comparison[tol] = {
            'cpu_time': test_cpu_time,
            'speedup': speedup,
            'efficiency': efficiency
        }
    
    return cpu_comparison


# Example usage:
# Compare all tolerances against the most precise one (1e-14)
comparison_results = compare_tolerance_effectiveness(
    all_tol_data, 
    reference_tolerance=1e-12, 
    metrics=['temperature', 'species', 'cpu_time']
)




In [None]:
# save comparison results to a pickle file
with open('comparison_results_rtol.pkl', 'wb') as f:
    pickle.dump(comparison_results, f)


In [None]:
# load comparison results from a pickle file
with open('comparison_results_rtol.pkl', 'rb') as f:
    comparison_results_rtol = pickle.load(f)

with open('comparison_results_atol.pkl', 'rb') as f:
    comparison_results_atol = pickle.load(f)



In [None]:
with open('/Users/elotech/Downloads/research_code/tolerance_study_results_/tolerance_results_1e-12.pkl', 'rb') as f:
    single_result = pickle.load(f)

single_result.keys()


In [None]:
single_cpu_time = np.sum(single_result['cpu_times'])


print(f"Single CPU time: {single_cpu_time}")










In [None]:
comparison_results_rtol['cpu_time'][1e-10]

In [None]:
comparison_results_atol['cpu_time'][1e-10]

In [None]:
136.78664755821228 - 128.32024765014648

In [None]:
RMSE
RTOL - 1e-12, ATOL - 1e-10 = 0.06409171894867949
ATOL - 1e-12, RTOL - 1e-10 = 0.11352238592234361

In [None]:
def plot_tolerance_comparison(comparison_results, reference_tolerance=1e-14, tol_2_analyze='atol'):
    """Plot comparison results for different metrics."""
    
    # Temperature comparison
    if 'temperature' in comparison_results:
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Plot max absolute error vs tolerance
        tolerances = list(comparison_results['temperature'].keys())
        # Plot error profiles over time for each tolerance
        for tol in tolerances:
            if 'error' in comparison_results['temperature'][tol]:
                times = np.arange(len(comparison_results['temperature'][tol]['error'])) * 1e-6
                errors = comparison_results['temperature'][tol]['error']
                axes[0, 0].plot(times, np.log10(np.maximum(errors, 1e-10)), '-', linewidth=2, label=f'{tol_2_analyze}={tol}')
        
        axes[0, 0].set_xlabel('Time (s)')
        axes[0, 0].set_ylabel('Absolute Error (K)')
        axes[0, 0].set_title('Temperature Error Profiles')
        axes[0, 0].grid(True, alpha=0.3)
        axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        
        # Plot RMSE vs tolerance
        rmses = [comparison_results['temperature'][tol]['rmse'] for tol in tolerances]
        axes[0, 1].semilogx(tolerances, rmses, 's-', linewidth=2, markersize=8, color='orange')
        axes[0, 1].set_xlabel(f'{tol_2_analyze}')
        axes[0, 1].set_ylabel('RMSE (K)')
        axes[0, 1].set_title(f'Temperature: RMSE vs {tol_2_analyze}')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot max relative error vs tolerance
        max_rel_errors = [comparison_results['temperature'][tol]['max_relative_error'] for tol in tolerances]
        axes[1, 0].semilogx(tolerances, max_rel_errors, '^-', linewidth=2, markersize=8, color='green')
        axes[1, 0].set_xlabel(f'{tol_2_analyze}')
        axes[1, 0].set_ylabel('Max Relative Error')
        axes[1, 0].set_title(f'Temperature: Max Relative Error vs {tol_2_analyze}')
        axes[1, 0].grid(True, alpha=0.3)
        
        # CPU time comparison
        if 'cpu_time' in comparison_results:
            cpu_times = [comparison_results['cpu_time'][tol]['cpu_time'] for tol in tolerances]
            axes[1, 1].semilogx(tolerances, cpu_times, 'd-', linewidth=2, markersize=8, color='red')
            axes[1, 1].set_xlabel(f'{tol_2_analyze}')
            axes[1, 1].set_ylabel('CPU Time (s)')
            axes[1, 1].set_title(f'CPU Time vs {tol_2_analyze}')
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    # Species comparison summary
    if 'species' in comparison_results:
        fig, ax = plt.subplots(figsize=(12, 8))
        
        tolerances = list(comparison_results['species'].keys())
        species_names = list(comparison_results['species'][tolerances[0]].keys()) if tolerances else []
        
        # Calculate average RMSE across all species for each tolerance
        avg_rmses = []
        for tol in tolerances:
            if comparison_results['species'][tol]:
                rmses = [comparison_results['species'][tol][sp]['rmse'] for sp in species_names 
                        if 'rmse' in comparison_results['species'][tol][sp]]
                avg_rmses.append(np.mean(rmses) if rmses else np.nan)
            else:
                avg_rmses.append(np.nan)
        
        ax.semilogx(tolerances, avg_rmses, 'o-', linewidth=2, markersize=8, color='purple')
        ax.set_xlabel(f'{tol_2_analyze}')
        ax.set_ylabel('Average RMSE across all species')
        ax.set_title(f'Species Profiles: Average RMSE vs {tol_2_analyze}')
        ax.grid(True, alpha=0.3)
        plt.show()

def print_tolerance_summary(comparison_results, reference_tolerance=1e-14):
    """Print a summary of tolerance comparison results."""
    
    print(f"\n{'='*60}")
    print(f"TOLERANCE COMPARISON SUMMARY (Reference: {reference_tolerance})")
    print(f"{'='*60}")
    
    if 'temperature' in comparison_results:
        print("\nTEMPERATURE COMPARISON:")
        print("-" * 40)
        for tol in sorted(comparison_results['temperature'].keys()):
            result = comparison_results['temperature'][tol]
            if not np.isnan(result['max_error']):
                print(f"Tolerance {tol:>8}: Max Error = {result['max_error']:>8.2f} K, "
                      f"RMSE = {result['rmse']:>8.2f} K, "
                      f"Max Rel Error = {result['max_relative_error']:>8.2e}")
    
    if 'cpu_time' in comparison_results:
        print("\nCPU TIME COMPARISON:")
        print("-" * 40)
        for tol in sorted(comparison_results['cpu_time'].keys()):
            result = comparison_results['cpu_time'][tol]
            if not np.isnan(result['cpu_time']):
                print(f"Tolerance {tol:>8}: CPU Time = {result['cpu_time']:>8.2f} s, "
                      f"Speedup = {result['speedup']:>8.2f}x")
    
    if 'species' in comparison_results:
        print("\nSPECIES PROFILES COMPARISON:")
        print("-" * 40)
        tolerances = list(comparison_results['species'].keys())
        if tolerances:
            species_names = list(comparison_results['species'][tolerances[0]].keys())
            print(f"Number of species compared: {len(species_names)}")
            
            # Show average RMSE for each tolerance
            for tol in sorted(tolerances):
                if comparison_results['species'][tol]:
                    rmses = [comparison_results['species'][tol][sp]['rmse'] for sp in species_names 
                            if 'rmse' in comparison_results['species'][tol][sp]]
                    avg_rmse = np.mean(rmses) if rmses else np.nan
                    if not np.isnan(avg_rmse):
                        print(f"Tolerance {tol:>8}: Average RMSE = {avg_rmse:>8.2e}")
# Print summary
if comparison_results:
    # print_tolerance_summary(comparison_results, reference_tolerance=1e-14)
    
    # Plot results
    plot_tolerance_comparison(comparison_results, reference_tolerance=1e-14, tol_2_analyze=tol_2_analyze)

In [None]:
def compare_consecutive_tolerances(all_tol_data, available_tolerance=None, metrics=['temperature', 'species', 'cpu_time']):
    """
    Compare consecutive tolerance values in a cascading manner.
    
    This function compares each tolerance with the next less precise one:
    - 1e-13 vs 1e-14 (reference)
    - 1e-12 vs 1e-13 (reference)
    - 1e-11 vs 1e-12 (reference)
    - ... and so on
    
    Args:
        all_tol_data: Dictionary containing tolerance results
        metrics: List of metrics to compare ('temperature', 'species', 'cpu_time')
    
    Returns:
        Dictionary containing cascading comparison results
    """
    
    # Get available tolerances and sort them from most precise to least precise
    if available_tolerance is None:
        available_tolerances = [tol for tol in all_tol_data.keys() if all_tol_data[tol] is not None]
    else:
        # Filter out tolerances that don't exist in all_tol_data
        available_tolerances = [tol for tol in available_tolerance if tol in all_tol_data and all_tol_data[tol] is not None]

    available_tolerances.sort()  # Sort from smallest to largest (most precise to least precise)
    
    if len(available_tolerances) < 2:
        print("Need at least 2 tolerance values for comparison")
        return None
    
    print(f"Performing cascading tolerance comparison for {len(available_tolerances)} tolerance values")
    print(f"Tolerance order (most precise to least precise): {available_tolerances}")
    
    cascading_results = {}
    
    for metric in metrics:
        cascading_results[metric] = {}
        
        if metric == 'temperature':
            cascading_results[metric] = compare_consecutive_temperature_profiles(all_tol_data, available_tolerances)
        
        elif metric == 'species':
            cascading_results[metric] = compare_consecutive_species_profiles(all_tol_data, available_tolerances)
        
        elif metric == 'cpu_time':
            cascading_results[metric] = compare_consecutive_cpu_times(all_tol_data, available_tolerances)
    
    return cascading_results

def compare_consecutive_temperature_profiles(all_tol_data, available_tolerances):
    """Compare temperature profiles between consecutive tolerance values."""
    
    temp_comparison = {}
    previous_max_error = None  # Track previous max error for growth factor calculation
    
    # Compare each tolerance with the next less precise one
    for i in range(len(available_tolerances) - 1):
        current_tol = available_tolerances[i]      # More precise (reference)
        next_tol = available_tolerances[i + 1]    # Less precise (test)
        
        print(f"Comparing {next_tol} vs {current_tol} (reference)")
        
        current_data = all_tol_data[current_tol]
        next_data = all_tol_data[next_tol]
        
        current_times = current_data['times']
        current_temps = current_data['temperatures']
        next_times = next_data['times']
        next_temps = next_data['temperatures']
        
        # Interpolate reference data to test tolerance time points
        if len(next_times) > 1 and len(current_times) > 1:
            current_interp = interp1d(current_times, current_temps, bounds_error=False, fill_value='extrapolate')
            current_temps_interp = current_interp(next_times)
            
            # Calculate errors
            errors = next_temps - current_temps_interp
            max_error = np.max(np.abs(errors))
            rmse = np.sqrt(np.mean(errors**2))
            
            # Calculate relative error (avoid division by zero)
            current_temps_safe = np.where(np.abs(current_temps_interp) > 1e-10, current_temps_interp, 1e-10)
            relative_errors = np.abs(errors) / np.abs(current_temps_safe)
            max_relative_error = np.max(relative_errors)
            
            # Calculate error growth factor (how much error increases from one tolerance to the next)
            if i == 0:  # First comparison, no previous error to compare
                error_growth_factor = 1.0  # Set to 1 for the first comparison
            else:
                error_growth_factor = max_error / previous_max_error if previous_max_error > 0 else np.nan
            
            temp_comparison[next_tol] = {
                'reference_tolerance': current_tol,
                'error': errors,
                'max_error': max_error,
                'rmse': rmse,
                'relative_error': relative_errors,
                'max_relative_error': max_relative_error,
                'error_growth_factor': error_growth_factor
            }
            
            # Update previous_max_error for next iteration
            previous_max_error = max_error
            
        else:
            temp_comparison[next_tol] = {
                'reference_tolerance': current_tol,
                'error': np.nan, 'max_error': np.nan, 'rmse': np.nan, 
                'relative_error': np.nan, 'max_relative_error': np.nan,
                'error_growth_factor': np.nan
            }
    
    return temp_comparison

def compare_consecutive_species_profiles(all_tol_data, available_tolerances):
    """Compare species profiles between consecutive tolerance values."""
    
    species_comparison = {}
    
    # Compare each tolerance with the next less precise one
    for i in range(len(available_tolerances) - 1):
        current_tol = available_tolerances[i]      # More precise (reference)
        next_tol = available_tolerances[i + 1]    # Less precise (test)
        
        current_data = all_tol_data[current_tol]
        next_data = all_tol_data[next_tol]
        
        current_times = current_data['times']
        current_species = current_data['species_profiles']
        next_times = next_data['times']
        next_species = next_data['species_profiles']
        
        species_comparison[next_tol] = {
            'reference_tolerance': current_tol,
            'species_errors': {}
        }
        
        # Compare each species
        for species_name in next_species.keys():
            if species_name in current_species:
                current_profile = current_species[species_name]
                next_profile = next_species[species_name]
                
                # Interpolate reference data to test tolerance time points
                if len(next_times) > 1 and len(current_times) > 1:
                    current_interp = interp1d(current_times, current_profile, bounds_error=False, fill_value='extrapolate')
                    current_profile_interp = current_interp(next_times)
                    
                    # Calculate errors
                    errors = next_profile - current_profile_interp
                    max_error = np.max(np.abs(errors))
                    rmse = np.sqrt(np.mean(errors**2))
                    
                    # Calculate relative error
                    current_profile_safe = np.where(np.abs(current_profile_interp) > 1e-10, current_profile_interp, 1e-10)
                    relative_errors = np.abs(errors) / np.abs(current_profile_safe)
                    max_relative_error = np.max(relative_errors)
                    
                    species_comparison[next_tol]['species_errors'][species_name] = {
                        'error': errors,
                        'max_error': max_error,
                        'rmse': rmse,
                        'relative_error': relative_errors,
                        'max_relative_error': max_relative_error
                    }
                else:
                    species_comparison[next_tol]['species_errors'][species_name] = {
                        'error': np.nan, 'max_error': np.nan, 'rmse': np.nan, 
                        'relative_error': np.nan, 'max_relative_error': np.nan
                    }
    
    return species_comparison

def compare_consecutive_cpu_times(all_tol_data, available_tolerances):
    """Compare CPU times between consecutive tolerance values."""
    
    cpu_comparison = {}
    
    # Compare each tolerance with the next less precise one
    for i in range(len(available_tolerances) - 1):
        current_tol = available_tolerances[i]      # More precise (reference)
        next_tol = available_tolerances[i + 1]    # Less precise (test)
        
        current_data = all_tol_data[current_tol]
        next_data = all_tol_data[next_tol]
        
        current_cpu_time = current_data['total_cpu_time']
        next_cpu_time = next_data['total_cpu_time']
        
        # Calculate speedup and efficiency
        speedup = current_cpu_time / next_cpu_time if next_cpu_time > 0 else np.nan
        efficiency = speedup if speedup <= 1 else 1/speedup
        
        # Calculate computational cost ratio
        cost_ratio = next_cpu_time / current_cpu_time if current_cpu_time > 0 else np.nan
        
        cpu_comparison[next_tol] = {
            'reference_tolerance': current_tol,
            'cpu_time': next_cpu_time,
            'reference_cpu_time': current_cpu_time,
            'speedup': speedup,
            'efficiency': efficiency,
            'cost_ratio': cost_ratio
        }
    
    return cpu_comparison

# Example usage:
available_tolerances = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
cascading_results = compare_consecutive_tolerances(all_tol_data, available_tolerances, metrics=['temperature', 'species', 'cpu_time'])


In [None]:


def plot_cascading_comparison(cascading_results, available_tolerances, tol_2_analyze):
    """Plot results from cascading tolerance comparison."""
    
    if not cascading_results:
        print("No cascading results to plot")
        return
    
    # Temperature comparison
    if 'temperature' in cascading_results:
        fig, axes = plt.subplots(2, 2, figsize=(15, 10))
        
        # Get test tolerances that actually have results
        test_tolerances = list(cascading_results['temperature'].keys())
        test_tolerances.sort()  # Sort from smallest to largest
        
        for tol in test_tolerances:
            if 'error' in cascading_results['temperature'][tol]:
                times = np.arange(len(cascading_results['temperature'][tol]['error'])) * 1e-6
                errors = cascading_results['temperature'][tol]['error']
                axes[0, 0].plot(times, np.log10(np.maximum(errors, 1e-10)), '-', linewidth=2, label=f'{tol_2_analyze}={tol}')
        
        axes[0, 0].set_xlabel(f'{tol_2_analyze}')
        axes[0, 0].set_ylabel('Max Absolute Error (K)')
        axes[0, 0].set_title(f'Temperature: Max Error vs {tol_2_analyze}\n(Compared to next more precise tolerance)')
        axes[0, 0].legend(bbox_to_anchor=(1.05, 1), loc='upper left')
        axes[0, 0].grid(True, alpha=0.3)
        
        # Plot RMSE vs tolerance
        rmses = [cascading_results['temperature'][tol]['rmse'] for tol in test_tolerances]
        axes[0, 1].semilogx(test_tolerances, rmses, 's-', linewidth=2, markersize=8, color='orange')
        axes[0, 1].set_xlabel(f'{tol_2_analyze}')
        axes[0, 1].set_ylabel('RMSE (K)')
        axes[0, 1].set_title(f'Temperature: RMSE vs {tol_2_analyze}\n(Compared to next more precise tolerance)')
        axes[0, 1].grid(True, alpha=0.3)
        
        # Plot max relative error vs tolerance
        max_rel_errors = [cascading_results['temperature'][tol]['max_relative_error'] for tol in test_tolerances]
        axes[1, 0].semilogx(test_tolerances, max_rel_errors, '^-', linewidth=2, markersize=8, color='green')
        axes[1, 0].set_xlabel(f'{tol_2_analyze}')
        axes[1, 0].set_ylabel('Max Relative Error')
        axes[1, 0].set_title(f'Temperature: Max Relative Error vs {tol_2_analyze}\n(Compared to next more precise tolerance)')
        axes[1, 0].grid(True, alpha=0.3)
        
        # Plot error growth factor vs tolerance
        error_growth_factors = [cascading_results['temperature'][tol]['error_growth_factor'] for tol in test_tolerances]
        axes[1, 1].semilogx(test_tolerances, error_growth_factors, 'd-', linewidth=2, markersize=8, color='red')
        axes[1, 1].set_xlabel(f'{tol_2_analyze}')
        axes[1, 1].set_ylabel('Error Growth Factor')
        axes[1, 1].set_title(f'Temperature: Error Growth Factor vs {tol_2_analyze}\n(How much error increases from one tolerance to the next)')
        axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
    
    # Species comparison
    if 'species' in cascading_results:
        # Get all species names from the first tolerance result
        test_tolerances = list(cascading_results['species'].keys())
        test_tolerances.sort()  # Sort from smallest to largest
        
        if test_tolerances:
            first_tol = test_tolerances[0]
            species_names = list(cascading_results['species'][first_tol]['species_errors'].keys())
            
            # Determine number of species and create subplot grid
            n_species = len(species_names)
            if n_species > 0:
                # Create a reasonable subplot layout
                if n_species <= 4:
                    rows, cols = 2, 2
                elif n_species <= 6:
                    rows, cols = 2, 3
                elif n_species <= 9:
                    rows, cols = 3, 3
                else:
                    rows, cols = 4, 4  # For more species, might need scrolling or separate plots
                
                fig, axes = plt.subplots(rows, cols, figsize=(15, 12))
                if rows == 1 and cols == 1:
                    axes = [axes]
                elif rows == 1 or cols == 1:
                    axes = axes.flatten()
                else:
                    axes = axes.flatten()
                
                # Plot each species
                for i, species_name in enumerate(species_names[:rows*cols]):  # Limit to subplot grid
                    ax = axes[i]
                    
                    # Extract max errors for this species across all tolerances
                    max_errors = []
                    for tol in test_tolerances:
                        if species_name in cascading_results['species'][tol]['species_errors']:
                            error = cascading_results['species'][tol]['species_errors'][species_name]['max_error']
                            max_errors.append(error)
                        else:
                            max_errors.append(np.nan)
                    
                    # Plot on log-log scale (log tolerance vs log error)
                    valid_indices = ~np.isnan(max_errors)
                    if np.any(valid_indices):
                        valid_tolerances = np.array(test_tolerances)[valid_indices]
                        valid_errors = np.array(max_errors)[valid_indices]
                        
                        # Remove zero or negative errors for log scale
                        positive_mask = valid_errors > 0
                        if np.any(positive_mask):
                            ax.loglog(valid_tolerances[positive_mask], valid_errors[positive_mask], 
                                     'o-', linewidth=2, markersize=6)
                    
                    ax.set_xlabel(f'{tol_2_analyze}')
                    ax.set_ylabel('Max Absolute Error (log scale)')
                    ax.set_title(f'{species_name}\nMax Error vs {tol_2_analyze}')
                    ax.grid(True, alpha=0.3)
                
                # Hide unused subplots
                for i in range(len(species_names), rows*cols):
                    axes[i].set_visible(False)
                
                plt.suptitle('Species: Max Error vs Tolerance (Log-Log Scale)\n(Compared to next more precise tolerance)', 
                           fontsize=14)
                plt.tight_layout()
                plt.show()
                
                # Create a second plot for relative errors
                fig, axes = plt.subplots(rows, cols, figsize=(15, 12))
                if rows == 1 and cols == 1:
                    axes = [axes]
                elif rows == 1 or cols == 1:
                    axes = axes.flatten()
                else:
                    axes = axes.flatten()
                
                # Plot relative errors for each species
                for i, species_name in enumerate(species_names[:rows*cols]):
                    ax = axes[i]
                    
                    # Extract max relative errors for this species across all tolerances
                    max_rel_errors = []
                    for tol in test_tolerances:
                        if species_name in cascading_results['species'][tol]['species_errors']:
                            rel_error = cascading_results['species'][tol]['species_errors'][species_name]['max_relative_error']
                            max_rel_errors.append(rel_error)
                        else:
                            max_rel_errors.append(np.nan)
                    
                    # Plot on semi-log scale
                    valid_indices = ~np.isnan(max_rel_errors)
                    if np.any(valid_indices):
                        valid_tolerances = np.array(test_tolerances)[valid_indices]
                        valid_rel_errors = np.array(max_rel_errors)[valid_indices]
                        
                        # Remove zero or negative errors for log scale
                        positive_mask = valid_rel_errors > 0
                        if np.any(positive_mask):
                            ax.semilogx(valid_tolerances[positive_mask], valid_rel_errors[positive_mask], 
                                      's-', linewidth=2, markersize=6, color='orange')
                    
                    ax.set_xlabel(f'{tol_2_analyze}')
                    ax.set_ylabel('Max Relative Error')
                    ax.set_title(f'{species_name}\nMax Relative Error vs {tol_2_analyze}')
                    ax.grid(True, alpha=0.3)
                
                # Hide unused subplots
                for i in range(len(species_names), rows*cols):
                    axes[i].set_visible(False)
                
                plt.suptitle(f'Species: Max Relative Error vs {tol_2_analyze}\n(Compared to next more precise tolerance)', 
                           fontsize=14)
                plt.tight_layout()
                plt.show()
    
    # CPU time comparison
    if 'cpu_time' in cascading_results:
        fig, axes = plt.subplots(1, 2, figsize=(12, 5))
        
        # Get test tolerances that actually have results
        test_tolerances = list(cascading_results['cpu_time'].keys())
        test_tolerances.sort()  # Sort from smallest to largest
        
        # Plot speedup vs tolerance
        speedups = [cascading_results['cpu_time'][tol]['speedup'] for tol in test_tolerances]
        axes[0].semilogx(test_tolerances, speedups, 'o-', linewidth=2, markersize=8, color='blue')
        axes[0].set_xlabel(f'{tol_2_analyze}')
        axes[0].set_ylabel('Speedup vs Next More Precise Tolerance')
        axes[0].set_title(f'Speedup vs {tol_2_analyze}\n(Compared to next more precise tolerance)')
        axes[0].grid(True, alpha=0.3)
        
        # Plot cost ratio vs tolerance
        cost_ratios = [cascading_results['cpu_time'][tol]['cost_ratio'] for tol in test_tolerances]
        axes[1].semilogx(test_tolerances, cost_ratios, 's-', linewidth=2, markersize=8, color='purple')
        axes[1].set_xlabel(f'{tol_2_analyze}')
        axes[1].set_ylabel('Computational Cost Ratio')
        axes[1].set_title(f'Computational Cost Ratio vs {tol_2_analyze}\n(Compared to next more precise tolerance)')
        axes[1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()

def print_cascading_summary(cascading_results, available_tolerances):
    """Print a summary of cascading tolerance comparison results."""
    
    print(f"\n{'='*80}")
    print(f"CASCADING TOLERANCE COMPARISON SUMMARY")
    print(f"{'='*80}")
    print(f"Tolerance order (most precise to least precise): {available_tolerances}")
    print(f"Number of comparisons: {len(available_tolerances) - 1}")
    
    if 'temperature' in cascading_results:
        print(f"\n{'TEMPERATURE CASCADING COMPARISON':-^80}")
        print(f"{'Tolerance':>12} {'Reference':>12} {'Max Error':>12} {'RMSE':>12} {'Max Rel Error':>15}")
        print("-" * 80)
        
        # Get the actual tolerances that have comparison results
        test_tolerances = list(cascading_results['temperature'].keys())
        # Sort them to match the tolerance order
        test_tolerances.sort(reverse=True)  # Most precise to least precise
        
        for tol in test_tolerances:
            result = cascading_results['temperature'][tol]
            if not np.isnan(result['max_error']):
                print(f"{tol:>12} {result['reference_tolerance']:>12} "
                      f"{result['max_error']:>12.2f} {result['rmse']:>12.2f} "
                      f"{result['max_relative_error']:>15.2e}")
    
    if 'cpu_time' in cascading_results:
        print(f"\n{'CPU TIME CASCADING COMPARISON':-^80}")
        print(f"{'Tolerance':>12} {'Reference':>12} {'CPU Time':>12} {'Speedup':>12} {'Cost Ratio':>12}")
        print("-" * 80)
        
        # Get the actual tolerances that have comparison results
        test_tolerances = list(cascading_results['cpu_time'].keys())
        # Sort them to match the tolerance order
        test_tolerances.sort(reverse=True)  # Most precise to least precise
        
        for tol in test_tolerances:
            result = cascading_results['cpu_time'][tol]
            if not np.isnan(result['cpu_time']):
                print(f"{tol:>12} {result['reference_tolerance']:>12} "
                      f"{result['cpu_time']:>12.2f} {result['speedup']:>12.2f} "
                      f"{result['cost_ratio']:>12.2f}")
    
    # Analyze error propagation
    if 'temperature' in cascading_results:
        print(f"\n{'ERROR PROPAGATION ANALYSIS':-^80}")
        
        # Get the actual tolerances that have comparison results
        test_tolerances = list(cascading_results['temperature'].keys())
        test_tolerances.sort(reverse=True)  # Most precise to least precise
        error_growth_factors = []
        
        for tol in test_tolerances:
            result = cascading_results['temperature'][tol]
            if not np.isnan(result['error_growth_factor']):
                error_growth_factors.append(result['error_growth_factor'])
        
        if error_growth_factors:
            avg_growth = np.mean(error_growth_factors)
            max_growth = np.max(error_growth_factors)
            min_growth = np.min(error_growth_factors)
            
            print(f"Average error growth factor: {avg_growth:.2f}")
            print(f"Maximum error growth factor: {max_growth:.2f}")
            print(f"Minimum error growth factor: {min_growth:.2f}")
            
            if avg_growth > 1.5:
                print("⚠️  Warning: High error propagation detected - errors are growing significantly between tolerance levels")
            elif avg_growth < 1.1:
                print("✅ Good: Low error propagation - errors are growing slowly between tolerance levels")
            else:
                print("⚠️  Moderate: Some error propagation - errors are growing moderately between tolerance levels")


# print_cascading_summary(cascading_results, available_tolerances)
plot_cascading_comparison(cascading_results, available_tolerances, tol_2_analyze=tol_2_analyze)

In [None]:
rtol_tolerances = [1e-12, 1e-11, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]
atol_tolerances = [1e-10, 1e-9, 1e-8, 1e-7, 1e-6]
all_tol_data = {} 
for rtol in rtol_tolerances:
    for atol in atol_tolerances:
        try:
            with open(f'tolerance_analysis_results/tolerance_results_rtol_{rtol}_atol_{atol}.pkl', 'rb') as f:
                tolerance_results = pickle.load(f)
            all_tol_data[(rtol, atol)] = tolerance_results
        except FileNotFoundError:
            pass


all_tol_data.keys()