
# Photon-Track Reconstruction by Iterative Voxel Peeling

This notebook builds voxel grids from ROOT **histograms** and reconstructs a track by
iteratively **peeling** photons using a greedy strategy. We support both seeding modes
(**max voxel** / **average of top-K voxels**) and both growth modes (**connected neighbors**
vs **global max**). Removal uses a physically motivated fraction per photon: either **uniform**
(1 / number of voxels traversed) or **length-weighted** (voxel path length / total path length).

**Key ideas:**

- Build a sparse **photon↔voxel incidence** using a 3D DDA through the detector AABB.
- The initial grid is just the sum over photons of those per-voxel contributions.
- At each step, choose a voxel (seed or next), **reduce** the weights of all photons that pass
  through it by their share attributable to that voxel, and **update** the grid **incrementally**
  without recomputing from scratch.
- Repeat until stopping criteria are met (min support, remaining weight, max steps, etc.).

> **Note**: This notebook **re-implements the grid-building** logic in-place to facilitate
> many iterations (one per selected voxel). We still **reuse your repo helpers** to read the
> histogram hits and primary track points. You can modify and iterate quickly here.


In [None]:

# --- Environment & imports ----------------------------------------------------
import os, sys, math, gc, logging, itertools, time
from typing import Tuple, List, Dict, Optional
import numpy as np
import pandas as pd
from tqdm import tqdm
import uproot

# Allow imports from your repo (../python)
sys.path.append('../python')

# Domain helpers from your codebase
from constants import CM_PER_RAD, MM_PER_CM, Y_LIM, DETECTOR_SIZE_MM
from importMethods import (
    get_histogram_hits_tuple,
    get_histogram_nHits_total,
    get_primary_position,
    get_primary_pdg
)
from hitAccuracyMethods import (
    make_r, make_theta, make_phi, make_reconstructedVector_direction, make_relativeVector
)
from filterMethods import filter_r

# Optional plotting helpers from your repo
from plotMethods import *
import plotParameters

import matplotlib.pyplot as plt
import matplotlib.cm as cm

# Logging
logging.basicConfig(level=logging.INFO, format='[%(levelname)s] %(message)s')
LOGGER = logging.getLogger('peel_reco')

# Constants / utilities reused from script
def r_to_theta(r):
    """Convert radius to theta using your constants."""
    return r / (CM_PER_RAD * MM_PER_CM)


In [None]:

# --- User settings (tweak here) -----------------------------------------------
# IO
INPUT_DIR = '/path/to/your/root/files'   # Folder containing *.root files
FILE_LIMIT = None                        # e.g., 10 to test; or None for all
HIST_DIR = 'photoSensor_hits_histograms' # We use histogram-based hits
PRIMARY_TREE = 'primary;1'               # For optional truth visualization

# Grid
GRID_SIZE = (80, 80, 80)                 # (nx, ny, nz)
DETECTOR_MM = tuple(DETECTOR_SIZE_MM)    # (Lx, Ly, Lz) from your constants

# Cuts
APPLY_CUTS = True                        # Option to skip cuts
MIN_N_HITS = 0
MIN_PRIMARY_STEPS = 30
PRIMARY_PDG = 13

# Photon share / peeling
ALPHA_MODE = 'length'  # 'uniform' or 'length' (length-weighted is usually better)
ALPHA_SCALE = 1.0      # Scale for how much to remove per selected voxel event

# Seeding & growth
SEED_MODE = 'max'      # 'max' or 'avg' (average of top-K voxels)
SEED_TOP_K = 50        # Only used for SEED_MODE='avg'
GROW_MODE = 'connected'  # 'connected' or 'global' (next global max)
CONNECTIVITY = 26        # 6, 18, or 26 neighbors for 'connected'
ALLOW_RESEED_IF_STUCK = True

# Stopping
MIN_SUPPORT = 3.0          # Minimum grid value for a voxel to be considered
MAX_STEPS = 500            # Limit of selected voxels
REMAINING_WEIGHT_FRAC = 0.05  # Stop if total residual photon weight < 5%

# Snapshots (grid at iterations) -- heavy! Use carefully.
SAVE_SNAPSHOTS = False
SNAPSHOT_EVERY = 1  # save every N steps (ignored if SAVE_SNAPSHOTS=False)

# Plotting (downsample to speed up in-notebook 3D plots)
RESHAPE_SIZE = 40      # e.g., 40 => average blocks to 40^3 for plotting
PLOT_MIN_VAL = 400     # min value threshold for rendering (tune per dataset)

# Misc
NUM_WORKERS = 1        # Notebook-friendly; set >1 to parallelize where applicable
RANDOM_SEED = 42

# Output directory for any artifacts you may later want to persist
OUTPUT_DIR = './peel_outputs'
os.makedirs(OUTPUT_DIR, exist_ok=True)

np.random.seed(RANDOM_SEED)


In [None]:

# --- Discover ROOT files & (optionally) apply cuts ----------------------------
def discover_root_files(input_dir: str, file_limit: Optional[int]=None):
    files = sorted([os.path.join(input_dir, f) for f in os.listdir(input_dir) if f.endswith('.root')])
    if file_limit is not None:
        files = files[:file_limit]
    return files

def passes_cuts(path: str) -> bool:
    if not APPLY_CUTS:
        return True
    try:
        n_hits = get_histogram_nHits_total(path, directoryName=HIST_DIR)
        prim_steps = len(get_primary_position(path, PRIMARY_TREE))
        if n_hits >= MIN_N_HITS and prim_steps >= MIN_PRIMARY_STEPS:
            return True
        else:
            LOGGER.info(f"Cut {os.path.basename(path)}: hits={n_hits}, steps={prim_steps}")
            return False
    except Exception as e:
        LOGGER.warning(f"Error applying cuts to {path}: {e}")
        return False

all_files = discover_root_files(INPUT_DIR, FILE_LIMIT)
files = [p for p in all_files if passes_cuts(p)]

LOGGER.info(f"Found {len(files)} ROOT files after cuts (from {len(all_files)} found)." )


In [None]:

# --- Build hits DataFrame for one file (from histograms) ----------------------
def build_hits_df_from_hist(path: str, hist_dir: str=HIST_DIR) -> pd.DataFrame:
    # Using your helper to read histogram-based hits
    ids, dirs, poss, walls, rb, rn = get_histogram_hits_tuple(path, hist_dir, True)
    df = pd.DataFrame({
        'sensor_name': ids,
        'sensor_direction': dirs,
        'sensor_position': poss,
        'sensor_wall': walls,
        'relativePosition_binned': rb,
        'relativePosition_nBin': rn,
    })
    # Reconstruct unit directions (following your pipeline)
    df = make_r(df)
    df = filter_r(df, Y_LIM)
    df = make_theta(df, r_to_theta)
    df = make_phi(df)
    df = make_reconstructedVector_direction(df)
    if 'initialPosition' in df.columns:
        df = make_relativeVector(df)
    return df

def get_primary_positions(path: str) -> np.ndarray:
    """Optional truth points for visualization (returns Nx3)."""
    try:
        pos = get_primary_position(path, PRIMARY_TREE)
        pdgs = get_primary_pdg(path, PRIMARY_TREE)
        pos = np.asarray([p for p, q in zip(pos, pdgs) if q == PRIMARY_PDG], dtype=float)
        # Keep only positions inside detector bounds
        half = np.asarray(DETECTOR_MM) / 2.0
        mask = np.all((pos > -half) & (pos < half), axis=1)
        return pos[mask]
    except Exception as e:
        LOGGER.warning(f"Primary load failed ({os.path.basename(path)}): {e}")
        return np.empty((0,3), dtype=float)


In [None]:

# --- Geometry: AABB intersection & 3D DDA traversal ---------------------------
EPS = 1e-9

def aabb_intersect(origin: np.ndarray, direction: np.ndarray,
                   bounds_min: np.ndarray, bounds_max: np.ndarray) -> Optional[Tuple[float,float]]:
    """Ray (origin + t*dir) vs axis-aligned box [min, max]. Return (t_enter, t_exit) or None."""
    inv = np.where(np.abs(direction) < EPS, np.inf, 1.0 / direction)
    t0s = (bounds_min - origin) * inv
    t1s = (bounds_max - origin) * inv
    tmin = np.maximum.reduce(np.minimum(t0s, t1s))
    tmax = np.minimum.reduce(np.maximum(t0s, t1s))
    if tmax < max(tmin, 0.0):
        return None
    return max(tmin, 0.0), tmax

def world_to_voxel(point_mm: np.ndarray, grid_shape: Tuple[int,int,int],
                   det_mm: Tuple[float,float,float]) -> np.ndarray:
    """Map mm coords to fractional voxel coords (not clipped)."""
    det_mm = np.asarray(det_mm, dtype=float)
    gshape = np.asarray(grid_shape, dtype=int)
    vsize = det_mm / gshape
    vmin = -det_mm / 2.0
    return (point_mm - vmin) / vsize  # fractional index coords

def voxel_center_ijk(ijk: Tuple[int,int,int], grid_shape, det_mm):
    gshape = np.asarray(grid_shape, dtype=int)
    det_mm = np.asarray(det_mm, dtype=float)
    vsize = det_mm / gshape
    vmin  = -det_mm / 2.0
    ijk = np.asarray(ijk, dtype=float)
    return vmin + (ijk + 0.5) * vsize

def dda_traverse(origin_mm: np.ndarray, direction_mm: np.ndarray,
                 grid_shape: Tuple[int,int,int], det_mm: Tuple[float,float,float],
                 t_enter: float, t_exit: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    3D DDA through voxel grid. Assumes `direction_mm` is **unit length** (mm units).
    Returns:
      - flat voxel indices (int)
      - lengths inside each voxel (float, in mm)
    """
    # Map entry point to fractional voxel coords
    gshape = np.asarray(grid_shape, dtype=int)
    det_mm = np.asarray(det_mm, dtype=float)
    vsize = det_mm / gshape
    vmin  = -det_mm / 2.0

    # Start at entry
    p = origin_mm + direction_mm * t_enter  # point in mm
    g = (p - vmin) / vsize                  # fractional voxel coords
    ijk = np.floor(g - EPS).astype(int)     # voxel index (clip later)

    # Handle boundary edge case
    ijk = np.clip(ijk, 0, gshape - 1)

    # Steps & deltas
    step = np.sign(direction_mm).astype(int)
    step[step == 0] = 0

    # Compute first boundary tMax per axis
    tMax = np.empty(3, dtype=float)
    tDelta = np.empty(3, dtype=float)
    for ax in range(3):
        if abs(direction_mm[ax]) < EPS:
            tMax[ax] = np.inf
            tDelta[ax] = np.inf
        else:
            if direction_mm[ax] > 0:
                next_boundary = vmin[ax] + (ijk[ax] + 1) * vsize[ax]
            else:
                next_boundary = vmin[ax] + ijk[ax] * vsize[ax]
            tMax[ax] = (next_boundary - p[ax]) / direction_mm[ax]
            tDelta[ax] = abs(vsize[ax] / direction_mm[ax])

    # DDA loop
    voxels = []
    lengths = []
    t = t_enter
    while t < t_exit - EPS:
        # Record current voxel and length to next boundary
        t_next = min(tMax[0], tMax[1], tMax[2], t_exit)
        dt = max(0.0, t_next - t)
        if dt > 0:
            flat = (ijk[0] * gshape[1] + ijk[1]) * gshape[2] + ijk[2]
            voxels.append(flat)
            lengths.append(dt)  # since |direction| == 1 mm, dt is path length in mm

        # Step along the axis with smallest tMax
        ax = int(np.argmin(tMax))
        if tMax[ax] >= t_exit:
            break

        ijk[ax] += step[ax]
        # Exit if outside grid
        if (ijk[ax] < 0) or (ijk[ax] >= gshape[ax]):
            break

        t = tMax[ax]
        tMax[ax] += tDelta[ax]

    if not voxels:
        return np.empty((0,), dtype=int), np.empty((0,), dtype=float)

    return np.asarray(voxels, dtype=int), np.asarray(lengths, dtype=float)


In [None]:

# --- Build photon↔voxel incidence (sparse) and initial grid -------------------
def build_incidence_from_hits(df_hits: pd.DataFrame,
                              grid_shape: Tuple[int,int,int]=GRID_SIZE,
                              det_mm: Tuple[float,float,float]=DETECTOR_MM,
                              alpha_mode: str = ALPHA_MODE):
    """
    From hits DataFrame (with 'sensor_position' and 'reconstructedVector_direction'),
    build:
      - initial grid (sum over photons of voxel contributions)
      - photon->voxel CSR-like structure
      - voxel->photon CSR-like structure
      - per-photon residual weights (init = 1.0)
    """
    starts = np.vstack(df_hits['sensor_position'].to_numpy())
    dirs   = -np.vstack(df_hits['reconstructedVector_direction'].to_numpy())  # back into detector

    # Normalize directions
    norms = np.linalg.norm(dirs, axis=1, keepdims=True)
    safe = norms[:,0] > 0
    dirs[safe] /= norms[safe]

    bounds_min = -np.asarray(det_mm) / 2.0
    bounds_max =  np.asarray(det_mm) / 2.0

    # CSR-like containers
    photon_ptr = [0]
    photon_vox = []
    photon_frac = []  # per-voxel share for that photon (sum to 1 per photon)

    # COO for voxel->photon (we will sort & CSR it later)
    coo_vox = []
    coo_photon = []
    coo_frac = []

    # Initial grid (flat)
    nvox = int(grid_shape[0] * grid_shape[1] * grid_shape[2])
    grid_flat = np.zeros(nvox, dtype=np.float64)

    for p_idx in tqdm(range(len(starts)), desc='Tracing photons', leave=False):
        o = starts[p_idx].astype(float)
        d = dirs[p_idx].astype(float)
        inter = aabb_intersect(o, d, bounds_min, bounds_max)
        if inter is None:
            photon_ptr.append(photon_ptr[-1])
            continue
        t_enter, t_exit = inter
        vox_idx, lengths = dda_traverse(o, d, grid_shape, det_mm, t_enter, t_exit)
        if vox_idx.size == 0:
            photon_ptr.append(photon_ptr[-1])
            continue

        if alpha_mode == 'uniform':
            contrib = np.full_like(lengths, 1.0/len(lengths), dtype=np.float64)
        else:  # 'length'
            total_len = np.sum(lengths)
            if total_len <= 0:
                photon_ptr.append(photon_ptr[-1])
                continue
            contrib = lengths / total_len

        # Append CSR segments
        photon_vox.extend(vox_idx.tolist())
        photon_frac.extend(contrib.tolist())
        photon_ptr.append(photon_ptr[-1] + len(vox_idx))

        # Accumulate initial grid
        grid_flat[vox_idx] += contrib

        # Build COO for voxel->photon
        coo_vox.extend(vox_idx.tolist())
        coo_photon.extend([p_idx] * len(vox_idx))
        coo_frac.extend(contrib.tolist())

    photon_ptr = np.asarray(photon_ptr, dtype=np.int64)
    photon_vox = np.asarray(photon_vox, dtype=np.int32)
    photon_frac = np.asarray(photon_frac, dtype=np.float32)

    # Build voxel->photon CSR by sorting COO by voxel index
    if len(coo_vox) == 0:
        voxel_ptr = np.zeros(nvox+1, dtype=np.int64)
        voxel_photon = np.empty((0,), dtype=np.int32)
        voxel_frac = np.empty((0,), dtype=np.float32)
    else:
        coo_vox = np.asarray(coo_vox, dtype=np.int32)
        order = np.argsort(coo_vox, kind='mergesort')  # stable
        coo_vox = coo_vox[order]
        voxel_photon = np.asarray(coo_photon, dtype=np.int32)[order]
        voxel_frac = np.asarray(coo_frac, dtype=np.float32)[order]

        voxel_ptr = np.zeros(nvox+1, dtype=np.int64)
        # Count entries per voxel
        np.add.at(voxel_ptr, coo_vox + 1, 1)
        voxel_ptr = np.cumsum(voxel_ptr)

    grid = grid_flat.reshape(grid_shape)

    # Residual photon weights (init to 1)
    photon_w = np.ones(len(starts), dtype=np.float32)

    return {
        'grid': grid,
        'photon_ptr': photon_ptr,
        'photon_vox': photon_vox,
        'photon_frac': photon_frac,
        'voxel_ptr': voxel_ptr,
        'voxel_photon': voxel_photon,
        'voxel_frac': voxel_frac,
        'photon_w': photon_w,
        'grid_shape': tuple(grid_shape),
        'det_mm': tuple(det_mm)
    }


In [None]:

# --- Utilities for seeding, neighbors, and peeling updates --------------------
def grid_argmax(grid: np.ndarray, visited: Optional[np.ndarray]=None, min_support: float=MIN_SUPPORT) -> Optional[int]:
    flat = grid.ravel()
    if visited is not None:
        flat = flat.copy()
        flat[visited.ravel()] = -np.inf
    idx = int(np.argmax(flat))
    if not np.isfinite(flat[idx]) or flat[idx] < min_support:
        return None
    return idx

def grid_seed_avg(grid: np.ndarray, top_k: int=SEED_TOP_K, det_mm: Tuple[float,float,float]=DETECTOR_MM) -> Optional[int]:
    """Center-of-mass of top-K voxels (weighted), snapped to nearest voxel."""
    gshape = np.asarray(grid.shape, dtype=int)
    flat = grid.ravel()
    if flat.size == 0:
        return None
    k = min(len(flat), max(1, int(top_k)))
    idxs = np.argpartition(flat, -k)[-k:]
    vals = flat[idxs]
    if np.all(vals <= 0):
        return None
    # Compute weighted CoM in voxel coordinates
    ijk = np.vstack(np.unravel_index(idxs, grid.shape)).T.astype(float)
    w = vals.astype(float)
    com = np.average(ijk, weights=w, axis=0)
    com = np.clip(np.round(com), 0, gshape - 1).astype(int)
    return int((com[0] * gshape[1] + com[1]) * gshape[2] + com[2])

def neighbors_of(flat_idx: int, grid_shape: Tuple[int,int,int], connectivity: int=CONNECTIVITY) -> List[int]:
    nx, ny, nz = grid_shape
    i = flat_idx // (ny * nz)
    j = (flat_idx % (ny * nz)) // nz
    k = flat_idx % nz

    offsets = []
    if connectivity == 6:
        offsets = [(1,0,0),(-1,0,0),(0,1,0),(0,-1,0),(0,0,1),(0,0,-1)]
    elif connectivity == 18:
        for di in [-1,0,1]:
            for dj in [-1,0,1]:
                for dk in [-1,0,1]:
                    if (di,dj,dk) == (0,0,0): continue
                    if abs(di) + abs(dj) + abs(dk) <= 2:
                        offsets.append((di,dj,dk))
    else:  # 26
        for di in [-1,0,1]:
            for dj in [-1,0,1]:
                for dk in [-1,0,1]:
                    if (di,dj,dk) == (0,0,0): continue
                    offsets.append((di,dj,dk))

    nbs = []
    for di,dj,dk in offsets:
        ii, jj, kk = i+di, j+dj, k+dk
        if 0 <= ii < nx and 0 <= jj < ny and 0 <= kk < nz:
            nbs.append((ii * ny + jj) * nz + kk)
    return nbs

def peel_once(selected_voxel: int, state: Dict, alpha_scale: float=ALPHA_SCALE) -> float:
    """
    Peel one voxel: reduce photon weights for photons traversing it by their fractional share,
    then update the grid incrementally. Returns total weight removed (sum dw_p).
    """
    grid = state['grid']
    photon_ptr = state['photon_ptr']
    photon_vox = state['photon_vox']
    photon_frac = state['photon_frac']
    voxel_ptr = state['voxel_ptr']
    voxel_photon = state['voxel_photon']
    voxel_frac = state['voxel_frac']
    photon_w = state['photon_w']
    gshape = state['grid_shape']

    # Photons hitting this voxel (CSR slice)
    start = voxel_ptr[selected_voxel]
    end   = voxel_ptr[selected_voxel + 1]
    if end <= start:
        return 0.0

    p_idx_arr = voxel_photon[start:end]
    frac_vp   = voxel_frac[start:end]  # per-photon share for this voxel
    # Compute weight reductions (cap at available weight)
    dw = np.minimum(photon_w[p_idx_arr], alpha_scale * frac_vp * photon_w[p_idx_arr])
    if np.all(dw <= 0):
        return 0.0

    # Apply on grid: for each affected photon p, subtract dw_p * contrib over all its voxels
    for p_idx, dw_p in zip(p_idx_arr, dw):
        if dw_p <= 0: 
            continue
        # photon CSR segment
        ps = photon_ptr[p_idx]
        pe = photon_ptr[p_idx + 1]
        if pe <= ps:
            continue
        vxs = photon_vox[ps:pe]
        frc = photon_frac[ps:pe]
        # grid -= dw_p * frc at these voxels
        # operate on flat grid
        grid.reshape(-1)[vxs] -= dw_p * frc
        photon_w[p_idx] -= dw_p

    # Clamp grid to >= 0 small epsilon to avoid negative drift
    np.maximum(grid, 0.0, out=grid)

    removed = float(np.sum(dw))
    return removed


In [None]:

# --- Full peeling loop ---------------------------------------------------------
def run_peeling(state: Dict,
                seed_mode: str=SEED_MODE,
                grow_mode: str=GROW_MODE,
                connectivity: int=CONNECTIVITY,
                min_support: float=MIN_SUPPORT,
                max_steps: int=MAX_STEPS,
                remaining_weight_frac: float=REMAINING_WEIGHT_FRAC,
                allow_reseed: bool=ALLOW_RESEED_IF_STUCK,
                save_snapshots: bool=SAVE_SNAPSHOTS,
                snapshot_every: int=SNAPSHOT_EVERY):
    grid = state['grid']
    nvox = grid.size
    visited = np.zeros(nvox, dtype=bool)
    gshape = grid.shape

    # Seed
    if seed_mode == 'avg':
        seed = grid_seed_avg(grid, top_k=SEED_TOP_K)
    else:
        seed = grid_argmax(grid, visited=None, min_support=min_support)

    if seed is None:
        LOGGER.warning('No valid seed found (grid below min_support).')
        return {
            'track_flat': [],
            'snapshots': [],
            'grid_final': grid,
            'visited_mask': visited.reshape(gshape),
        }

    # Frontier for 'connected' growth
    frontier = set()
    track = []

    def consider_neighbors(v):
        for nb in neighbors_of(v, gshape, connectivity=connectivity):
            if not visited[nb]:
                frontier.add(nb)

    # Initialize
    current = seed
    track.append(current)
    visited[current] = True
    consider_neighbors(current)

    snapshots = []
    if save_snapshots:
        snapshots.append(grid.copy())

    # Peel the seed
    peel_once(current, state, alpha_scale=ALPHA_SCALE)

    for step in range(1, max_steps):
        # Stopping: remaining photon weight fraction
        rem_frac = float(np.sum(state['photon_w'])) / len(state['photon_w'])
        if rem_frac <= remaining_weight_frac:
            LOGGER.info(f'Stopping at step {step}: remaining photon weight frac={rem_frac:.3f}')
            break

        # Choose next voxel
        next_vox = None
        if grow_mode == 'connected':
            if frontier:
                # pick frontier voxel with highest current support
                frontier_list = list(frontier)
                vals = grid.reshape(-1)[frontier_list]
                # filter by min_support
                vals_mask = vals >= min_support
                if np.any(vals_mask):
                    best_idx = int(np.argmax(vals * vals_mask))
                    next_vox = frontier_list[best_idx]
                else:
                    frontier.clear()
            if next_vox is None and allow_reseed:
                # reseed globally if stuck
                next_vox = grid_argmax(grid, visited=visited.reshape(gshape), min_support=min_support)
        else:  # 'global'
            next_vox = grid_argmax(grid, visited=visited.reshape(gshape), min_support=min_support)

        if next_vox is None:
            LOGGER.info(f'Stopping at step {step}: no voxel meets min_support.')
            break

        # Update track
        track.append(next_vox)
        visited[next_vox] = True
        if grow_mode == 'connected':
            consider_neighbors(next_vox)

        # Peel
        peel_once(next_vox, state, alpha_scale=ALPHA_SCALE)

        # Snapshot
        if save_snapshots and (step % max(1, snapshot_every) == 0):
            snapshots.append(grid.copy())

    return {
        'track_flat': track,
        'snapshots': snapshots,
        'grid_final': grid,
        'visited_mask': visited.reshape(gshape),
    }


In [None]:

# --- End-to-end for one file ---------------------------------------------------
def process_file(path: str,
                 grid_size: Tuple[int,int,int]=GRID_SIZE,
                 seed_mode: str=SEED_MODE,
                 grow_mode: str=GROW_MODE):
    LOGGER.info(f'Processing: {os.path.basename(path)}')
    df_hits = build_hits_df_from_hist(path, HIST_DIR)
    if len(df_hits) == 0:
        LOGGER.warning('No hits after filtering.')
        return None

    state = build_incidence_from_hits(df_hits, grid_size, DETECTOR_MM, ALPHA_MODE)

    # Optional truth
    primary_pos = get_primary_positions(path)

    # Run peeling
    res = run_peeling(state,
                      seed_mode=seed_mode,
                      grow_mode=grow_mode,
                      connectivity=CONNECTIVITY,
                      min_support=MIN_SUPPORT,
                      max_steps=MAX_STEPS,
                      remaining_weight_frac=REMAINING_WEIGHT_FRAC,
                      allow_reseed=ALLOW_RESEED_IF_STUCK,
                      save_snapshots=SAVE_SNAPSHOTS,
                      snapshot_every=SNAPSHOT_EVERY)

    res.update({
        'file': path,
        'primary_pos': primary_pos,
        'grid_shape': state['grid_shape'],
        'det_mm': state['det_mm']
    })
    return res


In [None]:

# --- Plotting helpers ----------------------------------------------------------
def downsample_grid(arr: np.ndarray, reshape_size: int) -> np.ndarray:
    arr = np.asarray(arr, dtype=float)
    RS = int(reshape_size)
    step = (arr.shape[0]//RS, arr.shape[1]//RS, arr.shape[2]//RS)
    assert (arr.shape[0] % RS == 0) and (arr.shape[1] % RS == 0) and (arr.shape[2] % RS == 0), \'Grid not divisible by reshape size\'
    return arr.reshape(RS, step[0], RS, step[1], RS, step[2]).mean(axis=(1,3,5))

def make_edges(det_mm: Tuple[float,float,float], grid_shape: Tuple[int,int,int]):
    gx, gy, gz = grid_shape
    xEdges = np.linspace(-det_mm[0]/2, det_mm[0]/2, gx + 1)
    yEdges = np.linspace(-det_mm[1]/2, det_mm[1]/2, gy + 1)
    zEdges = np.linspace(-det_mm[2]/2, det_mm[2]/2, gz + 1)
    yEdges, xEdges, zEdges = np.meshgrid(xEdges, yEdges, zEdges)  # matching your example
    return xEdges, yEdges, zEdges

def plot_pred_true(pred_grid: np.ndarray, true_grid: Optional[np.ndarray],
                   det_mm: Tuple[float,float,float], reshape_size: int=RESHAPE_SIZE,
                   min_val: float=PLOT_MIN_VAL):
    pred = downsample_grid(pred_grid, reshape_size)
    true = downsample_grid(true_grid, reshape_size) if true_grid is not None else None

    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(111, projection='3d')

    gridSize = pred.shape
    xEdges, yEdges, zEdges = make_edges(det_mm, gridSize)

    maxVal = np.max(pred) if pred.size>0 else 1.0
    alpha_filled = 0.5
    globalColorNorm = cm.colors.Normalize(vmin=min_val, vmax=maxVal)
    pred_viz = np.where(pred < min_val, 0, pred)
    colors = cm.viridis(globalColorNorm(pred_viz))

    ax = plot_grid(
        ax,
        xEdges,
        yEdges,
        zEdges,
        recoGrid=pred_viz,
        recoGridFaceColors=colors,
        recoGridEdgeColors=np.clip(colors*2-0.5, 0, 1),
        recoGridAlpha=alpha_filled,
        trueGrid=true if true is not None else None,
        trueGridEdgeColors='red',
        trueGridAlpha=0,
        nullGridAlpha=0.3,
        linewidth=0.5,
        cbar=True,
        colorNorm=globalColorNorm,
        cmap=cm.viridis,
    )
    ax.set_axis_off()
    ax.set_aspect('equal')
    plt.show()
    plt.close(fig)


In [None]:

# --- Build a simple truth grid from primary points (visual aid) ---------------
def primary_points_to_grid(primary_pos: np.ndarray,
                           grid_shape: Tuple[int,int,int],
                           det_mm: Tuple[float,float,float]) -> np.ndarray:
    """Nearest-voxel stamping of primary positions; visualization-only."""
    grid = np.zeros(grid_shape, dtype=np.float32)
    if primary_pos.size == 0:
        return grid
    det_mm = np.asarray(det_mm, dtype=float)
    vmin = -det_mm / 2.0
    vmax =  det_mm / 2.0
    vsize = det_mm / np.asarray(grid_shape, dtype=int)
    g = (primary_pos - vmin) / vsize
    ijk = np.floor(g).astype(int)
    ijk = np.clip(ijk, 0, np.asarray(grid_shape)-1)
    flat = (ijk[:,0] * grid_shape[1] + ijk[:,1]) * grid_shape[2] + ijk[:,2]
    vals = np.bincount(flat, minlength=grid.size).astype(np.float32)
    return vals.reshape(grid_shape)


In [None]:

# --- Run for the first file (demo both strategies) ----------------------------
if len(files) == 0:
    LOGGER.warning('No files to process. Please set INPUT_DIR correctly.')
else:
    test_file = files[0]
    LOGGER.info(f'Demo on: {os.path.basename(test_file)}')

    # Variant A: seed=max, grow=connected
    SEED_MODE_A = 'max'
    GROW_MODE_A = 'connected'
    resA = process_file(test_file, GRID_SIZE, seed_mode=SEED_MODE_A, grow_mode=GROW_MODE_A)

    # Variant B: seed=avg, grow=global
    SEED_MODE_B = 'avg'
    GROW_MODE_B = 'global'
    resB = process_file(test_file, GRID_SIZE, seed_mode=SEED_MODE_B, grow_mode=GROW_MODE_B)

    # Visualization (pred vs optional truth)
    if resA is not None:
        true_gridA = primary_points_to_grid(resA['primary_pos'], resA['grid_shape'], resA['det_mm'])
        plot_pred_true(resA['grid_final'], true_gridA, resA['det_mm'], reshape_size=RESHAPE_SIZE, min_val=PLOT_MIN_VAL)
    if resB is not None:
        true_gridB = primary_points_to_grid(resB['primary_pos'], resB['grid_shape'], resB['det_mm'])
        plot_pred_true(resB['grid_final'], true_gridB, resB['det_mm'], reshape_size=RESHAPE_SIZE, min_val=PLOT_MIN_VAL)


In [None]:

# --- Export helpers: track voxels & centers -----------------------------------
def flat_to_ijk(flat_idx: int, grid_shape: Tuple[int,int,int]):
    nx, ny, nz = grid_shape
    i = flat_idx // (ny * nz)
    j = (flat_idx % (ny * nz)) // nz
    k = flat_idx % nz
    return i, j, k

def voxels_to_mm_centers(track_flat: List[int], grid_shape: Tuple[int,int,int], det_mm: Tuple[float,float,float]):
    centers = []
    for f in track_flat:
        ijk = flat_to_ijk(f, grid_shape)
        centers.append(voxel_center_ijk(ijk, grid_shape, det_mm))
    return np.vstack(centers) if centers else np.empty((0,3))

# Example: dump track centers for Variant A if available
if len(files) > 0 and 'resA' in globals() and resA is not None:
    centersA = voxels_to_mm_centers(resA['track_flat'], resA['grid_shape'], resA['det_mm'])
    np.save(os.path.join(OUTPUT_DIR, 'track_centers_variantA.npy'), centersA)
    LOGGER.info(f"Saved Variant A track centers: {os.path.join(OUTPUT_DIR, 'track_centers_variantA.npy')} ({centersA.shape})")



## Notes & Next Steps

- **Fraction removal (`ALPHA_MODE`)**  
  - `uniform`: each voxel along a photon path gets equal share (1/N).  
  - `length`: each voxel's share is proportional to the path length inside that voxel (recommended).  
  The peeling removes `ALPHA_SCALE * share` of the photon's weight for photons hitting the selected voxel.

- **Multiple tracks**  
  To extract multiple tracks, iterate `run_peeling(...)` again on the **residual** state after the first run,
  recording the second track, and so on until stopping. (Left simple here to start with one.)

- **Performance**  
  This design does **not** rebuild the grid each iteration; it updates incrementally using the
  precomputed incidence. For very large datasets (>10k photons per file), consider running in a
  Python process instead of a notebook cell to avoid UI overhead. You can also batch files.

- **Debug maps**  
  If you want to inspect the photon↔voxel incidence, we can easily add exporters for the CSR arrays.

- **Safeguards**  
  We clamp the grid to ≥ 0 after updates to avoid negative drift. Min-support and remaining-weight
  stops prevent runaway loops.
