In [None]:
#compute_trajectory
import numpy as np
from scipy.integrate import solve_ivp
import time # Optional: for timing checks if needed

# Compute_trajectory calculates the trajectory based on the initial polar coordinates and psi/b values
# t_end is the max lambda values to integrate to, r_max is the max radius before integration stops
# x_stop: If provided (number), integration stops if x decreases below this value.
# x_targets: A list/array of x-values. Record y-values when the trajectory crosses each x-value moving from x > x_target to x < x_target.
#            Uses interpolation for higher accuracy at the crossing point.
# num_interp_points: The number of points to use for the final interpolated output trajectory.
# Compute_trajectory returns:
#   x, y, r, phi, K (trajectory arrays, interpolated to num_interp_points)
#   crossings_y_at_x_targets (list of lists for y-values at x_targets, calculated accurately at event times)
def compute_trajectory(r_0, phi_0, M, psi, r_max, x_stop=None, x_targets=None, t_end=10**6, num_interp_points=1000):
    # --- Basic Setup ---
    b = np.cos(psi) * r_0 / np.sqrt(1 - 2 * M / r_0)
    sign = [1]
    if psi % (2 * np.pi) > np.pi:
        sign[0] *= -1

    # --- ODE Definition ---
    def ODE_Solver(K, u):
        global k
        r, phi = u
        if r <= 2 * M:
            return [0.0, 0.0]
        metric_term = (1 - 2 * M / r)
        term_b_r_sq = (b / r) ** 2
        fr = 1 - metric_term * term_b_r_sq
        # Use abs() to prevent sqrt domain error just before event triggers
        # Event handling should manage the sign change correctly
        drdK = sign[0] * np.sqrt(abs(fr))
        dphidK = b / r**2
        return [drdK, dphidK]

    # --- Standard Terminal Event Definitions ---
    # (Event definitions remain the same as the previous version)
    def event_fr_zero(K, u): # Index 0
        r, phi = u
        if r <= (2 * M + 1e-12): return 0.0
        metric_term = (1 - 2 * M / r)
        term_b_r_sq = (b / r) ** 2
        return 1 - metric_term * term_b_r_sq
    event_fr_zero.terminal = True
    event_fr_zero.direction = 0

    def event_r_leq_2M(K, u): # Index 1
        r, phi = u
        return r - (2 * M + 1e-12)
    event_r_leq_2M.terminal = True
    event_r_leq_2M.direction = -1

    def event_r_max(K, u): # Index 2
        r, phi = u
        return r - r_max
    event_r_max.terminal = True
    event_r_max.direction = 1

    # --- Initialize Event List & Indices ---
    active_events = [event_fr_zero, event_r_leq_2M, event_r_max]
    event_x_stop_index = -1
    x_target_event_indices = []

    # --- Optional Terminal Event: x_stop ---
    event_x_stop_active = isinstance(x_stop, (int, float))
    if event_x_stop_active:
        def event_x_stop_func(K, u):
            r, phi = u
            if r < 1e-10: return 1.0
            x = r * np.cos(phi)
            return x - x_stop
        event_x_stop_func.terminal = True
        event_x_stop_func.direction = -1
        event_x_stop_index = len(active_events)
        active_events.append(event_x_stop_func)

    # --- Optional Non-Terminal Events: x_targets ---
    num_x_targets = 0
    event_x_targets_active = False
    x_targets_list = []
    if x_targets is not None:
        try:
            _ = iter(x_targets)
            x_targets_list = list(x_targets)
            num_x_targets = len(x_targets_list)
            if num_x_targets > 0: event_x_targets_active = True
        except TypeError:
            print("Warning: x_targets provided but not iterable. Ignoring.")
            num_x_targets = 0

    if event_x_targets_active:
        base_idx = len(active_events)
        for i, xt in enumerate(x_targets_list):
            def make_event_func(target_x):
                def event_x_target_dynamic(K, u):
                    r, phi = u
                    if r < 1e-10: return 1.0
                    x = r * np.cos(phi)
                    return x - target_x
                event_x_target_dynamic.terminal = False
                event_x_target_dynamic.direction = -1
                return event_x_target_dynamic
            event_func = make_event_func(xt)
            active_events.append(event_func)
            x_target_event_indices.append(base_idx + i)

    # --- Initialization for Integration Loop ---
    # Store dense solutions for each segment
    segments_data = [] # List to store {'t_start':..., 't_end':..., 'interpolator':...}
    crossings_y_at_x_targets = [[] for _ in range(num_x_targets)]
    current_t, current_y = 0.0, [r_0, phi_0]
    final_K = 0.0 # Track the maximum K reached

    # --- Integration Loop ---
    while current_t < t_end:
        sol = solve_ivp(
            ODE_Solver,
            (current_t, t_end),
            current_y,
            events=active_events,
            dense_output=True, # <<< MUST BE TRUE FOR INTERPOLATION
            rtol=1e-10, atol=1e-10,
        )

        # Store the dense output function and its time range for this segment
        if sol.t.size > 1: # Need at least two points for interpolation interval
             segments_data.append({
                 't_start': sol.t[0],
                 't_end': sol.t[-1],
                 'interpolator': sol.sol # Store the callable interpolator
             })
             final_K = sol.t[-1] # Update the final time reached
        elif sol.t.size == 1: # Handle case where solver stops immediately (e.g., event at t=0)
             # Only one point, can't really interpolate, but record the endpoint
             final_K = sol.t[0]
             # Maybe add this point explicitly later if needed?


        # --- Process NON-TERMINAL x_target events using INTERPOLATION ---
        # (This logic remains the same, using sol.sol for accuracy *at the event*)
        if event_x_targets_active and sol.sol:
            for i, event_idx in enumerate(x_target_event_indices):
                if event_idx < len(sol.t_events) and sol.t_events[event_idx].size > 0:
                    target_list_index = i
                    for t_event in sol.t_events[event_idx]:
                        try:
                            interpolated_u = sol.sol(t_event)
                            r_event, phi_event = interpolated_u
                            if r_event < 2 * M or r_event < 1e-10: continue
                            y_event = r_event * np.sin(phi_event)
                            crossings_y_at_x_targets[target_list_index].append(y_event)
                        except ValueError:
                            # Handle cases where t_event might be slightly outside sol.t range due to numerics
                            print(f"Warning: Could not interpolate at x_target event K={t_event}. Skipping point.")


        # --- Handle ALL TERMINAL events ---
        # (This logic remains largely the same, determining the break/continue condition)
        terminal_event_occurred = False
        event_time = t_end
        event_state = sol.y[:,-1] if sol.t.size > 0 else current_y

        min_terminal_event_time = t_end
        triggering_event_index = -1
        triggering_event_subindex = -1

        terminal_indices = [0, 1, 2]
        if event_x_stop_active: terminal_indices.append(event_x_stop_index)

        for term_idx in terminal_indices:
             if term_idx < len(sol.t_events) and sol.t_events[term_idx].size > 0:
                 first_event_time = sol.t_events[term_idx][0]
                 if first_event_time <= min_terminal_event_time:
                    if abs(first_event_time - min_terminal_event_time) < 1e-12:
                         # Prioritize critical events if times are near-identical
                         if term_idx == 0: # fr=0
                            min_terminal_event_time = first_event_time
                            triggering_event_index = term_idx
                            triggering_event_subindex = 0
                         elif term_idx == 1 and triggering_event_index != 0: # r=2M
                            min_terminal_event_time = first_event_time
                            triggering_event_index = term_idx
                            triggering_event_subindex = 0
                    else:
                        min_terminal_event_time = first_event_time
                        triggering_event_index = term_idx
                        triggering_event_subindex = 0


        if triggering_event_index != -1:
            terminal_event_occurred = True
            event_time = min_terminal_event_time
            try:
                # Use dense output to get state at the precise event time
                event_state = sol.sol(event_time)
            except ValueError:
                 # Fallback if interpolation fails (e.g., event exactly at segment boundary)
                 print(f"Warning: Interpolation failed for terminal event {triggering_event_index} at K={event_time}. Using endpoint.")
                 if triggering_event_subindex != -1 and triggering_event_subindex < len(sol.y_events[triggering_event_index]):
                     event_state = sol.y_events[triggering_event_index][triggering_event_subindex].copy()
                 elif sol.t.size > 0: # Fallback to last point of trajectory segment
                     event_state = sol.y[:,-1].copy()
                 else: # Fallback to current_y if segment was empty
                     event_state = current_y.copy()


            if triggering_event_index == 0: # fr=0: Reverse sign and continue
                sign[0] *= -1
                current_t = event_time
                current_y = event_state.copy()

                # Nudge slightly to avoid immediate re-triggering/numerical issues
                nudge_r_amount = 1e-7 * sign[0]
                current_y[0] = current_y[0] + nudge_r_amount
                current_t += 1e-7 # Also nudge time slightly

                # Check if nudge immediately caused termination
                if current_y[0] <= 2 * M or current_y[0] >= r_max: break
                if event_x_stop_active and (current_y[0] * np.cos(current_y[1]) <= x_stop): break

                continue # Continue to next integration segment

            else: # Any other terminal event: Stop integration
                # Update final_K to the exact event time
                final_K = event_time
                break # Exit while loop

        # Handle solver status if no specific terminal event caused termination
        if not terminal_event_occurred:
            if sol.status == 1: # Reached t_end
                final_K = t_end
                break
            elif sol.status < 0:
                print(f"Warning: Solver failed with status {sol.status} at K={current_t}, state={current_y}")
                final_K = sol.t[-1] if sol.t.size > 0 else current_t
                break
            elif sol.status == 0 and sol.t[-1] >= t_end: # Normal completion at t_end
                 final_K = sol.t[-1]
                 break


    # --- Post-Integration Interpolation ---
    if not segments_data and final_K == 0.0: # Handle immediate stop or no integration
        # print("Warning: No integration segments generated.")
        # Return initial conditions as a single point trajectory
        K_out = np.array([0.0])
        r_out = np.array([r_0])
        phi_out = np.array([phi_0 % (2 * np.pi)])
        x_out = r_out * np.cos(phi_out)
        y_out = r_out * np.sin(phi_out)
        return x_out, y_out, r_out, phi_out, K_out, crossings_y_at_x_targets

    # Create the unified time grid for interpolation
    # Ensure K starts at 0 and ends at final_K
    K_interp = np.linspace(0, final_K, num_interp_points)

    # Initialize output arrays
    r_interp = np.zeros_like(K_interp)
    phi_interp = np.zeros_like(K_interp)

    # Perform interpolation across all segments
    current_segment_idx = 0
    for i, k_val in enumerate(K_interp):
        # Find the correct segment for this k_val
        # Start search from the last used segment index for efficiency
        found_segment = False
        for seg_idx in range(current_segment_idx, len(segments_data)):
            segment = segments_data[seg_idx]
            # Check if k_val is within the segment's time range (inclusive)
            # Add small tolerance for floating point comparisons
            if segment['t_start'] - 1e-12 <= k_val <= segment['t_end'] + 1e-12:
                try:
                    r_interp[i], phi_interp[i] = segment['interpolator'](k_val)
                    current_segment_idx = seg_idx # Update hint for next search
                    found_segment = True
                    break # Move to the next k_val
                except ValueError:
                     # This might happen if k_val is *extremely* close to boundary
                     # Try using the endpoint value as a fallback
                     if abs(k_val - segment['t_start']) < 1e-10:
                         r_interp[i], phi_interp[i] = segment['interpolator'](segment['t_start'])
                         current_segment_idx = seg_idx
                         found_segment = True
                         break
                     elif abs(k_val - segment['t_end']) < 1e-10:
                          r_interp[i], phi_interp[i] = segment['interpolator'](segment['t_end'])
                          current_segment_idx = seg_idx
                          found_segment = True
                          break
                     else:
                         print(f"Warning: Interpolation failed for K={k_val} within segment {seg_idx}. Setting to NaN.")
                         r_interp[i], phi_interp[i] = np.nan, np.nan
                         found_segment = True # Mark as found to avoid fallback message
                         break


        if not found_segment:
             # This case should ideally not happen if segments cover 0 to final_K
             # It might occur if the very first point (k_val=0) wasn't captured in a segment
             if i == 0 and k_val == 0.0:
                 r_interp[i], phi_interp[i] = r_0, phi_0 # Use initial condition
             else:
                 print(f"Warning: Could not find segment for K={k_val}. Setting state to NaN.")
                 r_interp[i], phi_interp[i] = np.nan, np.nan


    # Final post-processing
    phi_interp = phi_interp % (2 * np.pi) # Ensure phi is in [0, 2*pi)
    x_interp = r_interp * np.cos(phi_interp)
    y_interp = r_interp * np.sin(phi_interp)

    # Remove NaN values if any occurred, although they indicate a problem
    nan_mask = np.isnan(r_interp) | np.isnan(phi_interp)
    if np.any(nan_mask):
        print(f"Warning: Removing {np.sum(nan_mask)} NaN values from interpolated output.")
        K_out = K_interp[~nan_mask]
        x_out = x_interp[~nan_mask]
        y_out = y_interp[~nan_mask]
        r_out = r_interp[~nan_mask]
        phi_out = phi_interp[~nan_mask]
    else:
        K_out, x_out, y_out, r_out, phi_out = K_interp, x_interp, y_interp, r_interp, phi_interp


    return x_out, y_out, r_out, phi_out, K_out, crossings_y_at_x_targets

In [None]:
#A slimed down verison of the one above, only difference if this trajectory function only returns the crossings
import numpy as np
from scipy.integrate import solve_ivp
# import time # Optional: for timing checks if needed

# Optimized version based on compute_trajectory_1 logic,
# focused solely on returning accurate y-crossings at x_targets.
def compute_trajectory_crossings_only(r_0, phi_0, M, psi, r_max, x_stop=None, x_targets=None, t_end=10**6):
    """
    Calculates trajectory intersections with x_targets accurately using interpolation.

    Based on the robust logic of compute_trajectory_1, but optimized to return
    *only* the crossings_y_at_x_targets list for improved speed when the full
    trajectory is not needed.

    Args:
        r_0, phi_0: Initial polar coordinates.
        M: Mass parameter.
        psi: Initial angle related to impact parameter.
        r_max: Maximum radius before integration stops.
        x_stop: If provided (number), integration stops if x decreases below this value.
        x_targets: A list/array of x-values. Records y-values when the trajectory
                   crosses each x-value (x decreasing). Uses interpolation for accuracy.
        t_end: Maximum integration parameter (lambda/K) value.

    Returns:
        list: A list of lists. Each inner list corresponds to an x_target and
              contains the y-values recorded at each crossing of that target.
              Example: [[y_cross1_xtarget1, y_cross2_xtarget1], [y_cross1_xtarget2], ...]
              Returns empty lists ([[], [], ...]) if x_targets is None or empty,
              or if the trajectory cannot be computed (e.g., r_0 <= 2M).
    """
    # --- Basic Setup ---
    # Handle cases where x_targets is None or empty early
    num_x_targets_initial = 0
    x_targets_list_initial = []
    if x_targets is not None:
         try:
             x_targets_list_initial = list(x_targets) # Check iterability early
             num_x_targets_initial = len(x_targets_list_initial)
         except TypeError:
             print("Warning: x_targets provided but not iterable. Will return empty lists.")
             x_targets = None # Treat as None if not iterable

    # Ensure r_0 is far enough for b calculation and initial state is physical
    if r_0 <= 2 * M: return [[] for _ in range(num_x_targets_initial)]
    metric_term_r0 = (1 - 2 * M / r_0)
    if metric_term_r0 <= 0: return [[] for _ in range(num_x_targets_initial)]
    b = np.cos(psi) * r_0 / np.sqrt(metric_term_r0)
    sign = [1]
    if psi % (2 * np.pi) > np.pi: sign[0] *= -1

    # --- ODE Definition (Identical to compute_trajectory_1, no global k) ---
    def ODE_Solver(K, u):
        r, phi = u
        if r <= 2 * M: return [0.0, 0.0]
        metric_term = (1 - 2 * M / r)
        term_b_r_sq = (b / r) ** 2
        fr = 1 - metric_term * term_b_r_sq
        drdK = sign[0] * np.sqrt(abs(fr)) # Use abs() for robustness near event
        if r < 1e-15: dphidK = 0.0 # Avoid division by zero
        else: dphidK = b / r**2
        return [drdK, dphidK]

    # --- Standard Terminal Event Definitions (Identical to compute_trajectory_1) ---
    def event_fr_zero(K, u): # Index 0
        r, phi = u
        if r <= (2 * M + 1e-12): return 0.0 # Match compute_trajectory_1
        metric_term = (1 - 2 * M / r)
        term_b_r_sq = (b / r) ** 2
        return 1 - metric_term * term_b_r_sq
    event_fr_zero.terminal = True
    event_fr_zero.direction = 0

    def event_r_leq_2M(K, u): # Index 1
        r, phi = u
        return r - (2 * M + 1e-12)
    event_r_leq_2M.terminal = True
    event_r_leq_2M.direction = -1

    def event_r_max(K, u): # Index 2
        r, phi = u
        return r - r_max
    event_r_max.terminal = True
    event_r_max.direction = 1

    # --- Initialize Event List & Indices (Identical to compute_trajectory_1) ---
    active_events = [event_fr_zero, event_r_leq_2M, event_r_max]
    event_x_stop_index = -1
    x_target_event_indices = []
    event_x_stop_active = isinstance(x_stop, (int, float))
    if event_x_stop_active:
        def event_x_stop_func(K, u):
            r, phi = u; x = r * np.cos(phi)
            if r < 1e-10: return 1.0
            return x - x_stop
        event_x_stop_func.terminal = True; event_x_stop_func.direction = -1
        event_x_stop_index = len(active_events); active_events.append(event_x_stop_func)

    num_x_targets = 0; event_x_targets_active = False; x_targets_list = []
    if x_targets is not None and num_x_targets_initial > 0:
        x_targets_list = x_targets_list_initial
        num_x_targets = num_x_targets_initial; event_x_targets_active = True
    if event_x_targets_active:
        base_idx = len(active_events)
        for i, xt in enumerate(x_targets_list):
            def make_event_func(target_x):
                def event_x_target_dynamic(K, u):
                    r, phi = u; x = r * np.cos(phi)
                    if r < 1e-10: return 1.0
                    return x - target_x
                event_x_target_dynamic.terminal = False; event_x_target_dynamic.direction = -1
                return event_x_target_dynamic
            event_func = make_event_func(xt)
            active_events.append(event_func); x_target_event_indices.append(base_idx + i)

    # --- Initialization for Integration Loop ---
    # Only need to store the crossings, not segment data or final_K
    crossings_y_at_x_targets = [[] for _ in range(num_x_targets)]
    current_t, current_y = 0.0, [r_0, phi_0]

    # --- Integration Loop (Logic identical to compute_trajectory_1) ---
    while current_t < t_end:
        sol = solve_ivp(
            ODE_Solver,
            (current_t, t_end),
            current_y,
            events=active_events,
            dense_output=True, # KEEP True for accurate event interpolation
            rtol=1e-8, atol=1e-10,
        )

        # --- Process NON-TERMINAL x_target events using INTERPOLATION ---
        # (Keep identical to compute_trajectory_1 for accuracy)
        if event_x_targets_active and sol.sol is not None:
            for i, event_idx in enumerate(x_target_event_indices):
                if event_idx < len(sol.t_events) and sol.t_events[event_idx].size > 0:
                    target_list_index = i
                    for t_event in sol.t_events[event_idx]:
                        # Check bounds for robustness
                        if not (sol.t[0] <= t_event <= sol.t[-1]): continue
                        try:
                            interpolated_u = sol.sol(t_event); r_event, phi_event = interpolated_u
                            if r_event < 2 * M or r_event < 1e-10: continue
                            y_event = r_event * np.sin(phi_event)
                            crossings_y_at_x_targets[target_list_index].append(y_event)
                        except ValueError:
                            print(f"Warning: Could not interpolate state at x_target event K={t_event}. Skipping.")


        # --- Handle ALL TERMINAL events (Identical logic to compute_trajectory_1) ---
        terminal_event_occurred = False
        event_time = t_end # Default if no event before t_end
        event_state = sol.y[:,-1] if sol.t.size > 0 else current_y # Default state

        min_terminal_event_time = t_end
        triggering_event_index = -1
        triggering_event_subindex = -1 # Keep for potential fallback logic matching _1

        terminal_indices = [0, 1, 2]
        if event_x_stop_active: terminal_indices.append(event_x_stop_index)

        # Find the earliest terminal event (Identical logic to _1)
        for term_idx in terminal_indices:
             if term_idx < len(sol.t_events) and sol.t_events[term_idx].size > 0:
                 first_event_time = sol.t_events[term_idx][0]
                 if first_event_time < min_terminal_event_time - 1e-12: # Use tolerance
                     min_terminal_event_time = first_event_time
                     triggering_event_index = term_idx
                     triggering_event_subindex = 0 # Index within the specific event's list
                 elif abs(first_event_time - min_terminal_event_time) < 1e-12:
                     # Prioritize critical events (Identical logic to _1)
                     if term_idx == 0 and triggering_event_index not in [0]: # fr=0 highest priority if same time
                         min_terminal_event_time = first_event_time
                         triggering_event_index = term_idx
                         triggering_event_subindex = 0
                     elif term_idx == 1 and triggering_event_index not in [0, 1]: # r=2M next
                         min_terminal_event_time = first_event_time
                         triggering_event_index = term_idx
                         triggering_event_subindex = 0
                     # else: keep the already found event if it was 0 or 1

        # Process the first terminal event that occurred (Identical logic to _1)
        # Check if an event was found *before* the end of the integration interval sol.t[-1]
        if triggering_event_index != -1 and min_terminal_event_time <= sol.t[-1] + 1e-12 : # Check includes endpoint
            terminal_event_occurred = True
            event_time = min_terminal_event_time
            try:
                # Use dense output for accuracy (Identical logic to _1)
                t_interp = np.clip(event_time, sol.t[0], sol.t[-1]) # Ensure time is within bounds
                event_state = sol.sol(t_interp)
            except ValueError:
                 # Fallback logic (Identical logic to _1)
                 print(f"Warning: Interpolation failed for terminal event {triggering_event_index} at K={event_time}. Using endpoint/event state.")
                 # Try using y_events if available and index is valid
                 if triggering_event_subindex != -1 and \
                    triggering_event_index < len(sol.y_events) and \
                    triggering_event_subindex < len(sol.y_events[triggering_event_index]):
                     event_state = sol.y_events[triggering_event_index][triggering_event_subindex].copy()
                 elif sol.t.size > 0: # Fallback to last point
                     event_state = sol.y[:,-1].copy()
                 else: # Fallback to current_y if segment was empty
                     event_state = current_y.copy()


            if triggering_event_index == 0: # fr=0: Reverse sign and continue (Identical logic to _1)
                sign[0] *= -1
                current_t = event_time
                current_y = event_state.copy()
                # Nudge slightly (Identical logic to _1)
                nudge_r_amount = 1e-7 * sign[0]
                current_y[0] = current_y[0] + nudge_r_amount
                current_t += 1e-7 # Also nudge time slightly
                # Check if nudge immediately caused termination (Identical logic to _1)
                if current_y[0] <= 2 * M or current_y[0] >= r_max: break
                if event_x_stop_active and (current_y[0] * np.cos(current_y[1]) <= x_stop): break
                continue # Continue to next integration segment

            else: # Any other terminal event: Stop integration
                break # Exit while loop

        # --- Handle solver status if no specific terminal event caused termination ---
        # (Identical logic to compute_trajectory_1, just no final_K update)
        if not terminal_event_occurred:
            if sol.status == 1: # Reached t_end for this segment
                 # Check if it's the overall t_end
                 if sol.t[-1] >= t_end - 1e-12: break # Reached overall t_end
                 else: # Should not happen if interval is (current_t, t_end)
                      current_t = sol.t[-1]; current_y = sol.y[:, -1].copy()
            elif sol.status < 0:
                print(f"Warning: Solver failed with status {sol.status} at K={current_t}, state={current_y}")
                break
            elif sol.status == 0: # Step completed successfully
                 if sol.t[-1] >= t_end - 1e-12: break # Reached overall t_end exactly
                 else: # Step finished before t_end, no event? Update and continue.
                      current_t = sol.t[-1]; current_y = sol.y[:, -1].copy()


    # --- Return ONLY the accumulated y-values at x-target crossings ---
    # No final interpolation needed.
    return crossings_y_at_x_targets

In [None]:
#Image_Map Cartesian
#Takes in the starting position (y,z) and returns where the photon hits the window and hits the target image in xy_{phi_{prime}}
def Image_map(Y,Z,x_0,x_1,x_2,a):
    a_prime=(x_2-x_0)/(x_2-x_1)*a
    r_max=np.sqrt(2*a_prime**2+x_0**2)
    psi=-np.arccos(np.sqrt(Y**2+Z**2)/np.sqrt((x_2-x_1)**2+Y**2+Z**2))
    #_, _, _, _, _,y_list = compute_trajectory_interpolated(r_0=x_2,phi_0=0, M=1,psi=psi,r_max=r_max,x_targets=[x_0,x_1],x_stop=x_0-1)
    y_list = compute_trajectory_crossings_only(r_0=x_2,phi_0=0, M=1,psi=psi,r_max=r_max,x_targets=[x_0,x_1],x_stop=x_0-1)
    y_image=y_list[0]
    y_window=y_list[1]
    if len(y_image) == 0 or len(y_window) == 0:
        return "Miss"
    return y_window[0],y_image[0]
    

In [None]:
#Image_Map_Radial
#Takes in the starting position (y,z) and returns where the photon hits the window and hits the target image in xy_{phi_{prime}}
# coordinates described in Lab noteboook 
def Image_map_radial(r,x_0,x_1,x_2,a):
    a_prime=(x_2-x_0)/(x_2-x_1)*a
    r_max=np.sqrt(2*a_prime**2+x_0**2)
    psi=-np.arccos(np.sqrt(r**2)/np.sqrt((x_2-x_1)**2+r**2))
    y_list = compute_trajectory_crossings_only(r_0=x_2,phi_0=0, M=1,psi=psi,r_max=r_max,x_targets=[x_0,x_1],x_stop=x_0-1)
    y_image=y_list[0]
    y_window=y_list[1]
    if len(y_image) == 0 or len(y_window) == 0:
        return "Miss"
    return y_window[0],y_image[0]

In [None]:
#Results_cartesian
import numpy as np
import time
import os # Added for path operations

# NOTE: This function assumes that a function named 'Image_map_radial'
#       is defined or imported in the scope where Results_Light_Ring is called.

def Results_Cartesian(x_0,x_1,x_2,a,n, chunk_size=1e7): # Added chunk_size parameter

    start_time = time.time()

    results_chunk=[] # Renamed from results
    saved_files = []
    chunk_index = 0

    print("Starting Results_Light_Ring generation...")

    # --- Define base filename ---
    x0_str = str(x_0).replace('.','p').replace('-','neg')
    x1_str = str(x_1).replace('.','p').replace('-','neg')
    x2_str = str(x_2).replace('.','p').replace('-','neg')
    base_save_path = f'Results_Light_Ring_Cartesian_{x0_str}_{x1_str}_{x2_str}_{n}'
    output_directory = "."
    os.makedirs(output_directory, exist_ok=True)

    # --- Helper function to save chunk ---
    def save_chunk(current_chunk_list, current_chunk_index):
        if not current_chunk_list: return None
        chunk_save_path = os.path.join(output_directory, f"{base_save_path}_chunk_{current_chunk_index:03d}.npy")
        print(f"\nSaving chunk {current_chunk_index} ({len(current_chunk_list)} results) to {chunk_save_path}...")
        try:
            np.save(chunk_save_path, np.array(current_chunk_list, dtype=np.float64))
            print(f"Chunk {current_chunk_index} saved.")
            return chunk_save_path
        except Exception as e:
            print(f"\nERROR saving chunk {current_chunk_index} to {chunk_save_path}: {e}")
            return None

    # --- Main Loops ---
    try:
        for i in range(n+1):
            y=a*i/n

            for j in range(i+1):
                z=a*j/n

                if j % 10**4==0:
                    print(y,"y",z," z ")

                phi_prime=np.arctan2(z,y)

                output=Image_map(Y=y,Z=z,x_0=x_0,x_1=x_1,x_2=x_2,a=a)
                if output=='Miss':
                    continue
                else:
                    y_window, y_image = output
                if y == z or z==0:
                    phi_values = [phi_prime, phi_prime + np.pi/2, phi_prime + np.pi, phi_prime + 3/2 * np.pi]
                else:
                    phi_values = [
                        phi_prime, phi_prime + np.pi/2, phi_prime + np.pi, phi_prime + 3/2 * np.pi,
                        np.pi/2 - phi_prime, np.pi - phi_prime, 3*np.pi/2 - phi_prime, 2*np.pi - phi_prime]
                for phi_prime in phi_values:
                    results_chunk.append((
                        y_window * np.cos(phi_prime), y_window * np.sin(phi_prime),
                        y_image * np.cos(phi_prime), y_image * np.sin(phi_prime)))

                # --- Check chunk size and save if needed ---
                if len(results_chunk) >= chunk_size:
                    saved_path = save_chunk(results_chunk, chunk_index)
                    if saved_path:
                        saved_files.append(saved_path)
                    else:
                        # Handle save error - e.g., stop the process
                        raise IOError(f"Failed to save chunk {chunk_index}")
                    results_chunk = [] # Reset chunk
                    chunk_index += 1

    except KeyboardInterrupt:
         print("\n--- Process Interrupted ---")
    except Exception as e:
         print(f"\n--- Error during processing: {e} ---")
         import traceback
         traceback.print_exc()
    finally:
        # --- Save the final (potentially incomplete) chunk ---
        print("\nAttempting to save final chunk...")
        final_saved_path = save_chunk(results_chunk, chunk_index)
        if final_saved_path:
            saved_files.append(final_saved_path)

        end_time = time.time()
        duration = end_time - start_time

        total_results_saved = sum(len(np.load(f)) for f in saved_files if os.path.exists(f))

        print(f"\nFinished generation attempt.")
        print(f"Total results saved: {total_results_saved}")
        print(f"Number of chunk files created: {len(saved_files)}")
        print(f"Total execution time: {duration:.2f} seconds")

        # Original code returned the array, now return filenames
        return saved_files

In [None]:
#Results_Radial
import numpy as np
import time
import os # Added for path operations

# NOTE: This function assumes that a function named 'Image_map_radial'
#       is defined or imported in the scope where Results_Light_Ring is called.

def Results_Radial(x_0,x_1,x_2,R_max,n, chunk_size=1e7): # Added chunk_size parameter

    start_time = time.time()

    a=1*(x_2-x_1)
    results_chunk=[] # Renamed from results
    saved_files = []
    chunk_index = 0

    print("Starting Results_Light_Ring generation...")

    # --- Define base filename ---
    x0_str = str(x_0).replace('.','p').replace('-','neg')
    x1_str = str(x_1).replace('.','p').replace('-','neg')
    x2_str = str(x_2).replace('.','p').replace('-','neg')
    base_save_path = f'Results_rainbow_thesis_{x0_str}_{x1_str}_{x2_str}_{n}'
    output_directory = "."
    os.makedirs(output_directory, exist_ok=True)

    # --- Helper function to save chunk ---
    def save_chunk(current_chunk_list, current_chunk_index):
        if not current_chunk_list: return None
        chunk_save_path = os.path.join(output_directory, f"{base_save_path}_chunk_{current_chunk_index:03d}.npy")
        print(f"\nSaving chunk {current_chunk_index} ({len(current_chunk_list)} results) to {chunk_save_path}...")
        try:
            np.save(chunk_save_path, np.array(current_chunk_list, dtype=np.float64))
            print(f"Chunk {current_chunk_index} saved.")
            return chunk_save_path
        except Exception as e:
            print(f"\nERROR saving chunk {current_chunk_index} to {chunk_save_path}: {e}")
            return None

    # --- Main Loops ---
    try:
        for i in range(n+1):
            r=R_max*i/n

            k=max(int(5*10**4*r),int(1))

            # Only call Image_map_radial once if k > 0
            output = Image_map_radial(r,x_0,x_1,x_2,a)
            valid_output = (output != "Miss")
            if valid_output:
                 y_window, y_image = output

            for j in range(k+1):
                # Calculate phi_prime inside loop as before
                phi_prime=j/k*2*np.pi

                if j % 10**4==0:
                    print(r,"r",phi_prime," phi_prime  ")

                if valid_output:
                    cos_phi = np.cos(phi_prime)
                    sin_phi = np.sin(phi_prime)
                    results_chunk.append((y_window * cos_phi, y_window * sin_phi,y_image * cos_phi, y_image * sin_phi))
                # NOTE: Original inner loop logic had `continue` for `output == "Miss"`.
                #       So, we only append if it's NOT a miss, matching original behaviour.

                # --- Check chunk size and save if needed ---
                if len(results_chunk) >= chunk_size:
                    saved_path = save_chunk(results_chunk, chunk_index)
                    if saved_path:
                        saved_files.append(saved_path)
                    else:
                        # Handle save error - e.g., stop the process
                        raise IOError(f"Failed to save chunk {chunk_index}")
                    results_chunk = [] # Reset chunk
                    chunk_index += 1

    except KeyboardInterrupt:
         print("\n--- Process Interrupted ---")
    except Exception as e:
         print(f"\n--- Error during processing: {e} ---")
         import traceback
         traceback.print_exc()
    finally:
        # --- Save the final (potentially incomplete) chunk ---
        print("\nAttempting to save final chunk...")
        final_saved_path = save_chunk(results_chunk, chunk_index)
        if final_saved_path:
            saved_files.append(final_saved_path)

        end_time = time.time()
        duration = end_time - start_time

        total_results_saved = sum(len(np.load(f)) for f in saved_files if os.path.exists(f))

        print(f"\nFinished generation attempt.")
        print(f"Total results saved: {total_results_saved}")
        print(f"Number of chunk files created: {len(saved_files)}")
        print(f"Total execution time: {duration:.2f} seconds")

        # Original code returned the array, now return filenames
        return saved_files

In [None]:
#Results_Light_Ring
import numpy as np
import time
import os # Added for path operations

# NOTE: This function assumes that a function named 'Image_map_radial'
#       is defined or imported in the scope where Results_Light_Ring is called.

def Results_Radial_Light_Ring(x_0,x_1,x_2,n, chunk_size=1e7): # Added chunk_size parameter

    start_time = time.time()

    a=1*(x_2-x_1)
    results_chunk=[] # Renamed from results
    saved_files = []
    chunk_index = 0
    misses=[]
    hits=[]

    print("Starting Results_Light_Ring generation...")

    # --- Define base filename ---
    x0_str = str(x_0).replace('.','p').replace('-','neg')
    x1_str = str(x_1).replace('.','p').replace('-','neg')
    x2_str = str(x_2).replace('.','p').replace('-','neg')
    base_save_path = f'Light_Ring_Segement_0th_large_{x0_str}_{x1_str}_{x2_str}_{n}'
    output_directory = "."
    os.makedirs(output_directory, exist_ok=True)

    # --- Helper function to save chunk ---
    def save_chunk(current_chunk_list, current_chunk_index):
        if not current_chunk_list: return None
        chunk_save_path = os.path.join(output_directory, f"{base_save_path}_chunk_{current_chunk_index:03d}.npy")
        print(f"\nSaving chunk {current_chunk_index} ({len(current_chunk_list)} results) to {chunk_save_path}...")
        try:
            np.save(chunk_save_path, np.array(current_chunk_list, dtype=np.float64))
            print(f"Chunk {current_chunk_index} saved.")
            return chunk_save_path
        except Exception as e:
            print(f"\nERROR saving chunk {current_chunk_index} to {chunk_save_path}: {e}")
            return None

    # --- Main Loops ---
    try:
        for i in range(n+1):

            r=0.2543706590950274+i/n*.0008610946154524735
            k=int(6*10**4*r)
            #k=0
            #print(k,"   k ")

            # Only call Image_map_radial once if k > 0
            output = Image_map_radial(r,x_0,x_1,x_2,a)
            valid_output = (output != "Miss")
            if valid_output:
                 y_window, y_image = output
                 #hits.append(r)
            #else:
                #misses.append(r)

            for j in range(k+1):
                # Calculate phi_prime inside loop as before
                phi_prime=3/2*np.pi+j/k*np.pi/180/10

                if j % 10**4==0:
                    print(r,"r",phi_prime," phi_prime  ")

                if valid_output:
                    cos_phi = np.cos(phi_prime)
                    sin_phi = np.sin(phi_prime)
                    results_chunk.append((y_window * cos_phi, y_window * sin_phi,y_image * cos_phi, y_image * sin_phi))
                # NOTE: Original inner loop logic had `continue` for `output == "Miss"`.
                #       So, we only append if it's NOT a miss, matching original behaviour.

                # --- Check chunk size and save if needed ---
                if len(results_chunk) >= chunk_size:
                    saved_path = save_chunk(results_chunk, chunk_index)
                    if saved_path:
                        saved_files.append(saved_path)
                    else:
                        # Handle save error - e.g., stop the process
                        raise IOError(f"Failed to save chunk {chunk_index}")
                    results_chunk = [] # Reset chunk
                    chunk_index += 1

    except KeyboardInterrupt:
         print("\n--- Process Interrupted ---")
    except Exception as e:
         print(f"\n--- Error during processing: {e} ---")
         import traceback
         traceback.print_exc()
    finally:
        # --- Save the final (potentially incomplete) chunk ---
        print("\nAttempting to save final chunk...")
        final_saved_path = save_chunk(results_chunk, chunk_index)
        if final_saved_path:
            saved_files.append(final_saved_path)

        end_time = time.time()
        duration = end_time - start_time

        total_results_saved = sum(len(np.load(f)) for f in saved_files if os.path.exists(f))

        print(f"\nFinished generation attempt.")
        print(f"Total results saved: {total_results_saved}")
        print(f"Number of chunk files created: {len(saved_files)}")
        print(f"Total execution time: {duration:.2f} seconds")

        # Original code returned the array, now return filenames
        return saved_files,misses,hits

In [None]:
#Map_photons (Takes in multiple files, can zoom in on a region [a,b]x[c,d] in the window, will do normal region if [a,b]x[c,d] isn't specified)
import numpy as np
import time
import os
import sys
from PIL import Image, UnidentifiedImageError
import traceback
import math # Import math for ceiling function

# Helper function _save_image (keep unchanged from original)
def _save_image(image_array, save_path):
    """Helper function to save a NumPy array as an RGB image using PIL."""
    try:
        parent_dir = os.path.dirname(save_path)
        if parent_dir and not os.path.exists(parent_dir):
            os.makedirs(parent_dir, exist_ok=True)
        if image_array.dtype != np.uint8:
             if image_array.dtype in [np.float32, np.float64]:
                 # Ensure conversion handles potential NaN/Inf before clipping/casting
                 image_array = np.nan_to_num(image_array, nan=0.0, posinf=255.0, neginf=0.0)
                 image_array = np.clip(image_array, 0, 255).astype(np.uint8)
             else:
                 try:
                     # Attempt direct cast for integer types after checking range maybe?
                     # For now, assume potential issues and convert carefully
                     if np.issubdtype(image_array.dtype, np.integer):
                         min_val, max_val = np.min(image_array), np.max(image_array)
                         if min_val < 0 or max_val > 255:
                              print(f"Warning: Integer data out of uint8 range [{min_val}, {max_val}]. Clipping.")
                              image_array = np.clip(image_array, 0, 255).astype(np.uint8)
                         else:
                              image_array = image_array.astype(np.uint8)
                     else: # For other types, attempt standard conversion
                          image_array = image_array.astype(np.uint8)
                 except (ValueError, OverflowError, TypeError) as e:
                      print(f"Warning: Could not convert final image dtype {image_array.dtype} to uint8. Error: {e}");
                      return False
        if image_array.ndim != 3 or image_array.shape[0] == 0 or image_array.shape[1] == 0 or image_array.shape[2] != 3:
             print(f"Error: Final image is not valid RGB (shape: {image_array.shape}). Cannot save.");
             return False

        # Double-check for NaN/Inf again after potential type conversions
        if np.any(np.isnan(image_array)) or np.any(np.isinf(image_array)):
             print(f"Error: Final image array contains NaN or Inf values after conversion attempts. Cannot save.")
             return False

        pil_image = Image.fromarray(image_array, mode='RGB')
        pil_image.save(save_path)
        print(f"Successfully saved image to {save_path}")
        return True
    except Exception as e:
        print(f"Error saving image to {save_path}: {e}");
        traceback.print_exc(limit=2);
        return False


def map_photons(
    image_source,
    photon_chunk_files,
    save_path,
    dest_logical_bounds=None,      # Format: [y_min(horiz), y_max(horiz), z_min(vert), z_max(vert)]
    pixels_per_logical_unit=None, # Resolution scale parameter
    output_shape=None,           # Optional, only used if dest_logical_bounds is None
    default_color=(0, 0, 0),
    epsilon=1e-9,
    flip_z_axis_render=True # New parameter to control Z-axis rendering direction
):
    """
    Maps photons. Uses max extent for source bounds.
    VERSION 3: Adds flip_z_axis_render to control Z-axis mapping for "bottom-up" view.

    Two main modes:
    1. Manual Bounds + Scale: Provide 'dest_logical_bounds' and 'pixels_per_logical_unit'.
       The output width and height are calculated automatically based on the logical range
       and the specified scale, preserving the logical aspect ratio (no stretch/compression).
       'output_shape' must be None.
    2. Auto Bounds + Manual Shape: Provide 'output_shape' (height, width).
       'dest_logical_bounds' and 'pixels_per_logical_unit' must be None. Bounds are calculated
       automatically from max photon extent and output_shape aspect ratio.

    **Axis Convention Change:** Destination coordinate y maps to the HORIZONTAL axis (width),
    and z maps to the VERTICAL axis (height).
    **Photon Chunk Convention:** Assumes y0,y1 are HORIZONTAL; z0,z1 are VERTICAL in chunk files.

    Args:
        image_source (str | np.ndarray): Path to image or NumPy array (H,W[,C]).
        photon_chunk_files (list[str]): REQUIRED. List of file paths to .npy chunks.
        save_path (str): REQUIRED. Path where the output image will be saved.
        dest_logical_bounds (list|tuple, optional): Defines the destination mapping window
            in *logical* coordinates [y_min(horiz), y_max(horiz), z_min(vert), z_max(vert)].
            If set, 'pixels_per_logical_unit' MUST also be set. Defaults to None.
            The z_min(vert) and z_max(vert) define the logical range.
        pixels_per_logical_unit (float | int, optional): Resolution scale (pixels per 1 unit
            of logical distance). REQUIRED if 'dest_logical_bounds' is set. Ignored otherwise.
        output_shape (tuple, optional): Output (height, width) in pixels.
            REQUIRED if 'dest_logical_bounds' is None. Ignored otherwise.
        default_color (int | tuple, optional): Background RGB color. Defaults to (0, 0, 0).
        epsilon (float, optional): Small value for division-by-zero checks.
        flip_z_axis_render (bool, optional): If True, renders the Z-axis such that
            dest_map_bound_min_z maps to the bottom of the image (pixel row mapped_h-1)
            and dest_map_bound_max_z maps to the top (pixel row 0). This gives a
            "bottom-up" view if z_min is logically "lower". Defaults to True for your case.
            If False, behaves as original (min_z to top, max_z to bottom).

    Returns:
        bool: True if the image was processed and saved successfully, False otherwise.
    """
    start_time_total = time.time()
    mapped_image = None
    mapped_h, mapped_w = 1, 1 # Will be overwritten
    using_manual_dest_bounds = False # Flag to track mode

    # --- DEBUG COUNTERS ---
    total_photons_loaded_from_chunks = 0
    photons_lost_nan = 0
    photons_lost_source_bounds = 0
    photons_lost_dest_logical_filter = 0
    photons_lost_dest_pixel_map = 0
    photons_lost_invalid_source_idx = 0
    photons_accumulated = 0
    # --- END DEBUG COUNTERS ---

    # --- 0. Validate save_path & chunk list ---
    if not save_path or not isinstance(save_path, str): raise ValueError("A valid string save_path is required.")
    if not isinstance(photon_chunk_files, list) or not all(isinstance(f, str) for f in photon_chunk_files):
        raise TypeError("photon_chunk_files must be a list of file path strings.")

    # --- Determine Mode and Validate Inputs / Set Output Shape ---
    dest_map_bound_min_y, dest_map_bound_max_y = 0, 0 # y now maps to horizontal
    dest_map_bound_min_z, dest_map_bound_max_z = 0, 0 # z now maps to vertical

    if dest_logical_bounds is not None:
        # --- Mode 1: Manual Bounds + Scale ---
        using_manual_dest_bounds = True
        print("Using MANUAL destination bounds + SCALE mode.")
        if flip_z_axis_render:
            print("  Z-axis rendering: FLIPPED (min_z to image bottom, max_z to image top).")
        else:
            print("  Z-axis rendering: NORMAL (min_z to image top, max_z to image bottom).")


        # Check required inputs for this mode
        if pixels_per_logical_unit is None:
            raise ValueError("Must provide 'pixels_per_logical_unit' when 'dest_logical_bounds' is set.")
        # Check for conflicting inputs
        if output_shape is not None:
            print("    Warning: 'output_shape' parameter provided but ignored because 'dest_logical_bounds' is set.")

        # Validate dest_logical_bounds
        if not isinstance(dest_logical_bounds, (list, tuple)) or len(dest_logical_bounds) != 4:
            raise ValueError("dest_logical_bounds must be a list or tuple of four numbers: [y_min(horiz), y_max(horiz), z_min(vert), z_max(vert)]")
        try:
            a, b, c, d = [float(x) for x in dest_logical_bounds]
            if a > b or c > d: # c is min_z, d is max_z
                raise ValueError("Destination logical bounds must have min <= max for both y ([a,b], horizontal) and z ([c,d], vertical).")
            dest_map_bound_min_y = a
            dest_map_bound_max_y = b
            dest_map_bound_min_z = c # This is the logical minimum Z for the range
            dest_map_bound_max_z = d # This is the logical maximum Z for the range
            print(f"  Logical Bounds: y(horiz)=[{a:.3f}, {b:.3f}], z(vert)=[{c:.3f}, {d:.3f}]")
        except (ValueError, TypeError) as e:
             raise ValueError(f"Invalid dest_logical_bounds format or values: {e}")

        # Validate pixels_per_logical_unit
        try:
            scale = float(pixels_per_logical_unit)
            if scale <= 0:
                raise ValueError("pixels_per_logical_unit must be a positive number.")
            print(f"  Scale: {scale:.3f} pixels per logical unit")
        except (ValueError, TypeError):
            raise ValueError("pixels_per_logical_unit must be a positive number.")

        # Calculate output dimensions based on logical range and scale
        logical_width = dest_map_bound_max_y - dest_map_bound_min_y
        logical_height = dest_map_bound_max_z - dest_map_bound_min_z # This is the extent of the Z range

        # Use ceiling to ensure the pixel grid covers the entire logical range
        # Ensure minimum dimension is 1 pixel
        mapped_w = max(1, int(math.ceil(logical_width * scale)))
        mapped_h = max(1, int(math.ceil(logical_height * scale)))

        print(f"--> Calculated Output Shape (Preserving Aspect Ratio): (Height={mapped_h}, Width={mapped_w})")

    else:
        # --- Mode 2: Auto Bounds + Manual Shape ---
        print("Using AUTO destination bounds + SHAPE mode.")
        if flip_z_axis_render:
            print("  Z-axis rendering: FLIPPED (min_z to image bottom, max_z to image top).")
        else:
            print("  Z-axis rendering: NORMAL (min_z to image top, max_z to image bottom).")


        # Check required input for this mode
        if output_shape is None:
             raise ValueError("Must provide 'output_shape' when 'dest_logical_bounds' is not set.")
        # Check for conflicting inputs
        if pixels_per_logical_unit is not None:
             print("    Warning: 'pixels_per_logical_unit' parameter provided but ignored because 'dest_logical_bounds' is not set.")

        # Validate output_shape
        if not isinstance(output_shape, tuple) or len(output_shape) != 2:
            raise ValueError("output_shape must be a tuple of (height, width).")
        try:
            mapped_h, mapped_w = int(output_shape[0]), int(output_shape[1])
            if mapped_h <= 0 or mapped_w <= 0:
                raise ValueError("output_shape dimensions must be positive integers.")
        except (ValueError, TypeError):
            raise ValueError("output_shape dimensions must be positive integers.")
        print(f"  Using provided Output Shape: (Height={mapped_h}, Width={mapped_w})")


    # --- Handle Empty Chunk List (Uses calculated/provided mapped_h/w) ---
    if not photon_chunk_files:
        print("Warning: photon_chunk_files list is empty. Nothing to process.")
        try:
            if isinstance(default_color, (int,float,np.number)): default_color_rgb = (default_color,)*3
            elif isinstance(default_color, (tuple, list)) and len(default_color) == 3: default_color_rgb = tuple(default_color)
            else: default_color_rgb = (0, 0, 0); print("Warning: Invalid default_color. Using (0,0,0).")
            default_color_rgb = tuple(np.clip(int(c), 0, 255) for c in default_color_rgb) # Ensure default_color_rgb is defined
            mapped_image = np.full((mapped_h, mapped_w, 3), default_color_rgb, dtype=np.uint8)
            return _save_image(mapped_image, save_path)
        except Exception as e: print(f"Error setting up background image dimensions: {e}"); return False


    # --- Continue Processing ---
    try:
        # --- 1. Calculate Output Aspect Ratio ---
        output_pixel_aspect = mapped_w / mapped_h if mapped_h > 0 else 1.0

        # --- 2. Load Source Image ---
        if isinstance(image_source, str):
            try:
                with Image.open(image_source) as img:
                    img_mode = img.mode;
                    if img.mode != 'RGB': img = img.convert('RGB')
                    original_image = np.array(img)
            except FileNotFoundError: raise FileNotFoundError(f"Source image file not found: {image_source}")
            except UnidentifiedImageError: raise RuntimeError(f"Could not identify or open image file: {image_source}")
            except Exception as e: raise RuntimeError(f"Error loading image file {image_source}: {e}")
        elif isinstance(image_source, np.ndarray):
            original_image = image_source.copy()
        else:
            raise TypeError("image_source must be a file path (string) or a NumPy array.")

        # --- Force RGB, Get Source Dimensions, Aspect ---
        if original_image.ndim == 2: original_image_rgb = np.stack([original_image] * 3, axis=-1)
        elif original_image.ndim == 3 and original_image.shape[2] == 1: original_image_rgb = np.repeat(original_image, 3, axis=-1)
        elif original_image.ndim == 3 and original_image.shape[2] == 3: original_image_rgb = original_image
        elif original_image.ndim == 3 and original_image.shape[2] == 4: original_image_rgb = original_image[:, :, :3] # Drop alpha
        else: raise ValueError(f"Unsupported source image format (shape: {original_image.shape})")

        orig_h, orig_w, _ = original_image_rgb.shape
        if orig_h == 0 or orig_w == 0: raise ValueError("Source image dimensions cannot be zero.")
        source_pixel_aspect = orig_w / orig_h

        # --- Rescale Logic ---
        final_image_for_sampling = original_image_rgb
        try:
            source_max_val = np.max(final_image_for_sampling); source_min_val = np.min(final_image_for_sampling)
            needs_rescaling = False; rescale_threshold = 32 # Rescale if max value is low (e.g., < 32 for uint8)
            if final_image_for_sampling.dtype == np.uint8 and source_max_val > source_min_val and source_max_val < rescale_threshold: needs_rescaling = True

            if needs_rescaling:
                 print("    Rescaling low-contrast source image to full range [0, 255].")
                 if source_max_val > source_min_val:
                     scaled_image_float = ((final_image_for_sampling.astype(np.float64) - source_min_val) / (source_max_val - source_min_val) * 255.0)
                 else: # Handle flat image case
                     scaled_image_float = np.full_like(final_image_for_sampling, 128.0)
                 final_image_for_sampling = np.clip(scaled_image_float, 0, 255).astype(np.uint8)
            elif final_image_for_sampling.dtype != np.uint8: # Ensure uint8 otherwise
                 print(f"    Converting source image from {final_image_for_sampling.dtype} to uint8, clipping to [0, 255].")
                 final_image_for_sampling = np.clip(final_image_for_sampling, 0, 255).astype(np.uint8)
        except Exception as scale_e: print(f"  Warning: Could not perform source image range check/rescaling: {scale_e}.")

        original_image_rgb = final_image_for_sampling # Use potentially rescaled image

        print(f"Source image loaded: shape={original_image_rgb.shape}, aspect={source_pixel_aspect:.3f}")
        print(f"Output shape set to: {(mapped_h, mapped_w)}, aspect={output_pixel_aspect:.3f}")


        # --- 4. Prepare Default Color ---
        if isinstance(default_color, (int, float, np.number)): default_color_rgb = (default_color,) * 3
        elif isinstance(default_color, (tuple, list)) and len(default_color) == 3: default_color_rgb = tuple(default_color)
        else: default_color_rgb = (0, 0, 0); print("Warning: Invalid default_color. Using (0,0,0).")
        try:
            default_color_rgb = tuple(np.clip(int(c), 0, 255) for c in default_color_rgb)
        except (ValueError, TypeError):
            default_color_rgb = (0, 0, 0); print("Warning: Could not parse default_color. Using (0,0,0).")


        # --- 5. Initialize Accumulation Buffers (Uses determined mapped_h/w) ---
        print("Initializing accumulation buffers...")
        sum_array = np.zeros((mapped_h, mapped_w, 3), dtype=np.float64)
        count_array = np.zeros((mapped_h, mapped_w), dtype=np.int64)
        mapped_image = np.full((mapped_h, mapped_w, 3), default_color_rgb, dtype=np.uint8)

        # --- Determine Source & Auto Destination Bounds (Scan needed for source + auto dest) ---
        print("Scanning chunks to determine coordinate bounds...")
        global_max_abs_y0=epsilon; global_max_abs_z0=epsilon # Needed only for Auto mode
        global_max_abs_y1=epsilon; global_max_abs_z1=epsilon # Needed for Source bounds
        scan_start_time = time.time()
        photons_found_in_scan = False
        for chunk_idx, chunk_file in enumerate(photon_chunk_files):
             try:
                photon_list_chunk = np.load(chunk_file)
                if photon_list_chunk.ndim != 2 or photon_list_chunk.shape[1] != 4 or photon_list_chunk.shape[0] == 0: continue
                y0c, z0c, y1c, z1c = photon_list_chunk[:,0], photon_list_chunk[:,1], photon_list_chunk[:,2], photon_list_chunk[:,3]
                nan_mask_chunk = np.isnan(y0c)|np.isnan(z0c)|np.isnan(y1c)|np.isnan(z1c)
                valid_indices = ~nan_mask_chunk
                if not np.any(valid_indices): continue
                photons_found_in_scan = True
                if not using_manual_dest_bounds:
                    global_max_abs_y0=max(global_max_abs_y0, np.max(np.abs(y0c[valid_indices])))
                    global_max_abs_z0=max(global_max_abs_z0, np.max(np.abs(z0c[valid_indices])))
                global_max_abs_y1=max(global_max_abs_y1, np.max(np.abs(y1c[valid_indices])))
                global_max_abs_z1=max(global_max_abs_z1, np.max(np.abs(z1c[valid_indices])))
                del photon_list_chunk
             except FileNotFoundError: print(f"\nWarn: File not found during scan: {chunk_file}. Skip."); continue
             except Exception as e: print(f"\nWarn: Error scanning {chunk_file} for bounds: {e}")
        print(f"\nBounds scan complete. Time: {time.time() - scan_start_time:.2f}s")

        if not photons_found_in_scan:
             print("\nError: No valid (non-NaN) photons found in any chunk during scan. Cannot proceed.")
             return _save_image(mapped_image, save_path)


        # --- Calculate SOURCE Mapping Bounds ---
        source_bound_y = max(global_max_abs_y1, epsilon); source_bound_z = max(global_max_abs_z1, epsilon)
        if source_pixel_aspect >= 1.0:
             temp_source_map_bound_y = source_bound_y
             temp_source_map_bound_z = max(source_bound_z, source_bound_y / source_pixel_aspect)
        else:
             temp_source_map_bound_z = source_bound_z
             temp_source_map_bound_y = max(source_bound_y, source_bound_z * source_pixel_aspect)
        source_map_bound_y = temp_source_map_bound_y
        source_map_bound_z = temp_source_map_bound_z
        source_denom_y = max(2.0 * source_map_bound_y, epsilon); source_denom_z = max(2.0 * source_map_bound_z, epsilon)
        print(f"Source mapping bounds (Aspect Corrected): y1(horiz):[+/-{source_map_bound_y:.3f}], z1(vert):[+/-{source_map_bound_z:.3f}]")


        # --- Calculate DESTINATION Mapping Denominators ---
        dest_denom_y, dest_denom_z = epsilon, epsilon

        if using_manual_dest_bounds:
            # Mode 1: dest_map_bound_min_z and dest_map_bound_max_z already set from input
            dest_range_y = dest_map_bound_max_y - dest_map_bound_min_y
            dest_range_z = dest_map_bound_max_z - dest_map_bound_min_z # This is (logical_max_z - logical_min_z)
            dest_denom_y = max(dest_range_y, epsilon)
            dest_denom_z = max(dest_range_z, epsilon) # Denominator for Z mapping
        else: # Mode 2 (Auto Bounds)
            logical_dest_bound_y = max(global_max_abs_y0, epsilon)
            logical_dest_bound_z = max(global_max_abs_z0, epsilon)
            print(f"Auto-calculating destination bounds from max extent: y0(horiz)=+/-{logical_dest_bound_y:.3f}, z0(vert)=+/-{logical_dest_bound_z:.3f}")
            temp_dest_map_bound_y, temp_dest_map_bound_z = 0, 0
            if output_pixel_aspect >= 1.0:
                 temp_dest_map_bound_y = logical_dest_bound_y
                 temp_dest_map_bound_z = max(logical_dest_bound_z, logical_dest_bound_y / output_pixel_aspect)
            else:
                 temp_dest_map_bound_z = logical_dest_bound_z
                 temp_dest_map_bound_y = max(logical_dest_bound_y, logical_dest_bound_z * output_pixel_aspect)
            dest_map_bound_min_y = -temp_dest_map_bound_y
            dest_map_bound_max_y = temp_dest_map_bound_y
            dest_map_bound_min_z = -temp_dest_map_bound_z # This is logical_min_z for auto mode
            dest_map_bound_max_z = temp_dest_map_bound_z # This is logical_max_z for auto mode
            dest_denom_y = max(2.0 * temp_dest_map_bound_y, epsilon)
            dest_denom_z = max(2.0 * temp_dest_map_bound_z, epsilon) # Denominator for Z mapping
            print(f"Destination mapping bounds (Auto, Aspect Corrected): y(horiz)=[{dest_map_bound_min_y:.3f},{dest_map_bound_max_y:.3f}], z(vert)=[{dest_map_bound_min_z:.3f},{dest_map_bound_max_z:.3f}]")

        print(f"Final Mapping Denominators: src_y1(horiz)={source_denom_y:.3f}, src_z1(vert)={source_denom_z:.3f}, dst_y0(horiz)={dest_denom_y:.3f}, dst_z0(vert)={dest_denom_z:.3f}")


        # --- Loop Through Chunks for Processing ---
        print("\nProcessing photon data chunk by chunk...")
        start_proc_time = time.time()
        for chunk_idx, chunk_file in enumerate(photon_chunk_files):
            if (chunk_idx + 1) % 10 == 0 or chunk_idx == 0 or chunk_idx == len(photon_chunk_files) - 1:
                 print(f"\rProcessing chunk {chunk_idx + 1}/{len(photon_chunk_files)}: {os.path.basename(chunk_file)}...", end="")
            try:
                photon_list = np.load(chunk_file)
                count_loaded_chunk = 0
                if photon_list.ndim == 2 and photon_list.shape[1] == 4 and photon_list.shape[0] > 0:
                    count_loaded_chunk = photon_list.shape[0]
                    total_photons_loaded_from_chunks += count_loaded_chunk
                else:
                    print(f"\nWarn: Invalid shape in chunk {chunk_file}: {photon_list.shape}. Skipping.")
                    del photon_list; continue
                y0_f, z0_f, y1_f, z1_f = photon_list[:, 0], photon_list[:, 1], photon_list[:, 2], photon_list[:, 3]
                nan_mask = np.isnan(y0_f) | np.isnan(z0_f) | np.isnan(y1_f) | np.isnan(z1_f)
                valid_indices_nan = np.where(~nan_mask)[0]
                count_after_nan = len(valid_indices_nan)
                photons_lost_nan += (count_loaded_chunk - count_after_nan)
                if count_after_nan == 0: del photon_list; continue
                source_bounds_ok = (np.abs(y1_f[valid_indices_nan]) <= source_map_bound_y + epsilon) & \
                                   (np.abs(z1_f[valid_indices_nan]) <= source_map_bound_z + epsilon)
                valid_indices_src_bounds = valid_indices_nan[source_bounds_ok]
                count_after_src_bounds = len(valid_indices_src_bounds)
                photons_lost_source_bounds += (count_after_nan - count_after_src_bounds)
                if count_after_src_bounds == 0: del photon_list; continue
                y0_f_filt = y0_f[valid_indices_src_bounds]; z0_f_filt = z0_f[valid_indices_src_bounds]
                dest_logical_ok = (y0_f_filt >= dest_map_bound_min_y - epsilon) & (y0_f_filt <= dest_map_bound_max_y + epsilon) & \
                                  (z0_f_filt >= dest_map_bound_min_z - epsilon) & (z0_f_filt <= dest_map_bound_max_z + epsilon)
                valid_indices_dest_logical = valid_indices_src_bounds[dest_logical_ok]
                count_after_dest_logical = len(valid_indices_dest_logical)
                photons_lost_dest_logical_filter += (count_after_src_bounds - count_after_dest_logical)
                if count_after_dest_logical == 0: del photon_list; continue
                y0_f_pv = y0_f[valid_indices_dest_logical]; z0_f_pv = z0_f[valid_indices_dest_logical]
                y1_f_pv = y1_f[valid_indices_dest_logical]; z1_f_pv = z1_f[valid_indices_dest_logical]

                row_indices_f = ((z1_f_pv + source_map_bound_z) / source_denom_z) * (orig_h - 1)
                col_indices_f = ((y1_f_pv + source_map_bound_y) / source_denom_y) * (orig_w - 1)
                y1_physical = np.round(row_indices_f).astype(int)
                z1_physical = np.round(col_indices_f).astype(int)
                y1_physical = np.clip(y1_physical, 0, orig_h - 1)
                z1_physical = np.clip(z1_physical, 0, orig_w - 1)

                # --- MODIFIED Z-AXIS MAPPING ---
                if using_manual_dest_bounds: # Mode 1
                    y0_idx_f = ((y0_f_pv - dest_map_bound_min_y) / dest_denom_y) * (mapped_w - 1)
                    if flip_z_axis_render:
                        # Maps min_z to bottom row, max_z to top row
                        z0_idx_f = ((dest_map_bound_max_z - z0_f_pv) / dest_denom_z) * (mapped_h - 1)
                    else: # Original mapping
                        z0_idx_f = ((z0_f_pv - dest_map_bound_min_z) / dest_denom_z) * (mapped_h - 1)
                else: # Mode 2 (Auto Bounds)
                    # dest_map_bound_min_z is -auto_bound_z, dest_map_bound_max_z is +auto_bound_z
                    # dest_denom_z is 2 * auto_bound_z
                    auto_bound_y = dest_map_bound_max_y
                    auto_bound_z_positive_extent = dest_map_bound_max_z # This is the positive extent for auto mode

                    y0_idx_f = ((y0_f_pv + auto_bound_y) / dest_denom_y) * (mapped_w - 1)
                    if flip_z_axis_render:
                        # Maps -auto_bound_z (logical bottom) to bottom row, +auto_bound_z (logical top) to top row
                        z0_idx_f = ((auto_bound_z_positive_extent - z0_f_pv) / dest_denom_z) * (mapped_h - 1)
                    else: # Original mapping
                        z0_idx_f = ((z0_f_pv + auto_bound_z_positive_extent) / dest_denom_z) * (mapped_h - 1)
                # --- END MODIFIED Z-AXIS MAPPING ---

                y0_idx_clipped = np.round(y0_idx_f).astype(int)
                z0_idx_clipped = np.round(z0_idx_f).astype(int)
                y0_idx_clipped = np.clip(y0_idx_clipped, 0, mapped_w - 1)
                z0_idx_clipped = np.clip(z0_idx_clipped, 0, mapped_h - 1)
                pixel_map_ok_mask = (y0_idx_f >= -epsilon) & (y0_idx_f < mapped_w + epsilon) & \
                                    (z0_idx_f >= -epsilon) & (z0_idx_f < mapped_h + epsilon)
                count_after_pixel_map = np.sum(pixel_map_ok_mask)
                photons_lost_dest_pixel_map += (count_after_dest_logical - count_after_pixel_map)
                if count_after_pixel_map == 0: del photon_list; continue
                y0_idx_final = y0_idx_clipped[pixel_map_ok_mask]
                z0_idx_final = z0_idx_clipped[pixel_map_ok_mask]
                y1_physical_final = y1_physical[pixel_map_ok_mask]
                z1_physical_final = z1_physical[pixel_map_ok_mask]
                invalid_src_idx_mask = (y1_physical_final >= orig_h) | (z1_physical_final >= orig_w) | \
                                       (y1_physical_final < 0) | (z1_physical_final < 0)
                num_invalid_src_idx = np.sum(invalid_src_idx_mask)
                photons_lost_invalid_source_idx += num_invalid_src_idx
                if num_invalid_src_idx > 0 :
                     valid_accumulation_mask = ~invalid_src_idx_mask
                     if not np.any(valid_accumulation_mask):
                         del photon_list; continue
                     y0_idx_final = y0_idx_final[valid_accumulation_mask]
                     z0_idx_final = z0_idx_final[valid_accumulation_mask]
                     y1_physical_final = y1_physical_final[valid_accumulation_mask]
                     z1_physical_final = z1_physical_final[valid_accumulation_mask]
                count_accumulated_chunk = len(y0_idx_final)
                photons_accumulated += count_accumulated_chunk
                if count_accumulated_chunk > 0:
                    source_colors = original_image_rgb[y1_physical_final, z1_physical_final].astype(np.float64)
                    np.add.at(sum_array, (z0_idx_final, y0_idx_final), source_colors)
                    np.add.at(count_array, (z0_idx_final, y0_idx_final), 1)
                del photon_list, y0_f, z0_f, y1_f, z1_f, valid_indices_nan, valid_indices_src_bounds, valid_indices_dest_logical
                del y0_idx_final, z0_idx_final, y1_physical_final, z1_physical_final
                if count_accumulated_chunk > 0: del source_colors
            except FileNotFoundError: print(f"\nWarn: Chunk file not found during processing: {chunk_file}. Skip."); continue
            except Exception as e: print(f"\nWarn: Error processing chunk {chunk_file}: {e}. Skip."); traceback.print_exc(limit=1, file=sys.stdout); continue
        print(f"\nFinished processing all chunks. Time: {time.time() - start_proc_time:.2f}s")
        print("\n--- Photon Loss Debug Summary ---")
        print(f"Destination Mode: {'MANUAL Bounds + Scale' if using_manual_dest_bounds else 'AUTO Bounds + Shape'}")
        print(f"  Z-axis Render Flip: {flip_z_axis_render}")
        print(f"Total photons loaded from chunks: {total_photons_loaded_from_chunks}")
        print(f"  Lost due to NaN values:         {photons_lost_nan}")
        count_rem_1 = total_photons_loaded_from_chunks - photons_lost_nan
        print(f"  Remaining after NaN filter:     {count_rem_1}")
        print(f"  Lost due to source bounds filter [y1(horiz):+/-{source_map_bound_y:.3f}, z1(vert):+/-{source_map_bound_z:.3f}]: {photons_lost_source_bounds}")
        count_rem_2 = count_rem_1 - photons_lost_source_bounds
        print(f"  Remaining after source bounds:  {count_rem_2}")
        print(f"  Lost due to DEST LOGICAL filter [y0(horiz)=[{dest_map_bound_min_y:.3f},{dest_map_bound_max_y:.3f}], z0(vert)=[{dest_map_bound_min_z:.3f},{dest_map_bound_max_z:.3f}]]: {photons_lost_dest_logical_filter}")
        count_rem_3 = count_rem_2 - photons_lost_dest_logical_filter
        print(f"  Remaining after dest logical:   {count_rem_3}")
        print(f"  Lost due to dest pixel map/clip (outside vert=[{0}-{mapped_h-1}], horiz=[{0}-{mapped_w-1}]): {photons_lost_dest_pixel_map}")
        count_rem_4 = count_rem_3 - photons_lost_dest_pixel_map
        print(f"  Remaining after dest pixel map: {count_rem_4}")
        print(f"  Lost due to invalid source idx (post-map/clip check): {photons_lost_invalid_source_idx}")
        print(f"Total photons accumulated (pre-add.at): {photons_accumulated}")
        print(f"Calculated total loss:            {total_photons_loaded_from_chunks - photons_accumulated}")
        calc_total = photons_lost_nan + photons_lost_source_bounds + photons_lost_dest_logical_filter + photons_lost_dest_pixel_map + photons_lost_invalid_source_idx + photons_accumulated
        if abs(calc_total - total_photons_loaded_from_chunks) > 0:
             print(f"!!! WARNING: Photon count sanity check failed! Calculated Total ({calc_total}) != Loaded ({total_photons_loaded_from_chunks}). Diff: {total_photons_loaded_from_chunks - calc_total}")
        print("----------------------------------")
        print("Calculating averages and finalizing image...")
        hit_mask = count_array > 0
        num_hit_pixels = np.sum(hit_mask)
        if num_hit_pixels > 0:
            print(f"  Averaging colors for {num_hit_pixels} hit pixels.")
            valid_counts = count_array[hit_mask]
            average_colors_float = sum_array[hit_mask] / valid_counts[..., np.newaxis]
            mapped_image[hit_mask] = np.clip(average_colors_float, 0, 255).astype(np.uint8)
        else:
            print("  No pixels were hit by valid photons.")
        end_time_total = time.time()
        print(f"Processing finished. Total time: {end_time_total - start_time_total:.2f} seconds.")
        return _save_image(mapped_image, save_path)
    except Exception as e:
        print(f"\n--- An critical error occurred during mapping setup or processing ---"); traceback.print_exc()
        if 'mapped_h' in locals() and 'mapped_w' in locals() and save_path:
             try:
                 if 'default_color_rgb' not in locals(): default_color_rgb = (0,0,0)
                 if 'mapped_image' not in locals() or mapped_image is None:
                     print("Initializing empty background for error image.")
                     mapped_image = np.full((mapped_h, mapped_w, 3), default_color_rgb, dtype=np.uint8)
                 elif mapped_image.shape != (mapped_h, mapped_w, 3):
                      print(f"Warning: Mapped image shape {mapped_image.shape} doesn't match expected {(mapped_h, mapped_w, 3)}. Creating empty background.")
                      mapped_image = np.full((mapped_h, mapped_w, 3), default_color_rgb, dtype=np.uint8)
                 print("Attempting to save background/partial image due to critical error.")
                 _save_image(mapped_image, save_path + "_CRITICAL_ERROR.png")
             except Exception as save_err:
                 print(f"Failed to save error image: {save_err}")
        return False