In [None]:
%load_ext autoreload
%autoreload 2

from functools import partial

import torch
from torch.nn.functional import avg_pool2d
import matplotlib.pyplot as plt
from ipywidgets import interact
from astropy.io import fits
import numpy as np
from time import process_time as time

import caustics

First lets create the lens we intend to fit. In this case it is an SIE model. The SIE parameters are unknown at this point so this is just a generic SIE model.

In [None]:
cosmology = caustics.cosmology.FlatLambdaCDM(name="cosmo")
cosmology.to(dtype=torch.float32)

z_l = torch.tensor(0.5, dtype=torch.float32)
z_s = torch.tensor(1.5, dtype=torch.float32)
lens = caustics.lenses.SIE(
    cosmology=cosmology,
    name="sie",
    z_l=z_l,
    s=1e-3,
)

To create a mock dataset to try and fit, we choose a source position and lens parameters then forward raytrace to find the image plane positions.

In [None]:
# Point in the source plane
sp_x = torch.tensor(0.2)
sp_y = torch.tensor(0.2)

#  true parameters     x0  y0  q    phi        b
params = torch.tensor([0.0, 0.0, 0.4, np.pi / 5, 1.0])
# Points in image plane
x, y = lens.forward_raytrace(sp_x, sp_y, z_s, params)

Just to see what's going on, we plot the caustics and image/source plane positions.

In [None]:
n_pix = 100
res = 0.05
upsample_factor = 1
fov = res * n_pix
thx, thy = caustics.utils.get_meshgrid(
    res / upsample_factor,
    upsample_factor * n_pix,
    upsample_factor * n_pix,
    dtype=torch.float32,
)

fig, ax = plt.subplots()

A = lens.jacobian_lens_equation(thx, thy, z_s, params)
detA = torch.linalg.det(A)

CS = ax.contour(thx, thy, detA, levels=[0.0], colors="b", zorder=1)
# Get the path from the matplotlib contour plot of the critical line
paths = CS.allsegs[0]
caustic_paths = []
for path in paths:
    # Collect the path into a discrete set of points
    x1 = torch.tensor(list(float(vs[0]) for vs in path))
    x2 = torch.tensor(list(float(vs[1]) for vs in path))
    # raytrace the points to the source plane
    y1, y2 = lens.raytrace(x1, x2, z_s, params)

    # Plot the caustic
    ax.plot(y1, y2, color="r", zorder=1)
ax.scatter(x, y, color="b", label="forward raytrace", zorder=10)
ax.scatter(sp_x, sp_y, color="r", marker="x", label="source plane", zorder=9)
plt.legend()
plt.show()

In this experiment we have access to the image plane positions `x` and `y` but we will assume we don't know the source plane position or the lens model parameters (`x0, y0, q, phi, b`). Let's start from some slightly incorrect parameters and try to recover the true parameters.

In this example we run from 50 starting points and let the Levenberg Marquardt try to find lensing parameters which map the image positions to a single source position. After the optimization finishes, we select which run actually ended up finding parameters that converge the source positions. The resulting parameters are incredibly close to the true values!

In [None]:
init_params = params.repeat(50, 1)
init_params += torch.normal(mean=0.0, std=0.1 * torch.ones_like(init_params))


def loss(P):
    bx, by = lens.raytrace(x, y, z_l, P)
    return torch.cat(
        (
            torch.sum((bx[0] - bx[1:]) ** 2).unsqueeze(-1),
            torch.sum((by[0] - by[1:]) ** 2).unsqueeze(-1),
        )
    )


fit_params = caustics.utils.batch_lm(
    init_params, torch.zeros(init_params.shape[0], 2), loss, stopping=1e-8
)
# Fit params includes:
# The fitted parameter values, A levenberg-marquardt state parameter, the chi^2 values

# Here we print the parameter values that have a good chi^2
# If you decrease the threshold from 1e-8 to 1e-10 you will get better fits, but fewer of them
print(fit_params[0][fit_params[2] < 1e-8])
# Note that the order is: x0, y0, q, phi, b