Skip to content

Performance optimization

No due date 0% complete

To fully realize the capability of Grad DFT to produce the new frontiers of accuracy for XC functionals, we must consider some performance bottlenecks which prevent us from training models at scale on modern HPC platforms.

Our goal in this milestone is to have Grad DFT efficiently compute a batched loss function with a reasonable batch size (16-64 structu…

To fully realize the capability of Grad DFT to produce the new frontiers of accuracy for XC functionals, we must consider some performance bottlenecks which prevent us from training models at scale on modern HPC platforms.

Our goal in this milestone is to have Grad DFT efficiently compute a batched loss function with a reasonable batch size (16-64 structures) in parallel on a supercomputer. To do this, we will need:

(1) Single program-multiple data parallelism. This is implemented in JAX with sharding. HPC system specs dependent, we should aim to run DFT calculations with 1-4 structures per node and scale to 10-100 nodes.

(2) Better handing or ERIs or bypassing their need completely. Holding on to ERIs loaded in from PySCF uses a large amount of memory. For periodic systems, we can use FFTs to efficiently get the hartree potential so we will no longer need ERIs.

(3) Using symmetry. Many computations, for molecules and solids alike, can be sped up using point-group and space-group symmetries. We should implement this where convenient.

Loading