In [1]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from glob import glob
from matplotlib.backends.backend_pdf import PdfPages
from astropy import units as u
from astropy.coordinates import SkyCoord
#pd.options.mode.copy_on_write = True
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM
import multiprocessing as mp
from tqdm import tqdm
cosmo = FlatLambdaCDM(name='Planck18', H0=67.66, Om0=0.30966, Tcmb0=2.7255, Neff=3.046, m_nu=[0.  , 0.  , 0.06]* u.eV, Ob0=0.04897)
cores = 8
plt.style.use('ggplot')

In [2]:
from desispec.io import read_spectra
from desitrip.preproc import rebin_flux, rescale_flux

from glob import glob

In [3]:
import warnings
warnings.filterwarnings('ignore')
warnings.filterwarnings('ignore', message='.*read_spectra.*')

In [6]:
minw, maxw, nbins = 2500., 9000., 160

In [9]:
def drop_nan(mytable):
    has_nan = np.zeros(len(mytable), dtype=bool)
    for col in mytable.itercols():
        if col.info.dtype.kind == 'f':
            has_nan |= np.isnan(col)
    mytable_no_nan = mytable[~has_nan]
    return mytable_no_nan

In [10]:
def condition_tde(coadd_files):
    """Read DESI spectra, rebin to a subsampled logarithmic wavelength grid, and rescale.
    
    Parameters
    ----------
    coadd_files : list or ndarray
        List of ascii files
    
    Returns
    -------
    fluxes : ndarray
        Array of fluxes rebinned to a logarithmic wavelength grid.
    """
    fluxes = []
    
    for cf in coadd_files:
        which = len(Table.read(cf, format = "ascii").colnames)
        if(which == 3):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux", "flux_err"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = spectra["flux_err"]
        elif(which == 6):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux", "sky_flux", "flux_err", "xpixel", "ypixel", "resopnse"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = spectra["flux_err"]
        elif(which == 2):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = None

#         # Pre-condition: remove spectra with NaNs and zero flux values.
#         mask = np.isnan(flux).any(axis=1) | (np.count_nonzero(flux, axis=1) == 0)
#         mask_idx = np.argwhere(mask)
#         flux = np.delete(flux, mask_idx, axis=0)
#         ivar = np.delete(ivar, mask_idx, axis=0)

        # Rebin and rescale fluxes so that each is normalized between 0 and 1.
        rewave, reflux, reivar = rebin_flux(wave, flux, ivar, minwave=minw, maxwave=maxw, nbins=nbins, log=True, clip=True)
        rsflux = rescale_flux(reflux)

        fluxes.append(rsflux)
    return fluxes

In [11]:
def condition_spectra(coadd_files, truth_files):
    """Read DESI spectra, rebin to a subsampled logarithmic wavelength grid, and rescale.
    
    Parameters
    ----------
    coadd_files : list or ndarray
        List of FITS files on disk with DESI spectra.
    truth_files : list or ndarray
        Truth files.
    
    Returns
    -------
    fluxes : ndarray
        Array of fluxes rebinned to a logarithmic wavelength grid.
    """
    fluxes = None
    
    for cf, tf in zip(coadd_files, truth_files):
        spectra = read_spectra(cf)
        wave = spectra.wave['brz']
        flux = spectra.flux['brz']
        ivar = spectra.ivar['brz']
        
        truth = Table.read(tf, 'TRUTH')
        try:
            truez = truth['TRUEZ']
        except:
            print(truth)
            print(cf, tf)

#         # Pre-condition: remove spectra with NaNs and zero flux values.
#         mask = np.isnan(flux).any(axis=1) | (np.count_nonzero(flux, axis=1) == 0)
#         mask_idx = np.argwhere(mask)
#         flux = np.delete(flux, mask_idx, axis=0)
#         ivar = np.delete(ivar, mask_idx, axis=0)

        # Rebin and rescale fluxes so that each is normalized between 0 and 1.
        rewave, reflux, reivar = rebin_flux(wave, flux, ivar, truez, minwave=minw, maxwave=maxw, nbins=nbins, log=True, clip=True)
        rsflux = rescale_flux(reflux)

        if fluxes is None:
            fluxes = rsflux
        else:
            fluxes = np.concatenate((fluxes, rsflux))
    
    return fluxes

In [12]:
tde_files = sorted(glob("TDE_Spectra/*.ascii"))
tde_flux = np.asarray(condition_tde(tde_files))
tde_flux.shape

(56, 160)