# Understanding the Multi-Plane Fermat Potential

## Setting up multi_plane and multi_plane_base

Firstly, when `multi-plane` is set to true in the lens model, `multi_plane_base` is also initialised as follows:

Firstly, the redshift at which the interpolation over the lens planes stops, if not yet specified, is set to be the redshift of the source.

After this, the `self._cosmo_bkg` class instance is initialised. This is the `Background` class from *Cosmo*, initialised as:

`self._cosmo_bkg = Background(cosmo, interp=cosmo_interp, z_stop=z_interp_stop, num_interp=num_z_interp)`

This class computes cosmological distances using `cosmo_interp`, a class which interpolates the comoving transverse distances and then calculates angular diameter distances from this

The code then double-checks that the maximum redshift in the list of lens redshifts isn't greater than the redshift taken to be the source redshift

We then check that the number of lens models we have matches the number of redshift values.

If there are no lens models, the ordered list of redshift indices is empty

Otherwise, we create a list of indices of `lens_redshift_list`, ordered according to the redshifts they correspond to

After this, the code interpolates through the lens planes. To this end, the starting redshift is set to be zero, as is the starting value of `T_z`. Note

`T_z` is the comoving transverse distance to redshift z as viewed from the observer
`T_ij` is the comoving transverse distance to the plane with the redshift of the index being iterated through, as viewed from the redshift of the previous index

The code also then populates a list of distance factors, but this isn't relevant for time delays

The code then iterates through the sorted list of indices. If the lens being iterated through has the same redshift as the previous one, then the comoving transverse between them is 0, and `T_z` doesn't need to be updated. If not, then `T_z` and `T_ij` at this redshift are calculated and appended to the lists. The code then repeats through the next redshift index.

## Calling Arrival Time

If the `arrival_time` function is called, this basically just calls `geo_shapiro_time_delay`:

`dt_geo, dt_grav = self.geo_shapiro_delay(theta_x, theta_y, kwargs_lens, check_convention=check_convention)`

which returns `dt_geo, dt_grav`, which are then added together to give the arrival time

## The Shapiro Time Delay

### the setup

`geo_shapiro_time_delay` works as follows. Firstly, the arrays are set up. These are arrays with as many values as there are images which are being considered, and their values will be updated with each step in the interpolation

`dt_grav = np.zeros_like(theta_x, dtype=float)
dt_geo = np.zeros_like(theta_x, dtype=float)
x = np.zeros_like(theta_x, dtype=float)
y = np.zeros_like(theta_y, dtype=float)
alpha_x = np.array(theta_x, dtype=float)
alpha_y = np.array(theta_y, dtype=float)`

`dt_grav` is, of course, the total gravitational time delay, while `dt_geo` is the geometric time delay. For each successive lens plane, the time delay for that component of the light beam's path is calculated, and these are added together to give the totals

`x` and `y` are the physical positions at which the light beam intersects the lens plane. These are updated using the deflection angles of each lens as we iterate

`alpha_x` and `alpha_y` are the deflection angles, which allow us to determine the position of the beam in a certain plane from the source. Initially, these are simply set to be `theta_x` and `theta_y` (**come back to why**)

We will then iterate through each of the lens planes, starting with the observer plane (where i = 0). We also set 

`z_lens_last = 0`

This is treated as the redshift of the previous lens, and is updated with each successive iteration. 

The for loop then begins, looping through i = 0,1,2,3,..., and the sorted list of redshift indices (which could be any order of indices, but ensures that these indices correspond to lenses ranked in increasing order of redshift.

### the for loop

1. `z_lens = self._lens_redshift_list[index]`, i.e. `z_lens` is defined to be the redshift of that lens

2. `if z_lens <= z_stop:`, i.e. check that the redshift at which we stop iterating (that of the source) hasn't been reached

3. `T_ij = self._T_ij_list[i]`. Recall that the list of `T_ij` values,  determined when multi_plane was first initialised, is already in order of increasing redshift, and so here we use `i` instead of `index` to find the value. This step sets `T_ij` to be the value appropriate for this plane, which is the tranverse comoving distance from the previous plane to the current one. This value will be replaced for each step of the loop

4. `x_new, y_new = self._ray_step(x, y, alpha_x, alpha_y, T_ij)`. Here, we use the `_ray_step` function in `multi_plane_base` (see more below). This updates the position at which quantities are evaluated in the lens plane

5. `if i == 0: pass`. This step ensures that values of the tranverse comoving distances and geometrical time delays need not be calculated when the lens in question is the first lens. To understand this, think about the fact that the geometric term is calculated in terms of the difference between the lens plane and source plane positions, and not to do with the observer to image plane differences. 

6. `elif T_ij > 0:` This step ensures that we do not calculate the geometric time delay for that iteration if the lens plane is at the same redshift as the previous one 


- `T_j = self._T_z_list[i]` T_j is set to be the tranverse comoving distance to that redshift

- `T_i = self._T_z_list[i - 1]` T_i is set to be the transverse comoving distance to the previous redshift

- `beta_i_x, beta_i_y = x / T_i, y / T_i` the angular position of the light beam in the previous plane is calculated by dividing its physical position by the transverse comoving distance to the plane (from the definition of transverse comoving distance)

- `beta_j_x, beta_j_y = x_new / T_j, y_new / T_j` exactly as above, but for the plane in question

- `dt_geo_new = self._geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij)` the geometrical delay between the previous plane and this plane is calculated using the appropriate quantities (see below)


7. `x, y = x_new, y_new` the physical coordinates of the light beam in the plane in question are set to be those of the previous plane for the next step in the loop

8. `dt_grav_new = self._gravitational_delay(x, y, kwargs_lens, i, z_lens)` the gravitational time delay doesn't need `T_ij > 0`. This is because multiple lenses in the same plane will each contribute separately to the gravitational time delay. This step returns the gravitational time delay (see below for how)

9. `alpha_x, alpha_y = self._add_deflection(x, y, alpha_x, alpha_y, kwargs_lens, i)` the overall deflection angle is updated by adding the deflection angle caused by that specific lens (see below)

10. `dt_grav += dt_grav_new` the total gravitational time delay is updated with the contribution determined above

11. `z_lens_last = z_lens` the current redshift becomes the previous redshift for the next iteration

### finishing up - the final calculation

The above for loop will of course only go up the final lens, which is not the same as all the way to the source plane. We therefore need to account for the additional geometrical time delay (there is no gravitational time delay due to the source plane.

Firstly, if the transverse comoving distance between the last lens and the source plane hasn't been specified, we calculate this:

`if T_ij_end is None:
    T_ij_end = self._cosmo_bkg.T_xy(z_lens_last, z_stop)
T_ij = T_ij_end` 
    
The new physical positions are calculated as before, using the transverse comoving distance calculated above

`x_new, y_new = self._ray_step(x, y, alpha_x, alpha_y, T_ij)`

We perform a similar calculation for `T_z_stop` as for `T_ij_end` if its value hasn't been specified

`if T_z_stop is None:
    T_z_stop = self._cosmo_bkg.T_xy(0, z_stop)
T_j = T_z_stop`

The transverse comoving distance to the previous redshift is that of the final lens. Since the value of i is the last value from the for loop, we use 

`T_i = self._T_z_list[i]`

The last lens and source plane positions are, respectively,

`beta_i_x, beta_i_y = x / T_i, y / T_i
beta_j_x, beta_j_y = x_new / T_j, y_new / T_j`

The geometric time delay is then calculated and updated as

`dt_geo_new = self._geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij)
dt_geo += dt_geo_new`

And then the geometric and gravitational time delays are returned

`return dt_geo, dt_grav`

## Gravitational and Geometric Delays

### the gravitational delay

The gravitational time delay is called as

`self._gravitational_delay(x, y, kwargs_lens, i, z_lens)`

Firstly, x and y are converted into angles using the usual transverse comoving distance conversions (see below for more)

`theta_x, theta_y = self._co_moving2angle(x, y, index)`

Now, the i used when calling this function refers to which plane it is, in order of increasing redshift. To get the actual index of the relevant lens, we use

`k = self._sorted_redshift_index[index]`

Then the potential is simply the potential of this lens, evaluated at the angular positions determined above

`potential = self.func_list[k].function(theta_x, theta_y, **kwargs_lens[k])`

To convert this potential into a time delay in days, we multiply by the time delay distance, using 

`delay_days = self._lensing_potential2time_delay(potential, z_lens, z_source=self._z_source_convention)`

The potential is negative, so the time delay which the function returns is multiplied by a negative sign to correct for this

`return -delay_days`

### the geometrical delay

The geometrical delay is called as

`self._geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij)`

Firstly, we need the difference in angular positions

`d_beta_x = beta_j_x - beta_i_x
d_beta_y = beta_j_y - beta_i_y`

The time delay distance (in days) is calculated via

`tau_ij = T_i * T_j / T_ij * const.Mpc / const.c / const.day_s * const.arcsec**2`

Then the geometric time delay is very simply calculated via

`return tau_ij * (d_beta_x ** 2 + d_beta_y ** 2) / 2`

## Other relevant functions

### _ray_step

This is basically the lens equation but for physical positions. It's initialised as 

`def _ray_step(x, y, alpha_x, alpha_y, delta_T):`

and the updated positions are returned as

`x_ = x + alpha_x * delta_T
y_ = y + alpha_y * delta_T
return x_, y_`

### _add_deflection

This adds the physical deflection angle of a single lens to the deflection field. **I need to draw on paper to understand all this better*

`def _add_deflection(self, x, y, alpha_x, alpha_y, kwargs_lens, index):`

Firstly, the physical coordinates are converted to angular coordinates via the `_co_moving2angle` function (see below)

`theta_x, theta_y = self._co_moving2angle(x, y, index)`

As before, to get the actual index of the relevant lens, we use

`k = self._sorted_redshift_index[index]`

The reduced deflection is extracted from the relative lens via 

`alpha_x_red, alpha_y_red = self.func_list[k].derivatives(theta_x, theta_y, **kwargs_lens[k])`

and these are then converted to physical deflection angles via (see below)

`alpha_x_phys = self._reduced2physical_deflection(alpha_x_red, index)
alpha_y_phys = self._reduced2physical_deflection(alpha_y_red, index)`

The angles returned are

`return alpha_x - alpha_x_phys, alpha_y - alpha_y_phys`

### _co_moving2angle

### _reduced2physical_deflection

In [None]:
### in multi_plane 
    
    def arrival_time(self, theta_x, theta_y, kwargs_lens, check_convention=True):
        """
        light travel time relative to a straight path through the coordinate (0,0)
        Negative sign means earlier arrival time

        :param theta_x: angle in x-direction on the image
        :param theta_y: angle in y-direction on the image
        :param kwargs_lens: lens model keyword argument list
        :return: travel time iThen unit of days
        """
        dt_geo, dt_grav = self.geo_shapiro_delay(theta_x, theta_y, kwargs_lens, check_convention=check_convention)
        return dt_geo + dt_grav

    def geo_shapiro_delay(self, theta_x, theta_y, kwargs_lens, check_convention=True):
        """
        geometric and Shapiro (gravitational) light travel time relative to a straight path through the coordinate (0,0)
        Negative sign means earlier arrival time

        :param theta_x: angle in x-direction on the image
        :param theta_y: angle in y-direction on the image
        :param kwargs_lens: lens model keyword argument list
        :param check_convention: boolean, if True goes through the lens model list and checks whether the positional
         conventions are satisfied.
        :return: geometric delay, gravitational delay [days]
        """
        if check_convention and not self.ignore_observed_positions:
            kwargs_lens = self._convention(kwargs_lens)
        return self._multi_plane_base.geo_shapiro_delay(theta_x, theta_y, kwargs_lens, z_stop=self._z_source,
                                                   T_z_stop=self._T_z_source, T_ij_end=self._T_ij_stop)

In [None]:
### in multi_plane_base

        if z_interp_stop is None:
            z_interp_stop = z_source_convention
        self._cosmo_bkg = Background(cosmo, interp=cosmo_interp, z_stop=z_interp_stop, num_interp=num_z_interp)
        self._z_source_convention = z_source_convention
        if len(lens_redshift_list) > 0:
            z_lens_max = np.max(lens_redshift_list)
            if z_lens_max >= z_source_convention:
                raise ValueError('deflector redshifts higher or equal the source redshift convention (%s >= %s for the reduced lens'
                                 ' model quantities not allowed (leads to negative reduced deflection angles!'
                                 % (z_lens_max, z_source_convention))
        if not len(lens_model_list) == len(lens_redshift_list):
            raise ValueError("The length of lens_model_list does not correspond to redshift_list")

        self._lens_redshift_list = lens_redshift_list
        super(MultiPlaneBase, self).__init__(lens_model_list, numerical_alpha_class=numerical_alpha_class,
                                             lens_redshift_list=lens_redshift_list,
                                             z_source_convention=z_source_convention, kwargs_interp=kwargs_interp)

        if len(lens_model_list) < 1:
            self._sorted_redshift_index = []
        else:
            self._sorted_redshift_index = self._index_ordering(lens_redshift_list)
        z_before = 0
        T_z = 0
        self._T_ij_list = []
        self._T_z_list = []
        # Sort redshift for vectorized reduced2physical factor calculation
        if len(lens_model_list)<1:
            self._reduced2physical_factor = []
        else:
            z_sort = np.array(self._lens_redshift_list)[self._sorted_redshift_index]
            z_source_array = np.ones(z_sort.shape)*z_source_convention
            self._reduced2physical_factor = self._cosmo_bkg.d_xy(0, z_source_convention) / self._cosmo_bkg.d_xy(z_sort, z_source_array)
        for idex in self._sorted_redshift_index:
            z_lens = self._lens_redshift_list[idex]
            if z_before == z_lens:
                delta_T = 0
            else:
                T_z = self._cosmo_bkg.T_xy(0, z_lens)
                delta_T = self._cosmo_bkg.T_xy(z_before, z_lens)
            self._T_ij_list.append(delta_T)
            self._T_z_list.append(T_z)
            z_before = z_lens

    def geo_shapiro_delay(self, theta_x, theta_y, kwargs_lens, z_stop, T_z_stop=None, T_ij_end=None):
        """
        geometric and Shapiro (gravitational) light travel time relative to a straight path through the coordinate (0,0)
        Negative sign means earlier arrival time

        :param theta_x: angle in x-direction on the image
        :param theta_y: angle in y-direction on the image
        :param kwargs_lens: lens model keyword argument list
        :param z_stop: redshift of the source to stop the backwards ray-tracing
        :param T_z_stop: optional, transversal angular distance from z=0 to z_stop
        :param T_ij_end: optional, transversal angular distance between the last lensing plane and the source plane
        :return: dt_geo, dt_shapiro, [days]
        """
        dt_grav = np.zeros_like(theta_x, dtype=float)
        dt_geo = np.zeros_like(theta_x, dtype=float)
        x = np.zeros_like(theta_x, dtype=float)
        y = np.zeros_like(theta_y, dtype=float)
        alpha_x = np.array(theta_x, dtype=float)
        alpha_y = np.array(theta_y, dtype=float)
        i = 0
        z_lens_last = 0
        for i, index in enumerate(self._sorted_redshift_index):
            z_lens = self._lens_redshift_list[index]
            if z_lens <= z_stop:
                T_ij = self._T_ij_list[i]
                x_new, y_new = self._ray_step(x, y, alpha_x, alpha_y, T_ij)
                if i == 0:
                    pass
                elif T_ij > 0:
                    T_j = self._T_z_list[i]
                    T_i = self._T_z_list[i - 1]
                    beta_i_x, beta_i_y = x / T_i, y / T_i
                    beta_j_x, beta_j_y = x_new / T_j, y_new / T_j
                    dt_geo_new = self._geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij)
                    dt_geo += dt_geo_new
                x, y = x_new, y_new
                dt_grav_new = self._gravitational_delay(x, y, kwargs_lens, i, z_lens)
                alpha_x, alpha_y = self._add_deflection(x, y, alpha_x, alpha_y, kwargs_lens, i)

                dt_grav += dt_grav_new
                z_lens_last = z_lens
        if T_ij_end is None:
            T_ij_end = self._cosmo_bkg.T_xy(z_lens_last, z_stop)
        T_ij = T_ij_end
        x_new, y_new = self._ray_step(x, y, alpha_x, alpha_y, T_ij)
        if T_z_stop is None:
            T_z_stop = self._cosmo_bkg.T_xy(0, z_stop)
        T_j = T_z_stop
        T_i = self._T_z_list[i]
        beta_i_x, beta_i_y = x / T_i, y / T_i
        beta_j_x, beta_j_y = x_new / T_j, y_new / T_j
        dt_geo_new = self._geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij)
        dt_geo += dt_geo_new
        return dt_geo, dt_grav

    def _gravitational_delay(self, x, y, kwargs_lens, index, z_lens):
        """

        :param x: co-moving coordinate at the lens plane
        :param y: co-moving coordinate at the lens plane
        :param kwargs_lens: lens model keyword arguments
        :param z_lens: redshift of the deflector
        :param index: index of the lens model in sorted redshfit convention
        :return: gravitational delay in units of days as seen at z=0
        """
        theta_x, theta_y = self._co_moving2angle(x, y, index)
        k = self._sorted_redshift_index[index]
        potential = self.func_list[k].function(theta_x, theta_y, **kwargs_lens[k])
        delay_days = self._lensing_potential2time_delay(potential, z_lens, z_source=self._z_source_convention)
        return -delay_days

    def _geometrical_delay(beta_i_x, beta_i_y, beta_j_x, beta_j_y, T_i, T_j, T_ij):
        """

        :param beta_i_x: angle on the sky at plane i
        :param beta_i_y: angle on the sky at plane i
        :param beta_j_x: angle on the sky at plane j
        :param beta_j_y: angle on the sky at plane j
        :param T_i: transverse diameter distance to z_i
        :param T_j: transverse diameter distance to z_j
        :param T_ij: transverse diameter distance from z_i to z_j
        :return: excess delay relative to a straight line
        """
        d_beta_x = beta_j_x - beta_i_x
        d_beta_y = beta_j_y - beta_i_y
        tau_ij = T_i * T_j / T_ij * const.Mpc / const.c / const.day_s * const.arcsec**2
        return tau_ij * (d_beta_x ** 2 + d_beta_y ** 2) / 2
    

    def _ray_step(x, y, alpha_x, alpha_y, delta_T):
        """
        ray propagation with small angle approximation

        :param x: co-moving x-position
        :param y: co-moving y-position
        :param alpha_x: deflection angle in x-direction at (x, y)
        :param alpha_y: deflection angle in y-direction at (x, y)
        :param delta_T: transverse angular diameter distance to the next step
        :return: co-moving position at the next step (backwards)
        """
        x_ = x + alpha_x * delta_T
        y_ = y + alpha_y * delta_T
        return x_, y_
    
    def _add_deflection(self, x, y, alpha_x, alpha_y, kwargs_lens, index):
        """
        adds the physical deflection angle of a single lens plane to the deflection field

        :param x: co-moving distance at the deflector plane
        :param y: co-moving distance at the deflector plane
        :param alpha_x: physical angle (radian) before the deflector plane
        :param alpha_y: physical angle (radian) before the deflector plane
        :param kwargs_lens: lens model parameter kwargs
        :param index: index of the lens model to be added in sorted redshift list convention
        :param idex_lens: redshift of the deflector plane
        :return: updated physical deflection after deflector plane (in a backwards ray-tracing perspective)
        """
        theta_x, theta_y = self._co_moving2angle(x, y, index)
        k = self._sorted_redshift_index[index]
        alpha_x_red, alpha_y_red = self.func_list[k].derivatives(theta_x, theta_y, **kwargs_lens[k])
        alpha_x_phys = self._reduced2physical_deflection(alpha_x_red, index)
        alpha_y_phys = self._reduced2physical_deflection(alpha_y_red, index)
        return alpha_x - alpha_x_phys, alpha_y - alpha_y_phys
    
    def _co_moving2angle(self, x, y, index):
        """
        transforms co-moving distances Mpc into angles on the sky (radian)

        :param x: co-moving distance
        :param y: co-moving distance
        :param index: index of plane
        :return: angles on the sky
        """
        T_z = self._T_z_list[index]
        theta_x = x / T_z
        theta_y = y / T_z
        return theta_x, theta_y
    
    def _reduced2physical_deflection(self, alpha_reduced, index_lens):
        """
        alpha_reduced = D_ds/Ds alpha_physical

        :param alpha_reduced: reduced deflection angle
        :param index_lens: integer, index of the deflector plane
        :return: physical deflection angle
        """
        factor = self._reduced2physical_factor[index_lens]
        return alpha_reduced * factor
    
    def _lensing_potential2time_delay(self, potential, z_lens, z_source):
        """
        transforms the lensing potential (in units arcsec^2) to a gravitational time-delay as measured at z=0

        :param potential: lensing potential
        :param z_lens: redshift of the deflector
        :param z_source: redshift of source for the definition of the lensing quantities
        :return: gravitational time-delay in units of days
        """
        D_dt = self._cosmo_bkg.ddt(z_lens, z_source)
        delay_days = const.delay_arcsec2days(potential, D_dt)
        return delay_days    
    
    def ddt(self, z_lens, z_source):
        """
        time-delay distance

        :param z_lens: redshift of lens
        :param z_source: redshift of source
        :return: time-delay distance in units of proper Mpc
        """
        return self.d_xy(0, z_lens) * self.d_xy(0, z_source) / self.d_xy(z_lens, z_source) * (1 + z_lens)    
