# Compare Prematch Approaches: Legacy vs Modern

This notebook compares two approaches for handling manual prematch in confocal→anatomy registration:

1. **Legacy Mode**: Pre-apply rotation with scipy, then hard crop, seed translation only into FireANTs
2. **Modern Mode**: Seed full prematch (rotation + translation) into FireANTs, no pre-rotation

We'll test on L395_f06's second confocal session (59° rotation) and compare:
- Registration loss values
- Transform matrices
- Visual quality of warped outputs

In [None]:
from pathlib import Path
import json
import numpy as np
import tifffile
import matplotlib.pyplot as plt
from datetime import datetime

from social_imaging_scripts.metadata.config import StageMode, load_project_config
from social_imaging_scripts.pipeline import run_pipeline

# Load config
cfg = load_project_config()

# Test parameters
ANIMAL_ID = "L395_f06"
CONFOCAL_SESSION_INDEX = 1  # Second session (0-indexed) = 59° rotation

print(f"Testing animal: {ANIMAL_ID}")
print(f"Confocal session index: {CONFOCAL_SESSION_INDEX}")
print(f"Timestamp: {datetime.now().isoformat()}")

## Setup Output Directories

We'll save outputs to separate folders for comparison.

In [None]:
# Create comparison output directories
comparison_root = Path(cfg.output_base_dir) / "_comparison_prematch_approaches"
legacy_output = comparison_root / "legacy_mode"
modern_output = comparison_root / "modern_mode"

legacy_output.mkdir(parents=True, exist_ok=True)
modern_output.mkdir(parents=True, exist_ok=True)

print(f"Comparison root: {comparison_root}")
print(f"Legacy output: {legacy_output}")
print(f"Modern output: {modern_output}")

## Run Legacy Mode

**Direct registration call** - bypasses full pipeline to save time. We'll use existing preprocessed data and only run the confocal→anatomy registration step.

In [None]:
print("="*80)
print("RUNNING LEGACY MODE (hard crop approach)")
print("="*80)

from social_imaging_scripts.registration.confocal_to_anatomy import register_confocal_to_anatomy
from social_imaging_scripts.metadata.loader import load_animal_file
import json

# Load animal metadata to get confocal session info
notebook_dir = Path.cwd()
if notebook_dir.name == "exampleNotebooks":
    metadata_dir = notebook_dir.parent / "metadata" / "animals"
else:
    metadata_dir = Path("/home/jlarsch/social_imaging_scripts/metadata/animals")

animal_yaml = metadata_dir / f"{ANIMAL_ID}.yaml"

if not animal_yaml.exists():
    raise FileNotFoundError(f"Animal metadata not found: {animal_yaml}")

animal_meta = load_animal_file(animal_yaml)
print(f"Loaded metadata for: {animal_meta.animal_id}")

# Filter for confocal sessions
confocal_sessions = [
    s for s in animal_meta.sessions 
    if s.session_type == "anatomy_stack" and getattr(s.session_data, "stack_type", "") == "confocal"
]

if not confocal_sessions:
    raise ValueError(f"No confocal sessions found for {ANIMAL_ID}")

print(f"Found {len(confocal_sessions)} confocal session(s)")

if CONFOCAL_SESSION_INDEX >= len(confocal_sessions):
    raise IndexError(f"Confocal session index {CONFOCAL_SESSION_INDEX} out of range")

confocal_session = confocal_sessions[CONFOCAL_SESSION_INDEX]
confocal_session_id = confocal_session.session_id

print(f"Confocal session ID: {confocal_session_id}")

# Construct paths to existing preprocessed data
animal_output_base = Path(cfg.output_base_dir) / ANIMAL_ID
anatomy_path = animal_output_base / cfg.anatomy_preprocessing.root_subdir / f"{ANIMAL_ID}_anatomy_stack.tif"
confocal_dir = animal_output_base / cfg.confocal_preprocessing.root_subdir / confocal_session_id

print(f"Animal output base: {animal_output_base}")
print(f"Anatomy path: {anatomy_path}")
print(f"Confocal dir: {confocal_dir}")

# Verify files exist
if not anatomy_path.exists():
    raise FileNotFoundError(f"Anatomy file not found: {anatomy_path}")
if not confocal_dir.exists():
    raise FileNotFoundError(f"Confocal directory not found: {confocal_dir}")

# Load confocal preprocessing metadata to get voxel spacing
confocal_metadata_path = confocal_dir / f"{confocal_session_id}_confocal_metadata.json"
with open(confocal_metadata_path) as f:
    confocal_meta = json.load(f)
    
voxel_spacing_um = tuple(confocal_meta["voxel_size_um"])
print(f"Confocal voxel spacing: {voxel_spacing_um} µm")

# Get anatomy voxel spacing from preprocessing metadata
anatomy_metadata_path = animal_output_base / cfg.anatomy_preprocessing.root_subdir / f"{ANIMAL_ID}_anatomy_metadata.json"
with open(anatomy_metadata_path) as f:
    anatomy_meta = json.load(f)

# Anatomy metadata has pixel_size_xy_um (2D) - need to add Z spacing
# For 2p anatomy, Z spacing comes from the raw data metadata
pixel_xy = anatomy_meta["pixel_size_xy_um"][0]  # XY are the same
# Use a default Z spacing for 2p (typically 3-5 µm)
# Or we could load it from the session metadata if available
z_spacing = 3.0  # Default 2p Z step size in µm
fixed_spacing_um = (z_spacing, pixel_xy, pixel_xy)  # (Z, Y, X) order
print(f"Anatomy voxel spacing: {fixed_spacing_um} µm (Z, Y, X)")

# Find the reference channel (gcamp) and additional channels
moving_channel_path = confocal_dir / f"channel_{cfg.confocal_to_anatomy_registration.reference_channel_name}.tif"
if not moving_channel_path.exists():
    raise FileNotFoundError(f"Reference channel not found: {moving_channel_path}")

# Find all other channels
all_channels = list(confocal_dir.glob("channel_*.tif"))
additional_channels = {
    ch.stem.replace("channel_", ""): ch 
    for ch in all_channels 
    if ch != moving_channel_path
}

print(f"Reference channel: {moving_channel_path.name}")
print(f"Additional channels: {list(additional_channels.keys())}")
print("✓ Found preprocessed data")

# Output directory for legacy mode
legacy_reg_output = legacy_output / "02_reg" / f"05_{confocal_session_id}_to_anatomy"
legacy_reg_output.mkdir(parents=True, exist_ok=True)

print(f"Legacy registration output: {legacy_reg_output}")

# Call registration directly with legacy mode enabled
result_legacy = register_confocal_to_anatomy(
    animal_id=ANIMAL_ID,
    confocal_session_id=confocal_session_id,
    anatomy_session_id="anatomy",
    moving_channel_path=moving_channel_path,
    fixed_stack_path=anatomy_path,
    additional_channels=additional_channels,
    output_root=legacy_reg_output,
    config=cfg.confocal_to_anatomy_registration.fireants,
    voxel_spacing_um=voxel_spacing_um,
    fixed_spacing_um=fixed_spacing_um,
    prematch_settings=None,  # Will be loaded from processing log
    warped_channel_template=cfg.confocal_to_anatomy_registration.warped_channel_template,
    metadata_filename=cfg.confocal_to_anatomy_registration.metadata_filename_template,
    transforms_subdir=cfg.confocal_to_anatomy_registration.transforms_subdir,
    qc_subdir=cfg.confocal_to_anatomy_registration.qc_subdir,
    reference_channel_name=cfg.confocal_to_anatomy_registration.reference_channel_name,
    mask_margin_xy=cfg.confocal_to_anatomy_registration.mask_margin_xy,
    mask_margin_z=cfg.confocal_to_anatomy_registration.mask_margin_z,
    mask_soft_edges=cfg.confocal_to_anatomy_registration.mask_soft_edges,
    histogram_match=cfg.confocal_to_anatomy_registration.histogram_match,
    histogram_levels=cfg.confocal_to_anatomy_registration.histogram_levels,
    histogram_match_points=cfg.confocal_to_anatomy_registration.histogram_match_points,
    histogram_threshold_at_mean=cfg.confocal_to_anatomy_registration.histogram_threshold_at_mean,
    initial_translation_mode=cfg.confocal_to_anatomy_registration.initial_translation_mode,
    crop_to_extent=cfg.confocal_to_anatomy_registration.crop_to_extent,
    crop_padding_um=cfg.confocal_to_anatomy_registration.crop_padding_um,
    output_base_dir=Path(cfg.output_base_dir),
    processing_log_config=cfg.processing_log,
    use_legacy_rotation_cropping=True,  # ENABLE LEGACY MODE
)

print("\n" + "="*80)
print("LEGACY MODE COMPLETE")
print("="*80)
print(f"Output saved to: {legacy_reg_output}")

## Run Modern Mode

**Direct registration call** - same preprocessed inputs, just different registration approach (no hard crop).

In [None]:
print("="*80)
print("RUNNING MODERN MODE (mask-only approach, no hard crop)")
print("="*80)

# Output directory for modern mode
modern_reg_output = modern_output / "02_reg" / f"05_{confocal_session_id}_to_anatomy"
modern_reg_output.mkdir(parents=True, exist_ok=True)

print(f"Modern registration output: {modern_reg_output}")

# Call registration directly with legacy mode DISABLED (modern approach)
result_modern = register_confocal_to_anatomy(
    animal_id=ANIMAL_ID,
    confocal_session_id=confocal_session_id,
    anatomy_session_id="anatomy",
    moving_channel_path=moving_channel_path,
    fixed_stack_path=anatomy_path,
    additional_channels=additional_channels,
    output_root=modern_reg_output,
    config=cfg.confocal_to_anatomy_registration.fireants,
    voxel_spacing_um=voxel_spacing_um,
    fixed_spacing_um=fixed_spacing_um,
    prematch_settings=None,  # Will be loaded from processing log
    warped_channel_template=cfg.confocal_to_anatomy_registration.warped_channel_template,
    metadata_filename=cfg.confocal_to_anatomy_registration.metadata_filename_template,
    transforms_subdir=cfg.confocal_to_anatomy_registration.transforms_subdir,
    qc_subdir=cfg.confocal_to_anatomy_registration.qc_subdir,
    reference_channel_name=cfg.confocal_to_anatomy_registration.reference_channel_name,
    mask_margin_xy=cfg.confocal_to_anatomy_registration.mask_margin_xy,
    mask_margin_z=cfg.confocal_to_anatomy_registration.mask_margin_z,
    mask_soft_edges=cfg.confocal_to_anatomy_registration.mask_soft_edges,
    histogram_match=cfg.confocal_to_anatomy_registration.histogram_match,
    histogram_levels=cfg.confocal_to_anatomy_registration.histogram_levels,
    histogram_match_points=cfg.confocal_to_anatomy_registration.histogram_match_points,
    histogram_threshold_at_mean=cfg.confocal_to_anatomy_registration.histogram_threshold_at_mean,
    initial_translation_mode=cfg.confocal_to_anatomy_registration.initial_translation_mode,
    crop_to_extent=cfg.confocal_to_anatomy_registration.crop_to_extent,
    crop_padding_um=cfg.confocal_to_anatomy_registration.crop_padding_um,
    output_base_dir=Path(cfg.output_base_dir),
    processing_log_config=cfg.processing_log,
    use_legacy_rotation_cropping=False,  # DISABLE LEGACY MODE (modern)
)

print("\n" + "="*80)
print("MODERN MODE COMPLETE")
print("="*80)
print(f"Output saved to: {modern_reg_output}")

## Load and Compare Registration Metadata

In [None]:
def load_registration_metadata(reg_dir: Path):
    """Load registration metadata from registration directory."""
    # Find metadata JSON
    metadata_files = list(reg_dir.glob("*registration_metadata.json"))
    if not metadata_files:
        raise FileNotFoundError(f"No metadata JSON found in {reg_dir}")
    
    metadata_path = metadata_files[0]
    print(f"Loading metadata from: {metadata_path.name}")
    
    with open(metadata_path) as f:
        metadata = json.load(f)
    
    return metadata, reg_dir

# Load both metadata files
print("Loading LEGACY metadata...")
metadata_legacy, session_dir_legacy = load_registration_metadata(legacy_reg_output)

print("\nLoading MODERN metadata...")
metadata_modern, session_dir_modern = load_registration_metadata(modern_reg_output)

print("\n✓ Metadata loaded successfully")

## Compare Prematch Settings

In [None]:
print("="*80)
print("PREMATCH COMPARISON")
print("="*80)

prematch_legacy = metadata_legacy.get("prematch", {})
prematch_modern = metadata_modern.get("prematch", {})

print("\nLEGACY MODE:")
print(f"  Rotation: {prematch_legacy.get('result', {}).get('rotation_deg', 'N/A')}°")
print(f"  Translation: {prematch_legacy.get('result', {}).get('translation_um', 'N/A')} µm")
print(f"  Note: {prematch_legacy.get('note', 'N/A')}")

print("\nMODERN MODE:")
print(f"  Rotation: {prematch_modern.get('result', {}).get('rotation_deg', 'N/A')}°")
print(f"  Translation: {prematch_modern.get('result', {}).get('translation_um', 'N/A')} µm")
print(f"  Note: {prematch_modern.get('note', 'N/A')}")

## Compare Transform Matrices

In [None]:
print("="*80)
print("TRANSFORM MATRIX COMPARISON")
print("="*80)

# Load affine matrices
affine_legacy = np.array(prematch_legacy.get("affine_matrix", []))
affine_modern = np.array(prematch_modern.get("affine_matrix", []))

print("\nLEGACY MODE Prematch Affine Matrix:")
print(affine_legacy)

print("\nMODERN MODE Prematch Affine Matrix:")
print(affine_modern)

if affine_legacy.size > 0 and affine_modern.size > 0:
    print("\nDifference (Modern - Legacy):")
    diff = affine_modern - affine_legacy
    print(diff)
    print(f"\nMax absolute difference: {np.max(np.abs(diff)):.6f}")
    
    # Extract rotation components
    print("\n" + "="*80)
    print("ROTATION COMPONENT ANALYSIS")
    print("="*80)
    
    if affine_legacy.shape == (4, 4):
        # Legacy: rotation should be identity (rotation pre-applied)
        rot_legacy = affine_legacy[:3, :3]
        angle_legacy = np.arctan2(rot_legacy[1, 0], rot_legacy[0, 0]) * 180 / np.pi
        print(f"\nLegacy rotation angle: {angle_legacy:.2f}° (should be ~0° since rotation pre-applied)")
        
        # Modern: rotation should be present (~59°)
        rot_modern = affine_modern[:3, :3]
        angle_modern = np.arctan2(rot_modern[1, 0], rot_modern[0, 0]) * 180 / np.pi
        print(f"Modern rotation angle: {angle_modern:.2f}° (should be ~59° from prematch)")
        
        print(f"\nRotation difference: {angle_modern - angle_legacy:.2f}°")

## Load and Compare Warped Outputs

In [None]:
def load_warped_channel(session_dir: Path, channel_name: str = "gcamp"):
    """Load warped channel from session directory."""
    warped_files = list(session_dir.glob(f"*{channel_name}_warped.tif"))
    if not warped_files:
        raise FileNotFoundError(f"No warped {channel_name} file found in {session_dir}")
    
    warped_path = warped_files[0]
    print(f"Loading: {warped_path.name}")
    
    warped = tifffile.imread(warped_path)
    print(f"  Shape: {warped.shape}")
    print(f"  Dtype: {warped.dtype}")
    print(f"  Range: [{warped.min():.2f}, {warped.max():.2f}]")
    
    return warped, warped_path

print("Loading LEGACY warped output...")
warped_legacy, path_legacy = load_warped_channel(session_dir_legacy)

print("\nLoading MODERN warped output...")
warped_modern, path_modern = load_warped_channel(session_dir_modern)

print("\n✓ Warped outputs loaded")

## Visual Comparison: Middle Slices

In [None]:
# Select middle Z slice
z_mid = warped_legacy.shape[0] // 2

fig, axes = plt.subplots(1, 3, figsize=(18, 6))

# Legacy
axes[0].imshow(warped_legacy[z_mid], cmap='gray', vmin=0, vmax=np.percentile(warped_legacy[z_mid], 99.5))
axes[0].set_title(f'Legacy Mode (Z={z_mid})', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Modern
axes[1].imshow(warped_modern[z_mid], cmap='gray', vmin=0, vmax=np.percentile(warped_modern[z_mid], 99.5))
axes[1].set_title(f'Modern Mode (Z={z_mid})', fontsize=14, fontweight='bold')
axes[1].axis('off')

# Difference
diff_slice = warped_modern[z_mid].astype(float) - warped_legacy[z_mid].astype(float)
vmax_diff = np.percentile(np.abs(diff_slice), 99)
im = axes[2].imshow(diff_slice, cmap='RdBu_r', vmin=-vmax_diff, vmax=vmax_diff)
axes[2].set_title(f'Difference (Modern - Legacy)', fontsize=14, fontweight='bold')
axes[2].axis('off')
plt.colorbar(im, ax=axes[2], fraction=0.046, pad=0.04)

plt.tight_layout()
plt.savefig(comparison_root / 'slice_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"\nSaved comparison figure to: {comparison_root / 'slice_comparison.png'}")

## Quantitative Comparison

In [None]:
print("="*80)
print("QUANTITATIVE COMPARISON")
print("="*80)

# Compute differences
diff_vol = warped_modern.astype(float) - warped_legacy.astype(float)

print(f"\nVolume shape: {warped_legacy.shape}")
print(f"\nIntensity Statistics:")
print(f"  Legacy mean: {warped_legacy.mean():.2f}")
print(f"  Modern mean: {warped_modern.mean():.2f}")
print(f"\nDifference Statistics:")
print(f"  Mean absolute difference: {np.abs(diff_vol).mean():.4f}")
print(f"  Max absolute difference: {np.abs(diff_vol).max():.4f}")
print(f"  RMS difference: {np.sqrt((diff_vol**2).mean()):.4f}")

# Correlation
flat_legacy = warped_legacy.flatten()
flat_modern = warped_modern.flatten()
correlation = np.corrcoef(flat_legacy, flat_modern)[0, 1]
print(f"\nPearson correlation: {correlation:.6f}")

# Structural similarity (if available)
try:
    from skimage.metrics import structural_similarity as ssim
    
    # Compute SSIM on middle slice
    ssim_val = ssim(
        warped_legacy[z_mid],
        warped_modern[z_mid],
        data_range=warped_legacy[z_mid].max() - warped_legacy[z_mid].min()
    )
    print(f"\nStructural Similarity (SSIM) on Z={z_mid}: {ssim_val:.6f}")
except ImportError:
    print("\n(skimage not available for SSIM computation)")

## Summary Report

In [None]:
summary = {
    "animal_id": ANIMAL_ID,
    "confocal_session_index": CONFOCAL_SESSION_INDEX,
    "timestamp": datetime.now().isoformat(),
    "legacy_mode": {
        "output_dir": str(session_dir_legacy),
        "prematch_rotation_deg": prematch_legacy.get('result', {}).get('rotation_deg'),
        "prematch_translation_um": prematch_legacy.get('result', {}).get('translation_um'),
        "prematch_note": prematch_legacy.get('note'),
        "affine_matrix": affine_legacy.tolist() if affine_legacy.size > 0 else None,
        "warped_mean_intensity": float(warped_legacy.mean()),
    },
    "modern_mode": {
        "output_dir": str(session_dir_modern),
        "prematch_rotation_deg": prematch_modern.get('result', {}).get('rotation_deg'),
        "prematch_translation_um": prematch_modern.get('result', {}).get('translation_um'),
        "prematch_note": prematch_modern.get('note'),
        "affine_matrix": affine_modern.tolist() if affine_modern.size > 0 else None,
        "warped_mean_intensity": float(warped_modern.mean()),
    },
    "comparison": {
        "mean_absolute_difference": float(np.abs(diff_vol).mean()),
        "max_absolute_difference": float(np.abs(diff_vol).max()),
        "rms_difference": float(np.sqrt((diff_vol**2).mean())),
        "pearson_correlation": float(correlation),
    }
}

# Save summary
summary_path = comparison_root / "comparison_summary.json"
with open(summary_path, 'w') as f:
    json.dump(summary, f, indent=2)

print("="*80)
print("SUMMARY REPORT")
print("="*80)
print(json.dumps(summary, indent=2))
print(f"\n✓ Summary saved to: {summary_path}")

## Conclusion

Key findings:

1. **Transform Chain**: Modern mode includes rotation in the affine matrix, legacy mode has it pre-applied
2. **Registration Quality**: Compare loss values from the logs and visual quality above
3. **Output Consistency**: Check correlation and difference metrics

The modern approach is preferred because:
- Complete transform chain (rotation included in FireANTs output)
- Additional channels will align correctly
- More flexible for downstream applications