In [16]:
import os
from pathlib import Path
from typing import List, Tuple
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# ========= 用户路径 =========
data_root = Path("/Users/hezizhao/Downloads/jupython-workspace/FH2/data")
fits_dir  = data_root / "fitscutout"
csv_path  = data_root / "fh2_desils.csv"

# 输出
png_dir   = fits_dir / "marked_png"
png_dir.mkdir(parents=True, exist_ok=True)

# 命名参数
ROUND_DIGITS = 5
DEFAULT_PIXSCALE = 0.262    # arcsec / pix
CENTER_MARKER_SIZE = 80
MATCH_MARKER_SIZE  = 40


# ---------- 解析多值列 ----------
def parse_multi_value(val) -> List[float]:
    """把 '123.4/123.5', '123.4 123.5', '[123.4,123.5]' 等解析成 float 列表。"""
    if pd.isna(val):
        return []
    if isinstance(val, (float, int, np.floating, np.integer)):
        return [float(val)]
    s = str(val).strip()
    if not s:
        return []
    for ch in "[](){}":
        s = s.replace(ch, " ")
    for sep in ["/", ";", ","]:
        s = s.replace(sep, " ")
    parts = [p for p in s.split() if p not in ("", "nan", "None")]
    out = []
    for p in parts:
        try:
            out.append(float(p))
        except ValueError:
            pass
    return out


# ---------- 构建 FITS 索引（按文件名解析 RA/DEC） ----------
def build_fits_index(fits_dir: Path) -> pd.DataFrame:
    rows = []
    for fp in fits_dir.glob("ra_*_dec_*-subimage.fits"):
        stem = fp.stem  # ra_###_dec_###
        if not stem.startswith("ra_") or "_dec_" not in stem:
            continue
        try:
            ra_str = stem.split("_dec_")[0].replace("ra_", "")
            dec_str = stem.split("_dec_")[1].replace("-subimage", "")
            rows.append({"ra_name": float(ra_str), "dec_name": float(dec_str), "path": fp})
        except Exception:
            continue
    return pd.DataFrame(rows)


def find_fits_for_radec(ra: float, dec: float, index_df: pd.DataFrame, tol_deg: float = 1e-4):
    """在索引表中找最接近 (ra,dec) 的文件；若距离>tol_deg (~0.36") 则返回 None。"""
    if index_df.empty:
        return None
    d2 = (index_df["ra_name"] - ra)**2 + (index_df["dec_name"] - dec)**2
    i = d2.idxmin()
    if i is None:
        return None
    if np.sqrt(d2.loc[i]) > tol_deg:
        return None
    return Path(index_df.loc[i, "path"])


# ---------- FITS I/O ----------
def load_fits_image_next_to_last(fits_path: Path):
    """返回倒数第二个 HDU 的 (data, header, ext_index)。"""
    from astropy.io import fits
    hdul = fits.open(fits_path)
    ext_index = len(hdul) - 2
    data = hdul[ext_index].data
    hdr  = hdul[ext_index].header
    hdul.close()
    return data, hdr, ext_index


# ---------- WCS ----------
def build_wcs_from_header(hdr):
    """若 header 中存在有效 WCS 则返回 WCS，否则 None。"""
    try:
        from astropy.wcs import WCS
        w = WCS(hdr)
        if w.naxis >= 2:
            return w
    except Exception:
        pass
    return None


def fallback_world2pix(ra_deg, dec_deg, img_shape, ra_c, dec_c, pixscale_arcsec=DEFAULT_PIXSCALE):
    """
    无 WCS fallback：小视场近似 TAN；x方向考虑 cos(dec)。
    返回 0-based pixel 坐标；origin='lower'。
    """
    ny, nx = img_shape
    deg_per_pix_y = pixscale_arcsec / 3600.0
    deg_per_pix_x = deg_per_pix_y / np.cos(np.deg2rad(dec_c))
    x0 = (nx - 1) / 2.0
    y0 = (ny - 1) / 2.0
    ra_arr  = np.asarray(ra_deg)
    dec_arr = np.asarray(dec_deg)
    x = x0 + (ra_arr - ra_c) / deg_per_pix_x
    y = y0 + (dec_arr - dec_c) / deg_per_pix_y
    return x, y


# ---------- 绘图 ----------
def draw_marked_png(
    data,
    hdr,
    png_path: Path,
    ra_center: float,
    dec_center: float,
    ra_matches: List[float],
    dec_matches: List[float],
    match_ids: List[int],        # 组内 ID（1..N）
    show_scale_bar: bool = False,
):
    """画中心、匹配点；在图上标出组内 ID。"""
    w = build_wcs_from_header(hdr)
    if w is not None:
        xy_center = np.array(w.all_world2pix([[ra_center, dec_center]], 0))[0]
        xy_match  = np.array(w.all_world2pix(np.column_stack([ra_matches, dec_matches]), 0)) if ra_matches else np.zeros((0,2))
        x0, y0 = xy_center
        xs, ys = xy_match[:,0], xy_match[:,1]
    else:
        xs, ys = fallback_world2pix(ra_matches, dec_matches, data.shape, ra_center, dec_center)
        x0, y0 = fallback_world2pix([ra_center], [dec_center], data.shape, ra_center, dec_center)
        x0, y0 = x0[0], y0[0]

    vmin, vmax = np.nanpercentile(data, (5, 99.5))

    fig, ax = plt.subplots(figsize=(4,4), dpi=150)
    ax.imshow(data, origin='lower', vmin=vmin, vmax=vmax, cmap='gray')

    # 中心 (输入 ra/dec)
    ax.scatter([x0], [y0], s=CENTER_MARKER_SIZE, marker='+', color='yellow', linewidths=1.5, label='Center')

    # 匹配点
    if len(ra_matches) > 0:
        ax.scatter(xs, ys, s=MATCH_MARKER_SIZE, facecolors='none', edgecolors='red', linewidths=1.0, label='Matches')
        for mid, xx, yy in zip(match_ids, xs, ys):
            ax.text(xx+1, yy+1, str(mid), color='red', fontsize=6)

    # 左上角：组信息
    ax.text(
        0.01, 0.99,
        f"({ra_center:.5f}, {dec_center:.5f})  N={len(match_ids)}",
        color='cyan', fontsize=7,
        ha='left', va='top',
        transform=ax.transAxes,
        bbox=dict(facecolor='black', alpha=0.35, edgecolor='none', pad=1.0),
    )

    ax.set_xlim(-0.5, data.shape[1]-0.5)
    ax.set_ylim(-0.5, data.shape[0]-0.5)
    ax.set_xlabel("X [pix]")
    ax.set_ylabel("Y [pix]")
    ax.set_title(f"RA={ra_center:.5f}, Dec={dec_center:.5f}")
    ax.legend(loc='upper right', fontsize=6, frameon=False)

    # 可选比例尺
    if show_scale_bar:
        scale_arcsec = 10.0
        pixscale = DEFAULT_PIXSCALE
        n_pix = scale_arcsec / pixscale
        x_bar_start = data.shape[1]*0.05
        y_bar = data.shape[0]*0.05
        ax.plot([x_bar_start, x_bar_start+n_pix], [y_bar, y_bar], color='white', lw=1)
        ax.text(x_bar_start+n_pix/2, y_bar+2, f"{scale_arcsec:.0f}\"", color='white', ha='center', va='bottom', fontsize=6)

    fig.tight_layout()
    fig.savefig(png_path, bbox_inches='tight')
    plt.close(fig)


# ---------- 主流程 ----------
def main():
    # 1. 读原始星表
    df = pd.read_csv(csv_path)

    # 2. 解析 ra_out / dec_out 为列表
    df["ra_out_list"]  = df["ra_out"].apply(parse_multi_value)
    df["dec_out_list"] = df["dec_out"].apply(parse_multi_value)

    # 3. 对每行爆炸（一个匹配 = 一行）
    df_exp = df.explode(["ra_out_list", "dec_out_list"], ignore_index=True)
    df_exp = df_exp.rename(columns={"ra_out_list":"ra_out_ex","dec_out_list":"dec_out_ex"})

    # 删除没有匹配的行
    df_exp = df_exp[~df_exp["ra_out_ex"].isna() & ~df_exp["dec_out_ex"].isna()].copy()

    # 4. 组内编号：按 (ra, dec) 分组；保持原始出现顺序
    df_exp["ID"] = df_exp.groupby(["ra","dec"]).cumcount() + 1

    # 5. 保存爆炸+ID后的表
    df_exp_out = csv_path.with_name(csv_path.stem + "_exploded_withID.csv")
    df_exp.to_csv(df_exp_out, index=False)

    # 6. 准备 FITS 索引
    idx_df = build_fits_index(fits_dir)
    if idx_df.empty:
        raise FileNotFoundError(f"没有在 {fits_dir} 找到 cutout FITS。")

    # 7. 按 unique (ra,dec) 画图
    for (ra_c, dec_c), g in df_exp.groupby(["ra","dec"], sort=False):
        # 收集匹配
        ra_m = g["ra_out_ex"].astype(float).to_list()
        dec_m = g["dec_out_ex"].astype(float).to_list()
        match_ids = g["ID"].astype(int).to_list()

        # 找到 FITS
        fname = f"ra_{float(ra_c):.{ROUND_DIGITS}f}_dec_{float(dec_c):.{ROUND_DIGITS}f}-subimage.fits"
        fpath = fits_dir / fname
        if not fpath.exists():
            fpath = find_fits_for_radec(float(ra_c), float(dec_c), idx_df)
            if fpath is None:
                print(f"[WARN] 无FITS: ({ra_c:.5f},{dec_c:.5f}) 跳过绘图。")
                continue

        # 打开倒数第二层
        try:
            data, hdr, ext_idx = load_fits_image_next_to_last(fpath)
        except Exception as e:
            print(f"[ERR] 打开 {fpath.name} 出错: {e}")
            continue

        # 输出文件
        png_name = f"ra_{float(ra_c):.{ROUND_DIGITS}f}_dec_{float(dec_c):.{ROUND_DIGITS}f}_N{len(match_ids)}.png"
        png_path = png_dir / png_name

        draw_marked_png(
            data=data,
            hdr=hdr,
            png_path=png_path,
            ra_center=float(ra_c),
            dec_center=float(dec_c),
            ra_matches=ra_m,
            dec_matches=dec_m,
            match_ids=match_ids,
        )

    print("\n=== 完成 ===")
    print(f"爆炸后星表写出: {df_exp_out}")
    print(f"图像输出目录:   {png_dir}")


if __name__ == "__main__":
    main()




=== 完成 ===
爆炸后星表写出: /Users/hezizhao/Downloads/jupython-workspace/FH2/data/fh2_desils_exploded_withID.csv
图像输出目录:   /Users/hezizhao/Downloads/jupython-workspace/FH2/data/fitscutout/marked_png
