In [None]:
import numpy as np
import time
import matplotlib.pyplot as plt
from IPython.display import clear_output # Import the clear_output function
import multiprocessing # <-- Import multiprocessing
from simulation_parameters import *
# --- MODIFIED: Import property getters from utils ---
from utils import (check_convergence, calculate_residual, 
                   get_effective_thermal_conductivity, get_effective_permeability)
# --- END MODIFICATION ---
from physics_solvers import (solve_pressure, solve_heat_equation, update_fin_and_htf,
                             solve_inter_fin_conduction, solve_heat_equation_empty, solve_pressure_empty)
from hydration_kinetics import (predict_final_alphas_parallel, calculate_avg_sources,
                                  calculate_conversion)
# Import both plotting functions
from plotting_utils import plot_reactor_heatmap, plot_zone_cross_section

# --- Set which cross-sections to plot (can be lists) ---
PLOT_CROSS_SECTION_ZONE_INDICES = [0]  # Which zone(s) to view (e.g., [0, 5, 9])

indices = np.round(np.linspace(0, NUM_TIMESTEPS - 1, NUM_CROSS_SECTION_PLOTS)).astype(int)
indices = np.unique(indices).tolist()
print(f"num_points: {indices}")

PLOT_CROSS_SECTION_TIME_STEP_INDICES = indices

def initialize_state():
    """Initializes all state variables for the simulation."""
    print("Initializing reactor...")
    
    # State variables (lists of arrays, one for each zone)
    T_zones = [np.full((NZ, NX), float(T_INITIAL)) for _ in range(NUM_ZONES)]
    P_zones = [np.full((NZ, NX), float(P_INITIAL)) for _ in range(NUM_ZONES)] 
    alphas_zones = [np.tile(ALPHA_INITIAL, (NZ, NX, 1)) for _ in range(NUM_ZONES)]
    avg_mass_source_zones = [np.zeros((NZ, NX)) for _ in range(NUM_ZONES)]
    
    # HTF and Fin temperatures (1D arrays)
    T_htf_zones = np.full(NUM_ZONES, float(T_INITIAL))
    T_fin_zones = np.full(NUM_ZONES, float(T_INITIAL))
    
    # History for plots
    history = {
        'htf': [], 'avg_temp': [], 'avg_pressure': [], 'avg_conversion': []
    }
    
    return T_zones, P_zones, alphas_zones, avg_mass_source_zones, T_htf_zones, T_fin_zones, history

# --- NEW: Worker function for parallel P/T solves ---
def solve_physics_worker(args):
    """
    Solves P and T equations for a single zone.
    This function is designed to be called by multiprocessing.Pool.
    """
    # Unpack arguments
    (P_guess, avg_mass_source, T_guess, alphas_old, P_inf, 
     T_old, avg_heat_source, T_fin_guess, dt) = args
    
    # Solve Pressure
    P_new = solve_pressure(
        P_guess, avg_mass_source, T_guess, alphas_old, P_inf, use_warm_start=True
    )
    
    # Solve Temperature
    T_new = solve_heat_equation(
        T_old, avg_heat_source, alphas_old, T_fin_guess, dt
    )
    
    return P_new, T_new
# --- END NEW FUNCTION ---


def solve_coupling_loop(n, dt, T_zones, P_zones, alphas_zones, avg_mass_source_zones, T_fin_zones, T_htf_zones):
    """Solves the main physics coupling loop for a single time step."""
    
    # Store state at the beginning of the time step
    alphas_old_zones = [a.copy() for a in alphas_zones]
    mass_source_guess_zones = [m.copy() for m in avg_mass_source_zones]
    avg_heat_source_zones_guess = [np.zeros((NZ, NX)) for _ in range(NUM_ZONES)] # Added for heat source relaxation
    
    # Guesses for the coupling loop
    T_guess_zones = [T.copy() for T in T_zones]
    P_guess_zones = [P.copy() for P in P_zones]
    T_fin_guess_zones = np.copy(T_fin_zones)
    
    current_omega = OMEGA_MASS_SOURCE_INITIAL
    last_P_residual = np.inf
    last_T_residual = np.inf # Added for T residual tracking
    
    print("Starting coupling loop...")
    for coupling_iter in range(MAX_ITER_COUPLING):
        
#         # --- DEBUG PRINT ---
#         if n in PLOT_CROSS_SECTION_TIME_STEP_INDICES:
#             T_debug = T_guess_zones[0][0, 0]
#             P_debug = P_guess_zones[0][0, 0]
#             print(f"  [Coupling Iter {coupling_iter+1}] T/P for cell (0,0,0): {T_debug:.2f} K, {P_debug:.2f} Pa")
        
        # --- 3a. Predict Sources (Parallel) ---
        alphas_final_zones = predict_final_alphas_parallel(
            alphas_old_zones, T_guess_zones, P_guess_zones, dt
        )
        
        avg_heat_source_zones_new_calc = [] 
        avg_mass_source_zones_new_calc = [] 
        for i in range(NUM_ZONES):
            avg_heat, avg_mass = calculate_avg_sources(
                alphas_old_zones[i], alphas_final_zones[i], dt
            )
            avg_heat_source_zones_new_calc.append(avg_heat)
            avg_mass_source_zones_new_calc.append(avg_mass) 

        # --- Optional: Debug print for source terms ---
        if coupling_iter==0:
            all_mass_sources = np.concatenate([arr.flatten() for arr in avg_mass_source_zones_new_calc])
            print(f"  [Debug] Mass Source Range: {np.min(all_mass_sources):.2e} to {np.max(all_mass_sources):.2e} kg/m^3/s")
        
        # --- Check for NaN/Inf from kinetics solver ---
        has_nan = any(np.isnan(arr).any() for arr in avg_mass_source_zones_new_calc)
        has_inf = any(np.isinf(arr).any() for arr in avg_mass_source_zones_new_calc)

        if has_nan or has_inf:
            print(f"  [Coupling Iter {coupling_iter+1}] CRITICAL: Kinetics solver failed (NaN/Inf detected).")
            print("  This is likely due to numerical instability (oscillating T/P guesses).")
            print(f"  RECOMMENDATION: Reduce DT (currently {DT}) in simulation_parameters.py.")
            raise RuntimeError(f"Kinetics solver failed (NaN/Inf). Reduce DT from {DT}.")

        # --- 3b. Apply Source Relaxation ---
        for i in range(NUM_ZONES):
            mass_source_guess_zones[i] = (
                mass_source_guess_zones[i] * (1.0 - current_omega) +
                avg_mass_source_zones_new_calc[i] * current_omega
            )
            avg_heat_source_zones_guess[i] = (
                avg_heat_source_zones_guess[i] * (1.0 - current_omega) +
                avg_heat_source_zones_new_calc[i] * current_omega
            )
        
        avg_mass_source_zones = mass_source_guess_zones

        # --- 3c. Solve Spatial Physics (Parallel) ---
        P_new_guess_zones = []
        T_new_guess_zones = []
        T_fin_new_zones = np.copy(T_fin_guess_zones) # Use guess for this step
        
        # 1. Build argument list for parallel workers
        physics_args_list = []
        for i in range(NUM_ZONES):
            physics_args_list.append((
                P_guess_zones[i], avg_mass_source_zones[i], T_guess_zones[i], alphas_old_zones[i], P_INF, 
                T_zones[i], avg_heat_source_zones_guess[i], T_fin_guess_zones[i], dt
            ))

        # 2. Run P and T solvers in parallel
        with multiprocessing.Pool() as pool:
            physics_results = pool.map(solve_physics_worker, physics_args_list)

        # 3. Reconstruct results
        for P_new, T_new in physics_results:
            P_new_guess_zones.append(P_new)
            T_new_guess_zones.append(T_new)

        # --- 3d. Update HTF and Fin Temps (Sequential) ---
        # This loop MUST be sequential and runs AFTER the parallel solve.
        # It uses the T_new_guess_zones calculated in the parallel step.
        T_htf_inlet_current_zone = T_HTF_IN
        for i in range(NUM_ZONES):
            T_htf_outlet, T_fin_new_for_zone = update_fin_and_htf(
                T_htf_inlet_current_zone, T_new_guess_zones[i], HTF_MASS_FLOW / NUM_ZONES, CP_HTF
            )
            T_fin_new_zones[i] = T_fin_new_for_zone # Update the fin temp for this zone
            T_htf_inlet_current_zone = T_htf_outlet
            T_htf_zones[i] = T_htf_outlet
        
        # --- 3e. Inter-Zone Conduction ---
        T_fin_new_zones = solve_inter_fin_conduction(T_fin_new_zones)
        
        # --- 3f. Check Convergence & Adapt ---
        T_converged = check_convergence(T_new_guess_zones, T_guess_zones, TOLERANCE_COUPLING)
        P_converged = check_convergence(P_new_guess_zones, P_guess_zones, TOLERANCE_COUPLING)
        Fin_converged = check_convergence(T_fin_new_zones, T_fin_guess_zones, TOLERANCE_COUPLING)
        
        current_T_residual = calculate_residual(T_new_guess_zones, T_guess_zones)
        current_P_residual = calculate_residual(P_new_guess_zones, P_guess_zones)
        
        if np.isnan(current_P_residual) or np.isnan(current_T_residual):
            print(f"  [Coupling Iter {coupling_iter+1}] CRITICAL: Residual is NaN.")
            print(f"  RECOMMENDATION: Reduce DT (currently {DT}) in simulation_parameters.py.")
            raise RuntimeError(f"Solver failed (NaN residual). Reduce DT from {DT}.")
        
        if coupling_iter > 0:
            if current_P_residual < last_P_residual or current_T_residual < last_T_residual:
                current_omega = min(current_omega * ADAPTIVE_OMEGA_INCREASE, OMEGA_MASS_SOURCE_MAX)
            else:
                current_omega = max(current_omega * ADAPTIVE_OMEGA_DECREASE, OMEGA_MASS_SOURCE_MIN)
            
            last_P_residual = current_P_residual
            last_T_residual = current_T_residual 

        if (coupling_iter + 1) % 1 == 0 or coupling_iter == MAX_ITER_COUPLING - 1:
             print(f"  [Coupling Iter {coupling_iter+1}] P_residual: {current_P_residual:.2e}, T_residual: {current_T_residual:.2e}, new omega: {current_omega:.3f}")

        
        # --- 3g. Apply State Relaxation ---
        
        halving_factor = 0.5 ** (coupling_iter // 100)
        current_omega_T_ceiling = OMEGA_STATE_RELAXATION_T * halving_factor
        current_omega_P_ceiling = OMEGA_STATE_RELAXATION_P * halving_factor

        max_delta_T = 0.0
        max_delta_P = 0.0
        for i in range(NUM_ZONES):
            max_delta_T = max(max_delta_T, np.max(np.abs(T_new_guess_zones[i] - T_guess_zones[i])))
            max_delta_P = max(max_delta_P, np.max(np.abs(P_new_guess_zones[i] - P_guess_zones[i])))

        MAX_CHANGE_T = 1.0 
        MAX_CHANGE_P = 500.0 
        
        omega_T = min(current_omega_T_ceiling, MAX_CHANGE_T / (max_delta_T + 1e-10))
        omega_P = min(current_omega_P_ceiling, MAX_CHANGE_P / (max_delta_P + 1e-10))

        for i in range(NUM_ZONES):
            T_guess_zones[i] = (1.0 - omega_T) * T_guess_zones[i] + omega_T * T_new_guess_zones[i]
            P_guess_zones[i] = (1.0 - omega_P) * P_guess_zones[i] + omega_P * P_new_guess_zones[i]
        
        T_fin_guess_zones = (1.0 - omega_T) * T_fin_new_zones + (omega_T) * T_fin_guess_zones # <-- Corrected relaxation for fin
        
        if T_converged and P_converged and Fin_converged and coupling_iter >= 4:
            print(f"Coupling loop converged in {coupling_iter + 1} iterations.")
            break 
            
    return T_guess_zones, P_guess_zones, T_fin_guess_zones, alphas_final_zones, avg_mass_source_zones, T_htf_zones


def log_and_plot_timestep(n, T_zones, P_zones, alphas_zones, avg_mass_source_zones, T_htf_zones, history, dt, nx, nz):
    """Handles all output, plotting, and history saving for a completed time step."""
    
    T_outlet = T_htf_zones[NUM_ZONES - 1]
    print(f"Time step {n+1} complete. HTF Outlet Temperature: {T_outlet:.2f} K")

    if n in PLOT_CROSS_SECTION_TIME_STEP_INDICES:
        # clear_output(wait=True) # Clear console output for ipynb
        print(f"Generating cross-section plots for Time Step {n}...")
        
        for zone_idx in PLOT_CROSS_SECTION_ZONE_INDICES:
            alphas = alphas_zones[zone_idx]
            k_eff = get_effective_thermal_conductivity(alphas)[0, 0]
            perm = get_effective_permeability(alphas)[0, 0]
            
            plot_zone_cross_section(
                T_zones[zone_idx], P_zones[zone_idx], alphas_zones[zone_idx],
                avg_mass_source_zones[zone_idx],
                W, H, zone_idx, n,
                dt, nx, nz, k_eff, perm # Pass new params
            )

    # Calculate and store averages for plotting
    history['htf'].append(np.copy(T_htf_zones))
    history['avg_temp'].append([np.mean(T) for T in T_zones])
    history['avg_pressure'].append([np.mean(P) for P in P_zones])
    history['avg_conversion'].append([np.mean(calculate_conversion(a)) for a in alphas_zones])

def plot_final_heatmaps(history):
    """Generates the 4-panel heatmap of all simulation results."""
    print("Generating heatmaps...")
    
    fig, axes = plt.subplots(2, 2, figsize=(8, 6))
    fig.suptitle('Reactor Performance Over Time', fontsize=16)
    
    plot_reactor_heatmap(axes[0, 0], history['htf'], T_SIMULATION, HTF_PATH_LENGTH,
                         title='HTF Temperature Evolution', cbar_label='HTF Temperature (K)')

    plot_reactor_heatmap(axes[0, 1], history['avg_temp'], T_SIMULATION, HTF_PATH_LENGTH,
                         title='Average Solid Temperature Evolution', cbar_label='Avg. Temperature (K)')

    plot_reactor_heatmap(axes[1, 0], history['avg_pressure'], T_SIMULATION, HTF_PATH_LENGTH,
                         title='Average Vapor Pressure Evolution', cbar_label='Avg. Pressure (Pa)')

    plot_reactor_heatmap(axes[1, 1], history['avg_conversion'], T_SIMULATION, HTF_PATH_LENGTH,
                         title='Average Conversion Evolution', cbar_label='Conversion (0.0 - 1.0)')
    
    plt.tight_layout(rect=[0, 0.03, 1, 0.95]) # Adjust for main title
    plt.show()

def run_simulation():
    # --- 1. Initialization ---
    T_zones, P_zones, alphas_zones, avg_mass_source_zones, \
    T_htf_zones, T_fin_zones, history = initialize_state()
    
    # --- 2. Main Time Loop ---
    for n in range(NUM_TIMESTEPS):
        t_current = n * DT
        print(f"\n--- Time step {n+1}/{NUM_TIMESTEPS}, Time: {t_current:.2f}s ---")

        # --- 3. Main Coupling Loop (Solves the time step) ---
        T_guess, P_guess, T_fin_guess, alphas_final, \
        avg_mass_source_final, T_htf_final = solve_coupling_loop(
            n, DT, T_zones, P_zones, alphas_zones, 
            avg_mass_source_zones, T_fin_zones, T_htf_zones
        )
        
        # --- 4. Finalize Time Step ---
        T_zones = T_guess
        P_zones = P_guess
        T_fin_zones = T_fin_guess
        alphas_zones = alphas_final
        avg_mass_source_zones = avg_mass_source_final # Store converged source
        T_htf_zones = T_htf_final # Store converged HTF temps
        
        clear_output(wait=True) # Clear console output for ipynb
        # --- 5. Output and Store Results for the Time Step ---
        log_and_plot_timestep(
            n, T_zones, P_zones, alphas_zones, 
            avg_mass_source_zones, T_htf_zones, history,
            DT, NX, NZ # Pass params
        )
    
    # --- 6. Plotting ---
    plot_final_heatmaps(history)


if __name__ == "__main__":
    # This check is important for multiprocessing to work correctly
    run_simulation()



num_points: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99]
Initializing reactor...

--- Time step 1/100, Time: 0.00s ---
Starting coupling loop...
  [Coupling Iter 1] T/P for cell (0,0,0): 353.00 K, 1200.00 Pa
  [Debug] Mass Source Range: 1.19e+00 to 1.19e+00 kg/m^3/s
  [Coupling Iter 1] P_residual: 1.11e-01, T_residual: 3.56e-03, new omega: 0.100
  [Coupling Iter 2] T/P for cell (0,0,0): 352.91 K, 1245.22 Pa
  [Coupling Iter 2] P_residual: 1.66e-01, T_residual: 6.50e-03, new omega: 0.120
  [Coupling Iter 3] T/P for cell (0,0,0): 352.76 K, 1314.11 Pa
  [Coupling Iter 3] P_residual: 1.97e-01, T_residual: 9.36e-03, new omega: 0.060
  [Coupling Iter 4