In [None]:
from astropy.cosmology import FlatLambdaCDM
from astropy.units import Quantity

# Removed unnecessary coverage installation
import numpy as np
import slsim.Sources as sources
import slsim.Deflectors as deflectors
import slsim.Pipelines as pipelines
from slsim.Sources.SourceCatalogues.QuasarCatalog.quasar_pop import QuasarRate
from slsim.Lenses.lens_pop import LensPop
import matplotlib.pyplot as plt
from slsim.ImageSimulation.image_simulation import (
    point_source_coordinate_properties,
    lens_image_series,
    rgb_image_from_image_list,
)
from slsim.Plots.plot_functions import create_image_montage_from_image_list
import pandas as pd
from contextlib import redirect_stdout
import io
import astropy.coordinates as coord
import astropy.units as u
from slsim.Util.all_plotting_functions import make_contour

%load_ext autoreload
%autoreload 2

## Lensed quasar population

This notebook shows how to simulate lensed quasar population with variability.

In [None]:
# define a cosmology
cosmo = FlatLambdaCDM(H0=70, Om0=0.3)


# define a sky area
galaxy_sky_area = Quantity(
    value=5, unit="deg2"
)  # this is the sky area over which galaxies are sampled
quasar_sky_area = Quantity(value=5, unit="deg2")
full_sky_area = Quantity(value=500, unit="deg2")


# define limits in the intrinsic deflector and source population (in addition
# to the skypy config
# file)
kwargs_deflector_cut = {"band": "i", "band_max": 28, "z_min": 0.01, "z_max": 2.5}
kwargs_source_cut = {"band": "i", "band_max": 26, "z_min": 0.001, "z_max": 5.0}

In [None]:
# generate galaxy population using skypy pipeline.
galaxy_simulation_pipeline = pipelines.SkyPyPipeline(
    skypy_config=None,
    sky_area=galaxy_sky_area,
    filters=["u", "g", "r", "i", "z", "y"],
    cosmo=cosmo,
)

In [None]:
# Initiate deflector popiulation class
lens_galaxies = deflectors.AllLensGalaxies(
    red_galaxy_list=galaxy_simulation_pipeline.red_galaxies,
    blue_galaxy_list=galaxy_simulation_pipeline.blue_galaxies,
    kwargs_cut=kwargs_deflector_cut,
    kwargs_mass2light={},
    cosmo=cosmo,
    gamma_pl={"mean": 2, "std_dev": 0.16},
    sky_area=galaxy_sky_area,
)

In [None]:
lens_galaxies.deflector_table

In [None]:
# Initiate QuasarRate class to generate quasar sample.
quasar_class = QuasarRate(
    cosmo=cosmo,
    sky_area=quasar_sky_area,
    noise=True,
    redshifts=np.linspace(0.001, 5.01, 100),  # these redshifts are provided
    # to match general slsim redshift range in skypy pipeline.
)

# quasar sample with host galaxy
quasar_source_plus_galaxy = quasar_class.quasar_sample(
    m_min=15, m_max=28, host_galaxy=True
)

### We need to sample from the distribution below to get the driving variability parameters

In [None]:
# define the variability probability distribution
### reed in dc2 files
sources_dc2 = pd.read_csv("../../data/OM10/sources3.csv")
sources_dc2["bh_mass"] = np.log10(
    0.0049
    * sources_dc2["stellar_mass_bulge"]
    * (sources_dc2["stellar_mass_bulge"] / 1e11) ** (0.15)
)
sources_dc2["log_tau_i"] = np.log10(sources_dc2["tau_i"])
sources_dc2["log_sf_i"] = np.log10(sources_dc2["sf_i"])

fig = make_contour(
    [sources_dc2[["bh_mass", "M_i", "log_sf_i", "log_tau_i", "ZSRC"]]],
    labels=[
        "$\log(M_{BH}/M_\odot)$",
        "$M_i$",
        "$ \log(SF_i / mag)$",
        "$\log(\\tau_i/days)$",
        "$z_{src}$",
    ],
    categories=["om10-cosmoDC2 sample"],
    colors=["blue"],
    show_correlation=True,
)
# fig.savefig('../../../bh_var_corr.pdf', dpi=300)

In [None]:
# Prepare dictionary of agn variability kwargs
variable_agn_kwarg_dict = {
    "length_of_light_curve": 500,
    "time_resolution": 1,
    "log_breakpoint_frequency": 1 / 20,
    "low_frequency_slope": 1,
    "high_frequency_slope": 3,
    "standard_deviation": 0.9,
}

kwargs_quasar = {
    "variability_model": "light_curve",
    "kwargs_variability": {"agn_lightcurve", "u", "g", "r", "i", "z", "y"},
    "agn_driving_variability_model": "bending_power_law",
    "agn_driving_kwargs_variability": variable_agn_kwarg_dict,
    "lightcurve_time": np.linspace(0, 1000, 500),
}
# Initiate source population class.
source_quasar_plus_galaxies = sources.PointPlusExtendedSources(
    point_plus_extended_sources_list=quasar_source_plus_galaxy,
    cosmo=cosmo,
    sky_area=quasar_sky_area,
    kwargs_cut=kwargs_source_cut,
    list_type="astropy_table",
    catalog_type="skypy",
    point_source_type="quasar",
    extended_source_type="single_sersic",
    point_source_kwargs=kwargs_quasar,
)

In [None]:
# Initiate LensPop class to generate lensed quasar pop.
quasar_lens_pop = LensPop(
    deflector_population=lens_galaxies,
    source_population=source_quasar_plus_galaxies,
    cosmo=cosmo,
    sky_area=full_sky_area,
)

### Check the variability model of the sources

## Draw lenses

In [None]:
### LSST cuts: 'u': 23.9,'g':25.0, 'r':24.7, 'i':24.0, 'z':23.3, 'y':22.1
kwargs_lens_cuts = {
    "min_image_separation": 0.7,
    "max_image_separation": 10,
    "second_brightest_image_cut": {"g": 25.0},
}
# drawing population
# the key difference in lens population drawing time is whether you ask for magnitude cuts or not I think?
quasar_lens_population = quasar_lens_pop.draw_population(
    speed_factor=1000, kwargs_lens_cuts=kwargs_lens_cuts
)

### View association catalog

In [None]:
f = io.StringIO()
full_pop_df = pd.DataFrame()
with redirect_stdout(f):
    for i, lens_obj in enumerate(quasar_lens_population):
        full_pop_df = lens_obj.lens_to_dataframe(index=i, df=full_pop_df)

In [None]:
quasar_lens_population[0].lens_to_dataframe()

In [None]:
ra, dec = quasar_lens_population[0].point_source_image_positions()[0][0]
ra, dec
quasar_lens_population[0].kappa_star(ra, dec)

In [None]:
plt.hist(full_pop_df["deflector_mass_theta_E"])

## Select a lens to visualize

In [None]:
# Select a random lens
kwargs_lens_cut = {"min_image_separation": 2, "max_image_separation": 10}
lens_class = quasar_lens_pop.select_lens_at_random(**kwargs_lens_cut)

## Set the cadence of observation

In [None]:
N = 10
ra_points = coord.Angle(np.random.uniform(low=0, high=360, size=N) * u.degree)
ra_points = ra_points.wrap_at(180 * u.degree)
# dec goes from -72 to +12
p = (
    np.sin(np.random.uniform(low=-72, high=12, size=N) * u.deg) - np.sin(-72 * u.deg)
) / (np.sin(12 * u.deg) - np.sin(-72 * u.deg))
dec_points = coord.Angle(
    ((((np.arcsin(2 * p - 1).to(u.deg) + 90 * u.deg) / (180 * u.deg)) * 84) - 72)
    * u.deg
)
dec_points

In [None]:
# Get a point source coordinate so that you can plot these image center in the plot.
def compute_magnitude_zeropoint(mag_zp_1s, exposure_time=30, gain=1):
    return mag_zp_1s + 2.5 * np.log10(exposure_time / gain)


bands = list("ugrizy")
mag_zps = np.array(
    [
        26.52,
        28.51,
        28.36,
        28.17,
        27.78,
        26.82,
    ]
)  # taken from https://smtn-002.lsst.io/
mag_zero_points_1_second = dict(zip(bands, mag_zps))  # mag

mag_zero_points_30_seconds = dict(
    zip(bands, compute_magnitude_zeropoint(mag_zps))
)  # mag
delta_pix = 0.2  # arcsec/pixel
num_pix = 33  # pixels
exp_time = 30  # s
pix_coord_list = [
    point_source_coordinate_properties(
        lens_class,
        band=i,
        mag_zero_point=mag_zero_points_30_seconds[i],
        delta_pix=delta_pix,
        num_pix=num_pix,
        transform_pix2angle=np.array([[0.2, 0], [0, 0.2]]),
    )
    for i in bands[1:]
]

## See the light curve of a selected quasar

In [None]:
time = np.linspace(-500, 1500, 500)

In [None]:
lens_class.image_number

In [None]:
# Get a point source coordinate so that you can plot these image center in the plot.
pix_coord = pix_coord_list[0]["image_pix"]
time = np.linspace(0, 365, 365)
time_sampled = np.linspace(0, 365, 15)[np.arange(0, 15, 3)]
image_number = lens_class.image_number
if isinstance(image_number, list):
    image_number = image_number[
        0
    ]  ### taking the number of images from the first plane source
fig, ax = plt.subplots(1, image_number, figsize=(30, 7))
ax = ax.flatten()
colors = ["red", "green", "orange", "blue"]
for band in list("grizy"):
    # loop through the bands and plot the light curves
    for i in range(image_number):
        ax[i].plot(
            time,
            lens_class.point_source_magnitude(band=band, lensed=True, time=time)[0][i],
            label=f"{band}-band, image-{i+1}",
        )
        ax[i].scatter(
            time_sampled,
            lens_class.point_source_magnitude(
                band=band, lensed=True, time=time_sampled
            )[0][i],
            marker="*",
            s=100,
            color=colors[i],
        )

ax[0].set_ylabel("Magnitude")
fig.supxlabel("Time" "[Days]")
for a in ax:
    a.invert_yaxis()
    a.legend()
fig.suptitle(
    f"Light curves for {lens_class.image_number} images of a multiply-imaged quasar"
)
fig.tight_layout()

# time = sorted(np.random.uniform(-20, 100, 10))
# time = np.array([0, 50, 70, 120])
repeats = len(time_sampled)
# load your psf kernel and transform matrix. If you have your own psf, please provide
# it here.
path = "../../tests/TestData/psf_kernels_for_deflector.npy"
psf_kernel = 1 * np.load(path)
psf_kernel[psf_kernel < 0] = 0
transform_matrix = np.array([[delta_pix, 0], [0, delta_pix]])

# let's set up psf kernel for each exposure. Here we have taken the same psf that we
# extracted above. However, each exposure can have different psf kernel and user should
# provide corresponding psf kernel to each exposure.
psf_kernel_list = [psf_kernel]
transform_matrix_list = [transform_matrix]
psf_kernels_all = psf_kernel_list * repeats
# psf_kernels_all = np.array([dp0["psf_kernel"][:10]])[0]

# let's set pixel to angle transform matrix. Here we have taken the same matrix for
# each exposure but user should provide corresponding transform matrix to each exposure.
transform_matrix_all = transform_matrix_list * repeats

# provide magnitude zero point for each exposures. Here we have taken the same magnitude
#  zero point for each exposure but user should provide the corresponding magnitude
# zero point for each exposure.
mag_list = [mag_zero_points_30_seconds["i"]]
mag_zero_points_all = mag_list * repeats
# mag_zero_points_all = np.array([dp0["zero_point"][:10]])[0]

expo_list = [exp_time]
exposure_time_all = expo_list * repeats
image_lens_series_all = lens_image_series(
    lens_class=lens_class,
    band="i",
    mag_zero_point=mag_zero_points_all,
    num_pix=num_pix,
    psf_kernel=[None] * repeats,
    transform_pix2angle=transform_matrix_all,
    exposure_time=exposure_time_all,
    std_gaussian_noise=None,
    t_obs=time,
    with_deflector=True,
    with_ps=True,
    with_source=True,
    add_noise=False,
)
plot_montage = create_image_montage_from_image_list(
    num_rows=1,
    num_cols=5,
    images=image_lens_series_all,
    time=time_sampled,
    image_center=pix_coord,
)
plot_montage.suptitle("All Light Included", y=1.01)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
ax1.plot(
    time,
    lens_class.point_source_magnitude(band="i", lensed=True, time=time)[0][0],
    label="i-band, image-1",
)
ax1.plot(
    time,
    lens_class.point_source_magnitude(band="r", lensed=True, time=time)[0][0],
    label="r-band, image-1",
)
ax1.plot(
    time,
    lens_class.point_source_magnitude(band="g", lensed=True, time=time)[0][0],
    label="g-band, image-1",
)
# ax1.gca().invert_yaxis()
ax1.set_ylabel("Magnitude")
ax1.set_xlabel("Time" "[Days]")
# ax1.set_xlim(0, 500)
# ax1.set_ylim(19.5, 21.25)
ax2.plot(
    time,
    lens_class.point_source_magnitude(band="i", lensed=True, time=time)[0][1],
    label="i-band, image-2",
)
ax2.plot(
    time,
    lens_class.point_source_magnitude(band="r", lensed=True, time=time)[0][1],
    label="r-band, image-2",
)
ax2.plot(
    time,
    lens_class.point_source_magnitude(band="g", lensed=True, time=time)[0][1],
    label="g-band, image-2",
)
# ax1.gca().invert_yaxis()
ax2.set_ylabel("Magnitude")
ax2.set_xlabel("Time" "[Days]")
# ax2.set_xlim(0, 250)
# ax2.set_ylim(None, 25)
ax2.invert_yaxis()
ax1.invert_yaxis()
ax1.legend()
ax2.legend()

## Set observation time and image configuration

In [None]:
time = np.linspace(0, 250, 10)
# time = sorted(np.random.uniform(-20, 100, 10))
# time = np.array([0, 50, 70, 120])
repeats = len(time)
# load your psf kernel and transform matrix. If you have your own psf, please provide
# it here.
path = "../tests/TestData/psf_kernels_for_deflector.npy"
psf_kernel = 1 * np.load(path)
psf_kernel[psf_kernel < 0] = 0
transform_matrix = np.array([[0.2, 0], [0, 0.2]])

# let's set up psf kernel for each exposure. Here we have taken the same psf that we
# extracted above. However, each exposure can have different psf kernel and user should
# provide corresponding psf kernel to each exposure.
psf_kernel_list = [psf_kernel]
transform_matrix_list = [transform_matrix]
psf_kernels_all = psf_kernel_list * repeats
# psf_kernels_all = np.array([dp0["psf_kernel"][:10]])[0]

# let's set pixel to angle transform matrix. Here we have taken the same matrix for
# each exposure but user should provide corresponding transform matrix to each exposure.
transform_matrix_all = transform_matrix_list * repeats

# provide magnitude zero point for each exposures. Here we have taken the same magnitude
#  zero point for each exposure but user should provide the corresponding magnitude
# zero point for each exposure.
mag_list = [31.0]
mag_zero_points_all = mag_list * repeats
# mag_zero_points_all = np.array([dp0["zero_point"][:10]])[0]

expo_list = [30]
exposure_time_all = expo_list * repeats

## Simulate Image

In [None]:
# Simulate a lens image
image_lens_series_i = lens_image_series(
    lens_class=lens_class,
    band="i",
    mag_zero_point=mag_zero_points_all,
    num_pix=64,
    psf_kernel=psf_kernels_all,
    transform_pix2angle=transform_matrix_all,
    exposure_time=exposure_time_all,
    t_obs=time,
    with_deflector=True,
    with_source=True,
)
image_lens_series_g = lens_image_series(
    lens_class=lens_class,
    band="g",
    mag_zero_point=mag_zero_points_all,
    num_pix=64,
    psf_kernel=psf_kernels_all,
    transform_pix2angle=transform_matrix_all,
    exposure_time=exposure_time_all,
    t_obs=time,
    with_deflector=True,
    with_source=True,
)
image_lens_series_r = lens_image_series(
    lens_class=lens_class,
    band="r",
    mag_zero_point=mag_zero_points_all,
    num_pix=64,
    psf_kernel=psf_kernels_all,
    transform_pix2angle=transform_matrix_all,
    exposure_time=exposure_time_all,
    t_obs=time,
    with_deflector=True,
    with_source=True,
)

In [None]:
rgb_image_list = []
for i in range(len(image_lens_series_i)):
    rgb_image_list.append(
        rgb_image_from_image_list(
            image_list=[
                image_lens_series_i[i],
                image_lens_series_r[i],
                image_lens_series_g[i],
            ],
            stretch=0.5,
        )
    )

## Visualize simulated images

In [None]:
plot_montage = create_image_montage_from_image_list(
    num_rows=2, num_cols=5, images=rgb_image_list, time=time, image_center=pix_coord
)

### Graveyard

In [None]:
# image_lens_series_lens_only = lens_image_series(
#     lens_class=lens_class,
#     band="i",
#     mag_zero_point=mag_zero_points_all,
#     num_pix=num_pix,
#     psf_kernel=[None] * repeats,
#     transform_pix2angle=transform_matrix_all,
#     exposure_time=exposure_time_all,
#     std_gaussian_noise=None,
#     t_obs=time,
#     with_deflector=True,
#     with_ps=False,
#     with_source=False,
#     add_noise=False,
# )
# image_lens_series_ps_only = lens_image_series(
#     lens_class=lens_class,
#     band="i",
#     mag_zero_point=mag_zero_points_all,
#     num_pix=num_pix,
#     psf_kernel=[None] * repeats,
#     transform_pix2angle=transform_matrix_all,
#     exposure_time=exposure_time_all,
#     std_gaussian_noise=None,
#     t_obs=time,
#     with_deflector=False,
#     with_ps=True,
#     with_source=False,
#     add_noise=False,
# )
# image_lens_series_src_only = lens_image_series(
#     lens_class=lens_class,
#     band="i",
#     mag_zero_point=mag_zero_points_all,
#     num_pix=num_pix,
#     psf_kernel=[None] * repeats,
#     transform_pix2angle=transform_matrix_all,
#     exposure_time=exposure_time_all,
#     std_gaussian_noise=None,
#     t_obs=time,
#     with_deflector=False,
#     with_ps=False,
#     with_source=True,
#     add_noise=False,
# )

In [None]:
# plot_montage = create_image_montage_from_image_list(
#     num_rows=1,
#     num_cols=5,
#     images=image_lens_series_lens_only,
#     time=time_sampled,
#     image_center=pix_coord,
# )
# plot_montage.suptitle("Only Lens Light", y=1.01)

# plot_montage = create_image_montage_from_image_list(
#     num_rows=1,
#     num_cols=5,
#     images=image_lens_series_ps_only,
#     time=time_sampled,
#     image_center=pix_coord,
# )
# plot_montage.suptitle("Lens Light Subtracted", y=1.01)

# plot_montage = create_image_montage_from_image_list(
#     num_rows=1,
#     num_cols=5,
#     images=image_lens_series_src_only,
#     time=time_sampled,
#     image_center=pix_coord,
# )
# plot_montage.suptitle("Lens and AGN Light Subtracted", y=1.01)