In [None]:
import pandas as pd
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
import os
import re

from astroquery.jplhorizons import Horizons

from astropy.io import fits
from astropy.nddata import Cutout2D
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.visualization import simple_norm
from astropy.time import Time
from astropy.config import set_temp_cache

In [None]:
df_init = pd.read_csv('data/comets_list2.csv')
df_init

In [None]:
def ztf_name_to_irsa_url(name: str,
                         base: str = "https://irsa.ipac.caltech.edu/ibe/data/ztf/products",
                         product: str | None = None) -> str:
    """
    Convert a ZTF filename to its IRSA IBE URL.

    Parameters
    ----------
    name : str
        e.g. 'ztf_20180508456111_000386_zr_c08_o_q2_sciimg.fits'
    base : str
        Base IRSA IBE URL (rarely needs changing).
    product : str | None
        ZTF product subtree. If None, try to infer ('sci', 'raw', 'cal') from the name.

    Returns
    -------
    str
        Full URL to the file on IRSA.

    Raises
    ------
    ValueError if the name doesn’t match the expected ZTF pattern.
    """
    # Try to infer product subtree if not provided
    if product is None:
        lowered = name.lower()
        if "sci" in lowered:
            product = "sci"
        elif "raw" in lowered:
            product = "raw"
        elif "cal" in lowered:
            product = "cal"
        else:
            product = "sci"  # sensible default

    # Parse ztf_{YYYY}{MM}{DD}{FIELD6}_{...}.fits
    m = re.match(r"^ztf_(\d{4})(\d{4})(\d{6})_", name, flags=re.IGNORECASE)
    if not m:
        raise ValueError("Filename doesn't look like a ZTF name: expected 'ztf_YYYYMMDDFFFFFF_...'.")
    year, mmdd, field6 = m.groups()

    return f"{base}/{product}/{year}/{mmdd}/{field6}/{name}.fits"


In [None]:
ztf_colors = {
    'ZTF_g': 'Greens_r',
    'ZTF_r': 'Reds_r',
    'ZTF_i': 'Oranges_r'
}


In [None]:
# Robust ephemeris fetch: Horizons with retries + local CSV fallback
from time import sleep
from functools import lru_cache
import warnings

@lru_cache(maxsize=4096)
def _horizons_ephem_cached(name: str, mjd: float, location: str = 'I41'):
    obj = Horizons(id=name, location=location, epochs=float(mjd))
    eph = obj.ephemerides()
    ra = float(eph['RA'][0])
    dec = float(eph['DEC'][0])
    tmag = float(eph['Tmag'][0]) if 'Tmag' in eph.colnames else float('nan')
    return ra, dec, tmag


def get_comet_position(name: str, mjd: float, location: str = 'I41', retries: int = 4, backoff_s: float = 3.0):
    """
    Try to get (RA, Dec, Tmag) at given MJD from JPL Horizons with exponential backoff.
    On failure, fall back to local CSV in data/jpl2/<name>_jplhorizons.csv using nearest JD.
    Returns a dict: {ra, dec, tmag, source}
    """
    last_err = None
    for i in range(max(retries, 0)):
        try:
            ra, dec, tmag = _horizons_ephem_cached(name, float(mjd), location)
            return {'ra': ra, 'dec': dec, 'tmag': tmag, 'source': 'horizons'}
        except Exception as e:
            last_err = e
            if i < retries - 1:
                sleep(backoff_s * (2 ** i))

    # Fallback to local CSV (precomputed Horizons table)
    csv_path = f"data/jpl2/{name.replace('/', '_')}_jplhorizons.csv"
    if os.path.exists(csv_path):
        try:
            df = pd.read_csv(csv_path)
            # Expect columns: 'datetime_jd', 'RA', 'DEC', 'Tmag' (as seen in repo files)
            if 'datetime_jd' in df.columns and 'RA' in df.columns and 'DEC' in df.columns:
                jd_target = float(mjd) + 2400000.5
                idx = (df['datetime_jd'] - jd_target).abs().idxmin()
                ra = float(df.loc[idx, 'RA'])
                dec = float(df.loc[idx, 'DEC'])
                tmag = float(df.loc[idx, 'Tmag']) if 'Tmag' in df.columns and pd.notna(df.loc[idx, 'Tmag']) else float('nan')
                return {'ra': ra, 'dec': dec, 'tmag': tmag, 'source': 'local-jpl2'}
        except Exception as e:
            last_err = e

    raise RuntimeError(f"Ephemeris unavailable for {name} at MJD={mjd}: {last_err}")


In [None]:
for i, row in df_init.iterrows():
    visible_filename = f"data/visible2/{row['Comet'].replace('/', '_').replace(' ', '')}_visible.csv"
    df_cadc = pd.read_csv(visible_filename)
    comet_name = row['Comet']
    print(f"Processing {comet_name} from {visible_filename}")

    # get observations from the ZTF
    df_ztf = df_cadc[df_cadc['Telescope/Instrument'] == 'ZTF']
    print(f"File: {visible_filename}, entries: {len(df_ztf)}")
    if len(df_ztf) == 0:
        continue
    
    # Ensure output folders exist
    os.makedirs('data/comets/cutouts/ztf', exist_ok=True)
    os.makedirs('data/comets/figs/ztf', exist_ok=True)

    for index, row in df_ztf.iterrows():
        print(f"Row index: {index}, MJD: {row['MJD']}")
        url = ztf_name_to_irsa_url(row['Image'])

        print(f"Downloading from URL: {url}")
        with set_temp_cache(path='./data/comets/fits/ztf/', delete=False):
            hdul = fits.open(url, cache=True, memmap=True)

        try:
            # Get WCS and observation time
            w = WCS(hdul[0].header)
            obs_time = Time(hdul[0].header.get('OBSMJD', 'Unknown'), format='mjd')
            
            # Get filter band
            filter_band = hdul[0].header.get('FILTER', 'Unknown').strip()
            
            # Get exposure time
            exposure = hdul[0].header.get('EXPTIME', 'Unknown')
            
            # Checking if the file already processed
            out_filename = f"data/comets/cutouts/ztf/{comet_name.replace('/', '_')}_ztf_{obs_time.iso}.fits"
            if os.path.exists(out_filename):
                print("File already processed")
                continue
            
            
            # Get comet position at observation time with retries and fallback
            print(f"Computing the comet position at MJD={obs_time.mjd}")
            try:
                ephem = get_comet_position(comet_name, obs_time.mjd, location='I41', retries=4, backoff_s=3.0)
            except Exception as e:
                print(f"Ephemeris fetch failed: {e}. Skipping this frame.")
                continue

            ra_comet = ephem['ra']
            dec_comet = ephem['dec']
            tmag = ephem.get('tmag', float('nan'))
            print(f"Comet position at observation time: RA={ra_comet}, Dec={dec_comet} (source={ephem['source']})")
            print(f"Filter: {filter_band}, Exposure: {exposure} sec")
            if pd.notna(tmag):
                print(f"Predicted comet magnitude: {tmag}")
            
            # Create SkyCoord for the comet position
            coord = SkyCoord(ra_comet*u.deg, dec_comet*u.deg, frame='icrs')

            # Check if the position is within the image FOV
            hdu = hdul[0]
            ny, nx = hdu.data.shape
            x_pix, y_pix = w.world_to_pixel(coord)
            in_fov = (0 <= x_pix < nx) and (0 <= y_pix < ny)
            
            if in_fov:
                size = u.Quantity((1, 1), u.arcmin)
                cutout = Cutout2D(hdu.data, (x_pix, y_pix), size, wcs=w)
                
                # Save cutout to a new FITS file
                hdu_out = fits.PrimaryHDU(data=cutout.data, header=cutout.wcs.to_header())
                hdu_out.header['COMET'] = comet_name
                hdu_out.header['OBSMJD'] = obs_time.mjd
                hdu_out.header['FILTER'] = filter_band
                hdu_out.header['ORIGFILE'] = os.path.basename(url)
                hdu_out.header['RA_COM'] = ra_comet
                hdu_out.header['DEC_COM'] = dec_comet
                hdu_out.header['CUTSIZE'] = '1 arcmin'
                hdu_out.header['EXPTIME'] = exposure
                
                out_filename = f"data/comets/cutouts/ztf/{comet_name.replace('/', '_')}_ztf_{obs_time.iso}.fits"
                hdu_out.writeto(out_filename, overwrite=True)
                print(f'Wrote cutout to {out_filename}')

                # Plot the cutout with the comet position
                fig = plt.figure(figsize=(5, 5))
                ax = fig.add_subplot(1, 1, 1, projection=cutout.wcs)
                norm = simple_norm(cutout.data, 'sqrt', percent=95.0)
                cmap = ztf_colors.get(filter_band, 'gray')
                ax.imshow(cutout.data, origin='lower', cmap=cmap, norm=norm)
                ax.scatter(cutout.wcs.world_to_pixel(coord)[0], cutout.wcs.world_to_pixel(coord)[1],
                        s=400, edgecolor='yellow', facecolor='none', marker='o', lw=1, alpha=0.7)
                # WCSAxes niceties
                ax.set_xlabel("RA")
                ax.set_ylabel("Dec")
                ax.coords.grid(True, alpha=0.3, linestyle="--")
                plt.tight_layout()
                fig_path = f"data/comets/figs/ztf/{comet_name.replace('/', '_')}_ztf_{obs_time.iso}_{filter_band}.png"
                plt.savefig(fig_path)
                print(f'Saved figure to {fig_path}')
                plt.show()
        finally:
            try:
                hdul.close()
            except Exception:
                pass
        # break
# # hdul.close()
    # break
# df_ztf

In [None]:
comet_name = "C/2009 F2"
t = Time(58204.5145255113, format='mjd', scale='utc')
# t.iso
obj = Horizons(id=comet_name, location='I41', epochs=t.mjd)
eph = obj.ephemerides()
eph
