In [12]:
import matplotlib.pyplot as plt
import numpy as np
from astropy.io import fits
import pandas as pd
from astropy.coordinates import SkyCoord
from astropy import units as u
from astropy.table import Table
from astropy.config import set_temp_cache
import os
import time
import urllib.error

In [2]:
df = pd.read_csv('data/dr1_selected_galaxy_byReff_equal_4Alex.csv')
df

Unnamed: 0,TARGETID,SURVEY,PROGRAM,HEALPIX,Z,ZERR,TARGET_RA,TARGET_DEC,OBJTYPE,BRICKID,...,MASSCOR_SL,V0_SL,VD_SL,AV_SL,XJ_SL,AGE_SL,Z_SL,DN4000,DN4000_ERR,Re_kpc
0,39633251291106150,main,dark,6637,0.423628,0.000082,118.716722,50.207093,TGT,585122,...,88364.101562,71.099998,216.309998,0.0884,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.078243e+10,0.023103,1.793298,0.161450,53.835911
1,39633417075163752,main,bright,16008,0.469346,0.000098,283.536197,62.495071,TGT,624648,...,46264.800781,99.959999,258.140015,-0.3780,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",7.692512e+09,0.035361,1.685602,0.223025,34.164835
2,39628514546160743,special,bright,20335,0.349231,0.000055,2.615321,31.665865,TGT,504370,...,130564.000000,71.870003,206.429993,0.1363,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",7.072529e+09,0.042743,1.794236,0.103359,56.366876
3,39627763589580919,sv3,bright,25599,0.263134,0.000048,180.089005,-0.983064,TGT,325328,...,102036.000000,86.550003,173.720001,-0.1394,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.4090999960899...",7.526512e+09,0.038023,1.713095,0.093255,36.278299
4,39628279400892215,main,dark,8961,0.368840,0.000064,225.222316,20.726499,TGT,448307,...,77068.296875,67.900002,187.500000,-0.3192,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",9.468909e+09,0.033063,1.740005,0.092325,72.902456
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
995,39627781046273084,main,bright,27279,0.310520,0.000069,140.539579,-0.170391,TGT,329490,...,132203.000000,66.449997,273.029999,-0.1604,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",7.395761e+09,0.035856,1.812492,0.127308,25.877943
996,39627774897422433,main,bright,21844,0.404497,0.000086,134.002597,-0.516236,TGT,328024,...,39412.601562,79.360001,235.839996,0.6254,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",2.838210e+09,0.043295,1.366483,0.284115,31.443220
997,39633338247415516,main,bright,7071,0.401326,0.000121,99.559496,56.372717,TGT,605854,...,91888.101562,90.320000,157.779999,0.0841,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.300000e+10,0.020000,6.663783,5.857314,8.480782
998,39627841872073028,main,dark,27039,0.454288,0.000075,166.193209,2.168347,TGT,343992,...,72960.296875,88.199997,256.769989,-0.3745,"(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...",1.147934e+10,0.023582,1.843497,0.106270,17.260873


In [3]:
df_healpix = df.groupby(['SURVEY', 'PROGRAM', 'HEALPIX']).size().reset_index(name='COUNT').sort_values(by='COUNT', ascending=False)
df_healpix

Unnamed: 0,SURVEY,PROGRAM,HEALPIX,COUNT
763,main,dark,20353,3
127,main,bright,9172,2
212,main,bright,12675,2
742,main,dark,19327,2
812,main,dark,26252,2
...,...,...,...,...
329,main,bright,20328,1
330,main,bright,20379,1
331,main,bright,20391,1
332,main,bright,21826,1


# Get spectrum from DESI

In [None]:
# Define the spectroscopic product directory
desi_root = "https://data.desi.lbl.gov/public/dr1//spectro/redux/iron"


In [None]:
def save_spectrum_to_fits(filename, wave, flux, ivar, mask):
    """
    Save the spectrum data to a FITS file.
    
    Parameters:
    - targetid: Unique identifier for the target
    - wave: Dictionary with wavelength data for each camera
    - flux: Dictionary with flux data for each camera
    - ivar: Dictionary with inverse variance data for each camera
    - mask: Dictionary with mask data for each camera
    """
    # Ensure the output directory exists
    out_folder = './data/spectra1000/'
    os.makedirs(out_folder, exist_ok=True)
    
    # Define the output path
    out_path = os.path.join(out_folder, filename)
    
    # Create HDUList with PrimaryHDU and BinTableHDUs for each camera
    hdus = []
    hdus.append(fits.PrimaryHDU())
    
    for cam in ['B', 'R', 'Z']:
        cols = [
            fits.Column(name='WAVELENGTH', format='D', array=wave[cam]),
            fits.Column(name='FLUX'      , format='D', array=flux[cam]),
            fits.Column(name='IVAR'      , format='D', array=ivar[cam]),
            fits.Column(name='MASK'      , format='J', array=mask[cam]),
        ]
        table_hdu = fits.BinTableHDU.from_columns(cols, name=f'SPECTRUM_{cam}')
        hdus.append(table_hdu)
    
    # Write the HDUList to a FITS file
    hdulist = fits.HDUList(hdus)
    hdulist.writeto(out_path, overwrite=True)
    print(f"Saved spectrum for TARGETID {targetid} to {out_path}")

In [None]:
def load_fits_spectrum(target, out_folder='./data/spectra1000/'):
    """
    Load the spectrum data from a FITS file.
    
    Parameters:
    - targetid: Unique identifier for the target
    - out_folder: Directory containing the FITS files

    Returns:
    - wave: Dictionary with wavelength data for each camera
    - flux: Dictionary with flux data for each camera
    - ivar: Dictionary with inverse variance data for each camera
    - mask: Dictionary with mask data for each camera
    """
    
    # out_folder = './data/spectra'
    targetid = target['TARGETID']
    survey = target['SURVEY']
    program = target['PROGRAM']
    healpix = target['HEALPIX']

    out_path = os.path.join(out_folder, f'{targetid}_{survey}_{program}_{healpix}.fits')
    print(out_path)

    if not os.path.exists(out_path):
        raise FileNotFoundError(f"Spectrum file for TARGETID {targetid} not found.")
    
    hdu_spectrum = fits.open(out_path)
    
    wave = dict()
    flux = dict()
    ivar = dict()
    mask = dict()
    
    for cam in ['B', 'R', 'Z']:
        wave[cam] = hdu_spectrum[f'SPECTRUM_{cam}'].data['WAVELENGTH']
        flux[cam] = hdu_spectrum[f'SPECTRUM_{cam}'].data['FLUX']
        ivar[cam] = hdu_spectrum[f'SPECTRUM_{cam}'].data['IVAR']
        mask[cam] = hdu_spectrum[f'SPECTRUM_{cam}'].data['MASK']
    
    return wave, flux, ivar, mask


In [6]:
def get_values(hdu_coadd, target_index):
    """
    Extracts wavelength, flux, inverse variance, and mask data for a given target index
    from the coadd HDU.
    """
    # Extract data for the specified target index
    wave = dict()
    flux = dict()
    ivar = dict()
    mask = dict()
    for camera in ['B', 'R', 'Z']:
        wave[camera] = hdu_coadd[f'{camera}_WAVELENGTH'].data
        flux[camera] = hdu_coadd[f'{camera}_FLUX'].data[target_index]
        ivar[camera] = hdu_coadd[f'{camera}_IVAR'].data[target_index]
        mask[camera] = hdu_coadd[f'{camera}_MASK'].data[target_index]
    return wave, flux, ivar, mask


In [8]:
def plot_spectrum(wave, flux, ivar, mask, targetid, redshift, issave=False):
    """
    Plots the spectrum for a given target.
    """
    plt.figure(figsize=(8, 3))
    for camera in ['B', 'R', 'Z']:
        w = wave[camera]
        f = flux[camera]
        i = ivar[camera]
        m = np.bool(mask[camera])
        plt.plot(w[~m], f[~m], label=f'{camera} band', lw=0.5)
        # plt.fill_between(w[~m], f[~m] - 1/np.sqrt(i[~m]), f[~m] + 1/np.sqrt(i[~m]), alpha=0.2)
    plt.xlabel('Wavelength (Angstroms)')
    plt.ylabel('Flux (arbitrary units)')
    plt.title(f'Spectrum for TARGETID {targetid} at redshift {redshift:.3f}')
    plt.legend()
    plt.grid(linestyle='--', alpha=0.5)
    plt.tight_layout()
    if issave:
        plt.savefig(f'./figs/spectra/{targetid}.png', dpi=300)
        plt.close()
    else:
        plt.show()


In [None]:

def download_with_retry(coadd_filepath, max_retries=3, timeout=600, cache_path='./data/cache/'):
    """Download FITS file with retry logic and timeout handling."""
    for attempt in range(max_retries):
        try:
            print(f"Attempt {attempt + 1}/{max_retries} to download: {coadd_filepath}")
            with set_temp_cache(path=cache_path, delete=False):
                hdu_coadd = fits.open(coadd_filepath, cache=True, timeout=timeout)
            return hdu_coadd
        except (TimeoutError, urllib.error.URLError, OSError) as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff
                print(f"Waiting {wait_time} seconds before retry...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} attempts failed for {coadd_filepath}")
                return None
    return None

In [None]:
def download_with_no_cache_with_retry(coadd_filepath, max_retries=3, timeout=600, cache_path='./data/cache/'):
    """
    Download FITS file with retry logic and timeout handling, without using cache.
    
    Parameters:
    - coadd_filepath: URL of the FITS file to download
    - max_retries: Maximum number of retry attempts
    - timeout: Timeout for each download attempt
    - cache_path: Directory to store the downloaded file
    
    Returns:
    - hdu_coadd: HDUList object containing the FITS data
    """
    for attempt in range(max_retries):
        try:
            print(f"Attempt {attempt + 1}/{max_retries} to download: {coadd_filepath}")
            with set_temp_cache(path=cache_path, delete=False):
                hdu_coadd = fits.open(coadd_filepath, cache=False, timeout=timeout)
            return hdu_coadd
        except (TimeoutError, urllib.error.URLError, OSError) as e:
            print(f"Attempt {attempt + 1} failed: {e}")
            if attempt < max_retries - 1:
                wait_time = 2 ** attempt  # Exponential backoff
                print(f"Waiting {wait_time} seconds before retry...")
                time.sleep(wait_time)
            else:
                print(f"All {max_retries} attempts failed for {coadd_filepath}")
                return None
    return None

In [None]:
# Define the spectroscopic product directory
desi_root = "https://data.desi.lbl.gov/public/dr1//spectro/redux/iron"
out_folder = './data/spectra1000/'
cache_path = './data/cache/'

start = 0
num = 100
for i, (_, row_healpix) in enumerate(df_healpix[start:start + num].iterrows()):
    survey = row_healpix['SURVEY']
    program = row_healpix['PROGRAM']
    healpix = row_healpix['HEALPIX']
    
    # Calculate the HPixGroup
    hpixgroup = healpix // 100
    print(f"Row {i + start}: SURVEY={survey}, PROGRAM={program}, HEALPIX={healpix}, HPixGroup={hpixgroup}")
    # Filename
    coadd_filepath = f'{desi_root}/healpix/{survey}/{program}/{hpixgroup}/{healpix}/coadd-{survey}-{program}-{healpix}.fits'

    cond = (df['SURVEY'] == survey) & (df['PROGRAM'] == program) & (df['HEALPIX'] == healpix)
    df_temp = df[cond]

    num_targets = len(df_temp)
    count = 0
    for _, row in df_temp.iterrows():
        targetid = row['TARGETID']
        filename = f'{targetid}_{survey}_{program}_{healpix}.fits'
        out_path = os.path.join(out_folder, filename)
        if os.path.exists(out_path):
            count += 1
    
    if count == num_targets:
        print(f"All {num_targets} targets for HEALPIX {healpix} already processed. Skipping download.")
        continue

    # Try to download with retry logic
    # hdu_coadd = download_with_retry(coadd_filepath, max_retries=5, timeout=600, cache_path='./data/ucmgs7411/')
    hdu_coadd = download_with_retry(coadd_filepath, max_retries=5, timeout=600, cache_path=cache_path)

    if hdu_coadd is None:
        print(f"Failed to download or open the FITS file for HEALPIX {healpix}. Skipping this HEALPIX.")
        continue

    for i, row in df[cond].iterrows():
        redshift = row['Z']
        targetid = row['TARGETID']
    
        df_fiber = pd.DataFrame(hdu_coadd['FIBERMAP'].data)
        cond = df_fiber['TARGETID'] == targetid
        target_index = df_fiber.index[cond][0]  # Get the first index where condition is True

        wave, flux, ivar, mask = get_values(hdu_coadd, target_index)

        # plot_spectrum(wave, flux, ivar, mask, targetid, redshift, issave=True)
        # Save the spectrum to a FITS file
        filename = f'{targetid}_{survey}_{program}_{healpix}.fits'
        save_spectrum_to_fits(filename, wave, flux, ivar, mask)
        # Load the spectrum from the FITS file
        # 
    # Clear the cache folder
    if (i + start) % 10 == 0:
        print(f"Processed {i + start} HEALPIX groups. Clearing cache folder.")
        folder_to_delete = os.path.join(cache_path, 'astropy')
        os.system(f'rm -rf {folder_to_delete}')
        print(f"Cache folder {folder_to_delete} cleared.")
    


Row 0: SURVEY=main, PROGRAM=dark, HEALPIX=20353, HPixGroup=203
All 3 targets for HEALPIX 20353 already processed. Skipping download.
Row 1: SURVEY=main, PROGRAM=bright, HEALPIX=9172, HPixGroup=91
Attempt 1/5 to download: https://data.desi.lbl.gov/public/dr1//spectro/redux/iron/healpix/main/bright/91/9172/coadd-main-bright-9172.fits
Saved spectrum for TARGETID 39632971581358715 to ./data/spectra7411/39632971581358715_main_bright_9172.fits
Saved spectrum for TARGETID 39632976559999708 to ./data/spectra7411/39632976559999708_main_bright_9172.fits
Processed 900 HEALPIX groups. Clearing cache folder.
Cache folder ./data/cache/astropy cleared.
Row 2: SURVEY=main, PROGRAM=bright, HEALPIX=12675, HPixGroup=126
Attempt 1/5 to download: https://data.desi.lbl.gov/public/dr1//spectro/redux/iron/healpix/main/bright/126/12675/coadd-main-bright-12675.fits
Attempt 1 failed: <urlopen error File was supposed to be 524134080 bytes but we only got 279319662 bytes. Download failed.>
Waiting 1 seconds before