In [1]:
import os
import matplotlib.pyplot as plt
import streamlit as st
from plotly.subplots import make_subplots
import plotly.express as px
import plotly.graph_objects as go

In [2]:
import torch
import numpy as np
import caustics
from caustics.utils import get_meshgrid
from app_configs import (
    lens_slider_configs,
    source_slider_configs,
    name_map,
    default_params,
)

from skimage import measure

In [3]:
def caustic_critical_line(lens, x, z_s, res, simulation_size, upsample_factor=1, device="cpu"):
    thx, thy = get_meshgrid(
        res / upsample_factor,
        upsample_factor * simulation_size,
        upsample_factor * simulation_size,
        dtype=torch.float32,
        device=device,
    )
    A = lens.jacobian_lens_equation(thx, thy, z_s, lens.pack(x))

    # Compute A's determinant at every point
    detA = torch.linalg.det(A)

    # Generate caustic using skimage's find_contours
    contours = measure.find_contours(detA.cpu().numpy(), 0.0)

    x1s = []
    x2s = []
    y1s = []
    y2s = []
    for contour in contours:
        # Convert contour to device tensor
        contour = torch.tensor(contour, device=device)
        # Raytrace the points to the source plane
        x1 = contour[:, 1]  # * res / upsample_factor + simulation_size / 2
        x2 = contour[:, 0]  # * res / upsample_factor + simulation_size / 2
        y1, y2 = lens.raytrace(
            (x1 - simulation_size / 2) * res,
            (x2 - simulation_size / 2) * res,
            z_s,
            params=lens.pack(x),
        )
        y1s.append(y1.cpu().numpy() / res + simulation_size / 2)
        y2s.append(y2.cpu().numpy() / res + simulation_size / 2)
        x1s.append(x1.cpu().numpy())
        x2s.append(x2.cpu().numpy())

    return x1s, x2s, y1s, y2s

In [4]:
simulation_size = 1024
fov = 6.5
deltam = fov / simulation_size

In [16]:
x_source = torch.tensor([-0., -0., 1.0, 0., 0.8, 0.3])
x_lens = torch.tensor([0., 0., 0.5, 0., 2.])
x_all = torch.cat((x_lens, x_source))
z_lens = 1.0
z_source = 2.0
cosmology = caustics.FlatLambdaCDM(name="cosmo")
lenses = [name_map["SIE"](cosmology, **default_params["SIE"], z_l=z_lens)]
lens = caustics.SinglePlane(lenses=lenses, cosmology=cosmology, z_l=z_lens)

In [17]:
src = name_map["Sersic"](name="src", **default_params["Sersic"])
minisim = caustics.Lens_Source(
    lens=lens, source=src, pixelscale=deltam, pixels_x=simulation_size, z_s=z_source
)
x1s, x2s, y1s, y2s = caustic_critical_line(
    lens=lens, x=x_lens, z_s=z_source, res=deltam, simulation_size=simulation_size
)


In [18]:
fig2 = make_subplots()
res = minisim(x_all, lens_source=False).numpy()
res = (res - np.min(res)) / (np.max(res) - np.min(res))
fig2.add_heatmap(z=res, zmin=0, zmax=1, coloraxis="coloraxis", hoverinfo="skip")

for c in range(len(y1s)):
    fig2.add_trace(
        go.Scatter(
            x=y1s[c], y=y2s[c], mode="lines", line=dict(color="white"), hoverinfo="skip"
        )
    )
fig2.update_layout(
    width=400,  # Adjust as needed
    height=400,  # Adjust as needed
    coloraxis_showscale=False,
    margin=dict(l=0, r=0, t=0, b=0),
    showlegend=False,
    coloraxis=dict(colorscale="inferno", cmin=-0.12, cmax=1),
)
fig2.write_image("sie1_back.pdf")

In [19]:
fig1 = make_subplots()
res = minisim(x_all, lens_source=True).numpy()
res = (res - np.min(res)) / (np.max(res) - np.min(res))
fig1.add_heatmap(z=res, zmin=0, zmax=1, coloraxis="coloraxis", hoverinfo="skip")
for c in range(len(x1s)):
    fig1.add_trace(
        go.Scatter(
            x=x1s[c], y=x2s[c], mode="lines", line=dict(color="white"), hoverinfo="skip"
        )
    )
fig1.update_layout(
    width=400,  # Adjust as needed
    height=400,  # Adjust as needed
    coloraxis_showscale=False,
    margin=dict(l=0, r=0, t=0, b=0),
    showlegend=False,
    coloraxis=dict(colorscale="inferno", cmin=-0.12, cmax=1),
)
fig1.write_image("sie1.pdf")