In [51]:
# IMPORTS
from __future__ import annotations
import torch    
import pytomography
from pytomography.metadata import ObjectMeta
from pytomography.metadata.PET import PETLMProjMeta
from pytomography.projectors.PET import PETLMSystemMatrix
from pytomography.algorithms import OSEM, MLEM
from pytomography.io.PET import gate, shared
from pytomography.likelihoods import PoissonLogLikelihood
import os
from pytomography.transforms.shared import GaussianFilter
import matplotlib.pyplot as plt
from pytomography.utils import sss

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'

torch.cuda.is_available = lambda: False
print(f"CUDA available: {torch.cuda.is_available()}")
pytomography.device = 'cpu'
print(f"Current device: {pytomography.device}")
torch.cuda.empty_cache()

CUDA available: False
Current device: cpu


In [52]:
# # Visualisation Functions
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider


def visualize_voxel_tensor_3d(voxel_tensor, initial_min_threshold=None, initial_max_threshold=None, 
                               voxel_size_mm=1.0, world_origin=None, min_threshold=None, max_threshold=None):
    """
    Interactive 3D visualization of voxel tensor with dual threshold sliders.

    Args:
        voxel_tensor: (nx, ny, nz) numpy array with voxel counts
        initial_min_threshold: Initial minimum threshold value for the slider (default: min_val)
        initial_max_threshold: Initial maximum threshold value for the slider (default: max_val)
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0) (optional)
        min_threshold: Minimum threshold value for slider range (optional)
        max_threshold: Maximum threshold value for slider range (optional)
    """
    # Extract non-zero voxel coordinates and values
    coords = np.where(voxel_tensor > 0)
    x_coords, y_coords, z_coords = coords
    values = voxel_tensor[coords]

    # Convert voxel indices to world coordinates if world_origin provided
    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords_world = x_coords * voxel_size_mm + x_min
        y_coords_world = y_coords * voxel_size_mm + y_min
        z_coords_world = z_coords * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords_world = x_coords * voxel_size_mm
        y_coords_world = y_coords * voxel_size_mm
        z_coords_world = z_coords * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"

    # Get value range for sliders
    min_val = float(np.min(values))
    max_val = float(np.max(values))

    # Use user-specified min/max threshold range if provided
    slider_min = min_threshold if min_threshold is not None else min_val
    slider_max = max_threshold if max_threshold is not None else max_val

    # Set initial thresholds with defaults
    if initial_min_threshold is None:
        initial_min_threshold = slider_min
    else:
        initial_min_threshold = max(slider_min, min(slider_max, float(initial_min_threshold)))
    
    if initial_max_threshold is None:
        initial_max_threshold = slider_max
    else:
        initial_max_threshold = max(slider_min, min(slider_max, float(initial_max_threshold)))

    # Ensure min <= max
    if initial_min_threshold > initial_max_threshold:
        initial_min_threshold, initial_max_threshold = initial_max_threshold, initial_min_threshold

    print(f"Voxel value range: {min_val} to {max_val}")
    print(f"Total non-zero voxels: {len(values)}")
    print(f"Initial thresholds: {initial_min_threshold} to {initial_max_threshold}")
    print(f"Slider range: {slider_min} to {slider_max}")
    print(f"Voxel resolution: {voxel_size_mm}mm")

    def update_plot(min_thresh, max_thresh):
        # Ensure min <= max
        if min_thresh > max_thresh:
            min_thresh, max_thresh = max_thresh, min_thresh

        # Filter voxels within threshold range
        mask = (values >= min_thresh) & (values <= max_thresh)
        if not np.any(mask):
            print(f"No voxels in threshold range [{min_thresh}, {max_thresh}]")
            return

        filtered_x = x_coords_world[mask]
        filtered_y = y_coords_world[mask]
        filtered_z = z_coords_world[mask]
        filtered_values = values[mask]

        # Create 3D scatter plot
        fig = go.Figure(data=go.Scatter3d(
            x=filtered_x,
            y=filtered_y,
            z=filtered_z,
            mode='markers',
            marker=dict(
                size=1,
                color=filtered_values,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Voxel Count"),
                line=dict(width=0)
            ),
            text=[f'Count: {v}' for v in filtered_values],
            hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>'
        ))

        fig.update_layout(
            title=f'3D Voxel Visualization (Range: [{min_thresh:.6f}, {max_thresh:.6f}], Showing: {len(filtered_values)} voxels)',
            scene=dict(
                xaxis_title=f'X{coord_suffix}',
                yaxis_title=f'Y{coord_suffix}',
                zaxis_title=f'Z{coord_suffix}',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                ),
                aspectmode='cube'
            ),
            width=800,
            height=600
        )

        fig.show()

    # Create interactive sliders with linked constraints
    min_threshold_slider = FloatSlider(
        value=initial_min_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Min Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    max_threshold_slider = FloatSlider(
        value=initial_max_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Max Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    # Link sliders to maintain min <= max constraint
    def on_min_change(change):
        if change['new'] > max_threshold_slider.value:
            max_threshold_slider.value = change['new']

    def on_max_change(change):
        if change['new'] < min_threshold_slider.value:
            min_threshold_slider.value = change['new']

    min_threshold_slider.observe(on_min_change, names='value')
    max_threshold_slider.observe(on_max_change, names='value')

    interact(update_plot, 
             min_thresh=min_threshold_slider, 
             max_thresh=max_threshold_slider)


from plotly.subplots import make_subplots
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import interact, IntSlider, FloatSlider

def plot_cross_sections_interactive(numpy_array, vmax=None, title="Orthogonal Cross Sections"):
    """
    Create an interactive 3D cross-section viewer for a 3D numpy array.
    
    Parameters:
    -----------
    numpy_array : np.ndarray
        3D numpy array to visualize (shape: nx, ny, nz)
    vmax : float, optional
        Maximum value for color scale. If None, uses array maximum
    title : str, optional
        Title for the plot
    """
    
    # Ensure input is numpy array
    if not isinstance(numpy_array, np.ndarray):
        raise ValueError("Input must be a numpy array")
    
    if numpy_array.ndim != 3:
        raise ValueError("Input must be a 3D array")
    
    nx, ny, nz = numpy_array.shape
    
    # Set default vmax if not provided
    if vmax is None:
        vmax = numpy_array.max()
    
    def plot_cross_sections_horizontal(x_idx=nx//2, y_idx=ny//2, z_idx=nz//2, 
                                     vmax_slider=vmax, cmap='Magma'):
        fig = make_subplots(rows=1, cols=3, subplot_titles=[
            f'XY plane @ z={z_idx}',
            f'XZ plane @ y={y_idx}',
            f'YZ plane @ x={x_idx}'
        ])

        # XY plane at z=z_idx
        fig.add_trace(go.Heatmap(
            z=numpy_array[:, :, z_idx].T,
            colorscale=cmap,
            zmax=vmax_slider,
            zmin=0,
            showscale=True,
            name=f'XY @ z={z_idx}'
        ), row=1, col=1)

        # XZ plane at y=y_idx
        fig.add_trace(go.Heatmap(
            z=numpy_array[:, y_idx, :].T,
            colorscale=cmap,
            zmax=vmax_slider,
            zmin=0,
            showscale=True,
            name=f'XZ @ y={y_idx}'
        ), row=1, col=2)

        # YZ plane at x=x_idx
        fig.add_trace(go.Heatmap(
            z=numpy_array[x_idx, :, :].T,
            colorscale=cmap,
            zmax=vmax_slider,
            zmin=0,
            showscale=True,
            name=f'YZ @ x={x_idx}'
        ), row=1, col=3)

        fig.update_layout(
            width=1200,
            height=400,
            title_text=title
        )
        fig.show()

    # Create interactive widget
    interact(
        plot_cross_sections_horizontal,
        x_idx=IntSlider(min=0, max=nx-1, step=1, value=nx//2, description='X index'),
        y_idx=IntSlider(min=0, max=ny-1, step=1, value=ny//2, description='Y index'),
        z_idx=IntSlider(min=0, max=nz-1, step=1, value=nz//2, description='Z index'),
        vmax_slider=FloatSlider(min=0, max=numpy_array.max(), step=0.01, value=vmax, description='vmax'),
        cmap=['Magma','Greys', 'Viridis', 'Cividis', 'Plasma']
    )


In [53]:
# Initial Loading, Filtering, and Coordinate Range Calculation
import numpy as np

# Load coordinates
coordinates = np.load(fr"C:\Users\h\Desktop\PET_Recons\Reconstructions\ground_truth.npy")

# Confirm shape should be (pairs, coords=6), coords are (x1, y1, z1, x2, y2, z2)
print(f"\nData Shape (pairs, coords) : {coordinates.shape}\n")  

# Remove pairs where any coordinate value is exactly 0
filtered_coordinates = coordinates[~np.any(coordinates == 0, axis=1)]
filtered_coordinates = torch.from_numpy(filtered_coordinates).float()
# filtered_coordinates = filtered_coordinates[:100,:]
print(f"Filtered shape: {filtered_coordinates.shape}\n")

# Extract all x, y, z pairs - Coordinates are in the order (x1, y1, z1, x2, y2, z2)
all_xyz = filtered_coordinates.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")


Data Shape (pairs, coords) : (62660, 6)

Filtered shape: torch.Size([6591, 6])

x range: min=-278.1294250488281, max=278.1666564941406
y range: min=-278.4194641113281, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453


In [54]:
# Artificial intersecting lines

import torch
import math

def create_intersecting_lines(counts=1000, intersection_point=(0, 0, 0), coord_range=(100, 100, 100)):
    """
    Create a torch tensor of shape (counts, 6) containing lines that all intersect at a specified point.
    
    Args:
        counts: Number of lines to generate
        intersection_point: (x, y, z) coordinates where all lines intersect
        coord_range: (x_range, y_range, z_range) maximum absolute values for coordinates
    
    Returns:
        torch.Tensor: Shape (counts, 6) with format [x1, y1, z1, x2, y2, z2] per row
    """
    
    # Convert to tensors
    intersection = torch.tensor(intersection_point, dtype=torch.float32)
    ranges = torch.tensor(coord_range, dtype=torch.float32)
    
    # Generate random unit directions for each line
    # Using normal distribution then normalizing gives uniform distribution on sphere
    directions = torch.randn(counts, 3)
    directions = directions / torch.norm(directions, dim=1, keepdim=True)
    
    # Generate random distances for each endpoint from intersection point
    # Use different distances for each end of the line
    distances1 = torch.rand(counts, 1) * ranges.max()  # Distance for first endpoint
    distances2 = torch.rand(counts, 1) * ranges.max()  # Distance for second endpoint
    
    # Calculate endpoints
    endpoint1 = intersection + distances1 * directions
    endpoint2 = intersection - distances2 * directions
    
    # Clamp coordinates to stay within specified ranges
    endpoint1 = torch.clamp(endpoint1, -ranges, ranges)
    endpoint2 = torch.clamp(endpoint2, -ranges, ranges)
    
    # Combine into final tensor shape (counts, 6)
    lines = torch.cat([endpoint1, endpoint2], dim=1)
    
    return lines.float()

# Generate default tensor
# lines_tensor = create_intersecting_lines()
# print(f"Generated tensor shape: {lines_tensor.shape}")
# print(f"First 5 lines:")
# print(lines_tensor[:5])

# Verify lines intersect at origin (check a few lines)
def verify_intersection(lines, intersection_point=(0, 0, 0)):
    """Verify that lines pass through the intersection point"""
    intersection = torch.tensor(intersection_point, dtype=torch.float32)
    
    for i in range(min(5, len(lines))):
        p1 = lines[i, :3]
        p2 = lines[i, 3:]
        
        # Parametric line equation: point = p1 + t*(p2-p1)
        # Find t where line passes closest to intersection
        direction = p2 - p1
        to_intersection = intersection - p1
        
        # t = dot(to_intersection, direction) / dot(direction, direction)
        t = torch.dot(to_intersection, direction) / torch.dot(direction, direction)
        closest_point = p1 + t * direction
        
        distance = torch.norm(closest_point - intersection)
        print(f"Line {i}: closest distance to intersection = {distance:.6f}")

# print("\nVerification (should be very close to 0):")
# verify_intersection(lines_tensor)

# custom_lines = create_intersecting_lines(
#     counts=1000, 
#     intersection_point=(10, -5, 20), 
#     coord_range=(50, 75, 30)
# )

# custom_lines.shape
# filtered_coordinates = custom_lines

In [55]:
# # # FAKE DATA GENERATION FOR TESTING
# # Create a 1000x2 tensor with random integers, ensuring pairs are never equal
# detector_ids = torch.full((1000, 2), 0, dtype=torch.long)
# detector_ids[:, 0] = 3
# detector_ids[:, 1] = 6
# detector_ids = torch.randint(low=0, high=64*8, size=(1000, 2))
# print(detector_ids.shape)

In [56]:
# Data shape extraction and basic 3d backprojection
voxel_size = 4.0

def binary_rasterize_lors_3d_dda(pairs_coords, voxel_size_mm=1.0):
    """
    Rasterize lines of response into 3D voxel space using exact 3D DDA traversal.
    
    Args:
        pairs_coords: (N, 6) array where each row is [x1,y1,z1,x2,y2,z2]
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
    
    Returns:
        voxel_tensor: (nx, ny, nz) array with line traversal counts
    """
    # World coordinate ranges (mm)
    all_xyz = pairs_coords.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]

    x_min = x_vals.min()
    x_max = x_vals.max()
    y_min = y_vals.min()
    y_max = y_vals.max()
    z_min = z_vals.min()
    z_max = z_vals.max()

    # Round the min and max values to voxel boundaries
    x_min = np.floor(x_min / voxel_size_mm) * voxel_size_mm
    x_max = np.ceil(x_max / voxel_size_mm) * voxel_size_mm
    y_min = np.floor(y_min / voxel_size_mm) * voxel_size_mm
    y_max = np.ceil(y_max / voxel_size_mm) * voxel_size_mm
    z_min = np.floor(z_min / voxel_size_mm) * voxel_size_mm
    z_max = np.ceil(z_max / voxel_size_mm) * voxel_size_mm

    # Calculate voxel tensor shape based on resolution
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    voxel_shape = (nx, ny, nz)
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: {voxel_shape} (nx, ny, nz)")
    
    voxel_tensor = np.zeros(voxel_shape, dtype=np.int32)
    
    # Coordinate transformation: world -> voxel indices
    def world_to_voxel(coords):
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        vx = (x - x_min) / voxel_size_mm
        vy = (y - y_min) / voxel_size_mm
        vz = (z - z_min) / voxel_size_mm
        return np.column_stack([vx, vy, vz])
    
    def dda_3d(p1, p2):
        """3D DDA algorithm - returns list of (x,y,z) voxel indices along line"""
        x1, y1, z1 = p1
        x2, y2, z2 = p2
        
        # Direction and step sizes
        dx = x2 - x1
        dy = y2 - y1
        dz = z2 - z1
        
        # Number of steps is maximum of absolute differences
        steps = int(max(abs(dx), abs(dy), abs(dz)))
        if steps == 0:
            return [(int(round(x1)), int(round(y1)), int(round(z1)))]
        
        # Step increments
        x_inc = dx / steps
        y_inc = dy / steps
        z_inc = dz / steps
        
        voxels = []
        x, y, z = x1, y1, z1
        
        for _ in range(steps + 1):
            vx, vy, vz = int(round(x)), int(round(y)), int(round(z))
            
            # Check bounds
            if 0 <= vx < nx and 0 <= vy < ny and 0 <= vz < nz:
                voxels.append((vx, vy, vz))
            
            x += x_inc
            y += y_inc
            z += z_inc
        
        return voxels
    
    # Transform all coordinates to voxel space
    points1 = world_to_voxel(pairs_coords[:, :3])  # First points
    points2 = world_to_voxel(pairs_coords[:, 3:])  # Second points
    
    # Process each LOR
    for i in range(len(pairs_coords)):
        p1 = points1[i]
        p2 = points2[i]
        
        # Get voxels along this line
        voxels = dda_3d(p1, p2)
        
        # Increment voxel counts
        for vx, vy, vz in voxels:
            voxel_tensor[vx, vy, vz] += 1
    
    return voxel_tensor

voxel_tensor = binary_rasterize_lors_3d_dda(filtered_coordinates, voxel_size_mm=voxel_size)
print(f"Voxel tensor shape: {voxel_tensor.shape}")

# Specify object space for reconstruction
# voxel_space = voxel_tensor.shape # voxels
voxel_space = voxel_tensor.shape # voxels

print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")

Voxel size: 4.0mm
Voxel shape: (140, 140, 74) (nx, ny, nz)
Voxel tensor shape: (140, 140, 74)
x range: min=-278.1294250488281, max=278.1666564941406
y range: min=-278.4194641113281, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453


In [57]:
# Configure and validate virtual PET scanner geometry

info = {'min_rsector_difference': 0,
 'crystal_length': 20.0,
 'radius': 337.0,
 'crystalTransNr': 8,
 'crystalTransSpacing': 4.0,
 'crystalAxialNr': 8,
 'crystalAxialSpacing': 4.0,
 'submoduleAxialNr': 1,
 'submoduleAxialSpacing': 0,
 'submoduleTransNr': 1,
 'submoduleTransSpacing': 0,
 'moduleTransNr': 1,
 'moduleTransSpacing': 0.0,
 'moduleAxialNr': 8,
 'moduleAxialSpacing': 32.25,
 'rsectorTransNr': 56,
 'rsectorAxialNr': 1,
 'NrCrystalsPerRing': 448,
 'NrRings': 64,
 'firstCrystalAxis': 1}

import plotly.graph_objects as go
from plotly.graph_objs import FigureWidget
import numpy as np
from scipy.spatial import cKDTree
from ipywidgets import interact, IntSlider, VBox
from IPython.display import display

def visualize_lors_and_detectors_3d(coordinates, scanner_info, sample_lors=None, 
                                   highlight_detector_ids=None, highlight_color='green',
                                   validate_lor_index=None, verbose=True, 
                                   distance_threshold=5.0):
    """
    3D visualization of LOR endpoints and PET detector positions with interactive slider for LOR validation.
    Camera orientation is preserved during slider interactions.
    
    Parameters:
    -----------
    coordinates : numpy.ndarray
        Shape (n_lors, 6) containing [x1, y1, z1, x2, y2, z2] for each LOR
    scanner_info : dict
        Dictionary containing scanner geometry parameters
    sample_lors : int, optional
        Number of LORs to sample for visualization (default: all)
    highlight_detector_ids : list or set, optional
        Detector IDs to highlight with different color
    highlight_color : str, optional
        Color for highlighted detectors (default: 'green')
    validate_lor_index : int, optional
        Initial LOR index to display in the slider (default: 0)
    verbose : bool, optional
        If False, suppresses all print outputs (default: True)
    distance_threshold : float, optional
        Maximum allowed distance (in mm) between extended endpoints and assigned detectors.
        LOR pairs where either endpoint exceeds this distance are excluded (default: 5.0)
        
    Returns:
    --------
    tuple
        (full_detector_id_pairs, valid_detector_id_pairs, valid_lor_indices)
    """
    
    def extend_endpoints_to_radius(endpoints1, endpoints2, target_radius):
        """
        Extend LOR endpoints along the LOR direction until they reach the target radius.
        """
        # Calculate LOR direction vectors (from endpoint1 to endpoint2)
        lor_directions = endpoints2 - endpoints1
        lor_lengths = np.linalg.norm(lor_directions, axis=1, keepdims=True)
        
        # Avoid division by zero
        lor_lengths = np.where(lor_lengths == 0, 1e-10, lor_lengths)
        lor_unit_vectors = lor_directions / lor_lengths
        
        def find_radius_intersection(start_point, direction, target_radius):
            """Find where a ray intersects the cylinder at target_radius."""
            x0, y0, z0 = start_point.T
            dx, dy, dz = direction.T
            
            # Quadratic equation coefficients: at² + bt + c = 0
            a = dx**2 + dy**2
            b = 2 * (x0 * dx + y0 * dy)
            c = x0**2 + y0**2 - target_radius**2
            
            # Solve quadratic equation
            valid_mask = np.abs(a) > 1e-10  # Direction has xy component
            result = start_point.copy()
            
            if np.any(valid_mask):
                a_valid = a[valid_mask]
                b_valid = b[valid_mask]
                c_valid = c[valid_mask]
                
                discriminant = b_valid**2 - 4 * a_valid * c_valid
                solvable_mask = discriminant >= 0
                
                if np.any(solvable_mask):
                    sqrt_disc = np.sqrt(discriminant[solvable_mask])
                    t1 = (-b_valid[solvable_mask] + sqrt_disc) / (2 * a_valid[solvable_mask])
                    t2 = (-b_valid[solvable_mask] - sqrt_disc) / (2 * a_valid[solvable_mask])
                    
                    # Choose the positive t value
                    t_chosen = np.where(
                        (t1 > 0) & (t2 > 0), np.minimum(t1, t2),
                        np.where(t1 > 0, t1, 
                                np.where(t2 > 0, t2, np.maximum(t1, t2)))
                    )
                    
                    # Calculate intersection points
                    valid_indices = np.where(valid_mask)[0]
                    solvable_indices = valid_indices[solvable_mask]
                    
                    result[solvable_indices] = (start_point[solvable_indices] + 
                                            t_chosen.reshape(-1, 1) * direction[solvable_indices])
            
            return result
        
        # Extend endpoints in OPPOSITE directions along the LOR line
        extended_endpoints1 = find_radius_intersection(endpoints1, -lor_unit_vectors, target_radius)
        extended_endpoints2 = find_radius_intersection(endpoints2, lor_unit_vectors, target_radius)
        
        return extended_endpoints1, extended_endpoints2
    
    # Sample LORs if specified
    if sample_lors is not None and sample_lors < coordinates.shape[0]:
        indices = np.random.choice(coordinates.shape[0], sample_lors, replace=False)
        coordinates_sample = coordinates[indices]
        if verbose:
            print(f"Sampling {sample_lors} LORs out of {coordinates.shape[0]} total")
    else:
        coordinates_sample = coordinates
        if verbose:
            print(f"Visualizing all {coordinates.shape[0]} LORs")
    
    # Extract LOR endpoints
    n_lors = coordinates_sample.shape[0]
    endpoint1 = coordinates_sample[:, :3]  # [x1, y1, z1]
    endpoint2 = coordinates_sample[:, 3:]  # [x2, y2, z2]
    
    # Generate detector positions and IDs
    radius = scanner_info['radius']
    crystals_per_ring = scanner_info['NrCrystalsPerRing']
    n_rings = scanner_info['NrRings']
    crystal_axial_spacing = scanner_info['crystalAxialSpacing']
    module_axial_spacing = scanner_info['moduleAxialSpacing']
    module_axial_nr = scanner_info['moduleAxialNr']
    crystals_axial_per_module = scanner_info['crystalAxialNr']
    
    # Calculate total axial extent
    crystals_per_module_axial = crystals_axial_per_module
    modules_span = (module_axial_nr - 1) * module_axial_spacing
    crystals_within_modules_span = (crystals_per_module_axial - 1) * crystal_axial_spacing * module_axial_nr
    total_axial_extent = modules_span + crystals_within_modules_span
    
    # Generate detector positions and IDs
    detector_positions = []
    detector_ids = []
    
    for ring in range(n_rings):
        # Calculate z position for this ring
        if n_rings == 1:
            z_pos = 0.0
        else:
            z_pos = -total_axial_extent/2 + ring * (total_axial_extent / (n_rings - 1))
        
        for crystal in range(crystals_per_ring):
            # Calculate angular position
            angle = 2 * np.pi * crystal / crystals_per_ring
            
            # Convert to Cartesian coordinates
            x_pos = radius * np.cos(angle)
            y_pos = radius * np.sin(angle)
            
            detector_positions.append([x_pos, y_pos, z_pos])
            detector_id = ring * crystals_per_ring + crystal
            detector_ids.append(detector_id)
    
    detector_positions = np.array(detector_positions)
    detector_ids = np.array(detector_ids)
    
    # Prepare detector colors
    if highlight_detector_ids is not None:
        highlight_set = set(highlight_detector_ids)
        highlight_mask = np.isin(detector_ids, list(highlight_set))
        regular_mask = ~highlight_mask
        regular_positions = detector_positions[regular_mask]
        highlight_positions = detector_positions[highlight_mask]
        regular_ids = detector_ids[regular_mask]
        highlight_ids = detector_ids[highlight_mask]
        
        if verbose:
            print(f"Highlighting {len(highlight_positions)} detectors: {sorted(highlight_set)}")
    else:
        regular_positions = detector_positions
        highlight_positions = np.array([]).reshape(0, 3)
        regular_ids = detector_ids
        highlight_ids = np.array([])
    
    # Calculate detector assignments for all LORs
    extended_endpoint1, extended_endpoint2 = extend_endpoints_to_radius(endpoint1, endpoint2, radius)
    
    # Build KDTree for efficient nearest neighbor search
    kdtree = cKDTree(detector_positions)
    
    # Find nearest detector for each extended endpoint
    _, nearest_indices1 = kdtree.query(extended_endpoint1)
    _, nearest_indices2 = kdtree.query(extended_endpoint2)
    
    # Convert to detector IDs
    full_detector_id_pairs = np.column_stack([
        detector_ids[nearest_indices1],
        detector_ids[nearest_indices2]
    ])
    
    # Calculate distances and apply filtering
    distances1 = np.linalg.norm(extended_endpoint1 - detector_positions[nearest_indices1], axis=1)
    distances2 = np.linalg.norm(extended_endpoint2 - detector_positions[nearest_indices2], axis=1)
    valid_mask = (distances1 <= distance_threshold) & (distances2 <= distance_threshold)
    
    valid_detector_id_pairs = full_detector_id_pairs[valid_mask]
    valid_lor_indices = np.where(valid_mask)[0]
    
    # Print initial statistics
    if verbose:
        print(f"\nDistance Filtering (threshold: {distance_threshold} mm):")
        print(f"Original LORs: {len(full_detector_id_pairs)}")
        print(f"Valid LORs: {len(valid_detector_id_pairs)}")
        print(f"Filtered out: {len(full_detector_id_pairs) - len(valid_detector_id_pairs)}")
        print(f"Filtering efficiency: {len(valid_detector_id_pairs)/len(full_detector_id_pairs)*100:.1f}%")
    
    # Create the persistent FigureWidget
    fig = FigureWidget()
    
    # Add static background elements that don't change with slider
    all_lor_points = np.vstack([endpoint1, endpoint2])
    bg_opacity = 0.3
    detector_opacity = 0.4
    
    # Add ALL LOR endpoints (static background)
    fig.add_trace(go.Scatter3d(
        x=all_lor_points[:, 0],
        y=all_lor_points[:, 1],
        z=all_lor_points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='red',
            opacity=bg_opacity,
            line=dict(width=0)
        ),
        name=f'All LOR Endpoints',
        hovertemplate='<b>LOR Endpoint</b><br>' +
                      'X: %{x:.1f} mm<br>' +
                      'Y: %{y:.1f} mm<br>' +
                      'Z: %{z:.1f} mm<extra></extra>'
    ))
    
    # Add regular detectors (static)
    if len(regular_positions) > 0:
        fig.add_trace(go.Scatter3d(
            x=regular_positions[:, 0],
            y=regular_positions[:, 1],
            z=regular_positions[:, 2],
            mode='markers',
            marker=dict(
                size=3,
                color='blue',
                opacity=detector_opacity,
                symbol='diamond',
                line=dict(width=0)
            ),
            name='Regular Detectors',
            customdata=regular_ids,
            hovertemplate='<b>Detector ID: %{customdata}</b><br>' +
                          'X: %{x:.1f} mm<br>' +
                          'Y: %{y:.1f} mm<br>' +
                          'Z: %{z:.1f} mm<extra></extra>'
        ))
    
    # Add highlighted detectors (static)
    if len(highlight_positions) > 0:
        fig.add_trace(go.Scatter3d(
            x=highlight_positions[:, 0],
            y=highlight_positions[:, 1],
            z=highlight_positions[:, 2],
            mode='markers',
            marker=dict(
                size=5,
                color=highlight_color,
                opacity=1.0,
                symbol='diamond',
                line=dict(width=1, color='black')
            ),
            name=f'Highlighted Detectors ({highlight_color})',
            customdata=highlight_ids,
            hovertemplate='<b>Highlighted Detector ID: %{customdata}</b><br>' +
                          'X: %{x:.1f} mm<br>' +
                          'Y: %{y:.1f} mm<br>' +
                          'Z: %{z:.1f} mm<extra></extra>'
        ))
    
    # Add placeholder traces for dynamic LOR validation elements
    # These will be updated by the slider function
    trace_names = [
        'LOR - Original Endpoints',
        'LOR - Extended Endpoints', 
        'LOR - Assigned Detectors',
        'LOR - Original Line',
        'Extension Path 1',
        'Extension Path 2', 
        'Detector Assignment 1',
        'Detector Assignment 2'
    ]
    
    # Add empty traces that will be populated by slider updates
    for i, name in enumerate(trace_names):
        fig.add_trace(go.Scatter3d(
            x=[], y=[], z=[],
            mode='markers' if 'Endpoints' in name or 'Detectors' in name else 'lines',
            name=name,
            showlegend=True if 'Endpoints' in name or 'Detectors' in name else False
        ))
    
    # Set initial layout
    fig.update_layout(
        title=f'3D LOR Validation with Persistent Camera View',
        scene=dict(
            xaxis_title='X (mm)',
            yaxis_title='Y (mm)',
            zaxis_title='Z (mm)',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5),
                center=dict(x=0, y=0, z=0),
                up=dict(x=0, y=0, z=1)
            ),
            aspectmode='cube'
        ),
        width=900,
        height=700,
        legend=dict(
            x=0.02,
            y=0.98,
            bgcolor='rgba(255,255,255,0.8)'
        )
    )
    
    def update_lor_visualization(lor_index):
        """Update only the LOR-specific traces without recreating the entire plot"""
        
        # Get validation data for the selected LOR
        validation_data = None
        if lor_index < coordinates_sample.shape[0]:
            # Check if this LOR survived filtering
            original_lor_survived = lor_index in valid_lor_indices
            
            # Get the specific LOR's data (regardless of whether it passed filtering)
            val_endpoint1 = coordinates_sample[lor_index, :3]
            val_endpoint2 = coordinates_sample[lor_index, 3:]
            val_extended1 = extended_endpoint1[lor_index]
            val_extended2 = extended_endpoint2[lor_index]
            
            # Get detector assignments
            val_detector1_id = full_detector_id_pairs[lor_index, 0]
            val_detector2_id = full_detector_id_pairs[lor_index, 1]
            
            # Find detector positions
            det_idx1 = np.where(detector_ids == val_detector1_id)[0][0]
            det_idx2 = np.where(detector_ids == val_detector2_id)[0][0]
            val_detector1_pos = detector_positions[det_idx1]
            val_detector2_pos = detector_positions[det_idx2]
            
            validation_data = {
                'original_endpoints': [val_endpoint1, val_endpoint2],
                'extended_endpoints': [val_extended1, val_extended2],
                'detector_ids': [val_detector1_id, val_detector2_id],
                'detector_positions': [val_detector1_pos, val_detector2_pos],
                'distances': [distances1[lor_index], distances2[lor_index]],
                'within_threshold': original_lor_survived
            }
        
        # Update traces with new data
        if validation_data:
            orig_ep1, orig_ep2 = validation_data['original_endpoints']
            ext_ep1, ext_ep2 = validation_data['extended_endpoints']
            det_pos1, det_pos2 = validation_data['detector_positions']
            det_id1, det_id2 = validation_data['detector_ids']
            is_valid = validation_data['within_threshold']
            dist1, dist2 = validation_data['distances']
            
            # Color coding based on validity
            endpoint_color = 'orange' if is_valid else 'red'
            extended_color = 'purple' if is_valid else 'magenta'
            detector_color = 'yellow' if is_valid else 'pink'
            line_color_1 = 'orange' if is_valid else 'red'
            line_color_2 = 'purple' if is_valid else 'magenta'
            
            status_text = f"VALID (≤{distance_threshold}mm)" if is_valid else f"FILTERED (>{distance_threshold}mm)"
            
            # Update dynamic traces (starting from index after static traces)
            static_traces = 3 if highlight_detector_ids is not None else 2
            
            # Original endpoints
            with fig.batch_update():
                fig.data[static_traces].update(
                    x=[orig_ep1[0], orig_ep2[0]],
                    y=[orig_ep1[1], orig_ep2[1]],
                    z=[orig_ep1[2], orig_ep2[2]],
                    marker=dict(
                        size=8,
                        color=endpoint_color,
                        opacity=1.0,
                        symbol='circle',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Original Endpoints ({status_text})',
                    hovertemplate='<b>Original Endpoint</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Extended endpoints
                fig.data[static_traces + 1].update(
                    x=[ext_ep1[0], ext_ep2[0]],
                    y=[ext_ep1[1], ext_ep2[1]],
                    z=[ext_ep1[2], ext_ep2[2]],
                    marker=dict(
                        size=6,
                        color=extended_color,
                        opacity=1.0,
                        symbol='square',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Extended Endpoints',
                    hovertemplate='<b>Extended Endpoint</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Assigned detectors
                fig.data[static_traces + 2].update(
                    x=[det_pos1[0], det_pos2[0]],
                    y=[det_pos1[1], det_pos2[1]],
                    z=[det_pos1[2], det_pos2[2]],
                    marker=dict(
                        size=8,
                        color=detector_color,
                        opacity=1.0,
                        symbol='diamond-open',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Assigned Detectors (d1:{dist1:.1f}mm, d2:{dist2:.1f}mm)',
                    customdata=[det_id1, det_id2],
                    hovertemplate='<b>Assigned Detector ID: %{customdata}</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Original LOR line
                fig.data[static_traces + 3].update(
                    x=[orig_ep1[0], orig_ep2[0]],
                    y=[orig_ep1[1], orig_ep2[1]],
                    z=[orig_ep1[2], orig_ep2[2]],
                    line=dict(color='black', width=6),
                    name=f'LOR {lor_index} - Original LOR Line'
                )
                
                # Extension paths
                fig.data[static_traces + 4].update(
                    x=[orig_ep1[0], ext_ep1[0]],
                    y=[orig_ep1[1], ext_ep1[1]],
                    z=[orig_ep1[2], ext_ep1[2]],
                    line=dict(color=line_color_1, width=4, dash='dash')
                )
                
                fig.data[static_traces + 5].update(
                    x=[orig_ep2[0], ext_ep2[0]],
                    y=[orig_ep2[1], ext_ep2[1]],
                    z=[orig_ep2[2], ext_ep2[2]],
                    line=dict(color=line_color_1, width=4, dash='dash')
                )
                
                # Detector assignment paths
                fig.data[static_traces + 6].update(
                    x=[ext_ep1[0], det_pos1[0]],
                    y=[ext_ep1[1], det_pos1[1]],
                    z=[ext_ep1[2], det_pos1[2]],
                    line=dict(color=line_color_2, width=4, dash='dot')
                )
                
                fig.data[static_traces + 7].update(
                    x=[ext_ep2[0], det_pos2[0]],
                    y=[ext_ep2[1], det_pos2[1]],
                    z=[ext_ep2[2], det_pos2[2]],
                    line=dict(color=line_color_2, width=4, dash='dot')
                )
                
                # Update title
                fig.layout.title.text = f'3D LOR Validation: LOR {lor_index} ({status_text})'
        
        else:
            # Clear dynamic traces if no validation data
            static_traces = 3 if highlight_detector_ids is not None else 2
            with fig.batch_update():
                for i in range(len(trace_names)):
                    fig.data[static_traces + i].update(x=[], y=[], z=[])
                fig.layout.title.text = f'3D LOR Validation: LOR {lor_index} (INVALID INDEX)'
    
    # Set up the interactive slider
    max_lor_index = coordinates_sample.shape[0] - 1
    initial_lor_index = validate_lor_index if validate_lor_index is not None else 0
    
    slider = IntSlider(
        value=initial_lor_index,
        min=0,
        max=max_lor_index,
        step=1,
        description='LOR Index:',
        style={'description_width': 'initial'},
        layout={'width': '400px'}
    )
    
    if verbose:
        print(f"\nInteractive slider created for LOR indices 0-{max_lor_index}")
        print(f"Distance threshold: {distance_threshold} mm")
        print("Camera orientation will persist during slider interactions!")
    
    # Connect slider to update function
    def on_slider_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            update_lor_visualization(change['new'])
    
    slider.observe(on_slider_change)
    
    # Initialize with the starting LOR
    update_lor_visualization(initial_lor_index)
    
    # Display the widget and figure
    display(VBox([slider, fig]))
    
    return full_detector_id_pairs, valid_detector_id_pairs, valid_lor_indices



# Silent mode with 5mm distance filtering (default):
# fig, full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d(
#     filtered_coordinates.numpy(), info, verbose=False)

# Custom distance threshold with verbose output:
# fig, full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d(
#     filtered_coordinates.numpy(), info, distance_threshold=3.0, verbose=True)

# Interactive slider mode with distance filtering:
# fig, full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d(
#     filtered_coordinates.numpy(), info, interactive_slider=True, 
#     distance_threshold=5.0, verbose=True)

# Validate specific LOR silently with custom threshold:
# fig, full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d(
#     filtered_coordinates.numpy(), info, validate_lor_index=1, 
#     distance_threshold=2.0, verbose=False)

full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d(filtered_coordinates.numpy(), info,
                                                    #  highlight_detector_ids=[0, 1, 5, 4370],
                                                     verbose=True,
                                                    #  verbose=False,
                                                     validate_lor_index=1,
                                                     distance_threshold=5.0,)

print(full_detector_ids[0])


Visualizing all 6591 LORs

Distance Filtering (threshold: 5.0 mm):
Original LORs: 6591
Valid LORs: 3987
Filtered out: 2604
Filtering efficiency: 60.5%

Interactive slider created for LOR indices 0-6590
Distance threshold: 5.0 mm
Camera orientation will persist during slider interactions!


VBox(children=(IntSlider(value=1, description='LOR Index:', layout=Layout(width='400px'), max=6590, style=Slid…

[  193 28543]


In [58]:
print(f"Detector IDs shape: {detector_ids.shape}")
print(f"Detector ID range: {detector_ids.min()} to {detector_ids.max()}")
print(f"Unique detector IDs: {len(torch.unique(detector_ids))}")
print(f"Expected total detectors: {info['NrCrystalsPerRing'] * info['NrRings']}")

Detector IDs shape: torch.Size([3987, 2])
Detector ID range: 0 to 28670
Unique detector IDs: 6945
Expected total detectors: 28672


In [59]:
# Optional detector assignment along the axial plane first, insead of the circumferencial plane first

import plotly.graph_objects as go
from plotly.graph_objs import FigureWidget
import numpy as np
from scipy.spatial import cKDTree
from ipywidgets import interact, IntSlider, VBox
from IPython.display import display

def visualize_lors_and_detectors_3d_axial_ids(coordinates, scanner_info, sample_lors=None, 
                                             highlight_detector_ids=None, highlight_color='green',
                                             validate_lor_index=None, verbose=True, 
                                             distance_threshold=5.0):
    """
    3D visualization of LOR endpoints and PET detector positions with interactive slider for LOR validation.
    Camera orientation is preserved during slider interactions.
    
    DETECTOR ID ORDERING: Axial-first (same angular position across rings, then next angular position)
    
    Parameters:
    -----------
    coordinates : numpy.ndarray
        Shape (n_lors, 6) containing [x1, y1, z1, x2, y2, z2] for each LOR
    scanner_info : dict
        Dictionary containing scanner geometry parameters
    sample_lors : int, optional
        Number of LORs to sample for visualization (default: all)
    highlight_detector_ids : list or set, optional
        Detector IDs to highlight with different color
    highlight_color : str, optional
        Color for highlighted detectors (default: 'green')
    validate_lor_index : int, optional
        Initial LOR index to display in the slider (default: 0)
    verbose : bool, optional
        If False, suppresses all print outputs (default: True)
    distance_threshold : float, optional
        Maximum allowed distance (in mm) between extended endpoints and assigned detectors.
        LOR pairs where either endpoint exceeds this distance are excluded (default: 5.0)
        
    Returns:
    --------
    tuple
        (full_detector_id_pairs, valid_detector_id_pairs, valid_lor_indices)
    """
    
    def extend_endpoints_to_radius(endpoints1, endpoints2, target_radius):
        """
        Extend LOR endpoints along the LOR direction until they reach the target radius.
        """
        # Calculate LOR direction vectors (from endpoint1 to endpoint2)
        lor_directions = endpoints2 - endpoints1
        lor_lengths = np.linalg.norm(lor_directions, axis=1, keepdims=True)
        
        # Avoid division by zero
        lor_lengths = np.where(lor_lengths == 0, 1e-10, lor_lengths)
        lor_unit_vectors = lor_directions / lor_lengths
        
        def find_radius_intersection(start_point, direction, target_radius):
            """Find where a ray intersects the cylinder at target_radius."""
            x0, y0, z0 = start_point.T
            dx, dy, dz = direction.T
            
            # Quadratic equation coefficients: at² + bt + c = 0
            a = dx**2 + dy**2
            b = 2 * (x0 * dx + y0 * dy)
            c = x0**2 + y0**2 - target_radius**2
            
            # Solve quadratic equation
            valid_mask = np.abs(a) > 1e-10  # Direction has xy component
            result = start_point.copy()
            
            if np.any(valid_mask):
                a_valid = a[valid_mask]
                b_valid = b[valid_mask]
                c_valid = c[valid_mask]
                
                discriminant = b_valid**2 - 4 * a_valid * c_valid
                solvable_mask = discriminant >= 0
                
                if np.any(solvable_mask):
                    sqrt_disc = np.sqrt(discriminant[solvable_mask])
                    t1 = (-b_valid[solvable_mask] + sqrt_disc) / (2 * a_valid[solvable_mask])
                    t2 = (-b_valid[solvable_mask] - sqrt_disc) / (2 * a_valid[solvable_mask])
                    
                    # Choose the positive t value
                    t_chosen = np.where(
                        (t1 > 0) & (t2 > 0), np.minimum(t1, t2),
                        np.where(t1 > 0, t1, 
                                np.where(t2 > 0, t2, np.maximum(t1, t2)))
                    )
                    
                    # Calculate intersection points
                    valid_indices = np.where(valid_mask)[0]
                    solvable_indices = valid_indices[solvable_mask]
                    
                    result[solvable_indices] = (start_point[solvable_indices] + 
                                            t_chosen.reshape(-1, 1) * direction[solvable_indices])
            
            return result
        
        # Extend endpoints in OPPOSITE directions along the LOR line
        extended_endpoints1 = find_radius_intersection(endpoints1, -lor_unit_vectors, target_radius)
        extended_endpoints2 = find_radius_intersection(endpoints2, lor_unit_vectors, target_radius)
        
        return extended_endpoints1, extended_endpoints2
    
    # Sample LORs if specified
    if sample_lors is not None and sample_lors < coordinates.shape[0]:
        indices = np.random.choice(coordinates.shape[0], sample_lors, replace=False)
        coordinates_sample = coordinates[indices]
        if verbose:
            print(f"Sampling {sample_lors} LORs out of {coordinates.shape[0]} total")
    else:
        coordinates_sample = coordinates
        if verbose:
            print(f"Visualizing all {coordinates.shape[0]} LORs")
    
    # Extract LOR endpoints
    n_lors = coordinates_sample.shape[0]
    endpoint1 = coordinates_sample[:, :3]  # [x1, y1, z1]
    endpoint2 = coordinates_sample[:, 3:]  # [x2, y2, z2]
    
    # Generate detector positions and IDs with AXIAL-FIRST ordering
    radius = scanner_info['radius']
    crystals_per_ring = scanner_info['NrCrystalsPerRing']
    n_rings = scanner_info['NrRings']
    crystal_axial_spacing = scanner_info['crystalAxialSpacing']
    module_axial_spacing = scanner_info['moduleAxialSpacing']
    module_axial_nr = scanner_info['moduleAxialNr']
    crystals_axial_per_module = scanner_info['crystalAxialNr']
    
    # Calculate total axial extent
    crystals_per_module_axial = crystals_axial_per_module
    modules_span = (module_axial_nr - 1) * module_axial_spacing
    crystals_within_modules_span = (crystals_per_module_axial - 1) * crystal_axial_spacing * module_axial_nr
    total_axial_extent = modules_span + crystals_within_modules_span
    
    # Generate detector positions and IDs with AXIAL-FIRST ordering
    detector_positions = []
    detector_ids = []
    
    # CHANGED: Loop over crystals (angular positions) first, then rings (axial positions)
    for crystal in range(crystals_per_ring):
        # Calculate angular position
        angle = 2 * np.pi * crystal / crystals_per_ring
        x_base = radius * np.cos(angle)
        y_base = radius * np.sin(angle)
        
        for ring in range(n_rings):
            # Calculate z position for this ring
            if n_rings == 1:
                z_pos = 0.0
            else:
                z_pos = -total_axial_extent/2 + ring * (total_axial_extent / (n_rings - 1))
            
            detector_positions.append([x_base, y_base, z_pos])
            # CHANGED: ID assignment is now crystal * n_rings + ring (axial-first)
            detector_id = crystal * n_rings + ring
            detector_ids.append(detector_id)
    
    detector_positions = np.array(detector_positions)
    detector_ids = np.array(detector_ids)
    
    if verbose:
        print(f"\nDetector ID Ordering: AXIAL-FIRST")
        print(f"Example: Angular position 0 has detector IDs: {list(range(0, n_rings))}")
        print(f"Example: Angular position 1 has detector IDs: {list(range(n_rings, 2*n_rings))}")
    
    # Prepare detector colors
    if highlight_detector_ids is not None:
        highlight_set = set(highlight_detector_ids)
        highlight_mask = np.isin(detector_ids, list(highlight_set))
        regular_mask = ~highlight_mask
        regular_positions = detector_positions[regular_mask]
        highlight_positions = detector_positions[highlight_mask]
        regular_ids = detector_ids[regular_mask]
        highlight_ids = detector_ids[highlight_mask]
        
        if verbose:
            print(f"Highlighting {len(highlight_positions)} detectors: {sorted(highlight_set)}")
    else:
        regular_positions = detector_positions
        highlight_positions = np.array([]).reshape(0, 3)
        regular_ids = detector_ids
        highlight_ids = np.array([])
    
    # Calculate detector assignments for all LORs
    extended_endpoint1, extended_endpoint2 = extend_endpoints_to_radius(endpoint1, endpoint2, radius)
    
    # Build KDTree for efficient nearest neighbor search
    kdtree = cKDTree(detector_positions)
    
    # Find nearest detector for each extended endpoint
    _, nearest_indices1 = kdtree.query(extended_endpoint1)
    _, nearest_indices2 = kdtree.query(extended_endpoint2)
    
    # Convert to detector IDs
    full_detector_id_pairs = np.column_stack([
        detector_ids[nearest_indices1],
        detector_ids[nearest_indices2]
    ])
    
    # Calculate distances and apply filtering
    distances1 = np.linalg.norm(extended_endpoint1 - detector_positions[nearest_indices1], axis=1)
    distances2 = np.linalg.norm(extended_endpoint2 - detector_positions[nearest_indices2], axis=1)
    valid_mask = (distances1 <= distance_threshold) & (distances2 <= distance_threshold)
    
    valid_detector_id_pairs = full_detector_id_pairs[valid_mask]
    valid_lor_indices = np.where(valid_mask)[0]
    
    # Print initial statistics
    if verbose:
        print(f"\nDistance Filtering (threshold: {distance_threshold} mm):")
        print(f"Original LORs: {len(full_detector_id_pairs)}")
        print(f"Valid LORs: {len(valid_detector_id_pairs)}")
        print(f"Filtered out: {len(full_detector_id_pairs) - len(valid_detector_id_pairs)}")
        print(f"Filtering efficiency: {len(valid_detector_id_pairs)/len(full_detector_id_pairs)*100:.1f}%")
    
    # Create the persistent FigureWidget
    fig = FigureWidget()
    
    # Add static background elements that don't change with slider
    all_lor_points = np.vstack([endpoint1, endpoint2])
    bg_opacity = 0.3
    detector_opacity = 0.4
    
    # Add ALL LOR endpoints (static background)
    fig.add_trace(go.Scatter3d(
        x=all_lor_points[:, 0],
        y=all_lor_points[:, 1],
        z=all_lor_points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color='red',
            opacity=bg_opacity,
            line=dict(width=0)
        ),
        name=f'All LOR Endpoints',
        hovertemplate='<b>LOR Endpoint</b><br>' +
                      'X: %{x:.1f} mm<br>' +
                      'Y: %{y:.1f} mm<br>' +
                      'Z: %{z:.1f} mm<extra></extra>'
    ))
    
    # Add regular detectors (static)
    if len(regular_positions) > 0:
        fig.add_trace(go.Scatter3d(
            x=regular_positions[:, 0],
            y=regular_positions[:, 1],
            z=regular_positions[:, 2],
            mode='markers',
            marker=dict(
                size=3,
                color='blue',
                opacity=detector_opacity,
                symbol='diamond',
                line=dict(width=0)
            ),
            name='Regular Detectors',
            customdata=regular_ids,
            hovertemplate='<b>Detector ID: %{customdata}</b><br>' +
                          'X: %{x:.1f} mm<br>' +
                          'Y: %{y:.1f} mm<br>' +
                          'Z: %{z:.1f} mm<extra></extra>'
        ))
    
    # Add highlighted detectors (static)
    if len(highlight_positions) > 0:
        fig.add_trace(go.Scatter3d(
            x=highlight_positions[:, 0],
            y=highlight_positions[:, 1],
            z=highlight_positions[:, 2],
            mode='markers',
            marker=dict(
                size=5,
                color=highlight_color,
                opacity=1.0,
                symbol='diamond',
                line=dict(width=1, color='black')
            ),
            name=f'Highlighted Detectors ({highlight_color})',
            customdata=highlight_ids,
            hovertemplate='<b>Highlighted Detector ID: %{customdata}</b><br>' +
                          'X: %{x:.1f} mm<br>' +
                          'Y: %{y:.1f} mm<br>' +
                          'Z: %{z:.1f} mm<extra></extra>'
        ))
    
    # Add placeholder traces for dynamic LOR validation elements
    # These will be updated by the slider function
    trace_names = [
        'LOR - Original Endpoints',
        'LOR - Extended Endpoints', 
        'LOR - Assigned Detectors',
        'LOR - Original Line',
        'Extension Path 1',
        'Extension Path 2', 
        'Detector Assignment 1',
        'Detector Assignment 2'
    ]
    
    # Add empty traces that will be populated by slider updates
    for i, name in enumerate(trace_names):
        fig.add_trace(go.Scatter3d(
            x=[], y=[], z=[],
            mode='markers' if 'Endpoints' in name or 'Detectors' in name else 'lines',
            name=name,
            showlegend=True if 'Endpoints' in name or 'Detectors' in name else False
        ))
    
    # Set initial layout
    fig.update_layout(
        title=f'3D LOR Validation with Persistent Camera View (Axial ID Ordering)',
        scene=dict(
            xaxis_title='X (mm)',
            yaxis_title='Y (mm)',
            zaxis_title='Z (mm)',
            camera=dict(
                eye=dict(x=1.5, y=1.5, z=1.5),
                center=dict(x=0, y=0, z=0),
                up=dict(x=0, y=0, z=1)
            ),
            aspectmode='cube'
        ),
        width=900,
        height=700,
        legend=dict(
            x=0.02,
            y=0.98,
            bgcolor='rgba(255,255,255,0.8)'
        )
    )
    
    def update_lor_visualization(lor_index):
        """Update only the LOR-specific traces without recreating the entire plot"""
        
        # Get validation data for the selected LOR
        validation_data = None
        if lor_index < coordinates_sample.shape[0]:
            # Check if this LOR survived filtering
            original_lor_survived = lor_index in valid_lor_indices
            
            # Get the specific LOR's data (regardless of whether it passed filtering)
            val_endpoint1 = coordinates_sample[lor_index, :3]
            val_endpoint2 = coordinates_sample[lor_index, 3:]
            val_extended1 = extended_endpoint1[lor_index]
            val_extended2 = extended_endpoint2[lor_index]
            
            # Get detector assignments
            val_detector1_id = full_detector_id_pairs[lor_index, 0]
            val_detector2_id = full_detector_id_pairs[lor_index, 1]
            
            # Find detector positions
            det_idx1 = np.where(detector_ids == val_detector1_id)[0][0]
            det_idx2 = np.where(detector_ids == val_detector2_id)[0][0]
            val_detector1_pos = detector_positions[det_idx1]
            val_detector2_pos = detector_positions[det_idx2]
            
            validation_data = {
                'original_endpoints': [val_endpoint1, val_endpoint2],
                'extended_endpoints': [val_extended1, val_extended2],
                'detector_ids': [val_detector1_id, val_detector2_id],
                'detector_positions': [val_detector1_pos, val_detector2_pos],
                'distances': [distances1[lor_index], distances2[lor_index]],
                'within_threshold': original_lor_survived
            }
        
        # Update traces with new data
        if validation_data:
            orig_ep1, orig_ep2 = validation_data['original_endpoints']
            ext_ep1, ext_ep2 = validation_data['extended_endpoints']
            det_pos1, det_pos2 = validation_data['detector_positions']
            det_id1, det_id2 = validation_data['detector_ids']
            is_valid = validation_data['within_threshold']
            dist1, dist2 = validation_data['distances']
            
            # Color coding based on validity
            endpoint_color = 'orange' if is_valid else 'red'
            extended_color = 'purple' if is_valid else 'magenta'
            detector_color = 'yellow' if is_valid else 'pink'
            line_color_1 = 'orange' if is_valid else 'red'
            line_color_2 = 'purple' if is_valid else 'magenta'
            
            status_text = f"VALID (≤{distance_threshold}mm)" if is_valid else f"FILTERED (>{distance_threshold}mm)"
            
            # Update dynamic traces (starting from index after static traces)
            static_traces = 3 if highlight_detector_ids is not None else 2
            
            # Original endpoints
            with fig.batch_update():
                fig.data[static_traces].update(
                    x=[orig_ep1[0], orig_ep2[0]],
                    y=[orig_ep1[1], orig_ep2[1]],
                    z=[orig_ep1[2], orig_ep2[2]],
                    marker=dict(
                        size=8,
                        color=endpoint_color,
                        opacity=1.0,
                        symbol='circle',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Original Endpoints ({status_text})',
                    hovertemplate='<b>Original Endpoint</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Extended endpoints
                fig.data[static_traces + 1].update(
                    x=[ext_ep1[0], ext_ep2[0]],
                    y=[ext_ep1[1], ext_ep2[1]],
                    z=[ext_ep1[2], ext_ep2[2]],
                    marker=dict(
                        size=6,
                        color=extended_color,
                        opacity=1.0,
                        symbol='square',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Extended Endpoints',
                    hovertemplate='<b>Extended Endpoint</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Assigned detectors
                fig.data[static_traces + 2].update(
                    x=[det_pos1[0], det_pos2[0]],
                    y=[det_pos1[1], det_pos2[1]],
                    z=[det_pos1[2], det_pos2[2]],
                    marker=dict(
                        size=8,
                        color=detector_color,
                        opacity=1.0,
                        symbol='diamond-open',
                        line=dict(width=2, color='black')
                    ),
                    name=f'LOR {lor_index} - Assigned Detectors (d1:{dist1:.1f}mm, d2:{dist2:.1f}mm)',
                    customdata=[det_id1, det_id2],
                    hovertemplate='<b>Assigned Detector ID: %{customdata}</b><br>' +
                                  'X: %{x:.1f} mm<br>' +
                                  'Y: %{y:.1f} mm<br>' +
                                  'Z: %{z:.1f} mm<extra></extra>'
                )
                
                # Original LOR line
                fig.data[static_traces + 3].update(
                    x=[orig_ep1[0], orig_ep2[0]],
                    y=[orig_ep1[1], orig_ep2[1]],
                    z=[orig_ep1[2], orig_ep2[2]],
                    line=dict(color='black', width=6),
                    name=f'LOR {lor_index} - Original LOR Line'
                )
                
                # Extension paths
                fig.data[static_traces + 4].update(
                    x=[orig_ep1[0], ext_ep1[0]],
                    y=[orig_ep1[1], ext_ep1[1]],
                    z=[orig_ep1[2], ext_ep1[2]],
                    line=dict(color=line_color_1, width=4, dash='dash')
                )
                
                fig.data[static_traces + 5].update(
                    x=[orig_ep2[0], ext_ep2[0]],
                    y=[orig_ep2[1], ext_ep2[1]],
                    z=[orig_ep2[2], ext_ep2[2]],
                    line=dict(color=line_color_1, width=4, dash='dash')
                )
                
                # Detector assignment paths
                fig.data[static_traces + 6].update(
                    x=[ext_ep1[0], det_pos1[0]],
                    y=[ext_ep1[1], det_pos1[1]],
                    z=[ext_ep1[2], det_pos1[2]],
                    line=dict(color=line_color_2, width=4, dash='dot')
                )
                
                fig.data[static_traces + 7].update(
                    x=[ext_ep2[0], det_pos2[0]],
                    y=[ext_ep2[1], det_pos2[1]],
                    z=[ext_ep2[2], det_pos2[2]],
                    line=dict(color=line_color_2, width=4, dash='dot')
                )
                
                # Update title
                fig.layout.title.text = f'3D LOR Validation: LOR {lor_index} ({status_text}) - Axial ID Ordering'
        
        else:
            # Clear dynamic traces if no validation data
            static_traces = 3 if highlight_detector_ids is not None else 2
            with fig.batch_update():
                for i in range(len(trace_names)):
                    fig.data[static_traces + i].update(x=[], y=[], z=[])
                fig.layout.title.text = f'3D LOR Validation: LOR {lor_index} (INVALID INDEX) - Axial ID Ordering'
    
    # Set up the interactive slider
    max_lor_index = coordinates_sample.shape[0] - 1
    initial_lor_index = validate_lor_index if validate_lor_index is not None else 0
    
    slider = IntSlider(
        value=initial_lor_index,
        min=0,
        max=max_lor_index,
        step=1,
        description='LOR Index:',
        style={'description_width': 'initial'},
        layout={'width': '400px'}
    )
    
    if verbose:
        print(f"\nInteractive slider created for LOR indices 0-{max_lor_index}")
        print(f"Distance threshold: {distance_threshold} mm")
        print("Camera orientation will persist during slider interactions!")
    
    # Connect slider to update function
    def on_slider_change(change):
        if change['type'] == 'change' and change['name'] == 'value':
            update_lor_visualization(change['new'])
    
    slider.observe(on_slider_change)
    
    # Initialize with the starting LOR
    update_lor_visualization(initial_lor_index)
    
    # Display the widget and figure
    display(VBox([slider, fig]))
    
    return full_detector_id_pairs, valid_detector_id_pairs, valid_lor_indices

# full_detector_ids, valid_detector_ids, valid_indices = visualize_lors_and_detectors_3d_axial_ids(filtered_coordinates.numpy(), info,
#                                              highlight_detector_ids=[0, 1, 5, 4370],
#                                                 validate_lor_index=1,
#                                                 verbose=False,
#                                                 distance_threshold=20.0,)

In [60]:
detector_ids = torch.tensor(valid_detector_ids)
detector_ids.shape

torch.Size([3987, 2])

In [61]:
object_meta = ObjectMeta(
    dr=(voxel_size,voxel_size,voxel_size), #mm
    shape=voxel_space #voxels
    # shape=(128,128,96) #voxels

)

In [62]:
# Get or Create attenuation maps

# atten_map = gate.get_aligned_attenuation_map(os.path.join(path, 'gate_simulation/simple_phantom/umap_mMR_brainSimplePhantom.hv'), object_meta).to(pytomography.device)

ones_atten_map = torch.ones(voxel_space)
zeros_atten_map = torch.zeros(voxel_space)

# Create a half-rectangular attenuation map
half_rect_atten_map = torch.zeros(voxel_space)
mid_x = voxel_space[0] // 2
half_rect_atten_map[:mid_x, :, :] = torch.linspace(0, 0.1, mid_x).view(-1, 1, 1)

# Create a gradient attenuation map
coords = torch.stack(torch.meshgrid(
    torch.arange(voxel_space[0]),
    torch.arange(voxel_space[1]),
    torch.arange(voxel_space[2]),
    indexing='ij'
), dim=-1).float()
gradient_atten_map = coords.mean(dim=-1)
gradient_atten_map = gradient_atten_map / gradient_atten_map.max()

# Create a cylindrical attenuation map with radial gradient
radius_mm = 278  # cylinder radius in mm
height = voxel_space[2] * voxel_size  # cylinder height in mm
center_x = voxel_space[0] // 2
center_y = voxel_space[1] // 2
center_z = voxel_space[2] // 2
radius_vox = radius_mm / voxel_size

# Generate grid of voxel indices
x = torch.arange(voxel_space[0]).float()
y = torch.arange(voxel_space[1]).float()
z = torch.arange(voxel_space[2]).float()
xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')

# Compute distance from central axis for each voxel
dist_from_axis = torch.sqrt((xx - center_x)**2 + (yy - center_y)**2)

# Create gradient cylinder attenuation map
gradient_cylindrical_atten_map = torch.ones(voxel_space)
normalized_dist = dist_from_axis / radius_vox  # Normalize distances to [0,1]

# Create radial gradient: high attenuation (1.0) at center, decreasing outward
gradient_cylindrical_atten_map = 1.0 - normalized_dist.clone()  # Invert the gradient
gradient_cylindrical_atten_map[gradient_cylindrical_atten_map < 0] = 0  # Clip negative values
gradient_cylindrical_atten_map[dist_from_axis > radius_vox] = 0  # Set outside cylinder to 0
gradient_cylindrical_atten_map /= 100

# Also create a solid cylinder for comparison
cylindrical_atten_map = torch.zeros(voxel_space)
cylindrical_atten_map[dist_from_axis <= radius_vox] = 0.01 # 0.14  cm⁻¹ is the attenuation for dense bone apparently


# Create a spherical attenuation map centered at origin
def create_spherical_atten_map(voxel_space, voxel_size, radius_mm=None, atten_value=0.1):
    """
    Create a spherical attenuation map centered at the midpoint of voxel space.
    
    Args:
        voxel_space: tuple (x, y, z) dimensions of the voxel grid
        voxel_size: size of each voxel in mm
        radius_mm: sphere radius in mm. If None, defaults to 1/4 of smallest axis
        atten_value: attenuation value inside the sphere (default 0.1)
    
    Returns:
        torch.Tensor: spherical attenuation map with shape voxel_space
    """
    # Calculate default radius if not specified
    if radius_mm is None:
        min_axis_mm = min(voxel_space) * voxel_size
        radius_mm = min_axis_mm / 4
    
    # Convert radius to voxel units
    radius_vox = radius_mm / voxel_size
    
    # Calculate center coordinates
    center_x = voxel_space[0] // 2
    center_y = voxel_space[1] // 2
    center_z = voxel_space[2] // 2
    
    # Generate grid of voxel indices
    x = torch.arange(voxel_space[0]).float()
    y = torch.arange(voxel_space[1]).float()
    z = torch.arange(voxel_space[2]).float()
    xx, yy, zz = torch.meshgrid(x, y, z, indexing='ij')
    
    # Compute distance from center for each voxel
    dist_from_center = torch.sqrt((xx - center_x)**2 + (yy - center_y)**2 + (zz - center_z)**2)
    
    # Create spherical attenuation map
    spherical_atten_map = torch.zeros(voxel_space)
    spherical_atten_map[dist_from_center <= radius_vox] = atten_value
    
    return spherical_atten_map

# Usage example with your existing variables:
spherical_atten_map = create_spherical_atten_map(voxel_space, voxel_size, atten_value=0.01)
# spherical_atten_map = create_spherical_atten_map(voxel_space, voxel_size, radius_mm=100, atten_value=0.1)

In [63]:
# Normalisation Weights if Applicable
# normalisation_weights = torch.ones((411027456,), dtype=torch.float32).to(pytomography.device)
# normalisation_weights = torch.load(os.path.join(path, 'normalization_weights.pt'))
# normalisation_weights = None

In [None]:
# atten_map = half_rect_atten_map.to(pytomography.device)
# atten_map = gradient_atten_map.to(pytomography.device)
atten_map = cylindrical_atten_map.to(pytomography.device)
# atten_map = gradient_cylindrical_atten_map.to(pytomography.device)
# atten_map = ones_atten_map.to(pytomography.device)
# atten_map = zeros_atten_map.to(pytomography.device)
# atten_map = spherical_atten_map.to(pytomography.device)
# atten_map = None

proj_meta = PETLMProjMeta(
    detector_ids,
    info,
    # weights_sensitivity=normalisation_weights,
    )

psf_transform = GaussianFilter(3.)

system_matrix = PETLMSystemMatrix(
       object_meta,
       proj_meta,
       obj2obj_transforms = [psf_transform],
       N_splits=1,
       device=pytomography.device,
       attenuation_map=atten_map,
)

initial_image = torch.ones(object_meta.shape)

# Forward projection
projections = system_matrix.forward(initial_image)

# Back projection
back_projected = system_matrix.backward(projections)

print(f"Input image shape: {initial_image.shape}")
print(f"Projections shape: {projections.shape}")
print(f"Back projected shape: {back_projected.shape}")

plot_cross_sections_interactive(back_projected.cpu().numpy())

Input image shape: torch.Size([140, 140, 74])
Projections shape: torch.Size([3987])
Back projected shape: torch.Size([140, 140, 74])


interactive(children=(IntSlider(value=70, description='X index', max=139), IntSlider(value=70, description='Y …

In [65]:
sinogram_delays  = gate.listmode_to_sinogram(detector_ids, info)
sinogram_delays  = gate.smooth_randoms_sinogram(sinogram_delays , info, sigma_r=4, sigma_theta=4, sigma_z=4)
lm_delays = shared.sinogram_to_listmode(detector_ids, sinogram_delays , info)

In [None]:
# Get additive term (without scatter term):
lm_sensitivity = system_matrix._compute_sensitivity_projection(all_ids=False)
additive_term = lm_delays / lm_sensitivity
additive_term[additive_term.isnan()] = 0 # remove NaN values

# Recon
likelihood = PoissonLogLikelihood(
        system_matrix,
        projections=projections,
        additive_term = additive_term
    )

recon_algorithm = OSEM(likelihood, object_initial=back_projected)
recon_without_scatter_estimation = recon_algorithm(4,14)

plot_cross_sections_interactive(recon_without_scatter_estimation.cpu().numpy())


interactive(children=(IntSlider(value=70, description='X index', max=139), IntSlider(value=70, description='Y …

In [72]:
print(back_projected.cpu().numpy().mean())
print(recon_without_scatter_estimation.cpu().numpy().mean())
diff = back_projected.cpu().numpy() - recon_without_scatter_estimation.cpu().numpy()
print(diff.max())

2.0966645e-06
2.0966645e-06
0.0


In [68]:
# sinogram_scatter = sss.get_sss_scatter_estimate(
#     object_meta = object_meta,
#     proj_meta = proj_meta,
#     pet_image = recon_without_scatter_estimation,
#     attenuation_image = atten_map,
#     system_matrix = system_matrix,
#     # proj_data = None, # assumes listmode
#     image_stepsize = 4,
#     attenuation_cutoff = 0.004,
#     sinogram_interring_stepsize = 4,
#     sinogram_intraring_stepsize = 4,
#     sinogram_random = sinogram_delays
#     )

# # Now convert to listmode and make additive term
# lm_scatter = shared.sinogram_to_listmode(proj_meta.detector_ids, sinogram_scatter, proj_meta.info)
# additive_term = (lm_scatter + lm_delays) / lm_sensitivity
# additive_term[additive_term.isnan()] = 0

In [69]:
likelihood = PoissonLogLikelihood(
        system_matrix,
        additive_term = additive_term
    )
recon_algorithm = OSEM(likelihood)
recon_lm_nontof = recon_algorithm(4,14)

In [70]:
plot_cross_sections_interactive(recon_lm_nontof.cpu().numpy())

interactive(children=(IntSlider(value=70, description='X index', max=139), IntSlider(value=70, description='Y …

In [71]:
# visualize_voxel_tensor_3d(recon_without_scatter_estimation.cpu().numpy(), voxel_size_mm=voxel_size)
visualize_voxel_tensor_3d(recon_lm_nontof.cpu().numpy(), voxel_size_mm=voxel_size)

Voxel value range: 1.401298464324817e-45 to 1.829294354882549e-11
Total non-zero voxels: 47665
Initial thresholds: 1.401298464324817e-45 to 1.829294354882549e-11
Slider range: 1.401298464324817e-45 to 1.829294354882549e-11
Voxel resolution: 4.0mm


interactive(children=(FloatSlider(value=1.401298464324817e-45, continuous_update=False, description='Min Thres…