In [30]:
import numpy as np
from tqdm import trange

def mm_to_voxel(coords_mm, vol_shape, voxel_size_mm):
    """Convert mm coords to voxel indices centered at (0,0,0)."""
    origin = np.array(vol_shape) * voxel_size_mm / 2
    return (coords_mm + origin) / voxel_size_mm

def siddon_ray_trace(start, end, vol_shape, voxel_size):
    start = np.array(start, dtype=np.float64)
    end = np.array(end, dtype=np.float64)
    d = end - start
    bounds_min = np.zeros(3)
    bounds_max = np.array(vol_shape) - 1

    with np.errstate(divide='ignore', invalid='ignore'):
        tmin = (bounds_min - start) / d
        tmax = (bounds_max - start) / d

    t1 = np.minimum(tmin, tmax)
    t2 = np.maximum(tmin, tmax)

    tnear = np.max(t1)
    tfar = np.min(t2)

    if tnear > tfar or tfar < 0:
        return np.empty((0, 3), dtype=int), np.array([], dtype=float)

    t = max(tnear, 0.0)
    p = start + t * d
    voxel = np.floor(p).astype(int)

    voxels = []
    lengths = []

    step = np.sign(d).astype(int)
    t_max = np.zeros(3)
    t_delta = np.zeros(3)

    for i in range(3):
        if d[i] != 0:
            if step[i] > 0:
                t_max[i] = ((voxel[i] + 1) - p[i]) / d[i]
            else:
                t_max[i] = (voxel[i] - p[i]) / d[i]
            t_delta[i] = abs(1.0 / d[i])
        else:
            t_max[i] = np.inf
            t_delta[i] = np.inf

    while True:
        if not (0 <= voxel[0] < vol_shape[0] and
                0 <= voxel[1] < vol_shape[1] and
                0 <= voxel[2] < vol_shape[2]):
            break

        axis = np.argmin(t_max)
        t_next = t_max[axis]
        segment_length = (t_next - t) * np.linalg.norm(d) * voxel_size
        voxels.append(voxel.copy())
        lengths.append(segment_length)

        voxel[axis] += step[axis]
        t = t_next
        t_max[axis] += t_delta[axis]

        if t > tfar:
            break

    return np.array(voxels, dtype=int), np.array(lengths, dtype=float)

def forward_project(image, lors, vol_shape, voxel_size):
    projections = np.zeros(len(lors), dtype=np.float32)
    for i, lor in enumerate(lors):
        voxels, lengths = siddon_ray_trace(lor[:3], lor[3:], vol_shape, voxel_size)
        if len(voxels) > 0:
            projections[i] = np.sum(image[tuple(voxels.T)] * lengths)
        else:
            projections[i] = 1e-8  # avoid zero projection
    return projections

def back_project(ratio, lors, vol_shape, voxel_size):
    backproj = np.zeros(vol_shape, dtype=np.float32)
    for i, lor in enumerate(lors):
        voxels, lengths = siddon_ray_trace(lor[:3], lor[3:], vol_shape, voxel_size)
        if len(voxels) > 0:
            backproj[tuple(voxels.T)] += ratio[i] * lengths
    return backproj

def mlem_reconstruction(lors_mm, num_iters=10, vol_shape=(128,128,128), voxel_size=2.0):
    """
    Full MLEM reconstruction on LOR data in mm.
    
    Args:
      lors_mm: np.array of shape (num_pairs, 6) with (x1,y1,z1,x2,y2,z2) in mm
      num_iters: number of MLEM iterations
      vol_shape: output volume shape (nx, ny, nz)
      voxel_size: voxel size in mm
    
    Returns:
      3D numpy array of reconstructed image
    """
    # Convert LOR endpoints from mm to voxel coords
    lors_vox = np.hstack([
        mm_to_voxel(lors_mm[:, :3], vol_shape, voxel_size),
        mm_to_voxel(lors_mm[:, 3:], vol_shape, voxel_size)
    ])

    img = np.ones(vol_shape, dtype=np.float32)
    sensitivity = back_project(np.ones(len(lors_vox)), lors_vox, vol_shape, voxel_size)
    sensitivity[sensitivity == 0] = 1e-8  # avoid div by zero

    for _ in trange(num_iters, desc="MLEM iterations"):
        proj_est = forward_project(img, lors_vox, vol_shape, voxel_size)
        proj_est[proj_est == 0] = 1e-8  # avoid div by zero
        ratio = 1.0 / proj_est
        backproj = back_project(ratio, lors_vox, vol_shape, voxel_size)
        img *= backproj / sensitivity

    return img


In [31]:
# LOAD LOR DATA
# Initial Loading, Filtering, and Coordinate Range Calculation
import numpy as np

# Load coordinates
coordinates = np.load(fr"C:\Users\h\Desktop\PetStuff\Image_Processing\ground_truth.npy")

# Confirm shape should be (pairs, coords=6), coords are (x1, y1, z1, x2, y2, z2)
print(f"\nData Shape (pairs, coords) : {coordinates.shape}\n")  

# Remove pairs where any coordinate value is exactly 0
filtered_coordinates = coordinates[~np.any(coordinates == 0, axis=1)]
print(f"Filtered shape: {filtered_coordinates.shape}\n")
# filtered_coordinates = torch.from_numpy(filtered_coordinates).float()

# Extract all x, y, z pairs - Coordinates are in the order (x1, y1, z1, x2, y2, z2)
all_xyz = filtered_coordinates.reshape(-1, 3) # Reshape to (pairs, 3) for (x, y, z)
x_vals, y_vals, z_vals = all_xyz[:, 0], all_xyz[:, 1], all_xyz[:, 2]
print(f"x range: min={x_vals.min()}, max={x_vals.max()}")
print(f"y range: min={y_vals.min()}, max={y_vals.max()}")
print(f"z range: min={z_vals.min()}, max={z_vals.max()}")


Data Shape (pairs, coords) : (62660, 6)

Filtered shape: (6591, 6)

x range: min=-278.12942504882807, max=278.1666564941406
y range: min=-278.41946411132807, max=277.8843688964844
z range: min=-147.99453735351562, max=147.9492950439453


In [35]:
# lors_mm is your LOR array of shape (num_pairs, 6), units in mm
reconstructed_volume = mlem_reconstruction(filtered_coordinates, num_iters=4, vol_shape=(600,600,600), voxel_size=1.0)


MLEM iterations: 100%|██████████| 4/4 [02:52<00:00, 43.20s/it]


In [36]:
# VISUALISATION FUNCTIONS
import plotly.graph_objects as go
import plotly.express as px
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider

def visualize_voxel_tensor_3d(voxel_tensor, initial_min_threshold=None, initial_max_threshold=None, 
                               voxel_size_mm=1.0, world_origin=None, min_threshold=None, max_threshold=None):
    """
    Interactive 3D visualization of voxel tensor with dual threshold sliders.

    Args:
        voxel_tensor: (nx, ny, nz) numpy array with voxel counts
        initial_min_threshold: Initial minimum threshold value for the slider (default: min_val)
        initial_max_threshold: Initial maximum threshold value for the slider (default: max_val)
        voxel_size_mm: Size of each voxel in mm (default: 1.0mm)
        world_origin: (x_min, y_min, z_min) world coordinates of voxel (0,0,0) (optional)
        min_threshold: Minimum threshold value for slider range (optional)
        max_threshold: Maximum threshold value for slider range (optional)
    """
    # Extract non-zero voxel coordinates and values
    coords = np.where(voxel_tensor > 0)
    x_coords, y_coords, z_coords = coords
    values = voxel_tensor[coords]

    # Convert voxel indices to world coordinates if world_origin provided
    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords_world = x_coords * voxel_size_mm + x_min
        y_coords_world = y_coords * voxel_size_mm + y_min
        z_coords_world = z_coords * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords_world = x_coords * voxel_size_mm
        y_coords_world = y_coords * voxel_size_mm
        z_coords_world = z_coords * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"

    # Get value range for sliders
    min_val = float(np.min(values))
    max_val = float(np.max(values))

    # Use user-specified min/max threshold range if provided
    slider_min = min_threshold if min_threshold is not None else min_val
    slider_max = max_threshold if max_threshold is not None else max_val

    # Set initial thresholds with defaults
    if initial_min_threshold is None:
        initial_min_threshold = slider_min
    else:
        initial_min_threshold = max(slider_min, min(slider_max, float(initial_min_threshold)))
    
    if initial_max_threshold is None:
        initial_max_threshold = slider_max
    else:
        initial_max_threshold = max(slider_min, min(slider_max, float(initial_max_threshold)))

    # Ensure min <= max
    if initial_min_threshold > initial_max_threshold:
        initial_min_threshold, initial_max_threshold = initial_max_threshold, initial_min_threshold

    print(f"Voxel value range: {min_val} to {max_val}")
    print(f"Total non-zero voxels: {len(values)}")
    print(f"Initial thresholds: {initial_min_threshold} to {initial_max_threshold}")
    print(f"Slider range: {slider_min} to {slider_max}")
    print(f"Voxel resolution: {voxel_size_mm}mm")

    def update_plot(min_thresh, max_thresh):
        # Ensure min <= max
        if min_thresh > max_thresh:
            min_thresh, max_thresh = max_thresh, min_thresh

        # Filter voxels within threshold range
        mask = (values >= min_thresh) & (values <= max_thresh)
        if not np.any(mask):
            print(f"No voxels in threshold range [{min_thresh}, {max_thresh}]")
            return

        filtered_x = x_coords_world[mask]
        filtered_y = y_coords_world[mask]
        filtered_z = z_coords_world[mask]
        filtered_values = values[mask]

        # Create 3D scatter plot
        fig = go.Figure(data=go.Scatter3d(
            x=filtered_x,
            y=filtered_y,
            z=filtered_z,
            mode='markers',
            marker=dict(
                size=1,
                color=filtered_values,
                colorscale='Viridis',
                opacity=0.8,
                colorbar=dict(title="Voxel Count"),
                line=dict(width=0)
            ),
            text=[f'Count: {v}' for v in filtered_values],
            hovertemplate='<b>Voxel (%{x:.1f}, %{y:.1f}, %{z:.1f})</b><br>%{text}<extra></extra>'
        ))

        fig.update_layout(
            title=f'3D Voxel Visualization (Range: [{min_thresh:.6f}, {max_thresh:.6f}], Showing: {len(filtered_values)} voxels)',
            scene=dict(
                xaxis_title=f'X{coord_suffix}',
                yaxis_title=f'Y{coord_suffix}',
                zaxis_title=f'Z{coord_suffix}',
                camera=dict(
                    eye=dict(x=1.5, y=1.5, z=1.5)
                ),
                aspectmode='cube'
            ),
            width=800,
            height=600
        )

        fig.show()

    # Create interactive sliders with linked constraints
    min_threshold_slider = FloatSlider(
        value=initial_min_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Min Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    max_threshold_slider = FloatSlider(
        value=initial_max_threshold,
        min=slider_min,
        max=slider_max,
        step=0.01,
        description='Max Threshold:',
        continuous_update=False,
        style={'description_width': 'initial'}
    )

    # Link sliders to maintain min <= max constraint
    def on_min_change(change):
        if change['new'] > max_threshold_slider.value:
            max_threshold_slider.value = change['new']

    def on_max_change(change):
        if change['new'] < min_threshold_slider.value:
            min_threshold_slider.value = change['new']

    min_threshold_slider.observe(on_min_change, names='value')
    max_threshold_slider.observe(on_max_change, names='value')

    interact(update_plot, 
             min_thresh=min_threshold_slider, 
             max_thresh=max_threshold_slider)

# visualize_voxel_tensor_3d(reconstructed_volume)

In [None]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import numpy as np
from ipywidgets import interact, IntSlider, FloatSlider

def visualize_voxel_tensor_slices(voxel_tensor, voxel_size_mm=1.0, world_origin=None,
                                            initial_min_threshold=None, initial_max_threshold=None,
                                            min_threshold=None, max_threshold=None):
    nx, ny, nz = voxel_tensor.shape

    if world_origin is not None:
        x_min, y_min, z_min = world_origin
        x_coords = np.arange(nx) * voxel_size_mm + x_min
        y_coords = np.arange(ny) * voxel_size_mm + y_min
        z_coords = np.arange(nz) * voxel_size_mm + z_min
        coord_suffix = " (mm)"
    else:
        x_coords = np.arange(nx) * voxel_size_mm
        y_coords = np.arange(ny) * voxel_size_mm
        z_coords = np.arange(nz) * voxel_size_mm
        coord_suffix = f" (×{voxel_size_mm}mm)"

    voxel_values = voxel_tensor.flatten()
    voxel_values = voxel_values[voxel_values > 0]
    min_val = float(np.min(voxel_values)) if voxel_values.size > 0 else 0.0
    max_val = float(np.max(voxel_values)) if voxel_values.size > 0 else 1.0

    slider_min = min_threshold if min_threshold is not None else min_val
    slider_max = max_threshold if max_threshold is not None else max_val

    initial_min = initial_min_threshold if initial_min_threshold is not None else slider_min
    initial_max = initial_max_threshold if initial_max_threshold is not None else slider_max
    initial_min = max(slider_min, min(slider_max, initial_min))
    initial_max = max(slider_min, min(slider_max, initial_max))
    if initial_min > initial_max:
        initial_min, initial_max = initial_max, initial_min

    def update_slices(x_slice_idx, y_slice_idx, z_slice_idx, min_thresh, max_thresh):
        if min_thresh > max_thresh:
            min_thresh, max_thresh = max_thresh, min_thresh

        slice_xy = voxel_tensor[:, :, z_slice_idx]
        slice_xz = voxel_tensor[:, y_slice_idx, :]
        slice_yz = voxel_tensor[x_slice_idx, :, :]

        mask_xy = (slice_xy >= min_thresh) & (slice_xy <= max_thresh)
        mask_xz = (slice_xz >= min_thresh) & (slice_xz <= max_thresh)
        mask_yz = (slice_yz >= min_thresh) & (slice_yz <= max_thresh)

        fig = make_subplots(rows=1, cols=3,
                            subplot_titles=[f'XY slice at Z = {z_coords[z_slice_idx]:.2f}{coord_suffix}',
                                            f'XZ slice at Y = {y_coords[y_slice_idx]:.2f}{coord_suffix}',
                                            f'YZ slice at X = {x_coords[x_slice_idx]:.2f}{coord_suffix}'])

        # XY Plane
        fig.add_trace(go.Heatmap(
            z=np.where(mask_xy, slice_xy, np.nan).T,
            x=x_coords,
            y=y_coords,
            colorscale='Viridis',
            zmin=min_thresh,
            zmax=max_thresh,
            showscale=True,
            colorbar=dict(title='Voxel value')), row=1, col=1)
        fig.update_xaxes(title_text=f'X{coord_suffix}', row=1, col=1)
        fig.update_yaxes(title_text=f'Y{coord_suffix}', autorange='reversed', row=1, col=1)

        # XZ Plane
        fig.add_trace(go.Heatmap(
            z=np.where(mask_xz, slice_xz, np.nan).T,
            x=x_coords,
            y=z_coords,
            colorscale='Viridis',
            zmin=min_thresh,
            zmax=max_thresh,
            showscale=False), row=1, col=2)
        fig.update_xaxes(title_text=f'X{coord_suffix}', row=1, col=2)
        fig.update_yaxes(title_text=f'Z{coord_suffix}', autorange='reversed', row=1, col=2)

        # YZ Plane
        fig.add_trace(go.Heatmap(
            z=np.where(mask_yz, slice_yz, np.nan).T,
            x=y_coords,
            y=z_coords,
            colorscale='Viridis',
            zmin=min_thresh,
            zmax=max_thresh,
            showscale=False), row=1, col=3)
        fig.update_xaxes(title_text=f'Y{coord_suffix}', row=1, col=3)
        fig.update_yaxes(title_text=f'Z{coord_suffix}', autorange='reversed', row=1, col=3)

        fig.update_layout(width=1200, height=400, title_text='Voxel Tensor Slices (Horizontal Layout)')
        fig.show()

    slider_x = IntSlider(min=0, max=nx-1, step=1, value=nx//2, description='X slice', continuous_update=False)
    slider_y = IntSlider(min=0, max=ny-1, step=1, value=ny//2, description='Y slice', continuous_update=False)
    slider_z = IntSlider(min=0, max=nz-1, step=1, value=nz//2, description='Z slice', continuous_update=False)
    slider_min_th = FloatSlider(min=slider_min, max=slider_max, step=0.01, value=initial_min,
                               description='Min Threshold', continuous_update=False)
    slider_max_th = FloatSlider(min=slider_min, max=slider_max, step=0.01, value=initial_max,
                               description='Max Threshold', continuous_update=False)

    def on_min_thresh_change(change):
        if change['new'] > slider_max_th.value:
            slider_max_th.value = change['new']

    def on_max_thresh_change(change):
        if change['new'] < slider_min_th.value:
            slider_min_th.value = change['new']

    slider_min_th.observe(on_min_thresh_change, names='value')
    slider_max_th.observe(on_max_thresh_change, names='value')

    interact(update_slices,
             x_slice_idx=slider_x,
             y_slice_idx=slider_y,
             z_slice_idx=slider_z,
             min_thresh=slider_min_th,
             max_thresh=slider_max_th)

visualize_voxel_tensor_slices(reconstructed_volume)

interactive(children=(IntSlider(value=300, continuous_update=False, description='X slice', max=599), IntSlider…