In [None]:
from astropy.table import Table
from astropy.nddata import CCDData, Cutout2D
from astropy.coordinates import SkyCoord
import astropy.units as u
import astropy

import numpy as np
import matplotlib.pyplot as plt

In [None]:
import arya

In [None]:
import tomllib

In [None]:
import sys
sys.path.append("../")
sys.path.append("../../imaging")
import phot_utils
from phot_utils import show_image

In [None]:
import photutils
import photutils.psf

In [None]:
imgdir = "../yasone2/img_i_01"
cat = Table.read(imgdir + "/detection.cat", hdu=2)
img = CCDData.read(imgdir + "/nobkg.fits", unit="adu")
mask = CCDData.read(imgdir + "/flag.fits", unit="adu")
img_err = CCDData.read(imgdir + "/flat_fielded.weight.fits", unit="adu")
psf = CCDData.read("../psf_osiris_r.fits", unit="adu") #CCDData.read(imgdir + "/psf.fits", unit="adu")

img_masked = img.data.copy()
img_masked[mask.data == 1] = np.nan
img_masked[mask.data >= 2 ]  = np.nan
img_masked = CCDData(img_masked, unit="adu")

In [None]:
global_bkg = np.median(img_err)

In [None]:
global_bkg

In [None]:
show_image(psf, log=True)

In [None]:
show_image(img_masked, log=True, clim=(150, 15_00))

In [None]:
import photutils.detection

In [None]:
global_rms = np.median(img_err) 

In [None]:
sources = photutils.detection.DAOStarFinder(5*global_rms, 0.60)(img.data)
sources

# Isolating sources

In [None]:
airmass = img.header["airmass"]
filtname = img.header["FILTER2"].split("Sloan_")[1]

In [None]:
stdid = "11"

stdname = f"../std1/img_{filtname}_{stdid}/flat_fielded-astrom-zeropoint.toml"

In [None]:
with open(stdname, "rb") as f:
    std_data = tomllib.load(f)
std_data

In [None]:
ap_idx_best = 3
r_sep_min = 5*u.arcsec


In [None]:
def get_zeropoint(filt, exposure=190, gain=1.9):

    return std_data["zeropoint"] + std_data["ap_corr"] + 2.5 * np.log10(exposure) - phot_utils.get_atm_extinction(airmass, f"Sloan_{filt}")[0] - 2.5 * np.log10(gain)

In [None]:
def get_airmass(file):
    path = Path(file).parent
    img = CCDData.read(path / "flat_fielded.fits")
    return img.header["AIRMASS"]

In [None]:
cat["MAG"] = cat["MAG_APER"][:, ap_idx_best] + get_zeropoint(filtname)

In [None]:
coords_all = SkyCoord(cat["ALPHA_J2000"], cat["DELTA_J2000"])

In [None]:
cat["DIST_NN_WORLD"] = coords_all.match_to_catalog_sky(coords_all, nthneighbor=2)[1]

In [None]:
filt_good = cat["FLAGS"] < 15 # (np.isin(cat["FLAGS"], [1, 2, 3, 4, 0]) )
# filt_good &= tab["FLAGS_WEIGHT"] == 0
# filt_good &= tab["IMAFLAGS_ISO"] == 0
filt_good &= cat["MAG"] < 20
# filt_good &= tab["MAG"] > 17
filt_good &= cat["ELLIPTICITY"] < 0.2
# filt_good &= cat["DIST_NN_WORLD"] > r_sep_min
filt_good &= ~astropy.stats.sigma_clip(cat["FWHM_WORLD"], sigma=3, stdfunc=astropy.stats.mad_std).mask

In [None]:
cat_good = cat[filt_good]

In [None]:
show_image(img_masked, log=True, clim=(150, 15_00))
plt.scatter(cat_good["X_IMAGE"]-1, cat_good["Y_IMAGE"]-1)

In [None]:
cat_good["x"] = cat_good["X_IMAGE"] - 1
cat_good["y"] = cat_good["Y_IMAGE"] - 1

In [None]:
stars = photutils.psf.extract_stars(img, cat_good, size=(51, 51))

In [None]:
from astropy.visualization import simple_norm
nrows = 10
ncols = 5
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(20, 20),
                       squeeze=True)
ax = ax.ravel()
for i in range(nrows * ncols):
    norm = simple_norm(stars[i], 'log', percent=99.0)
    ax[i].imshow(stars[i], norm=norm, origin='lower', cmap='viridis')


In [None]:
epsf, model_stars = photutils.psf.EPSFBuilder(maxiters=5)(stars)


In [None]:
show_image(epsf.data, log=True, dpi=100)

In [None]:
img_full = CCDData(img.data, img_err.data, mask = mask.data == 0, unit="adu")

In [None]:
fit_shape=(11, 11)
finder = photutils.detection.DAOStarFinder(2*global_rms , 0.60)
psf_phot = photutils.psf.PSFPhotometry(epsf, fit_shape, finder=finder, aperture_radius=5)

In [None]:
phot = psf_phot(img.data, mask=mask.data > 0, error=img_err.data)

In [None]:
phot

In [None]:
phot_good = phot[phot["flags"] == 1]

In [None]:
photutils.psf.decode_psf_flags(phot["flags"])

In [None]:
show_image(img_masked, log=True, clim=(global_rms, 10*global_rms))
plt.scatter(phot["x_fit"], phot["y_fit"])
plt.scatter(phot_good["x_fit"], phot_good["y_fit"])

plt.xlim(0, 2300)
plt.ylim(0, 2000)

In [None]:
resid = psf_phot.make_residual_image(img.data)

In [None]:
show_image(img_full.data - resid.data, log=True, clim=(0, 10*global_rms))

In [None]:
show_image(resid,  clim=(-10*global_rms, 10*global_rms), cmap="RdBu")

In [None]:
show_image(img_masked, log=True, clim=(150, 15_00))

plt.scatter(phot_good["x_fit"], phot_good["y_fit"], )
plt.colorbar()