Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU training #19

Open
zhangn2015 opened this issue Apr 23, 2023 · 1 comment
Open

GPU training #19

zhangn2015 opened this issue Apr 23, 2023 · 1 comment

Comments

@zhangn2015
Copy link

zhangn2015 commented Apr 23, 2023

hi
How can I speed up training on GPU such as VariationalGP?

@zhangn2015 zhangn2015 changed the title dump and load model GPU training Aug 7, 2023
@DanWaxman
Copy link

If you have Jax with CUDA installed, things should run on the GPU automatically. You can check the backend Jax is using with

from jax.lib import xla_bridge
print(xla_bridge.get_backend().platform)

You can install Jax with CUDA via pip; here are the instructions for version 0.4.2, which BayesNewton currently uses (see e.g. the Jax v0.4.2 README, as it differs a bit from the current version):

 pip install -U "jax[cuda]==0.4.2" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html jaxlib==0.4.2

You can double check that a jax.numpy array is on the correct device by calling arr.device().

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants