# Synthetic Strong Lens Generator

This notebook grabs pairs of spectra from observed DESI tiles and combines them according to the expression

$$
\mathcal{M}_3 = \alpha\mathcal{M}_1 + (1-\alpha)\mathcal{M}_2,
$$

where $\mathcal{M}_1$ is a model (redrock template) fit to the first observation and $\mathcal{M}_2$ is a model fit to the second spectrum. The value $\alpha$ is a uniformly generated parameter, e.g., $\alpha\sim U(0.1, 0.9)$. Given a model $\mathcal{M}_3$, the code computes a new realization of the flux of the combined spectra $f_3$.

Pairs of spectra are chosen (as of June 2021) from the same tile to ensure equal exposure times. Good redshift fits (`DELTACHI2>25` and `ZWARN==0`) are required. The member of the pair with the lower redshift value is referred to as the "lens," while the higher-redshift spectrum is the "background" object.

All spectra are saved in `desispec.Spectra` format. The following files are kept:
- `tileXXXXX_lens_spectra.fits`: Spectra of the lower-redshift objects in the selected pairs from tile ID XXXXX.
- `tileXXXXX_bkgd_spectra.fits`: Spectra of the higher-redshift objects in the selected pairs from tile ID XXXXX.
- `tileXXXXX_simlens_spectra.fits`: Combined spectra of the pairs from tile XXXXX.

The `Spectra` are stored with two additional tables:
1. `extra`: a table of model fits from redrock, keyed by spectrograph ('b', 'r', 'z').
2. `extra_catalog`: a table of `ZBEST` redshift fit and template fit coefficients from redrock.

### For the Impatient

The majority of this notebook contains functions that handle the spectrum bookkeeping. The code that actually generates the lenses is in the bottom 1/3 of the notebook starting in the section on [Generating Lenses](#generate_lens).

In [1]:
from glob import glob

from astropy.io import fits
from astropy.table import Table, join, vstack, hstack, unique

from desispec.io import read_spectra, write_spectra
from desispec.spectra import stack as specstack
from desispec.spectra import Spectra
from desispec.coaddition import coadd, coadd_cameras
from desispec.interpolation import resample_flux
from desispec.resolution import Resolution
from desispec.specscore import compute_coadd_scores

import redrock.templates

import copy
import shutil
import os

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt

In [2]:
mpl.rc('font', size=16)
mpl.rc('axes', titlesize='medium')

## Set up Redrock Templates

Instantiate a set of redrock templates for later use in constructing $\mathcal{M}_1$ and $\mathcal{M}_2$.

In [3]:
templates = dict()
for f in redrock.templates.find_templates():
    t = redrock.templates.Template(f)
    templates[(t.template_type, t.sub_type)] = t

DEBUG: Read templates from /global/common/software/desi/cori/desiconda/20190804-1.3.0-spec/code/redrock-templates/master
DEBUG: Using default redshift range -0.0050-1.6997 for rrtemplate-galaxy.fits
DEBUG: Using default redshift range 0.0500-5.9934 for rrtemplate-qso.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-A.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-B.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-CV.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-F.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-G.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-K.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-M.fits
DEBUG: Using default redshift range -0.0020-0.0020 for rrtemplate-star-WD.fits


## Bookkeeping Functions

A set of convenience functions to extract redshifts and spectra for the tile from a given spectroscopic reduction.

In [4]:
class SynthLens:
    
    def __init__(self, tileid, tilepath, destpath, tilefrac=0.15, alpha_min=0.1, alpha_max=0.9):
        """Generate synthetic strong lenses from one tile.
        
        Parameters
        ----------
        tileid : int
            DESI tile ID.
        tilepath : str
            Path to coadd and redrock FITS files for one tile.
        destpath : str
            Output directory for generated spectra.
        tilefrac : float
            Fraction of spectra in a tile to pair up in "lens" systems.
        alpha_min : float
            Number in [0,1] giving the minimum lens/background flux ratio.
        alpha_max : float
            Number in [0,1] giving the maximum lens/background flux ratio.
        """
        self.tileid   = tileid
        self.tilepath = tilepath
        self.destpath = destpath
        self.tilefrac = tilefrac
        self.alphamin = alpha_min
        self.alphamax = alpha_max
        
        self.redshifts = None          # Table of redrock best-fit redshifts after quality cuts.
        self.fibermap = None           # Fibermap corresponding to extracted spectra.
        self.exp_fibermap = None       # Exposure fibermap corresponding to extra
        
        self.z_fgr = None              # Table of redshift data for the randomly selected low-redshift "foreground" object.
        self.z_bkg = None              # Table of redshift data for the randomly selected high-redshift "background" object.
        self.z_join = None             # A Cartesian join of zfgrtab and zbkgtab.
        self.z_sgl = None              # Table of redshift data for objects not used as synthetic strong lenses.
        
        self.target_fgr = None         # The TARGETID, TILEID, PETAL_LOC of the "foreground" objects in the lens system.
        self.target_bkg = None         # The TARGETID, TILEID, PETAL_LOC of the "background" objects in the lens system.
        self.target_sgl = None         # The TARGETID, TILEID, PETAL_LOC of objects *not* used as synthetic strong lenses.
        
        self.spec_sgl = None           # Spectra of single objects in the tile.
        self.spec_fgr = None           # Spectra of foreground objects (the lensing galaxies).
        self.spec_bkg = None           # Spectra of background objects.
        self.spec_len = None           # Spectra of synthetic strong lensing systems.
        
        # 1. Grab redrock data for this tile and extract good spectra.
        self._get_redrock_data()
        
        # 2. Randomly pair up objects to form "strong lenses."
        self._select_pairs()
        
        # 3. Extract coadd spectra for the selected objects.
        self._get_coadds()
        
        # 4. Add foreground and background spectra to produce lens spectra.
        self._generate_lens_spectra()
        
        # 5. Write output.
        self._write_fits()
        
        # 6. Update group permissions and copy to output.
        self._copy_output()
        
    def _get_redrock_data(self):
        """Given a path to a tile folder, extract all REDSHIFT and EXP_FIBERMAP data from the tile.
            
        Returns
        -------
        nspec : int
            Number of spectra with passing redshifts pulled from this tile in total.
        """
        rrfiles = sorted(glob('{}/redrock*.fits'.format(self.tilepath)))
        
        # Loop through all files corresponding to individual petals.
        for rrfile in rrfiles:
            redshifts = Table.read(rrfile, 'REDSHIFTS')
            
            # Select all non-stellar objects with solid redshifts.
            select = (redshifts['SPECTYPE'] != 'STAR') & (redshifts['ZWARN'] == 0) & (redshifts['DELTACHI2'] >= 25) & (redshifts['TARGETID'] > 0)
            redshifts = redshifts[select]
            
            # Select the exp_fibermap table with all selected redshifts.
            exp_fmap = Table.read(rrfile, 'EXP_FIBERMAP')
            idx = np.in1d(exp_fmap['TARGETID'], redshifts['TARGETID'])
            exp_fmap = exp_fmap[idx]
            
            # Accumulate all redrock and exp_fibermap tables for the tile.
            if self.redshifts is None or self.exp_fibermap is None:
                self.redshifts = redshifts
                self.exp_fibermap = exp_fmap
            else:
                self.redshifts = vstack([self.redshifts, redshifts])
                self.exp_fibermap = vstack([self.exp_fibermap, exp_fmap])
        
        return len(self.redshifts)
    
    def _select_pairs(self):
        """Given a redshift and fibermap table, select random pairs of spectra (without replacement).
        """
        # Select pairs of galaxies and QSOs.
        pairs = np.random.choice(self.redshifts['TARGETID'], [int(self.tilefrac*len(self.redshifts)), 2], replace=False)

        # Loop through pairs of TARGETIDs.
        for pair in pairs:
            # Check the fibermap table to ensure the exposure times of the lens and background are equal.
            # Unequal total exposure times could occur if a petal or a CANBus was disabled during one exposure. 
            select0 = np.in1d(self.exp_fibermap['TARGETID'], pair[0])
            select1 = np.in1d(self.exp_fibermap['TARGETID'], pair[1])
            i, j = np.where(select0)[0], np.where(select1)[0]
            exptime0, exptime1 = [np.sum(self.exp_fibermap[_]['EXPTIME']) for _ in (i,j)]
            if exptime0 != exptime1:
                continue

            # Sort to make index k the "lens" and index l the "background" galaxy.
            select = np.in1d(self.redshifts['TARGETID'], pair)
            pairdata = self.redshifts[select]
            k, l = np.where(select)[0] if pairdata[0]['Z'] < pairdata[1]['Z'] else np.where(select)[0][::-1]

            # Then join the two ZBEST table entries into one row.
            row = join(self.redshifts[k], self.redshifts[l], join_type='cartesian')

            # Generate a random number giving the relative contribution of objects 1 and 2.
            alpha = np.random.uniform(self.alphamin, self.alphamax)
            row['ALPHA'] = alpha

            # Accumulate the rows of pairs of "lenses" and "background" galaxies.
            if self.z_join is None:
                self.z_join = row
                self.z_fgr = self.redshifts[k]
                self.z_bkg = self.redshifts[l]
            else:
                self.z_join = vstack([self.z_join, row])
                self.z_fgr = vstack([self.z_fgr, self.redshifts[k]])
                self.z_bkg = vstack([self.z_bkg, self.redshifts[l]])

        # Index j selects rows from the fibermap that passed the redshift cuts.
        j = np.in1d(self.exp_fibermap['TARGETID'], self.redshifts['TARGETID'])

        # Index i selects targets chosen as "foreground" objects.
        # Remove these from index j.
        i = np.in1d(self.exp_fibermap['TARGETID'], self.z_fgr['TARGETID'])
        self.target_fgr = unique(self.exp_fibermap[i]['TARGETID', 'TILEID', 'PETAL_LOC'])

        j = ~i & j

        # Index i selects targets paired with lenses as "background" objects.
        # Remove these from index j.
        i = np.in1d(self.exp_fibermap['TARGETID'], self.z_bkg['TARGETID'])
        self.target_bkg = unique(self.exp_fibermap[i]['TARGETID', 'TILEID', 'PETAL_LOC'])

        j = ~i & j

        # Tabulate all objects with good redshifts that are neither "lenses" nor "background" objects.
        # Also output the ZBEST table for this list of objects.
        self.target_sgl = unique(self.exp_fibermap[j]['TARGETID', 'TILEID', 'PETAL_LOC'])
        idx = np.nonzero(self.target_sgl['TARGETID'][:,None] == self.redshifts['TARGETID'])[1]
        self.z_sgl = self.redshifts[idx]
        
    def _get_coadds(self):
        """Given REDSHIFT data and TARGET info, extract coadds to Spectra objects.
        """
        cofiles = sorted(glob('{}/coadd*.fits'.format(self.tilepath)))
        
        # Loop over all petals.
        ef_sgl, ef_fgr, ef_bkg = None, None, None
        
        for cofile in cofiles:
            petal, tile = [int(_) for _ in os.path.basename(cofile).split('-')[1:3]]
            coadds = read_spectra(cofile)
            
            # Extract spectra for the single-object (unpaired) galaxies in this petal.
            targetids = self.target_sgl[self.target_sgl['PETAL_LOC'] == petal]['TARGETID']
            i = np.in1d(coadds.fibermap['TARGETID'], targetids)
            j = np.in1d(coadds.exp_fibermap['TARGETID'], targetids)

            if self.spec_sgl is None:
                self.spec_sgl = coadds[i]
                ef_sgl = coadds.exp_fibermap[j]
            else:
                self.spec_sgl = specstack([self.spec_sgl, coadds[i]])
                ef_sgl = vstack([ef_sgl, coadds.exp_fibermap[j]])
            
            if coadds[i].exp_fibermap:
                print(len(coadds.exp_fibermap[i]))
                
            # Extract spectra for the foreground galaxies in this petal.
            targetids = self.target_fgr[self.target_fgr['PETAL_LOC'] == petal]['TARGETID']
            i = np.in1d(coadds.fibermap['TARGETID'], targetids)
            j = np.in1d(coadds.exp_fibermap['TARGETID'], targetids)

            if self.spec_fgr is None:
                self.spec_fgr = coadds[i]
                ef_fgr = coadds.exp_fibermap[j]
            else:
                self.spec_fgr = specstack([self.spec_fgr, coadds[i]])
                ef_fgr = vstack([ef_fgr, coadds.exp_fibermap[j]])
                
            # Extract spectra for the background galaxies in this petal.
            targetids = self.target_bkg[self.target_bkg['PETAL_LOC'] == petal]['TARGETID']
            i = np.in1d(coadds.fibermap['TARGETID'], targetids)
            j = np.in1d(coadds.exp_fibermap['TARGETID'], targetids)

            if self.spec_bkg is None:
                self.spec_bkg = coadds[i]
                ef_bkg = coadds.exp_fibermap[j]
            else:
                self.spec_bkg = specstack([self.spec_bkg, coadds[i]])
                ef_bkg = vstack([ef_bkg, coadds.exp_fibermap[j]])

        # Unscramble the indices so the order of the spectra matches our coadd.
        # Then store redrock outputs in the extra_catalog and extra members of Spectra.
        idx = np.nonzero(self.z_sgl['TARGETID'][:,None] == self.spec_sgl.fibermap['TARGETID'])[1]
        self.spec_sgl = self.spec_sgl[idx]
        self.spec_sgl.extra_catalog = self.z_sgl
        self.spec_sgl.extra = self._get_redrock_models(self.spec_sgl)
        self.spec_sgl.exp_fibermap = ef_sgl
        
        idx = np.nonzero(self.z_fgr['TARGETID'][:,None] == self.spec_fgr.fibermap['TARGETID'])[1]
        self.spec_fgr = self.spec_fgr[idx]
        self.spec_fgr.extra_catalog = self.z_fgr
        self.spec_fgr.extra = self._get_redrock_models(self.spec_fgr)
        self.spec_fgr.exp_fibermap = ef_fgr
        
        idx = np.nonzero(self.z_bkg['TARGETID'][:,None] == self.spec_bkg.fibermap['TARGETID'])[1]
        self.spec_bkg = self.spec_bkg[idx]
        self.spec_bkg.extra_catalog = self.z_bkg
        self.spec_bkg.extra = self._get_redrock_models(self.spec_bkg)
        self.spec_bkg.exp_fibermap = ef_bkg
        
    def _get_redrock_models(self, targspec):
        """Given Spectra + redshift, compute best-fit redrock model fluxes.
        
        Parameters
        ----------
        targspec : Spectra
            Input spectra with redshift info in the extra_catalog member.
        
        Returns
        -------
        model : dict
            Model fluxes keyed by spectrograph camera.
        """
        
        model = {}
        for band in 'brz':
            bandmodel = []
            
            for i in range(targspec.num_spectra()):
                z = targspec.extra_catalog[i]['Z']
                sp, sb = targspec.extra_catalog[i]['SPECTYPE'], targspec.extra_catalog[i]['SUBTYPE']
                ncoeff = templates[(sp, sb)].flux.shape[0]
                coeff = targspec.extra_catalog[i]['COEFF'][0:ncoeff]
                tflux = templates[(sp, sb)].flux.T.dot(coeff)
                twave = templates[(sp, sb)].wave * (1 + z)

                R = Resolution(targspec.resolution_data[band][i])
                txflux = R.dot(resample_flux(targspec.wave[band], twave, tflux))
                bandmodel.append(txflux)

            model[band] = { 'model' : np.asarray(bandmodel) }
        return model

#         # Turn off scores.
#         if not hasattr(targspec, 'scores_comments'):
#             targspec.scores_comments = None

#         return targspec

    def _generate_lens_spectra(self):
        """Generate realizations of strong lenses by adding the spectra.
        The result is combo = alpha*lens + (1-alpha)*bkgd
        """
        # Build up a list of arrays and dictionaries needed to instantiate Spectra.
        bands = []
        wave = {} 
        flux = {}
        ivar = {}
        mask = {}
        resolution = {}
        fibermap = None
        extra = {}
        extra_catalog=None

        # Loop through the observed bands and merge the model fits.
        for band in 'brz':
            f1, w1 = self.spec_fgr.flux[band], self.spec_fgr.ivar[band]
            m1 = []
            f2, w2 = self.spec_bkg.flux[band], self.spec_bkg.ivar[band]
            m2 = []
            alpha = self.z_join['ALPHA'][:,None]
            w3 = w1*w2 / (alpha*w2 + (1-alpha)*w1)

            # Add the models using the alpha parameter to tune the relative contribution of the lens and background object.
            m1 = self.spec_fgr.extra[band]['model']
            m2 = self.spec_bkg.extra[band]['model']
            m3 = alpha*m1 + (1-alpha)*m2

            # Compute a "noise" vector using the differences between the observed fluxes and model fits.
            n3 = np.sqrt(alpha*(f1-m1)**2 + (1-alpha)*(f2-m2)**2)

            # Create a realized flux as a Gaussian with expectation given by the model and 
            # width given by the noise vector.
            f3 = np.random.normal(loc=m3, scale=n3)

            # Set up the spectrum wavelength, flux, variance, mask, bands.
            wave[band] = self.spec_fgr.wave[band]
            flux[band] = f3
            ivar[band] = w3
            mask[band] = self.spec_fgr.mask[band] | self.spec_bkg.mask[band]
            resolution[band] = self.spec_fgr.resolution_data[band]       # Maybe try to add the resolution matrices properly?
            bands.append(band)
            extra[band] = { 'model' : m3 }

        # Set up the fibermap as a join of the lens and background fibermaps.
    #     for row1, row2 in zip(lenspec.fibermap, bkgspec.fibermap):
    #         newrow = join(row1, row2, join_type='cartesian')
    #         if fibermap is None:
    #             fibermap = newrow
    #         else:
    #             fibermap = vstack([fibermap, newrow])
    #     fibermap['TARGETID'] = fibermap['TARGETID_1']

        # Add redshift info from the two individual spectra as an extra catalog.
        extra_catalog = self.z_join

        self.spec_len = Spectra(bands, wave, flux, ivar, mask, resolution_data=resolution,
                                fibermap=self.spec_fgr.fibermap,
                                exp_fibermap=self.spec_fgr.exp_fibermap,
                                extra=extra,
                                extra_catalog=extra_catalog)
        
        compute_coadd_scores(self.spec_len)

    def _write_fits(self):
        """Write extracted spectra to FITS output.
        """
        write_spectra('tile{:06d}_sgl_spectra.fits'.format(self.tileid), self.spec_sgl)
        write_spectra('tile{:06d}_fgr_spectra.fits'.format(self.tileid), self.spec_fgr)
        write_spectra('tile{:06d}_bkg_spectra.fits'.format(self.tileid), self.spec_bkg)
        write_spectra('tile{:06d}_simlens_spectra.fits'.format(self.tileid), self.spec_len)
        
        # Coadd the cameras and store the output.
        extra, self.spec_sgl.extra = copy.copy(self.spec_sgl.extra), None
        coadd = coadd_cameras(self.spec_sgl)
        self.spec_sgl.extra = extra
        write_spectra('tile{:06d}_sgl_coadd.fits'.format(self.tileid), coadd)
        
        extra, self.spec_len.extra = copy.copy(self.spec_len.extra), None
        coadd = coadd_cameras(self.spec_len)
        self.spec_len.extra = extra
        write_spectra('tile{:06d}_simlens_coadd.fits'.format(self.tileid), coadd)
        
    def _copy_output(self):
        """Copy output spectra to desired final location.
        """
        outspecs = sorted(glob('tile{:06d}*.fits'.format(self.tileid)))
        for outspec in outspecs:
            shutil.chown(os.path.abspath(outspec), group='desi')
            shutil.move(os.path.abspath(outspec), os.path.join(self.destpath, outspec))

## Generate Lensed Spectra
<a id='generate_lens'></a>

Build a list of SV tiles and generate synthetic strong lenses by pairing spectra.

In [5]:
# Use the daily cumulative reductions, always grabbing the most recent spectra.
redux = 'everest/tiles/cumulative'

tiles = sorted(glob('{}/{}/*'.format(os.environ['DESI_SPECTRO_REDUX'], redux)))

sv1tiles = []
sv3tiles = []
for tile in tiles:
    tileid = int(os.path.basename(tile))
    d = sorted(glob('{}/*'.format(tile)))
    if tileid < 20000:
        sv3tiles.append(d[-1])
    if tileid > 80000:
        sv1tiles.append(d[-1])

In [6]:
len(sv3tiles)

793

In [None]:
nlens, nsgl = 0, 0
destdir = '/global/project/projectdirs/desi/science/gqp/stronglens/training/everest'

for i, sv3tile in enumerate(sv3tiles):
    tileid, date = [int(x) for x in sv3tile.split('/')[-2:]]
    print('\nGenerating strong lenses using TILE {}'.format(tileid), flush=True)
    
    slens = SynthLens(tileid, tilepath=sv3tile, destpath=destdir)
    nlens += slens.spec_len.num_spectra()
    nsgl  += slens.spec_sgl.num_spectra()

    print('Cumulative: {} single object spectra and {} "lens" spectra.'.format(nsgl, nlens), flush=True)
    
    if nlens > 10000:
        break


Generating strong lenses using TILE 1


### Make Plots

In [None]:
mpl.rc('figure', max_open_warning = 0)

In [None]:
from scipy.ndimage import gaussian_filter1d

fgrspec = synthlens.spec_fgr
bkgspec = synthlens.spec_bkg
simspec = synthlens.spec_len

for j in range(fgrspec.num_spectra()):
    
    fig, axes = plt.subplots(1,3, figsize=(16,4.5), sharex=True, sharey=True, tight_layout=True)

    for b in 'brz':
        ax = axes[0]
        smoothed = gaussian_filter1d(fgrspec.flux[b][j], 5)
        ax.plot(fgrspec.wave[b], smoothed, lw=1, alpha=0.8, label=b)
        smoothed = gaussian_filter1d(fgrspec.extra[b]['model'][j], 5)
        ax.plot(bkgspec.wave[b], smoothed, lw=1, color='k', ls='--')
        ax.set(title='Lens: $z={:.3f}$ ({})'.format(fgrspec.extra_catalog[j]['Z'], fgrspec.extra_catalog[j]['SPECTYPE']),
               xlabel=r'$\lambda_\mathrm{obs}$ [$\AA$]',
               ylabel=r'flux [erg s$^{-1}$ cm$^{-2}$ $\AA^{-1}$]')
        ax.grid(ls=':')

        ax = axes[1]
        smoothed = gaussian_filter1d(bkgspec.flux[b][j], 5)
        ax.plot(bkgspec.wave[b], smoothed, lw=1, alpha=0.8)
        smoothed = gaussian_filter1d(bkgspec.extra[b]['model'][j], 5)
        ax.plot(bkgspec.wave[b], smoothed, lw=1, color='k', ls='--')
        ax.set(title='Bkg: $z={:.3f}$ ({})'.format(bkgspec.extra_catalog[j]['Z'], bkgspec.extra_catalog[j]['SPECTYPE']),
               xlabel=r'$\lambda_\mathrm{obs}$ [$\AA$]')
        ax.grid(ls=':')

        ax = axes[2]
        smoothed = gaussian_filter1d(simspec.flux[b][j], 5)
        ax.plot(simspec.wave[b], smoothed, lw=1, alpha=0.8)
        alpha = simspec.extra_catalog['ALPHA'][j]
        ax.set(title=r'Combined: ${:.2f}\mathcal{{M}}_1 + {:.2f}\mathcal{{M}}_2$'.format(alpha, 1-alpha),
               xlabel=r'$\lambda_\mathrm{obs}$ [$\AA$]')
        ax.grid(ls=':')
        
    ax = axes[0]
    ax.legend(fontsize=12, ncol=1, loc='best')
        
#     fig.savefig('figures/tile{:06d}_synthlens{:03d}.png'.format(tile, j), dpi=120)
    if j >= 3:
        break