In [40]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import os
import sys
from glob import glob
from matplotlib.backends.backend_pdf import PdfPages
from astropy import units as u
from astropy.coordinates import SkyCoord
#pd.options.mode.copy_on_write = True
from astropy.table import Table
from astropy.cosmology import FlatLambdaCDM
import multiprocessing as mp
from tqdm import tqdm
import astropy.io.fits as fits
cosmo = FlatLambdaCDM(name='Planck18', H0=67.66, Om0=0.30966, Tcmb0=2.7255, Neff=3.046, m_nu=[0.  , 0.  , 0.06]* u.eV, Ob0=0.04897)
cores = 8
plt.style.use('ggplot')

In [22]:
from desispec.io import read_spectra
from desitrip.preproc import rebin_flux, rescale_flux

from glob import glob

In [49]:
minw, maxw, nbins = 3500., 8000., 160

In [15]:
def drop_nan(mytable):
    has_nan = np.zeros(len(mytable), dtype=bool)
    for col in mytable.itercols():
        if col.info.dtype.kind == 'f':
            has_nan |= np.isnan(col)
    mytable_no_nan = mytable[~has_nan]
    return mytable_no_nan

In [16]:
def condition_tde(coadd_files):
    """Read DESI spectra, rebin to a subsampled logarithmic wavelength grid, and rescale.
    
    Parameters
    ----------
    coadd_files : list or ndarray
        List of ascii files
    
    Returns
    -------
    fluxes : ndarray
        Array of fluxes rebinned to a logarithmic wavelength grid.
    """
    fluxes = []
    
    for cf in coadd_files:
        which = len(Table.read(cf, format = "ascii").colnames)
        if(which == 3):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux", "flux_err"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = spectra["flux_err"]
        elif(which == 6):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux", "sky_flux", "flux_err", "xpixel", "ypixel", "resopnse"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = spectra["flux_err"]
        elif(which == 2):
            spectra = drop_nan(Table.read(cf, format = "ascii", names = ["wave", "flux"]))
            wave = spectra["wave"]
            flux = spectra["flux"]
            ivar = None

#         # Pre-condition: remove spectra with NaNs and zero flux values.
#         mask = np.isnan(flux).any(axis=1) | (np.count_nonzero(flux, axis=1) == 0)
#         mask_idx = np.argwhere(mask)
#         flux = np.delete(flux, mask_idx, axis=0)
#         ivar = np.delete(ivar, mask_idx, axis=0)

        # Rebin and rescale fluxes so that each is normalized between 0 and 1.
        spectra_table = Table()
        spectra_table["wave"] = wave
        spectra_table["flux"] = flux

        fluxes.append(spectra_table)
    return fluxes

In [17]:
def normalize_tde(tables):
    """Read DESI spectra, rebin to a subsampled logarithmic wavelength grid, and rescale.
    
    Parameters
    ----------
    coadd_files : list or ndarray
        List of ascii files
    
    Returns
    -------
    fluxes : ndarray
        Array of fluxes rebinned to a logarithmic wavelength grid.
    """
    fluxes = []
    
    for cf in tables:
        wave = cf["wave"]
        flux = cf["flux"]

#         # Pre-condition: remove spectra with NaNs and zero flux values.
#         mask = np.isnan(flux).any(axis=1) | (np.count_nonzero(flux, axis=1) == 0)
#         mask_idx = np.argwhere(mask)
#         flux = np.delete(flux, mask_idx, axis=0)
#         ivar = np.delete(ivar, mask_idx, axis=0)

        # Rebin and rescale fluxes so that each is normalized between 0 and 1.
        rewave, reflux, reivar = rebin_flux(wave, flux, minwave=minw, maxwave=maxw, nbins=nbins, log=True, clip=True)
        rsflux = rescale_flux(reflux)

        fluxes.append(rsflux)
    return fluxes

In [41]:
tde_files = sorted(glob("TDE_Spectra/*.ascii"))
tde_tables = condition_tde(tde_files)

In [50]:
unredshifted_tdes = []
zs = np.linspace(0.21, 1.21, 21)
for i in tde_tables:
    for z in zs:
        new_tde = Table()
        redshifted_wave = i["wave"]/(1 + z)
        redshifted_wave_cutoffs = np.where((redshifted_wave > minw) & (redshifted_wave < maxw))
        new_tde["wave"] = i["wave"][redshifted_wave_cutoffs]
        new_tde["flux"] = i["flux"][redshifted_wave_cutoffs]
        unredshifted_tdes.append(new_tde)

In [54]:
unredshifted_tdes[6] 

wave,flux
float64,float64
5286.66047935,2.85724e-15
5288.83587437,2.94108e-15
5291.01126938,2.91795e-15
5293.18666439,2.96881e-15
5295.3620594,2.904e-15
5297.53745441,2.88082e-15
5299.71284943,2.90797e-15
5301.88824444,2.91961e-15
5304.06363945,2.95796e-15
5306.23903446,2.93858e-15


In [27]:
tde_flux = np.asarray(normalize_tde(unredshifted_tdes))
tde_flux.shape

(1176, 160)

In [29]:
path = "/global/cfs/cdirs/desi/spectro/templates"

In [39]:
Table(fits.open("/global/cfs/cdirs/desi/spectro/templates/sne_templates/v1.0/sne_templates_v1.0.fits")[1].data)

TEMPLATEID,EPOCH
int32,float32
0,-18.0
1,-17.0
2,-16.0
3,-15.0
4,-14.0
5,-13.0
6,-12.0
7,-11.0
8,-10.0
9,-9.0
