Skip to content

Use auto diff to get at jacobian of deflection angles#206

Merged
CKrawczyk merged 1 commit intofeature/jax_wrapperfrom
feature/jax_jacobian
Oct 25, 2024
Merged

Use auto diff to get at jacobian of deflection angles#206
CKrawczyk merged 1 commit intofeature/jax_wrapperfrom
feature/jax_jacobian

Conversation

@CKrawczyk
Copy link
Copy Markdown
Collaborator

These changes add the ability to use JAX to find the jacobian of deflection angles. To do this it first needs a version of the deflection angle function that takes in two scalars and returns a 2 element vector. From this jax.jacfwd is used to get the 2x2 jacobian for a single (y,x) value. This function is vectorized to take in a set of (y,x) values (of arbitrary shape) and return a (*shape, 2, 2) array of jacobian values.

The final step is to pass the value from a grid object into this function.

These changes add the ability to use JAX to find the jacobian of deflection angles.  To do this it first needs a version of the deflection angle function that takes in two scalars and returns a 2 element vector.  From this `jax.jacfwd` is used to get the 2x2 jacobian for a single (y,x) value.  This function is vectorized to take in a set of (y,x) values (of arbitrary shape) and return a (*shape, 2, 2) array of jacobian values.

The final step is to pass the value from a `grid` object into this function.
Comment on lines +730 to +731
if not use_jax:
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this ever get called if jax is off?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't expect it to, but figured just to be safe.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be worth raising an exception instead?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

if not use_jax:
    raise AssertionError("Function should not be called unless JAX is in use")

return ...

@CKrawczyk CKrawczyk merged commit 758acbb into feature/jax_wrapper Oct 25, 2024
@CKrawczyk CKrawczyk deleted the feature/jax_jacobian branch October 25, 2024 15:26
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants