Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 204 additions & 7 deletions autogalaxy/operate/deflections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import wraps
from functools import wraps, partial
import logging
from autofit.jax_wrapper import numpy as np, use_jax
from autofit.jax_wrapper import numpy as np, use_jax, jit
from typing import List, Tuple, Union

if use_jax:
Expand Down Expand Up @@ -91,6 +91,26 @@ def wrapper(
return wrapper


def one_step(r, _, theta, fun, fun_dr):
r = np.abs(r - fun(r, theta) / fun_dr(r, theta))
return r, None


@partial(jit, static_argnums=(4,))
def step_r(r, theta, fun, fun_dr, N=20):
one_step_partial = jax.tree_util.Partial(
one_step,
theta=theta,
fun=fun,
fun_dr=fun_dr
)
new_r = jax.lax.scan(one_step_partial, r, xs=np.arange(N))[0]
return np.stack([
new_r * np.sin(theta),
new_r * np.cos(theta)
]).T


class OperateDeflections:
"""
Packages methods which manipulate the 2D deflection angle map returned from the `deflections_yx_2d_from` function
Expand Down Expand Up @@ -126,6 +146,9 @@ def deflections_yx_scalar(self, y, x, pixel_scales):

def __eq__(self, other):
return self.__dict__ == other.__dict__ and self.__class__ is other.__class__

def __hash__(self):
return hash(repr(self))

@precompute_jacobian
def tangential_eigen_value_from(self, grid, jacobian=None) -> aa.Array2D:
Expand Down Expand Up @@ -748,6 +771,162 @@ def jacobian_stack_vector(self, y, x, pixel_scales):
),
signature='(),()->(i,i)'
)(y, x)

def convergence_mag_shear_yx(self, y, x):
J = self.jacobian_stack_vector(y, x, 0.05)
K = 0.5 * (J[..., 0, 0] + J[..., 1, 1])
mag_shear = 0.5 * np.sqrt(
(J[..., 0, 1] + J[..., 1, 0])**2 + (J[..., 0, 0] - J[..., 1, 1])**2
)
return K, mag_shear

@partial(jit, static_argnums=(0,))
def tangential_eigen_value_yx(self, y, x):
K, mag_shear = self.convergence_mag_shear_yx(y, x)
return 1 - K - mag_shear

@partial(jit, static_argnums=(0, 3))
def tangential_eigen_value_rt(self, r, theta, centre=(0.0, 0.0)):
y = r * np.sin(theta) + centre[0]
x = r * np.cos(theta) + centre[1]
return self.tangential_eigen_value_yx(y, x)

@partial(jit, static_argnums=(0, 3))
def grad_r_tangential_eigen_value(self, r, theta, centre=(0.0, 0.0)):
# ignore `self` with the `argnums` below
tangential_eigen_part = partial(
self.tangential_eigen_value_rt,
centre=centre
)
return np.vectorize(
jax.jacfwd(tangential_eigen_part, argnums=(0,)),
signature='(),()->()'
)(r, theta)[0]

@partial(jit, static_argnums=(0,))
def radial_eigen_value_yx(self, y, x):
K, mag_shear = self.convergence_mag_shear_yx(y, x)
return 1 - K + mag_shear

@partial(jit, static_argnums=(0, 3))
def radial_eigen_value_rt(self, r, theta, centre=(0.0, 0.0)):
y = r * np.sin(theta) + centre[0]
x = r * np.cos(theta) + centre[1]
return self.radial_eigen_value_yx(y, x)

@partial(jit, static_argnums=(0, 3))
def grad_r_radial_eigen_value(self, r, theta, centre=(0.0, 0.0)):
# ignore `self` with the `argnums` below
radial_eigen_part = partial(
self.radial_eigen_value_rt,
centre=centre
)
return np.vectorize(
jax.jacfwd(radial_eigen_part, argnums=(0,)),
signature='(),()->()'
)(r, theta)[0]

def tangential_critical_curve_jax(
self,
init_r=0.1,
init_centre=(0.0, 0.0),
n_points=300,
n_steps=20,
threshold=1e-5
):
"""
Returns all tangential critical curves of the lensing system, which are computed as follows:

1) Create a set of `n_points` initial points in a circle of radius `init_r` and centred on `init_centre`
2) Apply `n_steps` of Newton's method to these points in the "radial" direction only (i.e. keeping angle fixed).
Jax's auto differentiation is used to find the radial derivatives of the tangential eigen value function for
this step.
3) Filter the results and only keep point that have their tangential eigen value `threshold` of 0

No underlying grid is needed for the method, but the quality of the results are dependent on the initial
circle of points.

Parameters
----------
init_r : float
Radius of the circle of initial guess points
init_centre : tuple
centre of the circle of initial guess points as `(y, x)`
n_points : Int
Number of initial guess points to create (evenly spaced in angle around `init_centre`)
n_steps : Int
Number of iterations of Newton's method to apply
threshold : float
Only keep points whose tangential eigen value is within this value of zero (inclusive)
"""
r = np.ones(n_points) * init_r
theta = np.linspace(0, 2 * np.pi, n_points + 1)[:-1]
new_yx = step_r(
r,
theta,
jax.tree_util.Partial(self.tangential_eigen_value_rt, centre=init_centre),
jax.tree_util.Partial(self.grad_r_tangential_eigen_value, centre=init_centre),
n_steps
)
new_yx = new_yx + np.array(init_centre)
# filter out nan values
fdx = np.isfinite(new_yx).all(axis=1)
new_yx = new_yx[fdx]
# filter out failed points
value = np.abs(self.tangential_eigen_value_yx(new_yx[:, 0], new_yx[:, 1]))
gdx = value <= threshold
return aa.structures.grids.irregular_2d.Grid2DIrregular(values=new_yx[gdx])

def radial_critical_curve_jax(
self,
init_r=0.01,
init_centre=(0.0, 0.0),
n_points=300,
n_steps=20,
threshold=1e-5
):
"""
Returns all radial critical curves of the lensing system, which are computed as follows:

1) Create a set of `n_points` initial points in a circle of radius `init_r` and centred on `init_centre`
2) Apply `n_steps` of Newton's method to these points in the "radial" direction only (i.e. keeping angle fixed).
Jax's auto differentiation is used to find the radial derivatives of the radial eigen value function for
this step.
3) Filter the results and only keep point that have their radial eigen value `threshold` of 0

No underlying grid is needed for the method, but the quality of the results are dependent on the initial
circle of points.

Parameters
----------
init_r : float
Radius of the circle of initial guess points
init_centre : tuple
centre of the circle of initial guess points as `(y, x)`
n_points : Int
Number of initial guess points to create (evenly spaced in angle around `init_centre`)
n_steps : Int
Number of iterations of Newton's method to apply
threshold : float
Only keep points whose radial eigen value is within this value of zero (inclusive)
"""
r = np.ones(n_points) * init_r
theta = np.linspace(0, 2 * np.pi, n_points + 1)[:-1]
new_yx = step_r(
r,
theta,
jax.tree_util.Partial(self.radial_eigen_value_rt, centre=init_centre),
jax.tree_util.Partial(self.grad_r_radial_eigen_value, centre=init_centre),
n_steps
)
new_yx = new_yx + np.array(init_centre)
# filter out nan values
fdx = np.isfinite(new_yx).all(axis=1)
new_yx = new_yx[fdx]
# filter out failed points
value = np.abs(self.radial_eigen_value_yx(new_yx[:, 0], new_yx[:, 1]))
gdx = value <= threshold
return aa.structures.grids.irregular_2d.Grid2DIrregular(values=new_yx[gdx])

def jacobian_from(self, grid):
"""
Expand Down Expand Up @@ -798,14 +977,26 @@ def jacobian_from(self, grid):

return [[a11, a12], [a21, a22]]
else:
a = self.jacobian_stack_vector(
A = self.jacobian_stack_vector(
grid.array[:, 0],
grid.array[:, 1],
grid.pixel_scales
)
a = np.eye(2).reshape(1, 2, 2) - A
return [
[
aa.Array2D(values=a[..., 1, 1], mask=grid.mask),
aa.Array2D(values=a[..., 1, 0], mask=grid.mask)
],
[
aa.Array2D(values=a[..., 0, 1], mask=grid.mask),
aa.Array2D(values=a[..., 0, 0], mask=grid.mask)
]
]

# transpose the result
# use `moveaxis` as grid might not be nx2
return np.moveaxis(np.moveaxis(a, -1, 0), -1, 0)
# return np.moveaxis(np.moveaxis(a, -1, 0), -1, 0)

@precompute_jacobian
def convergence_2d_via_jacobian_from(self, grid, jacobian=None) -> aa.Array2D:
Expand Down Expand Up @@ -855,9 +1046,15 @@ def shear_yx_2d_via_jacobian_from(
A precomputed lensing jacobian, which is passed throughout the `CalcLens` functions for efficiency.
"""

shear_yx_2d = np.zeros(shape=(grid.shape_slim, 2))
shear_yx_2d[:, 0] = -0.5 * (jacobian[0][1] + jacobian[1][0])
shear_yx_2d[:, 1] = 0.5 * (jacobian[1][1] - jacobian[0][0])
if not use_jax:
shear_yx_2d = np.zeros(shape=(grid.shape_slim, 2))
shear_yx_2d[:, 0] = -0.5 * (jacobian[0][1] + jacobian[1][0])
shear_yx_2d[:, 1] = 0.5 * (jacobian[1][1] - jacobian[0][0])

else:
shear_y = -0.5 * (jacobian[0][1] + jacobian[1][0]).array
shear_x = 0.5 * (jacobian[1][1] - jacobian[0][0]).array
shear_yx_2d = np.stack([shear_y, shear_x]).T

if isinstance(grid, aa.Grid2DIrregular):
return ShearYX2DIrregular(values=shear_yx_2d, grid=grid)
Expand Down
12 changes: 10 additions & 2 deletions autogalaxy/profiles/geometry_profiles.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@

if os.environ.get("USE_JAX", "0") == "1":
import jax.numpy as np
use_jax = True
else:
import numpy as np
use_jax = False

import autoarray as aa

Expand Down Expand Up @@ -129,7 +131,10 @@ def _cartesian_grid_via_radial_from(
radius
The circular radius of each coordinate from the profile center.
"""
grid_angles = np.arctan2(grid[:, 0], grid[:, 1])
if use_jax:
grid_angles = np.arctan2(grid.array[:, 0], grid.array[:, 1])
else:
grid_angles = np.arctan2(grid[:, 0], grid[:, 1])
cos_theta, sin_theta = self.angle_to_profile_grid_from(grid_angles=grid_angles)
return np.multiply(radius[:, None], np.vstack((sin_theta, cos_theta)).T)

Expand All @@ -145,7 +150,10 @@ def transformed_to_reference_frame_grid_from(self, grid, **kwargs):
grid
The (y, x) coordinates in the original reference frame of the grid.
"""
return np.subtract(grid, self.centre)
if use_jax:
return np.subtract(grid.array, np.array(self.centre))
else:
return np.subtract(grid, self.centre)

@aa.grid_dec.to_grid
def transformed_from_reference_frame_grid_from(self, grid, **kwargs):
Expand Down
32 changes: 22 additions & 10 deletions autogalaxy/profiles/light/standard/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import numpy as np
# import numpy as np
from autofit.jax_wrapper import numpy as np, use_jax
from typing import Optional, Tuple

import autoarray as aa
Expand Down Expand Up @@ -59,15 +60,26 @@ def image_2d_via_radii_from(self, grid_radii: np.ndarray) -> np.ndarray:
grid_radii
The radial distances from the centre of the profile, for each coordinate on the grid.
"""
return np.multiply(
self._intensity,
np.exp(
-0.5
* np.square(
np.divide(grid_radii, self.sigma / np.sqrt(self.axis_ratio))
)
),
)
if use_jax:
return np.multiply(
self._intensity,
np.exp(
-0.5
* np.square(
np.divide(grid_radii.array, self.sigma / np.sqrt(self.axis_ratio))
)
),
)
else:
return np.multiply(
self._intensity,
np.exp(
-0.5
* np.square(
np.divide(grid_radii, self.sigma / np.sqrt(self.axis_ratio))
)
),
)

@aa.over_sample
@aa.grid_dec.to_array
Expand Down
12 changes: 11 additions & 1 deletion autogalaxy/profiles/mass/stellar/gaussian.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import copy
import numpy as np
from autofit.jax_wrapper import use_jax
if use_jax:
import jax
from scipy.special import wofz
from scipy.integrate import quad
from typing import Tuple
Expand Down Expand Up @@ -188,7 +191,14 @@ def image_2d_via_radii_from(self, grid_radii: np.ndarray):
@property
def axis_ratio(self):
axis_ratio = super().axis_ratio
return axis_ratio if axis_ratio < 0.9999 else 0.9999
if use_jax:
return jax.lax.select(
axis_ratio < 0.9999,
axis_ratio,
0.9999
)
else:
return axis_ratio if axis_ratio < 0.9999 else 0.9999

def zeta_from(self, grid: aa.type.Grid2DLike):
q2 = self.axis_ratio**2.0
Expand Down