In [None]:
import numpy as np

lor_data = np.load("/home/h/Opengate/Post_process/Outputs/coincidence_pairs.npz")

xyz1 = lor_data['xyz1']  # Shape: (N, 3)
xyz2 = lor_data['xyz2']  # Shape: (N, 3)
lor_data = np.hstack([xyz1, xyz2])  # Shape: (N, 6)

In [None]:
import numpy as np
import numba
from numba import cuda
import math


@cuda.jit(device=True)
def siddon_cuda(x1, y1, z1, x2, y2, z2, grid_size, voxel_size, grid_origin,
                voxel_indices, lengths):
    """
    Siddon's algorithm for CUDA (device function).
    Returns count of intersected voxels.
    """
    nx, ny, nz = grid_size
    x0, y0, z0 = grid_origin
    
    # Direction vector
    dx = x2 - x1
    dy = y2 - y1
    dz = z2 - z1
    
    # Ray length
    ray_length = math.sqrt(dx*dx + dy*dy + dz*dz)
    if ray_length < 1e-10:
        return 0
    
    # Normalize direction
    dx /= ray_length
    dy /= ray_length
    dz /= ray_length
    
    # Compute parametric t values for grid boundaries
    if abs(dx) > 1e-10:
        tx_min = (x0 - x1) / dx
        tx_max = (x0 + nx * voxel_size - x1) / dx
        if tx_min > tx_max:
            tx_min, tx_max = tx_max, tx_min
    else:
        tx_min = -1e10
        tx_max = 1e10
    
    if abs(dy) > 1e-10:
        ty_min = (y0 - y1) / dy
        ty_max = (y0 + ny * voxel_size - y1) / dy
        if ty_min > ty_max:
            ty_min, ty_max = ty_max, ty_min
    else:
        ty_min = -1e10
        ty_max = 1e10
    
    if abs(dz) > 1e-10:
        tz_min = (z0 - z1) / dz
        tz_max = (z0 + nz * voxel_size - z1) / dz
        if tz_min > tz_max:
            tz_min, tz_max = tz_max, tz_min
    else:
        tz_min = -1e10
        tz_max = 1e10
    
    # Overall entry and exit
    t_min = max(max(tx_min, ty_min), tz_min)
    t_max = min(min(tx_max, ty_max), tz_max)
    
    if t_min >= t_max or t_max < 0 or t_min > ray_length:
        return 0
    
    t_min = max(t_min, 0.0)
    t_max = min(t_max, ray_length)
    
    # Starting position
    x_start = x1 + t_min * dx
    y_start = y1 + t_min * dy
    z_start = z1 + t_min * dz
    
    # Starting voxel
    ix = int((x_start - x0) / voxel_size)
    iy = int((y_start - y0) / voxel_size)
    iz = int((z_start - z0) / voxel_size)
    
    ix = max(0, min(ix, nx - 1))
    iy = max(0, min(iy, ny - 1))
    iz = max(0, min(iz, nz - 1))
    
    # Step directions
    step_x = 1 if dx > 0 else -1
    step_y = 1 if dy > 0 else -1
    step_z = 1 if dz > 0 else -1
    
    # Next boundary t values
    if abs(dx) > 1e-10:
        if dx > 0:
            t_next_x = ((ix + 1) * voxel_size + x0 - x1) / dx
        else:
            t_next_x = (ix * voxel_size + x0 - x1) / dx
        t_delta_x = voxel_size / abs(dx)
    else:
        t_next_x = 1e10
        t_delta_x = 1e10
    
    if abs(dy) > 1e-10:
        if dy > 0:
            t_next_y = ((iy + 1) * voxel_size + y0 - y1) / dy
        else:
            t_next_y = (iy * voxel_size + y0 - y1) / dy
        t_delta_y = voxel_size / abs(dy)
    else:
        t_next_y = 1e10
        t_delta_y = 1e10
    
    if abs(dz) > 1e-10:
        if dz > 0:
            t_next_z = ((iz + 1) * voxel_size + z0 - z1) / dz
        else:
            t_next_z = (iz * voxel_size + z0 - z1) / dz
        t_delta_z = voxel_size / abs(dz)
    else:
        t_next_z = 1e10
        t_delta_z = 1e10
    
    # Traverse voxels
    count = 0
    t_current = t_min
    max_voxels = len(voxel_indices)
    
    while t_current < t_max and count < max_voxels:
        t_next = min(min(t_next_x, t_next_y), t_next_z)
        t_next = min(t_next, t_max)
        
        length = t_next - t_current
        
        if length > 1e-10:
            voxel_indices[count] = ix + iy * nx + iz * nx * ny
            lengths[count] = length
            count += 1
        
        if t_next >= t_max:
            break
        
        if abs(t_next - t_next_x) < 1e-9:
            ix += step_x
            if ix < 0 or ix >= nx:
                break
            t_next_x += t_delta_x
        
        if abs(t_next - t_next_y) < 1e-9:
            iy += step_y
            if iy < 0 or iy >= ny:
                break
            t_next_y += t_delta_y
        
        if abs(t_next - t_next_z) < 1e-9:
            iz += step_z
            if iz < 0 or iz >= nz:
                break
            t_next_z += t_delta_z
        
        t_current = t_next
    
    return count


@cuda.jit
def simple_backproject_kernel(lors, backproj, nx, ny, nz, voxel_size, x0, y0, z0):
    idx = cuda.grid(1)
    
    if idx >= lors.shape[0]:
        return
    
    voxels = cuda.local.array(200, numba.int32)
    lens = cuda.local.array(200, numba.float32)
    
    x1 = lors[idx, 0]
    y1 = lors[idx, 1]
    z1 = lors[idx, 2]
    x2 = lors[idx, 3]
    y2 = lors[idx, 4]
    z2 = lors[idx, 5]
    
    count = siddon_cuda(x1, y1, z1, x2, y2, z2,
                       (nx, ny, nz), voxel_size, (x0, y0, z0),
                       voxels, lens)
    
    # Simply accumulate the path lengths
    for i in range(count):
        v_idx = voxels[i]
        cuda.atomic.add(backproj, v_idx, lens[i])


def simple_backproject_gpu(lors, grid_size, voxel_size, batch_size=100000):
    """
    Simple unfiltered backprojection on GPU.
    
    Parameters:
    - lors: (N, 6) array of LOR endpoints [x1,y1,z1,x2,y2,z2] in mm
    - grid_size: tuple (nx, ny, nz) - voxel grid dimensions
    - voxel_size: voxel size in mm
    - batch_size: number of LORs per GPU batch (adjust based on GPU memory)
    
    Returns:
    - image: (nx, ny, nz) backprojected image
    """
    nx, ny, nz = grid_size
    n_voxels = nx * ny * nz
    n_lors = lors.shape[0]
    
    # Center grid at origin
    grid_origin = (
        -nx * voxel_size / 2.0,
        -ny * voxel_size / 2.0,
        -nz * voxel_size / 2.0
    )
    x0, y0, z0 = grid_origin
    
    # Initialize backprojection image
    backproj = np.zeros(n_voxels, dtype=np.float32)
    d_backproj = cuda.to_device(backproj)
    
    threads_per_block = 256
    
    print(f"Starting backprojection on GPU...")
    print(f"Grid size: {grid_size}, Voxel size: {voxel_size} mm")
    print(f"Grid origin: ({x0:.1f}, {y0:.1f}, {z0:.1f}) mm")
    print(f"Total LORs: {n_lors}, Batch size: {batch_size}")
    
    # Process in batches
    for start_idx in range(0, n_lors, batch_size):
        end_idx = min(start_idx + batch_size, n_lors)
        batch_lors = lors[start_idx:end_idx]
        n_batch = batch_lors.shape[0]
        
        d_lors = cuda.to_device(batch_lors.astype(np.float32))
        
        blocks_per_grid = (n_batch + threads_per_block - 1) // threads_per_block
        simple_backproject_kernel[blocks_per_grid, threads_per_block](
            d_lors, d_backproj, nx, ny, nz, voxel_size, x0, y0, z0
        )
        
        if (end_idx % 500000 == 0) or (end_idx == n_lors):
            print(f"  Processed {end_idx}/{n_lors} LORs")
    
    # Copy result back to cpu and reshape
    backproj = d_backproj.copy_to_host()
    image = backproj.reshape(grid_size)
    
    return image


print(f"CUDA available: {cuda.is_available()}")
if cuda.is_available():
    print(f"GPU: {cuda.get_current_device().name.decode()}")

n_lors = len(lor_data)

# grid_size = (128, 128, 128)
grid_size = (256, 256, 256)
# grid_size = (512, 512, 512)
# grid_size = (1024, 1024, 1024)
# grid_size = (1024, 1024, 1024)

voxel_size = 2.0
# voxel_size = 0.5

reconstructed = simple_backproject_gpu(
    lors=lor_data,
    grid_size=grid_size, # (nx, ny, nz)
    voxel_size=voxel_size, # mm
    batch_size=500_000
)

print(f"\nImage shape: {reconstructed.shape}")
print(f"Image range: [{reconstructed.min():.3f}, {reconstructed.max():.3f}]")
print(f"Non-zero voxels: {np.count_nonzero(reconstructed)}")

# Display central slice
# import matplotlib.pyplot as plt
# plt.figure(figsize=(8, 8))
# plt.imshow(reconstructed[:, grid_size[0]//2, :], cmap='hot')
# plt.colorbar(label='Accumulated path length')
# plt.title('Central Slice - Simple Backprojection')
# plt.xlabel('X')
# plt.ylabel('Y')
# plt.tight_layout()
# plt.show()

import plotly
import plotly.express as px
# Animate through slices
fig = px.imshow(
    reconstructed,
    animation_frame=0,
    zmax=reconstructed.max(),
    color_continuous_scale='Hot',
).show()