# Full Pipeline: coronagraphoto + coronablink

This notebook demonstrates the complete workflow from image simulation to SNR calculation:

1. **coronagraphoto** generates simulated coronagraphic observations
2. **coronablink** performs PSF subtraction and SNR analysis

We use `pooch` to automatically download example data files.

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

import coronablink as cb

## Download Example Data

coronablink includes example data via `pooch`. The first time you run this, it will download:
- A coronagraph YIP (eac1_aavc_512 created by Susan Redmond)
- An ExoVista scene (modified Solar System)

In [None]:
# Fetch example data (downloads from GitHub if not cached)
coronagraph_path = cb.fetch_coronagraph()
scene_path = cb.fetch_scene()

print(f"Coronagraph: {coronagraph_path}")
print(f"Scene: {scene_path}")

## Load Data with coronagraphoto and yippy

In [None]:
from yippy import Coronagraph as YippyCoronagraph
from coronagraphoto import (
    Exposure, OpticalPath, load_sky_scene_from_exovista
)
from coronagraphoto.optical_elements import (
    PrimaryAperture, SimpleDetector, ConstantThroughputElement, from_yippy
)
from coronagraphoto.core.simulation import sim_star, sim_planets, sim_disk, sim_zodi

# Load scene and coronagraph
scene = load_sky_scene_from_exovista(scene_path)
yippy_coro = YippyCoronagraph(coronagraph_path, use_jax=True, quarter_symmetric_datacube=True)
coronagraph = from_yippy(yippy_coro)

print(f"Loaded scene with {len(scene.planets)} planets")

## Define Optical Path

In [None]:
optical_path = OpticalPath(
    primary=PrimaryAperture(diameter_m=6.0),
    attenuating_elements=(ConstantThroughputElement(throughput=0.9),),
    coronagraph=coronagraph,
    detector=SimpleDetector(pixel_scale=1/512, shape=(512, 512))
)

# FWHM in pixels (from coronagraph)
fwhm = 4.5  # Approximate for this coronagraph

## Simulate Observation

In [None]:
from coronagraphoto import conversions

# Define exposure
exposure = Exposure(
    start_time_jd=conversions.decimal_year_to_jd(2030.0),
    exposure_time_s=3600.0,  # 1 hour
    central_wavelength_nm=jnp.array([550.0]),
    bin_width_nm=jnp.array([100.0]),
    position_angle_deg=0.0
)

# Simulate each component
key = jax.random.PRNGKey(0)
k1, k2, k3, k4 = jax.random.split(key, 4)

args = (
    exposure.start_time_jd, 
    exposure.exposure_time_s,
    exposure.central_wavelength_nm[0], 
    exposure.bin_width_nm[0]
)

star_electrons = sim_star(*args, scene.stars, optical_path, k1)
planet_electrons = sim_planets(*args, exposure.position_angle_deg, scene.planets, optical_path, k2)

# Full observation = star + planets
observation = star_electrons + planet_electrons

# Add detector noise
noise_electrons = optical_path.detector.readout_noise_electrons(exposure.exposure_time_s, key)
noisy_observation = observation + noise_electrons

print(f"Observation shape: {noisy_observation.shape}")
print(f"Max signal: {float(jnp.max(noisy_observation)):.1f} e-")

## Visualizing the Raw Observation

The stellar speckles dominate the image. The planets are hidden in the noise.

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

obs_np = np.array(noisy_observation)
star_np = np.array(star_electrons)

# Raw observation
ax = axes[0]
vmax = np.percentile(obs_np, 99.5)
im = ax.imshow(obs_np, origin='lower', cmap='magma', vmin=0, vmax=vmax)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)
ax.set_title('Noisy Observation\n(Star + Planets + Noise)', fontsize=12, fontweight='bold')
ax.set_xlabel('X (pixels)')
ax.set_ylabel('Y (pixels)')

# Star model (what we know)
ax = axes[1]
im = ax.imshow(star_np, origin='lower', cmap='magma', vmin=0, vmax=vmax)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)
ax.set_title('Stellar Model\n(Noiseless Expectation)', fontsize=12, fontweight='bold')
ax.set_xlabel('X (pixels)')
ax.set_ylabel('Y (pixels)')

plt.tight_layout()
plt.show()

## PSF Subtraction with subtract_star

Now we use coronablink's `subtract_star` to subtract the stellar model. This reveals the planets!

In [None]:
# Subtract stellar PSF model
residual = cb.subtract_star(noisy_observation, star_electrons)

# If you have a disk model, you would also call:
# residual = cb.subtract_disk(residual, disk_model)

# Visualize
fig, axes = plt.subplots(1, 3, figsize=(16, 5))

# Before
ax = axes[0]
im = ax.imshow(obs_np, origin='lower', cmap='magma', vmin=0, vmax=vmax)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)
ax.set_title('Observation', fontsize=11, fontweight='bold')

# Subtraction
ax = axes[1]
im = ax.imshow(star_np, origin='lower', cmap='magma', vmin=0, vmax=vmax)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)
ax.set_title('- Star Model', fontsize=11, fontweight='bold')

# After
ax = axes[2]
res_np = np.array(residual)
vabs = np.percentile(np.abs(res_np), 99)
im = ax.imshow(res_np, origin='lower', cmap='RdBu_r', vmin=-vabs, vmax=vabs)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)
ax.set_title('= Residual (Planets Revealed!)', fontsize=11, fontweight='bold')

plt.tight_layout()
plt.show()

## Get Planet Positions and Calculate SNR

In [None]:
# Get planet positions from the scene (in pixels)
# Note: This depends on how your scene stores planet positions
# For this example, we'll compute approximate positions

center = np.array(noisy_observation.shape) / 2

# For demonstration, we'll use the known planet positions from the residual
# In practice, you'd get these from scene.planets or detection algorithms

# Calculate SNR at a sample position (adjust based on your scene)
# Let's find the brightest point in the residual outside the IWA
y, x = np.ogrid[:res_np.shape[0], :res_np.shape[1]]
r = np.sqrt((y - center[0])**2 + (x - center[1])**2)

# Mask inside IWA
masked_res = np.where(r > 2 * fwhm, res_np, -np.inf)
planet_idx = np.unravel_index(np.argmax(masked_res), masked_res.shape)
planet_pos = (float(planet_idx[0]), float(planet_idx[1]))

print(f"Brightest detection at: {planet_pos}")

# Calculate SNR using Mawet method
positions = jnp.array([[planet_pos[0], planet_pos[1]]])
snr_mawet = float(cb.snr(residual, positions, fwhm)[0])

print(f"\nPlanet Detection SNR: {snr_mawet:.1f}")

if snr_mawet > 5:
    print(f"\n  -> Detected at {snr_mawet:.1f}σ significance!")
else:
    print(f"\n  -> Below 5σ threshold")

## Visualizing the Detection

In [None]:
fig, ax = plt.subplots(figsize=(10, 10))

im = ax.imshow(res_np, origin='lower', cmap='RdBu_r', vmin=-vabs, vmax=vabs)
plt.colorbar(im, ax=ax, label='e-', shrink=0.8)

# Mark detection
circle = Circle((planet_pos[1], planet_pos[0]), fwhm, 
                fill=False, color='#00ff00', linewidth=3)
ax.add_patch(circle)
ax.annotate(f'SNR = {snr_mawet:.1f}', 
           (planet_pos[1] + fwhm + 5, planet_pos[0]), 
           color='#00ff00', fontsize=12, fontweight='bold')

ax.scatter([center[1]], [center[0]], marker='+', s=200, color='black', linewidths=2)

ax.set_title(f'Residual with Detection\nSNR = {snr_mawet:.1f}', 
             fontsize=12, fontweight='bold')
ax.set_xlabel('X (pixels)')
ax.set_ylabel('Y (pixels)')

plt.tight_layout()
plt.show()

## Complete Workflow Summary

```python
import coronablink as cb
import jax.numpy as jnp

# 1. Fetch example data
coronagraph_path = cb.fetch_coronagraph()
scene_path = cb.fetch_scene()

# 2. Generate images with coronagraphoto
star = sim_star(...)
planets = sim_planets(...)
observation = star + planets + noise

# 3. Subtract stellar PSF with coronablink
residual = cb.subtract_star(observation, star)

# 4. If you have a disk model:
# residual = cb.subtract_disk(residual, disk_model)

# 5. Calculate SNR (Mawet method)
positions = jnp.array([[py, px]])  # (N, 2) array of (y, x) coords
snrs = cb.snr(residual, positions, fwhm)
```