In [1]:
###############################################################################
# Initial Loading, Filtering, and Coordinate Range Calculation
###############################################################################
import numpy as np

# Load coordinates
coordinates = np.load(fr"C:\Users\h\Desktop\PetStuff\Image_Processing\ground_truth.npy")

# Confirm shape should be (pairs, coords=6), coords are (x1, y1, z1, x2, y2, z2)
print(f"\nData Shape (pairs, coords) : {coordinates.shape}\n")  

# Remove pairs where any coordinate value is exactly 0
filtered_coordinates = coordinates[~np.any(coordinates == 0, axis=1)]
print(f"Filtered shape: {filtered_coordinates.shape}\n")

# Extract all x, y, z pairs - Coordinates are in the order (x1, y1, z1, x2, y2, z2)
all_xyz = filtered_coordinates.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")


Data Shape (pairs, coords) : (62660, 6)

Filtered shape: (6591, 6)

x range: min=-278.12942504882807, max=278.1666564941406
y range: min=-278.41946411132807, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453


In [2]:
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider

def visualize_voxel_tensor_3d(voxel_tensor, initial_threshold=None, voxel_size_mm=1.0, world_origin=None):
    """
    Interactive 3D visualization of voxel tensor with threshold slider.
    
    Args:
        voxel_tensor: (nx, ny, nz) numpy array with voxel counts
        initial_threshold: Initial threshold value for the slider (default: halfway between min and max)
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0) (optional)
    """
    # Extract non-zero voxel coordinates and values
    coords = np.where(voxel_tensor > 0)
    x_coords, y_coords, z_coords = coords
    values = voxel_tensor[coords]
    
    # Convert voxel indices to world coordinates if world_origin provided
    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords_world = x_coords * voxel_size_mm + x_min
        y_coords_world = y_coords * voxel_size_mm + y_min
        z_coords_world = z_coords * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords_world = x_coords * voxel_size_mm
        y_coords_world = y_coords * voxel_size_mm
        z_coords_world = z_coords * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"
    
    # Get value range for slider
    min_val = int(np.min(values))
    max_val = int(np.max(values))
    
    # Set initial threshold to halfway between min and max if not provided
    if initial_threshold is None:
        initial_threshold = int((min_val + max_val) / 2)
    else:
        initial_threshold = max(min_val, min(max_val, int(initial_threshold)))  # Clamp to valid range
    
    print(f"Voxel value range: {min_val} to {max_val}")
    print(f"Total non-zero voxels: {len(values)}")
    print(f"Initial threshold: {initial_threshold}")
    print(f"Voxel resolution: {voxel_size_mm}mm")
    
    def update_plot(threshold):
        # Filter voxels above threshold
        mask = values >= threshold
        if not np.any(mask):
            print(f"No voxels above threshold {threshold}")
            return
        
        filtered_x = x_coords_world[mask]
        filtered_y = y_coords_world[mask]
        filtered_z = z_coords_world[mask]
        filtered_values = values[mask]
        
        # Create 3D scatter plot
        fig = go.Figure(data=go.Scatter3d(
            x=filtered_x,
            y=filtered_y,
            z=filtered_z,
            mode='markers',
            marker=dict(
                size=3,
                color=filtered_values,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Voxel Count"),
                line=dict(width=0)
            ),
            text=[f'Count: {v}' for v in filtered_values],
            hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>'
        ))
        
        # Set layout for orbital controls
        fig.update_layout(
            title=f'3D Voxel Visualization (Threshold: {threshold}, Showing: {len(filtered_values)} voxels)',
            scene=dict(
                xaxis_title=f'X{coord_suffix}',
                yaxis_title=f'Y{coord_suffix}',
                zaxis_title=f'Z{coord_suffix}',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                ),
                aspectmode='cube'
            ),
            width=800,
            height=600
        )
        
        fig.show()
    
    # Create interactive slider with initial threshold
    threshold_slider =  FloatSlider(
        value=initial_threshold,
        min=min_val,
        max=max_val,
        step=1,
        description='Threshold:',
        continuous_update=False
    )
    
    # Display interactive widget
    interact(update_plot, threshold=threshold_slider)




In [3]:
# Rasterisation function
def binary_rasterize_lors_3d_dda(pairs_coords, voxel_size_mm=1.0):
    """
    Rasterize lines of response into 3D voxel space using exact 3D DDA traversal.
    
    Args:
        pairs_coords: (N, 6) array where each row is [x1,y1,z1,x2,y2,z2]
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
    
    Returns:
        voxel_tensor: (nx, ny, nz) array with line traversal counts
    """
    # World coordinate ranges (mm)
    all_xyz = pairs_coords.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]

    x_min = x_vals.min()
    x_max = x_vals.max()
    y_min = y_vals.min()
    y_max = y_vals.max()
    z_min = z_vals.min()
    z_max = z_vals.max()

    # Round the min and max values to voxel boundaries
    x_min = np.floor(x_min / voxel_size_mm) * voxel_size_mm
    x_max = np.ceil(x_max / voxel_size_mm) * voxel_size_mm
    y_min = np.floor(y_min / voxel_size_mm) * voxel_size_mm
    y_max = np.ceil(y_max / voxel_size_mm) * voxel_size_mm
    z_min = np.floor(z_min / voxel_size_mm) * voxel_size_mm
    z_max = np.ceil(z_max / voxel_size_mm) * voxel_size_mm

    # Calculate voxel tensor shape based on resolution
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    voxel_shape = (nx, ny, nz)
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: {voxel_shape} (nx, ny, nz)")
    
    voxel_tensor = np.zeros(voxel_shape, dtype=np.int32)
    
    # Coordinate transformation: world -> voxel indices
    def world_to_voxel(coords):
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        vx = (x - x_min) / voxel_size_mm
        vy = (y - y_min) / voxel_size_mm
        vz = (z - z_min) / voxel_size_mm
        return np.column_stack([vx, vy, vz])
    
    def dda_3d(p1, p2):
        """3D DDA algorithm - returns list of (x,y,z) voxel indices along line"""
        x1, y1, z1 = p1
        x2, y2, z2 = p2
        
        # Direction and step sizes
        dx = x2 - x1
        dy = y2 - y1
        dz = z2 - z1
        
        # Number of steps is maximum of absolute differences
        steps = int(max(abs(dx), abs(dy), abs(dz)))
        if steps == 0:
            return [(int(round(x1)), int(round(y1)), int(round(z1)))]
        
        # Step increments
        x_inc = dx / steps
        y_inc = dy / steps
        z_inc = dz / steps
        
        voxels = []
        x, y, z = x1, y1, z1
        
        for _ in range(steps + 1):
            vx, vy, vz = int(round(x)), int(round(y)), int(round(z))
            
            # Check bounds
            if 0 <= vx < nx and 0 <= vy < ny and 0 <= vz < nz:
                voxels.append((vx, vy, vz))
            
            x += x_inc
            y += y_inc
            z += z_inc
        
        return voxels
    
    # Transform all coordinates to voxel space
    points1 = world_to_voxel(pairs_coords[:, :3])  # First points
    points2 = world_to_voxel(pairs_coords[:, 3:])  # Second points
    
    # Process each LOR
    for i in range(len(pairs_coords)):
        p1 = points1[i]
        p2 = points2[i]
        
        # Get voxels along this line
        voxels = dda_3d(p1, p2)
        
        # Increment voxel counts
        for vx, vy, vz in voxels:
            voxel_tensor[vx, vy, vz] += 1
    
    return voxel_tensor

voxel_tensor = binary_rasterize_lors_3d_dda(filtered_coordinates, voxel_size_mm=1)

print(f"Voxel tensor shape: {voxel_tensor.shape}")

Voxel size: 1mm
Voxel shape: (558, 557, 296) (nx, ny, nz)
Voxel tensor shape: (558, 557, 296)


In [4]:
import cupy as cp
import numpy as np

def rasterize_lors_vectorized_gpu(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    GPU-accelerated line rasterization using vectorized sampling.
    
    Args:
        pairs_coords: (N, 6) array where each row is [x1,y1,z1,x2,y2,z2]
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        samples_per_mm: Number of samples per mm along each line (default: 3)
                       Higher = more accurate but slower
    
    Returns:
        voxel_tensor: (nx, ny, nz) array with line traversal counts
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0)
    """
    print("Moving data to GPU...")
    pairs_gpu = cp.asarray(pairs_coords)
    
    # Calculate world bounds
    all_xyz = pairs_gpu.reshape(-1, 3)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
    
    x_min = float(cp.floor(cp.min(x_vals) / voxel_size_mm) * voxel_size_mm)
    x_max = float(cp.ceil(cp.max(x_vals) / voxel_size_mm) * voxel_size_mm)
    y_min = float(cp.floor(cp.min(y_vals) / voxel_size_mm) * voxel_size_mm)
    y_max = float(cp.ceil(cp.max(y_vals) / voxel_size_mm) * voxel_size_mm)
    z_min = float(cp.floor(cp.min(z_vals) / voxel_size_mm) * voxel_size_mm)
    z_max = float(cp.ceil(cp.max(z_vals) / voxel_size_mm) * voxel_size_mm)
    
    # Calculate voxel tensor shape
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: ({nx}, {ny}, {nz})")
    print(f"Samples per mm: {samples_per_mm}")
    
    # Extract line endpoints
    p1 = pairs_gpu[:, :3]  # (N, 3) start points
    p2 = pairs_gpu[:, 3:]  # (N, 3) end points
    
    # Calculate line properties
    line_vectors = p2 - p1  # (N, 3)
    line_lengths = cp.linalg.norm(line_vectors, axis=1)  # (N,)
    
    # Calculate number of samples per line
    n_samples_per_line = cp.maximum(1, (line_lengths * samples_per_mm).astype(cp.int32))
    max_samples = int(cp.max(n_samples_per_line))
    
    print(f"Max samples per line: {max_samples}")
    print(f"Total lines: {len(pairs_coords)}")
    
    # Create voxel tensor on GPU
    voxel_tensor_gpu = cp.zeros((nx, ny, nz), dtype=cp.int32)
    
    # Process lines in chunks to manage memory
    chunk_size = min(1000, len(pairs_coords))  # Adjust based on GPU memory
    
    for chunk_start in range(0, len(pairs_coords), chunk_size):
        chunk_end = min(chunk_start + chunk_size, len(pairs_coords))
        chunk_size_actual = chunk_end - chunk_start
        
        print(f"Processing chunk {chunk_start//chunk_size + 1}/{(len(pairs_coords)-1)//chunk_size + 1}")
        
        # Get chunk data
        p1_chunk = p1[chunk_start:chunk_end]  # (chunk_size, 3)
        p2_chunk = p2[chunk_start:chunk_end]  # (chunk_size, 3)
        line_vectors_chunk = line_vectors[chunk_start:chunk_end]  # (chunk_size, 3)
        n_samples_chunk = n_samples_per_line[chunk_start:chunk_end]  # (chunk_size,)
        
        # Create parameter array for sampling
        t_values = cp.linspace(0, 1, max_samples)  # (max_samples,)
        
        # Vectorized sampling: (max_samples, chunk_size, 3)
        sample_points = (p1_chunk[None, :, :] + 
                        t_values[:, None, None] * line_vectors_chunk[None, :, :])
        
        # Mask for valid samples (within line length)
        sample_mask = t_values[:, None] <= (n_samples_chunk[None, :] - 1) / cp.maximum(1, n_samples_chunk[None, :] - 1)
        
        # Convert to voxel indices
        vx = ((sample_points[:, :, 0] - x_min) / voxel_size_mm).astype(cp.int32)
        vy = ((sample_points[:, :, 1] - y_min) / voxel_size_mm).astype(cp.int32)
        vz = ((sample_points[:, :, 2] - z_min) / voxel_size_mm).astype(cp.int32)
        
        # Bounds checking
        valid_mask = (sample_mask & 
                     (vx >= 0) & (vx < nx) & 
                     (vy >= 0) & (vy < ny) & 
                     (vz >= 0) & (vz < nz))
        
        # Get valid indices
        valid_vx = vx[valid_mask]
        valid_vy = vy[valid_mask]
        valid_vz = vz[valid_mask]
        
        # Accumulate in voxel tensor using atomic operations
        if len(valid_vx) > 0:
            # Use bincount for efficient accumulation
            flat_indices = valid_vx * (ny * nz) + valid_vy * nz + valid_vz
            flat_tensor = voxel_tensor_gpu.ravel()
            
            # Count occurrences and add to tensor
            counts = cp.bincount(flat_indices, minlength=nx*ny*nz)
            flat_tensor += counts
    
    print("Moving result back to CPU...")
    voxel_tensor = cp.asnumpy(voxel_tensor_gpu)
    world_origin = (x_min, y_min, z_min)
    
    return voxel_tensor, world_origin

def rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    CPU fallback version if CuPy is not available.
    """
    print("Using CPU fallback (install cupy for GPU acceleration)")
    
    # Calculate world bounds
    all_xyz = pairs_coords.reshape(-1, 3)
    x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
    
    x_min = np.floor(np.min(x_vals) / voxel_size_mm) * voxel_size_mm
    x_max = np.ceil(np.max(x_vals) / voxel_size_mm) * voxel_size_mm
    y_min = np.floor(np.min(y_vals) / voxel_size_mm) * voxel_size_mm
    y_max = np.ceil(np.max(y_vals) / voxel_size_mm) * voxel_size_mm
    z_min = np.floor(np.min(z_vals) / voxel_size_mm) * voxel_size_mm
    z_max = np.ceil(np.max(z_vals) / voxel_size_mm) * voxel_size_mm
    
    nx = int((x_max - x_min) / voxel_size_mm)
    ny = int((y_max - y_min) / voxel_size_mm)
    nz = int((z_max - z_min) / voxel_size_mm)
    
    print(f"Voxel size: {voxel_size_mm}mm")
    print(f"Voxel shape: ({nx}, {ny}, {nz})")
    
    voxel_tensor = np.zeros((nx, ny, nz), dtype=np.int32)
    
    # Process each line
    for i, line in enumerate(pairs_coords):
        if i % 10000 == 0:
            print(f"Processing line {i}/{len(pairs_coords)}")
            
        p1, p2 = line[:3], line[3:]
        line_length = np.linalg.norm(p2 - p1)
        n_samples = max(1, int(line_length * samples_per_mm))
        
        # Sample along line
        t_values = np.linspace(0, 1, n_samples)
        sample_points = p1[None, :] + t_values[:, None] * (p2 - p1)[None, :]
        
        # Convert to voxel indices
        vx = ((sample_points[:, 0] - x_min) / voxel_size_mm).astype(int)
        vy = ((sample_points[:, 1] - y_min) / voxel_size_mm).astype(int)
        vz = ((sample_points[:, 2] - z_min) / voxel_size_mm).astype(int)
        
        # Bounds checking and accumulation
        valid = ((vx >= 0) & (vx < nx) & 
                (vy >= 0) & (vy < ny) & 
                (vz >= 0) & (vz < nz))
        
        if np.any(valid):
            for j in np.where(valid)[0]:
                voxel_tensor[vx[j], vy[j], vz[j]] += 1
    
    return voxel_tensor, (x_min, y_min, z_min)

# Smart wrapper that tries GPU first, falls back to CPU
def rasterize_lors_vectorized(pairs_coords, voxel_size_mm=1.0, samples_per_mm=3):
    """
    Rasterize LORs using vectorized sampling with automatic GPU/CPU selection.
    """
    try:
        import cupy as cp
        return rasterize_lors_vectorized_gpu(pairs_coords, voxel_size_mm, samples_per_mm)
    except ImportError:
        print("CuPy not available, using CPU version")
        return rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm, samples_per_mm)
    except Exception as e:
        print(f"GPU version failed ({e}), falling back to CPU")
        return rasterize_lors_vectorized_cpu_fallback(pairs_coords, voxel_size_mm, samples_per_mm)

voxel_tensor, world_origin = rasterize_lors_vectorized(filtered_coordinates, voxel_size_mm=5, samples_per_mm=5)


Moving data to GPU...
Voxel size: 5mm
Voxel shape: (112, 112, 60)
Samples per mm: 5
Max samples per line: 2970
Total lines: 6591
Processing chunk 1/7
Processing chunk 2/7
Processing chunk 3/7
Processing chunk 4/7
Processing chunk 5/7
Processing chunk 6/7
Processing chunk 7/7
Moving result back to CPU...


In [5]:
visualize_voxel_tensor_3d(voxel_tensor, voxel_size_mm=5, world_origin=world_origin)

Voxel value range: 1 to 3003
Total non-zero voxels: 265513
Initial threshold: 1502
Voxel resolution: 5mm


interactive(children=(FloatSlider(value=1502.0, continuous_update=False, description='Threshold:', max=3003.0,…

In [None]:
from scipy.ndimage import gaussian_filter

def gaussian_blur_3d(voxel_tensor, sigma=1):
    """
    Apply a 3D Gaussian blur to the voxel tensor.

    Args:
        voxel_tensor: 3D numpy array to blur.
        sigma: Standard deviation for Gaussian kernel. Can be a float or a tuple of 3 floats.

    Returns:
        Blurred 3D numpy array.
    """
    blurred = gaussian_filter(voxel_tensor, sigma=sigma)
    return blurred

# Example usage:
blurred_tensor = gaussian_blur_3d(voxel_tensor, sigma=2)

visualize_voxel_tensor_3d(blurred_tensor, voxel_size_mm=5, world_origin=world_origin)


Voxel value range: 1 to 222
Total non-zero voxels: 628886
Initial threshold: 111
Voxel resolution: 5mm


interactive(children=(FloatSlider(value=111.0, continuous_update=False, description='Threshold:', max=222.0, m…

In [None]:
import numpy as np

try:
    import cupy as cp
    GPU_AVAILABLE = True
    print("GPU acceleration available")
except ImportError:
    GPU_AVAILABLE = False
    print("Using CPU only")

def rasterize_lors_gpu(pairs_coords, voxel_size=1.0, samples_per_mm=3):
    """Optimized GPU LOR rasterization"""
    pairs_gpu = cp.asarray(pairs_coords, dtype=cp.float32)
    
    # Calculate bounds
    all_coords = pairs_gpu.reshape(-1, 3)
    mins = cp.floor(cp.min(all_coords, axis=0) / voxel_size) * voxel_size
    maxs = cp.ceil(cp.max(all_coords, axis=0) / voxel_size) * voxel_size
    
    # Grid dimensions
    grid_size = ((maxs - mins) / voxel_size).astype(cp.int32)
    nx, ny, nz = int(grid_size[0]), int(grid_size[1]), int(grid_size[2])
    
    # Line properties
    p1, p2 = pairs_gpu[:, :3], pairs_gpu[:, 3:]
    lengths = cp.linalg.norm(p2 - p1, axis=1)
    n_samples = cp.maximum(2, (lengths * samples_per_mm).astype(cp.int32))
    max_samples = int(cp.max(n_samples))
    
    # Sample all lines at once
    t = cp.linspace(0, 1, max_samples, dtype=cp.float32)[:, None, None]  # (max_samples, 1, 1)
    directions = p2 - p1  # (N, 3)
    
    # Broadcast: (max_samples, N, 3)
    points = p1[None, :, :] + t * directions[None, :, :]
    
    # Valid sample mask
    sample_ids = cp.arange(max_samples)[:, None]  # (max_samples, 1)
    valid_mask = sample_ids < n_samples[None, :]  # (max_samples, N)
    
    # Convert to voxel indices
    voxel_coords = (points - mins[None, None, :]) / voxel_size
    vx = voxel_coords[:, :, 0].astype(cp.int32)
    vy = voxel_coords[:, :, 1].astype(cp.int32)
    vz = voxel_coords[:, :, 2].astype(cp.int32)
    
    # Bounds check
    bounds_mask = ((vx >= 0) & (vx < nx) & 
                   (vy >= 0) & (vy < ny) & 
                   (vz >= 0) & (vz < nz))
    
    final_mask = valid_mask & bounds_mask
    
    # Get valid indices
    valid_vx = vx[final_mask]
    valid_vy = vy[final_mask]
    valid_vz = vz[final_mask]
    
    # Accumulate using bincount
    flat_indices = valid_vx * (ny * nz) + valid_vy * nz + valid_vz
    voxel_tensor = cp.bincount(flat_indices, minlength=nx*ny*nz).reshape(nx, ny, nz)
    
    return voxel_tensor, mins

def rasterize_lors_cpu(pairs_coords, voxel_size=1.0, samples_per_mm=3):
    """Optimized CPU LOR rasterization"""
    # Calculate bounds
    all_coords = pairs_coords.reshape(-1, 3)
    mins = np.floor(np.min(all_coords, axis=0) / voxel_size) * voxel_size
    maxs = np.ceil(np.max(all_coords, axis=0) / voxel_size) * voxel_size
    
    grid_size = ((maxs - mins) / voxel_size).astype(int)
    nx, ny, nz = grid_size[0], grid_size[1], grid_size[2]
    
    voxel_tensor = np.zeros((nx, ny, nz), dtype=np.int32)
    
    for line in pairs_coords:
        p1, p2 = line[:3], line[3:]
        length = np.linalg.norm(p2 - p1)
        n_samples = max(2, int(length * samples_per_mm))
        
        t = np.linspace(0, 1, n_samples)
        points = p1[None, :] + t[:, None] * (p2 - p1)[None, :]
        
        # Convert to voxel indices
        voxel_coords = (points - mins[None, :]) / voxel_size
        vx = voxel_coords[:, 0].astype(int)
        vy = voxel_coords[:, 1].astype(int)
        vz = voxel_coords[:, 2].astype(int)
        
        # Bounds check and accumulate
        valid = ((vx >= 0) & (vx < nx) & 
                (vy >= 0) & (vy < ny) & 
                (vz >= 0) & (vz < nz))
        
        voxel_tensor[vx[valid], vy[valid], vz[valid]] += 1
    
    return voxel_tensor, mins

class MLEMReconstructor:
    def __init__(self, use_gpu=None):
        if use_gpu is None:
            self.use_gpu = GPU_AVAILABLE
        else:
            self.use_gpu = use_gpu and GPU_AVAILABLE
            
        self.xp = cp if self.use_gpu else np
        print(f"Using {'GPU' if self.use_gpu else 'CPU'}")
    
    def rasterize_lors(self, pairs_coords, voxel_size=1.0):
        """Unified LOR rasterization"""
        if self.use_gpu:
            return rasterize_lors_gpu(pairs_coords, voxel_size)
        else:
            return rasterize_lors_cpu(pairs_coords, voxel_size)
    
    def forward_project_vectorized(self, image, lors, world_origin, voxel_size=1.0, attenuation_map=None):
        """Vectorized forward projection with attenuation correction"""
        if self.use_gpu and not isinstance(image, cp.ndarray):
            image = cp.asarray(image)
        if attenuation_map is not None and self.use_gpu and not isinstance(attenuation_map, cp.ndarray):
            attenuation_map = cp.asarray(attenuation_map)
        
        # Process all LORs at once
        p1, p2 = lors[:, :3], lors[:, 3:]
        lengths = self.xp.linalg.norm(p2 - p1, axis=1)
        
        # Use fixed sampling rate for all lines
        n_samples = 8  # More samples for better attenuation modeling
        t = self.xp.linspace(0, 1, n_samples, dtype=self.xp.float32)[:, None, None]
        
        # Sample all lines: (n_samples, N_lors, 3)
        points = p1[None, :, :] + t * (p2 - p1)[None, :, :]
        
        # Convert to voxel indices
        voxel_coords = (points - world_origin[None, None, :]) / voxel_size
        vx = self.xp.floor(voxel_coords[:, :, 0]).astype(int)
        vy = self.xp.floor(voxel_coords[:, :, 1]).astype(int)
        vz = self.xp.floor(voxel_coords[:, :, 2]).astype(int)
        
        # Bounds check
        valid = ((vx >= 0) & (vx < image.shape[0]) & 
                (vy >= 0) & (vy < image.shape[1]) & 
                (vz >= 0) & (vz < image.shape[2]))
        
        # Sum activity values along each LOR
        expected = self.xp.zeros(len(lors), dtype=self.xp.float32)
        attenuation_factors = self.xp.ones(len(lors), dtype=self.xp.float32)
        
        for i in range(n_samples):
            mask = valid[i]
            if self.xp.any(mask):
                # Add activity contribution
                expected[mask] += image[vx[i][mask], vy[i][mask], vz[i][mask]]
                
                # Accumulate attenuation along path (more conservative)
                if attenuation_map is not None:
                    # Clamp attenuation to reasonable values
                    mu = self.xp.clip(attenuation_map[vx[i][mask], vy[i][mask], vz[i][mask]], 0.0, 1.0)
                    # Path length per sample
                    dl = lengths[mask] / n_samples
                    # More conservative attenuation: exp(-mu * dl * 0.1)
                    attenuation_factors[mask] *= self.xp.exp(-mu * dl * 0.1)
        
        # Apply attenuation correction to expected counts (clamp to prevent zeros)
        if attenuation_map is not None:
            # Prevent attenuation from going to zero
            attenuation_factors = self.xp.maximum(attenuation_factors, 0.01)
            expected *= attenuation_factors
        
        # Ensure minimum expected counts to prevent instability
        expected = self.xp.maximum(expected, 0.01)
        
        return expected
    
    def backproject_vectorized(self, ratios, lors, image_shape, world_origin, voxel_size=1.0, attenuation_map=None):
        """Vectorized backprojection with attenuation correction"""
        update = self.xp.ones(image_shape, dtype=self.xp.float32)
        
        # Clamp ratios to prevent instability
        ratios = self.xp.clip(ratios, 0.1, 10.0)  # Reasonable ratio bounds
        
        # Process all LORs at once with fixed sampling
        p1, p2 = lors[:, :3], lors[:, 3:]
        lengths = self.xp.linalg.norm(p2 - p1, axis=1)
        n_samples = 8  # Match forward projection
        t = self.xp.linspace(0, 1, n_samples, dtype=self.xp.float32)[:, None, None]
        
        # Sample all lines: (n_samples, N_lors, 3)
        points = p1[None, :, :] + t * (p2 - p1)[None, :, :]
        
        # Convert to voxel indices
        voxel_coords = (points - world_origin[None, None, :]) / voxel_size
        vx = self.xp.floor(voxel_coords[:, :, 0]).astype(int)
        vy = self.xp.floor(voxel_coords[:, :, 1]).astype(int)
        vz = self.xp.floor(voxel_coords[:, :, 2]).astype(int)
        
        # Bounds check
        valid = ((vx >= 0) & (vx < image_shape[0]) & 
                (vy >= 0) & (vy < image_shape[1]) & 
                (vz >= 0) & (vz < image_shape[2]))
        
        # Calculate attenuation factors for each LOR if provided
        attenuation_factors = self.xp.ones(len(lors), dtype=self.xp.float32)
        if attenuation_map is not None:
            if self.use_gpu and not isinstance(attenuation_map, cp.ndarray):
                attenuation_map = cp.asarray(attenuation_map)
            
            for i in range(n_samples):
                mask = valid[i]
                if self.xp.any(mask):
                    mu = self.xp.clip(attenuation_map[vx[i][mask], vy[i][mask], vz[i][mask]], 0.0, 1.0)
                    dl = lengths[mask] / n_samples
                    attenuation_factors[mask] *= self.xp.exp(-mu * dl * 0.1)
            
            # Prevent attenuation from going to zero
            attenuation_factors = self.xp.maximum(attenuation_factors, 0.01)
        
        # Apply corrected ratios to all samples along each LOR
        corrected_ratios = ratios * attenuation_factors
        
        for i in range(n_samples):
            mask = valid[i]
            if self.xp.any(mask):
                valid_ratios = corrected_ratios[mask]
                update[vx[i][mask], vy[i][mask], vz[i][mask]] *= valid_ratios
        
        return update
    
    def reconstruct(self, lors, measured_counts, voxel_size=1.0, iterations=10, attenuation_map=None):
        """Optimized MLEM reconstruction with optional attenuation correction"""
        print(f"Starting MLEM reconstruction: {len(lors)} LORs, {iterations} iterations")
        if attenuation_map is not None:
            print("Using attenuation correction")
        
        # Convert inputs to appropriate arrays
        lors = self.xp.asarray(lors, dtype=self.xp.float32)
        measured_counts = self.xp.asarray(measured_counts, dtype=self.xp.float32)
        
        # Initial rasterization
        print("Generating initial image...")
        voxel_tensor, world_origin = self.rasterize_lors(lors, voxel_size)
        
        # Initialize image (rasterization + small uniform)
        image = voxel_tensor.astype(self.xp.float32) + 0.1
        
        # Handle attenuation map
        if attenuation_map is not None:
            if self.use_gpu and not isinstance(attenuation_map, cp.ndarray):
                attenuation_map = cp.asarray(attenuation_map)
            
            if attenuation_map.shape != image.shape:
                raise ValueError(f"Attenuation map shape {attenuation_map.shape} must match image shape {image.shape}")
            
            # Apply attenuation prior to initial image
            # Higher attenuation = lower initial activity
            attenuation_prior = self.xp.exp(-attenuation_map * 0.1)  # Gentle prior
            image *= attenuation_prior
        
        print(f"Image shape: {image.shape}")
        print(f"World origin: {world_origin}")
        
        # MLEM iterations
        for it in range(iterations):
            print(f"Iteration {it+1}/{iterations}")
            
            # Forward projection with attenuation
            expected = self.forward_project_vectorized(image, lors, world_origin, voxel_size, attenuation_map)
            expected = self.xp.maximum(expected, 1e-6)  # More conservative minimum
            
            # Calculate ratios with bounds
            ratios = self.xp.clip(measured_counts / expected, 0.01, 100.0)
            
            # Backprojection with attenuation
            update = self.backproject_vectorized(ratios, lors, image.shape, world_origin, voxel_size, attenuation_map)
            
            # Conservative update to prevent explosion
            update = self.xp.clip(update, 0.5, 2.0)
            
            # Update image
            image *= update
            
            # Prevent image explosion
            image = self.xp.clip(image, 0.0, 1000.0)
            
            # Apply attenuation regularization (optional)
            if attenuation_map is not None:
                # Gently suppress activity in high-attenuation regions
                attenuation_regularizer = self.xp.exp(-attenuation_map * 0.01)
                image *= (1.0 - 0.1) + 0.1 * attenuation_regularizer
            
            # Stats
            ratio_stats = f"ratios: {float(self.xp.min(ratios)):.3f} to {float(self.xp.max(ratios)):.3f}"
            image_stats = f"image: {float(self.xp.min(image)):.3f} to {float(self.xp.max(image)):.3f}"
            print(f"  {ratio_stats}, {image_stats}")
        
        # Return CPU version for compatibility
        if self.use_gpu:
            image = cp.asnumpy(image)
            world_origin = cp.asnumpy(world_origin)
        
        return image, world_origin

# Simple interface
def mlem_reconstruct(lors, measured_counts, voxel_size=1.0, iterations=10, use_gpu=None, attenuation_map=None):
    """Simple MLEM reconstruction interface with attenuation correction"""
    reconstructor = MLEMReconstructor(use_gpu=use_gpu)
    return reconstructor.reconstruct(lors, measured_counts, voxel_size, iterations, attenuation_map)





shape = (558, 557, 296)

# Create a 3D array of shape (558, 557, 296) with values increasing from 0 to max_attenuation along the z axis
def create_rectangular_attenuation(shape, max_attenuation):
    return np.broadcast_to(
        np.linspace(0, max_attenuation, shape[2], dtype=np.float32),
        shape
    )

def create_radial_corner_gradient(shape, max_intensity):
    # Create a grid of coordinates
    x = np.arange(shape[0])[:, None, None]
    y = np.arange(shape[1])[None, :, None]
    z = np.arange(shape[2])[None, None, :]
    # Diagonal corner at (0,0,0), opposite at (shape[0]-1, shape[1]-1, shape[2]-1)
    center = np.array([0, 0, 0])
    far_corner = np.array([shape[0]-1, shape[1]-1, shape[2]-1])
    # Compute distance from (0,0,0) for each voxel
    dist = np.sqrt((x - center[0])**2 + (y - center[1])**2 + (z - center[2])**2)
    max_dist = np.linalg.norm(far_corner - center)
    # Intensity decreases from max_intensity at (0,0,0) to 0 at far corner
    intensity = max_intensity * (1 - dist / max_dist)
    intensity = np.clip(intensity, 0, max_intensity)
    return intensity.astype(np.float32)


# Example usage:
max_attenuation = 10  # User-specified maximum attenuation value
rectangular = create_rectangular_attenuation(shape, max_attenuation)
radial_corner_gradient = create_radial_corner_gradient(shape, max_attenuation)


counts = np.random.poisson(100, filtered_coordinates.shape[0])  # Poisson noise
reconstructed, origin = mlem_reconstruct(filtered_coordinates,
                                         counts,
                                         voxel_size=1.0,
                                         attenuation_map=radial_corner_gradient,
                                         iterations=500)

GPU acceleration available
Using GPU
Starting MLEM reconstruction: 6591 LORs, 500 iterations
Using attenuation correction
Generating initial image...
Image shape: (558, 557, 296)
World origin: [-279. -279. -148.]
Iteration 1/500
  ratios: 6.248 to 100.000, image: 0.021 to 79.762
Iteration 2/500
  ratios: 3.262 to 100.000, image: 0.010 to 113.468
Iteration 3/500
  ratios: 1.634 to 100.000, image: 0.005 to 118.089
Iteration 4/500
  ratios: 1.200 to 100.000, image: 0.003 to 118.089
Iteration 5/500
  ratios: 1.486 to 100.000, image: 0.001 to 118.089
Iteration 6/500
  ratios: 1.109 to 100.000, image: 0.001 to 118.089
Iteration 7/500
  ratios: 1.369 to 100.000, image: 0.000 to 118.089
Iteration 8/500
  ratios: 1.111 to 100.000, image: 0.000 to 118.089
Iteration 9/500
  ratios: 1.364 to 100.000, image: 0.000 to 118.089
Iteration 10/500
  ratios: 1.112 to 100.000, image: 0.000 to 118.088
Iteration 11/500
  ratios: 1.364 to 100.000, image: 0.000 to 118.088
Iteration 12/500
  ratios: 1.113 to 10

In [None]:
reconstructed.shape
visualize_voxel_tensor_3d(reconstructed, voxel_size_mm=1, world_origin=world_origin, initial_threshold=20)

Voxel value range: 0 to 118
Total non-zero voxels: 91946780
Initial threshold: 20
Voxel resolution: 1mm


interactive(children=(FloatSlider(value=20.0, continuous_update=False, description='Threshold:', max=118.0, st…