-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add multi-GPU support with accelerate
#76
Conversation
Packedgraphs do not have a .to() method, but a custom .cuda() method. Pretty messed up. https://github.com/DeepGraphLearning/torchdrug/blob/master/torchdrug/data/graph.py |
So one main thing about this PR is that I don't think it's necessary to use Accelerate to add GPU support - we only have to more cleverly keep track of devices. I'm not against using Accelerate since it has some nice add-ons for multi-gpu etc. but I wouldn't depend on it to solve the original problem Perhaps we can monkey patch a to() into the packed graph class |
Hi @cthoyt 😄 My thinking was this approach avoids us having to re-invent the wheel and quickly solves the current gpu need. The original issue doesn't describe requirements to aim for #65. Two questions:
I'm not attached to Accelerate as a solution either, just trying to understand limits & needs |
Codecov Report
@@ Coverage Diff @@
## main #76 +/- ##
==========================================
- Coverage 94.65% 94.00% -0.66%
==========================================
Files 34 34
Lines 1478 1500 +22
==========================================
+ Hits 1399 1410 +11
- Misses 79 90 +11
Continue to review full report at Codecov.
|
I'm saying that implementing GPU usability and implementing |
Additionally I've opened a PR on torchdrug to solve the problem upstream, which will be much more elegant than us hacking it in: DeepGraphLearning/torchdrug#70. In the meantime, we could provide a |
chemicalx/pipeline.py
Outdated
prediction = model(*model.unpack(batch)) | ||
loss_value = loss(prediction, batch.labels) | ||
|
||
device_batch = to_device(model.unpack(batch), device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the batch generator should know what device to put the batches on so this doesn't have to be changed in the pipeline
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I.e., the BatchGenerator.__init__
should take an optional torch.device
(if not given, assume CPU) and the generation steps should take care of moving the tensors over to the appropriate device
Closes AstraZeneca#76. This PR first requires AstraZeneca#84 to be tested and merged. ## Blocked by - [ ] AstraZeneca#84
@GavEdwards please note we've already merged a simple solution into the main branch and now updated your PR with it, please check it out |
@@ -15,6 +15,9 @@ | |||
"pystow", | |||
"pytdc", | |||
"more-itertools", | |||
"accelerate", | |||
# FIXME what is packaging for? | |||
"packaging", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's packaging
for?
@@ -70,6 +71,45 @@ def save(self, directory: Union[str, Path]) -> None: | |||
) | |||
|
|||
|
|||
def to_device(objects, device): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, will close the PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@GavEdwards I was only referring to the to_device
function. There's still some benefit to consider adding accelerate for multi-GPU or TPU usage
Summary
Enable GPU support (+more) via the Accelerate library.
This is still work in progress - there's still some bugs to be ironed out around multi-gpu and some models.
TODO:
Changes