Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jitcdde and jax #52

Open
boxaio opened this issue Jan 10, 2024 · 3 comments
Open

jitcdde and jax #52

boxaio opened this issue Jan 10, 2024 · 3 comments

Comments

@boxaio
Copy link

boxaio commented Jan 10, 2024

is jitcdde compatible with jax?

@Wrzlprmft
Copy link
Contributor

Wrzlprmft commented Jan 10, 2024

As far as I can tell, Jax is for vectorisable problems (SPMD). The main strength of JiTC*DE is that it can handle non-vectorisable problems (MPMD). Therefore combining this with Jax makes little sense. If your problem is vectorisable, e.g., a typical PDE, there are usually more suited tools than JiTC*DE, and they may use or be Jax.

However, I acknowledge that there is a lack of tools for DDEs and thus you might want to have the best of both worlds for something like a PDDE (partial delay differential equation). However, here you have the additional problem that JiTCDDE’s integrator is strongly intertwined with the computation of the derivative (where Jax might help).

That being said, there may be a solution for some problems where you can benefit from both modules, but for that I would need to know the detailed problem.

@boxaio
Copy link
Author

boxaio commented Jan 10, 2024

As far as I can tell, Jax is for vectorisable problems (SPMD). The main strength of JiTCDE is that it can handle non-vectorisable problems (MPMD). Therefore combining this with Jax makes little sense. If your problem is vectorisable, e.g., a typical PDE, there are usually more suited tools than JiTCDE, and they may use or be Jax.

However, I acknowledge that there is a lack of tools for DDEs and thus you might want to have the best of both worlds for something like a PDDE (partial delay differential equation). However, here you have the additional problem that JiTCDDE’s integrator is strongly intertwined with the computation of the derivative (where Jax might help).

That being said, there may be a solution for some problems where you can benefit from both modules, but for that I would need to know the detailed problem.

Thanks for your reply.
Actually, I was trying to solve DDEs using JiTCDDE, starting from enormous amount of different initial conditions. For solving ODEs with various initial conditions, this is easily done with jax.vmap() function in JAX.
However, in JiTDDE, the data are represented using numpy arrays instead of jax.numpy, see the file jitcdde/jitced_template.c
line 3: # include <numpy/arrayobject.h>
I guess I have to transform from numpy to jax.numpy, but did not find similar file in jax.numpy

@Wrzlprmft
Copy link
Contributor

I see two main ways of how parallelising initial conditions can make you win over JiTCDDE:

  • Reducing overhead. This mostly matters for small DDEs and should become irrelevant for larger ones. Also note that JiTCDDE has some features that reduce its overhead for such applications, in particular saving the compilate, so you only need to compile once.

  • Using a GPU. This may speed things up for suitable problems and if you parallelise a lot of initial conditions at once.

In both cases, the more initial conditions you run in parallel, the worse the drawbacks of parallelising:

First, suppose that you parallelise by “copying” your differential equations, e.g., instead of having six three-dimensional systems, you have one eighteen-dimensional system consisting of six non-interacting subsystems.

One of the powers of JiTCDDE is that it uses an adaptive integrator, i.e., the step size changes according to how easily the dynamics can be integrated. The problem with the copying approach is that the integrator will choose the lowest step size needed by any of these six systems and thus many of the systems will be integrated with a smaller step size than necessary. A variation of this approach that avoids this would be adapting step sizes for each subsystem. However, that way some integrations may be longer than others.

Moreover, in many scenarios where you run a lot initial conditions, you have some criterion for aborting a run, e.g., because the dynamics clearly converged to a fixed point. With the copying approach, you lose that option.

The probably most worthwhile way to use the copying approach is to use a fixed step size since the step size will be rather constant anyway for a large number of subsystems. This way, you save some overhead required for step-size computations, but then you are so far removed from what JiTCDDE does that you might as well write your own tool. However, since you do not have to handle changing step sizes, implementing this becomes a lot easier.

For solving ODEs with various initial conditions, this is easily done with jax.vmap() function in JAX.

At a quick glance, unless you create your own solver using Jax routines, this will either not grant you any speed-up compared to basic multi-core parallelisation tools or use the copying approach outlined above with all its drawbacks.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants