In [None]:
import numpy as np

lor_data = np.load("/home/h/Opengate/Post_process/Outputs/coincidence_pairs.npz")

xyz1 = lor_data['xyz1']  # Shape: (N, 3)
xyz2 = lor_data['xyz2']  # Shape: (N, 3)
lor_data = np.hstack([xyz1, xyz2])  # Shape: (N, 6)

In [None]:
import numpy as np
import numba
from numba import cuda
import math


@cuda.jit(device=True)
def siddon_cuda(x1, y1, z1, x2, y2, z2, grid_size, voxel_size, grid_origin,
                voxel_indices, lengths):
    """
    Siddon's algorithm for CUDA (device function).
    Returns count of intersected voxels.
    """
    nx, ny, nz = grid_size
    x0, y0, z0 = grid_origin
    
    # Direction vector
    dx = x2 - x1
    dy = y2 - y1
    dz = z2 - z1
    
    # Ray length
    ray_length = math.sqrt(dx*dx + dy*dy + dz*dz)
    if ray_length < 1e-10:
        return 0
    
    # Normalize direction
    dx /= ray_length
    dy /= ray_length
    dz /= ray_length
    
    # Compute parametric t values for grid boundaries
    if abs(dx) > 1e-10:
        tx_min = (x0 - x1) / dx
        tx_max = (x0 + nx * voxel_size - x1) / dx
        if tx_min > tx_max:
            tx_min, tx_max = tx_max, tx_min
    else:
        tx_min = -1e10
        tx_max = 1e10
    
    if abs(dy) > 1e-10:
        ty_min = (y0 - y1) / dy
        ty_max = (y0 + ny * voxel_size - y1) / dy
        if ty_min > ty_max:
            ty_min, ty_max = ty_max, ty_min
    else:
        ty_min = -1e10
        ty_max = 1e10
    
    if abs(dz) > 1e-10:
        tz_min = (z0 - z1) / dz
        tz_max = (z0 + nz * voxel_size - z1) / dz
        if tz_min > tz_max:
            tz_min, tz_max = tz_max, tz_min
    else:
        tz_min = -1e10
        tz_max = 1e10
    
    # Overall entry and exit
    t_min = max(max(tx_min, ty_min), tz_min)
    t_max = min(min(tx_max, ty_max), tz_max)
    
    if t_min >= t_max or t_max < 0 or t_min > ray_length:
        return 0
    
    t_min = max(t_min, 0.0)
    t_max = min(t_max, ray_length)
    
    # Starting position
    x_start = x1 + t_min * dx
    y_start = y1 + t_min * dy
    z_start = z1 + t_min * dz
    
    # Starting voxel
    ix = int((x_start - x0) / voxel_size)
    iy = int((y_start - y0) / voxel_size)
    iz = int((z_start - z0) / voxel_size)
    
    ix = max(0, min(ix, nx - 1))
    iy = max(0, min(iy, ny - 1))
    iz = max(0, min(iz, nz - 1))
    
    # Step directions
    step_x = 1 if dx > 0 else -1
    step_y = 1 if dy > 0 else -1
    step_z = 1 if dz > 0 else -1
    
    # Next boundary t values
    if abs(dx) > 1e-10:
        if dx > 0:
            t_next_x = ((ix + 1) * voxel_size + x0 - x1) / dx
        else:
            t_next_x = (ix * voxel_size + x0 - x1) / dx
        t_delta_x = voxel_size / abs(dx)
    else:
        t_next_x = 1e10
        t_delta_x = 1e10
    
    if abs(dy) > 1e-10:
        if dy > 0:
            t_next_y = ((iy + 1) * voxel_size + y0 - y1) / dy
        else:
            t_next_y = (iy * voxel_size + y0 - y1) / dy
        t_delta_y = voxel_size / abs(dy)
    else:
        t_next_y = 1e10
        t_delta_y = 1e10
    
    if abs(dz) > 1e-10:
        if dz > 0:
            t_next_z = ((iz + 1) * voxel_size + z0 - z1) / dz
        else:
            t_next_z = (iz * voxel_size + z0 - z1) / dz
        t_delta_z = voxel_size / abs(dz)
    else:
        t_next_z = 1e10
        t_delta_z = 1e10
    
    # Traverse voxels
    count = 0
    t_current = t_min
    max_voxels = len(voxel_indices)
    
    while t_current < t_max and count < max_voxels:
        t_next = min(min(t_next_x, t_next_y), t_next_z)
        t_next = min(t_next, t_max)
        
        length = t_next - t_current
        
        if length > 1e-10:
            voxel_indices[count] = ix + iy * nx + iz * nx * ny
            lengths[count] = length
            count += 1
        
        if t_next >= t_max:
            break
        
        if abs(t_next - t_next_x) < 1e-9:
            ix += step_x
            if ix < 0 or ix >= nx:
                break
            t_next_x += t_delta_x
        
        if abs(t_next - t_next_y) < 1e-9:
            iy += step_y
            if iy < 0 or iy >= ny:
                break
            t_next_y += t_delta_y
        
        if abs(t_next - t_next_z) < 1e-9:
            iz += step_z
            if iz < 0 or iz >= nz:
                break
            t_next_z += t_delta_z
        
        t_current = t_next
    
    return count


@cuda.jit
def simple_backproject_kernel(lors, backproj, nx, ny, nz, voxel_size, x0, y0, z0):
    """
    Simple backprojection kernel - just accumulate path lengths.
    """
    idx = cuda.grid(1)
    
    if idx >= lors.shape[0]:
        return
    
    voxels = cuda.local.array(200, numba.int32)
    lens = cuda.local.array(200, numba.float32)
    
    x1 = lors[idx, 0]
    y1 = lors[idx, 1]
    z1 = lors[idx, 2]
    x2 = lors[idx, 3]
    y2 = lors[idx, 4]
    z2 = lors[idx, 5]
    
    count = siddon_cuda(x1, y1, z1, x2, y2, z2,
                       (nx, ny, nz), voxel_size, (x0, y0, z0),
                       voxels, lens)
    
    # Simply accumulate the path lengths
    for i in range(count):
        v_idx = voxels[i]
        cuda.atomic.add(backproj, v_idx, lens[i])


def generate_cylindrical_detectors(radius, n_rings, n_detectors_per_ring, axial_fov):
    """
    Generate detector positions for a cylindrical PET scanner.
    
    Parameters:
    - radius: scanner radius in mm
    - n_rings: number of axial rings
    - n_detectors_per_ring: number of detectors in each ring
    - axial_fov: axial field of view in mm
    
    Returns:
    - detectors: (n_rings * n_detectors_per_ring, 3) array of detector positions
    """
    n_total = n_rings * n_detectors_per_ring
    detectors = np.zeros((n_total, 3), dtype=np.float32)
    
    # Axial positions (z)
    z_positions = np.linspace(-axial_fov/2, axial_fov/2, n_rings)
    
    idx = 0
    for ring_idx in range(n_rings):
        z = z_positions[ring_idx]
        for det_idx in range(n_detectors_per_ring):
            angle = 2 * np.pi * det_idx / n_detectors_per_ring
            x = radius * np.cos(angle)
            y = radius * np.sin(angle)
            detectors[idx] = [x, y, z]
            idx += 1
    
    return detectors


def generate_lors_from_detectors(detectors, n_rings, n_detectors_per_ring, 
                                 max_ring_difference=None):
    """
    Generate all valid LORs from detector pairs.
    
    Parameters:
    - detectors: (N, 3) array of detector positions
    - n_rings: number of axial rings
    - n_detectors_per_ring: detectors per ring
    - max_ring_difference: maximum ring difference for coincidences (None = all)
    
    Returns:
    - lors: (M, 6) array of LORs [x1,y1,z1,x2,y2,z2]
    """
    if max_ring_difference is None:
        max_ring_difference = n_rings - 1
    
    lors_list = []
    
    print(f"Generating LORs from detector geometry...")
    print(f"  Max ring difference: {max_ring_difference}")
    
    for ring1 in range(n_rings):
        for ring2 in range(ring1, min(ring1 + max_ring_difference + 1, n_rings)):
            # For each ring pair, create LORs between detectors
            for det1 in range(n_detectors_per_ring):
                idx1 = ring1 * n_detectors_per_ring + det1
                
                # Determine opposing detectors (avoid adjacent detectors in same ring)
                if ring1 == ring2:
                    # Same ring - only connect to opposite side (±90 to ±180 degrees)
                    det2_start = n_detectors_per_ring // 4
                    det2_end = 3 * n_detectors_per_ring // 4 + 1
                else:
                    # Different rings - can connect to all detectors
                    det2_start = 0
                    det2_end = n_detectors_per_ring
                
                for det2 in range(det2_start, det2_end):
                    idx2 = ring2 * n_detectors_per_ring + det2
                    if idx1 < idx2:  # Avoid duplicates
                        lor = np.hstack([detectors[idx1], detectors[idx2]])
                        lors_list.append(lor)
        
        if (ring1 + 1) % 10 == 0:
            print(f"  Processed ring {ring1 + 1}/{n_rings}")
    
    lors = np.array(lors_list, dtype=np.float32)
    print(f"Total LORs generated: {len(lors):,}")
    
    return lors


def backproject_gpu_batch(lors, grid_size, voxel_size, grid_origin, batch_size=100000):
    """
    Internal function for batched GPU backprojection.
    """
    nx, ny, nz = grid_size
    n_voxels = nx * ny * nz
    n_lors = lors.shape[0]
    x0, y0, z0 = grid_origin
    
    # Initialize backprojection image
    backproj = np.zeros(n_voxels, dtype=np.float32)
    d_backproj = cuda.to_device(backproj)
    
    threads_per_block = 256
    
    # Process in batches
    for start_idx in range(0, n_lors, batch_size):
        end_idx = min(start_idx + batch_size, n_lors)
        batch_lors = lors[start_idx:end_idx]
        n_batch = batch_lors.shape[0]
        
        d_lors = cuda.to_device(batch_lors.astype(np.float32))
        
        blocks_per_grid = (n_batch + threads_per_block - 1) // threads_per_block
        simple_backproject_kernel[blocks_per_grid, threads_per_block](
            d_lors, d_backproj, nx, ny, nz, voxel_size, x0, y0, z0
        )
        
        if (end_idx % 500000 == 0) or (end_idx == n_lors):
            print(f"    Processed {end_idx:,}/{n_lors:,} LORs")
    
    # Copy result back
    backproj = d_backproj.copy_to_host()
    
    return backproj


def compute_sensitivity_map(scanner_radius, n_rings, n_detectors_per_ring, axial_fov,
                           grid_size, voxel_size, max_ring_difference=None, 
                           batch_size=100000):
    """
    Compute sensitivity map for a virtual PET scanner.
    
    Parameters:
    - scanner_radius: scanner radius in mm
    - n_rings: number of axial rings
    - n_detectors_per_ring: number of detectors per ring
    - axial_fov: axial field of view in mm
    - grid_size: tuple (nx, ny, nz) - voxel grid dimensions
    - voxel_size: voxel size in mm
    - max_ring_difference: maximum ring difference for coincidences (None = all rings)
    - batch_size: LORs per GPU batch
    
    Returns:
    - sensitivity_map: (nx, ny, nz) sensitivity image
    """
    print("\n" + "="*60)
    print("COMPUTING SENSITIVITY MAP")
    print("="*60)
    
    # Generate detector positions
    print(f"\nScanner geometry:")
    print(f"  Radius: {scanner_radius} mm")
    print(f"  Rings: {n_rings}")
    print(f"  Detectors per ring: {n_detectors_per_ring}")
    print(f"  Axial FOV: {axial_fov} mm")
    
    detectors = generate_cylindrical_detectors(scanner_radius, n_rings, 
                                               n_detectors_per_ring, axial_fov)
    
    # Generate all valid LORs
    lors = generate_lors_from_detectors(detectors, n_rings, n_detectors_per_ring,
                                       max_ring_difference)
    
    # Compute grid origin
    nx, ny, nz = grid_size
    grid_origin = (
        -nx * voxel_size / 2.0,
        -ny * voxel_size / 2.0,
        -nz * voxel_size / 2.0
    )
    
    print(f"\nBackprojecting LORs to compute sensitivity...")
    print(f"  Grid size: {grid_size}")
    print(f"  Voxel size: {voxel_size} mm")
    print(f"  Batch size: {batch_size:,}")
    
    sensitivity_flat = backproject_gpu_batch(lors, grid_size, voxel_size, 
                                            grid_origin, batch_size)
    sensitivity_map = sensitivity_flat.reshape(grid_size)
    
    print(f"\nSensitivity map computed!")
    print(f"  Range: [{sensitivity_map.min():.3f}, {sensitivity_map.max():.3f}]")
    print(f"  Non-zero voxels: {np.count_nonzero(sensitivity_map):,}")
    
    return sensitivity_map


def normalized_backproject_gpu(data_lors, sensitivity_map, grid_size, voxel_size, 
                               batch_size=100000):
    """
    Perform normalized backprojection using pre-computed sensitivity map.
    
    Parameters:
    - data_lors: (N, 6) array of data LOR endpoints [x1,y1,z1,x2,y2,z2] in mm
    - sensitivity_map: (nx, ny, nz) pre-computed sensitivity map
    - grid_size: tuple (nx, ny, nz) - voxel grid dimensions
    - voxel_size: voxel size in mm
    - batch_size: number of LORs per GPU batch
    
    Returns:
    - normalized_image: (nx, ny, nz) normalized backprojected image
    """
    print("\n" + "="*60)
    print("BACKPROJECTING DATA")
    print("="*60)
    
    nx, ny, nz = grid_size
    n_lors = data_lors.shape[0]
    
    # Center grid at origin
    grid_origin = (
        -nx * voxel_size / 2.0,
        -ny * voxel_size / 2.0,
        -nz * voxel_size / 2.0
    )
    
    print(f"\nData backprojection:")
    print(f"  Total data LORs: {n_lors:,}")
    print(f"  Grid size: {grid_size}")
    print(f"  Voxel size: {voxel_size} mm")
    
    backproj_flat = backproject_gpu_batch(data_lors, grid_size, voxel_size,
                                         grid_origin, batch_size)
    backproj_image = backproj_flat.reshape(grid_size)
    
    print(f"\nNormalizing with sensitivity map...")
    
    # Normalize by sensitivity (avoid division by zero)
    sensitivity_threshold = 1e-10
    mask = sensitivity_map > sensitivity_threshold
    normalized_image = np.zeros_like(backproj_image)
    normalized_image[mask] = backproj_image[mask] / sensitivity_map[mask]
    
    print(f"\nNormalized backprojection complete!")
    print(f"  Range: [{normalized_image.min():.3f}, {normalized_image.max():.3f}]")
    print(f"  Non-zero voxels: {np.count_nonzero(normalized_image):,}")
    
    return normalized_image


print(f"CUDA available: {cuda.is_available()}")
if cuda.is_available():
    print(f"GPU: {cuda.get_current_device().name.decode()}")

# Scanner geometry parameters
scanner_radius = 400.0  # mm
n_rings = 10
n_detectors_per_ring = 128
axial_fov = 160.0  # mm
max_ring_difference = None  # Limit axial acceptance

# Reconstruction grid
grid_size = (128, 128, 128)
voxel_size = 2.0
batch_size = 500_000

# Step 1: Compute sensitivity map
sensitivity_map = compute_sensitivity_map(
    scanner_radius=scanner_radius,
    n_rings=n_rings,
    n_detectors_per_ring=n_detectors_per_ring,
    axial_fov=axial_fov,
    grid_size=grid_size,
    voxel_size=voxel_size, # mm
    max_ring_difference=max_ring_difference,
    batch_size=batch_size
)

# Step 3: Normalized backprojection
reconstructed = normalized_backproject_gpu(
    data_lors=lor_data,
    sensitivity_map=sensitivity_map,
    grid_size=grid_size,
    voxel_size=voxel_size,
    batch_size=batch_size
)

# # Visualization
# import matplotlib.pyplot as plt

# fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# # Sensitivity map
# axes[0].imshow(sensitivity_map[:, :, 64], cmap='hot')
# axes[0].set_title('Sensitivity Map (Central Slice)')
# axes[0].set_xlabel('X')
# axes[0].set_ylabel('Y')

# # Raw backprojection (before normalization)
# backproj_raw = backproject_gpu_batch(
#     data_lors, grid_size, voxel_size,
#     (-grid_size[0]*voxel_size/2, -grid_size[1]*voxel_size/2, -grid_size[2]*voxel_size/2),
#     batch_size
# ).reshape(grid_size)
# axes[1].imshow(backproj_raw[:, :, 64], cmap='hot')
# axes[1].set_title('Raw Backprojection')
# axes[1].set_xlabel('X')
# axes[1].set_ylabel('Y')

# # Normalized reconstruction
# axes[2].imshow(reconstructed[:, :, 64], cmap='hot')
# axes[2].set_title('Normalized Reconstruction')
# axes[2].set_xlabel('X')
# axes[2].set_ylabel('Y')

# for ax in axes:
#     ax.axis('equal')

# plt.tight_layout()
# plt.savefig('pet_reconstruction_comparison.png', dpi=150, bbox_inches='tight')
# plt.show()

import plotly
import plotly.express as px

# Animate through slices
fig = px.imshow(
    np.log10(reconstructed),
    animation_frame=0,
    zmax=reconstructed.max(),
    color_continuous_scale='Hot',
).show()

In [None]:

# Animate through slices
fig = px.imshow(
    sensitivity_map,
    animation_frame=0,
    zmax=sensitivity_map.max(),
    color_continuous_scale='Hot',
).show()