# Making a new lens model

Here we will demo how you can make your own lens model just by defining a potential, `caustics` will take care of the rest.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt
import torch
import caustics
from caustics import forward, Param

Below we define a class that inherits from `caustics.ThinLens`, this is the abstract class for all single plane lenses in `caustics`. The base class needs a cosmology, lens redshift, and source redshift which are passed via `super()`. After that we define the `Param`s needed by our class (which is just some gaussian parameters). Finally, we define the potential function for our class, which in this case is a gaussian. The potential is convenient because all other lensing quantities (deflection angle and convergence) can be determined from derivatives of the potential. This is why, given only the potential, `caustics` is able to build a full model.

In [None]:
class GaussianPotential(caustics.ThinLens):

    def __init__(self, cosmology, z_l, z_s, x0, y0, A, sigma):
        super().__init__(cosmology=cosmology, z_l=z_l, z_s=z_s)

        self.x0 = Param("x0", x0)
        self.y0 = Param("y0", y0)
        self.A = Param("A", A)
        self.sigma = Param("sigma", sigma)

    @forward
    def potential(self, x, y, x0, y0, A, sigma):
        return -A * torch.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2))

Now we can do a really basic simulation just to see everything is working. We take a Sersic source model and use the `LensSource` simulator to make an image of the lensing from our gaussian potential model.

In [None]:
cosmo = caustics.FlatLambdaCDM()
lens = GaussianPotential(cosmo, z_l=0.5, z_s=1.0, x0=0.0, y0=0.0, A=2.0, sigma=1.0)
src = caustics.Sersic(x0=0.2, y0=0.2, q=0.6, phi=1.0, Ie=1.0, Re=1.0, n=2.0)

sim = caustics.LensSource(lens, src, pixels_x=100, pixelscale=0.05, upsample_factor=2)

plt.imshow(sim().numpy(), origin="lower")
plt.axis("off")
plt.title("Sersic lensed with Gaussian potential")
plt.show()

Now that we've tried lensing, lets look at all the basic lensing quantities and map them out for our new lens. The potential is exactly as we specified, the deflection angles are its derivatives, the convergence comes from second derivatives, and so on. We can compute shear, magnification, and the time delay field as well.

In [None]:
fig, axarr = plt.subplots(2, 4, figsize=(20, 10))
n_pix = 100
res = 0.05
thx, thy = caustics.utils.meshgrid(res, n_pix, dtype=torch.float32)
axarr[0][0].imshow(lens.potential(thx, thy).numpy(), origin="lower")
axarr[0][0].set_title("Potential")
axarr[0][0].axis("off")
axarr[0][1].imshow(lens.reduced_deflection_angle(thx, thy)[0].numpy(), origin="lower")
axarr[0][1].set_title("Deflection x")
axarr[0][1].axis("off")
axarr[0][2].imshow(lens.reduced_deflection_angle(thx, thy)[1].numpy(), origin="lower")
axarr[0][2].set_title("Deflection y")
axarr[0][2].axis("off")
axarr[0][3].imshow(lens.convergence(thx, thy).numpy(), origin="lower")
axarr[0][3].set_title("Convergence")
axarr[0][3].axis("off")
axarr[1][0].imshow(lens.shear(thx, thy)[0].numpy(), origin="lower")
axarr[1][0].set_title("Shear g1")
axarr[1][0].axis("off")
axarr[1][1].imshow(lens.shear(thx, thy)[1].numpy(), origin="lower")
axarr[1][1].set_title("Shear g2")
axarr[1][1].axis("off")
axarr[1][2].imshow(
    torch.clamp(lens.magnification(thx, thy), -10.0, 20.0).numpy(), origin="lower"
)
axarr[1][2].set_title("Magnification")
axarr[1][2].axis("off")
axarr[1][3].imshow(lens.time_delay(thx, thy).numpy(), origin="lower")
axarr[1][3].set_title("Time delay")
axarr[1][3].axis("off")

If you know the analytic form of one of the quantities, you may want to write out the appropriate function and overload the base class method which uses autograd to compute it. This will be faster since you've done some of the work for the code by figuring out the analytic form.

In [None]:
class GaussianPotentialFast(caustics.ThinLens):

    def __init__(self, cosmology, z_l, z_s, x0, y0, A, sigma):
        super().__init__(cosmology=cosmology, z_l=z_l, z_s=z_s)

        self.x0 = Param("x0", x0)
        self.y0 = Param("y0", y0)
        self.A = Param("A", A)
        self.sigma = Param("sigma", sigma)

    @forward
    def potential(self, x, y, x0, y0, A, sigma):
        return -A * torch.exp(-((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma**2))

    @forward
    def reduced_deflection_angle(self, x, y, x0, y0, A, sigma):
        ax = -(x - x0) / sigma**2  # derivative of exponent
        ay = -(y - y0) / sigma**2
        p = self.potential(x, y)  # exponential stays after derivative
        return ax * p, ay * p

    @forward
    def convergence(self, x, y, x0, y0, A, sigma):
        p = self.potential(x, y)
        dx = (x - x0) ** 2 / sigma**4
        dxdx = -1 / sigma**2
        dy = (y - y0) ** 2 / sigma**4
        return 0.5 * (2 * dxdx + dx + dy) * p

In [None]:
lens_basic = GaussianPotential(
    cosmo, z_l=0.5, z_s=1.0, x0=0.0, y0=0.0, A=2.0, sigma=1.0
)
lens_fast = GaussianPotentialFast(
    cosmo, z_l=0.5, z_s=1.0, x0=0.0, y0=0.0, A=2.0, sigma=1.0
)

fig, axarr = plt.subplots(2, 3, figsize=(15, 10))
axarr[0][0].imshow(
    lens_basic.reduced_deflection_angle(thx, thy)[0].numpy(), origin="lower"
)
axarr[0][0].set_title("Deflection x basic")
axarr[0][0].axis("off")
axarr[0][1].imshow(
    lens_fast.reduced_deflection_angle(thx, thy)[0].numpy(), origin="lower"
)
axarr[0][1].set_title("Deflection x fast")
axarr[0][1].axis("off")
axarr[0][2].imshow(
    lens_basic.reduced_deflection_angle(thx, thy)[0].numpy()
    - lens_fast.reduced_deflection_angle(thx, thy)[0].numpy(),
    origin="lower",
)
axarr[0][2].set_title("Difference")
axarr[0][2].axis("off")
axarr[1][0].imshow(
    lens_basic.reduced_deflection_angle(thx, thy)[1].numpy(), origin="lower"
)
axarr[1][0].set_title("Deflection y basic")
axarr[1][0].axis("off")
axarr[1][1].imshow(
    lens_fast.reduced_deflection_angle(thx, thy)[1].numpy(), origin="lower"
)
axarr[1][1].set_title("Deflection y fast")
axarr[1][1].axis("off")
axarr[1][2].imshow(
    lens_basic.reduced_deflection_angle(thx, thy)[1].numpy()
    - lens_fast.reduced_deflection_angle(thx, thy)[1].numpy(),
    origin="lower",
)
axarr[1][2].set_title("Difference")
axarr[1][2].axis("off")
fig.suptitle("Comparison of basic and fast lensing, the two are identical", fontsize=16)
plt.show()

In [None]:
%%timeit
ax, ay = lens_basic.reduced_deflection_angle(thx, thy)

In [None]:
%%timeit
ax, ay = lens_fast.reduced_deflection_angle(thx, thy)

Here we see that our new fast version is much faster (almost 10x faster) than the basic one which only uses automatic differentiation from the potential. There are a few reasons for this, in the most straightforward setups it is normal for autograd to be about 2-3x slower than an analytic derivation. Further, because in this case there are many shared calculations between `ax` and `ay`, we were able to save ourselves a bunch of calculations by only doing the shared stuff once.

Next lets look at the convergence, which uses the Hessian of the potential. Below we compare the two ways of computing the convergence, one using autograd and the other using analytic derivatives. We see that the two are nearly identical, the residuals are at the level of `10^-7` which is the precision of floating point operations. Thus the two are identical up to the level that we can tell with our current numerical precision.

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 5))
axarr[0].imshow(lens_basic.convergence(thx, thy).numpy(), origin="lower")
axarr[0].set_title("Convergence basic")
axarr[0].axis("off")
axarr[1].imshow(lens_fast.convergence(thx, thy).numpy(), origin="lower")
axarr[1].set_title("Convergence fast")
axarr[1].axis("off")
im = axarr[2].imshow(
    lens_basic.convergence(thx, thy).numpy() - lens_fast.convergence(thx, thy).numpy(),
    origin="lower",
)
fig.colorbar(im, ax=axarr[2])
axarr[2].set_title("Difference")
axarr[2].axis("off")
fig.suptitle(
    "Comparison of basic and fast convergence, the two are identical", fontsize=16
)
plt.show()

In [None]:
%%timeit
kappa = lens_basic.convergence(thx, thy)

In [None]:
%%timeit
kappa = lens_fast.convergence(thx, thy)

Since the convergence uses second derivatives, we see an even more dramatic difference between our basic autograd from the potential and an analytic calculation. It is now almost 30x faster, which is another factor of 3 because of the extra autograd operation needed for the basic calculation.

The conclusion here is that using autograd from the potential is easy and reasonably fast, but if performance is a significant value then its worth doing the extra work to get the derivatives yourself. In `caustics` all of the base models have analytic `potential`, `deflection_angle`, and `convergence` so that it is as performant as possible. If you can use a built-in method of `caustics` then it is worth doing so, but if you need to make your own model, you now know how to make it as fast as possible!