
# 3D Motion Correction — Notebook

This notebook loads a pre‑aligned 3D stack, synthesizes displaced volumes, estimates 3D optical flow, applies motion compensation, and visualizes results:

- Raw vs. Displaced: side‑by‑side slices and 3D
- Raw vs. Compensated: side‑by‑side slices and 3D
- Flow metrics and EPE

Backends for 3D in‑notebook:
- Primary: **napari + jupyter_rfb**
- Fallbacks: **PyVista/VTK** or **itkwidgets**

Run on a local Jupyter environment with GPU if available.



## Setup

If needed, install packages in a dedicated environment and restart the kernel after installation.


In [1]:

# %pip install --upgrade pip
# %pip install napari jupyter_rfb
# %pip install pyvista ipyvtklink
# %pip install itkwidgets
# %pip install pyflowreg flowreg3d scipy numpy matplotlib ipywidgets


## Imports and backend selection

In [2]:

import os, time
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interact, IntSlider, Dropdown, VBox, HBox, Layout
from IPython.display import display

try:
    import torch
    TORCH_OK = True
except Exception:
    TORCH_OK = False

BACKEND = "napari"
MODE = "torch"

NAPARI_OK = False
PYVISTA_OK = False
ITK_OK = False

if BACKEND == "napari":
    try:
        import napari  # noqa
        from jupyter_rfb import RemoteFrameBuffer  # noqa
        NAPARI_OK = True
    except Exception:
        NAPARI_OK = False

if not NAPARI_OK:
    try:
        import pyvista as pv  # noqa
        from ipyvtklink.viewer import ViewInteractiveWidget  # noqa
        PYVISTA_OK = True
    except Exception:
        PYVISTA_OK = False

if not (NAPARI_OK or PYVISTA_OK):
    try:
        from itkwidgets import view as itkview  # noqa
        ITK_OK = True
    except Exception:
        ITK_OK = False

assert TORCH_OK or MODE != "torch", "Switch MODE='numpy' or install torch."
assert NAPARI_OK or PYVISTA_OK or ITK_OK, "Install at least one 3D backend."


ModuleNotFoundError: No module named 'matplotlib'

## IO and helpers

In [None]:

from scipy.ndimage import zoom, gaussian_filter
from pyflowreg.util.io.factory import get_video_file_reader
from flowreg3d.core.optical_flow_3d import get_displacement as get_displacement_numpy
from flowreg3d.core.torch.optical_flow_3d import get_displacement as get_displacement_torch
from flowreg3d.core.optical_flow_3d import imregister_wrapper
from flowreg3d.motion_generation.motion_generators import (
    get_default_3d_generator,
    get_low_disp_3d_generator,
    get_test_3d_generator,
    get_high_disp_3d_generator,
)
from flowreg3d.util.random import fix_seed

def process_3d_stack(video, resize_factor=1, crop=25):
    v = video.astype(np.float32)
    if v.ndim == 4:
        zf = (1.0, resize_factor, resize_factor, 1.0)
    else:
        zf = (1.0, resize_factor, resize_factor)
    v = zoom(v, zf, order=1)
    if v.ndim == 4:
        v = v[:, crop:-crop, crop:-crop, :]
    else:
        v = v[:, crop:-crop, crop:-crop]
    mn, mx = v.min(), v.max()
    if mx > mn:
        v = (v - mn) / (mx - mn)
    return v

def warp_volume_splat3d(volume, flow):
    Z, H, W = volume.shape[:3]
    z, y, x = np.meshgrid(np.arange(Z), np.arange(H), np.arange(W), indexing='ij')
    tx = (x + flow[..., 0]).ravel()
    ty = (y + flow[..., 1]).ravel()
    tz = (z + flow[..., 2]).ravel()
    iz = np.floor(tz).astype(np.int64); fz = tz - iz
    iy = np.floor(ty).astype(np.int64); fy = ty - iy
    ix = np.floor(tx).astype(np.int64); fx = tx - ix
    iz0 = np.clip(iz, 0, Z - 1); iz1 = np.clip(iz + 1, 0, Z - 1)
    iy0 = np.clip(iy, 0, H - 1); iy1 = np.clip(iy + 1, 0, H - 1)
    ix0 = np.clip(ix, 0, W - 1); ix1 = np.clip(ix + 1, 0, W - 1)
    w000 = (1 - fx) * (1 - fy) * (1 - fz)
    w100 = fx * (1 - fy) * (1 - fz)
    w010 = (1 - fx) * fy * (1 - fz)
    w110 = fx * fy * (1 - fz)
    w001 = (1 - fx) * (1 - fy) * fz
    w101 = fx * (1 - fy) * fz
    w011 = (1 - fx) * fy * fz
    w111 = fx * fy * fz

    def accum(values):
        V = values.ravel()
        idx = lambda zz, yy, xx: (zz * H + yy) * W + xx
        N = Z * H * W
        out = np.zeros(N, dtype=np.float64)
        den = np.zeros(N, dtype=np.float64)
        for w, zz, yy, xx in [(w000, iz0, iy0, ix0), (w100, iz0, iy0, ix1), (w010, iz0, iy1, ix0),
                              (w110, iz0, iy1, ix1), (w001, iz1, iy0, ix0), (w101, iz1, iy0, ix1),
                              (w011, iz1, iy1, ix0), (w111, iz1, iy1, ix1)]:
            idv = idx(zz, yy, xx)
            np.add.at(out, idv, V * w)
            np.add.at(den, idv, w)
        den[den == 0] = 1.0
        return (out / den).reshape(Z, H, W).astype(values.dtype)

    if volume.ndim == 4:
        C = volume.shape[3]
        return np.stack([accum(volume[..., c]) for c in range(C)], axis=-1)
    return accum(volume)

def create_displaced(video, generator="high_disp", crop=10):
    if generator == "default":
        gen = get_default_3d_generator()
    elif generator == "low_disp":
        gen = get_low_disp_3d_generator()
    elif generator == "test":
        gen = get_test_3d_generator()
    else:
        gen = get_high_disp_3d_generator()
    flow_gt, invalid = gen(depth=video.shape[0], height=video.shape[1], width=video.shape[2])
    displaced = warp_volume_splat3d(video, flow_gt)
    if displaced.ndim == 4:
        displaced = displaced[crop:-crop, crop:-crop, crop:-crop, :]
        flow_gt = flow_gt[crop:-crop, crop:-crop, crop:-crop, :]
    else:
        displaced = displaced[crop:-crop, crop:-crop, crop:-crop]
        flow_gt = flow_gt[crop:-crop, crop:-crop, crop:-crop, :]
    return displaced, flow_gt

def compute_flow(frame1, frame2, params, mode="torch"):
    if frame1.ndim == 3:
        f1 = frame1[..., None]
        f2 = frame2[..., None]
    else:
        f1, f2 = frame1, frame2
    if mode == "torch":
        t1 = torch.from_numpy(f1).to(torch.float64)
        t2 = torch.from_numpy(f2).to(torch.float64)
        t1 = t1.cuda() if torch.cuda.is_available() else t1
        t2 = t2.cuda() if torch.cuda.is_available() else t2
        with torch.no_grad():
            flow = get_displacement_torch(t1, t2, **params)
        return flow.detach().cpu().numpy().astype(np.float64, copy=False)
    f1s = np.empty_like(f1); f2s = np.empty_like(f2)
    for c in range(f1.shape[-1]):
        f1s[..., c] = gaussian_filter(f1[..., c], sigma=0.5)
        f2s[..., c] = gaussian_filter(f2[..., c], sigma=0.5)
    mn = f1s.min(axis=(0,1,2), keepdims=True); rg = f1s.max(axis=(0,1,2), keepdims=True) - mn
    rg = np.where(rg > 0, rg, 1.0)
    f1n = (f1s - mn) / rg; f2n = (f2s - mn) / rg
    return get_displacement_numpy(f1n, f2n, **params)

def epe(flow_est, flow_gt, boundary=25):
    fe = flow_est[boundary:-boundary, boundary:-boundary, boundary:-boundary, :]
    fg = flow_gt[boundary:-boundary, boundary:-boundary, boundary:-boundary, :]
    return np.mean(np.linalg.norm(fe - fg, axis=-1))


## Load data

In [None]:

fix_seed(seed=1, deterministic=True, verbose=False)
repo_root = Path.cwd().parent if (Path.cwd().parent / "data" / "aligned_sequence").exists() else Path.cwd()
aligned_file = repo_root / "data" / "aligned_sequence" / "compensated.HDF5"
assert aligned_file.exists(), f"Missing file: {aligned_file}"
reader = get_video_file_reader(str(aligned_file), buffer_size=100, bin_size=9)
buf = []
while reader.has_batch():
    buf.append(reader.read_batch())
video_3d = np.concatenate(buf, axis=0) if buf else reader[:]
reader.close()
processed = process_3d_stack(video_3d, resize_factor=1, crop=25)
boundary = 10
original = processed[boundary:-boundary, boundary:-boundary, boundary:-boundary] if processed.ndim == 3 else processed[boundary:-boundary, boundary:-boundary, boundary:-boundary, :]
displaced, flow_gt = create_displaced(original, generator="high_disp", crop=0)


## Flow estimation

In [None]:

flow_params = {
    "alpha": (0.25, 0.25, 0.25),
    "iterations": 100,
    "a_data": 0.45,
    "a_smooth": 1.0,
    "weight": np.array([0.5, 0.5], dtype=np.float64),
    "levels": 50,
    "eta": 0.8,
    "update_lag": 5,
    "min_level": 5,
    "const_assumption": "gc",
    "uvw": None
}
t0 = time.perf_counter()
flow_est = compute_flow(original, displaced, flow_params, mode=MODE)
t_flow = time.perf_counter() - t0
compensated = imregister_wrapper(displaced, flow_est[...,0], flow_est[...,1], flow_est[...,2], original, interpolation_method="cubic")
score_epe = epe(flow_est, flow_gt, boundary=25)
mad_od = np.mean(np.abs(original - displaced))
mad_oc = np.mean(np.abs(original - compensated))
imp = mad_od / max(mad_oc, 1e-8)
print(f"flow_time_s={t_flow:.2f}, EPE={score_epe:.3f}, MAD(ori,disp)={mad_od:.4f}, MAD(ori,comp)={mad_oc:.4f}, improvement={imp:.2f}x")


## 2D slice viewers

In [None]:

Z = original.shape[0]
@interact(z=IntSlider(0, 0, Z-1, 1, description="z"))
def _view(z=0):
    fig, axs = plt.subplots(1, 2, figsize=(10,5), constrained_layout=True)
    axs[0].imshow(original[z], cmap="gray", vmin=0, vmax=1)
    axs[0].set_title("Raw")
    axs[0].axis("off")
    axs[1].imshow(displaced[z], cmap="gray", vmin=0, vmax=1)
    axs[1].set_title("Displaced")
    axs[1].axis("off")
    plt.show()

@interact(z=IntSlider(0, 0, Z-1, 1, description="z"))
def _view2(z=0):
    fig, axs = plt.subplots(1, 2, figsize=(10,5), constrained_layout=True)
    axs[0].imshow(original[z], cmap="gray", vmin=0, vmax=1)
    axs[0].set_title("Raw")
    axs[0].axis("off")
    axs[1].imshow(compensated[z], cmap="gray", vmin=0, vmax=1)
    axs[1].set_title("Compensated")
    axs[1].axis("off")
    plt.show()


## Flow magnitude slices

In [None]:

mag_est = np.sqrt(np.sum(flow_est**2, axis=-1))
mag_gt = np.sqrt(np.sum(flow_gt**2, axis=-1))
@interact(z=IntSlider(0, 0, Z-1, 1, description="z"))
def _view_flow(z=0):
    fig, axs = plt.subplots(1, 2, figsize=(10,5), constrained_layout=True)
    axs[0].imshow(mag_est[z], vmin=0, vmax=np.percentile(mag_est, 99))
    axs[0].set_title("Flow |est|")
    axs[0].axis("off")
    axs[1].imshow(mag_gt[z], vmin=0, vmax=np.percentile(mag_gt, 99))
    axs[1].set_title("Flow |gt|")
    axs[1].axis("off")
    plt.show()


## 3D visualization

In [None]:

def show_3d_napari(volumes, names, cmaps, opacities):
    import napari
    from jupyter_rfb import RemoteFrameBuffer
    v = napari.Viewer(title="3D Volumes", ndisplay=3)
    for vol, name, cmap, op in zip(volumes, names, cmaps, opacities):
        v.add_image(vol, name=name, colormap=cmap, blending="additive", opacity=op, contrast_limits=[0,1])
    return RemoteFrameBuffer(v.canvas, events=v)

def show_3d_pyvista(volumes, names):
    import pyvista as pv
    from ipyvtklink.viewer import ViewInteractiveWidget
    pl = pv.Plotter(shape=(1, len(volumes)))
    for i, vol in enumerate(volumes):
        pl.subplot(0, i)
        grid = pv.UniformGrid()
        z, y, x = vol.shape[:3]
        grid.dimensions = x, y, z
        grid.origin = (0, 0, 0)
        grid.spacing = (1, 1, 1)
        grid.point_data["values"] = vol.ravel(order="F")
        pl.add_volume(grid, opacity="sigmoid", mapper="gpu")
        pl.add_text(names[i], font_size=10)
    pl.link_views()
    return ViewInteractiveWidget(pl.show(jupyter_backend="ipyvtklink"))

def show_3d_itk(volumes, names):
    from itkwidgets import view as itkview
    widgets = []
    for vol, name in zip(volumes, names):
        widgets.append(itkview(image=vol, cmap="gray"))
    return VBox(widgets, layout=Layout(width="100%"))

vols1 = [original, displaced]
names1 = ["Raw", "Displaced"]
vols2 = [original, compensated]
names2 = ["Raw", "Compensated"]

if NAPARI_OK:
    display(show_3d_napari(vols1, names1, ["gray", "magenta"], [0.7, 0.7]))
    display(show_3d_napari(vols2, names2, ["gray", "cyan"], [0.7, 0.9]))
elif PYVISTA_OK:
    display(show_3d_pyvista(vols1, names1))
    display(show_3d_pyvista(vols2, names2))
elif ITK_OK:
    display(show_3d_itk(vols1, names1))
    display(show_3d_itk(vols2, names2))
