In [None]:
import pandas as pd
import numpy as np
from astropy.io import fits
import matplotlib.pyplot as plt
import os
import re

from astroquery.jplhorizons import Horizons

from astropy.io import fits
from astropy.nddata import Cutout2D
from astropy.wcs import WCS
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.visualization import simple_norm
from astropy.time import Time
from astropy.config import set_temp_cache

In [None]:
df_init = pd.read_csv('data/comets_list2.csv')
df_init

In [None]:
def ztf_name_to_irsa_url(name: str,
                         base: str = "https://irsa.ipac.caltech.edu/ibe/data/ztf/products",
                         product: str | None = None) -> str:
    """
    Convert a ZTF filename to its IRSA IBE URL.

    Parameters
    ----------
    name : str
        e.g. 'ztf_20180508456111_000386_zr_c08_o_q2_sciimg.fits'
    base : str
        Base IRSA IBE URL (rarely needs changing).
    product : str | None
        ZTF product subtree. If None, try to infer ('sci', 'raw', 'cal') from the name.

    Returns
    -------
    str
        Full URL to the file on IRSA.

    Raises
    ------
    ValueError if the name doesn’t match the expected ZTF pattern.
    """
    # Try to infer product subtree if not provided
    if product is None:
        lowered = name.lower()
        if "sci" in lowered:
            product = "sci"
        elif "raw" in lowered:
            product = "raw"
        elif "cal" in lowered:
            product = "cal"
        else:
            product = "sci"  # sensible default

    # Parse ztf_{YYYY}{MM}{DD}{FIELD6}_{...}.fits
    m = re.match(r"^ztf_(\d{4})(\d{4})(\d{6})_", name, flags=re.IGNORECASE)
    if not m:
        raise ValueError("Filename doesn't look like a ZTF name: expected 'ztf_YYYYMMDDFFFFFF_...'.")
    year, mmdd, field6 = m.groups()

    return f"{base}/{product}/{year}/{mmdd}/{field6}/{name}.fits"


In [None]:
ztf_colors = {
    'ZTF_g': 'Greens_r',
    'ZTF_r': 'Reds_r',
    'ZTF_i': 'Oranges_r'}

In [None]:
for i, row in df_init.iterrows():
    cadc_filename = f"data/cadc2/{row['Comet'].replace('/', '_')}_cadc.csv"
    df_cadc = pd.read_csv(cadc_filename)
    comet_name = row['Comet']
    print(f"Processing {comet_name} from {cadc_filename}")
    
    # get observations from the ZTF
    df_ztf = df_cadc[df_cadc['Telescope/Instrument'] == 'ZTF']
    print(f"File: {cadc_filename}, entries: {len(df_ztf)}")
    if len(df_ztf) == 0:
        continue
    
    for index, row in df_ztf.iterrows():
        print(f"Row index: {index}, MJD: {row['MJD']}")
        url = ztf_name_to_irsa_url(row['Image'])

        print(f"Downloading from URL: {url}")
        with set_temp_cache(path='./data/comets/fits/ztf/', delete=False):
            hdul = fits.open(url, cache=True, memmap=True)

        # Get WCS and observation time
        w = WCS(hdul[0].header)
        obs_time = Time(hdul[0].header.get('OBSMJD', 'Unknown'), format='mjd')
        
        # Get filter band
        filter_band = hdul[0].header.get('FILTER', 'Unknown').strip()
        
        # Get exposure time
        exposure = hdul[0].header.get('EXPTIME', 'Unknown')
        
        # Recompute comet position at observation time
        print(f"Recomuting the comet position at MJD={obs_time.mjd}")
        obj = Horizons(id=comet_name, location='I41', epochs=obs_time.mjd)
        eph = obj.ephemerides()
        ra_comet = eph['RA'][0]
        dec_comet = eph['DEC'][0]
        print(f"Comet position at observation time: RA={ra_comet}, Dec={dec_comet}")
        
        # Create SkyCoord for the comet position
        coord = SkyCoord(ra_comet*u.deg, dec_comet*u.deg, frame='icrs')

        # Check if the position is within the image FOV
        hdu = hdul[0]
        ny, nx = hdu.data.shape
        x_pix, y_pix = w.world_to_pixel(coord)
        in_fov = (0 <= x_pix < nx) and (0 <= y_pix < ny)
        
        if in_fov:
            size = u.Quantity((1, 1), u.arcmin)
            cutout = Cutout2D(hdu.data, (x_pix, y_pix), size, wcs=w)
            
            # Save cutout to a new FITS file
            hdu_out = fits.PrimaryHDU(data=cutout.data, header=cutout.wcs.to_header())
            hdu_out.header['COMET'] = comet_name
            hdu_out.header['OBSMJD'] = obs_time.mjd
            hdu_out.header['FILTER'] = filter_band
            hdu_out.header['ORIGFILE'] = os.path.basename(url)
            hdu_out.header['RA_COM'] = ra_comet
            hdu_out.header['DEC_COM'] = dec_comet
            hdu_out.header['CUTSIZE'] = '1 arcmin'
            hdu_out.header['EXPTIME'] = exposure
            
            out_filename = f"data/comets/cutouts/ztf/{comet_name.replace('/', '_')}_ztf_{int(obs_time.mjd)}.fits"
            hdu_out.writeto(out_filename, overwrite=True)
            print(f'Wrote cutout to {out_filename}')

            # Plot the cutout with the comet position
            fig = plt.figure(figsize=(5, 5))
            ax = fig.add_subplot(1, 1, 1, projection=cutout.wcs)
            norm = simple_norm(cutout.data, 'sqrt', percent=95.0)
            ax.imshow(cutout.data, origin='lower', cmap=ztf_colors[filter_band], norm=norm)
            ax.scatter(cutout.wcs.world_to_pixel(coord)[0], cutout.wcs.world_to_pixel(coord)[1],
                    s=100, edgecolor='white', facecolor='none', marker='o', lw=1, alpha=0.7)
            # WCSAxes niceties
            ax.set_xlabel("RA")
            ax.set_ylabel("Dec")
            ax.coords.grid(True, alpha=0.3, linestyle="--")
            plt.tight_layout()
            plt.savefig(f"data/comets/figs/ztf/{comet_name.replace('/', '_')}_ztf_{int(obs_time.mjd)}_{filter_band}.png")
            plt.show()
        break
# # hdul.close()
    break
# df_ztf