# Multi-Source Modelling

Sometimes a lensing system will have aligned with multiple sources of light. These will be at different source redshifts and so some consideration must be taken to model the image of both objects. Here we will demonstrate a case of a single gravitational lens projecting two sources.

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import matplotlib.pyplot as plt
from numpy import pi
import torch
from torch.nn.functional import avg_pool2d
import caustics
from caskade import OverrideParam
from caustics import Module, forward

Lets define the objects we will need int he lensing system. The cosmology and lens are like normal, but now we have two sources. Also, the Einstein radius of the lens is usually set to a single value, but that value is encoding information about the lens configuration and is related to the redshift of the source. We can't conveniently pack that away anymore, so we set the Einstein radius to be a function of the source redshift. In this case we choose an arbitrary function, but in general you would choose a more meaningful function that represents the true configuration of your system.

In [None]:
cosmology = caustics.FlatLambdaCDM()
lens = caustics.SIE(cosmology=cosmology, name="lens", z_l=0.5, z_s=0.0)
src1 = caustics.Sersic(name="source1")
src2 = caustics.Sersic(name="source2")
# Einstein radius a function of source redshift
lens.b = lambda p: p["z_s"].value * 0.8
lens.b.link(lens.z_s)

In [None]:
# Define the pixel grid for imaging
n_pix = 100
res = 0.05
upsample_factor = 3

Below we define a new ``LensSource`` Module, except that now it takes two sources, so it is ``LensTwoSources``. During runtime the source redshift for the lens is overridden with the two values. For each source the lensing proceeds like normal.

In [None]:
class LensTwoSources(Module):
    def __init__(self, lens, src1, src2, z_s1, z_s2):
        super().__init__()
        self.lens = lens
        self.src1 = src1
        self.src2 = src2
        self.z_s1 = torch.as_tensor(z_s1)
        self.z_s2 = torch.as_tensor(z_s2)
        theta_x, theta_y = caustics.utils.meshgrid(
            res / upsample_factor,
            upsample_factor * n_pix,
            dtype=torch.float32,
        )
        self.theta_x = theta_x
        self.theta_y = theta_y

    @forward
    def __call__(self, source1=True, source2=True):
        mu = torch.zeros_like(self.theta_x)

        if source1:
            with OverrideParam(lens.z_s, self.z_s1):
                bx, by = self.lens.raytrace(self.theta_x, self.theta_y)
                mu += self.src1.brightness(bx, by)

        if source2:
            with OverrideParam(lens.z_s, self.z_s2):
                bx, by = self.lens.raytrace(self.theta_x, self.theta_y)
                mu += self.src2.brightness(bx, by)

        return avg_pool2d(mu[None][None], upsample_factor).squeeze()

In [None]:
lts_sim = LensTwoSources(lens, src1, src2, z_s1=0.8, z_s2=1.5)

# List params
#                          x0   y0   q    phi
lensparams = torch.tensor([0.0, 0.0, 0.5, 0.0])
#                          x0   y0   q    phi   n    Re   Ie
src1params = torch.tensor([0.0, 0.0, 0.6, pi / 3, 4.0, 0.4, 0.5])
src2params = torch.tensor([0.1, 0.2, 0.3, -pi / 4, 2.0, 1.5, 1.5])

params = torch.cat([lensparams, src1params, src2params])

# Sample the image
img = lts_sim(params)

In [None]:
fig, axarr = plt.subplots(1, 3, figsize=(15, 5))
axarr[0].imshow(img, cmap="inferno")
axarr[0].axis("off")
axarr[0].set_title("Two sources")
axarr[1].imshow(lts_sim(params, source2=False), cmap="inferno")
axarr[1].axis("off")
axarr[1].set_title("Source 1")
axarr[2].imshow(lts_sim(params, source1=False), cmap="inferno")
axarr[2].axis("off")
axarr[2].set_title("Source 2")
plt.show()