# Lensing playground

This is just a fun notebook where you can interactively change lensing parameters. It is a great way to build some intuition around lensing and make some cool pictures!

In [None]:
%load_ext autoreload
%autoreload 2

import torch
from torch.nn.functional import avg_pool2d
import matplotlib.pyplot as plt
from ipywidgets import interact
from astropy.io import fits
import numpy as np

import caustics

In [None]:
n_pix = 100
res = 0.05
upsample_factor = 2
fov = res * n_pix
thx, thy = caustics.utils.get_meshgrid(
    res / upsample_factor,
    upsample_factor * n_pix,
    upsample_factor * n_pix,
    dtype=torch.float32,
)
z_l = torch.tensor(0.5, dtype=torch.float32)
z_s = torch.tensor(1.5, dtype=torch.float32)
cosmology = caustics.FlatLambdaCDM(name="cosmo")
cosmology.to(dtype=torch.float32)

In [None]:
# SIE lens model, kappa map, alpha map, magnification, time delay, caustics


def plot_lens_metrics(thx0, thy0, q, phi, b):
    lens = caustics.SIE(
        cosmology=cosmology,
        z_l=z_l,
        x0=thx0,
        y0=thy0,
        q=q,
        phi=phi,
        b=b,
    )
    fig, axarr = plt.subplots(2, 3, figsize=(9, 6))
    kappa = avg_pool2d(
        lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[0][0].imshow(torch.log10(kappa), origin="lower")
    axarr[0][0].set_title("log(convergence)")
    psi = avg_pool2d(lens.potential(thx, thy, z_s)[None, None, :, :], upsample_factor)[
        0, 0
    ]
    axarr[0][1].imshow(psi, origin="lower")
    axarr[0][1].set_title("potential")
    timedelay = avg_pool2d(
        lens.time_delay(thx, thy, z_s)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[0][2].imshow(timedelay, origin="lower")
    axarr[0][2].set_title("time delay")
    magnification = avg_pool2d(
        lens.magnification(thx, thy, z_s)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[1][0].imshow(torch.log10(magnification), origin="lower")
    axarr[1][0].set_title("log(magnification)")
    alpha = lens.reduced_deflection_angle(thx, thy, z_s)
    alpha0 = avg_pool2d(alpha[0][None, None, :, :], upsample_factor)[0, 0]
    alpha1 = avg_pool2d(alpha[1][None, None, :, :], upsample_factor)[0, 0]
    axarr[1][1].imshow(alpha0, origin="lower")
    axarr[1][1].set_title("deflection angle x")
    axarr[1][2].imshow(alpha1, origin="lower")
    axarr[1][2].set_title("deflection angle y")
    plt.show()

In [None]:
p = interact(
    plot_lens_metrics,
    thx0=(-2.5, 2.5, 0.1),
    thy0=(-2.5, 2.5, 0.1),
    q=(0.01, 0.99, 0.01),
    phi=(0.0, np.pi, np.pi / 25),
    b=(0.1, 2.0, 0.1),
)

In [None]:
# Sersic source, demo lensed source
def plot_lens_distortion(
    x0_lens,
    y0_lens,
    q_lens,
    phi_lens,
    b_lens,
    x0_src,
    y0_src,
    q_src,
    phi_src,
    n_src,
    Re_src,
    Ie_src,
):
    lens = caustics.SIE(
        cosmology,
        z_l,
        x0=x0_lens,
        y0=y0_lens,
        q=q_lens,
        phi=phi_lens,
        b=b_lens,
    )
    source = caustics.Sersic(
        x0=x0_src,
        y0=y0_src,
        q=q_src,
        phi=phi_src,
        n=n_src,
        Re=Re_src,
        Ie=Ie_src,
    )
    fig, axarr = plt.subplots(1, 3, figsize=(18, 6))
    brightness = avg_pool2d(
        source.brightness(thx, thy)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[0].imshow(brightness, origin="lower")
    axarr[0].set_title("Sersic source")
    kappa = avg_pool2d(
        lens.convergence(thx, thy, z_s)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[1].imshow(torch.log10(kappa), origin="lower")
    axarr[1].set_title("lens log(convergence)")
    beta_x, beta_y = lens.raytrace(thx, thy, z_s)
    mu = avg_pool2d(
        source.brightness(beta_x, beta_y)[None, None, :, :], upsample_factor
    )[0, 0]
    axarr[2].imshow(mu, origin="lower")
    axarr[2].set_title("Sersic lensed")
    plt.show()

In [None]:
p = interact(
    plot_lens_distortion,
    x0_lens=(-2.5, 2.5, 0.1),
    y0_lens=(-2.5, 2.5, 0.1),
    q_lens=(0.01, 0.99, 0.01),
    phi_lens=(0.0, np.pi, np.pi / 25),
    b_lens=(0.1, 2.0, 0.1),
    x0_src=(-2.5, 2.5, 0.1),
    y0_src=(-2.5, 2.5, 0.1),
    q_src=(0.01, 0.99, 0.01),
    phi_src=(0.0, np.pi, np.pi / 25),
    n_src=(0.5, 4, 0.1),
    Re_src=(0.1, 2, 0.1),
    Ie_src=(0.1, 2.0, 0.1),
)