In [None]:
# --- Batch: transform suite2p ROIs into common reference space ---

import numpy as np
import pandas as pd
from pathlib import Path
import tifffile as tiff
import ants
# Setup & imports
%load_ext autoreload
%autoreload 2
import sys, math, warnings, json
import numpy as np
import nrrd

import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display


import os

# Get parent of the notebook dir (project_root)
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))

# Add sibling directory "src" to sys.path
sys.path.append(os.path.join(project_root, "src"))

import alignSubstackUtils as asu


In [None]:


BASE_FOLDER = Path(r"Y:/07_Data/Matilde")

# Set this to an animal name (e.g. "L395_f11") to run only one animal for testing.
# Leave as None to process all animals.
TEST_ANIMAL = None #"L395_f11"  # or None
TEST_ANIMAL = "L427_f01" 

# Subpaths
FUNC_S2P_SUBPATH = Path(r"03_analysis/functional/suite2P")
REG_RESULTS_SUBPATH = Path(r"02_reg/07_2pf-a")
ANTS_TRANS_SUBPATH = Path(r"02_reg/08_2pa-ref/transMatrices")
ROI_OUT_SUBPATH = Path(r"03_analysis/functional/ROI_transformed")

# Skip / overwrite behavior
OVERWRITE_ROI_CSV = True  # False -> skip animals with existing outputs; True -> recompute
REF_PATH = Path(r"Y:/03_Common_Use/reference brains/ref_05_LB_Perrino_2p/average_2p_noRot_8bit_flipX_reverse.nrrd")





In [None]:
# (use the reference brain for this, since same size for all)
ref_stack = asu.read_good_nrrd_uint8(
    Path(r"Y:/03_Common_Use/reference brains/ref_05_LB_Perrino_2p/average_2p_noRot_8bit_flipX_reverse.nrrd"),
    flip_horizontal=True
)
#FIXED_SHAPE = ref_stack.shape  # (Z, Y, X)
REF_SPACING_ZYX, REF_ORIGIN_ZYX = asu.get_spacing_origin_ZYX(REF_PATH)


# Ensure output folders exist
for animal_dir in BASE_FOLDER.iterdir():
    if animal_dir.is_dir():
        (animal_dir / ROI_OUT_SUBPATH).mkdir(parents=True, exist_ok=True)

def load_suite2p_rois(plane_dir: Path, animal: str, plane_idx: int, verbose: bool = True):
    """
    Load ROI centers (y,x) for a Suite2p plane.
    Accepts either {animal}_plane{plane_idx}_stat.npy or plain stat.npy.
    """
    stat_candidates = [
        plane_dir / f"{animal}_plane{plane_idx}_stat.npy",
        plane_dir / "stat.npy",
    ]

    stat_p = next((p for p in stat_candidates if p.exists()), None)
    if stat_p is None:
        raise FileNotFoundError(f"Missing Suite2p stat.npy in {plane_dir}")

    stat = np.load(stat_p, allow_pickle=True)
    rois = np.array([s["med"] for s in stat], dtype=np.float32)

    if verbose:
        print(f"    [s2p] loaded {len(rois)} ROIs from {stat_p.name}")

    return rois


def apply_func2anat_transform(rois, reg_row):
    """Apply functional→anatomy (as used during registration): scale, shift, z.
    Returns (N,3) as z,y,x in the *flipped anatomy* space (no unflip here).
    """
    s = reg_row["scale_moving_to_fixed"]
    y = reg_row["y_px"]
    x = reg_row["x_px"]
    z = reg_row["z_index"]

    # Scale and shift into fixed (flipped) coords
    rois_scaled  = rois * s                  # (N,2): (y,x)
    rois_shifted = rois_scaled + np.array([y, x], dtype=np.float32)

    # Add Z
    z_coords = np.full((rois.shape[0], 1), z, dtype=np.float32)
    return np.hstack([z_coords, rois_shifted])  # (N,3): z,y,x  (still in flipped anatomy space)



def apply_anat2ref_transform(
    rois_anat_zyx_idx: np.ndarray,
    trans_dir: Path,
    animal: str,
    anat_nrrd_path: Path,
    ref_spacing_zyx: np.ndarray,
    ref_origin_zyx: np.ndarray,
) -> np.ndarray:
    """
    Input: anatomy **index** (z,y,x) coords (unflipped to match ANTs).
    Output: reference **index** (z,y,x) coords.
    """
    affine = trans_dir / f"{animal}_2P_to_ref_0GenericAffine.mat"
    warp   = trans_dir / f"{animal}_2P_to_ref_1Warp.nii.gz"
    if not affine.exists() or not warp.exists():
        raise FileNotFoundError(f"Missing transforms for {animal}: {trans_dir}")

    anat_spacing_zyx, anat_origin_zyx = asu.get_spacing_origin_ZYX(anat_nrrd_path)

    # (1) anat index -> anat physical (z,y,x)
    anat_phys_zyx = asu.indexZYX_to_physZYX(rois_anat_zyx_idx.astype(np.float64), anat_spacing_zyx, anat_origin_zyx)

    # (2) reorder to (x,y,z) for antspy
    anat_phys_xyz = anat_phys_zyx[:, [2,1,0]]
    pts_df = pd.DataFrame(anat_phys_xyz, columns=["x","y","z"])

    # (3) forward animal->ref (warp, then affine) without inversion
    tx = [str(warp), str(affine)]
    df_out = ants.apply_transforms_to_points(
        3,
        pts_df,
        tx,
        whichtoinvert=[False, True]   # invert affine, keep warp forward
    )

    # (4) back to (z,y,x) physical
    ref_phys_zyx = df_out[["z","y","x"]].to_numpy(dtype=np.float64)

    # (5) ref physical -> ref index
    ref_idx_zyx = asu.physZYX_to_indexZYX(ref_phys_zyx, ref_spacing_zyx, ref_origin_zyx)
    return ref_idx_zyx.astype(np.float32)



def process_animal(animal_dir: Path):
    animal = animal_dir.name
    print(f"\n=== Processing {animal} ===")

    # I/O paths
    reg_csv   = animal_dir / REG_RESULTS_SUBPATH / f"{animal}_registration_results.csv"
    s2p_dir   = animal_dir / FUNC_S2P_SUBPATH
    trans_dir = animal_dir / ANTS_TRANS_SUBPATH
    out_dir   = animal_dir / ROI_OUT_SUBPATH
    out_csv   = out_dir / f"{animal}_rois_transformed.csv"

    # already have:
    ANAT_SUBPATH  = Path(r"02_reg/00_preprocessing/2p_anatomy")
    #REF_PATH      = Path(r".../average_2p_noRot_8bit_flipX_reverse.nrrd")

    # cache reference spacing/origin once (uses helpers you added earlier)
    REF_SPACING_ZYX, REF_ORIGIN_ZYX = asu.get_spacing_origin_ZYX(REF_PATH)

    # if you also rely on a global shape for unflip:
    FIXED_SHAPE = asu.read_good_nrrd_uint8(REF_PATH, flip_horizontal=False).shape  # (Z,Y,X)

# 🔹 NEW: resolve the anatomy NRRD path for THIS animal (pick *GCaMP* if multiple)
    anat_dir = animal_dir / ANAT_SUBPATH
    fixed_matches = sorted(anat_dir.glob(f"{animal}*.nrrd"))
    if len(fixed_matches) > 1:
        gcamp = [p for p in fixed_matches if "gcamp" in p.name.lower()]
        fixed_p = gcamp[0] if len(gcamp) == 1 else fixed_matches[0]
        print(f"  [NOTE] picked anatomy file: {fixed_p.name}")
    elif len(fixed_matches) == 1:
        fixed_p = fixed_matches[0]
    else:
        print(f"[SKIP:missing] anatomy NRRD for {animal} in {anat_dir}")
        return "skipped_missing"

    anat_stack_flipped = asu.read_good_nrrd_uint8(fixed_p, flip_horizontal=True)
    _, _, W_anat = anat_stack_flipped.shape

    # Skip if output exists and not overwriting
    if out_csv.exists() and not OVERWRITE_ROI_CSV:
        print(f"[SKIP:exists] {animal}: {out_csv.name}")
        return "skipped_exists"

    # Basic presence checks
    if not reg_csv.exists():
        print(f"[SKIP:missing] registration CSV not found: {reg_csv}")
        return "skipped_missing"
    if not s2p_dir.exists():
        print(f"[SKIP:missing] suite2p dir not found: {s2p_dir}")
        return "skipped_missing"
    if not trans_dir.exists():
        print(f"[SKIP:missing] ANTs transforms dir not found: {trans_dir}")
        return "skipped_missing"

    df_reg = pd.read_csv(reg_csv)
    all_rows = []
    # Assume row index corresponds to plane index; otherwise map via 'moving_plane' column if present
    plane_col = "moving_plane" if "moving_plane" in df_reg.columns else None

    for idx, reg_row in df_reg.iterrows():
        plane_idx = int(reg_row[plane_col]) if plane_col else int(idx)
        plane_dir = s2p_dir / f"plane{plane_idx}"
        try:
            rois_yx = load_suite2p_rois(plane_dir, animal, plane_idx)
        except FileNotFoundError as e:
            print(f"  [WARN] {e}")
            continue

        #
        
        # functional -> anatomy (index, flipped space)
        # (if your apply_func2anat_transform previously took fixed_stack.shape, switch it to use FIXED_SHAPE global)
        rois_anat_zyx = apply_func2anat_transform(rois_yx, reg_row)  # (N,3): z,y,x

        # 🔹 unflip X to match ANTs (ANTs was run on *unflipped* anatomy)
        rois_anat_zyx[:, 2] = (W_anat - 1) - rois_anat_zyx[:, 2]

        # 🔹 anatomy -> reference with proper index/physical handling (now pass fixed_p)
        rois_ref_zyx = apply_anat2ref_transform(
            rois_anat_zyx,
            trans_dir,
            animal,
            fixed_p,            # <--- pass the anatomy NRRD path here
            REF_SPACING_ZYX,    # cached once
            REF_ORIGIN_ZYX
        )


        # collect rows
        n_rois = rois_ref_zyx.shape[0]
        all_rows.extend(
            {
                "animal": animal,
                "plane": plane_idx,
                "roi_id": int(rid),
                        # NEW: intermediate (functional→anatomy flipped) coords
                # NEW: intermediate (functional→anatomy flipped) coords
                "z_anat": float(zA),
                "y_anat": float(yA),
                "x_anat": float(xA),

                # final reference coords
                "z_ref": float(zR),
                "y_ref": float(yR),
                "x_ref": float(xR),
            }
            for rid, ((zA, yA, xA), (zR, yR, xR)) in enumerate(zip(rois_anat_zyx, rois_ref_zyx))
        )
        print(f"  plane {plane_idx}: {n_rois} ROIs transformed")

    if not all_rows:
        print(f"[SKIP] no ROIs found for {animal}")
        return "skipped_empty"

    df_out = pd.DataFrame(
        all_rows,
        columns=[
            "animal","plane","roi_id",
            "z_anat","y_anat","x_anat",
            "z_ref","y_ref","x_ref",
        ],
    )
    df_out.to_csv(out_csv, index=False)
    print(f"[SAVED] {out_csv}  ({len(df_out)} rows)")
    return "processed"

# ---- run with summary ----
processed = skipped_exists = skipped_missing = skipped_empty = 0

if TEST_ANIMAL is not None:
    status = process_animal(BASE_FOLDER / TEST_ANIMAL)
    if   status == "processed":       processed += 1
    elif status == "skipped_exists":  skipped_exists += 1
    elif status == "skipped_missing": skipped_missing += 1
    elif status == "skipped_empty":   skipped_empty  += 1
else:
    for animal_dir in BASE_FOLDER.iterdir():
        if not animal_dir.is_dir():
            continue
        status = process_animal(animal_dir)
        if   status == "processed":       processed += 1
        elif status == "skipped_exists":  skipped_exists += 1
        elif status == "skipped_missing": skipped_missing += 1
        elif status == "skipped_empty":   skipped_empty  += 1

print("\n--- ROI transform summary ---")
print(f"Processed:        {processed}")
print(f"Skipped (exists): {skipped_exists}")
print(f"Skipped (missing):{skipped_missing}")
print(f"Skipped (empty):  {skipped_empty}")

In [None]:
# === Flexible ROI checker GUI (with flip toggle) ===
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
import ipywidgets as widgets
from IPython.display import display, clear_output

REF_PATH = Path(r"Y:/03_Common_Use/reference brains/ref_05_LB_Perrino_2p/average_2p_noRot_8bit_flipX_reverse.nrrd")

# Params
Z_TOL = 3.0            # show ROIs within ±Z_TOL of selected plane
POINT_SIZE = 6
ALPHA = 0.9


# --- Config ---
BASE_FOLDER = Path(r"Y:/07_Data/Matilde")
ROI_OUT_SUBPATH = Path(r"03_analysis/functional/ROI_transformed")
ANAT_SUBPATH = Path(r"02_reg/00_preprocessing/2p_anatomy")


# Use your existing helper
ref_stack = asu.read_good_nrrd_uint8(REF_PATH, flip_horizontal=False)
Z_ref, Y_ref, X_ref = ref_stack.shape

# --- Load ROI CSVs ---
animals = []
roi_tables = {}
for animal_dir in sorted(p for p in BASE_FOLDER.iterdir() if p.is_dir()):
    animal = animal_dir.name
    csv_p = animal_dir / ROI_OUT_SUBPATH / f"{animal}_rois_transformed.csv"
    if csv_p.exists():
        try:
            df = pd.read_csv(csv_p)
            if {"animal","plane","roi_id","z_anat","y_anat","x_anat","z_ref","y_ref","x_ref"}.issubset(df.columns):
                animals.append(animal)
                roi_tables[animal] = df
        except Exception:
            pass
animals = sorted(animals)

_anat_cache = {}

def get_anat_stack(animal: str):
    if animal in _anat_cache:
        return _anat_cache[animal]

    anat_dir = BASE_FOLDER / animal / ANAT_SUBPATH
    matches = sorted(anat_dir.glob(f"{animal}*.nrrd"))
    if not matches:
        raise FileNotFoundError(f"No anatomy NRRD for {animal} under {anat_dir}")

    chosen = matches[0]
    if len(matches) > 1:
        gcamp_matches = [p for p in matches if "gcamp" in p.name.lower()]
        if len(gcamp_matches) == 1:
            chosen = gcamp_matches[0]

    stk = asu.read_good_nrrd_uint8(chosen, flip_horizontal=False)
    _anat_cache[animal] = stk
    return stk

# --- Widgets ---
space_dd = widgets.Dropdown(
    options=[("reference space (ref)", "ref"), ("anatomy space (animal)", "anat")],
    value="ref", description="Display space:"
)
animals_ms = widgets.SelectMultiple(
    options=animals, value=tuple(animals[:3]) if len(animals)>=3 else tuple(animals),
    description="ROIs from:", rows=min(10, len(animals))
)
bg_dd = widgets.Dropdown(options=[("reference (ref stack)", "__REF__")]+[(a,a) for a in animals],
                         value="__REF__", description="Background:")
plane_slider = widgets.IntSlider(value=0,min=0,max=Z_ref-1,step=1,description="Plane (z):",continuous_update=False)
z_tol_slider = widgets.FloatSlider(value=3.0,min=0,max=10,step=0.5,description="z tolerance:",continuous_update=False)
pt_sz = widgets.FloatSlider(value=8,min=2,max=20,step=1,description="Point size:",continuous_update=False)
alpha_pts = widgets.FloatSlider(value=0.9,min=0.1,max=1,step=0.05,description="Alpha:",continuous_update=False)
flip_cb = widgets.Checkbox(value=False, description="Flip display (horiz)")

# --- Update bg options ---
def _update_bg(*_):
    if space_dd.value=="ref":
        bg_dd.options=[("reference (ref stack)","__REF__")]
        bg_dd.value="__REF__"; bg_dd.disabled=True
        plane_slider.max=ref_stack.shape[0]-1
    else:
        bg_dd.options=[(a,a) for a in animals]; bg_dd.disabled=False
        if bg_dd.value=="__REF__" or bg_dd.value not in animals:
            bg_dd.value=animals[0] if animals else "__REF__"
        if animals:
            plane_slider.max=get_anat_stack(bg_dd.value).shape[0]-1
space_dd.observe(_update_bg, names="value")
bg_dd.observe(_update_bg, names="value")
_update_bg()

# --- Plotting ---
cmap = plt.get_cmap("tab10")
def _color_for(animal): return cmap(animals.index(animal)%10) if animal in animals else "red"

out = widgets.Output()
def render(*_):
    with out:
        clear_output(wait=True)
        if space_dd.value=="ref":
            bg = ref_stack; z_col,y_col,x_col="z_ref","y_ref","x_ref"
        else:
            bg=get_anat_stack(bg_dd.value); z_col,y_col,x_col="z_anat","y_anat","x_anat"

        z=int(plane_slider.value); ztol=float(z_tol_slider.value)
        img=bg[z].astype(np.float32)
        p1,p99=np.percentile(img,(1,99)); 
        if p99<=p1: p1,p99=img.min(), max(img.max(),img.min()+1e-3)
        img_n=np.clip((img-p1)/(p99-p1+1e-6),0,1)

        if flip_cb.value:  # <- Flip display only
            img_n=np.fliplr(img_n)

        H,W=img_n.shape
        fig,ax=plt.subplots(figsize=(6,6))
        ax.imshow(img_n,cmap="gray",interpolation="nearest")
        ax.set_xlim(0,W); ax.set_ylim(H,0); ax.set_aspect("equal"); ax.set_axis_off()
        ax.set_title(f"{'Reference' if space_dd.value=='ref' else bg_dd.value} | z={z} | "
                     f"ROIs from {', '.join(animals_ms.value)} | space={space_dd.label}, flip={flip_cb.value}")

        for ani in animals_ms.value:
            if ani not in roi_tables: continue
            df=roi_tables[ani]
            rows=df[np.abs(df[z_col]-z)<=ztol]
            if rows.empty: continue
            ax.scatter(rows[x_col], rows[y_col], s=pt_sz.value, alpha=alpha_pts.value,
                       c=[_color_for(ani)], label=ani, edgecolors="none")
        if animals_ms.value: ax.legend(loc="lower right",fontsize=8,frameon=True)
        plt.show()

for w in (space_dd,animals_ms,bg_dd,plane_slider,z_tol_slider,pt_sz,alpha_pts,flip_cb):
    w.observe(render,names="value")
render()

# --- Layout ---
display(widgets.HBox([space_dd,bg_dd,flip_cb,z_tol_slider]),
        widgets.HBox([animals_ms,widgets.VBox([plane_slider,pt_sz,alpha_pts])]),
        out)
