alignment using cross-correlation

In [7]:
import numpy as np
from astropy.io import fits
from astropy.stats import sigma_clipped_stats
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
from scipy.fft import fft2, ifft2, fftshift
from scipy.ndimage import shift as apply_shift
import warnings
warnings.filterwarnings('ignore')

def cross_correlate_images(ref_img, src_img):
    """
    Calculate cross-correlation to find shift between images
    Returns: (shift_x, shift_y, correlation_strength)
    """
    # Normalize images to mean=0, std=1
    ref_norm = (ref_img - np.mean(ref_img)) / np.std(ref_img)
    src_norm = (src_img - np.mean(src_img)) / np.std(src_img)
    
    # FFT-based cross-correlation
    f_ref = fft2(ref_norm)
    f_src = fft2(src_norm)
    
    # Cross-correlation in frequency domain
    cross_power = f_ref * np.conj(f_src)
    cross_corr = fftshift(ifft2(cross_power).real)
    
    # Find peak position (maximum correlation)
    peak_y, peak_x = np.unravel_index(np.argmax(cross_corr), cross_corr.shape)
    
    # Calculate shift from center
    center_y, center_x = np.array(cross_corr.shape) // 2
    shift_y = peak_y - center_y
    shift_x = peak_x - center_x
    
    # Correlation strength (normalized)
    corr_strength = np.max(cross_corr) / cross_corr.size
    
    # Subpixel refinement using parabolic fit around peak
    if 1 < peak_y < cross_corr.shape[0]-1 and 1 < peak_x < cross_corr.shape[1]-1:
        # Fit parabola in X direction
        c_x = cross_corr[peak_y, peak_x]
        c_xm1 = cross_corr[peak_y, peak_x-1]
        c_xp1 = cross_corr[peak_y, peak_x+1]
        
        # Fit parabola in Y direction
        c_ym1 = cross_corr[peak_y-1, peak_x]
        c_yp1 = cross_corr[peak_y+1, peak_x]
        
        # Subpixel correction
        if c_xp1 + c_xm1 - 2*c_x != 0:
            dx = 0.5 * (c_xm1 - c_xp1) / (c_xp1 + c_xm1 - 2*c_x)
            shift_x += dx
        
        if c_yp1 + c_ym1 - 2*c_x != 0:
            dy = 0.5 * (c_ym1 - c_yp1) / (c_yp1 + c_ym1 - 2*c_x)
            shift_y += dy
    
    return shift_x, shift_y, corr_strength

def align_image(image, shift_x, shift_y, output_shape=None):
    """
    Apply shift to image using scipy's shift function
    Uses spline interpolation for subpixel accuracy
    """
    if output_shape is None:
        output_shape = image.shape
    
    # Apply shift (order=3 for cubic spline interpolation)
    aligned = apply_shift(image, [shift_y, shift_x], order=3, mode='constant', cval=0.0)
    
    return aligned

def select_reference_frame(fits_files, method='middle'):
    """
    Select reference frame
    method: 'first', 'middle', or 'brightest'
    """
    if method == 'first':
        return 0
    elif method == 'middle':
        return len(fits_files) // 2
    elif method == 'brightest':
        print("\n  Finding brightest frame...")
        max_median = 0
        best_idx = 0
        
        for idx, fits_file in enumerate(fits_files[::5]):  # Check every 5th file
            with fits.open(fits_file) as hdul:
                data = hdul[0].data.astype(float)
            
            median_val = np.median(data)
            if median_val > max_median:
                max_median = median_val
                best_idx = idx * 5
        
        return best_idx
    else:
        return len(fits_files) // 2

def create_shift_visualization(shifts_x, shifts_y, correlations, filenames, output_path, filter_name):
    """Create visualization of alignment shifts"""
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    frame_numbers = np.arange(1, len(shifts_x) + 1)
    
    # Plot 1: Shift trajectory
    ax1 = axes[0, 0]
    ax1.plot(shifts_x, shifts_y, 'bo-', markersize=6, linewidth=1.5, alpha=0.7)
    ax1.plot(0, 0, 'r*', markersize=20, label='Reference frame')
    ax1.set_xlabel('X shift (pixels)', fontsize=11)
    ax1.set_ylabel('Y shift (pixels)', fontsize=11)
    ax1.set_title(f'Drift Pattern - Filter {filter_name.upper()}', fontsize=12, fontweight='bold')
    ax1.grid(True, alpha=0.3)
    ax1.legend()
    ax1.axis('equal')
    
    # Plot 2: Shifts over time
    ax2 = axes[0, 1]
    ax2.plot(frame_numbers, shifts_x, 'b.-', label='X shift', markersize=5, linewidth=1.5)
    ax2.plot(frame_numbers, shifts_y, 'r.-', label='Y shift', markersize=5, linewidth=1.5)
    ax2.set_xlabel('Frame Number', fontsize=11)
    ax2.set_ylabel('Shift (pixels)', fontsize=11)
    ax2.set_title('Shifts Over Time', fontsize=12, fontweight='bold')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # Plot 3: Total shift magnitude
    ax3 = axes[1, 0]
    total_shifts = np.sqrt(np.array(shifts_x)**2 + np.array(shifts_y)**2)
    ax3.plot(frame_numbers, total_shifts, 'g.-', markersize=5, linewidth=1.5)
    ax3.set_xlabel('Frame Number', fontsize=11)
    ax3.set_ylabel('Total Shift (pixels)', fontsize=11)
    ax3.set_title('Total Drift Magnitude', fontsize=12, fontweight='bold')
    ax3.grid(True, alpha=0.3)
    
    # Plot 4: Correlation strength
    ax4 = axes[1, 1]
    ax4.plot(frame_numbers, correlations, 'mo-', markersize=5, linewidth=1.5)
    ax4.set_xlabel('Frame Number', fontsize=11)
    ax4.set_ylabel('Correlation Strength', fontsize=11)
    ax4.set_title('Alignment Quality', fontsize=12, fontweight='bold')
    ax4.grid(True, alpha=0.3)
    ax4.axhline(0.9, color='r', linestyle='--', alpha=0.5, label='0.9 threshold')
    ax4.legend()
    
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()

def align_images(filter_name, light_dir, output_dir):
    """Align all images in a filter directory using cross-correlation"""
    print(f"\n{'='*60}")
    print(f"Processing filter: {filter_name.upper()}")
    print(f"{'='*60}")
    
    # Get all FITS files
    filter_path = light_dir / filter_name
    fits_files = sorted(filter_path.glob('*.fits'))
    
    if len(fits_files) == 0:
        print(f"  WARNING: No FITS files found in {filter_path}")
        return
    
    print(f"  Found {len(fits_files)} FITS files")
    
    # Create output directory
    output_filter_dir = output_dir / filter_name
    output_filter_dir.mkdir(parents=True, exist_ok=True)
    
    # Select reference frame (using middle frame as it's likely to have median tracking)
    ref_idx = select_reference_frame(fits_files, method='middle')
    print(f"  Using reference frame: {fits_files[ref_idx].name} (frame {ref_idx+1}/{len(fits_files)})")
    
    # Load reference image
    with fits.open(fits_files[ref_idx]) as hdul:
        ref_data = hdul[0].data.astype(float)
        ref_header = hdul[0].header
    
    print(f"  Reference image shape: {ref_data.shape}")
    
    # Save reference image
    ref_output_path = output_filter_dir / fits_files[ref_idx].name.replace('.fits', '_aligned.fits')
    fits.writeto(ref_output_path, ref_data, ref_header, overwrite=True)
    
    # Track shifts for visualization
    shifts_x = []
    shifts_y = []
    correlations = []
    filenames = []
    
    # Align all images
    print(f"\n  Aligning images using cross-correlation...")
    
    successful = 0
    failed = []
    
    for idx, fits_file in enumerate(tqdm(fits_files, desc=f"  {filter_name}")):
        if idx == ref_idx:
            # Reference frame - no shift needed
            shifts_x.append(0.0)
            shifts_y.append(0.0)
            correlations.append(1.0)
            filenames.append(fits_file.name)
            successful += 1
            continue
        
        try:
            # Load source image
            with fits.open(fits_file) as hdul:
                src_data = hdul[0].data.astype(float)
                src_header = hdul[0].header
            
            # Calculate shift using cross-correlation
            shift_x, shift_y, corr = cross_correlate_images(ref_data, src_data)
            
            # Check correlation quality
            if corr < 0.7:
                raise ValueError(f"Low correlation: {corr:.3f} (threshold: 0.7)")
            
            # Apply shift to align image
            aligned_data = align_image(src_data, shift_x, shift_y, ref_data.shape)
            
            # Save aligned image
            output_path = output_filter_dir / fits_file.name.replace('.fits', '_aligned.fits')
            fits.writeto(output_path, aligned_data, src_header, overwrite=True)
            
            # Track statistics
            shifts_x.append(shift_x)
            shifts_y.append(shift_y)
            correlations.append(corr)
            filenames.append(fits_file.name)
            
            successful += 1
            
        except Exception as e:
            failed.append((fits_file.name, str(e)))
            # Still add to lists for plotting (with NaN)
            shifts_x.append(np.nan)
            shifts_y.append(np.nan)
            correlations.append(np.nan)
            filenames.append(fits_file.name)
    
    # Create visualization
    print(f"\n  Creating alignment visualization...")
    viz_path = output_filter_dir / f"alignment_shifts_{filter_name}.png"
    
    # Remove NaN values for plotting
    valid_idx = ~np.isnan(shifts_x)
    if np.sum(valid_idx) > 0:
        create_shift_visualization(
            np.array(shifts_x)[valid_idx],
            np.array(shifts_y)[valid_idx],
            np.array(correlations)[valid_idx],
            np.array(filenames)[valid_idx],
            viz_path,
            filter_name
        )
        print(f"  Saved visualization: {viz_path}")
    
    # Statistics
    valid_shifts_x = np.array(shifts_x)[valid_idx]
    valid_shifts_y = np.array(shifts_y)[valid_idx]
    valid_corrs = np.array(correlations)[valid_idx]
    
    if len(valid_shifts_x) > 0:
        total_shifts = np.sqrt(valid_shifts_x**2 + valid_shifts_y**2)
        
        print(f"\n  {'='*50}")
        print(f"  Alignment Statistics:")
        print(f"  - Mean shift: {np.mean(total_shifts):.2f} ± {np.std(total_shifts):.2f} pixels")
        print(f"  - Max shift: {np.max(total_shifts):.2f} pixels")
        print(f"  - Mean correlation: {np.mean(valid_corrs):.4f}")
        print(f"  - Min correlation: {np.min(valid_corrs):.4f}")
        print(f"  {'='*50}")
    
    # Summary
    print(f"\n  {'='*50}")
    print(f"  Filter {filter_name.upper()} Summary:")
    print(f"  Successfully aligned: {successful}/{len(fits_files)}")
    if failed:
        print(f"  Failed: {len(failed)}")
        for fname, error in failed[:5]:
            print(f"    - {fname}: {error}")
        if len(failed) > 5:
            print(f"    ... and {len(failed)-5} more")
    print(f"  {'='*50}")

def main():
    # Define paths
    base_dir = Path("/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19")
    light_dir = base_dir / "Image_reduction_workspace" / "LIGHT"
    output_dir = base_dir / "Image_reduction_workspace" / "LIGHT_aligned_cross_correlation"
    
    # Check if input directory exists
    if not light_dir.exists():
        print(f"ERROR: Input directory not found: {light_dir}")
        return
    
    print(f"="*60)
    print("CROSS-CORRELATION IMAGE ALIGNMENT")
    print(f"="*60)
    print(f"Input directory: {light_dir}")
    print(f"Output directory: {output_dir}")
    
    # Create main output directory
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Process each filter
    filters = ['clear', 'r', 'g', 'i']
    
    for filter_name in filters:
        filter_path = light_dir / filter_name
        if filter_path.exists():
            align_images(filter_name, light_dir, output_dir)
        else:
            print(f"\nWARNING: Filter directory not found: {filter_path}")
    
    print(f"\n{'='*60}")
    print("ALIGNMENT COMPLETE!")
    print(f"Aligned images saved to: {output_dir}")
    print(f"Visualization plots saved in each filter subdirectory")
    print(f"{'='*60}")

if __name__ == "__main__":
    main()

CROSS-CORRELATION IMAGE ALIGNMENT
Input directory: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT
Output directory: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation

Processing filter: CLEAR
  Found 62 FITS files
  Using reference frame: NGC1365-0032_2x2_60s_calibrated.fits (frame 32/62)
  Reference image shape: (2048, 2048)

  Aligning images using cross-correlation...


  clear: 100%|██████████████████████████████████| 62/62 [00:38<00:00,  1.62it/s]



  Creating alignment visualization...
  Saved visualization: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation/clear/alignment_shifts_clear.png

  Alignment Statistics:
  - Mean shift: 3.03 ± 1.87 pixels
  - Max shift: 7.60 pixels
  - Mean correlation: 0.9705
  - Min correlation: 0.9442

  Filter CLEAR Summary:
  Successfully aligned: 62/62

Processing filter: R
  Found 62 FITS files
  Using reference frame: NGC1365-0032_2x2_r_60s_calibrated.fits (frame 32/62)
  Reference image shape: (2048, 2048)

  Aligning images using cross-correlation...


  r: 100%|██████████████████████████████████████| 62/62 [00:38<00:00,  1.60it/s]



  Creating alignment visualization...
  Saved visualization: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation/r/alignment_shifts_r.png

  Alignment Statistics:
  - Mean shift: 5.17 ± 2.90 pixels
  - Max shift: 10.13 pixels
  - Mean correlation: 0.9553
  - Min correlation: 0.9094

  Filter R Summary:
  Successfully aligned: 62/62

Processing filter: G
  Found 62 FITS files
  Using reference frame: NGC1365-0032_2x2_g_60s_calibrated.fits (frame 32/62)
  Reference image shape: (2048, 2048)

  Aligning images using cross-correlation...


  g: 100%|██████████████████████████████████████| 62/62 [00:38<00:00,  1.60it/s]



  Creating alignment visualization...
  Saved visualization: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation/g/alignment_shifts_g.png

  Alignment Statistics:
  - Mean shift: 4.38 ± 2.65 pixels
  - Max shift: 9.88 pixels
  - Mean correlation: 0.8955
  - Min correlation: 0.7026

  Filter G Summary:
  Successfully aligned: 61/62
  Failed: 1
    - NGC1365-0020_2x2_g_60s_calibrated.fits: Low correlation: 0.692 (threshold: 0.7)

Processing filter: I
  Found 62 FITS files
  Using reference frame: NGC1365-0032_2x2_i_60s_calibrated.fits (frame 32/62)
  Reference image shape: (2048, 2048)

  Aligning images using cross-correlation...


  i: 100%|██████████████████████████████████████| 62/62 [00:38<00:00,  1.63it/s]



  Creating alignment visualization...
  Saved visualization: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation/i/alignment_shifts_i.png

  Alignment Statistics:
  - Mean shift: 3.50 ± 2.06 pixels
  - Max shift: 8.06 pixels
  - Mean correlation: 0.9464
  - Min correlation: 0.9087

  Filter I Summary:
  Successfully aligned: 62/62

ALIGNMENT COMPLETE!
Aligned images saved to: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned_cross_correlation
Visualization plots saved in each filter subdirectory


In [8]:
import numpy as np
from astropy.io import fits
from astropy.stats import sigma_clipped_stats
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from pathlib import Path

def load_and_normalize(fits_path):
    """Load FITS file and normalize for display"""
    with fits.open(fits_path) as hdul:
        data = hdul[0].data.astype(float)
    
    # Calculate display range
    mean, median, std = sigma_clipped_stats(data, sigma=3.0)
    vmin = median - 2*std
    vmax = median + 15*std
    
    return data, vmin, vmax

def create_blink_comparison(raw_files, aligned_files, output_path, filter_name, n_frames=4):
    """Create side-by-side comparison of raw and aligned frames"""
    
    # Select frames to display (first, some middle ones, last)
    if len(raw_files) > n_frames:
        indices = np.linspace(0, len(raw_files)-1, n_frames, dtype=int)
    else:
        indices = range(len(raw_files))
        n_frames = len(raw_files)
    
    fig, axes = plt.subplots(n_frames, 2, figsize=(16, 4*n_frames))
    
    # Handle single frame case
    if n_frames == 1:
        axes = axes.reshape(1, -1)
    
    print(f"\n  Loading frames for visualization...")
    
    for i, idx in enumerate(indices):
        raw_file = raw_files[idx]
        aligned_file = aligned_files[idx]
        
        # Load raw frame
        raw_data, vmin_raw, vmax_raw = load_and_normalize(raw_file)
        
        # Load aligned frame
        aligned_data, vmin_aligned, vmax_aligned = load_and_normalize(aligned_file)
        
        # Use same scaling for both
        vmin = min(vmin_raw, vmin_aligned)
        vmax = max(vmax_raw, vmax_aligned)
        
        # Display raw frame
        axes[i, 0].imshow(raw_data, cmap='gray', origin='lower', vmin=vmin, vmax=vmax)
        axes[i, 0].set_title(f'RAW - Frame {idx+1}: {raw_file.name}', fontsize=10)
        axes[i, 0].axis('off')
        
        # Display aligned frame
        axes[i, 1].imshow(aligned_data, cmap='gray', origin='lower', vmin=vmin, vmax=vmax)
        axes[i, 1].set_title(f'ALIGNED - Frame {idx+1}: {aligned_file.name}', fontsize=10)
        axes[i, 1].axis('off')
        
        print(f"    Frame {idx+1}/{len(raw_files)}: {raw_file.name}")
    
    plt.suptitle(f'Raw vs Aligned Frames - Filter {filter_name.upper()}', 
                 fontsize=14, fontweight='bold', y=0.995)
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved: {output_path}")

def create_detail_comparison(raw_files, aligned_files, output_path, filter_name, crop_size=300):
    """Create zoomed-in comparison showing alignment detail"""
    
    # Use first, middle, and last frames
    if len(raw_files) >= 3:
        indices = [0, len(raw_files)//2, len(raw_files)-1]
    else:
        indices = list(range(len(raw_files)))
    
    fig = plt.figure(figsize=(18, 6*len(indices)))
    
    print(f"\n  Creating detailed comparison...")
    
    for plot_idx, idx in enumerate(indices):
        raw_file = raw_files[idx]
        aligned_file = aligned_files[idx]
        
        # Load images
        raw_data, _, _ = load_and_normalize(raw_file)
        aligned_data, _, _ = load_and_normalize(aligned_file)
        
        # Find galaxy center (brightest region)
        smoothed = raw_data.copy()
        from scipy.ndimage import gaussian_filter
        smoothed = gaussian_filter(smoothed, sigma=10)
        center_y, center_x = np.unravel_index(np.argmax(smoothed), smoothed.shape)
        
        # Ensure crop stays within bounds
        y_start = max(0, center_y - crop_size//2)
        y_end = min(raw_data.shape[0], center_y + crop_size//2)
        x_start = max(0, center_x - crop_size//2)
        x_end = min(raw_data.shape[1], center_x + crop_size//2)
        
        # Crop regions
        raw_crop = raw_data[y_start:y_end, x_start:x_end]
        aligned_crop = aligned_data[y_start:y_end, x_start:x_end]
        
        # Calculate display range for crops
        mean, median, std = sigma_clipped_stats(raw_crop, sigma=3.0)
        vmin = median - 2*std
        vmax = median + 15*std
        
        # Create 3 subplots per frame: Full raw, Full aligned, Blink animation region
        base_idx = plot_idx * 3
        
        # Full raw image with crop box
        ax1 = plt.subplot(len(indices), 3, base_idx + 1)
        ax1.imshow(raw_data, cmap='gray', origin='lower', 
                  vmin=np.percentile(raw_data, 1), vmax=np.percentile(raw_data, 99))
        rect = Rectangle((x_start, y_start), x_end-x_start, y_end-y_start,
                        linewidth=2, edgecolor='red', facecolor='none')
        ax1.add_patch(rect)
        ax1.set_title(f'RAW Full Frame {idx+1}\n(Red box = zoom region)', fontsize=10)
        ax1.axis('off')
        
        # Full aligned image with crop box
        ax2 = plt.subplot(len(indices), 3, base_idx + 2)
        ax2.imshow(aligned_data, cmap='gray', origin='lower',
                  vmin=np.percentile(aligned_data, 1), vmax=np.percentile(aligned_data, 99))
        rect2 = Rectangle((x_start, y_start), x_end-x_start, y_end-y_start,
                         linewidth=2, edgecolor='red', facecolor='none')
        ax2.add_patch(rect2)
        ax2.set_title(f'ALIGNED Full Frame {idx+1}\n(Red box = zoom region)', fontsize=10)
        ax2.axis('off')
        
        # Zoomed comparison (side by side)
        ax3 = plt.subplot(len(indices), 3, base_idx + 3)
        
        # Create side-by-side zoom
        combined = np.hstack([raw_crop, aligned_crop])
        ax3.imshow(combined, cmap='gray', origin='lower', vmin=vmin, vmax=vmax)
        ax3.axvline(raw_crop.shape[1], color='yellow', linewidth=2, linestyle='--')
        ax3.text(raw_crop.shape[1]//2, 10, 'RAW', color='yellow', 
                fontsize=12, ha='center', weight='bold',
                bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
        ax3.text(raw_crop.shape[1] + raw_crop.shape[1]//2, 10, 'ALIGNED', 
                color='yellow', fontsize=12, ha='center', weight='bold',
                bbox=dict(boxstyle='round', facecolor='black', alpha=0.7))
        ax3.set_title(f'Zoomed Detail Comparison - Frame {idx+1}', fontsize=10)
        ax3.axis('off')
        
        print(f"    Frame {idx+1}: Crop center at ({center_x}, {center_y})")
    
    plt.suptitle(f'Detailed Alignment Verification - Filter {filter_name.upper()}', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved: {output_path}")

def create_animated_blink(raw_files, aligned_files, output_path, filter_name, n_frames=10):
    """Create a single image showing multiple frames that can be used for blinking"""
    
    # Select evenly spaced frames
    if len(raw_files) > n_frames:
        indices = np.linspace(0, len(raw_files)-1, n_frames, dtype=int)
    else:
        indices = range(len(raw_files))
        n_frames = len(raw_files)
    
    # Create figure with subplots
    n_cols = 5
    n_rows = (n_frames + n_cols - 1) // n_cols
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(20, 4*n_rows))
    axes = axes.flatten() if n_frames > 1 else [axes]
    
    print(f"\n  Creating blink sequence with {n_frames} frames...")
    
    for i, idx in enumerate(indices):
        aligned_file = aligned_files[idx]
        
        # Load aligned frame
        aligned_data, vmin, vmax = load_and_normalize(aligned_file)
        
        axes[i].imshow(aligned_data, cmap='gray', origin='lower', vmin=vmin, vmax=vmax)
        axes[i].set_title(f'Frame {idx+1}', fontsize=9)
        axes[i].axis('off')
    
    # Hide unused subplots
    for i in range(n_frames, len(axes)):
        axes[i].axis('off')
    
    plt.suptitle(f'Aligned Frame Sequence - Filter {filter_name.upper()}\n(Blink through these to check alignment)', 
                 fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(output_path, dpi=150, bbox_inches='tight')
    plt.close()
    
    print(f"  Saved: {output_path}")

def verify_alignment(filter_name, raw_dir, aligned_dir, output_dir):
    """Create verification visualizations for a filter"""
    print(f"\n{'='*60}")
    print(f"Creating verification plots for filter: {filter_name.upper()}")
    print(f"{'='*60}")
    
    # Get file lists
    raw_path = raw_dir / filter_name
    aligned_path = aligned_dir / filter_name
    
    raw_files = sorted(raw_path.glob('*.fits'))
    aligned_files = sorted(aligned_path.glob('*_aligned.fits'))
    
    if len(raw_files) == 0 or len(aligned_files) == 0:
        print(f"  WARNING: No files found!")
        return
    
    print(f"  Found {len(raw_files)} raw files")
    print(f"  Found {len(aligned_files)} aligned files")
    
    # Create output directory for this filter
    filter_output = output_dir / filter_name
    filter_output.mkdir(parents=True, exist_ok=True)
    
    # 1. Side-by-side comparison
    print(f"\n  [1/3] Creating side-by-side comparison...")
    sidebyside_path = filter_output / f"verification_sidebyside_{filter_name}.png"
    create_blink_comparison(raw_files, aligned_files, sidebyside_path, filter_name, n_frames=4)
    
    # 2. Detailed zoom comparison
    print(f"\n  [2/3] Creating detailed zoom comparison...")
    detail_path = filter_output / f"verification_detail_{filter_name}.png"
    create_detail_comparison(raw_files, aligned_files, detail_path, filter_name, crop_size=400)
    
    # 3. Blink sequence
    print(f"\n  [3/3] Creating blink sequence...")
    blink_path = filter_output / f"verification_blink_{filter_name}.png"
    create_animated_blink(raw_files, aligned_files, blink_path, filter_name, n_frames=10)
    
    print(f"\n  ✓ Verification plots complete for {filter_name.upper()}")

def main():
    # Define paths
    base_dir = Path("/home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19")
    raw_dir = base_dir / "Image_reduction_workspace" / "LIGHT"
    aligned_dir = base_dir / "Image_reduction_workspace" / "LIGHT_aligned"
    output_dir = base_dir / "Image_reduction_workspace" / "alignment_verification_visual"
    
    if not raw_dir.exists():
        print(f"ERROR: Raw directory not found: {raw_dir}")
        return
    
    if not aligned_dir.exists():
        print(f"ERROR: Aligned directory not found: {aligned_dir}")
        return
    
    print(f"="*60)
    print("ALIGNMENT VERIFICATION VISUALIZATION")
    print(f"="*60)
    print(f"Raw images: {raw_dir}")
    print(f"Aligned images: {aligned_dir}")
    print(f"Output: {output_dir}")
    
    # Create main output directory
    output_dir.mkdir(parents=True, exist_ok=True)
    
    # Process each filter
    filters = ['clear', 'r', 'g', 'i']
    
    for filter_name in filters:
        raw_filter_path = raw_dir / filter_name
        aligned_filter_path = aligned_dir / filter_name
        
        if raw_filter_path.exists() and aligned_filter_path.exists():
            verify_alignment(filter_name, raw_dir, aligned_dir, output_dir)
        else:
            print(f"\nWARNING: Skipping {filter_name} - directories not found")
    
    print(f"\n{'='*60}")
    print("VERIFICATION VISUALIZATION COMPLETE!")
    print(f"All plots saved to: {output_dir}")
    print(f"{'='*60}")
    print("\nGenerated plots:")
    print("  1. verification_sidebyside_*.png - Raw vs Aligned comparison")
    print("  2. verification_detail_*.png - Zoomed detail comparison")
    print("  3. verification_blink_*.png - Sequence for blinking through")

if __name__ == "__main__":
    main()

ALIGNMENT VERIFICATION VISUALIZATION
Raw images: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT
Aligned images: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/LIGHT_aligned
Output: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/alignment_verification_visual

Creating verification plots for filter: CLEAR
  Found 62 raw files
  Found 62 aligned files

  [1/3] Creating side-by-side comparison...

  Loading frames for visualization...
    Frame 1/62: NGC1365-0001_2x2_60s_calibrated.fits
    Frame 21/62: NGC1365-0021_2x2_60s_calibrated.fits
    Frame 41/62: NGC1365-0041_2x2_60s_calibrated.fits
    Frame 62/62: NGC1365-0062_2x2_60s_calibrated.fits
  Saved: /home/devika/PhD/S2/Obs_Astronomy/Image_Processing/Ckoirama_2025-12-19/Image_reduction_workspace/alignment_verification_visual/clear/verification_sidebyside_clear.png

  [2/3] Creating 