# Confocal central crop FireANTs test

Manual experiment to crop the confocal stack to the two-photon field of view and rerun FireANTs using the existing pipeline parameters.

In [None]:
from __future__ import annotations
import copy
import json
from pathlib import Path

import numpy as np
import tifffile
from skimage.exposure import match_histograms

from social_imaging_scripts.metadata.loader import load_animals
from social_imaging_scripts.metadata.config import load_project_config
from social_imaging_scripts.registration.confocal_to_anatomy import register_confocal_to_anatomy

animal_id = "L331_f01"
conf_session_id = "L331_f01_confocal_round1_2025-05-20"
anat_session_id = "L331_f01_anatomy_2025-05-20"

cfg = load_project_config()
conf_preproc_cfg = cfg.confocal_preprocessing
conf_stage_cfg = cfg.confocal_to_anatomy_registration
anat_stage_cfg = cfg.anatomy_preprocessing

repo_root = Path.cwd().resolve().parent
metadata_dir = repo_root / "metadata" / "animals"

conf_preproc_dir = (
    Path(cfg.output_base_dir)
    / animal_id
    / conf_preproc_cfg.root_subdir
    / conf_session_id
)
anat_preproc_dir = (
    Path(cfg.output_base_dir)
    / animal_id
    / anat_stage_cfg.root_subdir
)

conf_meta_path = conf_preproc_dir / conf_preproc_cfg.metadata_filename_template.format(
    session_id=conf_session_id,
    animal_id=animal_id,
)
conf_meta = json.loads(conf_meta_path.read_text())

anat_meta_path = anat_preproc_dir / anat_stage_cfg.metadata_filename_template.format(
    animal_id=animal_id,
    session_id=anat_session_id,
)
anat_meta = json.loads(anat_meta_path.read_text())

anat_stack_path = anat_preproc_dir / anat_stage_cfg.stack_filename_template.format(
    animal_id=animal_id,
    session_id=anat_session_id,
)

with tifffile.TiffFile(conf_meta["channels"]["gcamp"]) as tif:
    conf_shape = tif.series[0].shape  # (z, y, x)
with tifffile.TiffFile(str(anat_stack_path)) as tif:
    anat_shape = tif.series[0].shape

conf_spacing = np.asarray(conf_meta["voxel_size_um"], dtype=float)
anat_spacing_xy = np.asarray(anat_meta["pixel_size_xy_um"], dtype=float)

animals = load_animals(base_dir=metadata_dir)
animal = animals.by_id(animal_id)
anat_session = next(s for s in animal.sessions if s.session_id == anat_session_id)
z_spacing = (
    getattr(anat_session.session_data, "plane_spacing", None)
    or getattr(anat_session.session_data, "z_step_um", None)
    or 1.0
)

existing_flip_z = bool(conf_meta.get("flip_z", False))
target_flip_z = bool(conf_preproc_cfg.flip_z)
print(f"Existing flip_z={existing_flip_z}, target flip_z={target_flip_z}")

anat_extent_x_um = anat_shape[2] * anat_spacing_xy[0]
anat_extent_y_um = anat_shape[1] * anat_spacing_xy[1]

crop_x_vox = int(round(anat_extent_x_um / conf_spacing[0]))
crop_y_vox = int(round(anat_extent_y_um / conf_spacing[1]))
crop_x_vox = min(crop_x_vox, conf_shape[2])
crop_y_vox = min(crop_y_vox, conf_shape[1])

start_x = max(0, (conf_shape[2] - crop_x_vox) // 2)
start_y = max(0, (conf_shape[1] - crop_y_vox) // 2)
end_x = min(conf_shape[2], start_x + crop_x_vox)
end_y = min(conf_shape[1], start_y + crop_y_vox)

print(f"Confocal shape: {conf_shape}")
print(f"Target crop (y, x): {crop_y_vox} x {crop_x_vox}")
print(f"Crop index ranges: y={start_y}:{end_y}, x={start_x}:{end_x}")

crop_dir = conf_preproc_dir / "central_crop"
crop_dir.mkdir(parents=True, exist_ok=True)

cropped_channels: dict[str, Path] = {}
for name, channel_path in conf_meta["channels"].items():
    channel_path = Path(channel_path)
    out_path = crop_dir / f"{name}_central_crop.tif"
    if not out_path.exists():
        data_mm = tifffile.memmap(channel_path, mode="r")
        cropped = np.asarray(data_mm[:, start_y:end_y, start_x:end_x])
        if target_flip_z and not existing_flip_z:
            cropped = np.flip(cropped, axis=0)
        tifffile.imwrite(out_path, cropped, dtype=cropped.dtype)
        del data_mm
    cropped_channels[name] = out_path

with tifffile.TiffFile(str(cropped_channels["gcamp"])) as tif:
    cropped_shape = tif.series[0].shape
print(f"Cropped gcamp shape: {cropped_shape}")

# Build soft mask to emphasise central region
def build_mask(shape, margin_xy: int, margin_z: int) -> np.ndarray:
    mask = np.zeros(shape, dtype=np.float32)
    z0 = margin_z
    z1 = max(z0, shape[0] - margin_z)
    y0 = margin_xy
    y1 = max(y0, shape[1] - margin_xy)
    x0 = margin_xy
    x1 = max(x0, shape[2] - margin_xy)
    mask[z0:z1, y0:y1, x0:x1] = 1.0
    if mask.sum() == 0:
        return np.ones(shape, dtype=np.float32)
    return mask

conf_margin_xy = max(10, int(0.08 * cropped_shape[1]))
conf_margin_z = max(2, int(0.05 * cropped_shape[0]))
conf_mask = build_mask(cropped_shape, conf_margin_xy, conf_margin_z)

anat_arr = tifffile.imread(anat_stack_path)
anat_margin_xy = max(10, int(0.1 * anat_shape[1]))
anat_margin_z = max(2, int(0.05 * anat_shape[0]))
anat_mask = build_mask(anat_shape, anat_margin_xy, anat_margin_z)

mask_dir = crop_dir / "masked"
mask_dir.mkdir(exist_ok=True)

masked_channels: dict[str, Path] = {}
for name, path in cropped_channels.items():
    vol = tifffile.imread(path).astype(np.float32)
    masked = vol * conf_mask
    masked_path = mask_dir / f"{name}_central_crop_masked.tif"
    tifffile.imwrite(masked_path, masked)
    masked_channels[name] = masked_path

anat_masked_path = mask_dir / "anat_central_masked.tif"
tifffile.imwrite(anat_masked_path, (anat_arr * anat_mask).astype(np.float32))

voxel_spacing_um = tuple(float(v) for v in conf_spacing)
fixed_spacing_um = (
    float(anat_spacing_xy[0]),
    float(anat_spacing_xy[1]),
    float(z_spacing),
)

additional_channels = {
    name: path
    for name, path in masked_channels.items()
    if name.lower() != conf_stage_cfg.reference_channel_name.lower()
}

output_root = (
    Path(cfg.output_base_dir)
    / animal_id
    / "02_reg"
    / "central_crop_test"
    / conf_session_id
)
output_root.mkdir(parents=True, exist_ok=True)

original_result = register_confocal_to_anatomy(
    animal_id=animal_id,
    confocal_session_id=conf_session_id,
    anatomy_session_id=anat_session_id,
    moving_channel_path=masked_channels[conf_stage_cfg.reference_channel_name],
    fixed_stack_path=anat_masked_path,
    additional_channels=additional_channels,
    output_root=output_root,
    config=copy.deepcopy(conf_stage_cfg.fireants),
    voxel_spacing_um=voxel_spacing_um,
    fixed_spacing_um=fixed_spacing_um,
    warped_channel_template="{animal_id}_{confocal_session_id}_{channel}_central_warped.tif",
    metadata_filename="central_crop_registration_metadata.json",
    transforms_subdir=conf_stage_cfg.transforms_subdir,
    qc_subdir=conf_stage_cfg.qc_subdir,
    reference_channel_name=conf_stage_cfg.reference_channel_name,
    mask_margin_xy=conf_stage_cfg.mask_margin_xy,
    mask_margin_z=conf_stage_cfg.mask_margin_z,
    histogram_match=False,
    histogram_levels=conf_stage_cfg.histogram_levels,
    histogram_match_points=conf_stage_cfg.histogram_match_points,
    histogram_threshold_at_mean=conf_stage_cfg.histogram_threshold_at_mean,
    crop_to_extent=True,
    crop_padding_um=0.0,
)

print("Original masked registration outputs:")
for key, value in original_result.items():
    print(f"  {key}: {value}")

# Histogram-matching on masked data
histmatch_dir = crop_dir / "histmatch"
histmatch_dir.mkdir(exist_ok=True)

matched_channels: dict[str, Path] = {}
ref_arr = anat_arr * anat_mask
moving_ref = tifffile.imread(masked_channels[conf_stage_cfg.reference_channel_name]).astype(np.float32)
matched = np.empty_like(moving_ref)
for z in range(moving_ref.shape[0]):
    ref_z = ref_arr[min(z, ref_arr.shape[0] - 1)]
    matched[z] = match_histograms(moving_ref[z], ref_z, channel_axis=None)
matched_path = histmatch_dir / "gcamp_central_crop_masked_histmatch.tif"
tifffile.imwrite(matched_path, matched.astype(np.float32))

for name, path in additional_channels.items():
    matched_channels[name] = path
matched_channels[conf_stage_cfg.reference_channel_name] = matched_path

histmatch_output = output_root / "histmatch"
histmatch_output.mkdir(parents=True, exist_ok=True)

matched_result = register_confocal_to_anatomy(
    animal_id=animal_id,
    confocal_session_id=conf_session_id,
    anatomy_session_id=anat_session_id,
    moving_channel_path=matched_channels[conf_stage_cfg.reference_channel_name],
    fixed_stack_path=anat_masked_path,
    additional_channels={name: matched_channels.get(name, path) for name, path in additional_channels.items()},
    output_root=histmatch_output,
    config=copy.deepcopy(conf_stage_cfg.fireants),
    voxel_spacing_um=voxel_spacing_um,
    fixed_spacing_um=fixed_spacing_um,
    warped_channel_template="{animal_id}_{confocal_session_id}_{channel}_central_warped.tif",
    metadata_filename="central_crop_registration_metadata.json",
    transforms_subdir=conf_stage_cfg.transforms_subdir,
    qc_subdir=conf_stage_cfg.qc_subdir,
    reference_channel_name=conf_stage_cfg.reference_channel_name,
    mask_margin_xy=conf_stage_cfg.mask_margin_xy,
    mask_margin_z=conf_stage_cfg.mask_margin_z,
    histogram_match=True,
    histogram_levels=conf_stage_cfg.histogram_levels,
    histogram_match_points=conf_stage_cfg.histogram_match_points,
    histogram_threshold_at_mean=conf_stage_cfg.histogram_threshold_at_mean,
    crop_to_extent=True,
    crop_padding_um=0.0,
)

print("
Histogram-matched masked registration outputs:")
for key, value in matched_result.items():
    print(f"  {key}: {value}")