# Fitting a Lens image using Levenberg-Marquardt

In [None]:
import caustics
import numpy as np
import torch
from lenstronomy.Util.param_util import ellipticity2phi_q
import matplotlib.pyplot as plt

## Specs for the data

In [None]:
# Data specs

background_rms = 0.005  #  background noise per pixel
exp_time = 500.0  #  exposure time (arbitrary units, flux per pixel is in units #photons/exp_time unit)
numPix = 60  #  cutout pixel size per axis
pixelscale = 0.05  #  pixel size in arcsec (area per pixel = pixel_scale**2)
fwhm = 0.05  # full width at half maximum of PSF
psf_sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
psf_type = "GAUSSIAN"  # 'GAUSSIAN', 'PIXEL', 'NONE'

cosmology = caustics.FlatLambdaCDM(name="cosmo")
cosmology.to(dtype=torch.float32)

upsample_factor = 1
thx, thy = caustics.utils.get_meshgrid(
    pixelscale / upsample_factor,
    upsample_factor * numPix,
    upsample_factor * numPix,
    dtype=torch.float32,
)
z_l = torch.tensor(0.5, dtype=torch.float32)
z_s = torch.tensor(1.5, dtype=torch.float32)

## Build simulator forward model

In [None]:
# Set up the forward model

# Lens mass model (SIE + shear)
lens_sie = caustics.SIE(name="galaxylens", cosmology=cosmology, z_l=1.0)
lens_shear = caustics.ExternalShear(
    name="externalshear", cosmology=cosmology, x0=0.0, y0=0.0, z_l=1.0
)
lens_mass_model = caustics.SinglePlane(
    name="lensmass", cosmology=cosmology, lenses=[lens_sie, lens_shear], z_l=1.0
)

# Lens light model (sersic)
lens_light_model = caustics.Sersic(name="lenslight")

# Source light model (sersic)
source_light_model = caustics.Sersic(name="sourcelight")

# Gaussian PSF Model
psf_image = caustics.utils.gaussian(
    nx=upsample_factor * 6 + 1,
    ny=upsample_factor * 6 + 1,
    pixelscale=pixelscale / upsample_factor,
    sigma=psf_sigma,
    upsample=2,
)

# Image plane simulator
sim = caustics.Lens_Source(
    lens=lens_mass_model,
    lens_light=lens_light_model,
    source=source_light_model,
    psf=psf_image,
    pixels_x=numPix,
    pixelscale=pixelscale,
    upsample_factor=upsample_factor,
    z_s=2.0,
)

## Sample some mock data

In [None]:
# Generate the mock data
phi_lensmass, q_lensmass = ellipticity2phi_q(e1=0.07, e2=-0.03)
phi_lenslight, q_lenslight = ellipticity2phi_q(e1=-0.1, e2=0.1)
phi_sourcelight, q_sourcelight = ellipticity2phi_q(e1=-0.1, e2=0.1)
true_params = {
    "galaxylens": {
        "x0": 0.05,
        "y0": 0.0,
        "q": q_lensmass,
        "phi": phi_lensmass,
        "b": 0.66,
    },
    "externalshear": {"gamma_1": 0.0, "gamma_2": -0.05},
    "sourcelight": {
        "x0": 0.1,
        "y0": 0.0,
        "q": q_sourcelight,
        "phi": phi_sourcelight,
        "n": 1.0,
        "Re": 0.1 / np.sqrt(q_sourcelight),
        "Ie": 16 * pixelscale**2,
    },
    "lenslight": {
        "x0": 0.05,
        "y0": 0.0,
        "q": q_lenslight,
        "phi": phi_lenslight,
        "n": 2.0,
        "Re": 0.6 / np.sqrt(q_lenslight),
        "Ie": 16 * pixelscale**2,
    },
}
allparams = []
for model in true_params:
    for key in true_params[model]:
        allparams.append(true_params[model][key])
allparams = torch.tensor(allparams)
print(true_params)

# simulate lens, crop extra evaluation for PSF
true_system = sim(allparams)  # [numPix//2:3*numPix//2,numPix//2:3*numPix//2]

print(true_system.shape)
fig, axarr = plt.subplots(1, 2, figsize=(15, 8))
axarr[0].imshow(np.log10(true_system.detach().cpu().numpy()), origin="lower")
axarr[0].axis("off")
axarr[0].set_title("Mock Lens System")

shot_noise = torch.normal(
    mean=torch.zeros_like(true_system),
    std=torch.sqrt(torch.abs(true_system) / exp_time),
)
background = torch.normal(
    mean=torch.zeros_like(true_system), std=torch.tensor(background_rms)
)
variance = (torch.abs(true_system) / exp_time) + background_rms**2
obs_system = true_system + shot_noise + background

axarr[1].imshow(np.log10(obs_system.detach().cpu().numpy()), origin="lower")
axarr[1].axis("off")
axarr[1].set_title("Mock Observation")
# plt.colorbar()
plt.savefig("mock_obs")
plt.show()

## Fit using LM

In [None]:
batch_inits = allparams.repeat((10, 1))
batch_inits += 0.02 * torch.randn_like(batch_inits)
batch_inits = batch_inits.to(dtype=torch.float32)
res = caustics.utils.batch_lm(
    batch_inits,
    obs_system.repeat(10, 1, 1).reshape(10, -1),
    lambda x: sim(x).reshape(-1),
)
print(res[2])

In [None]:
plt.imshow(
    np.log10(sim(res[0][np.argmin(res[2].numpy())]).detach().cpu().numpy()),
    origin="lower",
)
plt.show()