# Supernova Visual Inspection Notebook

Quick visual inspection of downloaded reference/science image pairs to:
1. Verify data quality
2. Identify same-mission pairs suitable for differencing
3. Visually spot supernovae in science images
4. Assess alignment/WCS quality

**Data source:** `output/fits_training/` (organized from MAST downloads)


In [1]:
# Setup
import sys
from pathlib import Path

project_root = Path.cwd().parent
sys.path.insert(0, str(project_root))

print(f"Project root: {project_root}")


Project root: /mnt/astrid/AstrID


In [2]:
# Imports
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm, SymLogNorm
from astropy.io import fits
from astropy.wcs import WCS
from astropy.visualization import ZScaleInterval, ImageNormalize, AsinhStretch
import json
from collections import defaultdict
import warnings
warnings.filterwarnings('ignore', category=fits.verify.VerifyWarning)

print("Imports complete")


Imports complete


In [3]:
# Configuration
TRAINING_DIR = project_root / "output" / "fits_training"
MANIFEST_FILE = TRAINING_DIR / "training_manifest.json"

# Load manifest
with open(MANIFEST_FILE) as f:
    manifest = json.load(f)

print(f"Training directory: {TRAINING_DIR}")
print(f"Total SNe in manifest: {len(manifest)}")


Training directory: /mnt/astrid/AstrID/output/fits_training
Total SNe in manifest: 223


## 1. Inventory: Same-Mission Pairs

For meaningful image differencing, we need reference and science images from the **same mission** (ideally same filter too).


In [4]:
def get_mission_from_filename(filename: str) -> str:
    """Extract mission name from filename prefix."""
    name = Path(filename).name
    if name.startswith("SWIFT_"):
        return "SWIFT"
    elif name.startswith("PS1_"):
        return "PS1"
    elif name.startswith("GALEX_"):
        return "GALEX"
    elif name.startswith("TESS_"):
        return "TESS"
    else:
        return "UNKNOWN"

def get_swift_filter(filename: str) -> str:
    """Extract SWIFT UVOT filter from filename."""
    name = Path(filename).name.lower()
    for filt in ['uvw2', 'uvm2', 'uvw1', 'uuu', 'ubb', 'uvv', 'uwh']:
        if filt in name:
            return filt
    return "unknown"

def analyze_sn_missions(entry: dict) -> dict:
    """Analyze missions present in reference and science for one SN."""
    ref_missions = defaultdict(list)
    sci_missions = defaultdict(list)
    
    for f in entry.get('reference_files', []):
        mission = get_mission_from_filename(f)
        ref_missions[mission].append(f)
    
    for f in entry.get('science_files', []):
        mission = get_mission_from_filename(f)
        sci_missions[mission].append(f)
    
    # Find overlapping missions
    common_missions = set(ref_missions.keys()) & set(sci_missions.keys())
    
    return {
        'sn_name': entry['sn_name'],
        'ref_missions': dict(ref_missions),
        'sci_missions': dict(sci_missions),
        'common_missions': list(common_missions),
        'has_same_mission_pair': len(common_missions) > 0
    }

# Analyze all SNe
analyses = [analyze_sn_missions(entry) for entry in manifest]

# Summary
same_mission_sne = [a for a in analyses if a['has_same_mission_pair']]
cross_mission_sne = [a for a in analyses if not a['has_same_mission_pair']]

print("=" * 60)
print("MISSION ANALYSIS")
print("=" * 60)
print(f"\n‚úÖ Same-mission pairs (usable for differencing): {len(same_mission_sne)}")
for a in same_mission_sne:
    print(f"   {a['sn_name']}: {', '.join(a['common_missions'])}")

print(f"\n‚ö†Ô∏è  Cross-mission only (not ideal for differencing): {len(cross_mission_sne)}")
for a in cross_mission_sne:
    ref_m = list(a['ref_missions'].keys())
    sci_m = list(a['sci_missions'].keys())
    print(f"   {a['sn_name']}: REF={ref_m} vs SCI={sci_m}")


MISSION ANALYSIS

‚úÖ Same-mission pairs (usable for differencing): 168
   2005V: GALEX
   2005ai: GALEX
   2005ay: GALEX
   2005az: GALEX
   2005bj: GALEX
   2005bk: GALEX
   2005bn: GALEX
   2005bt: GALEX
   2005ca: GALEX
   2005cc: GALEX
   2005ck: GALEX
   2005cr: GALEX
   2005ct: GALEX
   2005dh: GALEX
   2005dn: GALEX
   2005dq: GALEX
   2005dt: GALEX
   2005dy: GALEX
   2005ea: GALEX
   2005eb: GALEX
   2005ej: GALEX
   2005eu: GALEX
   2005ev: GALEX
   2005hh: GALEX
   2005kl: GALEX
   2005lt: GALEX
   2005lx: GALEX
   2005mf: GALEX
   2005mr: GALEX
   2005nc: SWIFT
   2006P: GALEX
   2006X: SWIFT
   2006Z: GALEX
   2006af: GALEX
   2006ao: GALEX
   2006aq: GALEX
   2006as: GALEX
   2006ax: GALEX
   2006ay: GALEX
   2006bk: GALEX
   2006br: GALEX
   2006bt: GALEX
   2006by: GALEX
   2006ce: GALEX
   2006cg: GALEX
   2006cj: GALEX
   2006cs: GALEX
   2006ct: GALEX
   2006cx: GALEX
   2006da: GALEX
   2006db: GALEX
   2006dg: GALEX
   2006dh: GALEX
   2006di: GALEX
   2006dk: GAL

## 2. Helper Functions for Visualization


In [5]:
def load_fits_image(filepath: Path) -> tuple:
    """
    Load FITS image data and header info.
    Returns: (image_data, info_dict) or (None, error_info)
    """
    try:
        with fits.open(filepath) as hdul:
            # Try primary HDU first
            data = hdul[0].data
            header = hdul[0].header
            
            # If primary is empty, try first extension with data
            if data is None:
                for i, hdu in enumerate(hdul):
                    if hdu.data is not None and len(hdu.data.shape) >= 2:
                        data = hdu.data
                        header = hdu.header
                        break
            
            if data is None:
                return None, {'error': 'No image data found'}
            
            # Handle 3D data (take first slice)
            if len(data.shape) == 3:
                data = data[0]
            
            # Get WCS info
            try:
                wcs = WCS(header, naxis=2)
                has_wcs = wcs.has_celestial
            except Exception:
                has_wcs = False
            
            info = {
                'shape': data.shape,
                'dtype': str(data.dtype),
                'min': float(np.nanmin(data)),
                'max': float(np.nanmax(data)),
                'mean': float(np.nanmean(data)),
                'has_wcs': has_wcs,
                'filter': header.get('FILTER', header.get('FILTER1', 'unknown')),
                'exptime': header.get('EXPTIME', header.get('EXPOSURE', 'unknown')),
                'date_obs': header.get('DATE-OBS', 'unknown'),
            }
            
            return data.astype(float), info
            
    except Exception as e:
        return None, {'error': str(e)}


def display_image(ax, data: np.ndarray, title: str = "", cmap: str = 'gray'):
    """Display astronomical image with appropriate scaling."""
    if data is None:
        ax.text(0.5, 0.5, 'No data', ha='center', va='center', fontsize=12)
        ax.set_title(title)
        return
    
    # Use ZScale for astronomical images
    try:
        interval = ZScaleInterval()
        vmin, vmax = interval.get_limits(data[np.isfinite(data)])
    except Exception:
        vmin, vmax = np.nanpercentile(data, [1, 99])
    
    im = ax.imshow(data, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_title(title, fontsize=10)
    plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)


def show_ref_sci_comparison(sn_name: str, ref_file: str, sci_file: str, base_dir: Path):
    """Display reference and science images side by side."""
    ref_path = base_dir / ref_file
    sci_path = base_dir / sci_file
    
    ref_data, ref_info = load_fits_image(ref_path)
    sci_data, sci_info = load_fits_image(sci_path)
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    fig.suptitle(f"SN {sn_name}", fontsize=14, fontweight='bold')
    
    # Reference
    ref_title = f"Reference\n{Path(ref_file).name[:40]}...\nShape: {ref_info.get('shape', 'N/A')}"
    display_image(axes[0], ref_data, ref_title)
    
    # Science
    sci_title = f"Science\n{Path(sci_file).name[:40]}...\nShape: {sci_info.get('shape', 'N/A')}"
    display_image(axes[1], sci_data, sci_title)
    
    # Simple difference (if same shape)
    if ref_data is not None and sci_data is not None:
        if ref_data.shape == sci_data.shape:
            diff = sci_data - ref_data
            display_image(axes[2], diff, "Simple Difference\n(Science - Reference)", cmap='RdBu_r')
        else:
            axes[2].text(0.5, 0.5, f"Shape mismatch\nRef: {ref_data.shape}\nSci: {sci_data.shape}", 
                        ha='center', va='center', fontsize=10)
            axes[2].set_title("Cannot difference")
    else:
        axes[2].text(0.5, 0.5, 'Missing data', ha='center', va='center')
        axes[2].set_title("Cannot difference")
    
    plt.tight_layout()
    plt.show()
    
    # Print metadata
    print(f"\nMetadata:")
    print(f"   Reference: filter={ref_info.get('filter')}, exptime={ref_info.get('exptime')}, WCS={ref_info.get('has_wcs')}")
    print(f"   Science:   filter={sci_info.get('filter')}, exptime={sci_info.get('exptime')}, WCS={sci_info.get('has_wcs')}")
    print(f"   Date (sci): {sci_info.get('date_obs')}")

print("Helper functions defined")


Helper functions defined


## 3. Inspect Same-Mission Pairs

Let's look at the SNe that have matching missions in reference and science.


In [6]:
# Get detailed info for same-mission SNe
print("=" * 60)
print("SAME-MISSION SNe - DETAILED VIEW")
print("=" * 60)

for analysis in same_mission_sne:
    sn = analysis['sn_name']
    print(f"\nüåü {sn}")
    
    for mission in analysis['common_missions']:
        ref_files = analysis['ref_missions'].get(mission, [])
        sci_files = analysis['sci_missions'].get(mission, [])
        
        print(f"   üì° {mission}:")
        print(f"      Reference files: {len(ref_files)}")
        print(f"      Science files:   {len(sci_files)}")
        
        # Show filter breakdown for SWIFT
        if mission == "SWIFT":
            ref_filters = [get_swift_filter(f) for f in ref_files]
            sci_filters = [get_swift_filter(f) for f in sci_files]
            print(f"      Ref filters:  {sorted(set(ref_filters))}")
            print(f"      Sci filters:  {sorted(set(sci_filters))}")
            
            # Find matching filters
            common_filters = set(ref_filters) & set(sci_filters)
            if common_filters:
                print(f"      ‚úÖ Matching filters: {sorted(common_filters)}")


SAME-MISSION SNe - DETAILED VIEW

üåü 2005V
   üì° GALEX:
      Reference files: 6
      Science files:   6

üåü 2005ai
   üì° GALEX:
      Reference files: 6
      Science files:   3

üåü 2005ay
   üì° GALEX:
      Reference files: 9
      Science files:   9

üåü 2005az
   üì° GALEX:
      Reference files: 3
      Science files:   9

üåü 2005bj
   üì° GALEX:
      Reference files: 3
      Science files:   3

üåü 2005bk
   üì° GALEX:
      Reference files: 3
      Science files:   3

üåü 2005bn
   üì° GALEX:
      Reference files: 6
      Science files:   3

üåü 2005bt
   üì° GALEX:
      Reference files: 6
      Science files:   3

üåü 2005ca
   üì° GALEX:
      Reference files: 6
      Science files:   6

üåü 2005cc
   üì° GALEX:
      Reference files: 6
      Science files:   3

üåü 2005ck
   üì° GALEX:
      Reference files: 3
      Science files:   9

üåü 2005cr
   üì° GALEX:
      Reference files: 6
      Science files:   3

üåü 2005ct
   üì° GALEX:
    

## 4. Visual Inspection: Pick a SN to Examine

Select a supernova with same-mission data and visually compare reference vs science.


In [7]:
# Pick a SN to inspect
TARGET_SN = "2014J"

# Find it in manifest
target_entry = next((e for e in manifest if e['sn_name'] == TARGET_SN), None)
target_analysis = next((a for a in analyses if a['sn_name'] == TARGET_SN), None)

if target_entry is None:
    print(f"‚ùå SN {TARGET_SN} not found in manifest")
else:
    print(f"‚úÖ Found SN {TARGET_SN}")
    print(f"   Reference files: {len(target_entry['reference_files'])}")
    print(f"   Science files: {len(target_entry['science_files'])}")
    print(f"   Common missions: {target_analysis['common_missions']}")


‚ùå SN 2014J not found in manifest


In [8]:
# List available files for the target SN
if target_entry:
    print(f"\nReference files for {TARGET_SN}:")
    for i, f in enumerate(target_entry['reference_files'][:10]):
        filt = get_swift_filter(f) if 'SWIFT' in f else 'N/A'
        print(f"   [{i}] {Path(f).name} (filter: {filt})")
    if len(target_entry['reference_files']) > 10:
        print(f"   ... and {len(target_entry['reference_files']) - 10} more")
    
    print(f"\nScience files for {TARGET_SN}:")
    for i, f in enumerate(target_entry['science_files']):
        filt = get_swift_filter(f) if 'SWIFT' in f else 'N/A'
        print(f"   [{i}] {Path(f).name} (filter: {filt})")


In [9]:
# Compare a reference and science image with MATCHING filter
if target_entry and target_analysis['has_same_mission_pair']:
    # Find matching filter pair
    ref_files = target_entry['reference_files']
    sci_files = target_entry['science_files']
    
    # Group by filter
    ref_by_filter = defaultdict(list)
    sci_by_filter = defaultdict(list)
    
    for f in ref_files:
        if 'SWIFT' in f:
            filt = get_swift_filter(f)
            ref_by_filter[filt].append(f)
    
    for f in sci_files:
        if 'SWIFT' in f:
            filt = get_swift_filter(f)
            sci_by_filter[filt].append(f)
    
    # Find common filters
    common_filters = set(ref_by_filter.keys()) & set(sci_by_filter.keys())
    print(f"Common SWIFT filters: {sorted(common_filters)}")
    
    # Pick first common filter and show comparison
    if common_filters:
        chosen_filter = sorted(common_filters)[0]
        ref_file = ref_by_filter[chosen_filter][0]
        sci_file = sci_by_filter[chosen_filter][0]
        
        print(f"\nComparing {chosen_filter} filter images:")
        show_ref_sci_comparison(TARGET_SN, ref_file, sci_file, TRAINING_DIR)


In [10]:
# ISSUE IDENTIFIED: Need to match filters AND align via WCS!
# The previous comparison used different filters (uw1 vs um2) - wrong!
# Let's do a proper same-filter comparison with WCS alignment

from reproject import reproject_interp

def show_aligned_comparison(sn_name: str, ref_file: str, sci_file: str, base_dir: Path):
    """Display WCS-aligned reference and science images."""
    ref_path = base_dir / ref_file
    sci_path = base_dir / sci_file
    
    # Load with full header for WCS
    with fits.open(ref_path) as hdul:
        for hdu in hdul:
            if hdu.data is not None and len(hdu.data.shape) >= 2:
                ref_data = hdu.data.astype(float)
                ref_header = hdu.header
                if len(ref_data.shape) == 3:
                    ref_data = ref_data[0]
                break
    
    with fits.open(sci_path) as hdul:
        for hdu in hdul:
            if hdu.data is not None and len(hdu.data.shape) >= 2:
                sci_data = hdu.data.astype(float)
                sci_header = hdu.header
                if len(sci_data.shape) == 3:
                    sci_data = sci_data[0]
                break
    
    ref_wcs = WCS(ref_header, naxis=2)
    sci_wcs = WCS(sci_header, naxis=2)
    
    print(f"Reference: {Path(ref_file).name}")
    print(f"  Date: {ref_header.get('DATE-OBS', 'N/A')}")
    print(f"  Shape: {ref_data.shape}")
    
    print(f"\nScience: {Path(sci_file).name}")
    print(f"  Date: {sci_header.get('DATE-OBS', 'N/A')}")
    print(f"  Shape: {sci_data.shape}")
    
    # Reproject science to match reference WCS
    print("\nüîÑ Reprojecting science image to match reference WCS...")
    sci_reproj, footprint = reproject_interp(
        (sci_data, sci_wcs), 
        ref_header, 
        shape_out=ref_data.shape
    )
    
    # Compute difference
    diff = sci_reproj - ref_data
    
    # Handle NaN from reprojection (areas outside overlap)
    diff[~np.isfinite(diff)] = 0
    sci_reproj[~np.isfinite(sci_reproj)] = np.nan
    
    # Display
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    fig.suptitle(f"SN {sn_name} - WCS Aligned (Same Filter)", fontsize=14, fontweight='bold')
    
    # Reference
    interval = ZScaleInterval()
    vmin, vmax = interval.get_limits(ref_data[np.isfinite(ref_data)])
    im1 = axes[0].imshow(ref_data, origin='lower', cmap='gray', vmin=vmin, vmax=vmax)
    axes[0].set_title(f"Reference (PRE-SN)\n{ref_header.get('DATE-OBS', '')[:10]}")
    plt.colorbar(im1, ax=axes[0], fraction=0.046)
    
    # Science (reprojected)
    vmin2, vmax2 = interval.get_limits(sci_reproj[np.isfinite(sci_reproj)])
    im2 = axes[1].imshow(sci_reproj, origin='lower', cmap='gray', vmin=vmin2, vmax=vmax2)
    axes[1].set_title(f"Science (POST-SN)\n{sci_header.get('DATE-OBS', '')[:10]}")
    plt.colorbar(im2, ax=axes[1], fraction=0.046)
    
    # Difference
    vmax_diff = np.nanpercentile(np.abs(diff), 99)
    im3 = axes[2].imshow(diff, origin='lower', cmap='RdBu_r', vmin=-vmax_diff, vmax=vmax_diff)
    axes[2].set_title("Difference\n(Science - Reference)")
    plt.colorbar(im3, ax=axes[2], fraction=0.046)
    
    # Footprint (overlap region)
    axes[3].imshow(footprint, origin='lower', cmap='gray')
    axes[3].set_title("Overlap Footprint")
    
    plt.tight_layout()
    plt.show()
    
    return diff, footprint

print("‚úÖ WCS-aligned comparison function ready")


‚úÖ WCS-aligned comparison function ready


In [11]:
# Now compare with MATCHING filter (uvv) and WCS alignment 

# Find files containing "uvv" in the reference and science folders
ref_dir = TRAINING_DIR / TARGET_SN / "reference"
sci_dir = TRAINING_DIR / TARGET_SN / "science"

ref_uvv_files = list(ref_dir.glob("*uvv*.fits"))
sci_uvv_files = list(sci_dir.glob("*uvv*.fits"))

if ref_uvv_files and sci_uvv_files:
    # Use the first match from each folder, convert to relative path
    ref_uvv = str(ref_uvv_files[0].relative_to(TRAINING_DIR))
    sci_uvv = str(sci_uvv_files[0].relative_to(TRAINING_DIR))
    
    print("üî¨ SN "+TARGET_SN+" - Proper same-filter (uvv) WCS-aligned comparison")
    print("=" * 60)
    print(f"Reference: {ref_uvv}")
    print(f"Science: {sci_uvv}")
    diff_result, footprint = show_aligned_comparison(TARGET_SN, ref_uvv, sci_uvv, TRAINING_DIR)
else:
    print(f"‚ö†Ô∏è No uvv filter files found for {TARGET_SN}")
    if not ref_uvv_files:
        print(f"   Missing in reference folder")
    if not sci_uvv_files:
        print(f"   Missing in science folder")


‚ö†Ô∏è No uvv filter files found for 2014J
   Missing in reference folder
   Missing in science folder


In [12]:
# Also try the uuu filter pair
ref_dir = TRAINING_DIR / TARGET_SN / "reference"
sci_dir = TRAINING_DIR / TARGET_SN / "science"

ref_uuu_files = list(ref_dir.glob("*uuu*.fits"))
sci_uuu_files = list(sci_dir.glob("*uuu*.fits"))

if ref_uuu_files and sci_uuu_files:
    # Use the first match from each folder, convert to relative path
    ref_uuu = str(ref_uuu_files[0].relative_to(TRAINING_DIR))
    sci_uuu = str(sci_uuu_files[0].relative_to(TRAINING_DIR))
    
    print("üî¨ SN "+TARGET_SN+" - Proper same-filter (uuu) WCS-aligned comparison")
    print("=" * 60)
    print(f"Reference: {ref_uuu}")
    print(f"Science: {sci_uuu}")
    diff_result, footprint = show_aligned_comparison(TARGET_SN, ref_uuu, sci_uuu, TRAINING_DIR)
else:
    print(f"‚ö†Ô∏è No uuu filter files found for {TARGET_SN}")
    if not ref_uuu_files:
        print(f"   Missing in reference folder")
    if not sci_uuu_files:
        print(f"   Missing in science folder")


‚ö†Ô∏è No uuu filter files found for 2014J
   Missing in reference folder
   Missing in science folder


## 7. Full Differencing Pipeline

A reusable pipeline with:
1. WCS alignment (reprojection)
2. Background estimation and subtraction
3. PSF estimation and matching
4. Flux normalization
5. ZOGY optimal differencing
6. Source detection on difference image
7. Known SN position marking


In [13]:
# Import from shared module
from src.domains.differencing.pipeline import (
    DifferencingResult as PipelineResult,
    SNDifferencingPipeline,
)

# Create wrapper for notebook visualization (preserves images)
def process_with_images(pipeline, ref_path, sci_path, sn_name="unknown", sn_coords=None):
    """Wrapper that preserves image arrays for visualization."""
    diff, sig, mask, result = pipeline.process(ref_path, sci_path, sn_name, sn_coords)
    
    # Reload images (pipeline cleans them up)
    ref_data, ref_header, ref_wcs = pipeline.load_fits(ref_path)
    sci_data, sci_header, sci_wcs = pipeline.load_fits(sci_path)
    
    # Reproject for visualization
    from reproject import reproject_interp
    sci_aligned, footprint = reproject_interp(
        (sci_data, sci_wcs), ref_wcs, shape_out=ref_data.shape
    )
    
    # Return dict compatible with your visualization function
    return {
        'sn_name': result.sn_name,
        'filter_name': result.filter_name,
        'reference': ref_data,
        'science': sci_data,
        'science_aligned': sci_aligned,
        'difference': diff,
        'significance': sig,
        'footprint': footprint,
        'ref_header': ref_header,
        'sci_header': sci_header,
        'metrics': {
            'overlap_fraction': result.overlap_fraction,
            'sig_max': result.sig_max,
            # ... other metrics
        },
        'sn_pixel': result.sn_pixel,
    }

INFO:src.core.db.session:No SSL certificate path provided, using certifi CA bundle
INFO:src.core.db.session:Using certifi CA bundle for SSL verification
INFO:src.core.db.session:Creating database engine with URL: postgresql+asyncpg://postgres:None@aws-0-us-west-1.pooler.supabase.com:5432/postgres
INFO:src.core.db.session:Database engine created successfully


ValueError: Supabase URL and anon key must be configured

In [None]:
def visualize_result(result: DifferencingResult, figsize=(20, 10)):
    """
    Visualize differencing pipeline results.
    
    Shows: Reference, Science (aligned), Difference, Significance
    With SN position marked if known.
    """
    fig, axes = plt.subplots(2, 4, figsize=figsize)
    fig.suptitle(f"SN {result.sn_name} - {result.filter_name} filter\n"
                 f"Ref: {result.ref_header.get('DATE-OBS', '')[:10]} | "
                 f"Sci: {result.sci_header.get('DATE-OBS', '')[:10]}", 
                 fontsize=14, fontweight='bold')
    
    interval = ZScaleInterval()
    valid = result.footprint > 0.5
    
    # Row 1: Images
    # Reference
    vmin, vmax = interval.get_limits(result.reference[np.isfinite(result.reference)])
    im1 = axes[0, 0].imshow(result.reference, origin='lower', cmap='gray', vmin=vmin, vmax=vmax)
    axes[0, 0].set_title("Reference (PRE-SN)")
    plt.colorbar(im1, ax=axes[0, 0], fraction=0.046)
    
    # Science aligned
    vmin2, vmax2 = interval.get_limits(result.science_aligned[valid])
    im2 = axes[0, 1].imshow(result.science_aligned, origin='lower', cmap='gray', vmin=vmin2, vmax=vmax2)
    axes[0, 1].set_title("Science Aligned (POST-SN)")
    plt.colorbar(im2, ax=axes[0, 1], fraction=0.046)
    
    # Difference
    vmax_diff = np.nanpercentile(np.abs(result.difference[valid]), 99)
    im3 = axes[0, 2].imshow(result.difference, origin='lower', cmap='RdBu_r', 
                            vmin=-vmax_diff, vmax=vmax_diff)
    axes[0, 2].set_title("Difference (Sci - Ref)")
    plt.colorbar(im3, ax=axes[0, 2], fraction=0.046)
    
    # Significance
    vmax_sig = min(np.nanpercentile(np.abs(result.significance[valid]), 99.5), 20)
    im4 = axes[0, 3].imshow(result.significance, origin='lower', cmap='RdBu_r',
                            vmin=-vmax_sig, vmax=vmax_sig)
    axes[0, 3].set_title(f"Significance (œÉ)\nMax: {result.metrics['sig_max']:.1f}œÉ")
    plt.colorbar(im4, ax=axes[0, 3], fraction=0.046)
    
    # Mark SN position on all plots
    if result.sn_pixel is not None:
        x, y = result.sn_pixel
        for ax in axes[0, :]:
            ax.scatter([x], [y], marker='o', s=200, facecolors='none', 
                      edgecolors='lime', linewidths=2, label='Known SN')
            ax.scatter([x], [y], marker='+', s=100, c='lime', linewidths=2)
    
    # Row 2: Diagnostics
    # Footprint
    axes[1, 0].imshow(result.footprint, origin='lower', cmap='gray')
    axes[1, 0].set_title(f"Overlap Footprint\n{result.metrics['overlap_fraction']:.1f}%")
    
    # Histogram of significance
    sig_vals = result.significance[valid].flatten()
    sig_vals = sig_vals[np.isfinite(sig_vals)]
    axes[1, 1].hist(sig_vals, bins=100, range=(-10, 10), density=True, alpha=0.7)
    axes[1, 1].axvline(0, color='k', linestyle='--', alpha=0.5)
    axes[1, 1].axvline(5, color='r', linestyle='--', alpha=0.5, label='5œÉ')
    axes[1, 1].axvline(-5, color='r', linestyle='--', alpha=0.5)
    axes[1, 1].set_xlabel("Significance (œÉ)")
    axes[1, 1].set_ylabel("Density")
    axes[1, 1].set_title("Significance Distribution")
    axes[1, 1].legend()
    
    # Zoomed difference around SN (if known)
    if result.sn_pixel is not None:
        x, y = result.sn_pixel
        x, y = int(x), int(y)
        size = 100
        y1, y2 = max(0, y-size), min(result.difference.shape[0], y+size)
        x1, x2 = max(0, x-size), min(result.difference.shape[1], x+size)
        
        cutout = result.difference[y1:y2, x1:x2]
        vmax_cut = np.nanpercentile(np.abs(cutout), 99)
        im5 = axes[1, 2].imshow(cutout, origin='lower', cmap='RdBu_r',
                                vmin=-vmax_cut, vmax=vmax_cut)
        axes[1, 2].scatter([size], [size], marker='+', s=200, c='lime', linewidths=2)
        axes[1, 2].set_title(f"Zoomed Difference\n(200x200 px around SN)")
        plt.colorbar(im5, ax=axes[1, 2], fraction=0.046)
        
        # Zoomed significance
        cutout_sig = result.significance[y1:y2, x1:x2]
        vmax_sig_cut = min(np.nanpercentile(np.abs(cutout_sig), 99), 15)
        im6 = axes[1, 3].imshow(cutout_sig, origin='lower', cmap='RdBu_r',
                                vmin=-vmax_sig_cut, vmax=vmax_sig_cut)
        axes[1, 3].scatter([size], [size], marker='+', s=200, c='lime', linewidths=2)
        axes[1, 3].set_title("Zoomed Significance")
        plt.colorbar(im6, ax=axes[1, 3], fraction=0.046)
    else:
        axes[1, 2].text(0.5, 0.5, "No SN coords\nprovided", ha='center', va='center')
        axes[1, 3].text(0.5, 0.5, "No SN coords\nprovided", ha='center', va='center')
    
    plt.tight_layout()
    plt.show()
    
    # Print metrics
    print("\nüìä Pipeline Metrics:")
    for key, val in result.metrics.items():
        print(f"   {key}: {val:.4f}" if isinstance(val, float) else f"   {key}: {val}")

print("‚úÖ Visualization function defined")


In [None]:
# Known SN coordinates from catalogs/literature
# With J2000 coordinates of RA = 6h 21m 44s.86 Dec. = ‚àí59¬∞44‚Ä≤26‚Ä≥  - https://academic.oup.com/mnras/article/424/2/1297/999770
from astropy.coordinates import SkyCoord

SN_COORDINATES = {
    "2014J": SkyCoord("6h 21m 44.86s", "-59d 44m 26s", frame='icrs'),
    # Add more as we find them...
}

# Test the full pipeline on SN 2014J with the uuu filter
pipeline = SNDifferencingPipeline(
    psf_fwhm=4.0,  # SWIFT UVOT typical FWHM
    background_box_size=64,
    detection_threshold=5.0
)

# Process SN 2014J (uuu filter - best overlap)
ref_path = TRAINING_DIR / "2014J" / "reference" / "SWIFT_sw00031102001uuu_sk.fits"
sci_path = TRAINING_DIR / "2014J" / "science" / "SWIFT_sw00031562001uuu_sk.fits"

result_2014J = pipeline.process(
    ref_path=ref_path,
    sci_path=sci_path,
    sn_name="2014J",
    sn_coords=SN_COORDINATES.get("2014J")
)


In [None]:
# Visualize the results
visualize_result(result_2014J)


## 8. Process All Same-Mission SNe

Run the pipeline on all 8 SNe with SWIFT-SWIFT pairs.


In [None]:
def find_matching_filter_pair(sn_name: str, base_dir: Path) -> Optional[Tuple[Path, Path, str]]:
    """
    Find a reference/science pair with matching SWIFT filter.
    
    Returns:
        (ref_path, sci_path, filter_name) or None if no match
    """
    ref_dir = base_dir / sn_name / "reference"
    sci_dir = base_dir / sn_name / "science"
    
    if not ref_dir.exists() or not sci_dir.exists():
        return None
    
    # Get SWIFT files grouped by filter
    ref_by_filter = defaultdict(list)
    sci_by_filter = defaultdict(list)
    
    for f in ref_dir.glob("SWIFT_*.fits"):
        for filt in ['uvw2', 'uvm2', 'uvw1', 'uuu', 'ubb', 'uvv']:
            if filt in f.name.lower():
                ref_by_filter[filt].append(f)
                break
    
    for f in sci_dir.glob("SWIFT_*.fits"):
        for filt in ['uvw2', 'uvm2', 'uvw1', 'uuu', 'ubb', 'uvv']:
            if filt in f.name.lower():
                sci_by_filter[filt].append(f)
                break
    
    # Find common filters
    common = set(ref_by_filter.keys()) & set(sci_by_filter.keys())
    
    if not common:
        return None
    
    # Prefer uuu (U-band) or uvw1 for best SN visibility
    for preferred in ['uuu', 'uvw1', 'uvm2', 'uvw2', 'ubb', 'uvv']:
        if preferred in common:
            return ref_by_filter[preferred][0], sci_by_filter[preferred][0], preferred
    
    # Fallback to any common filter
    filt = list(common)[0]
    return ref_by_filter[filt][0], sci_by_filter[filt][0], filt

# List of same-mission SNe
SAME_MISSION_SNE = ['2013gc', '2013hh', '2013hn', '2014J', '2014L', '2014ai', '2014bh', '2014bi']

print(f"Processing {len(SAME_MISSION_SNE)} same-mission SNe...")
print("=" * 60)


In [None]:
# Process all same-mission SNe
# Set to True to run (takes a few minutes)
PROCESS_ALL = True

all_results = {}

if PROCESS_ALL:
    for sn_name in SAME_MISSION_SNE:
        print(f"\n{'='*60}")
        pair = find_matching_filter_pair(sn_name, TRAINING_DIR)
        
        if pair is None:
            print(f"‚ö†Ô∏è  {sn_name}: No matching filter pair found")
            continue
        
        ref_path, sci_path, filter_name = pair
        print(f"üìÅ {sn_name}: Using {filter_name} filter")
        
        try:
            result = pipeline.process(
                ref_path=ref_path,
                sci_path=sci_path,
                sn_name=sn_name,
                sn_coords=SN_COORDINATES.get(sn_name)
            )
            all_results[sn_name] = result
            
            # Quick visualization
            visualize_result(result)
            
        except Exception as e:
            print(f"‚ùå {sn_name}: Processing failed - {e}")
            import traceback
            traceback.print_exc()
else:
    print("Set PROCESS_ALL = True to process all SNe")


## 9. Save Pipeline for Reuse

Export the pipeline class for use in other notebooks and scripts.


In [None]:
# Save results summary
if all_results or result_2014J:
    summary = {
        'pipeline_version': '1.0',
        'processed_sne': list(all_results.keys()) if all_results else ['2014J'],
        'results': {}
    }
    
    results_to_save = all_results if all_results else {'2014J': result_2014J}
    
    for sn_name, result in results_to_save.items():
        summary['results'][sn_name] = {
            'filter': result.filter_name,
            'ref_date': result.ref_header.get('DATE-OBS', 'N/A')[:10],
            'sci_date': result.sci_header.get('DATE-OBS', 'N/A')[:10],
            'metrics': result.metrics,
            'sn_pixel': result.sn_pixel,
        }
    
    output_file = TRAINING_DIR / 'differencing_results.json'
    with open(output_file, 'w') as f:
        json.dump(summary, f, indent=2, default=str)
    
    print(f"üíæ Results saved to: {output_file}")

print("\n" + "=" * 60)
print("DIFFERENCING PIPELINE COMPLETE")
print("=" * 60)
print("""
‚úÖ SNDifferencingPipeline ready!

Next: Run triplet generation below (Phase 4) to create CNN training data.
""")
