In [1]:
###############################################################################
# Initial Loading, Filtering, and Coordinate Range Calculation
###############################################################################
import numpy as np

# Load coordinates
coordinates = np.load(fr"C:\Users\h\Desktop\PetStuff\Image_Processing\ground_truth.npy")

# Confirm shape should be (pairs, coords=6), coords are (x1, y1, z1, x2, y2, z2)
print(f"\nData Shape (pairs, coords) : {coordinates.shape}\n")  

# Remove pairs where any coordinate value is exactly 0
filtered_coordinates = coordinates[~np.any(coordinates == 0, axis=1)]
print(f"Filtered shape: {filtered_coordinates.shape}\n")

# Extract all x, y, z pairs - Coordinates are in the order (x1, y1, z1, x2, y2, z2)
all_xyz = filtered_coordinates.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")


Data Shape (pairs, coords) : (62660, 6)

Filtered shape: (6591, 6)

x range: min=-278.12942504882807, max=278.1666564941406
y range: min=-278.41946411132807, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453


In [2]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider

def visualize_voxel_tensor_3d(voxel_tensor, initial_threshold=None, voxel_size_mm=1.0, world_origin=None):
    """
    Interactive 3D visualization of voxel tensor with threshold slider.
    
    Args:
        voxel_tensor: (nx, ny, nz) numpy array with voxel counts
        initial_threshold: Initial threshold value for the slider (default: min_val)
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0) (optional)
    """
    # Extract non-zero voxel coordinates and values
    coords = np.where(voxel_tensor > 0)
    x_coords, y_coords, z_coords = coords
    values = voxel_tensor[coords]
    
    # Convert voxel indices to world coordinates if world_origin provided
    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords_world = x_coords * voxel_size_mm + x_min
        y_coords_world = y_coords * voxel_size_mm + y_min
        z_coords_world = z_coords * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords_world = x_coords * voxel_size_mm
        y_coords_world = y_coords * voxel_size_mm
        z_coords_world = z_coords * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"
    
    # Get value range for slider
    min_val = int(np.min(values))
    max_val = int(np.max(values))
    
    # Set initial threshold
    if initial_threshold is None:
        initial_threshold = min_val
    else:
        initial_threshold = max(min_val, min(max_val, int(initial_threshold)))  # Clamp to valid range
    
    print(f"Voxel value range: {min_val} to {max_val}")
    print(f"Total non-zero voxels: {len(values)}")
    print(f"Initial threshold: {initial_threshold}")
    print(f"Voxel resolution: {voxel_size_mm}mm")
    
    def update_plot(threshold):
        # Filter voxels above threshold
        mask = values >= threshold
        if not np.any(mask):
            print(f"No voxels above threshold {threshold}")
            return
        
        filtered_x = x_coords_world[mask]
        filtered_y = y_coords_world[mask]
        filtered_z = z_coords_world[mask]
        filtered_values = values[mask]
        
        # Create 3D scatter plot
        fig = go.Figure(data=go.Scatter3d(
            x=filtered_x,
            y=filtered_y,
            z=filtered_z,
            mode='markers',
            marker=dict(
                size=3,
                color=filtered_values,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Voxel Count"),
                line=dict(width=0)
            ),
            text=[f'Count: {v}' for v in filtered_values],
            hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>'
        ))
        
        # Set layout for orbital controls
        fig.update_layout(
            title=f'3D Voxel Visualization (Threshold: {threshold}, Showing: {len(filtered_values)} voxels)',
            scene=dict(
                xaxis_title=f'X{coord_suffix}',
                yaxis_title=f'Y{coord_suffix}',
                zaxis_title=f'Z{coord_suffix}',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                ),
                aspectmode='cube'
            ),
            width=800,
            height=600
        )
        
        fig.show()
    
    # Create interactive slider with initial threshold
    threshold_slider =  FloatSlider(
        value=initial_threshold,
        min=min_val,
        max=max_val,
        step=1,
        description='Threshold:',
        continuous_update=False
    )
    
    # Display interactive widget
    interact(update_plot, threshold=threshold_slider)

# different thresholding thin 
# def visualize_voxel_tensor_3d(voxel_tensor, initial_threshold=None, voxel_size_mm=1.0, world_origin=None):
#     """
#     Interactive 3D visualization of voxel tensor with threshold slider positioned near colorbar.
#     """
#     # Extract non-zero voxel coordinates and values
#     coords = np.where(voxel_tensor > 0)
#     x_coords, y_coords, z_coords = coords
#     values = voxel_tensor[coords]
    
#     # Convert voxel indices to world coordinates if world_origin provided
#     if world_origin is not None:
#         x_min, y_min, z_min = world_origin
#         x_coords_world = x_coords * voxel_size_mm + x_min
#         y_coords_world = y_coords * voxel_size_mm + y_min
#         z_coords_world = z_coords * voxel_size_mm + z_min
#         coord_suffix = " (mm)"
#     else:
#         x_coords_world = x_coords * voxel_size_mm
#         y_coords_world = y_coords * voxel_size_mm
#         z_coords_world = z_coords * voxel_size_mm
#         coord_suffix = f" (×{voxel_size_mm}mm)"
    
#     # Get value range
#     min_val = float(np.min(values))
#     max_val = float(np.max(values))
    
#     # Set initial threshold
#     if initial_threshold is None:
#         initial_threshold = min_val
#     else:
#         initial_threshold = max(min_val, min(max_val, float(initial_threshold)))
    
#     print(f"Voxel value range: {min_val} to {max_val}")
#     print(f"Total non-zero voxels: {len(values)}")
#     print(f"Initial threshold: {initial_threshold}")
    
#     # Create steps for threshold values
#     n_steps = 100
#     threshold_values = np.linspace(min_val, max_val, n_steps)
    
#     # Find initial step index
#     initial_step = np.argmin(np.abs(threshold_values - initial_threshold))
    
#     steps = []
#     for i, threshold in enumerate(threshold_values):
#         # Filter data for this threshold
#         mask = values >= threshold
#         n_visible = np.sum(mask)
        
#         if n_visible > 0:
#             step = dict(
#                 method="update",
#                 args=[{
#                     "x": [x_coords_world[mask]],
#                     "y": [y_coords_world[mask]], 
#                     "z": [z_coords_world[mask]],
#                     "marker.color": [values[mask]],
#                     "text": [[f'Count: {v:.2f}' for v in values[mask]]]
#                 }, {
#                     "title": f'3D Voxel Visualization (Threshold: {threshold:.2f}, Showing: {n_visible} voxels)'
#                 }],
#                 label=f"{threshold:.2f}"
#             )
#         else:
#             # Empty plot when no voxels pass threshold
#             step = dict(
#                 method="update",
#                 args=[{
#                     "x": [[]],
#                     "y": [[]],
#                     "z": [[]],
#                     "marker.color": [[]],
#                     "text": [[]]
#                 }, {
#                     "title": f'3D Voxel Visualization (Threshold: {threshold:.2f}, Showing: 0 voxels)'
#                 }],
#                 label=f"{threshold:.2f}"
#             )
#         steps.append(step)
    
#     # Create initial filtered data
#     initial_mask = values >= initial_threshold
    
#     # Create figure with initial data
#     fig = go.Figure()
    
#     scatter = go.Scatter3d(
#         x=x_coords_world[initial_mask],
#         y=y_coords_world[initial_mask],
#         z=z_coords_world[initial_mask],
#         mode='markers',
#         marker=dict(
#             size=3,
#             color=values[initial_mask],
#             colorscale='Viridis',
#             opacity=0.8,
#             colorbar=dict(
#                 title="Voxel Count",
#                 x=1.02,
#                 thickness=15,
#                 len=0.6,
#                 y=0.7
#             ),
#             cmin=min_val,
#             cmax=max_val,
#             line=dict(width=0)
#         ),
#         text=[f'Count: {v:.2f}' for v in values[initial_mask]],
#         hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>',
#         name='Voxels'
#     )
    
#     fig.add_trace(scatter)
    
#     # Add horizontal slider positioned near colorbar
#     sliders = [dict(
#         active=initial_step,
#         currentvalue={"prefix": "Threshold: ", "xanchor": "right", "font": {"size": 12}},
#         pad={"t": 20, "b": 20},
#         steps=steps,
#         x=0.85,  # Position near colorbar
#         y=0.1,   # Bottom of plot
#         len=0.3, # Length of slider
#         xanchor="left"
#     )]
    
#     fig.update_layout(
#         sliders=sliders,
#         title=f'3D Voxel Visualization (Threshold: {initial_threshold:.2f}, Showing: {np.sum(initial_mask)} voxels)',
#         scene=dict(
#             xaxis_title=f'X{coord_suffix}',
#             yaxis_title=f'Y{coord_suffix}',
#             zaxis_title=f'Z{coord_suffix}',
#             camera=dict(eye=dict(x=1.5, y=1.5, z=1.5)),
#             aspectmode='cube'
#         ),
#         width=1000,
#         height=700
#     )
    
#     fig.show()


In [3]:
# Rasterisation function
def binary_rasterize_lors_3d_dda(pairs_coords, voxel_size_mm=1.0):
    """
    Rasterize lines of response into 3D voxel space using exact 3D DDA traversal.
    
    Args:
        pairs_coords: (N, 6) array where each row is [x1,y1,z1,x2,y2,z2]
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
    
    Returns:
        voxel_tensor: (nx, ny, nz) array with line traversal counts
    """
    # World coordinate ranges (mm)
    all_xyz = pairs_coords.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]

    x_min = x_vals.min()
    x_max = x_vals.max()
    y_min = y_vals.min()
    y_max = y_vals.max()
    z_min = z_vals.min()
    z_max = z_vals.max()

    # Round the min and max values to voxel boundaries
    x_min = np.floor(x_min / voxel_size_mm) * voxel_size_mm
    x_max = np.ceil(x_max / voxel_size_mm) * voxel_size_mm
    y_min = np.floor(y_min / voxel_size_mm) * voxel_size_mm
    y_max = np.ceil(y_max / voxel_size_mm) * voxel_size_mm
    z_min = np.floor(z_min / voxel_size_mm) * voxel_size_mm
    z_max = np.ceil(z_max / voxel_size_mm) * voxel_size_mm

    # Calculate voxel tensor shape based on resolution
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    voxel_shape = (nx, ny, nz)
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: {voxel_shape} (nx, ny, nz)")
    
    voxel_tensor = np.zeros(voxel_shape, dtype=np.int32)
    
    # Coordinate transformation: world -> voxel indices
    def world_to_voxel(coords):
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        vx = (x - x_min) / voxel_size_mm
        vy = (y - y_min) / voxel_size_mm
        vz = (z - z_min) / voxel_size_mm
        return np.column_stack([vx, vy, vz])
    
    def dda_3d(p1, p2):
        """3D DDA algorithm - returns list of (x,y,z) voxel indices along line"""
        x1, y1, z1 = p1
        x2, y2, z2 = p2
        
        # Direction and step sizes
        dx = x2 - x1
        dy = y2 - y1
        dz = z2 - z1
        
        # Number of steps is maximum of absolute differences
        steps = int(max(abs(dx), abs(dy), abs(dz)))
        if steps == 0:
            return [(int(round(x1)), int(round(y1)), int(round(z1)))]
        
        # Step increments
        x_inc = dx / steps
        y_inc = dy / steps
        z_inc = dz / steps
        
        voxels = []
        x, y, z = x1, y1, z1
        
        for _ in range(steps + 1):
            vx, vy, vz = int(round(x)), int(round(y)), int(round(z))
            
            # Check bounds
            if 0 <= vx < nx and 0 <= vy < ny and 0 <= vz < nz:
                voxels.append((vx, vy, vz))
            
            x += x_inc
            y += y_inc
            z += z_inc
        
        return voxels
    
    # Transform all coordinates to voxel space
    points1 = world_to_voxel(pairs_coords[:, :3])  # First points
    points2 = world_to_voxel(pairs_coords[:, 3:])  # Second points
    
    # Process each LOR
    for i in range(len(pairs_coords)):
        p1 = points1[i]
        p2 = points2[i]
        
        # Get voxels along this line
        voxels = dda_3d(p1, p2)
        
        # Increment voxel counts
        for vx, vy, vz in voxels:
            voxel_tensor[vx, vy, vz] += 1
    
    return voxel_tensor

voxel_tensor = binary_rasterize_lors_3d_dda(filtered_coordinates, voxel_size_mm=1)

print(f"Voxel tensor shape: {voxel_tensor.shape}")

Voxel size: 1mm
Voxel shape: (558, 557, 296) (nx, ny, nz)
Voxel tensor shape: (558, 557, 296)


In [4]:
import cupy as cp
import numpy as np

def rasterize_lors_vectorized_gpu(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    GPU-accelerated line rasterization using vectorized sampling.
    
    Args:
        pairs_coords: (N, 6) array where each row is [x1,y1,z1,x2,y2,z2]
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        samples_per_mm: Number of samples per mm along each line (default: 3)
                       Higher = more accurate but slower
    
    Returns:
        voxel_tensor: (nx, ny, nz) array with line traversal counts
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0)
    """
    print("Moving data to GPU...")
    pairs_gpu = cp.asarray(pairs_coords)
    
    # Calculate world bounds
    all_xyz = pairs_gpu.reshape(-1, 3)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
    
    x_min = float(cp.floor(cp.min(x_vals) / voxel_size_mm) * voxel_size_mm)
    x_max = float(cp.ceil(cp.max(x_vals) / voxel_size_mm) * voxel_size_mm)
    y_min = float(cp.floor(cp.min(y_vals) / voxel_size_mm) * voxel_size_mm)
    y_max = float(cp.ceil(cp.max(y_vals) / voxel_size_mm) * voxel_size_mm)
    z_min = float(cp.floor(cp.min(z_vals) / voxel_size_mm) * voxel_size_mm)
    z_max = float(cp.ceil(cp.max(z_vals) / voxel_size_mm) * voxel_size_mm)
    
    # Calculate voxel tensor shape
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: ({nx}, {ny}, {nz})")
    print(f"Samples per mm: {samples_per_mm}")
    
    # Extract line endpoints
    p1 = pairs_gpu[:, :3]  # (N, 3) start points
    p2 = pairs_gpu[:, 3:]  # (N, 3) end points
    
    # Calculate line properties
    line_vectors = p2 - p1  # (N, 3)
    line_lengths = cp.linalg.norm(line_vectors, axis=1)  # (N,)
    
    # Calculate number of samples per line
    n_samples_per_line = cp.maximum(1, (line_lengths * samples_per_mm).astype(cp.int32))
    max_samples = int(cp.max(n_samples_per_line))
    
    print(f"Max samples per line: {max_samples}")
    print(f"Total lines: {len(pairs_coords)}")
    
    # Create voxel tensor on GPU
    voxel_tensor_gpu = cp.zeros((nx, ny, nz), dtype=cp.int32)
    
    # Process lines in chunks to manage memory
    chunk_size = min(1000, len(pairs_coords))  # Adjust based on GPU memory
    
    for chunk_start in range(0, len(pairs_coords), chunk_size):
        chunk_end = min(chunk_start + chunk_size, len(pairs_coords))
        chunk_size_actual = chunk_end - chunk_start
        
        print(f"Processing chunk {chunk_start//chunk_size + 1}/{(len(pairs_coords)-1)//chunk_size + 1}")
        
        # Get chunk data
        p1_chunk = p1[chunk_start:chunk_end]  # (chunk_size, 3)
        p2_chunk = p2[chunk_start:chunk_end]  # (chunk_size, 3)
        line_vectors_chunk = line_vectors[chunk_start:chunk_end]  # (chunk_size, 3)
        n_samples_chunk = n_samples_per_line[chunk_start:chunk_end]  # (chunk_size,)
        
        # Create parameter array for sampling
        t_values = cp.linspace(0, 1, max_samples)  # (max_samples,)
        
        # Vectorized sampling: (max_samples, chunk_size, 3)
        sample_points = (p1_chunk[None, :, :] + 
                        t_values[:, None, None] * line_vectors_chunk[None, :, :])
        
        # Mask for valid samples (within line length)
        sample_mask = t_values[:, None] <= (n_samples_chunk[None, :] - 1) / cp.maximum(1, n_samples_chunk[None, :] - 1)
        
        # Convert to voxel indices
        vx = ((sample_points[:, :, 0] - x_min) / voxel_size_mm).astype(cp.int32)
        vy = ((sample_points[:, :, 1] - y_min) / voxel_size_mm).astype(cp.int32)
        vz = ((sample_points[:, :, 2] - z_min) / voxel_size_mm).astype(cp.int32)
        
        # Bounds checking
        valid_mask = (sample_mask & 
                     (vx >= 0) & (vx < nx) & 
                     (vy >= 0) & (vy < ny) & 
                     (vz >= 0) & (vz < nz))
        
        # Get valid indices
        valid_vx = vx[valid_mask]
        valid_vy = vy[valid_mask]
        valid_vz = vz[valid_mask]
        
        # Accumulate in voxel tensor using atomic operations
        if len(valid_vx) > 0:
            # Use bincount for efficient accumulation
            flat_indices = valid_vx * (ny * nz) + valid_vy * nz + valid_vz
            flat_tensor = voxel_tensor_gpu.ravel()
            
            # Count occurrences and add to tensor
            counts = cp.bincount(flat_indices, minlength=nx*ny*nz)
            flat_tensor += counts
    
    print("Moving result back to CPU...")
    voxel_tensor = cp.asnumpy(voxel_tensor_gpu)
    world_origin = (x_min, y_min, z_min)
    
    return voxel_tensor, world_origin

def rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    CPU fallback version if CuPy is not available.
    """
    print("Using CPU fallback (install cupy for GPU acceleration)")
    
    # Calculate world bounds
    all_xyz = pairs_coords.reshape(-1, 3)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
    
    x_min = np.floor(np.min(x_vals) / voxel_size_mm) * voxel_size_mm
    x_max = np.ceil(np.max(x_vals) / voxel_size_mm) * voxel_size_mm
    y_min = np.floor(np.min(y_vals) / voxel_size_mm) * voxel_size_mm
    y_max = np.ceil(np.max(y_vals) / voxel_size_mm) * voxel_size_mm
    z_min = np.floor(np.min(z_vals) / voxel_size_mm) * voxel_size_mm
    z_max = np.ceil(np.max(z_vals) / voxel_size_mm) * voxel_size_mm
    
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: ({nx}, {ny}, {nz})")
    
    voxel_tensor = np.zeros((nx, ny, nz), dtype=np.int32)
    
    # Process each line
    for i, line in enumerate(pairs_coords):
        if i % 10000 == 0:
            print(f"Processing line {i}/{len(pairs_coords)}")
            
        p1, p2 = line[:3], line[3:]
        line_length = np.linalg.norm(p2 - p1)
        n_samples = max(1, int(line_length * samples_per_mm))
        
        # Sample along line
        t_values = np.linspace(0, 1, n_samples)
        sample_points = p1[None, :] + t_values[:, None] * (p2 - p1)[None, :]
        
        # Convert to voxel indices
        vx = ((sample_points[:, 0] - x_min) / voxel_size_mm).astype(int)
        vy = ((sample_points[:, 1] - y_min) / voxel_size_mm).astype(int)
        vz = ((sample_points[:, 2] - z_min) / voxel_size_mm).astype(int)
        
        # Bounds checking and accumulation
        valid = ((vx >= 0) & (vx < nx) & 
                (vy >= 0) & (vy < ny) & 
                (vz >= 0) & (vz < nz))
        
        if np.any(valid):
            for j in np.where(valid)[0]:
                voxel_tensor[vx[j], vy[j], vz[j]] += 1
    
    return voxel_tensor, (x_min, y_min, z_min)

# Smart wrapper that tries GPU first, falls back to CPU
def rasterize_lors_vectorized(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    Rasterize LORs using vectorized sampling with automatic GPU/CPU selection.
    """
    try:
        import cupy as cp
        return rasterize_lors_vectorized_gpu(pairs_coords, voxel_size_mm, samples_per_mm)
    except ImportError:
        print("CuPy not available, using CPU version")
        return rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm, samples_per_mm)
    except Exception as e:
        print(f"GPU version failed ({e}), falling back to CPU")
        return rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm, samples_per_mm)

# Usage:
voxel_tensor, world_origin = rasterize_lors_vectorized(filtered_coordinates, voxel_size_mm=5, samples_per_mm=5)
visualize_voxel_tensor_3d(voxel_tensor, initial_threshold=15, voxel_size_mm=5, world_origin=world_origin)

Moving data to GPU...
Voxel size: 5mm
Voxel shape: (112, 112, 60)
Samples per mm: 5
Max samples per line: 2970
Total lines: 6591
Processing chunk 1/7
Processing chunk 2/7
Processing chunk 3/7
Processing chunk 4/7
Processing chunk 5/7
Processing chunk 6/7
Processing chunk 7/7
Moving result back to CPU...
Voxel value range: 1 to 3003
Total non-zero voxels: 265513
Initial threshold: 15
Voxel resolution: 5mm


interactive(children=(FloatSlider(value=15.0, continuous_update=False, description='Threshold:', max=3003.0, m…

In [5]:
from scipy.ndimage import gaussian_filter

def gaussian_blur_3d(voxel_tensor, sigma=1):
    """
    Apply a 3D Gaussian blur to the voxel tensor.

    Args:
        voxel_tensor: 3D numpy array to blur.
        sigma: Standard deviation for Gaussian kernel. Can be a float or a tuple of 3 floats.

    Returns:
        Blurred 3D numpy array.
    """
    blurred = gaussian_filter(voxel_tensor, sigma=sigma)
    return blurred

# Example usage:
blurred_tensor = gaussian_blur_3d(voxel_tensor, sigma=2)

visualize_voxel_tensor_3d(blurred_tensor, initial_threshold=15, voxel_size_mm=5, world_origin=world_origin)


Voxel value range: 1 to 222
Total non-zero voxels: 628886
Initial threshold: 15
Voxel resolution: 5mm


interactive(children=(FloatSlider(value=15.0, continuous_update=False, description='Threshold:', max=222.0, mi…

In [6]:
import pytomography
from pytomography.algorithms import OSEM, MLEM
from pytomography.metadata import ObjectMeta, ProjMeta
from pytomography.projectors import SystemMatrix
from pytomography.likelihoods import NegativeMSELikelihood
import matplotlib.pyplot as plt
import torch

class EXSListmodeProjMeta(ProjMeta):
    def __init__(self, shape, scanner_LUT, detector_ids, sensitivity_factor):
        self.scanner_LUT = scanner_LUT
        self.detector_ids = detector_ids
        self.sensitivity_at_ids = sensitivity_factor[*scanner_LUT[:,1:].T]
        self.shape = shape
        if (sensitivity_factor.shape[0]!=M)*(sensitivity_factor.shape[1]!=M):
            raise ValueError("sensitivity_factor should have side dimensions M")

def bin_to_detectors(coincidence_data, n_circumferential=80, n_axial=200, 
                    scanner_radius=280, axial_length=300):
    """
    Convert raw XYZ coincidence coordinates to detector pair indices for pytomography.
    
    Parameters:
    -----------
    coincidence_data : numpy.ndarray
        Shape (n_pairs, 6) containing [x1, y1, z1, x2, y2, z2] coordinates in mm
    n_circumferential : int
        Number of detector elements around circumference (default: 80)
    n_axial : int
        Number of detector rows in axial direction (default: 200)
    scanner_radius : float
        Scanner radius in mm (default: 280)
    axial_length : float 
        Scanner axial length in mm (default: 300)
    
    Returns:
    --------
    detector_ids : numpy.ndarray
        Shape (n_pairs*2,) containing flattened detector IDs for pytomography
    scanner_LUT : numpy.ndarray
        Shape (n_detectors, 3) lookup table [angle, axial, circumferential]
    detector_positions : dict
        Dictionary containing detector geometry information
    """
    
    # Input validation
    if not isinstance(coincidence_data, np.ndarray):
        raise ValueError("coincidence_data must be a numpy array")
    
    if coincidence_data.shape[1] != 6:
        raise ValueError(f"Expected shape (n_pairs, 6), got {coincidence_data.shape}")
    
    n_pairs = coincidence_data.shape[0]
    print(f"Processing {n_pairs} coincidence pairs")
    
    # Validate coordinate ranges (allowing some tolerance)
    x_range = [coincidence_data[:, [0, 3]].min(), coincidence_data[:, [0, 3]].max()]
    y_range = [coincidence_data[:, [1, 4]].min(), coincidence_data[:, [1, 4]].max()]
    z_range = [coincidence_data[:, [2, 5]].min(), coincidence_data[:, [2, 5]].max()]
    
    print(f"Input coordinate ranges:")
    print(f"  X: {x_range[0]:.2f} to {x_range[1]:.2f} mm")
    print(f"  Y: {y_range[0]:.2f} to {y_range[1]:.2f} mm") 
    print(f"  Z: {z_range[0]:.2f} to {z_range[1]:.2f} mm")
    
    # Define detector geometry
    angular_spacing = 2 * np.pi / n_circumferential  # radians per detector
    axial_spacing = axial_length / n_axial  # mm per detector row
    z_min = -axial_length / 2  # Center the axial range
    z_max = axial_length / 2
    
    print(f"\nDetector geometry:")
    print(f"  Circumferential detectors: {n_circumferential}")
    print(f"  Axial detectors: {n_axial}")
    print(f"  Angular spacing: {np.degrees(angular_spacing):.2f} degrees")
    print(f"  Axial spacing: {axial_spacing:.2f} mm")
    print(f"  Scanner radius: {scanner_radius} mm")
    print(f"  Axial range: {z_min} to {z_max} mm")
    
    def xyz_to_detector_id(x, y, z):
        """Convert XYZ coordinates to detector ID"""
        # Convert to cylindrical coordinates
        r = np.sqrt(x**2 + y**2)
        theta = np.arctan2(y, x)  # Returns [-π, π]
        
        # Normalize theta to [0, 2π]
        theta = np.where(theta < 0, theta + 2*np.pi, theta)
        
        # Bin to nearest angular detector
        angular_bin = np.round(theta / angular_spacing).astype(int)
        # Handle wraparound (detector 80 -> detector 0)
        angular_bin = angular_bin % n_circumferential
        
        # Bin to nearest axial detector
        # Clamp z values to valid range
        z_clamped = np.clip(z, z_min, z_max)
        axial_bin = np.round((z_clamped - z_min) / axial_spacing).astype(int)
        # Ensure within bounds
        axial_bin = np.clip(axial_bin, 0, n_axial - 1)
        
        # Convert to linear detector ID
        detector_id = axial_bin * n_circumferential + angular_bin
        
        return detector_id, angular_bin, axial_bin, r
    
    # Process both endpoints of each coincidence pair
    x1, y1, z1 = coincidence_data[:, 0], coincidence_data[:, 1], coincidence_data[:, 2]
    x2, y2, z2 = coincidence_data[:, 3], coincidence_data[:, 4], coincidence_data[:, 5]
    
    # Convert to detector IDs
    det_id_1, ang_1, ax_1, r_1 = xyz_to_detector_id(x1, y1, z1)
    det_id_2, ang_2, ax_2, r_2 = xyz_to_detector_id(x2, y2, z2)
    
    # Create detector pairs array - pytomography expects flattened detector IDs
    # For coincidence pairs, we need to interleave the detector IDs
    detector_ids_flat = np.empty(n_pairs * 2, dtype=int)
    detector_ids_flat[0::2] = det_id_1  # Even indices get first detector
    detector_ids_flat[1::2] = det_id_2  # Odd indices get second detector
    
    # Check for points significantly outside scanner radius
    outside_1 = r_1 > scanner_radius * 1.1  # 10% tolerance
    outside_2 = r_2 > scanner_radius * 1.1
    n_outside = np.sum(outside_1 | outside_2)
    
    if n_outside > 0:
        print(f"\nWarning: {n_outside} points are >10% outside scanner radius")
        print(f"  Max radius point 1: {r_1.max():.2f} mm")
        print(f"  Max radius point 2: {r_2.max():.2f} mm")
    
    # Create scanner lookup table (LUT) compatible with pytomography
    # Format: [angle_index, axial_index, circumferential_index]
    # For cylindrical scanners, we typically have 2 angles (0, 1) for opposing detector pairs
    angle_indices = np.array([0, 1])  # Two opposing angles
    axial_indices = np.arange(n_axial)
    circumferential_indices = np.arange(n_circumferential)
    
    # Create cartesian product for scanner LUT
    scanner_LUT = np.array(np.meshgrid(angle_indices, axial_indices, circumferential_indices, indexing='ij')).T.reshape(-1, 3)
    
    # Create detector geometry information for pytomography
    detector_positions = {
        'n_detectors_total': len(scanner_LUT),
        'n_circumferential': n_circumferential,
        'n_axial': n_axial,
        'scanner_radius': scanner_radius,
        'axial_length': axial_length,
        'angular_spacing_deg': np.degrees(angular_spacing),
        'axial_spacing_mm': axial_spacing,
        'scanner_LUT': scanner_LUT,
        'detector_ids': detector_ids_flat,
        'shape': (2, n_axial, n_circumferential)
    }
    
    # Statistics
    unique_det_1 = len(np.unique(det_id_1))
    unique_det_2 = len(np.unique(det_id_2))
    unique_detector_ids = len(np.unique(detector_ids_flat))
    
    print(f"\nBinning results:")
    print(f"  Unique detectors hit (endpoint 1): {unique_det_1}")
    print(f"  Unique detectors hit (endpoint 2): {unique_det_2}")
    print(f"  Total unique detector IDs: {unique_detector_ids}")
    print(f"  Total detector events: {len(detector_ids_flat)}")
    print(f"  Detector ID range: {detector_ids_flat.min()} to {detector_ids_flat.max()}")
    print(f"  Scanner LUT shape: {scanner_LUT.shape}")
    
    return detector_ids_flat, scanner_LUT, detector_positions


detector_ids, scanner_LUT, geometry = bin_to_detectors(filtered_coordinates,
                                                       n_circumferential=80,
                                                       n_axial=200,
                                                       scanner_radius=280,
                                                       axial_length=300)



# You can then create the pytomography projection metadata:
# proj_meta_listmode = EXSListmodeProjMeta(
#     shape=geometry['shape'], 
#     scanner_LUT=scanner_LUT, 
#     detector_ids=detector_ids, 
#     sensitivity_factor=your_sensitivity_matrix
# )

Processing 6591 coincidence pairs
Input coordinate ranges:
  X: -278.13 to 278.17 mm
  Y: -278.42 to 277.88 mm
  Z: -147.99 to 147.95 mm

Detector geometry:
  Circumferential detectors: 80
  Axial detectors: 200
  Angular spacing: 4.50 degrees
  Axial spacing: 1.50 mm
  Scanner radius: 280 mm
  Axial range: -150.0 to 150.0 mm

Binning results:
  Unique detectors hit (endpoint 1): 5216
  Unique detectors hit (endpoint 2): 5119
  Total unique detector IDs: 8468
  Total detector events: 13182
  Detector ID range: 91 to 15985
  Scanner LUT shape: (32000, 3)
