In [None]:
%load_ext autoreload
%autoreload 2
import numpy as np
import matplotlib.pyplot as plt
import torch
from torch import vmap
from torch.func import jacfwd

# Welcome to Caustics!

In this introduction, we will showcase some of the features and design principles of Caustics. We will see
1. How to get started from one of our pre-built `Simulator`

2. Visualization of the `Simulator` graph (DAG of `caustics` modules)

3. Distinction between **Static** and **Dynamic** parameters

4. How to create a a **batch** of simulations

5. Semantic structure of the Simulator input
6. Taking gradient w.r.t. to parameters with `Pytorch` autodiff functionalities
7. Swapping in flexible modules like the `Pixelated` representation for more advanced usage
8. How to create your own Simulator

## Getting started with the `LensSource` Simulator

For this first introduction, we use the simplest modules in caustics for the lens and source, namely the `SIE` and the `Sersic` modules. We also assume a `FlatLambdaCDM` cosmology. 

In [None]:
from caustics import LensSource, SIE, Sersic, FlatLambdaCDM

# Define parameters of the camera pixel grid
pixelscale = 0.04  # arcsec/pixel
pixels = 100

# Instantiate modules for the simulator
cosmo = FlatLambdaCDM(name="cosmo")
lens = SIE(cosmology=cosmo, name="lens")
source = Sersic(name="source")
simulator = LensSource(lens, source, pixelscale=pixelscale, pixels_x=pixels)

### Generating a simulation of a strong gravitational lens

In [None]:
z_s = torch.tensor([1.0])
lens_params = torch.tensor(
    [
        0.5,  # z_l
        0.0,  # x_0
        0.0,  # y_0
        0.9,  # q
        0.4,  # phi
        1.0,  # b (Einstein radius)
    ]
)

source_params = torch.tensor(
    [
        0.0,  # x_0
        0.0,  # y_0
        0.5,  # q
        0.9,  # phi
        1.0,  # n
        0.1,  # Re
        10.0,  # I_e
    ]
)

In [None]:
# Generate a lensed image
y = simulator([z_s, lens_params, source_params])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 4))

# A meshgrid to show the source
x = torch.linspace(-0.5, 0.5, 100)
X, Y = torch.meshgrid(x, x, indexing="xy")

ax = axs[0]
ax.set_title(r"Sérsic source")
source_im = source.brightness(X, Y, source_params)
ax.imshow(source_im, origin="lower", extent=(-0.5, 0.5, -0.5, 0.5), cmap="gray")
ax.set_ylabel(r"$\beta_y$ ['']")
ax.set_xlabel(r"$\beta_x$ ['']")

ax = axs[1]
ax.set_title(r"SIE mass distribution")
lens_im = lens.convergence(X * 2, Y * 2, z_s, lens_params)
ax.imshow(lens_im, origin="lower", extent=(-1, 1, -1, 1), cmap="hot")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")

ax = axs[2]
ax.set_title(r"Lensed image")
ax.imshow(y, origin="lower", extent=(-1, 1, -1, 1), cmap="gray")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']");

## Visualization of the `Simulator` DAG 


In [None]:
simulator.graph(True, True)

## **Static** vs **Dynamic** parameters

In the DAG shown above, 

- **Dynamic parameters** are shown in white boxes

- **Static parameters** are shown in grey boxes 

The distinction between the two types can be summarized as follows

- **Dynamic parameters** are fed as input to the simulator and can be batched over (data parallelism)

- **Static parameters** have fixed values. Their values is stored in the internal DAG, and will be broadcasted over when batching computation


In [None]:
# Making a parameter static
simulator.z_s = 1.0
simulator.graph(False, True)  # z_s turns grey

In [None]:
# Making a parameter dynamic
simulator.z_s = None
simulator.graph(
    False, True
)  # z_s turns white, which makes it disappear when we don't show the dynamic parameters (first option False)

## Simulating a batch of observations

We use `vmap` over the simulator to create a batch of parameters. In this example, we create a batch of examples that only differ by their Einstein radius. To do this, we turn all the other parameter into static parameters. This is done in the hidden cell below

In [None]:
# Make all parameters static except the Einstein radius
simulator.lens.x0 = 0.0
simulator.lens.y0 = 0.0
simulator.lens.q = 0.9
simulator.lens.phi = 0.4
simulator.lens.b = None  # Make sure this one stays Dynamic
simulator.lens.z_l = 0.5

simulator.source.x0 = 0.0
simulator.source.y0 = 0.0
simulator.source.q = 0.5
simulator.source.phi = 0.5
simulator.source.n = 1.0
simulator.source.Re = 0.1
simulator.source.Ie = 10.0

simulator.z_s = 1.0

In [None]:
from torch import vmap

# Create a grid of Einstein radius
b = torch.linspace(0.5, 1.5, 5).view(-1, 1)  # Shape is [B, 1]
ys = vmap(simulator)(b)

In [None]:
fig, axs = plt.subplots(1, 5, figsize=(20, 4))

for i, ax in enumerate(axs.flatten()):
    ax.axis("off")
    ax.imshow(ys[i], cmap="gray")
    ax.set_title(f"$b = {b[i].item():.2f}$")
plt.subplots_adjust(wspace=0, hspace=0)

## Semantic structure of the input

The simulator's input takes different format to allow different usecase scenarios
1. Flattened tensor for deep neural network like in [Hezaveh et al. (2017)](https://arxiv.org/abs/1708.08842)

2. Semantic List to separate the input int terms of high level modules like Lens and Source
3. Low-level Dictionary to decompose the parameters at the level of the leafs of the DAG

Below, we illustrate how to use all of these structures. For completeness, we also use `vmap`.

In [None]:
# Make some parameters dynamic for this example
simulator.source.Ie = None
simulator.lens.b = None

### Flattened Tensor
To make sure the order of the parameter is correct, print the simulator. Order of dynamic parameters is shown in the `x_order` field

In [None]:
simulator

In [None]:
B = 5  # Batch dimension
b = torch.rand(B, 1)
Ie = torch.rand(B, 1)
x = torch.concat([b, Ie], dim=1)  # Concat along the feature dimension

# Now we can use vmap to simulate multiple images at once
ys = vmap(simulator)(x)

### Semantic lists

A semantic list is simply a list over module parameters like the one we used earlier: `[z_s, lens_params, source_params]`. Note that we could also include cosmological parameters in that list

In [None]:
# Make some parameters dynamic for this example
simulator.source.Ie = None
simulator.lens.b = None
simulator.lens.x0 = None
simulator.lens.cosmology.h0 = None

simulator

In [None]:
B = 5
lens_params = torch.randn(B, 2)  # x0 and b
source_params = torch.rand(B, 1)  # Ie
cosmo_params = torch.rand(B, 1)  # h0

x = [lens_params, cosmo_params, source_params]
ys = vmap(simulator)(x)

### Low-level Dictionary

In [None]:
B = 5
x0 = torch.randn(B, 1)
b = torch.randn(B, 1)
Ie = torch.rand(B, 1)
h0 = torch.rand(B, 1)

In [None]:
x = {
    "lens": {
        "x0": x0,
        "b": b,
    },
    "source": {
        "Ie": Ie,
    },
    "cosmo": {
        "h0": h0,
    },
}
ys = vmap(simulator)(x)

## Computing gradients with automatic differentiation

Computing gradients is particularly useful for optimization. Since taking gradients w.r.t. list or dictionary inputs is not possible with `torch.func.grad`, we will need a small wrapper around the simulator. For optimisation, the wrapper will often be a log likelihood function. For now we use a generic `lambda` wrapper. 

In the case of the semantic list input, the wrapper has the general form
```python
lambda *x: simulator(x)
```

The low-level dictionary input is a bit more involved but can be worked out on a case by case basis. 

**Note**: apply `vmap` around the gradient function (e.g. `jacfwd` or `grad`) to handle batched computation


In [None]:
# Choose some sensible values to compute the gradient
lens_params = torch.tensor([0.0, 1.0])  # x0 and b
source_params = torch.tensor([10.0])  # Ie
cosmo_params = torch.tensor([0.7])  # h0

`jacfwd` will return a list of 3 tensors of shape [B, pixels, pixels, D], where D is the number of parameters in that module

In [None]:
from torch.func import jacfwd

jac = jacfwd(lambda *x: simulator(x), argnums=(0, 1, 2))(
    lens_params, cosmo_params, source_params
)

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(20, 4))

titles = [
    r"$\nabla_{x_0} f(\mathbf{x})$",
    r"$\nabla_{b} f(\mathbf{x})$",
    r"$\nabla_{h_0} f(\mathbf{x})$",
    r"$\nabla_{I_e} f(\mathbf{x})$",
]
jacs = torch.concat(jac, dim=-1)
for i, ax in enumerate(axs.flatten()):
    ax.axis("off")
    ax.imshow(jacs[..., i], cmap="seismic", vmin=-10, vmax=10)
    ax.set_title(titles[i], fontsize=18);

## Pixelated representations

The examples above made use of very simplistic modules. Here, we will showcase how easily we can swap-in flexible representations to represent more realistic systems. 

- `Pixelated` is the module used to represent the background source with a grid of pixels

- `PixelatedConvergence` is the module used to represent the convergence of the lens with a grid of pixels

For this example, we will use source samples from the PROBES dataset ([Stone et al., 2019](https://iopscience.iop.org/article/10.3847/1538-4357/ab3126/meta#:~:text=The%20intrinsic%20scatter%20of%20the%20baryonic%20RAR%20is%20predicted%20by,null%20value%20reported%20by%20L17.)) and convergence maps sampled from Illustris TNG ([Nelson et al., 2019](https://comp-astrophys-cosmol.springeropen.com/articles/10.1186/s40668-019-0028-x), see [Adam et al., 2023](https://iopscience.iop.org/article/10.3847/1538-4357/accf84/meta) for preprocessing, or use this [link](https://zenodo.org/records/6555463/files/hkappa128hst_TNG100_rau_trainset.h5?download=1) to download the maps). 

In [None]:
from caustics import Pixelated, PixelatedConvergence

# Some static parameters for the simulator
pixelscale = 0.07
source_pixelscale = 0.25 * pixelscale
z_l = 0.5
z_s = 1.0
x0 = 0
y0 = 0

# Construct the Simulator with Pixelated and PixalatedConvergence modules
cosmo = FlatLambdaCDM(name="cosmo")
source = Pixelated(
    name="source", shape=(256, 256), pixelscale=source_pixelscale, x0=x0, y0=y0
)
lens = PixelatedConvergence(
    cosmology=cosmo,
    name="lens",
    pixelscale=pixelscale,
    shape=(128, 128),
    z_l=z_l,
)
simulator = LensSource(lens, source, pixelscale=pixelscale, pixels_x=pixels, z_s=z_s)

simulator.graph(True, True)

In the hidden cell below, we load the maps from a dataset. If you downloaded the datasets mentioned above, you can use the code below to load maps from them. 

In [None]:
# import h5py

# B = 10
# path_to_kappa_maps = "/path/to/hkappa128hst_TNG100_rau_trainset.h5"  # modify this to your system path
# index = [250] + sorted(list(np.random.randint(251, 1000, size=B-1)))
# kappa_map = torch.tensor(h5py.File(path_to_kappa_maps, "r")["kappa"][index])

# path_to_source_maps = "/path/to/probes.h5"  # modify this to your system path
# index = [101] + sorted(list(np.random.randint(251, 1000, size=B-1)))
# filter_ = 0  # grz filters: 0 is g, etc.
# source_map = torch.tensor(
#     h5py.File(path_to_source_maps, "r")["galaxies"][index, ..., filter_]
# )

# Load saved assets for demonstration
kappa_maps = torch.tensor(
    np.load("assets/kappa_maps.npz", allow_pickle=True)["kappa_maps"]
)
source_maps = torch.tensor(
    np.load("assets/source_maps.npz", allow_pickle=True)["source_maps"]
)

# Cherry picked example
source_map = source_maps[0]
kappa_map = kappa_maps[0]

Make a simulation by feeding the maps as input to the simulator (using semantic list inputs)

In [None]:
y = simulator([kappa_map, source_map])

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 4))

beta_extent = [
    -source_pixelscale * source_map.shape[0] / 2,
    source_pixelscale * source_map.shape[0] / 2,
] * 2

ax = axs[0]
ax.set_title(r"Source map")
ax.imshow(source_map, origin="lower", cmap="gray", extent=beta_extent)
ax.set_ylabel(r"$\beta_y$ ['']")
ax.set_xlabel(r"$\beta_x$ ['']")

theta_extent = [-pixelscale * pixels / 2, pixelscale * pixels / 2] * 2

ax = axs[1]
ax.set_title(r"Convergence map")
ax.imshow(
    kappa_map,
    origin="lower",
    cmap="hot",
    extent=theta_extent,
    norm=plt.cm.colors.LogNorm(vmin=1e-1, vmax=10),
)
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")
ax.set_title(r"Convergence map")

ax = axs[2]
ax.set_title(r"Lensed image")
ax.imshow(y, origin="lower", extent=theta_extent, cmap="gray")
ax.set_ylabel(r"$\theta_y$ ['']")
ax.set_xlabel(r"$\theta_x$ ['']")
ax.set_title(r"Lensed image");

Of course, batching works the same way as before and is super fast. Below, we show the time it takes to make 4 batched simulations on a laptop.

In [None]:
%%timeit

ys = vmap(simulator)([kappa_maps, source_maps])

In [None]:
fig, axs = plt.subplots(3, 3, figsize=(9, 9))

ys = vmap(simulator)([kappa_maps, source_maps])
for i in range(3):
    ax = axs[i, 0]
    ax.axis("off")
    ax.imshow(
        source_maps[len(ys) - 1 - i],
        origin="lower",
        cmap="gray",
        norm=plt.cm.colors.LogNorm(vmin=1e-2, vmax=1, clip=True),
    )

    ax = axs[i, 1]
    ax.axis("off")
    ax.imshow(
        kappa_maps[len(ys) - 1 - i],
        origin="lower",
        cmap="hot",
        norm=plt.cm.colors.LogNorm(vmin=1e-1, vmax=10),
    )

    ax = axs[i, 2]
    ax.axis("off")
    ax.imshow(
        ys[len(ys) - 1 - i],
        origin="lower",
        cmap="gray",
        norm=plt.cm.colors.LogNorm(vmin=1e-2, vmax=1, clip=True),
    )
axs[0, 0].set_title(r"Source map")
axs[0, 1].set_title(r"Convergence map")
axs[0, 2].set_title(r"Lensed image")
plt.subplots_adjust(wspace=0, hspace=0);

## Creating your own Simulator

Here, we only introduce the general design principles to create a Simulator. Worked examples can be found in [this notebook](./Simulators.ipynb). 

### A Simulator is very much like a neural network in Pytorch
A simulator inherits from the super class `Simulator`, similar to how a neural network inherits from the `nn.Module` class in `Pytorch`

```python
from caustics import Simulator

class MySim(Simulator):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        ...
```

- The `init` method constructs the computation graph, initialize the `caustics` modules, and can prepare or store variables for the `forward` method. 
- The `forward` method is where the actual simulation happens.
- `x` generally denotes a set of parameters which affect the computations in the simulator graph. 

### How to use a Simulator in your workflow
Like a neural network, `MySim` (and in general any `caustics` modules), must be instantiated **outside** the main workload. This is because `caustics` builds a graph internally every time a module is created. Ideally, this happens only once to avoid overhead. In general, you can follow the following code pattern
```python

# Instantiation
simulator = MySim()

# Heavy workload
for n in range(N):
    y = vmap(simulator)(x)
```
This allows you to perform inefficient computations that only need to happen once in the `__init__` method while keeping your forward method lightweight.

### How to feed parameters to the different modules

This is probably the easiest part of building a Simulator. Just feed `x` at the end of each method. And you are done. 
 
 Here is a minimal example that shows how to feed the parameters `x` to different modules in the `forward` method
 ```python
 def forward(self, x):
    alpha_x, alpha_y = self.lens.reduced_deflection_angle(self.theta_x, self.theta_y, self.z_s, x)
    beta_x = self.theta_x - alpha_x # lens equation
    ...
    lensed_image = self.source.brightness(beta_x, beta_y, x)
 ``` 

You might worry that `x` can have a relatively complex structure (flattened tensor, semantic lict, low-level dictionary). 
`caustics` handles this complexity for you. 
You only need to make sure that `x` contains all the **dynamic** parameters required by your custom simulator. 
This design works for every `caustics` module and each of their methods, meaning that `x` is always the last argument in a `caustics` method call signature.  
 
The only details that you need to handle explicitly in your own simulator are stuff like the camera pixel position (`theta_x` and `theta_y`), and source redshifts (`z_s`). Those are often constructed in the `__init__` method because they can be assumed fixed. Thus, the example above assumed that they can be retrieved from the `self` registry. A Simulator is often an abstraction of an instrument with many fixed variables to describe it, or aimed at a specific observation. 

Of course, you could have more complex workflows for which this assumption is not true. For example, you might want to infer the PSF parameters of your instrument and need to feed this to the simulator as a dynamic parameter. 
The next section has what you need to customize completely your simulator

### Creating your own variables as leafs in the DAG

You can register new variables in the DAG for custom calculations as follows

In [None]:
from caustics import Simulator


class MySim(Simulator):
    def __init__(self):
        super().__init__()  # Don't forget to use super!!
        # shape has to be a tuple, e.g. shape=(1,). This can be any shape you need.
        self.add_param(
            "my_dynamic_arg", value=None, shape=(1,)
        )  # register a dynamic parameter in the DAG
        self.add_param(
            "my_static_arg", value=1.0, shape=(1,)
        )  # register a static parameter in the DAG

    def forward(self, x):
        my_arg = x["MySim"][0]  # retrieve your arguments
        my_arg2 = x["MySim"][1]

        # My very complex workflow
        ...


sim = MySim()
sim.graph(True, True)