In [None]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional
import math

class CUDAMLEMReconstructor:
    """
    Maximum Likelihood Expectation Maximization for PET reconstruction
    with scatter, randoms, and attenuation corrections using CUDA acceleration.
    """
    
    def __init__(self, lors: torch.Tensor, voxel_size: float = 2.0, device: str = 'cuda'):
        """
        Initialize MLEM reconstructor.
        
        Args:
            lors: Tensor of shape (num_lors, 6) with [x1,y1,z1,x2,y2,z2] endpoints in mm
            voxel_size: Voxel size in mm
            device: 'cuda' or 'cpu'
        """
        self.device = torch.device(device)
        self.lors = lors.to(self.device).float()
        self.voxel_size = voxel_size
        self.num_lors = lors.shape[0]
        
        # Calculate voxel space bounds
        self._calculate_voxel_space()
        
        # Initialize system matrix storage
        self.system_matrix = None
        self.sensitivity_image = None
        
    def _calculate_voxel_space(self):
        """Calculate voxel grid dimensions from LOR coordinates."""
        # Get coordinate bounds with small padding
        coords = self.lors.view(-1, 3)  # Reshape to (2*num_lors, 3)
        min_coords = coords.min(dim=0)[0] - self.voxel_size
        max_coords = coords.max(dim=0)[0] + self.voxel_size
        
        # Calculate grid dimensions
        self.grid_size = ((max_coords - min_coords) / self.voxel_size).ceil().int()
        self.grid_origin = min_coords
        
        print(f"Voxel grid: {self.grid_size.tolist()} voxels")
        print(f"Grid origin: {self.grid_origin.tolist()} mm")
        print(f"Voxel size: {self.voxel_size} mm")
        
    def _siddon_ray_trace(self, lors_batch: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        GPU-accelerated Siddon ray tracing for LOR-voxel intersections.
        
        Returns:
            voxel_indices: Flattened voxel indices
            lor_indices: Corresponding LOR indices  
            weights: Intersection lengths
        """
        batch_size = lors_batch.shape[0]
        
        # Convert LOR endpoints to voxel coordinates
        p1 = (lors_batch[:, :3] - self.grid_origin) / self.voxel_size
        p2 = (lors_batch[:, 3:] - self.grid_origin) / self.voxel_size
        
        # Ray direction and length
        ray_dir = p2 - p1
        ray_length = torch.norm(ray_dir, dim=1, keepdim=True)
        ray_dir = ray_dir / (ray_length + 1e-10)
        
        # Initialize storage for intersections
        max_intersections = int(torch.max(self.grid_size).item()) * 2
        voxel_indices = []
        lor_indices = []
        weights = []
        
        # Process each LOR
        for i in range(batch_size):
            voxels, lens = self._trace_single_ray(p1[i], p2[i], ray_dir[i])
            if len(voxels) > 0:
                voxel_indices.extend(voxels)
                lor_indices.extend([i] * len(voxels))
                weights.extend(lens)
        
        if len(voxel_indices) == 0:
            return torch.empty(0, dtype=torch.long, device=self.device), \
                   torch.empty(0, dtype=torch.long, device=self.device), \
                   torch.empty(0, dtype=torch.float, device=self.device)
        
        return torch.tensor(voxel_indices, device=self.device), \
               torch.tensor(lor_indices, device=self.device), \
               torch.tensor(weights, device=self.device) * self.voxel_size
    
    def _trace_single_ray(self, p1: torch.Tensor, p2: torch.Tensor, ray_dir: torch.Tensor):
        """Trace a single ray through the voxel grid using Siddon's algorithm."""
        voxels = []
        lengths = []
        
        # Current position
        pos = p1.clone()
        end_pos = p2
        
        # Step sizes for each dimension
        step = torch.sign(ray_dir)
        delta = torch.abs(1.0 / (ray_dir + 1e-10))
        
        # Current voxel
        current_voxel = pos.floor().int()
        
        # Distance to next voxel boundary
        if ray_dir[0] > 0:
            next_x = (current_voxel[0] + 1).float()
        else:
            next_x = current_voxel[0].float()
        
        if ray_dir[1] > 0:
            next_y = (current_voxel[1] + 1).float()
        else:
            next_y = current_voxel[1].float()
            
        if ray_dir[2] > 0:
            next_z = (current_voxel[2] + 1).float()
        else:
            next_z = current_voxel[2].float()
        
        t_max = torch.tensor([
            abs((next_x - pos[0]) / (ray_dir[0] + 1e-10)),
            abs((next_y - pos[1]) / (ray_dir[1] + 1e-10)), 
            abs((next_z - pos[2]) / (ray_dir[2] + 1e-10))
        ], device=self.device)
        
        t_current = 0.0
        t_end = torch.norm(end_pos - pos).item()
        
        # Traverse voxels
        while t_current < t_end:
            # Check if current voxel is valid
            if (current_voxel >= 0).all() and (current_voxel < self.grid_size).all():
                # Calculate intersection length
                t_next = min(t_max.min().item(), t_end)
                length = t_next - t_current
                
                if length > 1e-6:
                    voxel_idx = (current_voxel[0] * self.grid_size[1] * self.grid_size[2] + 
                               current_voxel[1] * self.grid_size[2] + current_voxel[2]).item()
                    voxels.append(voxel_idx)
                    lengths.append(length)
            
            # Move to next voxel
            if t_max.min() >= t_end:
                break
                
            min_dim = t_max.argmin()
            t_current = t_max[min_dim].item()
            current_voxel[min_dim] += step[min_dim].int()
            t_max[min_dim] += delta[min_dim]
        
        return voxels, lengths
    
    def build_system_matrix(self, batch_size: int = 10):
        """Build sparse system matrix using batched ray tracing."""
        print("Building system matrix...")
        
        all_voxel_indices = []
        all_lor_indices = []
        all_weights = []
        
        # Process LORs in batches to manage memory
        num_batches = (self.num_lors + batch_size - 1) // batch_size
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, self.num_lors)
            
            lors_batch = self.lors[start_idx:end_idx]
            voxel_idx, lor_idx, weights = self._siddon_ray_trace(lors_batch)
            
            # Adjust LOR indices for global indexing
            lor_idx += start_idx
            
            all_voxel_indices.append(voxel_idx)
            all_lor_indices.append(lor_idx)
            all_weights.append(weights)
            
            if (batch_idx + 1) % 10 == 0:
                print(f"Processed batch {batch_idx + 1}/{num_batches}")
        
        # Concatenate all results
        voxel_indices = torch.cat(all_voxel_indices)
        lor_indices = torch.cat(all_lor_indices)
        weights = torch.cat(all_weights)
        
        # Create sparse system matrix
        total_voxels = self.grid_size.prod().item()
        indices = torch.stack([lor_indices, voxel_indices])
        
        self.system_matrix = torch.sparse_coo_tensor(
            indices, weights, (self.num_lors, total_voxels), device=self.device
        ).coalesce()
        
        print(f"System matrix: {self.system_matrix.shape}, {len(weights)} non-zero elements")
        
        # Calculate sensitivity image (column sums)
        self.sensitivity_image = torch.sparse.sum(self.system_matrix, dim=0).to_dense()
        self.sensitivity_image[self.sensitivity_image == 0] = 1.0  # Avoid division by zero
    
    def estimate_scatter(self, image: torch.Tensor, scatter_fraction: float = 0.15) -> torch.Tensor:
        """
        Simple scatter estimation using convolution-based approach.
        More sophisticated scatter models (Monte Carlo, single scatter simulation) 
        could be implemented here.
        """
        # Reshape image to 3D for convolution
        img_3d = image.view(self.grid_size[0], self.grid_size[1], self.grid_size[2]).unsqueeze(0).unsqueeze(0)
        
        # Simple Gaussian scatter kernel (approximation)
        kernel_size = 7
        sigma = 2.0
        kernel = torch.zeros(1, 1, kernel_size, kernel_size, kernel_size, device=self.device)
        
        # Create 3D Gaussian kernel
        center = kernel_size // 2
        for i in range(kernel_size):
            for j in range(kernel_size):
                for k in range(kernel_size):
                    dist_sq = (i - center)**2 + (j - center)**2 + (k - center)**2
                    kernel[0, 0, i, j, k] = math.exp(-dist_sq / (2 * sigma**2))
        
        kernel = kernel / kernel.sum()
        
        # Apply convolution for scatter estimate
        scattered = F.conv3d(img_3d, kernel, padding=kernel_size//2)
        scattered = scattered.squeeze().flatten()
        
        # Forward project scattered activity
        scatter_sino = torch.sparse.mm(self.system_matrix, scattered.unsqueeze(1)).squeeze()
        
        return scatter_sino * scatter_fraction
    
    def estimate_randoms(self, singles_rate: Optional[torch.Tensor] = None) -> torch.Tensor:
        """
        Estimate random coincidences. 
        For simplicity, using uniform randoms. In practice, this would use
        singles rates and detector geometry.
        """
        if singles_rate is None:
            # Simple uniform randoms estimate
            randoms_rate = torch.full((self.num_lors,), 0.05, device=self.device)
        else:
            # More sophisticated randoms calculation could be implemented
            randoms_rate = singles_rate * 0.01  # Simplified calculation
        
        return randoms_rate
    
    def apply_attenuation(self, sino: torch.Tensor, attenuation_mask: torch.Tensor) -> torch.Tensor:
        """
        Apply attenuation correction using provided attenuation mask.
        
        Args:
            sino: Sinogram data
            attenuation_mask: 3D attenuation factors, same shape as voxel grid
        """
        # Forward project attenuation mask to get LOR attenuation factors
        atten_flat = attenuation_mask.flatten()
        atten_factors = torch.sparse.mm(self.system_matrix, atten_flat.unsqueeze(1)).squeeze()
        
        # Apply exponential attenuation
        atten_correction = torch.exp(-atten_factors)
        return sino * atten_correction
    
    def forward_project(self, image: torch.Tensor) -> torch.Tensor:
        """Forward project image to sinogram space."""
        return torch.sparse.mm(self.system_matrix, image.unsqueeze(1)).squeeze()
    
    def back_project(self, sino: torch.Tensor) -> torch.Tensor:
        """Back project sinogram to image space."""
        return torch.sparse.mm(self.system_matrix.t(), sino.unsqueeze(1)).squeeze()
    
    def reconstruct(self, measured_data: torch.Tensor, attenuation_mask: torch.Tensor,
                   num_iterations: int = 20, convergence_threshold: float = 1e-4) -> torch.Tensor:
        """
        Perform MLEM reconstruction with scatter, randoms, and attenuation corrections.
        
        Args:
            measured_data: Measured sinogram data (num_lors,)
            attenuation_mask: 3D attenuation correction factors
            num_iterations: Maximum number of iterations
            convergence_threshold: Relative change threshold for convergence
            
        Returns:
            Reconstructed image as 1D tensor
        """
        print("Starting MLEM reconstruction...")
        
        # Build system matrix if not already done
        if self.system_matrix is None:
            self.build_system_matrix()
        
        # Initialize image with uniform activity
        total_voxels = self.grid_size.prod().item()
        image = torch.ones(total_voxels, device=self.device)
        
        # Apply attenuation to measured data
        measured_data = self.apply_attenuation(measured_data.to(self.device), 
                                             attenuation_mask.to(self.device))
        
        prev_likelihood = float('inf')
        
        for iteration in range(num_iterations):
            # Forward project current estimate
            estimated_sino = self.forward_project(image)
            
            # Add scatter and randoms
            scatter = self.estimate_scatter(image)
            randoms = self.estimate_randoms()
            
            total_estimated = estimated_sino + scatter + randoms
            total_estimated = torch.clamp(total_estimated, min=1e-10)  # Avoid division by zero
            
            # Calculate likelihood ratio
            ratio = measured_data / total_estimated
            
            # Back project ratio
            correction = self.back_project(ratio)
            
            # MLEM update
            image = image * correction / self.sensitivity_image
            image = torch.clamp(image, min=0)  # Enforce non-negativity
            
            # Calculate log-likelihood for convergence check
            likelihood = torch.sum(measured_data * torch.log(total_estimated) - total_estimated).item()
            
            # Check convergence
            if iteration > 0:
                rel_change = abs(likelihood - prev_likelihood) / abs(prev_likelihood)
                print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}, "
                      f"Rel. change = {rel_change:.6f}")
                
                if rel_change < convergence_threshold:
                    print(f"Converged after {iteration + 1} iterations")
                    break
            else:
                print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}")
            
            prev_likelihood = likelihood
        
        return image
    
    def get_image_3d(self, image_1d: torch.Tensor) -> torch.Tensor:
        """Convert 1D image to 3D array."""
        return image_1d.view(self.grid_size[0], self.grid_size[1], self.grid_size[2])

# Example usage
def example_usage():
    """Example of how to use the MLEM reconstructor."""
    
    # Generate example LOR data (normally this would be real PET data)
    num_lors = 500
    lors = torch.randn(num_lors, 6) * 100  # Random LORs in mm
    
    # Generate example measured data
    measured_counts = torch.poisson(torch.ones(num_lors) * 100)
    
    # Initialize reconstructor
    reconstructor = CUDAMLEMReconstructor(lors, voxel_size=2.0, device='cuda')
    
    # Create example attenuation mask (uniform attenuation)
    grid_shape = reconstructor.grid_size
    attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2]) * 0.1
    
    # Perform reconstruction
    reconstructed_image = reconstructor.reconstruct(
        measured_counts, 
        attenuation_mask,
        num_iterations=10
    )
    
    # Convert to 3D for visualization
    image_3d = reconstructor.get_image_3d(reconstructed_image)
    
    print(f"Reconstructed image shape: {image_3d.shape}")
    print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
          f"mean={image_3d.mean():.3f}")
    
    return reconstructed_image, image_3d


def MLEM(input_data):
    
    
    # Generate example measured data
    measured_counts = torch.poisson(torch.ones(input_data.shape[0]) * 100)
    
    # Initialize reconstructor
    reconstructor = CUDAMLEMReconstructor(input_data, voxel_size=2.0, device='cuda')
    
    # Create example attenuation mask (uniform attenuation)
    grid_shape = reconstructor.grid_size
    attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2]) * 0.1
    
    # Perform reconstruction
    reconstructed_image = reconstructor.reconstruct(
        measured_counts, 
        attenuation_mask,
        num_iterations=1
    )
    
    # Convert to 3D for visualization
    image_3d = reconstructor.get_image_3d(reconstructed_image)
    
    print(f"Reconstructed image shape: {image_3d.shape}")
    print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
          f"mean={image_3d.mean():.3f}")
    
    return reconstructed_image, image_3d


# LOAD LOR DATA
# Initial Loading, Filtering, and Coordinate Range Calculation
import numpy as np

# Load coordinates
coordinates = np.load(fr"C:\Users\h\Desktop\PetStuff\Image_Processing\ground_truth.npy")

# Confirm shape should be (pairs, coords=6), coords are (x1, y1, z1, x2, y2, z2)
print(f"\nData Shape (pairs, coords) : {coordinates.shape}\n")  

# Remove pairs where any coordinate value is exactly 0
filtered_coordinates = coordinates[~np.any(coordinates == 0, axis=1)]
print(f"Filtered shape: {filtered_coordinates.shape}\n")
filtered_coordinates = torch.from_numpy(filtered_coordinates).float()

# Extract all x, y, z pairs - Coordinates are in the order (x1, y1, z1, x2, y2, z2)
all_xyz = filtered_coordinates.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")

image, image3d = MLEM(filtered_coordinates[:100,:])



Data Shape (pairs, coords) : (62660, 6)

Filtered shape: (6591, 6)

x range: min=-278.1294250488281, max=278.1666564941406
y range: min=-278.4194641113281, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453
Voxel grid: [278, 276, 149] voxels
Grid origin: [-275.20611572265625, -275.9847412109375, -148.61593627929688] mm
Voxel size: 2.0 mm
Starting MLEM reconstruction...
Building system matrix...
Processed batch 10/10
System matrix: torch.Size([100, 11432472]), 16938 non-zero elements


In [12]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider

# Convert tensor to numpy for plotting
recon_np = image3d.cpu().numpy()
nx, ny, nz = recon_np.shape

def plot_cross_sections_horizontal(x_idx=nx//2, y_idx=ny//2, z_idx=nz//2, vmax=0.5, cmap='Magma'):
    fig = make_subplots(rows=1, cols=3, subplot_titles=[
        f'XY plane @ z={z_idx}',
        f'XZ plane @ y={y_idx}',
        f'YZ plane @ x={x_idx}'
    ])

    # XY plane at z=z_idx
    fig.add_trace(go.Heatmap(
        z=recon_np[:, :, z_idx].T,
        colorscale=cmap,
        zmax=vmax,
        zmin=0,
        showscale=True,
        name=f'XY @ z={z_idx}'
    ), row=1, col=1)

    # XZ plane at y=y_idx
    fig.add_trace(go.Heatmap(
        z=recon_np[:, y_idx, :].T,
        colorscale=cmap,
        zmax=vmax,
        zmin=0,
        showscale=True,
        name=f'XZ @ y={y_idx}'
    ), row=1, col=2)

    # YZ plane at x=x_idx
    fig.add_trace(go.Heatmap(
        z=recon_np[x_idx, :, :].T,
        colorscale=cmap,
        zmax=vmax,
        zmin=0,
        showscale=True,
        name=f'YZ @ x={x_idx}'
    ), row=1, col=3)

    fig.update_layout(
        width=1200,
        height=400,
        title_text="Orthogonal Cross Sections"
    )
    fig.show()

interact(
    plot_cross_sections_horizontal,
    x_idx=IntSlider(min=0, max=nx-1, step=1, value=nx//2, description='X index'),
    y_idx=IntSlider(min=0, max=ny-1, step=1, value=ny//2, description='Y index'),
    z_idx=IntSlider(min=0, max=nz-1, step=1, value=nz//2, description='Z index'),
    vmax=FloatSlider(min=0, max=1, step=0.01, value=0.05, description='vmax'),
    cmap=['Magma','Greys', 'Viridis', 'Cividis', 'Plasma']
)

interactive(children=(IntSlider(value=138, description='X index', max=275), IntSlider(value=137, description='…

<function __main__.plot_cross_sections_horizontal(x_idx=138, y_idx=137, z_idx=72, vmax=0.5, cmap='Magma')>

In [13]:

def visualize_voxel_tensor_3d(voxel_tensor, initial_min_threshold=None, initial_max_threshold=None, 
                               voxel_size_mm=1.0, world_origin=None, min_threshold=None, max_threshold=None):
    """
    Interactive 3D visualization of voxel tensor with dual threshold sliders.

    Args:
        voxel_tensor: (nx, ny, nz) numpy array with voxel counts
        initial_min_threshold: Initial minimum threshold value for the slider (default: min_val)
        initial_max_threshold: Initial maximum threshold value for the slider (default: max_val)
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0) (optional)
        min_threshold: Minimum threshold value for slider range (optional)
        max_threshold: Maximum threshold value for slider range (optional)
    """
    # Extract non-zero voxel coordinates and values
    coords = np.where(voxel_tensor > 0)
    x_coords, y_coords, z_coords = coords
    values = voxel_tensor[coords]

    # Convert voxel indices to world coordinates if world_origin provided
    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords_world = x_coords * voxel_size_mm + x_min
        y_coords_world = y_coords * voxel_size_mm + y_min
        z_coords_world = z_coords * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords_world = x_coords * voxel_size_mm
        y_coords_world = y_coords * voxel_size_mm
        z_coords_world = z_coords * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"

    # Get value range for sliders
    min_val = float(np.min(values))
    max_val = float(np.max(values))

    # Use user-specified min/max threshold range if provided
    slider_min = min_threshold if min_threshold is not None else min_val
    slider_max = max_threshold if max_threshold is not None else max_val

    # Set initial thresholds with defaults
    if initial_min_threshold is None:
        initial_min_threshold = slider_min
    else:
        initial_min_threshold = max(slider_min, min(slider_max, float(initial_min_threshold)))
    
    if initial_max_threshold is None:
        initial_max_threshold = slider_max
    else:
        initial_max_threshold = max(slider_min, min(slider_max, float(initial_max_threshold)))

    # Ensure min <= max
    if initial_min_threshold > initial_max_threshold:
        initial_min_threshold, initial_max_threshold = initial_max_threshold, initial_min_threshold

    print(f"Voxel value range: {min_val} to {max_val}")
    print(f"Total non-zero voxels: {len(values)}")
    print(f"Initial thresholds: {initial_min_threshold} to {initial_max_threshold}")
    print(f"Slider range: {slider_min} to {slider_max}")
    print(f"Voxel resolution: {voxel_size_mm}mm")

    def update_plot(min_thresh, max_thresh):
        # Ensure min <= max
        if min_thresh > max_thresh:
            min_thresh, max_thresh = max_thresh, min_thresh

        # Filter voxels within threshold range
        mask = (values >= min_thresh) & (values <= max_thresh)
        if not np.any(mask):
            print(f"No voxels in threshold range [{min_thresh}, {max_thresh}]")
            return

        filtered_x = x_coords_world[mask]
        filtered_y = y_coords_world[mask]
        filtered_z = z_coords_world[mask]
        filtered_values = values[mask]

        # Create 3D scatter plot
        fig = go.Figure(data=go.Scatter3d(
            x=filtered_x,
            y=filtered_y,
            z=filtered_z,
            mode='markers',
            marker=dict(
                size=1,
                color=filtered_values,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Voxel Count"),
                line=dict(width=0)
            ),
            text=[f'Count: {v}' for v in filtered_values],
            hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>'
        ))

        fig.update_layout(
            title=f'3D Voxel Visualization (Range: [{min_thresh:.6f}, {max_thresh:.6f}], Showing: {len(filtered_values)} voxels)',
            scene=dict(
                xaxis_title=f'X{coord_suffix}',
                yaxis_title=f'Y{coord_suffix}',
                zaxis_title=f'Z{coord_suffix}',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                ),
                aspectmode='cube'
            ),
            width=800,
            height=600
        )

        fig.show()

    # Create interactive sliders with linked constraints
    min_threshold_slider = FloatSlider(
        value=initial_min_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Min Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    max_threshold_slider = FloatSlider(
        value=initial_max_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Max Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    # Link sliders to maintain min <= max constraint
    def on_min_change(change):
        if change['new'] > max_threshold_slider.value:
            max_threshold_slider.value = change['new']

    def on_max_change(change):
        if change['new'] < min_threshold_slider.value:
            min_threshold_slider.value = change['new']

    min_threshold_slider.observe(on_min_change, names='value')
    max_threshold_slider.observe(on_max_change, names='value')

    interact(update_plot, 
             min_thresh=min_threshold_slider, 
             max_thresh=max_threshold_slider)

visualize_voxel_tensor_3d(image3d.cpu().numpy(),
                         initial_min_threshold=0.0, 
                         initial_max_threshold=0.5)

Voxel value range: 2.795579327507096e-25 to 4.270153045654297
Total non-zero voxels: 8172
Initial thresholds: 2.795579327507096e-25 to 0.5
Slider range: 2.795579327507096e-25 to 4.270153045654297
Voxel resolution: 1.0mm


interactive(children=(FloatSlider(value=2.795579327507096e-25, continuous_update=False, description='Min Thres…

In [14]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional
import math
import cupy as cp
from cupyx.scipy.sparse import csr_matrix as cupy_csr_matrix

class OptimizedCUDAMLEM:
    """
    Highly optimized CUDA-accelerated MLEM with custom CUDA kernels,
    memory-efficient operations, and vectorized computations.
    """
    
    def __init__(self, lors: torch.Tensor, voxel_size: float = 2.0, device: str = 'cuda'):
        """Initialize with maximum GPU optimization."""
        self.device = torch.device(device)
        self.lors = lors.to(self.device, dtype=torch.float32, non_blocking=True)
        self.voxel_size = voxel_size
        self.num_lors = lors.shape[0]
        
        # Pre-allocate GPU memory
        torch.cuda.empty_cache()
        
        # Calculate voxel space
        self._calculate_voxel_space()
        
        # Pre-compile CUDA kernels
        self._setup_cuda_kernels()
        
        # System matrix storage
        self.system_matrix_csr = None
        self.sensitivity_image = None
        
    def _calculate_voxel_space(self):
        """Optimized voxel space calculation."""
        # Vectorized min/max calculation
        coords = self.lors.view(-1, 3)
        min_coords, _ = torch.min(coords, dim=0)
        max_coords, _ = torch.max(coords, dim=0)
        
        # Add padding
        padding = self.voxel_size
        min_coords -= padding
        max_coords += padding
        
        # Calculate grid dimensions
        self.grid_size = ((max_coords - min_coords) / self.voxel_size).ceil().int()
        self.grid_origin = min_coords
        self.total_voxels = self.grid_size.prod().item()
        
        print(f"Optimized grid: {self.grid_size.tolist()}, {self.total_voxels:,} voxels")
        
    def _setup_cuda_kernels(self):
        """Setup custom CUDA kernels for maximum performance."""
        # Ray tracing kernel
        self.ray_trace_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void siddon_ray_trace(const float* lors, const float* grid_origin, 
                             const int* grid_size, float voxel_size,
                             int* voxel_indices, int* lor_indices, 
                             float* weights, int* counts, int num_lors) {
            
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            // Get LOR endpoints
            float x1 = lors[lor_idx * 6 + 0];
            float y1 = lors[lor_idx * 6 + 1]; 
            float z1 = lors[lor_idx * 6 + 2];
            float x2 = lors[lor_idx * 6 + 3];
            float y2 = lors[lor_idx * 6 + 4];
            float z2 = lors[lor_idx * 6 + 5];
            
            // Convert to voxel coordinates
            float p1x = (x1 - grid_origin[0]) / voxel_size;
            float p1y = (y1 - grid_origin[1]) / voxel_size;
            float p1z = (z1 - grid_origin[2]) / voxel_size;
            float p2x = (x2 - grid_origin[0]) / voxel_size;
            float p2y = (y2 - grid_origin[1]) / voxel_size;
            float p2z = (z2 - grid_origin[2]) / voxel_size;
            
            // Ray direction
            float dx = p2x - p1x;
            float dy = p2y - p1y;
            float dz = p2z - p1z;
            float ray_length = sqrtf(dx*dx + dy*dy + dz*dz);
            
            if (ray_length < 1e-6) return;
            
            dx /= ray_length;
            dy /= ray_length; 
            dz /= ray_length;
            
            // Siddon algorithm - simplified for speed
            float stepX = (dx > 0) ? 1.0f : -1.0f;
            float stepY = (dy > 0) ? 1.0f : -1.0f;
            float stepZ = (dz > 0) ? 1.0f : -1.0f;
            
            float tDeltaX = fabsf(1.0f / dx);
            float tDeltaY = fabsf(1.0f / dy);
            float tDeltaZ = fabsf(1.0f / dz);
            
            int voxelX = (int)floorf(p1x);
            int voxelY = (int)floorf(p1y);
            int voxelZ = (int)floorf(p1z);
            
            float tMaxX = (dx > 0) ? (voxelX + 1 - p1x) * tDeltaX : (p1x - voxelX) * tDeltaX;
            float tMaxY = (dy > 0) ? (voxelY + 1 - p1y) * tDeltaY : (p1y - voxelY) * tDeltaY;
            float tMaxZ = (dz > 0) ? (voxelZ + 1 - p1z) * tDeltaZ : (p1z - voxelZ) * tDeltaZ;
            
            float t = 0.0f;
            int count = 0;
            int max_steps = grid_size[0] + grid_size[1] + grid_size[2];
            
            while (t < ray_length && count < max_steps) {
                // Check bounds
                if (voxelX >= 0 && voxelX < grid_size[0] &&
                    voxelY >= 0 && voxelY < grid_size[1] &&
                    voxelZ >= 0 && voxelZ < grid_size[2]) {
                    
                    float t_next = fminf(fminf(tMaxX, tMaxY), tMaxZ);
                    t_next = fminf(t_next, ray_length);
                    float length = t_next - t;
                    
                    if (length > 1e-6) {
                        int voxel_idx = voxelX * grid_size[1] * grid_size[2] + 
                                       voxelY * grid_size[2] + voxelZ;
                        
                        int base_idx = lor_idx * max_steps + count;
                        voxel_indices[base_idx] = voxel_idx;
                        lor_indices[base_idx] = lor_idx;
                        weights[base_idx] = length * voxel_size;
                        count++;
                    }
                    
                    t = t_next;
                }
                
                // Move to next voxel
                if (tMaxX < tMaxY && tMaxX < tMaxZ) {
                    voxelX += (int)stepX;
                    tMaxX += tDeltaX;
                } else if (tMaxY < tMaxZ) {
                    voxelY += (int)stepY;
                    tMaxY += tDeltaY;
                } else {
                    voxelZ += (int)stepZ;
                    tMaxZ += tDeltaZ;
                }
            }
            
            counts[lor_idx] = count;
        }
        ''', 'siddon_ray_trace')
        
        # Forward projection kernel
        self.forward_proj_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void forward_project(const float* image, const int* row_ptr, 
                           const int* col_idx, const float* data,
                           float* sino, int num_lors) {
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            float sum = 0.0f;
            for (int i = row_ptr[lor_idx]; i < row_ptr[lor_idx + 1]; i++) {
                sum += data[i] * image[col_idx[i]];
            }
            sino[lor_idx] = sum;
        }
        ''', 'forward_project')
        
        # Back projection kernel  
        self.back_proj_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void back_project(const float* sino, const int* row_ptr,
                         const int* col_idx, const float* data, 
                         float* image, int num_lors) {
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            for (int i = row_ptr[lor_idx]; i < row_ptr[lor_idx + 1]; i++) {
                atomicAdd(&image[col_idx[i]], data[i] * sino[lor_idx]);
            }
        }
        ''', 'back_project')
    
    def build_system_matrix_gpu(self, max_intersections_per_lor: int = 1000):
        """Ultra-fast GPU system matrix building with custom CUDA kernels."""
        print("Building system matrix with GPU acceleration...")
        
        # Allocate GPU memory for intersections
        max_total = self.num_lors * max_intersections_per_lor
        voxel_indices = cp.zeros(max_total, dtype=cp.int32)
        lor_indices = cp.zeros(max_total, dtype=cp.int32)
        weights = cp.zeros(max_total, dtype=cp.float32)
        counts = cp.zeros(self.num_lors, dtype=cp.int32)
        
        # Convert tensors to CuPy arrays
        lors_cp = cp.asarray(self.lors)
        grid_origin_cp = cp.asarray(self.grid_origin)
        grid_size_cp = cp.asarray(self.grid_size)
        
        # Launch ray tracing kernel
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.ray_trace_kernel(
            (blocks,), (threads_per_block,),
            (lors_cp, grid_origin_cp, grid_size_cp, self.voxel_size,
             voxel_indices, lor_indices, weights, counts, self.num_lors)
        )
        cp.cuda.Stream.null.synchronize()
        
        # Compact results (remove zeros)
        total_intersections = int(cp.sum(counts))
        print(f"Total intersections: {total_intersections:,}")
        
        # Create compact arrays
        valid_voxels = []
        valid_lors = []
        valid_weights = []
        
        for lor_idx in range(self.num_lors):
            count = int(counts[lor_idx])
            if count > 0:
                start_idx = lor_idx * max_intersections_per_lor
                end_idx = start_idx + count
                valid_voxels.extend(voxel_indices[start_idx:end_idx].tolist())
                valid_lors.extend([lor_idx] * count)
                valid_weights.extend(weights[start_idx:end_idx].tolist())
        
        # Create CSR matrix for ultra-fast SpMV
        row_indices = cp.array(valid_lors, dtype=cp.int32)
        col_indices = cp.array(valid_voxels, dtype=cp.int32)
        data = cp.array(valid_weights, dtype=cp.float32)
        
        self.system_matrix_csr = cupy_csr_matrix(
            (data, (row_indices, col_indices)), 
            shape=(self.num_lors, self.total_voxels)
        )
        
        # Calculate sensitivity image (column sums) - GPU accelerated
        self.sensitivity_image = cp.array(self.system_matrix_csr.sum(axis=0)).flatten()
        self.sensitivity_image[self.sensitivity_image == 0] = 1.0
        
        print(f"System matrix: {self.system_matrix_csr.shape}, "
              f"{self.system_matrix_csr.nnz:,} non-zeros")
    
    def forward_project_gpu(self, image_cp: cp.ndarray) -> cp.ndarray:
        """Ultra-fast GPU forward projection using custom CUDA kernel."""
        sino = cp.zeros(self.num_lors, dtype=cp.float32)
        
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.forward_proj_kernel(
            (blocks,), (threads_per_block,),
            (image_cp, self.system_matrix_csr.indptr, 
             self.system_matrix_csr.indices, self.system_matrix_csr.data,
             sino, self.num_lors)
        )
        
        return sino
    
    def back_project_gpu(self, sino_cp: cp.ndarray) -> cp.ndarray:
        """Ultra-fast GPU back projection using custom CUDA kernel."""
        image = cp.zeros(self.total_voxels, dtype=cp.float32)
        
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.back_proj_kernel(
            (blocks,), (threads_per_block,),
            (sino_cp, self.system_matrix_csr.indptr,
             self.system_matrix_csr.indices, self.system_matrix_csr.data,
             image, self.num_lors)
        )
        
        return image
    
    def estimate_scatter_gpu(self, image_cp: cp.ndarray, scatter_fraction: float = 0.15) -> cp.ndarray:
        """GPU-accelerated scatter estimation with 3D convolution."""
        # Reshape to 3D
        img_3d = image_cp.reshape(self.grid_size[0].item(), 
                                 self.grid_size[1].item(), 
                                 self.grid_size[2].item())
        
        # Fast 3D Gaussian convolution using separable filters
        sigma = 2.0
        kernel_size = 7
        
        # Create 1D Gaussian kernel
        x = cp.arange(kernel_size) - kernel_size // 2
        kernel_1d = cp.exp(-x**2 / (2 * sigma**2))
        kernel_1d /= kernel_1d.sum()
        
        # Apply separable 3D convolution (much faster than full 3D)
        from cupyx.scipy.ndimage import convolve1d
        
        # Convolve along each axis
        scattered = convolve1d(img_3d, kernel_1d, axis=0, mode='constant')
        scattered = convolve1d(scattered, kernel_1d, axis=1, mode='constant')
        scattered = convolve1d(scattered, kernel_1d, axis=2, mode='constant')
        
        # Forward project and scale
        scattered_flat = scattered.flatten()
        scatter_sino = self.forward_project_gpu(scattered_flat)
        
        return scatter_sino * scatter_fraction
    
    def reconstruct_gpu(self, measured_data: torch.Tensor, 
                       attenuation_mask: torch.Tensor,
                       num_iterations: int = 20, 
                       convergence_threshold: float = 1e-4) -> torch.Tensor:
        """Ultra-fast GPU MLEM reconstruction."""
        print("Starting optimized GPU MLEM reconstruction...")
        
        # Build system matrix if needed
        if self.system_matrix_csr is None:
            self.build_system_matrix_gpu()
        
        # Convert to CuPy for maximum speed
        measured_cp = cp.asarray(measured_data.to(self.device))
        atten_cp = cp.asarray(attenuation_mask.to(self.device).flatten())
        
        # Initialize image
        image_cp = cp.ones(self.total_voxels, dtype=cp.float32)
        
        # Apply attenuation correction to measured data
        atten_factors = self.forward_project_gpu(atten_cp)
        atten_correction = cp.exp(-atten_factors)
        measured_corrected = measured_cp * atten_correction
        
        # Pre-allocate arrays for speed
        estimated_sino = cp.zeros(self.num_lors, dtype=cp.float32)
        correction = cp.zeros(self.total_voxels, dtype=cp.float32)
        
        prev_likelihood = float('inf')
        
        # MLEM iterations
        for iteration in range(num_iterations):
            # Forward project - GPU accelerated
            estimated_sino = self.forward_project_gpu(image_cp)
            
            # Add scatter and randoms - GPU operations
            if iteration % 2 == 0:  # Update scatter every other iteration for speed
                scatter = self.estimate_scatter_gpu(image_cp)
                randoms = cp.full(self.num_lors, 0.05, dtype=cp.float32)
                self._cached_corrections = scatter + randoms
            
            total_estimated = estimated_sino + self._cached_corrections
            total_estimated = cp.maximum(total_estimated, 1e-10)
            
            # Calculate ratio and back project - GPU accelerated
            ratio = measured_corrected / total_estimated
            correction = self.back_project_gpu(ratio)
            
            # MLEM update - GPU operations
            image_cp *= correction / self.sensitivity_image
            image_cp = cp.maximum(image_cp, 0)  # Non-negativity
            
            # Fast likelihood calculation (subset for speed)
            if iteration % 2 == 0:
                likelihood = float(cp.sum(measured_corrected * cp.log(total_estimated) - total_estimated))
                
                if iteration > 0:
                    rel_change = abs(likelihood - prev_likelihood) / abs(prev_likelihood)
                    print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}, "
                          f"Rel. change = {rel_change:.6f}")
                    
                    if rel_change < convergence_threshold:
                        print(f"Converged after {iteration + 1} iterations")
                        break
                else:
                    print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}")
                
                prev_likelihood = likelihood
            else:
                print(f"Iteration {iteration + 1}: (likelihood check skipped for speed)")
        
        # Convert back to PyTorch
        result = torch.as_tensor(image_cp, device=self.device)
        return result
    
    def get_image_3d(self, image_1d: torch.Tensor) -> torch.Tensor:
        """Convert 1D image to 3D array."""
        return image_1d.view(self.grid_size[0], self.grid_size[1], self.grid_size[2])

# Optimized example usage
def optimized_example():
    """Example with GPU optimization features."""
    
    # Enable GPU optimizations
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    
    # Generate test data
    num_lors = 100000  # Larger dataset to show performance gains
    lors = torch.randn(num_lors, 6, device='cuda') * 100
    measured_counts = torch.poisson(torch.ones(num_lors, device='cuda') * 50)
    
    # Initialize optimized reconstructor
    reconstructor = OptimizedCUDAMLEM(lors, voxel_size=2.0, device='cuda')
    
    # Create attenuation mask
    grid_shape = reconstructor.grid_size
    attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2], 
                                device='cuda') * 0.1
    
    # Time the reconstruction
    import time
    start_time = time.time()
    
    # Perform optimized reconstruction
    reconstructed_image = reconstructor.reconstruct_gpu(
        measured_counts, 
        attenuation_mask,
        num_iterations=10
    )
    
    end_time = time.time()
    print(f"\nReconstruction completed in {end_time - start_time:.2f} seconds")
    
    # Convert to 3D
    image_3d = reconstructor.get_image_3d(reconstructed_image)
    
    print(f"Final image shape: {image_3d.shape}")
    print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
          f"mean={image_3d.mean():.3f}")
    
    return reconstructed_image, image_3d

if __name__ == "__main__":
    optimized_example()

Optimized grid: [459, 458, 455], 95,651,010 voxels
Starting optimized GPU MLEM reconstruction...
Building system matrix with GPU acceleration...
Total intersections: 0
System matrix: (100000, 95651010), 0 non-zeros
Iteration 1: Log-likelihood = -14987221.00
Iteration 2: (likelihood check skipped for speed)
Iteration 3: Log-likelihood = -14987221.00, Rel. change = 0.000000
Converged after 3 iterations

Reconstruction completed in 11.26 seconds
Final image shape: torch.Size([459, 458, 455])
Image statistics: min=0.000, max=0.000, mean=0.000


In [15]:
# Enable GPU optimizations
torch.backends.cudnn.benchmark = True
torch.cuda.empty_cache()

# Generate test data
measured_counts = torch.poisson(torch.ones(filtered_coordinates.shape[0], device='cuda') * 50)

# Initialize optimized reconstructor
reconstructor = OptimizedCUDAMLEM(filtered_coordinates, voxel_size=2.0, device='cuda')

# Create attenuation mask
grid_shape = reconstructor.grid_size
attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2], 
                            device='cuda') * 0.1

# Time the reconstruction
import time
start_time = time.time()

# Perform optimized reconstruction
reconstructed_image = reconstructor.reconstruct_gpu(
    measured_counts, 
    attenuation_mask,
    num_iterations=10
)

end_time = time.time()
print(f"\nReconstruction completed in {end_time - start_time:.2f} seconds")

# Convert to 3D
image_3d = reconstructor.get_image_3d(reconstructed_image)

print(f"Final image shape: {image_3d.shape}")
print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
        f"mean={image_3d.mean():.3f}")




Optimized grid: [281, 281, 150], 11,844,150 voxels
Starting optimized GPU MLEM reconstruction...
Building system matrix with GPU acceleration...
Total intersections: 0
System matrix: (6591, 11844150), 0 non-zeros
Iteration 1: Log-likelihood = -987962.62
Iteration 2: (likelihood check skipped for speed)
Iteration 3: Log-likelihood = -987962.62, Rel. change = 0.000000
Converged after 3 iterations

Reconstruction completed in 0.26 seconds
Final image shape: torch.Size([281, 281, 150])
Image statistics: min=0.000, max=0.000, mean=0.000


In [16]:
visualize_voxel_tensor_3d(image_3d.cpu().numpy(),
                         initial_min_threshold=0.0, 
                         initial_max_threshold=0.5)

ValueError: zero-size array to reduction operation minimum which has no identity

In [28]:
import torch
import torch.nn.functional as F
import numpy as np
from typing import Tuple, Optional
import math
import cupy as cp
from cupyx.scipy.sparse import csr_matrix as cupy_csr_matrix

class OptimizedCUDAMLEM:
    """
    Highly optimized CUDA-accelerated MLEM with custom CUDA kernels,
    memory-efficient operations, and vectorized computations.
    """
    
    def __init__(self, lors: torch.Tensor, voxel_size: float = 2.0, device: str = 'cuda'):
        """Initialize with maximum GPU optimization."""
        self.device = torch.device(device)
        self.lors = lors.to(self.device, dtype=torch.float32, non_blocking=True)
        self.voxel_size = voxel_size
        self.num_lors = lors.shape[0]
        
        # Pre-allocate GPU memory
        torch.cuda.empty_cache()
        
        # Calculate voxel space
        self._calculate_voxel_space()
        
        # Pre-compile CUDA kernels
        self._setup_cuda_kernels()
        
        # System matrix storage
        self.system_matrix_csr = None
        self.sensitivity_image = None
        
    def _calculate_voxel_space(self):
        """Optimized voxel space calculation."""
        # Vectorized min/max calculation
        coords = self.lors.view(-1, 3)
        min_coords, _ = torch.min(coords, dim=0)
        max_coords, _ = torch.max(coords, dim=0)
        
        # Add padding
        padding = self.voxel_size
        min_coords -= padding
        max_coords += padding
        
        # Calculate grid dimensions
        self.grid_size = ((max_coords - min_coords) / self.voxel_size).ceil().int()
        self.grid_origin = min_coords
        self.total_voxels = self.grid_size.prod().item()
        
        print(f"Optimized grid: {self.grid_size.tolist()}, {self.total_voxels:,} voxels")
        
    def _setup_cuda_kernels(self):
        """Setup custom CUDA kernels with proper debugging."""
        # Robust ray tracing kernel with fixed weight calculation
        self.ray_trace_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void siddon_ray_trace(const float* lors, const float* grid_origin, 
                             const int* grid_size, float voxel_size,
                             int* voxel_indices, int* lor_indices, 
                             float* weights, int* counts, int num_lors,
                             int max_intersections_per_lor) {
            
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            // Get LOR endpoints
            float x1 = lors[lor_idx * 6 + 0];
            float y1 = lors[lor_idx * 6 + 1]; 
            float z1 = lors[lor_idx * 6 + 2];
            float x2 = lors[lor_idx * 6 + 3];
            float y2 = lors[lor_idx * 6 + 4];
            float z2 = lors[lor_idx * 6 + 5];
            
            // Convert to voxel coordinates
            float p1x = (x1 - grid_origin[0]) / voxel_size;
            float p1y = (y1 - grid_origin[1]) / voxel_size;
            float p1z = (z1 - grid_origin[2]) / voxel_size;
            float p2x = (x2 - grid_origin[0]) / voxel_size;
            float p2y = (y2 - grid_origin[1]) / voxel_size;
            float p2z = (z2 - grid_origin[2]) / voxel_size;
            
            // Ray direction and length in voxel space
            float dx = p2x - p1x;
            float dy = p2y - p1y;
            float dz = p2z - p1z;
            float ray_length_voxels = sqrtf(dx*dx + dy*dy + dz*dz);
            
            if (ray_length_voxels < 1e-6f) {
                counts[lor_idx] = 0;
                return;
            }
            
            // Normalize direction
            dx /= ray_length_voxels;
            dy /= ray_length_voxels; 
            dz /= ray_length_voxels;
            
            // DDA-style traversal with fixed step size
            int num_steps = (int)(ray_length_voxels * 1.5f) + 50;  // Ensure good sampling
            if (num_steps > max_intersections_per_lor) num_steps = max_intersections_per_lor;
            if (num_steps < 10) num_steps = 10;  // Minimum sampling
            
            int count = 0;
            int prev_voxel_idx = -1;  // Track previous voxel to avoid duplicates
            
            // Sample along the ray
            for (int step = 0; step < num_steps && count < max_intersections_per_lor - 1; step++) {
                float t = (float)step / (float)(num_steps - 1) * ray_length_voxels;
                
                // Current position along ray in voxel coordinates
                float px = p1x + t * dx;
                float py = p1y + t * dy;
                float pz = p1z + t * dz;
                
                // Current voxel indices
                int vx = (int)floorf(px);
                int vy = (int)floorf(py);
                int vz = (int)floorf(pz);
                
                // Check bounds
                if (vx >= 0 && vx < grid_size[0] &&
                    vy >= 0 && vy < grid_size[1] &&
                    vz >= 0 && vz < grid_size[2]) {
                    
                    // Calculate flat voxel index
                    int voxel_idx = vx * grid_size[1] * grid_size[2] + 
                                   vy * grid_size[2] + vz;
                    
                    // Only add if this is a new voxel
                    if (voxel_idx != prev_voxel_idx) {
                        int base_idx = lor_idx * max_intersections_per_lor + count;
                        
                        voxel_indices[base_idx] = voxel_idx;
                        lor_indices[base_idx] = lor_idx;
                        
                        // Fixed weight calculation: length per step in mm
                        float weight = ray_length_voxels * voxel_size / (float)num_steps;
                        
                        // Ensure weight is finite and positive
                        if (weight > 0.0f && isfinite(weight)) {
                            weights[base_idx] = weight;
                        } else {
                            weights[base_idx] = voxel_size * 0.1f;  // Fallback weight
                        }
                        
                        prev_voxel_idx = voxel_idx;
                        count++;
                    }
                }
            }
            
            counts[lor_idx] = count;
        }
        ''', 'siddon_ray_trace')
        
        # Forward projection kernel
        self.forward_proj_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void forward_project(const float* image, const int* row_ptr, 
                           const int* col_idx, const float* data,
                           float* sino, int num_lors) {
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            float sum = 0.0f;
            for (int i = row_ptr[lor_idx]; i < row_ptr[lor_idx + 1]; i++) {
                sum += data[i] * image[col_idx[i]];
            }
            sino[lor_idx] = sum;
        }
        ''', 'forward_project')
        
        # Back projection kernel  
        self.back_proj_kernel = cp.RawKernel(r'''
        extern "C" __global__
        void back_project(const float* sino, const int* row_ptr,
                         const int* col_idx, const float* data, 
                         float* image, int num_lors) {
            int lor_idx = blockIdx.x * blockDim.x + threadIdx.x;
            if (lor_idx >= num_lors) return;
            
            for (int i = row_ptr[lor_idx]; i < row_ptr[lor_idx + 1]; i++) {
                atomicAdd(&image[col_idx[i]], data[i] * sino[lor_idx]);
            }
        }
        ''', 'back_project')
    
    def build_system_matrix_gpu(self, max_intersections_per_lor: int = 1000):
        """Ultra-fast GPU system matrix building with custom CUDA kernels."""
        print("Building system matrix with GPU acceleration...")
        
        # Allocate GPU memory for intersections
        max_total = self.num_lors * max_intersections_per_lor
        voxel_indices = cp.zeros(max_total, dtype=cp.int32)
        lor_indices = cp.zeros(max_total, dtype=cp.int32)
        weights = cp.zeros(max_total, dtype=cp.float32)
        counts = cp.zeros(self.num_lors, dtype=cp.int32)
        
        # Convert tensors to CuPy arrays
        lors_cp = cp.asarray(self.lors)
        grid_origin_cp = cp.asarray(self.grid_origin)
        grid_size_cp = cp.asarray(self.grid_size)
        
        # Debug: Check LOR data first
        print(f"LOR data check:")
        print(f"  LOR range: X=[{float(lors_cp[:, 0].min()):.1f}, {float(lors_cp[:, 0].max()):.1f}]")
        print(f"             Y=[{float(lors_cp[:, 1].min()):.1f}, {float(lors_cp[:, 1].max()):.1f}]") 
        print(f"             Z=[{float(lors_cp[:, 2].min()):.1f}, {float(lors_cp[:, 2].max()):.1f}]")
        print(f"  Grid origin: {[float(x) for x in grid_origin_cp]}")
        print(f"  Grid size: {[int(x) for x in grid_size_cp]}")
        print(f"  Voxel size: {self.voxel_size}")
        
        # Launch ray tracing kernel with fixed parameters
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.ray_trace_kernel(
            (blocks,), (threads_per_block,),
            (lors_cp, grid_origin_cp, grid_size_cp, self.voxel_size,
             voxel_indices, lor_indices, weights, counts, self.num_lors,
             max_intersections_per_lor)
        )
        cp.cuda.Stream.null.synchronize()
        
        # Compact results with debugging
        total_intersections = int(cp.sum(counts))
        print(f"Total intersections: {total_intersections:,}")
        
        if total_intersections == 0:
            print("WARNING: No intersections found!")
            print("Debug info:")
            # Check a few LORs manually
            sample_lors = lors_cp[:5]
            for i, lor in enumerate(sample_lors):
                x1, y1, z1, x2, y2, z2 = lor
                print(f"  LOR {i}: ({x1:.1f},{y1:.1f},{z1:.1f}) -> ({x2:.1f},{y2:.1f},{z2:.1f})")
                
                # Convert to voxel coords manually
                vx1 = (x1 - grid_origin_cp[0]) / self.voxel_size
                vy1 = (y1 - grid_origin_cp[1]) / self.voxel_size  
                vz1 = (z1 - grid_origin_cp[2]) / self.voxel_size
                vx2 = (x2 - grid_origin_cp[0]) / self.voxel_size
                vy2 = (y2 - grid_origin_cp[1]) / self.voxel_size
                vz2 = (z2 - grid_origin_cp[2]) / self.voxel_size
                print(f"    Voxel coords: ({vx1:.1f},{vy1:.1f},{vz1:.1f}) -> ({vx2:.1f},{vy2:.1f},{vz2:.1f})")
                
                # Check if any part is in bounds
                in_bounds = ((0 <= vx1 < grid_size_cp[0] or 0 <= vx2 < grid_size_cp[0]) and
                           (0 <= vy1 < grid_size_cp[1] or 0 <= vy2 < grid_size_cp[1]) and  
                           (0 <= vz1 < grid_size_cp[2] or 0 <= vz2 < grid_size_cp[2]))
                print(f"    In bounds: {in_bounds}")
            
            # Fallback to CPU ray tracing for comparison
            print("Falling back to CPU ray tracing...")
            return self._build_system_matrix_cpu_fallback()
        
        # Create compact arrays
        valid_voxels = []
        valid_lors = []
        valid_weights = []
        
        for lor_idx in range(self.num_lors):
            count = int(counts[lor_idx])
            if count > 0:
                start_idx = lor_idx * max_intersections_per_lor
                end_idx = start_idx + count
                valid_voxels.extend(voxel_indices[start_idx:end_idx].tolist())
                valid_lors.extend([lor_idx] * count)
                valid_weights.extend(weights[start_idx:end_idx].tolist())
        
        # Validate and clean data before creating matrix
        print(f"Validating {len(valid_weights)} intersection weights...")
        valid_weights_clean = []
        valid_voxels_clean = []
        valid_lors_clean = []
        
        for i, (weight, voxel, lor) in enumerate(zip(valid_weights, valid_voxels, valid_lors)):
            if np.isfinite(weight) and weight > 0 and 0 <= voxel < self.total_voxels:
                valid_weights_clean.append(weight)
                valid_voxels_clean.append(voxel)
                valid_lors_clean.append(lor)
        
        print(f"After cleaning: {len(valid_weights_clean)} valid intersections")
        
        if len(valid_weights_clean) == 0:
            print("ERROR: No valid intersections after cleaning!")
            # Create minimal dummy matrix
            valid_weights_clean = [1.0]
            valid_voxels_clean = [0]
            valid_lors_clean = [0]
        
        # Create CSR matrix with cleaned data
        row_indices = cp.array(valid_lors_clean, dtype=cp.int32)
        col_indices = cp.array(valid_voxels_clean, dtype=cp.int32)
        data = cp.array(valid_weights_clean, dtype=cp.float32)
        
        self.system_matrix_csr = cupy_csr_matrix(
            (data, (row_indices, col_indices)), 
            shape=(self.num_lors, self.total_voxels)
        )
        
        # Verify matrix was created properly
        print(f"System matrix created: shape={self.system_matrix_csr.shape}")
        print(f"Data range: min={float(self.system_matrix_csr.data.min()):.6f}, "
              f"max={float(self.system_matrix_csr.data.max()):.6f}")
        print(f"NaN in data: {int(cp.sum(cp.isnan(self.system_matrix_csr.data)))}")
        
        # Calculate sensitivity image (column sums) - GPU accelerated with fix
        print("Calculating sensitivity image...")
        
        # Method 1: Try direct sum
        try:
            sensitivity_raw = cp.array(self.system_matrix_csr.sum(axis=0)).flatten()
            if cp.any(cp.isnan(sensitivity_raw)) or cp.any(cp.isinf(sensitivity_raw)):
                print("WARNING: NaN/Inf in raw sensitivity - using alternative method")
                raise ValueError("Invalid sensitivity values")
            self.sensitivity_image = sensitivity_raw
        except:
            # Method 2: Manual calculation for robustness
            print("Using manual sensitivity calculation...")
            self.sensitivity_image = cp.zeros(self.total_voxels, dtype=cp.float32)
            
            # Add weights manually - use cleaned data
            for i in range(len(data)):
                col_idx = int(col_indices[i])
                weight = float(data[i])
                if 0 <= col_idx < self.total_voxels:
                    self.sensitivity_image[col_idx] += weight
        
        # Ensure no zeros or invalid values
        self.sensitivity_image = cp.maximum(self.sensitivity_image, 1e-10)
        self.sensitivity_image = cp.nan_to_num(self.sensitivity_image, nan=1e-10, posinf=1e10, neginf=1e-10)
        
        print(f"Sensitivity image stats: min={float(self.sensitivity_image.min()):.6f}, "
              f"max={float(self.sensitivity_image.max()):.6f}, "
              f"mean={float(self.sensitivity_image.mean()):.6f}, "
              f"zeros={int(cp.sum(self.sensitivity_image <= 1e-10))}")
        
        print(f"System matrix: {self.system_matrix_csr.shape}, "
              f"{self.system_matrix_csr.nnz:,} non-zeros")
    
    def _build_system_matrix_cpu_fallback(self):
        """CPU fallback when GPU ray tracing fails."""
        print("Building system matrix with CPU fallback...")
        
        # Convert back to PyTorch for CPU processing
        lors_torch = torch.as_tensor(self.lors)
        
        all_voxel_indices = []
        all_lor_indices = []
        all_weights = []
        
        # Process in batches
        batch_size = 1000
        num_batches = (self.num_lors + batch_size - 1) // batch_size
        
        for batch_idx in range(num_batches):
            start_idx = batch_idx * batch_size
            end_idx = min((batch_idx + 1) * batch_size, self.num_lors)
            
            lors_batch = lors_torch[start_idx:end_idx]
            
            # Simple ray tracing
            for i, lor in enumerate(lors_batch):
                x1, y1, z1, x2, y2, z2 = lor
                
                # Convert to voxel coordinates  
                p1 = torch.tensor([(x1 - self.grid_origin[0]) / self.voxel_size,
                                 (y1 - self.grid_origin[1]) / self.voxel_size,
                                 (z1 - self.grid_origin[2]) / self.voxel_size])
                p2 = torch.tensor([(x2 - self.grid_origin[0]) / self.voxel_size,
                                 (y2 - self.grid_origin[1]) / self.voxel_size,
                                 (z2 - self.grid_origin[2]) / self.voxel_size])
                
                # Simple sampling along ray
                direction = p2 - p1
                length = torch.norm(direction)
                
                if length > 1e-6:
                    direction = direction / length
                    
                    # Sample points along ray
                    num_samples = int(length * 2) + 10
                    for step in range(num_samples):
                        t = step / num_samples * length
                        pos = p1 + t * direction
                        
                        vx, vy, vz = int(pos[0]), int(pos[1]), int(pos[2])
                        
                        if (0 <= vx < self.grid_size[0] and 
                            0 <= vy < self.grid_size[1] and
                            0 <= vz < self.grid_size[2]):
                            
                            voxel_idx = vx * self.grid_size[1] * self.grid_size[2] + vy * self.grid_size[2] + vz
                            lor_idx = start_idx + i
                            weight = length / num_samples * self.voxel_size
                            
                            all_voxel_indices.append(int(voxel_idx))
                            all_lor_indices.append(int(lor_idx))
                            all_weights.append(float(weight))
            
            if (batch_idx + 1) % 10 == 0:
                print(f"CPU fallback: batch {batch_idx + 1}/{num_batches}")
        
        if len(all_voxel_indices) == 0:
            print("ERROR: Even CPU fallback found no intersections!")
            # Create minimal dummy matrix to avoid crashes
            all_voxel_indices = [0]
            all_lor_indices = [0] 
            all_weights = [1e-10]
        
        # Create CSR matrix
        row_indices = cp.array(all_lor_indices, dtype=cp.int32)
        col_indices = cp.array(all_voxel_indices, dtype=cp.int32)
        data = cp.array(all_weights, dtype=cp.float32)
        
        self.system_matrix_csr = cupy_csr_matrix(
            (data, (row_indices, col_indices)), 
            shape=(self.num_lors, self.total_voxels)
        )
        
        # Calculate sensitivity image
        self.sensitivity_image = cp.array(self.system_matrix_csr.sum(axis=0)).flatten()
        self.sensitivity_image[self.sensitivity_image == 0] = 1.0
    
    def forward_project_gpu(self, image_cp: cp.ndarray) -> cp.ndarray:
        """Ultra-fast GPU forward projection using custom CUDA kernel."""
        sino = cp.zeros(self.num_lors, dtype=cp.float32)
        
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.forward_proj_kernel(
            (blocks,), (threads_per_block,),
            (image_cp, self.system_matrix_csr.indptr, 
             self.system_matrix_csr.indices, self.system_matrix_csr.data,
             sino, self.num_lors)
        )
        
        return sino
    
    def back_project_gpu(self, sino_cp: cp.ndarray) -> cp.ndarray:
        """Ultra-fast GPU back projection using custom CUDA kernel."""
        image = cp.zeros(self.total_voxels, dtype=cp.float32)
        
        threads_per_block = 256
        blocks = (self.num_lors + threads_per_block - 1) // threads_per_block
        
        self.back_proj_kernel(
            (blocks,), (threads_per_block,),
            (sino_cp, self.system_matrix_csr.indptr,
             self.system_matrix_csr.indices, self.system_matrix_csr.data,
             image, self.num_lors)
        )
        
        return image
    
    def estimate_scatter_gpu(self, image_cp: cp.ndarray, scatter_fraction: float = 0.15) -> cp.ndarray:
        """GPU-accelerated scatter estimation with 3D convolution."""
        # Reshape to 3D
        img_3d = image_cp.reshape(self.grid_size[0].item(), 
                                 self.grid_size[1].item(), 
                                 self.grid_size[2].item())
        
        # Fast 3D Gaussian convolution using separable filters
        sigma = 2.0
        kernel_size = 7
        
        # Create 1D Gaussian kernel
        x = cp.arange(kernel_size) - kernel_size // 2
        kernel_1d = cp.exp(-x**2 / (2 * sigma**2))
        kernel_1d /= kernel_1d.sum()
        
        # Apply separable 3D convolution (much faster than full 3D)
        from cupyx.scipy.ndimage import convolve1d
        
        # Convolve along each axis
        scattered = convolve1d(img_3d, kernel_1d, axis=0, mode='constant')
        scattered = convolve1d(scattered, kernel_1d, axis=1, mode='constant')
        scattered = convolve1d(scattered, kernel_1d, axis=2, mode='constant')
        
        # Forward project and scale
        scattered_flat = scattered.flatten()
        scatter_sino = self.forward_project_gpu(scattered_flat)
        
        return scatter_sino * scatter_fraction
    
    def reconstruct_gpu(self, measured_data: torch.Tensor, 
                       attenuation_mask: torch.Tensor,
                       num_iterations: int = 20, 
                       convergence_threshold: float = 1e-4) -> torch.Tensor:
        """Ultra-fast GPU MLEM reconstruction with numerical stability."""
        print("Starting optimized GPU MLEM reconstruction...")
        
        # Build system matrix if needed
        if self.system_matrix_csr is None:
            self.build_system_matrix_gpu()
        
        # Convert to CuPy for maximum speed
        measured_cp = cp.asarray(measured_data.to(self.device), dtype=cp.float32)
        atten_cp = cp.asarray(attenuation_mask.to(self.device).flatten(), dtype=cp.float32)
        
        # Debug system matrix data quality
        print(f"System matrix data check:")
        data_stats = self.system_matrix_csr.data
        print(f"  Data stats: min={float(data_stats.min()):.6f}, max={float(data_stats.max()):.6f}")
        print(f"  NaN count: {int(cp.sum(cp.isnan(data_stats)))}")
        print(f"  Inf count: {int(cp.sum(cp.isinf(data_stats)))}")
        print(f"  Zero count: {int(cp.sum(data_stats == 0))}")
        
        # Clean system matrix data
        if cp.any(cp.isnan(data_stats)) or cp.any(cp.isinf(data_stats)):
            print("Cleaning system matrix data...")
            self.system_matrix_csr.data = cp.nan_to_num(data_stats, nan=0.0, posinf=1e6, neginf=0.0)
            # Remove zeros after cleaning
            self.system_matrix_csr.eliminate_zeros()
        
        # Debug system matrix
        print(f"System matrix nnz: {self.system_matrix_csr.nnz}")
        print(f"Sensitivity image stats: min={float(self.sensitivity_image.min()):.6f}, "
              f"max={float(self.sensitivity_image.max()):.6f}, "
              f"mean={float(self.sensitivity_image.mean()):.6f}")
        
        # Check for zeros in sensitivity
        zero_sens = cp.sum(self.sensitivity_image <= 1e-10)
        if zero_sens > 0:
            print(f"WARNING: {zero_sens} voxels have near-zero sensitivity")
            
        # More robust sensitivity handling
        sens_mean = float(self.sensitivity_image.mean())
        sens_threshold = max(1e-10, sens_mean * 1e-6)  # Adaptive threshold
        self.sensitivity_image = cp.maximum(self.sensitivity_image, sens_threshold)
        
        print(f"Applied sensitivity threshold: {sens_threshold:.2e}")
        print(f"Updated sensitivity stats: min={float(self.sensitivity_image.min()):.6f}, "
              f"max={float(self.sensitivity_image.max()):.6f}")
        
        # Initialize image with better scaling
        total_sensitivity = float(cp.sum(self.sensitivity_image))
        image_cp = cp.ones(self.total_voxels, dtype=cp.float32)
        total_measured = float(cp.sum(measured_cp))
        if total_measured > 0 and total_sensitivity > 0:
            # Scale by ratio of measured counts to total sensitivity
            scale_factor = total_measured / total_sensitivity
            image_cp *= scale_factor
            print(f"Image initialization scale factor: {scale_factor:.2e}")
        
        print(f"Initial image stats: min={float(image_cp.min()):.3e}, "
              f"max={float(image_cp.max()):.3e}, mean={float(image_cp.mean()):.3e}")
        
        print(f"Measured data stats: sum={total_measured:.1f}, "
              f"min={float(measured_cp.min()):.3f}, max={float(measured_cp.max()):.3f}")
        
        # Apply attenuation correction more carefully
        print("Applying attenuation correction...")
        try:
            atten_factors = self.forward_project_gpu(atten_cp)
            
            # Check attenuation factors
            print(f"Attenuation factors: min={float(atten_factors.min()):.3f}, "
                  f"max={float(atten_factors.max()):.3f}, "
                  f"mean={float(atten_factors.mean()):.3f}")
            
            if cp.any(cp.isnan(atten_factors)) or cp.any(cp.isinf(atten_factors)):
                print("WARNING: Invalid attenuation factors detected")
                atten_factors = cp.nan_to_num(atten_factors, nan=0.1, posinf=10, neginf=0)
            
            atten_factors = cp.clip(atten_factors, 0, 10)  # Reasonable attenuation range
            atten_correction = cp.exp(-atten_factors)
            atten_correction = cp.clip(atten_correction, 1e-6, 1.0)
            
        except Exception as e:
            print(f"Attenuation correction failed: {e}, using no correction")
            atten_correction = cp.ones(self.num_lors, dtype=cp.float32)
        
        measured_corrected = measured_cp * atten_correction
        
        print(f"Attenuation correction stats: min={float(atten_correction.min()):.6f}, "
              f"max={float(atten_correction.max()):.6f}")
        print(f"Corrected data sum: {float(cp.sum(measured_corrected)):.1f}")
        
        # Pre-allocate arrays
        estimated_sino = cp.zeros(self.num_lors, dtype=cp.float32)
        correction = cp.zeros(self.total_voxels, dtype=cp.float32)
        
        # Initialize corrections
        scatter = cp.zeros(self.num_lors, dtype=cp.float32)
        randoms = cp.full(self.num_lors, max(0.01, total_measured * 0.001), dtype=cp.float32)
        
        prev_likelihood = float('inf')
        
        # MLEM iterations with enhanced stability
        for iteration in range(num_iterations):
            # Forward project with stability check
            estimated_sino = self.forward_project_gpu(image_cp)
            
            # Check for issues
            if cp.any(cp.isnan(estimated_sino)) or cp.any(cp.isinf(estimated_sino)):
                print(f"WARNING: NaN/Inf in forward projection at iteration {iteration + 1}")
                estimated_sino = cp.nan_to_num(estimated_sino, nan=1e-10, posinf=1e10, neginf=1e-10)
            
            # Add corrections with stability
            if iteration % 3 == 0:  # Update scatter less frequently for stability
                try:
                    scatter = self.estimate_scatter_gpu(image_cp)
                    scatter = cp.clip(scatter, 0, total_measured * 0.5)  # Limit scatter
                except:
                    print("Scatter estimation failed, using previous values")
            
            total_estimated = estimated_sino + scatter + randoms
            total_estimated = cp.clip(total_estimated, 1e-8, 1e10)  # Robust clipping
            
            # Check estimated data
            if cp.any(cp.isnan(total_estimated)):
                print(f"WARNING: NaN in total_estimated at iteration {iteration + 1}")
                total_estimated = cp.nan_to_num(total_estimated, nan=1e-8)
            
            # Calculate ratio with stability
            ratio = measured_corrected / total_estimated
            ratio = cp.clip(ratio, 0, 100)  # Limit extreme ratios
            ratio = cp.nan_to_num(ratio, nan=1.0, posinf=1.0, neginf=0.0)
            
            # Back project with stability check
            correction = self.back_project_gpu(ratio)
            
            if cp.any(cp.isnan(correction)) or cp.any(cp.isinf(correction)):
                print(f"WARNING: NaN/Inf in back projection at iteration {iteration + 1}")
                correction = cp.nan_to_num(correction, nan=1.0, posinf=1.0, neginf=0.0)
            
            # MLEM update with enhanced stability
            update_factor = correction / self.sensitivity_image
            update_factor = cp.clip(update_factor, 0.1, 10.0)  # Limit update magnitude
            update_factor = cp.nan_to_num(update_factor, nan=1.0)
            
            image_cp *= update_factor
            image_cp = cp.clip(image_cp, 1e-10, 1e6)  # Robust non-negativity and upper bound
            
            # Stability check on image
            if cp.any(cp.isnan(image_cp)) or cp.any(cp.isinf(image_cp)):
                print(f"ERROR: NaN/Inf in image at iteration {iteration + 1}")
                image_cp = cp.nan_to_num(image_cp, nan=1e-10, posinf=1e6, neginf=1e-10)
            
            # Robust likelihood calculation
            if iteration % 2 == 0:
                try:
                    # Ensure positive values for log
                    log_arg = cp.clip(total_estimated, 1e-10, 1e10)
                    likelihood = float(cp.sum(measured_corrected * cp.log(log_arg) - total_estimated))
                    
                    if not cp.isfinite(likelihood):
                        likelihood = -1e10  # Fallback for numerical issues
                    
                    if iteration > 0 and cp.isfinite(prev_likelihood):
                        rel_change = abs(likelihood - prev_likelihood) / (abs(prev_likelihood) + 1e-10)
                        print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}, "
                              f"Rel. change = {rel_change:.6f}")
                        
                        # Check image stats
                        img_stats = f"Image: min={float(image_cp.min()):.3e}, max={float(image_cp.max()):.3e}, mean={float(image_cp.mean()):.3e}"
                        print(f"  {img_stats}")
                        
                        if rel_change < convergence_threshold:
                            print(f"Converged after {iteration + 1} iterations")
                            break
                    else:
                        print(f"Iteration {iteration + 1}: Log-likelihood = {likelihood:.2f}")
                    
                    prev_likelihood = likelihood
                    
                except Exception as e:
                    print(f"Likelihood calculation failed: {e}")
                    likelihood = -1e10
            else:
                print(f"Iteration {iteration + 1}: (likelihood check skipped)")
        
        # Final stability check
        image_cp = cp.nan_to_num(image_cp, nan=0.0, posinf=1e6, neginf=0.0)
        image_cp = cp.clip(image_cp, 0, 1e6)
        
        # Convert back to PyTorch
        result = torch.as_tensor(image_cp, device=self.device)
        return result
    
    def get_image_3d(self, image_1d: torch.Tensor) -> torch.Tensor:
        """Convert 1D image to 3D array."""
        return image_1d.view(self.grid_size[0], self.grid_size[1], self.grid_size[2])

# Optimized example usage
def optimized_example():
    """Example with GPU optimization features."""
    
    # Enable GPU optimizations
    torch.backends.cudnn.benchmark = True
    torch.cuda.empty_cache()
    
    # Generate test data
    num_lors = 100000  # Larger dataset to show performance gains
    lors = torch.randn(num_lors, 6, device='cuda') * 100
    measured_counts = torch.poisson(torch.ones(num_lors, device='cuda') * 50)
    
    # Initialize optimized reconstructor
    reconstructor = OptimizedCUDAMLEM(lors, voxel_size=2.0, device='cuda')
    
    # Create attenuation mask
    grid_shape = reconstructor.grid_size
    attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2], 
                                device='cuda') * 0.1
    
    # Time the reconstruction
    import time
    start_time = time.time()
    
    # Perform optimized reconstruction
    reconstructed_image = reconstructor.reconstruct_gpu(
        measured_counts, 
        attenuation_mask,
        num_iterations=10
    )
    
    end_time = time.time()
    print(f"\nReconstruction completed in {end_time - start_time:.2f} seconds")
    
    # Convert to 3D
    image_3d = reconstructor.get_image_3d(reconstructed_image)
    
    print(f"Final image shape: {image_3d.shape}")
    print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
          f"mean={image_3d.mean():.3f}")
    
    return reconstructed_image, image_3d

if __name__ == "__main__":
    optimized_example()


# # Optimized example usage
# def optimized_example():
#     """Example with GPU optimization features."""
    
#     # Enable GPU optimizations
#     torch.backends.cudnn.benchmark = True
#     torch.cuda.empty_cache()
#     # Generate test data
#     num_lors = filtered_coordinates.shape[0]  # Use existing LOR count
#     lors = filtered_coordinates.clone().to('cuda')  # Use filtered coordinates directly
#     measured_counts = torch.poisson(torch.ones(num_lors, device='cuda') * 50)  # Fix ones() argument
#     # Initialize optimized reconstructor
#     reconstructor = OptimizedCUDAMLEM(lors, voxel_size=2.0, device='cuda')
    
#     # Create attenuation mask
#     grid_shape = reconstructor.grid_size
#     attenuation_mask = torch.ones(grid_shape[0], grid_shape[1], grid_shape[2], 
#                                 device='cuda') * 0.1
    
#     # Time the reconstruction
#     import time
#     start_time = time.time()
    
#     # Perform optimized reconstruction
#     reconstructed_image = reconstructor.reconstruct_gpu(
#         measured_counts, 
#         attenuation_mask,
#         num_iterations=100
#     )
    
#     end_time = time.time()
#     print(f"\nReconstruction completed in {end_time - start_time:.2f} seconds")
    
#     # Convert to 3D
#     image_3d = reconstructor.get_image_3d(reconstructed_image)
    
#     print(f"Final image shape: {image_3d.shape}")
#     print(f"Image statistics: min={image_3d.min():.3f}, max={image_3d.max():.3f}, "
#           f"mean={image_3d.mean():.3f}")
    
#     # Convert tensor to numpy for plotting
#     recon_np = image3d.cpu().numpy()
#     nx, ny, nz = recon_np.shape

#     def plot_cross_sections_horizontal(x_idx=nx//2, y_idx=ny//2, z_idx=nz//2, vmax=0.5, cmap='Magma'):
#         fig = make_subplots(rows=1, cols=3, subplot_titles=[
#             f'XY plane @ z={z_idx}',
#             f'XZ plane @ y={y_idx}',
#             f'YZ plane @ x={x_idx}'
#         ])

#         # XY plane at z=z_idx
#         fig.add_trace(go.Heatmap(
#             z=recon_np[:, :, z_idx].T,
#             colorscale=cmap,
#             zmax=vmax,
#             zmin=0,
#             showscale=True,
#             name=f'XY @ z={z_idx}'
#         ), row=1, col=1)

#         # XZ plane at y=y_idx
#         fig.add_trace(go.Heatmap(
#             z=recon_np[:, y_idx, :].T,
#             colorscale=cmap,
#             zmax=vmax,
#             zmin=0,
#             showscale=True,
#             name=f'XZ @ y={y_idx}'
#         ), row=1, col=2)

#         # YZ plane at x=x_idx
#         fig.add_trace(go.Heatmap(
#             z=recon_np[x_idx, :, :].T,
#             colorscale=cmap,
#             zmax=vmax,
#             zmin=0,
#             showscale=True,
#             name=f'YZ @ x={x_idx}'
#         ), row=1, col=3)

#         fig.update_layout(
#             width=1200,
#             height=400,
#             title_text="Orthogonal Cross Sections"
#         )
#         fig.show()

#     interact(
#         plot_cross_sections_horizontal,
#         x_idx=IntSlider(min=0, max=nx-1, step=1, value=nx//2, description='X index'),
#         y_idx=IntSlider(min=0, max=ny-1, step=1, value=ny//2, description='Y index'),
#         z_idx=IntSlider(min=0, max=nz-1, step=1, value=nz//2, description='Z index'),
#         vmax=FloatSlider(min=0, max=1, step=0.01, value=0.05, description='vmax'),
#         cmap=['Magma','Greys', 'Viridis', 'Cividis', 'Plasma'])
    
#     return reconstructed_image, image_3d

# if __name__ == "__main__":
#     optimized_example()

Optimized grid: [490, 477, 479], 111,956,670 voxels
Starting optimized GPU MLEM reconstruction...
Building system matrix with GPU acceleration...
LOR data check:
  LOR range: X=[-392.9, 528.4]
             Y=[-449.2, 500.0]
             Z=[-488.9, 463.3]
  Grid origin: [-448.3731384277344, -451.216064453125, -490.93780517578125]
  Grid size: [490, 477, 479]
  Voxel size: 2.0
Total intersections: 100,000
Validating 100000 intersection weights...
After cleaning: 0 valid intersections
ERROR: No valid intersections after cleaning!
System matrix created: shape=(100000, 111956670)
Data range: min=1.000000, max=1.000000
NaN in data: 0
Calculating sensitivity image...
Sensitivity image stats: min=0.000000, max=1.000000, mean=0.000000, zeros=111956669
System matrix: (100000, 111956670), 1 non-zeros
System matrix data check:
  Data stats: min=1.000000, max=1.000000
  NaN count: 0
  Inf count: 0
  Zero count: 0
System matrix nnz: 1
Sensitivity image stats: min=0.000000, max=1.000000, mean=0.00000

KeyboardInterrupt: 