In [1]:
import os
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits
from scipy import ndimage as ndi
from scipy.spatial import cKDTree



RAW_FITS = "mosaic.fits"
CLEAN_FITS = "mosaic_cleaned_enhanced.fits" 

OUT_DIR = os.path.join(os.getcwd(), "outputs")
os.makedirs(OUT_DIR, exist_ok=True)

OUT_OVERLAY_JPG   = os.path.join(OUT_DIR, "detections_overlay_knn.jpg")
OUT_LOGN_ALL_JPG  = os.path.join(OUT_DIR, "logN_all_stars_gals.jpg")
OUT_LOGN_GALFIT_JPG = os.path.join(OUT_DIR, "logN_galaxies_weightedfit.jpg")

# detector
BAD_VALUE = 3421.0      # missing pixel value
SAT_LEVEL = 36000.0     # saturation cut for detection

# photometric calibration
ZP = 25.3               # zero point [mag]
K_EXT = None
AIRMASS = None

# detection
DETECT_SIGMA = 5.0
MIN_AREA = 6

# kNN dense-halo removal
KNN_K = 10
CROWDED_FRACTION = 0.09
GRID = 10
HOT_PERCENTILE = 90.0
DILATE_CELLS = 5
MIN_REGION_CELLS = 10

# aperture photometry (3 diameter)
PIX_SCALE = 0.258
R_AP = 1.5 / PIX_SCALE
R_IN = R_AP + 3.0
R_OUT = R_AP + 8.0
EDGE_PAD = int(np.ceil(R_OUT + 2))

# star / galaxy separation
R_SMALL = 2.0
STAR_C_MAX = 0.8
STAR_RE_MAX = 3.8

# logN(m)
DM = 0.25
FIT_M_MIN = 11.0
FIT_M_MAX = 16.5

# overlay
CIRCLE_RADIUS_PX = 6
CIRCLE_LW = 0.6
DOWNSAMPLE = 2
P_LOW, P_HIGH = 5.0, 99.7



# Basic functions

def robust_bg_sigma(image, bad_value=BAD_VALUE):
    v = image.ravel()
    v = v[np.isfinite(v) & (v != bad_value)]
    bg = float(np.median(v))
    mad = float(np.median(np.abs(v - bg)))
    sig = float(1.4826 * mad) if mad > 0 else float(np.std(v))
    return bg, sig


def stretch_for_display(img):
    good = np.isfinite(img) & (img != BAD_VALUE)
    v = img[good]
    lo = float(np.percentile(v, P_LOW))
    hi = float(np.percentile(v, P_HIGH))
    if hi <= lo:
        hi = lo + 1.0
    return lo, hi


# Detection and kNN halo filter

def detect_sources(img):
    """Connected-component detection above bg + DETECT_SIGMA*sig."""
    bg, sig = robust_bg_sigma(img, BAD_VALUE)
    thr = bg + DETECT_SIGMA * sig

    valid = np.isfinite(img) & (img != BAD_VALUE)
    det = valid & (img >= thr) & (img < SAT_LEVEL)

    labels, nlab = ndi.label(det, structure=np.ones((3,3), int))
    slices = ndi.find_objects(labels)

    cat = []
    for lab_id, slc in enumerate(slices, start=1):
        if slc is None:
            continue
        m = (labels[slc] == lab_id)
        area = int(m.sum())
        if area < MIN_AREA:
            continue

        sub = img[slc]
        tmp = np.where(m, sub, -np.inf)
        iy, ix = np.unravel_index(int(np.argmax(tmp)), tmp.shape)
        y_peak = slc[0].start + iy
        x_peak = slc[1].start + ix

        vals = sub[m] - bg
        vals = np.clip(vals, 0, None)
        if vals.sum() > 0:
            ys, xs = np.nonzero(m)
            y_cent = float(slc[0].start + np.sum(ys * vals) / np.sum(vals))
            x_cent = float(slc[1].start + np.sum(xs * vals) / np.sum(vals))
        else:
            y_cent, x_cent = float(y_peak), float(x_peak)

        ny, nx = img.shape
        if x_cent < EDGE_PAD or x_cent > nx-1-EDGE_PAD or y_cent < EDGE_PAD or y_cent > ny-1-EDGE_PAD:
            continue

        r_eq = float(np.sqrt(area / np.pi))

        cat.append({
            "id": lab_id,
            "x": x_cent,
            "y": y_cent,
            "x_peak": int(x_peak),
            "y_peak": int(y_peak),
            "area": area,
            "r_eq": r_eq
        })

    info = {"bg": bg, "sigma": sig, "thr": thr, "n_labels": int(nlab), "n_cat": len(cat)}
    return cat, info


def build_region_from_points(xs, ys, shape, grid=GRID,
                             hot_percentile=HOT_PERCENTILE,
                             dilate_cells=DILATE_CELLS,
                             min_region_cells=MIN_REGION_CELLS):
    """Turn a very dense set of (x,y) positions into a contiguous pixel region."""
    ny, nx = shape
    gx = int(np.ceil(nx / grid))
    gy = int(np.ceil(ny / grid))

    ix = np.clip((xs / grid).astype(int), 0, gx-1)
    iy = np.clip((ys / grid).astype(int), 0, gy-1)

    H = np.zeros((gy, gx), dtype=np.int32)
    np.add.at(H, (iy, ix), 1)

    nz = H[H > 0]
    if nz.size == 0:
        return None

    thr = float(np.percentile(nz, hot_percentile))
    hot = H >= thr

    lab, nlab = ndi.label(hot, structure=np.ones((3,3), int))
    if nlab == 0:
        return None

    best = None
    best_score = -1.0
    for k in range(1, nlab+1):
        comp = (lab == k)
        score = float(H[comp].sum())
        if score > best_score:
            best_score = score
            best = comp

    if int(best.sum()) < min_region_cells:
        return None

    if dilate_cells and dilate_cells > 0:
        best = ndi.binary_dilation(best, iterations=int(dilate_cells))

    pix = np.kron(best.astype(np.uint8), np.ones((grid, grid), dtype=np.uint8)).astype(bool)
    return pix[:ny, :nx]


def knn_filter(cat, shape):
    """Remove the very dense halo around the bright star using kNN density."""
    if len(cat) == 0:
        return cat

    xs = np.array([o["x"] for o in cat], float)
    ys = np.array([o["y"] for o in cat], float)
    pts = np.column_stack([xs, ys])

    if len(pts) < (KNN_K + 2):
        return cat

    tree = cKDTree(pts)
    dists, _ = tree.query(pts, k=min(KNN_K+1, len(pts)))
    dk = dists[:, -1]

    n = len(dk)
    m = max(1, int(np.floor(CROWDED_FRACTION * n)))
    idx_crowded = np.argpartition(dk, m-1)[:m]

    crowd_x = xs[idx_crowded]
    crowd_y = ys[idx_crowded]

    region = build_region_from_points(
        crowd_x, crowd_y, shape,
        grid=GRID,
        hot_percentile=HOT_PERCENTILE,
        dilate_cells=DILATE_CELLS,
        min_region_cells=MIN_REGION_CELLS
    )

    if region is None:
        keep = np.ones(n, dtype=bool)
        keep[idx_crowded] = False
        return [cat[i] for i in range(n) if keep[i]]

    xi = np.clip(xs.astype(int), 0, shape[1]-1)
    yi = np.clip(ys.astype(int), 0, shape[0]-1)
    keep = ~region[yi, xi]

    return [cat[i] for i in range(n) if keep[i]]



# Aperture photometry and classification

def aperture_photometry(img, x0, y0, r_ap, r_in, r_out):
    """Simple circular aperture with local background annulus."""
    ny, nx = img.shape
    if x0 < EDGE_PAD or x0 > nx-1-EDGE_PAD or y0 < EDGE_PAD or y0 > ny-1-EDGE_PAD:
        return None

    rbox = int(np.ceil(r_out + 2))
    x0f, y0f = float(x0), float(y0)

    x1 = max(0, int(np.floor(x0f)) - rbox); x2 = min(nx-1, int(np.floor(x0f)) + rbox)
    y1 = max(0, int(np.floor(y0f)) - rbox); y2 = min(ny-1, int(np.floor(y0f)) + rbox)

    patch = img[y1:y2+1, x1:x2+1]
    yy, xx = np.ogrid[y1:y2+1, x1:x2+1]
    rr2 = (xx - x0f)**2 + (yy - y0f)**2

    bad = (~np.isfinite(patch)) | (patch == BAD_VALUE)
    ap  = (rr2 <= r_ap**2) & (~bad)
    ann = (rr2 >= r_in**2) & (rr2 <= r_out**2) & (~bad)

    Nap, Nann = int(ap.sum()), int(ann.sum())
    if Nap < 10 or Nann < 30:
        return None

    ap_vals = patch[ap]
    ann_vals = patch[ann]

    bkg = float(np.median(ann_vals))
    flux_ap = float(ap_vals.sum())
    flux_net = float(flux_ap - bkg * Nap)
    if not (np.isfinite(flux_net) and flux_net > 0):
        return None

    # sky sigma and simple error estimate
    med = np.median(ann_vals)
    mad = np.median(np.abs(ann_vals - med))
    sigma_sky = float(1.4826 * mad) if mad > 0 else float(np.std(ann_vals))
    sig_b = sigma_sky / np.sqrt(Nann)
    flux_err = float(np.sqrt(Nap * sigma_sky**2 + (Nap * sig_b)**2))

    return {
        "flux_net": flux_net,
        "flux_err": flux_err,
        "bkg": bkg,
        "Nap": Nap,
        "Nann": Nann,
        "sigma_sky": sigma_sky
    }


def flux_to_cal_mag(F):
    m = ZP - 2.5*np.log10(F)
    if (K_EXT is not None) and (AIRMASS is not None):
        m = m - float(K_EXT)*float(AIRMASS)
    return float(m)


def classify_star_galaxy(img, cat):
    """
    Do aperture photometry, plus a small aperture for concentration index,
    then split stars / galaxies by a simple size–concentration cut.
    """
    n = len(cat)
    m_big = np.full(n, np.nan, float)
    m_err = np.full(n, np.nan, float)
    C = np.full(n, np.nan, float)
    r_eq = np.array([o["r_eq"] for o in cat], float)

    for i, o in enumerate(cat):
        p_big = aperture_photometry(img, o["x"], o["y"], R_AP, R_IN, R_OUT)
        p_sml = aperture_photometry(img, o["x"], o["y"], R_SMALL, R_IN, R_OUT)
        if (p_big is None) or (p_sml is None):
            continue

        mB = flux_to_cal_mag(p_big["flux_net"])
        mS = flux_to_cal_mag(p_sml["flux_net"])
        m_big[i] = mB
        m_err[i] = 1.0857362 * (p_big["flux_err"]/p_big["flux_net"])
        C[i] = float(mS - mB)

    ok = np.isfinite(m_big) & np.isfinite(C) & np.isfinite(r_eq)

    is_star = np.zeros(n, dtype=bool)
    is_star[ok] = (C[ok] <= STAR_C_MAX) & (r_eq[ok] <= STAR_RE_MAX)
    is_gal = ok & (~is_star)

    return m_big, m_err, C, is_star, is_gal



# log N(m) and weighted fit

def logN_curve(mags, dm=DM):
    mags = np.asarray(mags, float)
    mags = mags[np.isfinite(mags)]
    mags.sort()
    if mags.size < 5:
        return None, None, None

    m_min = np.floor(mags.min()/dm)*dm
    m_max = np.ceil(mags.max()/dm)*dm
    m_grid = np.arange(m_min, m_max+dm, dm)

    N = np.searchsorted(mags, m_grid, side="right")
    logN = np.full_like(m_grid, np.nan, float)
    ok = N > 0
    logN[ok] = np.log10(N[ok])
    return m_grid, logN, N


def weighted_logN_fit(m_grid, logN, N, m_min_fit, m_max_fit, Nmin=20, frac_max=0.95):
    """Poisson-weighted linear fit to logN vs m, over a chosen mag range."""
    if m_grid is None:
        return None

    m_grid = np.asarray(m_grid)
    logN = np.asarray(logN)
    N = np.asarray(N)

    mask = np.isfinite(logN) & (N >= Nmin)
    mask &= (m_grid >= m_min_fit) & (m_grid <= m_max_fit)

    Nmax = float(np.max(N[mask])) if np.any(mask) else float(np.max(N))
    mask &= (N <= frac_max * Nmax)

    m = m_grid[mask]
    y = logN[mask]
    Nfit = N[mask]
    if m.size < 3:
        return None

    sigma_y = 1.0 / (np.log(10.0)*np.sqrt(Nfit))
    w = 1.0/(sigma_y**2)

    A = np.vstack([m, np.ones_like(m)]).T
    Aw = A * np.sqrt(w)[:, None]
    yw = y * np.sqrt(w)

    (a, b), *_ = np.linalg.lstsq(Aw, yw, rcond=None)
    cov = np.linalg.inv(A.T @ (w[:,None]*A))
    sa = float(np.sqrt(cov[0,0]))
    sb = float(np.sqrt(cov[1,1]))
    return {"a": float(a), "b": float(b), "sa": sa, "sb": sb,
            "m_fit": m, "y_fit": y, "sigma_y": sigma_y}



# Plotting

def plot_overlay(img, cat, fname):
    ds = int(max(1, DOWNSAMPLE))
    img_show = img[::ds, ::ds]
    lo, hi = stretch_for_display(img)
    xs = np.array([o["x"] for o in cat], float) / ds
    ys = np.array([o["y"] for o in cat], float) / ds
    s_area = (2.0 * (CIRCLE_RADIUS_PX/ds))**2

    plt.figure(figsize=(8, 10))
    plt.imshow(img_show, origin="lower", cmap="gray", vmin=lo, vmax=hi,
               interpolation="nearest")
    plt.scatter(xs, ys, s=s_area, facecolors="none", edgecolors="red",
                linewidths=CIRCLE_LW)
    plt.xlabel("x [pix]")
    plt.ylabel("y [pix]")
    plt.title(f"Detections overlay (kNN halo removed, N={len(cat)})")
    plt.tight_layout()
    plt.savefig(fname, dpi=200, bbox_inches="tight")
    plt.close()


def plot_logN_all_stars_gals(m_all, y_all, N_all,
                             m_s,   y_s,   N_s,
                             m_g,   y_g,   N_g,
                             fit_res, fname):
    """
    Overview plot: all / stars / galaxies, each with Poisson error bars
    in log10 N.
    """
    def poisson_err(N, y):
        sigma = np.full_like(y, np.nan, float)
        mask = (N > 0) & np.isfinite(y)
        sigma[mask] = 1.0 / (np.log(10.0) * np.sqrt(N[mask]))
        return sigma, mask

    sigma_all, mask_all = poisson_err(N_all, y_all)
    sigma_s,   mask_s   = (None, None)
    sigma_g,   mask_g   = (None, None)

    if m_s is not None and N_s is not None:
        sigma_s, mask_s = poisson_err(N_s, y_s)
    if m_g is not None and N_g is not None:
        sigma_g, mask_g = poisson_err(N_g, y_g)

    plt.figure(figsize=(8, 5))

    # all sources
    if m_all is not None and np.any(mask_all):
        plt.errorbar(m_all[mask_all], y_all[mask_all],
                     yerr=sigma_all[mask_all],
                     fmt="-o", ms=3, lw=1.0, capsize=2,
                     label="All")

    # stars
    if (m_s is not None) and (sigma_s is not None) and np.any(mask_s):
        plt.errorbar(m_s[mask_s], y_s[mask_s],
                     yerr=sigma_s[mask_s],
                     fmt="-s", ms=3, lw=1.0, capsize=2,
                     label="Stars")

    # galaxies
    if (m_g is not None) and (sigma_g is not None) and np.any(mask_g):
        plt.errorbar(m_g[mask_g], y_g[mask_g],
                     yerr=sigma_g[mask_g],
                     fmt="^-", ms=3, lw=1.0, capsize=2,
                     label="Galaxies")

    # optional: overplot the galaxy fit line if available
    if (fit_res is not None) and (fit_res["m_fit"].size > 0):
        a, sa = fit_res["a"], fit_res["sa"]
        m_fit = fit_res["m_fit"]
        plt.plot(m_fit, a*m_fit + fit_res["b"], "r-",
                 lw=2, label=f"Gal fit: {a:.3f}±{sa:.3f}")

    plt.xlabel("Calibrated magnitude limit m")
    plt.ylabel(r"$\log_{10} N(m)$")
    plt.title(r"$\log N(m)$ vs $m$ (kNN-filtered) — All / Stars / Galaxies")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fname, dpi=200, bbox_inches="tight")
    plt.close()


def plot_logN_galaxies_with_errors(m_grid_g, logN_g, N_g, fit_res, fname):
    if m_grid_g is None:
        return
    y = logN_g.copy()
    N = N_g.copy()
    sigma_y = np.full_like(y, np.nan, float)
    okN = (N > 0)
    sigma_y[okN] = 1.0 / (np.log(10.0)*np.sqrt(N[okN]))

    plt.figure(figsize=(8,5))

    if fit_res is not None:
        mask_fit = np.isin(m_grid_g, fit_res["m_fit"])

        # excluded bins (grey) with error bars
        if np.any(~mask_fit & okN):
            plt.errorbar(m_grid_g[~mask_fit & okN], y[~mask_fit & okN],
                         yerr=sigma_y[~mask_fit & okN],
                         fmt="o", ms=3, capsize=2,
                         color="lightgray", label="excluded")

        # fitted bins with error bars
        plt.errorbar(fit_res["m_fit"], fit_res["y_fit"],
                     yerr=fit_res["sigma_y"], fmt="o", ms=3,
                     capsize=2, label="galaxies (Poisson err)")

        a, sa = fit_res["a"], fit_res["sa"]
        m = fit_res["m_fit"]
        plt.plot(m, a*m + fit_res["b"], "g-", lw=2,
                 label=f"weighted fit: {a:.3f}±{sa:.3f}")
    else:
        plt.errorbar(m_grid_g, y, yerr=sigma_y, fmt="o", ms=3, capsize=2,
                     label="galaxies (Poisson err)")

    plt.xlabel("Calibrated magnitude limit m")
    plt.ylabel(r"$\log_{10} N(m)$")
    plt.title("log N(m) vs m (galaxies, kNN-filtered) with weighted fit")
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.savefig(fname, dpi=200, bbox_inches="tight")
    plt.close()


# Main process

if __name__ == "__main__":
    # load cleaned image 
    with fits.open(CLEAN_FITS) as hdul:
        img = hdul[0].data.astype(float)

    # detection and kNN halo removal
    cat_raw, det_info = detect_sources(img)
    print("[detect]", det_info)
    cat = knn_filter(cat_raw, img.shape)
    print(f"[kNN] kept {len(cat)} / {len(cat_raw)} detections after halo removal")

    # overlay plot
    plot_overlay(img, cat, OUT_OVERLAY_JPG)
    print("Saved overlay:", OUT_OVERLAY_JPG)

    # star / galaxy classification
    mags, mag_err, C, is_star, is_gal = classify_star_galaxy(img, cat)
    print(f"[classify] usable photometry for {np.isfinite(mags).sum()} objects")
    print(f"[classify] stars={is_star.sum()}, galaxies={is_gal.sum()}")

    # logN curves
    m_all, y_all, N_all = logN_curve(mags, dm=DM)
    m_s,   y_s,   N_s   = logN_curve(mags[is_star], dm=DM)
    m_g,   y_g,   N_g   = logN_curve(mags[is_gal], dm=DM)

    # galaxy-only weighted fit
    fit_res = None
    if m_g is not None:
        fit_res = weighted_logN_fit(m_g, y_g, N_g,
                                    m_min_fit=FIT_M_MIN,
                                    m_max_fit=FIT_M_MAX,
                                    Nmin=20, frac_max=0.95)
        if fit_res is not None:
            print("[fit] slope a = {:.3f} ± {:.3f}, intercept b = {:.2f} ± {:.2f}".format(
                fit_res["a"], fit_res["sa"], fit_res["b"], fit_res["sb"]))
            print("[fit] mag range used: [{:.2f}, {:.2f}], points = {}".format(
                fit_res["m_fit"].min(), fit_res["m_fit"].max(), fit_res["m_fit"].size))
        else:
            print("[fit] not enough points for galaxy fit")

    # big overview plot: all / stars / galaxies (now with error bars)
    plot_logN_all_stars_gals(m_all, y_all, N_all,
                             m_s, y_s, N_s,
                             m_g, y_g, N_g,
                             fit_res, OUT_LOGN_ALL_JPG)
    print("Saved logN (all/stars/gals):", OUT_LOGN_ALL_JPG)

    # galaxies-only plot with Poisson error bars and fit line
    if m_g is not None:
        plot_logN_galaxies_with_errors(m_g, y_g, N_g, fit_res,
                                       OUT_LOGN_GALFIT_JPG)
        print("Saved logN (galaxies+fit):", OUT_LOGN_GALFIT_JPG)


[detect] {'bg': 3419.0, 'sigma': 13.343399999999999, 'thr': 3485.717, 'n_labels': 13846, 'n_cat': 2307}
[kNN] kept 2100 / 2307 detections after halo removal
Saved overlay: c:\Users\asus\OneDrive - Imperial College London\Lab Astro\Astro\Astro\Fits_Data\outputs\detections_overlay_knn.jpg
[classify] usable photometry for 2026 objects
[classify] stars=21, galaxies=2005
[fit] slope a = 0.368 ± 0.005, intercept b = -2.87 ± 0.08
[fit] mag range used: [11.50, 16.25], points = 20
Saved logN (all/stars/gals): c:\Users\asus\OneDrive - Imperial College London\Lab Astro\Astro\Astro\Fits_Data\outputs\logN_all_stars_gals.jpg
Saved logN (galaxies+fit): c:\Users\asus\OneDrive - Imperial College London\Lab Astro\Astro\Astro\Fits_Data\outputs\logN_galaxies_weightedfit.jpg
