In [1]:
import tifffile
import mbo_utilities as mbo

Image(value=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x01,\x00\x00\x007\x08\x06\x00\x00\x00\xb6\x1bw\x99\x…

Valid,Device,Type,Backend,Driver
✅ (default),Intel(R) Graphics (RPL-P),IntegratedGPU,Vulkan,Mesa 24.0.3-1pop1~1711635559~22.04~7a9f319
❗ limited,"llvmpipe (LLVM 15.0.7, 256 bits)",CPU,Vulkan,Mesa 24.0.3-1pop1~1711635559~22.04~7a9f319 (LLVM 15.0.7)
❌,Mesa Intel(R) Graphics (RPL-P),IntegratedGPU,OpenGL,4.6 (Core Profile) Mesa 24.2.8-1~bpo12+1pop1~1744225826~22.04~b077665


In [2]:
files = mbo.get_files("/home/flynn/lbm_data/raw")
files

['/home/flynn/lbm_data/raw/mk301_03_01_2025_2roi_17p07hz_224x448px_2umpx_180mw_green_00002.tif',
 '/home/flynn/lbm_data/raw/mk301_03_01_2025_2roi_17p07hz_224x448px_2umpx_180mw_green_00001.tif']

In [16]:
data = tifffile.imread(files[0])[:, 7, :, :]
data.shape

(101, 912, 224)

In [29]:
import numpy as np
from scipy.ndimage import fourier_shift
from skimage.registration import phase_cross_correlation
from mbo_utilities import log

TWO_DIM_PHASECORR_METHODS = {"frame", "mean", "max", "std", "mean-sub", "mean-sub-std"}
THREE_DIM_PHASECORR_METHODS = ["mean", "max", "std", "mean-sub"]

MBO_WINDOW_METHODS = {
    "mean": lambda X: np.mean(X, axis=0),
    "max": lambda X: np.max(X, axis=0),
    "std": lambda X: np.std(X, axis=0),
    "mean-sub": lambda X: X - np.mean(X, axis=0),
    "mean-sub-std": lambda X: (X - np.mean(X, axis=0)) / (np.std(X, axis=0) + 1e-8),
}

logger = log.get("phasecorr")


def _phase_corr_2d(
        frame,
        upsample=1,
        border=0,
        max_offset=4
):
    if frame.ndim != 2:
        raise ValueError("Expected a 2D frame, got a 3D array.")

    h, w = frame.shape

    if isinstance(border, int):
        t = b = l = r = border
    else:
        t, b, l, r = border

    pre, post = frame[::2], frame[1::2]
    m = min(pre.shape[0], post.shape[0])

    row_start = t
    row_end = m - b if b else m
    col_start = l
    col_end = w - r if r else w

    a = pre[row_start:row_end, col_start:col_end]
    b_ = post[row_start:row_end, col_start:col_end]

    shift, *_ = phase_cross_correlation(a, b_, upsample_factor=upsample)
    dx = float(shift[1])
    if max_offset:
        return np.sign(dx) * min(abs(dx), max_offset)
    return dx


def _apply_offset(frame, shift):
    if frame.ndim < 2:
        return frame
    rows = frame[1::2]
    f = np.fft.fftn(rows)
    shift_vec = (0, shift)[:rows.ndim]
    rows[:] = np.fft.ifftn(fourier_shift(f, shift_vec)).real
    return frame


def nd_windowed_compute_optimal_offset(
        arr,
        method="mean",
        upsample=1,
        max_offset=4,
        border=2,
):
    """
    Compute scan‐phase offsets. If `arr` is 2D, always run a single‐image offset.
    If `arr` is 3D (time × height × width), one of:

      - "frame"      → compute offset frame‐by‐frame (returns a 1D array of length T)
      - "mean", "max", "std"       → collapse along time with np.mean/np.max/np.std first
      - "mean-sub"   → subtract the temporal mean from each frame, then run offset on that difference
      - "mean-sub-std" → first z‐score each pixel over time, then compute offset on that z‐scored image

    """
    a = np.asarray(arr)
    if a.ndim == 2:
        if method not in TWO_DIM_PHASECORR_METHODS:
            logger.debug(
                "Attempted to use a windowed phase-corr method on 2D data."
                f"Available 2D methods: {TWO_DIM_PHASECORR_METHODS}"
            )
        return _phase_corr_2d(
            a,
            upsample=upsample,
            border=border,
            max_offset=max_offset
        )
    # flatten z/t
    flat = a.reshape(a.shape[0], *a.shape[-2:])

    # one offset per frame
    if method == "frame":
        return np.array([_phase_corr_2d(
            f,
            upsample=upsample,
            border=border,
            max_offset=max_offset,
        ) for f in flat])  # dtype=np.float32)

    if method not in MBO_WINDOW_METHODS:
        raise ValueError(f"Unknown phase‐corr method: {method!r}")

    image = MBO_WINDOW_METHODS[method](flat)
    return _phase_corr_2d(image, upsample=upsample, border=border, max_offset=max_offset)

def nd_windowed(arr, *, method="mean", upsample=1,
                max_offset=4, border=2):
    a = np.asarray(arr)
    if a.ndim == 2:
        offs = _phase_corr_2d(a, upsample, border, max_offset)
    else:
        flat = a.reshape(a.shape[0], *a.shape[-2:])
        if method == "frame":
            offs = np.array([_phase_corr_2d(f, upsample, border, max_offset)
                             for f in flat])
        else:
            if method not in MBO_WINDOW_METHODS:
                raise ValueError(f"unknown method {method}")
            img = MBO_WINDOW_METHODS[method](flat)
            offs = _phase_corr_2d(img, upsample, border, max_offset)

    if np.isscalar(offs):
        corrected = _apply_offset(a.copy(), offs)
    else:
        corrected = np.stack([
            _apply_offset(f.copy(), s) for f, s in zip(a, offs)
        ])
    return corrected, offs


def apply_scan_phase_offsets(arr, offs):
    out = np.asarray(arr).copy()
    if np.isscalar(offs):
        return _apply_offset(out, offs)
    for k, off in enumerate(offs):
        out[k] = _apply_offset(out[k], off)
    return out


In [30]:
def nd_windowed_compute_optimal_offset(
        arr,
        method="mean",
        upsample=1,
        max_offset=4,
        border=2,
):
    """
    Compute scan‐phase offsets. If `arr` is 2D, always run a single‐image offset.
    If `arr` is 3D (time × height × width), one of:

      - "frame"      → compute offset frame‐by‐frame (returns a 1D array of length T)
      - "mean", "max", "std"       → collapse along time with np.mean/np.max/np.std first
      - "mean-sub"   → subtract the temporal mean from each frame, then run offset on that difference
      - "mean-sub-std" → first z‐score each pixel over time, then compute offset on that z‐scored image

    """
    a = np.asarray(arr)
    if a.ndim == 2:
        if method not in TWO_DIM_PHASECORR_METHODS:
            logger.debug(
                "Attempted to use a windowed phase-corr method on 2D data."
                f"Available 2D methods: {TWO_DIM_PHASECORR_METHODS}"
            )
        return _phase_corr_2d(
            a,
            upsample=upsample,
            border=border,
            max_offset=max_offset
        )
    # flatten z/t
    flat = a.reshape(a.shape[0], *a.shape[-2:])

    # one offset per frame
    if method == "frame":
        return np.array([_phase_corr_2d(
            f,
            upsample=upsample,
            border=border,
            max_offset=max_offset,
        ) for f in flat])  # dtype=np.float32)

    if method not in MBO_WINDOW_METHODS:
        raise ValueError(f"Unknown phase‐corr method: {method!r}")

    image = MBO_WINDOW_METHODS[method](flat)
    return _phase_corr_2d(image, upsample=upsample, border=border, max_offset=max_offset)

In [34]:
%%timeit
_ = nd_windowed(data, method="mean", upsample=8)

RuntimeError: sequence argument must have length equal to input rank

In [33]:
%%timeit
offs = nd_windowed_compute_optimal_offset(
    data,
    method="frame",
    upsample=8,
)
corrected = apply_scan_phase_offsets(data, offs)

1.89 s ± 17 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [21]:
def run_chunks(n):
    for s in range(0, data.shape[0], n):
        nd_windowed(data[s:s+n], method="mean", upsample=8)

In [25]:
n = 3
for s in range(0, data.shape[0], n):
    print(data[s:s+n].shape)

(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(3, 912, 224)
(2, 912, 224)


In [22]:
%%timeit
run_chunks(3)

RuntimeError: sequence argument must have length equal to input rank