# IN PROGRESS

# Example notebook 2: simulating the HLWAS 

This notebook walks through the process of simulating strong lenses in Roman's HLWAS (High Latitude Wide Area Survey).

In [48]:
import sys

import numpy as np
import matplotlib.pyplot as plt
from hydra import initialize, compose
import astropy.cosmology as astropy_cosmo
from pyHalo.preset_models import CDM

# set paths to various directories based on the machine this code is being executed on
with initialize(version_base=None, config_path='config'):
    config = compose(config_name='config.yaml')

# enable use of local modules
repo_dir  = config.machine.repo_dir
if repo_dir not in sys.path:
    sys.path.append(repo_dir)

import mejiro
from mejiro.helpers import survey_sim
from mejiro.utils import util

# set matplotlib style
plt.style.use(f'{repo_dir}/mejiro/mplstyle/science.mplstyle')

In [None]:
roman_filters = sorted(glob(os.path.join(repo_dir, 'mejiro', 'data', 'avg_filter_responses', 'Roman-*.ecsv')))
_ = speclite.filters.load_filters(*roman_filters[:8])
if debugging:
    print('Configured Roman filters. Loaded:')
    pprint(roman_filters)

In [None]:
# load SkyPy config file
    module_path = os.path.dirname(mejiro.__file__)
    skypy_config = os.path.join(module_path, 'data', 'roman_hlwas.yml')
    if debugging: print(f'Loaded SkyPy configuration file {skypy_config}')

In [None]:
# set HLWAS parameters
    config_file = util.load_skypy_config(skypy_config)  # read skypy config file to get survey area
    survey_area = float(config_file['fsky'][:-5])
    sky_area = Quantity(value=survey_area, unit='deg2')
    cosmo = default_cosmology.get()
    bands = pipeline_params['bands']

In [None]:
# define cuts on the intrinsic deflector and source populations (in addition to the skypy config file)
    kwargs_deflector_cut = {
        'band': survey_params['deflector_cut_band'],
        'band_max': survey_params['deflector_cut_band_max'],
        'z_min': survey_params['deflector_z_min'],
        'z_max': survey_params['deflector_z_max']
    }
    kwargs_source_cut = {
        'band': survey_params['source_cut_band'],
        'band_max': survey_params['source_cut_band_max'],
        'z_min': survey_params['source_z_min'],
        'z_max': survey_params['source_z_max']
    }

In [None]:
# create the lens population
    if debugging: print('Defining galaxy population...')
    lens_pop = LensPop(deflector_type="all-galaxies",
                       source_type="galaxies",
                       kwargs_deflector_cut=kwargs_deflector_cut,
                       kwargs_source_cut=kwargs_source_cut,
                       kwargs_mass2light=None,
                       skypy_config=skypy_config,
                       sky_area=sky_area,
                       cosmo=cosmo)
    if debugging: print('Defined galaxy population')

In [None]:
# num_lenses = lens_pop.deflector_number()
    # num_sources = lens_pop.source_number()
    # print(f'Number of deflectors: {num_lenses}, scaled to HLWAS ({area_hlwas} sq deg): {int((area_hlwas / survey_area) * num_lenses)}')
    # print(f'Number of sources: {num_sources}, scaled to HLWAS ({area_hlwas} sq deg): {int((area_hlwas / survey_area) * num_sources)}')

In [None]:
# draw the total lens population
    if debugging: print('Identifying lenses...')
    kwargs_lens_total_cut = {
        'min_image_separation': 0,
        'max_image_separation': 10,
        'mag_arc_limit': None
    }
    total_lens_population = lens_pop.draw_population(kwargs_lens_cuts=kwargs_lens_total_cut)
    if debugging: print(f'Number of total lenses: {len(total_lens_population)}')

In [None]:
# apply additional detectability criteria
    limit = None
    detectable_gglenses, snr_list = [], []
    for candidate in tqdm(lens_population, disable=not debugging):
        # 1. Einstein radius and Sersic radius
        _, kwargs_params = candidate.lenstronomy_kwargs(band=survey_params['large_lens_band'])
        lens_mag = candidate.deflector_magnitude(band=survey_params['large_lens_band'])

        if kwargs_params['kwargs_lens'][0]['theta_E'] < kwargs_params['kwargs_lens_light'][0][
            'R_sersic'] and lens_mag < survey_params['large_lens_mag_max']:
            filter_1 += 1
            if filter_1 <= num_samples:
                filtered_sample['filter_1'].append(candidate)
            continue

        # 2. SNR
        snr, _ = survey_sim.get_snr(candidate, survey_params['snr_band'],
                                    mask_mult=survey_params['snr_mask_multiplier'])

        if snr < survey_params['snr_threshold']:
            snr_list.append(snr)
            filter_2 += 1
            if filter_2 <= num_samples:
                filtered_sample['filter_2'].append(candidate)
            continue

        # if both criteria satisfied, consider detectable
        detectable_gglenses.append(candidate)

        # if I've imposed a limit above this loop, exit the loop
        if limit is not None and len(detectable_gglenses) == limit:
            break

    if debugging: print(f'Run {str(run).zfill(2)}: {len(detectable_gglenses)} detectable lens(es)')

In [None]:
if debugging: print('Retrieving lenstronomy parameters...')
    dict_list = []
    for gglens, snr in tqdm(zip(detectable_gglenses, snr_list), disable=not debugging, total=len(detectable_gglenses)):

        # get lens params from gglens object
        kwargs_model, kwargs_params = gglens.lenstronomy_kwargs(band='F106')

        # build dicts for lens and source magnitudes
        lens_mags, source_mags = {}, {}
        for band in bands:  # add F158
            lens_mags[band] = gglens.deflector_magnitude(band)
            source_mags[band] = gglens.extended_source_magnitude(band)

        z_lens, z_source = gglens.deflector_redshift, gglens.source_redshift
        kwargs_lens = kwargs_params['kwargs_lens']

        # add additional necessary key/value pairs to kwargs_model
        kwargs_model['lens_redshift_list'] = [z_lens] * len(kwargs_lens)
        kwargs_model['source_redshift_list'] = [z_source]
        kwargs_model['cosmo'] = cosmo
        kwargs_model['z_source'] = z_source
        kwargs_model['z_source_convention'] = 5

        # create dict to pickle
        gglens_dict = {
            'kwargs_model': kwargs_model,
            'kwargs_params': kwargs_params,
            'lens_mags': lens_mags,
            'source_mags': source_mags,
            'deflector_stellar_mass': gglens.deflector_stellar_mass(),
            'deflector_velocity_dispersion': gglens.deflector_velocity_dispersion(),
            'snr': snr
        }

        dict_list.append(gglens_dict)

    if debugging: print('Pickling lenses...')
    for i, each in tqdm(enumerate(dict_list), disable=not debugging):
        save_path = os.path.join(lens_output_dir, f'detectable_lens_{str(run).zfill(2)}_{str(i).zfill(5)}.pkl')
        util.pickle(save_path, each)

    detectable_pop_csv = os.path.join(output_dir, f'detectable_pop_{str(run).zfill(2)}.csv')
    survey_sim.write_lens_pop_to_csv(detectable_pop_csv, detectable_gglenses, bands)