# Build PSF with STARRED (single star)

目标：
- 从 `WFI2033` 原始 `SCI/ERR` 中，在 `(x=8101, y=3465)` 裁切 `101x101` 星像。
- 用 `starred.procedures.build_psf` 拟合 supersampled PSF。
- 输出 `404x404` supersampled PSF 和 `101x101` detector PSF。

说明：
- 这里坐标默认按 **0-based** 像素（与 Python 索引一致）。
- 如果你手头坐标来自 DS9（1-based），把 `coords_are_one_based=True` 即可。


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import colors
from astropy.io import fits
from pathlib import Path

import starred.procedures


In [None]:
DATA_DIR = Path('../../Data/WFI2033')
raw_fits = DATA_DIR / 'jw01198-o004_t004_nircam_clear-f115w_i2d.fits'
out_fits = DATA_DIR / 'F115W_PSF_starred_x8101_y3465.fits'

# input coordinates
x_in, y_in = 8101, 3465
coords_are_one_based = False

cut_size = 101
supersampling_factor = 4

# starred optimizer iterations (CPU 下 400/1000 大约 1-2 分钟)
n_iter_analytic = 400
n_iter_adabelief = 1000


In [None]:
# coordinate convention
if coords_are_one_based:
    x0 = int(x_in) - 1
    y0 = int(y_in) - 1
else:
    x0 = int(x_in)
    y0 = int(y_in)

half = cut_size // 2

with fits.open(raw_fits, memmap=True) as hdul:
    sci_full = np.array(hdul['SCI'].data, dtype=np.float64)
    err_full = np.array(hdul['ERR'].data, dtype=np.float64)

ny, nx = sci_full.shape
x1, x2 = x0 - half, x0 + half + 1
y1, y2 = y0 - half, y0 + half + 1

if x1 < 0 or y1 < 0 or x2 > nx or y2 > ny:
    raise ValueError(f'Cutout out of bounds: x[{x1}:{x2}], y[{y1}:{y2}], image shape={(ny, nx)}')

star_sci = sci_full[y1:y2, x1:x2].copy()
star_err = err_full[y1:y2, x1:x2].copy()

print('center (0-based):', (x0, y0))
print('cutout shape:', star_sci.shape)


In [None]:
# Robust local background subtraction from SCI border
k = 10
border_vals = np.hstack([
    star_sci[:k, :].ravel(),
    star_sci[-k:, :].ravel(),
    star_sci[:, :k].ravel(),
    star_sci[:, -k:].ravel(),
])
local_bkg = np.nanmedian(border_vals)
star_sci_bgsub = star_sci - local_bkg

# sanitize ERR (noise map)
valid_noise = np.isfinite(star_err) & (star_err > 0)
if not np.any(valid_noise):
    raise RuntimeError('ERR cutout has no finite positive pixels.')
noise_floor = np.nanmedian(star_err[valid_noise])
star_err_clean = np.where(valid_noise, star_err, noise_floor)
star_err_clean = np.clip(star_err_clean, 0.3 * noise_floor, None)

print('local_bkg:', float(local_bkg))
print('noise_floor:', float(noise_floor))


In [None]:
fig, axes = plt.subplots(1, 2, figsize=(10, 4.5))

norm1 = colors.SymLogNorm(
    linthresh=max(np.nanstd(star_sci_bgsub) * 0.05, 1e-3),
    vmin=np.nanpercentile(star_sci_bgsub, 5),
    vmax=np.nanpercentile(star_sci_bgsub, 99.8),
)
im0 = axes[0].imshow(star_sci_bgsub, origin='lower', cmap='gray', norm=norm1)
axes[0].set_title('SCI cutout (bg-subtracted)')
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

norm2 = colors.LogNorm(
    vmin=max(np.nanpercentile(star_err_clean, 1), 1e-6),
    vmax=np.nanpercentile(star_err_clean, 99.5),
)
im1 = axes[1].imshow(star_err_clean, origin='lower', cmap='magma', norm=norm2)
axes[1].set_title('ERR cutout (cleaned)')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()


In [None]:
# STARRED expects shape (n_stars, ny, nx)
stars = star_sci_bgsub[None, :, :]
stars_noise = star_err_clean[None, :, :]

psf_output = starred.procedures.build_psf(
    stars,
    stars_noise,
    supersampling_factor,
    n_iter_analytic=n_iter_analytic,
    n_iter_adabelief=n_iter_adabelief,
)

full_psf = np.array(psf_output['full_psf'], dtype=np.float64)
print('full_psf shape:', full_psf.shape)


In [None]:
ny_psf, nx_psf = full_psf.shape
ss = supersampling_factor
if ny_psf % ss != 0 or nx_psf % ss != 0:
    raise RuntimeError(f'full_psf shape {full_psf.shape} not divisible by supersampling factor {ss}')

det_psf_raw = full_psf.reshape(ny_psf // ss, ss, nx_psf // ss, ss).sum(axis=(1, 3))
det_psf_clip = np.clip(det_psf_raw, 0, None)
det_psf_norm = det_psf_clip / det_psf_clip.sum()

print('det_psf_raw shape:', det_psf_raw.shape)
print('det_psf_norm sum:', float(det_psf_norm.sum()))


In [None]:
fig, axes = plt.subplots(1, 3, figsize=(14, 4.2))

im0 = axes[0].imshow(full_psf, origin='lower', cmap='viridis', norm=colors.SymLogNorm(linthresh=1e-7))
axes[0].set_title('FULL_PSF_SUPER4 (404x404)')
plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

im1 = axes[1].imshow(det_psf_raw, origin='lower', cmap='viridis', norm=colors.SymLogNorm(linthresh=1e-6))
axes[1].set_title('DET_PSF_RAW (101x101)')
plt.colorbar(im1, ax=axes[1], fraction=0.046, pad=0.04)

im2 = axes[2].imshow(det_psf_norm, origin='lower', cmap='viridis', norm=colors.SymLogNorm(linthresh=1e-8))
axes[2].set_title('DET_PSF_NORM (sum=1)')
plt.colorbar(im2, ax=axes[2], fraction=0.046, pad=0.04)

for ax in axes:
    ax.set_xticks([])
    ax.set_yticks([])

plt.tight_layout()
plt.show()


In [None]:
hdus = [
    fits.PrimaryHDU(),
    fits.ImageHDU(star_sci_bgsub.astype(np.float32), name='STAR_CUTOUT_SCI'),
    fits.ImageHDU(star_err_clean.astype(np.float32), name='STAR_CUTOUT_ERR'),
    fits.ImageHDU(full_psf.astype(np.float32), name='FULL_PSF_SUPER4'),
    fits.ImageHDU(det_psf_raw.astype(np.float32), name='DET_PSF_RAW'),
    fits.ImageHDU(det_psf_norm.astype(np.float32), name='DET_PSF_NORM'),
]

for h in hdus[1:]:
    h.header['X_CENTER'] = x0
    h.header['Y_CENTER'] = y0
    h.header['CUTSIZE'] = cut_size

hdus[0].header['COMMENT'] = 'Generated by Build_PSF_starred_x8101_y3465.ipynb'
hdus[0].header['SS_FACT'] = supersampling_factor
hdus[0].header['ITERANA'] = n_iter_analytic
hdus[0].header['ITERADA'] = n_iter_adabelief

fits.HDUList(hdus).writeto(out_fits, overwrite=True)
print('saved:', out_fits)
