In [None]:
import matplotlib.pyplot as plt
from astropy.table import Table
import numpy as np
import arya

import astropy.units as u
from astropy.nddata import CCDData
from astropy.nddata import Cutout2D
from astropy.coordinates import SkyCoord

import tomllib

In [None]:
import numpy as np

In [None]:
import astropy

In [None]:
import sys
sys.path.append("..")
sys.path.append("../../imaging/")
from phot_utils import to_mag, get_atm_extinction, show_image, swap_byteorder
from photutils.aperture import CircularAperture, SkyCircularAperture



In [None]:
def read_catalogue(filename):
    return Table.read(filename, hdu=2, format="fits")

In [None]:
filtname = "i"
objname = "yasone2"
objid = "01"
stdid = "11"
frame = True

imgdir = f"../{objname}/stacked_{filtname}"
if frame:
    imgdir = f"../{objname}/img_{filtname}_{objid}"
stdname = f"../std1/img_{filtname}_{stdid}/flat_fielded-astrom-zeropoint.toml"


In [None]:

if frame:
    imagename = imgdir + "/nobkg.fits"
else:
    imagename = imgdir + "/coadd.fits"

filename = imgdir + "/detection.cat"
flagname = imgdir + "/flag.fits"

In [None]:
with open(stdname, "rb") as f:
    std_data = tomllib.load(f)
std_data

In [None]:
img = CCDData.read(imagename, unit="adu")
flag = CCDData.read(flagname, unit="adu")
img.data[flag.data > 0] = np.nan
# assert std_data["filter"] == img.header["filter2"]

In [None]:
apertures = np.array([1, 2,3,  5, 7, 10, 20])

In [None]:
img.wcs.pixel_scale_matrix * 3600
pixel_scale = 0.254 * u.arcsec

In [None]:
ap_idx_best = 3
ap_best = apertures[ap_idx_best]
ap_best_world = ap_best * pixel_scale
ap_best
r_sep_min = 5*u.arcsec
ap_best_world, r_sep_min

In [None]:
apertues_world = apertures * pixel_scale

In [None]:
tab

In [None]:
airmass = 1.2

In [None]:
def get_zeropoint(filt, exposure=190, gain=1.9):

    return std_data["zeropoint"] + std_data["ap_corr"] + 2.5 * np.log10(exposure) - get_atm_extinction(airmass, f"Sloan_{filt}")[0] - 2.5 * np.log10(gain)

In [None]:
def get_airmass(file):
    path = Path(file).parent
    img = CCDData.read(path / "flat_fielded.fits")
    return img.header["AIRMASS"]

In [None]:
tab["MAG"] = tab["MAG_APER"][:, ap_idx_best] + get_zeropoint(filtname)

In [None]:
coords_all = SkyCoord(tab["ALPHA_J2000"], tab["DELTA_J2000"])

In [None]:
tab["DIST_NN_WORLD"] = coords_all.match_to_catalog_sky(coords_all, nthneighbor=2)[1]

In [None]:
plt.hist(tab["DIST_NN_WORLD"].to("arcsec"))

In [None]:
filt_good = (tab["FLAGS"] < 8) 
# filt_good &= tab["FLAGS_WEIGHT"] == 0
# filt_good &= tab["IMAFLAGS_ISO"] == 0
filt_good &= tab["MAG"] < 21
# filt_good &= tab["MAG"] > 17
filt_good &= tab["ELLIPTICITY"] < 0.2
filt_good &= tab["DIST_NN_WORLD"] > r_sep_min
filt_good &= ~astropy.stats.sigma_clip(tab["FWHM_WORLD"], sigma=3, stdfunc=astropy.stats.mad_std).mask

In [None]:
plt.figure()
bins = np.linspace(15, 27, 50)
plt.hist(tab["MAG"], bins)
plt.hist(tab["MAG"][filt_good], bins)
plt.yscale("log")
plt.xlabel("magnitude")
plt.ylabel("count")

In [None]:
filt_good

In [None]:
good_cat = tab[filt_good]

In [None]:
np.sum(filt_good)

In [None]:
%matplotlib ipympl

In [None]:
cens = [a for a in zip(good_cat["XWIN_IMAGE"]-1, good_cat["YWIN_IMAGE"]-1)] # adjust from fits to numpy coordinates

In [None]:
coords = img.wcs.pixel_to_world(good_cat["XWIN_IMAGE"]-1, good_cat["YWIN_IMAGE"]-1)

In [None]:
fig, ax = plt.subplots(subplot_kw=dict(projection=img.wcs))

show_image(img, ax=ax, fig=fig, log=True)
for i in range(len(good_cat)):
    sky_ap = SkyCircularAperture(coords[i], r=r_sep_min)
    pix_ap = sky_ap.to_pixel(img.wcs)
    pix_ap.plot(
        ax=ax,
        color='red',
        lw=0.5
    )
plt.scatter(good_cat["XWIN_IMAGE"], good_cat["YWIN_IMAGE"], s=0.1, edgecolor="red",)


In [None]:
%matplotlib inline

In [None]:
plt.close()
plt.figure();

In [None]:
fwhm = np.median(good_cat["FWHM_WORLD"].to("arcsec")) 
fwhm

In [None]:
fwhm_max = 3 * astropy.stats.mad_std(good_cat["FWHM_WORLD"].to("arcsec")) + fwhm
fwhm_max

In [None]:
plt.hist(good_cat["FWHM_WORLD"].to("arcsec"))
plt.axvline(fwhm / u.arcsec, color=arya.COLORS[1])

In [None]:
s = plt.scatter(good_cat["X_IMAGE"], good_cat["Y_IMAGE"], c=good_cat["FWHM_WORLD"].to("arcsec") / u.arcsec)

plt.colorbar(s, label="FWHM / arcsec")

In [None]:
def plot_pixel_cutouts(
    ccd,
    centers_xy,
    cutout_size=10 * u.arcsec,
    aperture_radius=ap_best,
):
    """
    Plot pixel-coordinate cutouts with circular apertures.

    Parameters
    ----------
    ccd : astropy.nddata.CCDData
        Input image (WCS optional but recommended).
    centers_xy : list of (x, y)
        Pixel coordinates of centers (0-indexed).
    cutout_size : Quantity or int
        Size of cutout (square). If Quantity, converted via WCS.
    aperture_radius : float
        Aperture radius in pixels.
    """
    for i, (x, y) in enumerate(centers_xy):
        cutout = Cutout2D(
            data=ccd.data,
            position=(x, y),
            size=cutout_size,
            wcs=ccd.wcs
        )

        # Center of the source *within the cutout*
        # (important if the cutout was clipped at image edges)
        x_c, y_c = cutout.to_cutout_position([x, y])

        aperture = CircularAperture((x_c, y_c), r=aperture_radius)

        fig, ax = plt.subplots()
        im = ax.imshow(cutout.data, norm="asinh", vmin=1)

        aperture.plot(
            ax=ax,
            color='red',
            lw=1.5
        )
        plt.scatter(x_c, y_c, color="red")

        ax.set_title(f'Cutout {i}')
        ax.set_xlabel('x [pix]')
        ax.set_ylabel('y [pix]')
        
        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.tight_layout()

        plt.tight_layout()
        plt.show()

In [None]:
def plot_sky_cutouts(ccd, centers, size=10*u.arcsec, aperture_radius=ap_best_world):
    """
    Plot cutouts around sky-coordinate centers.

    Parameters
    ----------
    ccd : astropy.nddata.CCDData
        Input image with WCS.
    centers : list of astropy.coordinates.SkyCoord
        Center positions.
    size : astropy.units.Quantity
        Cutout size (assumed square).
    """
    for i, center in enumerate(centers):
        cutout = Cutout2D(
            data=ccd.data,
            position=center,
            size=size,
            wcs=ccd.wcs
        )

        fig, ax = plt.subplots(
            subplot_kw={'projection': cutout.wcs},
            figsize=(4, 4)
        )
        im = ax.imshow(cutout.data, norm="asinh", vmin=1)

        # plt.scatter(*ccd.wcs.world_to_pixel(center), color="red")
        sky_ap = SkyCircularAperture(center, r=aperture_radius)
        pix_ap = sky_ap.to_pixel(cutout.wcs)
        pix_ap.plot(
            ax=ax,
            color='red',
            lw=1.5
        )
        plt.scatter(*cutout.wcs.world_to_pixel(center), color="red")


        ax.set_title(f'Cutout {i}')
        ax.set_xlabel('RA')
        ax.set_ylabel('Dec')

        plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
        plt.tight_layout()
        plt.show()


In [None]:
plot_sky_cutouts(img, coords[:10])

In [None]:
plot_pixel_cutouts(img, cens[1:10], cutout_size=40)

## stacked psf

In [None]:
from convenience_functions import combine_images

In [None]:
def get_cutouts(
    ccd,
    centers_xy,
    cutout_size=101,
):
    cutouts = []
    for x, y in centers_xy:
        cutout = Cutout2D(
            data=ccd.data,
            position=((x), (y)),
            size=cutout_size,
            wcs=ccd.wcs
        )
        cutouts.append(cutout)

    return cutouts

In [None]:
cutouts = get_cutouts(img, cens)


In [None]:
import sep

In [None]:
from photutils.centroids import centroid_com
from scipy.ndimage import shift


In [None]:
def recentre_cutout(data, target_xy=None, order=3):
    """
    Shift a cutout so its flux-weighted centroid is at target_xy.

    Parameters
    ----------
    data : 2D ndarray
        Cutout image.
    target_xy : tuple, optional
        Desired (x, y) pixel location.
        Defaults to image center.
    order : int
        Interpolation order (3 = cubic).

    Returns
    -------
    shifted : 2D ndarray
        Recentred image.
    """
    ny, nx = data.shape

    if target_xy is None:
        target_xy = ((nx - 1) / 2, (ny - 1) / 2)

    x_c, y_c = centroid_com(data)

    dx = target_xy[0] - x_c
    dy = target_xy[1] - y_c

    shifted = shift(
        data,
        shift=(dx, dy),   # (row, col) order!
        order=order,
        mode='constant',
        cval=0.0
    )

    return data


In [None]:
cutouts_normalized = [
    CCDData(cutout.data / flux, unit="adu") 
    for cutout, flux in zip(cutouts, good_cat["FLUX_APER"][:, 3])
]

cutouts_good = [cutout for cutout in cutouts_normalized if cutout.shape == (100, 100)]

In [None]:
cutouts_centred = []
for im in cutouts_normalized:
    cutouts_centred.append(CCDData(recentre_cutout(im.data), unit="adu"))
    

In [None]:
cutouts_good

In [None]:
median_psf = combine_images(cutouts_good, method="median")

In [None]:
median_psf2 = np.median(np.stack(cutouts_good), axis=0)

In [None]:
show_image(median_psf, dpi=100, clim=(None, None), log=True)

In [None]:
CCDData(median_psf).write(imgdir + "/psf.fits", overwrite=True)

In [None]:
psf_fluxes = sep.sum_circle(median_psf.data, [(median_psf.shape[0]-1)/2], [(median_psf.shape[1]-1)/2], apertures)[0]
psf_mags = to_mag(psf_fluxes, 0, 0, 0)[0]
psf_mags -= psf_mags[5]

In [None]:
psf_fluxes2 = sep.sum_circle(median_psf2, [(median_psf2.shape[0]-1)/2], [(median_psf2.shape[1]-1)/2], apertures)[0]
psf_mags2 = to_mag(psf_fluxes2, 0, 0, 0)[0]
psf_mags2 -= psf_mags2[5]

In [None]:
r_sep_min

In [None]:
plt.plot(apertures, psf_mags)
plt.plot(apertures, psf_mags2)
plt.ylim(2, -0.5)
plt.xlim(0, r_sep_min / pixel_scale)

In [None]:
profiles = good_cat["MAG_APER"] - good_cat["MAG_APER"][:, 5].reshape(-1, 1)

In [None]:
profile_median = np.median(profiles, axis=0)

In [None]:
profiles_sep = []
profiles_sep_flux = []

for i in range(len(cutouts)):
    cutout = cutouts[i]
    x, y = cutout.to_cutout_position(cens[i])
    x, y = cutout.center_cutout

    fluxes = sep.sum_circle(swap_byteorder(cutout.data), [x], [y], apertures)[0]
    mags = to_mag(fluxes, 0, 0, 0)[0]
    mags -= mags[5]
    profiles_sep.append(mags)
    profiles_sep_flux.append(fluxes)
    

In [None]:
img_nobkg = swap_byteorder(img.data) - sep.Background(swap_byteorder(img.data), bw=128, bh=128, fw=6, fh=6)

In [None]:
profiles_sep = []
profiles_sep_flux = []

for i in range(len(cutouts)):
    cutout = cutouts[i]
    x, y = (cens[i])

    fluxes = sep.sum_circle(img_nobkg, [x], [y], apertures)[0]
    mags = to_mag(fluxes, 0, 0, 0)[0]
    mags -= mags[5]
    profiles_sep.append(mags)
    profiles_sep_flux.append(fluxes)
    

In [None]:
apertures

In [None]:
plt.figure()

for i in range(len(good_cat)):
    plt.plot(apertures, good_cat["MAG_APER"][i] - good_cat["MAG_APER"][i][5], alpha=0.2, color=arya.COLORS[0])

plt.ylim(2, -1)
plt.xlabel("aperture size / pixel")
plt.ylabel("relative aperture magnitude")

plt.plot(apertures, profile_median, color=arya.COLORS[1], label="median profile")
plt.plot(apertures, psf_mags, color="black", label="stacked psf")

plt.legend()