
# Functional → Anatomy → HCR Mapping (Zebrafish, 2P + HCR)

This notebook automates a practical pipeline to map motion-corrected functional planes to the best matching **2P anatomy** Z-plane and then into **HCR/confocal** space.

**Core steps**
1. Build a *crisp* per-plane functional reference using a Suite2p-style **top‑correlated mean**.
2. For each plane, find the **best matching Z** in the 2P anatomy stack via **normalized cross‑correlation**.
3. Estimate an **in‑plane transform** (shift/similarity) from functional reference → anatomy[best‑Z].
4. Apply a precomputed **3D warp** (anatomy → HCR) to place the functional plane/slab into HCR space.
5. Transfer **ROI labels** (nearest‑neighbor) or **fluorescence images** (linear) as appropriate.

> Tip: Set the voxel sizes (µm) correctly for both stacks before registration, and keep a record of transforms so you can compose and resample **once** wherever possible.



## Requirements

This notebook uses common scientific Python packages:

- `numpy`, `scipy`, `pandas`
- `tifffile`
- `scikit-image` (`skimage`)
- `opencv-python` (optional; speeds up template matching)
- `matplotlib` for quick QA plots

If an import fails, install the package in your environment and re-run the cell.


In [None]:

import os, json, math, random
from pathlib import Path
import numpy as np
import pandas as pd
from tifffile import imread, imwrite
from skimage import filters, exposure, transform, feature, measure, registration, img_as_float32
from skimage.transform import SimilarityTransform, AffineTransform, warp
from skimage.util import img_as_ubyte
from scipy import ndimage as ndi

# Optional OpenCV (accelerated NCC); guarded import
try:
    import cv2
    HAS_CV2 = True
except Exception:
    HAS_CV2 = False

import matplotlib.pyplot as plt

print("HAS_CV2:", HAS_CV2)



## Paths & I/O

Edit these to your data. You can also point to your colleague's helper notebook (`/mnt/data/fToA_registration_jl.ipynb`) if you want to copy functions from it.


In [None]:

# --- User paths (EDIT ME) ---
FUNC_STACK_PATH = '/mnt/data/functional_stack.tif'   # motion-corrected, single plane over time OR volume over time
ANAT_STACK_PATH = '/mnt/data/anatomy_2p_stack.tif'   # 2P anatomy (structural) stack
HCR_STACK_PATH  = '/mnt/data/hcr_stack.tif'          # HCR/confocal stack (already cleared sample)
OUTDIR = Path('/mnt/data/f2h_outputs'); OUTDIR.mkdir(exist_ok=True, parents=True)

# Optional label image to map (e.g., Cellpose labels)
LABELS_PATH = None   # e.g., '/mnt/data/cellpose_labels.tif' (uint16 IDs)

# Voxel sizes in microns (used for metadata/logging; algorithms here work in pixels)
VOX_FUNC = (1.0, 1.0, 1.0)   # (z, y, x) for functional
VOX_ANAT = (1.0, 1.0, 1.0)   # (z, y, x) for anatomy
VOX_HCR  = (1.0, 1.0, 1.0)   # (z, y, x) for HCR

# Random seed for reproducibility in frame subsampling
RNG_SEED = 42
random.seed(RNG_SEED)
np.random.seed(RNG_SEED)



## Utility functions


In [None]:

def zproject_mean(stack):
    return stack.mean(axis=0)

def norm01(img):
    img = img.astype(np.float32)
    m, M = np.percentile(img, (1, 99))
    if M <= m:
        M = img.max(); m = img.min()
    out = np.clip((img - m) / (M - m + 1e-6), 0, 1)
    return out

def local_unsharp(img, blur_sigma=1.0, amount=0.6):
    base = ndi.gaussian_filter(img, blur_sigma)
    return np.clip(base + amount*(img - base), 0, 1)

def corrcoef_img(a, b):
    # Pearson correlation between 2D arrays
    a = a.astype(np.float32); b = b.astype(np.float32)
    am = a.mean(); bm = b.mean()
    num = ((a - am)*(b - bm)).sum()
    den = np.sqrt(((a - am)**2).sum() * ((b - bm)**2).sum()) + 1e-8
    return float(num / den)

def top_correlated_mean(stack_t, take_k=20, pre_smooth_sigma=0.5):
    """Suite2p-like: build crisp reference by selecting top-K frames most correlated to a provisional mean."""
    T, H, W = stack_t.shape
    # Provisional mean
    m0 = stack_t.mean(axis=0)
    # Optional pre-smoothing to reduce shot noise
    if pre_smooth_sigma and pre_smooth_sigma > 0:
        m0s = ndi.gaussian_filter(m0, pre_smooth_sigma)
    else:
        m0s = m0
    # Correlate each frame with provisional mean
    corrs = np.empty(T, dtype=np.float32)
    for i in range(T):
        fi = stack_t[i]
        if pre_smooth_sigma and pre_smooth_sigma > 0:
            fi = ndi.gaussian_filter(fi, pre_smooth_sigma)
        corrs[i] = corrcoef_img(fi, m0s)
    # Take top-K
    k = min(take_k, T)
    idx = np.argsort(corrs)[-k:]
    ref = stack_t[idx].mean(axis=0)
    return ref, idx, corrs

def best_z_by_ncc(template, anat_stack, use_cv2=True):
    """Return best Z index and NCC scores over Z for a 2D template vs 3D stack."""
    template = norm01(template)
    H, W = template.shape
    scores = []
    if use_cv2 and HAS_CV2:
        templ = (template*255).astype(np.uint8)
        for z in range(anat_stack.shape[0]):
            sl = norm01(anat_stack[z])
            sl8 = (sl*255).astype(np.uint8)
            res = cv2.matchTemplate(sl8, templ, cv2.TM_CCORR_NORMED)
            # Whole-image match: template same size; if not, pad template or crop; here we assume same FOV/size
            if res.size == 1:
                s = float(res.ravel()[0])
            else:
                s = float(res.max())
            scores.append(s)
    else:
        # fallback: simple correlation on same-size images
        for z in range(anat_stack.shape[0]):
            sl = norm01(anat_stack[z])
            s = corrcoef_img(template, sl)
            scores.append(s)
    scores = np.asarray(scores, dtype=np.float32)
    best_z = int(np.argmax(scores))
    return best_z, scores

def estimate_inplane_transform(mov, ref, method='similarity'):
    """Estimate 2D transform from moving image (mov) to reference (ref).
    Tries ORB+RANSAC; falls back to phase cross-correlation (shift only)."""
    m = norm01(mov); r = norm01(ref)
    # ORB keypoints
    try:
        detector = feature.ORB(n_keypoints=2000, fast_threshold=0.05)
        detector.detect_and_extract(img_as_float32(m))
        kp1 = detector.keypoints; d1 = detector.descriptors
        detector.detect_and_extract(img_as_float32(r))
        kp2 = detector.keypoints; d2 = detector.descriptors
        if len(kp1) >= 10 and len(kp2) >= 10 and d1 is not None and d2 is not None:
            matches12 = feature.match_descriptors(d1, d2, cross_check=True, max_ratio=0.8)
            src = kp1[matches12[:, 0]][:, ::-1]  # (x,y)
            dst = kp2[matches12[:, 1]][:, ::-1]
            if method == 'similarity':
                model, inliers = measure.ransac((src, dst), SimilarityTransform,
                                                min_samples=3, residual_threshold=2.0, max_trials=2000)
            else:
                model, inliers = measure.ransac((src, dst), AffineTransform,
                                                min_samples=3, residual_threshold=2.0, max_trials=2000)
            if model is not None:
                return model
    except Exception as e:
        pass
    # Fallback: phase correlation for shift
    shift, _, _ = registration.phase_cross_correlation(r, m, upsample_factor=10)
    tform = SimilarityTransform(translation=(shift[1], shift[0]))
    return tform

def apply_transform_2d(img, tform, output_shape=None, order=1, preserve_range=True):
    if output_shape is None:
        output_shape = img.shape
    warped = warp(img, inverse_map=tform.inverse, output_shape=output_shape, order=order,
                  preserve_range=preserve_range, mode='constant', cval=0.0, clip=True)
    return warped

def resample_labels_nn(img, tform, output_shape=None):
    # nearest-neighbor for label images
    return apply_transform_2d(img, tform, output_shape=output_shape, order=0, preserve_range=True)

def apply_anat_to_hcr_warp_2d(slice_img, z_index, warp3d_func):
    """Hook to apply a 3D warp (anatomy→HCR) to a 2D slice.
    `warp3d_func` should accept (z,y,x) indices or coordinates and return warped image in HCR coords.
    For now this is a placeholder you can implement with your BigWarp/ANTs output.
    """
    return warp3d_func(slice_img, z_index)

def quickshow(img, title='', vmin=None, vmax=None):
    plt.figure(figsize=(5,5))
    plt.imshow(img, vmin=vmin, vmax=vmax)
    plt.title(title); plt.axis('off'); plt.show()



## 1) Build a crisp functional reference per plane

If your functional input is **single-plane over time**: this outputs one reference.  
If it's **volume over time**: set `PLANE_INDEX` to the plane you want, or loop planes.


In [None]:

# Load functional data
func = imread(FUNC_STACK_PATH)
print("Functional shape:", func.shape)

# Detect dimensionality: (T, Y, X) or (T, Z, Y, X)
if func.ndim == 3:
    # Single plane over time
    T, H, W = func.shape
    ref2d, kept_idx, corrs = top_correlated_mean(func, take_k=min(20, T), pre_smooth_sigma=0.5)
    ref2d = norm01(ref2d)
    imwrite(OUTDIR/'func_ref_plane0.tif', (ref2d*65535).astype(np.uint16))
    print(f"Saved reference to {OUTDIR/'func_ref_plane0.tif'}")
    quickshow(ref2d, "Functional reference (plane 0)")
else:
    # Volume over time
    T, Z, H, W = func.shape
    PLANE_INDEX = 0  # EDIT: choose plane
    plane_t = func[:, PLANE_INDEX, :, :]
    ref2d, kept_idx, corrs = top_correlated_mean(plane_t, take_k=min(20, T), pre_smooth_sigma=0.5)
    ref2d = norm01(ref2d)
    imwrite(OUTDIR/f'func_ref_plane{PLANE_INDEX}.tif', (ref2d*65535).astype(np.uint16))
    print(f"Saved reference to {OUTDIR/f'func_ref_plane{PLANE_INDEX}.tif'}")
    quickshow(ref2d, f"Functional reference (plane {PLANE_INDEX})")



## 2) Find the best matching Z in the 2P anatomy stack
We correlate the functional reference against each anatomy slice and take the Z with the maximum NCC.


In [None]:

anat = imread(ANAT_STACK_PATH)
print("Anatomy shape:", anat.shape)

# Optional pre-filter to enhance structure
anat_f = np.stack([local_unsharp(norm01(s), 1.0, 0.6) for s in anat], axis=0)
ref_f  = local_unsharp(norm01(ref2d), 1.0, 0.6)

best_z, scores = best_z_by_ncc(ref_f, anat_f, use_cv2=True)
print("Best Z in anatomy:", best_z)

# Save scores for QA
pd.Series(scores).to_csv(OUTDIR/'bestZ_scores.csv', index=False)

plt.figure()
plt.plot(scores); plt.axvline(best_z, linestyle='--')
plt.xlabel('Z'); plt.ylabel('NCC score'); plt.title('Best-Z scores')
plt.show()

quickshow(anat[best_z], f'Anatomy slice Z={best_z}')



## 3) Estimate in-plane transform (functional → anatomy[best‑Z])
We try ORB+RANSAC to get a **similarity** (shift/scale/rotation) transform, and fall back to phase correlation (shift only).


In [None]:

tform = estimate_inplane_transform(ref_f, anat_f[best_z], method='similarity')
print("Estimated transform:
", tform.params)

# Warp the functional reference into anatomy space for visual QA
ref_warped = apply_transform_2d(ref_f, tform, output_shape=anat_f[best_z].shape, order=1)
quickshow(ref_warped, 'Functional ref warped → anatomy')
quickshow(anat_f[best_z], 'Anatomy best-Z')

# Overlay QA
plt.figure(figsize=(6,6))
plt.imshow(anat_f[best_z], alpha=0.7)
plt.imshow(ref_warped, alpha=0.3)
plt.title('Overlay (anatomy + warped functional)')
plt.axis('off'); plt.show()

# Save transform matrix
np.save(OUTDIR/'func_to_anat_similarity_2x3.npy', tform.params)
print("Saved transform to", OUTDIR/'func_to_anat_similarity_2x3.npy')



## 4) (Optional) Transfer ROI labels safely (nearest-neighbor)

If you have a label image (e.g., Cellpose IDs), warp it with **order=0** (nearest‑neighbor) to avoid fragmentation or mixing of labels.


In [None]:

if LABELS_PATH and os.path.exists(LABELS_PATH):
    labels = imread(LABELS_PATH)
    if labels.ndim == 2:
        labels_warp = resample_labels_nn(labels, tform, output_shape=anat.shape[1:])
        imwrite(OUTDIR/'labels_in_anat_space.tif', labels_warp.astype(np.uint16))
        print("Saved labels_in_anat_space.tif")
        quickshow(labels_warp>0, 'Labels (any>0) in anatomy space')
    else:
        print("LABELS_PATH is not a 2D label image; adapt code for your layout.")
else:
    print("No LABELS_PATH provided; skipping labels transfer.")



## 5) Apply anatomy → HCR 3D warp (hook)

Provide a function `warp3d_func(slice_img, z_index)` that applies your 3D deformation field to a 2D slice at Z.  
This is where you plug in BigWarp/ANTs/SimpleITK. Below is a **mock** that passes data through unchanged.


In [None]:

def dummy_warp3d(slice_img, z_index):
    # Replace with your BigWarp/ANTs application. This is identity.
    return slice_img

# Example usage:
slice_in_hcr = apply_anat_to_hcr_warp_2d(ref_warped, best_z, warp3d_func=dummy_warp3d)
quickshow(slice_in_hcr, 'Slice in HCR space (dummy)')



## 6) Save a small JSON with parameters and results


In [None]:

meta = {
    "FUNC_STACK_PATH": FUNC_STACK_PATH,
    "ANAT_STACK_PATH": ANAT_STACK_PATH,
    "HCR_STACK_PATH": HCR_STACK_PATH,
    "best_z": int(best_z),
    "transform_params_2x3": tform.params.tolist(),
    "voxels": {"func": VOX_FUNC, "anat": VOX_ANAT, "hcr": VOX_HCR},
    "rng_seed": RNG_SEED,
}

with open(OUTDIR/'run_metadata.json', 'w') as f:
    json.dump(meta, f, indent=2)

print("Wrote", OUTDIR/'run_metadata.json')



## Notes & next steps

- If the NCC best‑Z curve is **broad** or has multiple peaks, consider matching a **slab** (±2–3 slices) instead of a single slice.
- For in‑plane alignment, try `method='affine'` if you suspect slight shear/scale differences.
- To register **anatomy → HCR** in Python, consider:
  - **SimpleITK (Elastix)**: rigid→affine→B‑spline with mutual information.
  - **ANTsPy**: SyN-based deformable registration.
- Once you have a 3D transform, apply it to **grayscale** with linear order, to **labels** with nearest‑neighbor.
- Compose transforms and resample once if you can (functional → anatomy → HCR → output).