# Two-Channel Image Alignment Tool

This notebook aligns two fluorescence channels (e.g., from chromatic aberration or registration issues).

**Workflow:**
1. Set your file paths in Cell 1
2. Run Cell 2 to load images and use interactive sliders to find optimal alignment
3. Run Cell 3 to export the aligned C2 channel

In [66]:
# =============================================================================
# CELL 1: CONFIGURATION - Edit these paths!
# =============================================================================

# Paths to your TIF files
c1_path = "/Volumes/Expansion/Pleomorphisms/20230720CsSpotAssayLiqMed/Jan26/MB_1x_464_30min_NoWash_002_Cy5.tif"  # Reference channel (magenta)
c2_path = "/Volumes/Expansion/Pleomorphisms/20230720CsSpotAssayLiqMed/Jan26/MB_1x_464_30min_NoWash_002_TRITC.tif"  # Channel to align (green)

# Output directory (leave as None to save in same folder as input)
output_dir = None

In [67]:
# =============================================================================
# CELL 2: LOAD & INTERACTIVE ALIGNMENT
# =============================================================================

import numpy as np
import matplotlib.pyplot as plt
from skimage import transform
from skimage.registration import phase_cross_correlation
import tifffile
import os
from pathlib import Path
from ipywidgets import interact, IntSlider, FloatSlider, Checkbox, HBox, VBox, Output
from IPython.display import display, clear_output

# -----------------------------------------------------------------------------
# Load images
# -----------------------------------------------------------------------------
print("Loading image stacks...")
c1_stack = tifffile.imread(c1_path)
c2_stack = tifffile.imread(c2_path)
print(f"C1 stack shape: {c1_stack.shape}, dtype: {c1_stack.dtype}")
print(f"C2 stack shape: {c2_stack.shape}, dtype: {c2_stack.dtype}")

# -----------------------------------------------------------------------------
# Handle 2D vs 3D data
# -----------------------------------------------------------------------------
if c1_stack.ndim == 2:
    c1_stack = c1_stack[np.newaxis, :, :]
    c2_stack = c2_stack[np.newaxis, :, :]
    is_3d = False
    print("Detected 2D images - treating as single z-slice")
else:
    is_3d = True
    print(f"Detected 3D stack with {c1_stack.shape[0]} z-slices")

n_slices, height, width = c1_stack.shape
mid_z = n_slices // 2
original_dtype = c2_stack.dtype

# -----------------------------------------------------------------------------
# Try auto-detection of shift
# -----------------------------------------------------------------------------
try:
    shift, error, diffphase = phase_cross_correlation(
        c1_stack[mid_z], c2_stack[mid_z], upsample_factor=10
    )
    initial_y, initial_x = int(round(shift[0])), int(round(shift[1]))
    print(f"Auto-detected shift: x={initial_x}, y={initial_y} pixels")
except Exception as e:
    print(f"Could not auto-detect shift: {e}")
    initial_x, initial_y = 0, 0

# -----------------------------------------------------------------------------
# Interactive preview function
# -----------------------------------------------------------------------------
def preview_alignment(x_shift, y_shift, 
                      c1_min=0.0, c1_max=1.0, 
                      c2_min=0.0, c2_max=1.0, c2_gamma=1.0, c2_gain=1.0,
                      z_slice=0, 
                      enable_zoom=False, zoom_x=0, zoom_y=0, zoom_size=300):
    """
    Interactive preview with magenta/green overlay.
    Adjustments are display-only; export preserves original intensities.
    """
    plt.clf()
    
    # Apply shift to C2
    transform_shift = transform.AffineTransform(translation=(x_shift, y_shift))
    c2_shifted = transform.warp(
        c2_stack[z_slice], 
        transform_shift,
        mode='constant',
        preserve_range=True
    ).astype(float)
    
    # Normalize to 0-1
    c1_norm = c1_stack[z_slice].astype(float)
    c1_norm = c1_norm / c1_norm.max() if c1_norm.max() > 0 else c1_norm
    c2_norm = c2_shifted / c2_shifted.max() if c2_shifted.max() > 0 else c2_shifted
    
    # Apply display adjustments
    c1_adj = np.clip((c1_norm - c1_min) / (c1_max - c1_min + 1e-10), 0, 1)
    c2_adj = np.clip((c2_norm - c2_min) / (c2_max - c2_min + 1e-10), 0, 1)
    c2_adj = np.power(c2_adj, 1/c2_gamma) * c2_gain
    c2_adj = np.clip(c2_adj, 0, 1)
    
    # Create RGB overlay (C1=magenta, C2=green)
    rgb = np.zeros((height, width, 3))
    rgb[:, :, 0] = c1_adj        # Red channel (magenta)
    rgb[:, :, 1] = c2_adj        # Green channel
    rgb[:, :, 2] = c1_adj        # Blue channel (magenta)
    
    # Handle zoom
    if enable_zoom and zoom_size > 0:
        half = zoom_size // 2
        y1, y2 = max(0, zoom_y - half), min(height, zoom_y + half)
        x1, x2 = max(0, zoom_x - half), min(width, zoom_x + half)
        rgb_display = rgb[y1:y2, x1:x2]
        c1_display = c1_adj[y1:y2, x1:x2]
        c2_display = c2_adj[y1:y2, x1:x2]
        title_suffix = f" (Zoom: {x1}-{x2}, {y1}-{y2})"
    else:
        rgb_display = rgb
        c1_display = c1_adj
        c2_display = c2_adj
        title_suffix = ""
    
    # Plot
    fig, axes = plt.subplots(1, 3, figsize=(18, 6))
    
    axes[0].imshow(c1_display, cmap='magma')
    axes[0].set_title(f'C1 (Reference){title_suffix}')
    axes[0].axis('off')
    
    axes[1].imshow(c2_display, cmap='viridis')
    axes[1].set_title(f'C2 (Shifted: x={x_shift}, y={y_shift}){title_suffix}')
    axes[1].axis('off')
    
    axes[2].imshow(rgb_display)
    axes[2].set_title(f'Overlay (Magenta=C1, Green=C2){title_suffix}')
    axes[2].axis('off')
    
    plt.tight_layout()
    plt.show()
    
    # Store current shift values globally for export
    global current_x_shift, current_y_shift
    current_x_shift = x_shift
    current_y_shift = y_shift

# -----------------------------------------------------------------------------
# Create interactive widget
# -----------------------------------------------------------------------------
print("\n" + "="*60)
print("INTERACTIVE ALIGNMENT")
print("="*60)
print("• Adjust X/Y Shift until channels align in the overlay")
print("• B&C sliders are display-only (won't affect export)")
print("• Enable Zoom to inspect fine details")
print("="*60 + "\n")

# Initialize global shift values
current_x_shift = initial_x
current_y_shift = initial_y

interact(
    preview_alignment,
    x_shift=IntSlider(min=-100, max=100, step=1, value=initial_x, description='X Shift:', continuous_update=False),
    y_shift=IntSlider(min=-100, max=100, step=1, value=initial_y, description='Y Shift:', continuous_update=False),
    c1_min=FloatSlider(min=0, max=0.5, step=0.01, value=0.0, description='C1 Min:', continuous_update=False),
    c1_max=FloatSlider(min=0.5, max=1.0, step=0.01, value=1.0, description='C1 Max:', continuous_update=False),
    c2_min=FloatSlider(min=0, max=0.5, step=0.01, value=0.0, description='C2 Min:', continuous_update=False),
    c2_max=FloatSlider(min=0.5, max=1.0, step=0.01, value=1.0, description='C2 Max:', continuous_update=False),
    c2_gamma=FloatSlider(min=0.1, max=3.0, step=0.1, value=1.0, description='C2 Gamma:', continuous_update=False),
    c2_gain=FloatSlider(min=1.0, max=10.0, step=0.5, value=1.0, description='C2 Gain:', continuous_update=False),
    z_slice=IntSlider(min=0, max=n_slices-1, step=1, value=mid_z, description='Z-Slice:', continuous_update=False),
    enable_zoom=Checkbox(value=False, description='Enable Zoom'),
    zoom_x=IntSlider(min=0, max=width-1, step=10, value=width//2, description='Zoom X:', continuous_update=False),
    zoom_y=IntSlider(min=0, max=height-1, step=10, value=height//2, description='Zoom Y:', continuous_update=False),
    zoom_size=IntSlider(min=50, max=1000, step=50, value=300, description='Zoom Size:', continuous_update=False)
);

Loading image stacks...
C1 stack shape: (35, 2304, 2304), dtype: uint16
C2 stack shape: (35, 2304, 2304), dtype: uint16
Detected 3D stack with 35 z-slices
Auto-detected shift: x=0, y=0 pixels

INTERACTIVE ALIGNMENT
• Adjust X/Y Shift until channels align in the overlay
• B&C sliders are display-only (won't affect export)
• Enable Zoom to inspect fine details



interactive(children=(IntSlider(value=0, continuous_update=False, description='X Shift:', min=-100), IntSlider…

In [63]:
# =============================================================================
# CELL 3: EXPORT ALIGNED C2
# =============================================================================

# Option 1: Use the current slider values automatically
x_shift = current_x_shift
y_shift = current_y_shift

# Option 2: Or manually override here:
# x_shift = 3  
# y_shift = 4

print(f"Exporting with shift: x={x_shift}, y={y_shift}")
print(f"Original dtype: {original_dtype}")

# Apply alignment to full stack
if is_3d:
    aligned_c2 = np.zeros_like(c2_stack)
    for z in range(n_slices):
        tform = transform.AffineTransform(translation=(x_shift, y_shift))
        aligned_c2[z] = transform.warp(
            c2_stack[z], tform, mode='constant', preserve_range=True
        ).astype(original_dtype)
        if (z + 1) % 10 == 0 or z == n_slices - 1:
            print(f"  Processed {z + 1}/{n_slices} slices...")
else:
    # Remove the dummy z dimension for 2D export
    tform = transform.AffineTransform(translation=(x_shift, y_shift))
    aligned_c2 = transform.warp(
        c2_stack[0], tform, mode='constant', preserve_range=True
    ).astype(original_dtype)

# Determine output path
if output_dir is None:
    out_folder = Path(c2_path).parent
else:
    out_folder = Path(output_dir)
    out_folder.mkdir(parents=True, exist_ok=True)

out_name = f"{Path(c2_path).stem}_aligned_x{x_shift}_y{y_shift}.tif"
out_path = out_folder / out_name

# Save
tifffile.imwrite(out_path, aligned_c2)
print(f"\n✅ Saved aligned C2 to:\n   {out_path}")
print(f"   Shape: {aligned_c2.shape}, dtype: {aligned_c2.dtype}")

Exporting with shift: x=12, y=6
Original dtype: uint16
  Processed 10/35 slices...
  Processed 20/35 slices...
  Processed 30/35 slices...
  Processed 35/35 slices...

✅ Saved aligned C2 to:
   /Volumes/Expansion/Pleomorphisms/20230720CsSpotAssayLiqMed/Jan26/MB_1x_FA_30min_NoWash_004_TRITC_aligned_x12_y6.tif
   Shape: (35, 2304, 2304), dtype: uint16
