# Frangi GPU playground
Interactively sweep Frangi parameters on a NIfTI volume with optional GPU acceleration (cucim/CuPy).
- Edit the path below, then drag the sliders to update the output slice in real time.
- Use the GPU toggle when available; otherwise it will fall back to scikit-image on CPU.


In [18]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt

import ipywidgets as widgets
from IPython.display import display
from pathlib import Path

%matplotlib inline

# Optional GPU path via cucim
cp = None
cu_frangi = None
HAS_CUPY = False
gpu_setup_error = ""
try:
    import cupy as cp
    from cucim.skimage.filters import frangi as cu_frangi
    HAS_CUPY = cp.cuda.runtime.getDeviceCount() > 0
except Exception as exc:
    cp = None
    cu_frangi = None
    HAS_CUPY = False
    gpu_setup_error = repr(exc)

# CPU fallback
try:
    from skimage.filters import frangi as sk_frangi
except Exception as exc:
    sk_frangi = None
    cpu_setup_error = repr(exc)
else:
    cpu_setup_error = ""

print(f"cucim/CuPy GPU available: {HAS_CUPY}")
if not HAS_CUPY and gpu_setup_error:
    print(f"GPU setup issue (ignored): {gpu_setup_error}")
if sk_frangi is None:
    print(f"scikit-image unavailable; GPU path required. ({cpu_setup_error})")


cucim/CuPy GPU available: True


In [19]:
# --- Set your volume path here ---
# Auto-detects the repo root so it works when launched from visualization/.
from typing import Optional

def find_data_images_dir() -> Optional[Path]:
    for base in [Path.cwd(), *Path.cwd().parents]:
        candidate = base / 'data' / 'images'
        if candidate.exists():
            return candidate
    return None

data_images_dir = find_data_images_dir()
if data_images_dir is None:
    raise FileNotFoundError('Could not find data/images folder by walking up from current working directory.')

default_candidates = sorted(data_images_dir.glob('image_*.nii*'))
if not default_candidates:
    raise FileNotFoundError(f'No NIfTI files matching image_* found under {data_images_dir}')
default_image = default_candidates[0]

# Override this if you want a specific file
image_path = default_image

if not image_path.exists():
    raise FileNotFoundError(f'Missing volume at {image_path}')

img = nib.load(str(image_path))
volume = img.get_fdata().astype(np.float32, copy=False)
spacing = img.header.get_zooms()[: volume.ndim]

gpu_volume = cp.asarray(volume) if HAS_CUPY else None

print(f'Loaded {image_path} with shape {volume.shape}')
print(f'Spacing (used for mm->voxel conversion): {spacing}')
slice_defaults = {
    'axis': 2,
    'idx': volume.shape[2] // 2,
}


Loaded /home/haozhe/noisy_segment/data/images/image_001.nii.gz with shape (512, 512, 258)
Spacing (used for mm->voxel conversion): (np.float32(0.6464844), np.float32(0.6464844), np.float32(1.0))


In [20]:
def parse_sigmas(text: str):
    if not text:
        return [1.0]
    vals = []
    for chunk in text.replace(";", ",").split(","):
        chunk = chunk.strip()
        if not chunk:
            continue
        vals.append(float(chunk))
    if not vals:
        raise ValueError("Provide at least one numeric sigma.")
    return vals


def sigmas_mm_to_vox(sigmas_mm, spacing):
    spacing_arr = np.array(spacing, dtype=float)
    iso = np.allclose(spacing_arr, spacing_arr[0])

    def _convert(val):
        if iso:
            return float(val / spacing_arr[0])
        return tuple(float(val / s) for s in spacing_arr)

    return [_convert(v) for v in sigmas_mm]


def pick_slice(volume: np.ndarray, axis: int, idx: int):
    idx = int(np.clip(idx, 0, volume.shape[axis] - 1))
    return idx, np.take(volume, indices=idx, axis=axis)


def get_vrange(vol_a: np.ndarray, vol_b: np.ndarray, low=1, high=99):
    combined = np.concatenate([vol_a.ravel(), vol_b.ravel()])
    lo, hi = np.percentile(combined, [low, high])
    if hi <= lo:
        hi = lo + 1e-3
    return float(lo), float(hi)


def frangi_volume(alpha: float, beta: float, gamma: float, sigmas_text: str, use_gpu: bool, white_ridges: bool, sigma_unit: str = "mm") -> np.ndarray:
    sigmas = parse_sigmas(sigmas_text)
    sigmas_for_filter = sigmas_mm_to_vox(sigmas, spacing) if sigma_unit == "mm" else sigmas
    if use_gpu and not HAS_CUPY:
        print("GPU requested but CuPy/cucim not available; using CPU instead.")
    if use_gpu and HAS_CUPY and cu_frangi is not None:
        x = gpu_volume if gpu_volume is not None else cp.asarray(volume, dtype=cp.float32)
        filtered = cu_frangi(
            x,
            sigmas=sigmas_for_filter,
            alpha=alpha,
            beta=beta,
            gamma=gamma,
            black_ridges=not white_ridges,
            mode="nearest",
        )
        return cp.asnumpy(filtered)
    if sk_frangi is None:
        raise RuntimeError("GPU Frangi unavailable and scikit-image not installed.")
    return sk_frangi(
        volume,
        sigmas=sigmas_for_filter,
        alpha=alpha,
        beta=beta,
        gamma=gamma,
        black_ridges=not white_ridges,
        mode="nearest",
    )


In [None]:
# --- Controls ---
sigma_text = widgets.Text(value="0.5,1,1.5,2,2.5,3,3.5", description="Sigmas", layout=widgets.Layout(width="240px"))
sigma_unit = widgets.ToggleButtons(options=[("mm", "mm"), ("voxels", "vox")], value="mm", description="Unit")
alpha_slider = widgets.FloatSlider(value=0.5, min=0.05, max=2.0, step=0.05, continuous_update=False, readout_format=".2f", description="alpha")
beta_slider = widgets.FloatSlider(value=0.5, min=0.05, max=2.0, step=0.05, continuous_update=False, readout_format=".2f", description="beta")
gamma_slider = widgets.FloatSlider(value=15.0, min=1.0, max=300.0, step=0.5, continuous_update=False, readout_format=".1f", description="gamma")
thr_slider = widgets.FloatSlider(value=0.20, min=0.0, max=1.5, step=0.01, continuous_update=False, readout_format=".2f", description="Thresh")
white_checkbox = widgets.Checkbox(value=False, description="White ridges")
gpu_checkbox = widgets.Checkbox(value=HAS_CUPY, description="Use GPU (cucim)", disabled=not HAS_CUPY)
clip_slider = widgets.IntRangeSlider(value=[1, 99], min=0, max=100, step=1, description="Clip %", continuous_update=False)

# param using https://oa.upm.es/28902/1/INVE_MEM_2013_166109.pdf
slice_axis = widgets.Dropdown(
    options=[("Sagittal (0)", 0), ("Coronal (1)", 1), ("Axial (2)", 2)],
    value=slice_defaults["axis"],
    description="Slice axis",
)
slice_idx = widgets.IntSlider(
    min=0,
    max=volume.shape[slice_defaults["axis"]] - 1,
    value=slice_defaults["idx"],
    step=1,
    continuous_update=False,
    description="Slice idx",
)

def _update_slice_idx(change):
    axis = change["new"]
    slice_idx.max = volume.shape[axis] - 1
    slice_idx.value = volume.shape[axis] // 2

slice_axis.observe(_update_slice_idx, names="value")


In [22]:
def update(alpha, beta, gamma, sigmas_text, white_ridges, use_gpu, slice_axis, slice_idx, threshold, sigma_unit, clip_percentiles):
    try:
        frangi_vol = frangi_volume(alpha, beta, gamma, sigmas_text, use_gpu, white_ridges, sigma_unit=sigma_unit)
    except Exception as exc:
        print(f"Frangi failed: {exc}")
        return

    clip_low, clip_high = sorted([clip_percentiles[0], clip_percentiles[1]])
    slice_idx, orig_slice = pick_slice(volume, slice_axis, slice_idx)
    _, frangi_slice = pick_slice(frangi_vol, slice_axis, slice_idx)

    # Rotate for display (counter-clockwise 90 degrees)
    orig_slice = np.rot90(orig_slice)
    frangi_slice = np.rot90(frangi_slice)

    orig_vmin, orig_vmax = get_vrange(orig_slice, orig_slice, low=clip_low, high=clip_high)
    frangi_vmin, frangi_vmax = get_vrange(frangi_slice, frangi_slice, low=clip_low, high=clip_high)

    fig, ax = plt.subplots(figsize=(6, 6))
    if threshold > 0:
        mask = frangi_slice >= threshold
        ax.imshow(orig_slice, cmap="gray", vmin=orig_vmin, vmax=orig_vmax)
        ax.imshow(mask, cmap="Reds", alpha=0.35)
        ax.set_title(f"Mask >= {threshold:.2f} (axis {slice_axis}, idx {slice_idx})")
    else:
        im = ax.imshow(frangi_slice, cmap="magma", vmin=frangi_vmin, vmax=frangi_vmax)
        fig.colorbar(im, ax=ax, shrink=0.75, label="Intensity")
        ax.set_title(f"Frangi output (axis {slice_axis}, idx {slice_idx})")
    ax.axis("off")

    plt.tight_layout()
    plt.show()
    plt.close(fig)

controls = {
    "alpha": alpha_slider,
    "beta": beta_slider,
    "gamma": gamma_slider,
    "sigmas_text": sigma_text,
    "white_ridges": white_checkbox,
    "use_gpu": gpu_checkbox,
    "slice_axis": slice_axis,
    "slice_idx": slice_idx,
    "threshold": thr_slider,
    "sigma_unit": sigma_unit,
    "clip_percentiles": clip_slider,
}

ui = widgets.VBox(
    [
        widgets.HBox([sigma_text, sigma_unit, gpu_checkbox, white_checkbox]),
        widgets.HBox([alpha_slider, beta_slider, gamma_slider, thr_slider]),
        widgets.HBox([slice_axis, slice_idx, clip_slider]),
    ]
)

display(ui, widgets.interactive_output(update, controls))


VBox(children=(HBox(children=(Text(value='0.5,1,1.5,2,2.5,3,3.5', description='Sigmas', layout=Layout(width='2â€¦

Output()