Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sharding the computation between GPUs #83

Open
PabloAMC opened this issue Sep 23, 2023 · 5 comments
Open

Sharding the computation between GPUs #83

PabloAMC opened this issue Sep 23, 2023 · 5 comments
Labels
enhancement New feature or request

Comments

@PabloAMC
Copy link
Collaborator

Add sharding following https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html

@PabloAMC PabloAMC added the enhancement New feature or request label Sep 23, 2023
@PabloAMC
Copy link
Collaborator Author

PabloAMC commented Sep 28, 2023

Before I start working on this, I would like to have a discussion:
What would be the target of sharding here? There are at least two possible objectives:

  1. Sharding very large arrays that are kept in memory, chi and repulsion tensor are the obvious candidates here.
  2. Sharding a list of molecules to parallelize the computation between them. The problem here is that, to the best of my knowledge, sharding happens between arrays, so I don't think this is a great target.

I'd be interested in your comments @Matematija.

@Matematija
Copy link
Collaborator

Sorry for not replying earlier.

I would say that #1 is definitely the way to go.

Backing up a little bit, my intuition is that breaking up the integration grid into parts that get committed to different devices is a good starting point. Many large tensors in DFT inherit their "largeness" from the grid.

@jackbaker1001 jackbaker1001 added this to the Performance optimization milestone Dec 7, 2023
@jackbaker1001
Copy link
Collaborator

I think I have to disagree with @Matematija on this one.

While parallelizing single DFT executions is nice, a lot of performance can be extracted by simply sending all data to a single (good) GPU on an HPC cluster. On Perlmutter, for example, I can perform differentiable SCF calculations in a matter of seconds for reasonable solid materials with reasonable basis sets.

Given that the use case for Grad DFT is most of the time going to be learning from small systems (as this is the domain in which accurate wavefunction calculations are possible for the training data), I think it is best to put some effort into parallelizing batched loss function computation with sharding.

This, in principle, should be pretty easy as our data is stored in flax.struct.dataclass instances which we can put into a standard python list and shard with jax.put_device_sharded (see https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put_sharded.html). This would be single program multiple data parallelism and would allow us to train realistic models as we can approach usual machine-learning batch sizes in training.

@Matematija
Copy link
Collaborator

If the target use case is learning smaller systems then I agree that the problem I was describing doesn't exist by definition. Simple batch-level parallelism should do the trick. I was imagining large molecules with tens of millions of grid points at inference time.

@jackbaker1001
Copy link
Collaborator

I guess that in the (possibly far) future, we may wish to insert such parallelizations. I just think for now batch parallelism is a good target.

I'm experimenting with this now...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants