# SNR Calculation with coronablink

This notebook demonstrates SNR calculation using the **Mawet et al. (2014)** method, which is the standard approach for exoplanet detection in coronagraphic images.

| Method | Description | Use Case |
|--------|-------------|----------|
| **Mawet SNR** | Aperture-based with small-sample correction | Detection claims, publications |

> **Note:** An experimental matched-filter method is available via `from coronablink.core.matched_filter import matched_filter_snr` for research comparison.

In [None]:
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Circle
from matplotlib.lines import Line2D

# SNR API imports
from coronablink import snr, snr_estimator
from coronablink.core.geometry import calculate_n_apertures, generate_aperture_coords

## Test Setup

In [None]:
# Create test image
size = 101
fwhm = 5.0
noise_level = 100.0
center = size // 2

np.random.seed(0)
image = np.random.normal(0, noise_level, (size, size))

# Add a planet
planet_sep = 25
planet_flux = 500.0
sigma = fwhm / 2.355

planet_y = center + planet_sep
planet_x = center

y, x = np.ogrid[:size, :size]
r2 = (y - planet_y)**2 + (x - planet_x)**2
image += planet_flux * np.exp(-r2 / (2 * sigma**2))

image_jax = jnp.array(image)

## Mawet SNR (Aperture-Based)

The `snr()` function places discrete apertures around the annulus at the planet's separation. It uses the **standard deviation** of aperture fluxes as the noise estimate, with a **small-sample correction** for few apertures.

Reference: Mawet et al. (2014), ApJ, 792, 97

In [None]:
# Calculate Mawet SNR
positions = jnp.array([[planet_y, planet_x]])
snr_mawet = float(snr(image_jax, positions, fwhm)[0])

print(f"Mawet SNR: {snr_mawet:.2f}")

In [None]:
# Visualize Mawet aperture placement using the ACTUAL geometry function
fig, ax = plt.subplots(figsize=(8, 8))

vmax = np.percentile(image, 99)
ax.imshow(image, origin='lower', cmap='viridis', vmin=-200, vmax=vmax)

# Calculate planet angle (same as used in the actual SNR calculation)
planet_angle = np.arctan2(planet_y - center, planet_x - center)

# Calculate number of apertures
n_apertures = calculate_n_apertures(radius=planet_sep, fwhm=fwhm)

# Generate aperture coords using the ACTUAL library function
y_coords, x_coords, mask = generate_aperture_coords(
    center=(center, center),
    radius=planet_sep,
    planet_angle=planet_angle,
    n_apertures=n_apertures,
    fwhm=fwhm
)

# Draw planet aperture in red
planet_circle = Circle((planet_x, planet_y), fwhm/2, fill=False, color='red', linewidth=3)
ax.add_patch(planet_circle)

# Draw reference apertures using the actual computed coordinates
y_arr = np.array(y_coords)
x_arr = np.array(x_coords)
mask_arr = np.array(mask)

for i in range(len(mask_arr)):
    if mask_arr[i]:
        circle = Circle((x_arr[i], y_arr[i]), fwhm/2, fill=False, 
                        color='white', linewidth=1.5, alpha=0.8)
        ax.add_patch(circle)

ax.scatter([center], [center], marker='+', s=200, color='white', linewidths=2)
ax.set_title(f'Mawet Method: {n_apertures} ref apertures at r={planet_sep}px\nSNR = {snr_mawet:.1f}', 
             fontsize=12, fontweight='bold')
ax.set_xlabel('X (pixels)')
ax.set_ylabel('Y (pixels)')

# Legend
legend_elements = [
    Line2D([0], [0], color='red', linewidth=3, label='Planet aperture'),
    Line2D([0], [0], color='white', linewidth=1.5, label=f'{n_apertures} reference apertures')
]
ax.legend(handles=legend_elements, loc='upper right', facecolor='black', labelcolor='white')

plt.colorbar(ax.images[0], ax=ax, label='e-', shrink=0.8)
plt.tight_layout()
plt.show()

## SNR vs Flux

In [None]:
# Measure SNR across different flux levels
np.random.seed(0)  # For reproducibility
flux_values = [100, 200, 300, 400, 500, 600]
snrs_mawet = []

for flux in flux_values:
    # Create test image with this flux
    test_image = np.random.normal(0, noise_level, (size, size))
    test_image += flux * np.exp(-r2 / (2 * sigma**2))
    test_jax = jnp.array(test_image)
    
    snrs_mawet.append(float(snr(test_jax, positions, fwhm)[0]))

# Plot
fig, ax = plt.subplots(figsize=(8, 5))
ax.plot(flux_values, snrs_mawet, 'o-', label='Mawet SNR', color='tab:blue', linewidth=2)
ax.axhline(y=5, color='red', linestyle=':', alpha=0.7, label='5Ïƒ threshold')

ax.set_xlabel('Planet Flux (e-)', fontsize=11)
ax.set_ylabel('SNR', fontsize=11)
ax.set_title('SNR vs Planet Flux', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

## Using Estimators for Pipelines

For high-performance iterative pipelines, use `snr_estimator()` to pre-compute the aperture kernel:

In [None]:
import time

# Create estimator once
estimator = snr_estimator(fwhm, fast=True)

# Warmup (JIT compilation)
_ = estimator(image_jax, positions).block_until_ready()

# Time repeated calls
t0 = time.time()
for _ in range(100):
    estimator(image_jax, positions).block_until_ready()
elapsed = (time.time() - t0) / 100 * 1000

print(f"SNR Estimator: {elapsed:.2f} ms/call")

## Summary

| Function | Class | Use Case |
|----------|-------|----------|
| `snr()` | `SNREstimator` | Mawet 2014 method - publications, detection claims |

For iterative pipelines, use the estimator:
```python
estimator = snr_estimator(fwhm, fast=True)  # Pre-compute kernel
snrs = estimator(image, positions)          # Fast repeated calls
```